In [2]:
import time
import torch
import torch.utils.data
from draugr.numpy_utilities import SplitEnum
from draugr.torch_utilities import (
    TensorBoardPytorchWriter,
    TorchEvalSession,
    global_torch_device,
)
from draugr.writers import Writer
from math import inf
from pathlib import Path
from matplotlib import pyplot

from neodroidvision.utilities import scatter_plot_encoding_space
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from draugr.tqdm_utilities import progress_bar
from neodroidvision import PROJECT_APP_PATH
from neodroidvision.data.classification.vgg_face2 import VggFace2
from neodroidvision.regression.vae.architectures.beta_vae import HigginsVae
from neodroidvision.regression.vae.architectures.vae import VAE

ImportError: cannot import name 'hwc_to_chw' from 'draugr' (/home/heider/Projects/draugr/draugr/__init__.py)

In [None]:
torch.manual_seed(8237245329)
LOWEST_L = inf

core_count = 0  # min(8, multiprocessing.cpu_count() - 1)

GLOBAL_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DL_KWARGS = (
    {"num_workers": core_count, "pin_memory": True} if torch.cuda.is_available() else {}
)
BASE_PATH = PROJECT_APP_PATH.user_data / "bvae"
if not BASE_PATH.exists():
    BASE_PATH.mkdir(parents=True)

INPUT_SIZE = 64
CHANNELS = 3

BATCH_SIZE = 1024
EPOCHS = 1000
LR = 3e-3
ENCODING_SIZE = 2

DATASET = VggFace2(
    Path.home() / "Data" / "Datasets" / "vggface2",
    split=SplitEnum.training,
    resize_s=(INPUT_SIZE),
)
MODEL: VAE = HigginsVae(CHANNELS, latent_size=ENCODING_SIZE).to(global_torch_device())
BETA = 4

Processing: 0 images for train split
Processing: 1000 images for train split
Processing: 2000 images for train split
Processing: 3000 images for train split
Processing: 4000 images for train split
Processing: 5000 images for train split
Processing: 6000 images for train split
Processing: 7000 images for train split
Processing: 8000 images for train split
Processing: 9000 images for train split
Processing: 10000 images for train split
Processing: 11000 images for train split
Processing: 12000 images for train split
Processing: 13000 images for train split
Processing: 14000 images for train split
Processing: 15000 images for train split
Processing: 16000 images for train split
Processing: 17000 images for train split
Processing: 18000 images for train split
Processing: 19000 images for train split
Processing: 20000 images for train split
Processing: 21000 images for train split
Processing: 22000 images for train split
Processing: 23000 images for train split
Processing: 24000 images for 

In [None]:
def stest_model(
    model: VAE,
    epoch_i: int,
    metric_writer: Writer,
    loader: DataLoader,
    save_images: bool = True,
):
    """

    Args:
      model:
      epoch_i:
      metric_writer:
      loader:
      save_images:
    """
    global LOWEST_L
    with TorchEvalSession(model):
        with torch.no_grad():
            for i, (original, labels, *_) in progress_bar(enumerate(loader)):
                original = original.to(global_torch_device())

                reconstruction, mean, log_var = model(original)

                if save_images:
                    if i == 0:
                        n = min(original.size(0), 8)
                        comparison = torch.cat([original[:n], reconstruction[:n]])
                        save_image(
                            comparison.cpu(),  # Torch save images
                            str(BASE_PATH / f"reconstruction_{str(epoch_i)}.png"),
                            nrow=n,
                        )

                scatter_plot_encoding_space(
                    str(BASE_PATH / f"encoding_space_{str(epoch_i)}.png"),
                    mean.to("cpu").numpy(),
                    log_var.to("cpu").numpy(),
                    labels,
                )

                break

In [None]:
if __name__ == "__main__":

    def main():
        dataset_loader = DataLoader(
            DATASET, batch_size=BATCH_SIZE, shuffle=True, **DL_KWARGS
        )

        with TensorBoardPytorchWriter(
            PROJECT_APP_PATH.user_log / "VggFace2" / "BetaVAE" / f"{time.time()}"
        ) as metric_writer:
            for epoch in range(1, EPOCHS + 1):
                # stest_model(MODEL, epoch, metric_writer, dataset_loader)
                with TorchEvalSession(MODEL):
                    with torch.no_grad():
                        inv_sample = DATASET.inverse_transform(
                            MODEL.sample().view(CHANNELS, INPUT_SIZE, INPUT_SIZE)
                        )
                        inv_sample.save(str(BASE_PATH / f"sample_{str(epoch)}.png"))
                        if ENCODING_SIZE == 2:
                            from neodroidvision.utilities import plot_manifold

                            plot_manifold(
                                MODEL.decoder,
                                out_path=BASE_PATH / f"manifold_{str(epoch)}.png",
                                img_w=INPUT_SIZE,
                                img_h=INPUT_SIZE,
                            )
                            pyplot.show()
                            break

    main()

  return torch.FloatTensor(


TypeError: Cannot handle this data type: (1, 1, 3), <f4