In [None]:
import sys
import os

SCRIPT_DIR = os.path.dirname(os.path.abspath("."))
sys.path.append(SCRIPT_DIR)

In [None]:
from trainer.models import ldm
from torchvision import transforms, datasets
from utils.loader import ConfigKey, DataloaderConfig, DatasetLoader

In [None]:
class NotebookLoader(DatasetLoader):
    def get_transform(self):
        args = self.args

        def dequantize(x, nvals=256):
            """[0, 1] -> [0, nvals] -> add uniform noise -> [0, 1]"""
            noise = x.new().resize_as_(x).uniform_()
            x = x * (nvals - 1) + noise
            x = x / nvals
            return x

        transform = transforms.Compose(
            [
                # transforms.RandomResizedCrop(args.img_size, scale=(0.8, 1.0)),
                transforms.Resize((args.img_size, args.img_size)),
                transforms.ToTensor(),
                # transforms.Normalize((0.5,), (0.5,)),
                dequantize,
            ]
        )

        return transform

    def get_dataloader_configs(self):
        args = self.args

        dataset = datasets.MNIST(
            args.dataset_path,
            transform=self.get_transform(),
            download=True,
        )
        config = DataloaderConfig(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=args.shuffle,
        )

        return {ConfigKey.train: config}

In [None]:
ldm_trainer = ldm.LDMTrainer()
args = ldm_trainer.args

In [None]:
args.model_type = "sde"
args.img_size = 32
args.in_channels = 1
args.z_channels = 32
args.batch_size = 256
args.shuffle = True
args.save_freq = 2
args.dataset_path = "./data"
args.beta_min = 0.1
args.beta_max = 1
args.num_classes = 10
args.loader = NotebookLoader
args.vae_checkpoint = "ckpt-100.pt"

In [None]:
ldm_trainer.train()