In [None]:
import yaml
from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

from models.vae import VAE
from models.unet import Unet
from utils.dataset import Dataset
from utils.train_ddpm import train
from utils.scheduler import LinearNoiseSchedule


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DDPM_CONFIG = "configs/ddpm.yaml"
VAE_CONFIG = "configs/vae.yaml"
VAE_CKPT = ""

In [2]:
# Read the config file #
with open(DDPM_CONFIG, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
ddpm_model_config = config['model_config']
ddpm_dataset_config = config['dataset_config']
ddpm_training_config = config['training_config']

with open(VAE_CONFIG, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
vae_model_config = config['model_config']

In [3]:
scheduler = LinearNoiseSchedule(T=ddpm_training_config['NUM_TIMESTEPS'])

In [4]:
transform = [
    transforms.ToTensor(),
    transforms.Resize((ddpm_dataset_config['IMG_SIZE'], ddpm_dataset_config['IMG_SIZE']), Image.BICUBIC),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]


data_loader = DataLoader(
    Dataset(ddpm_dataset_config['ROOT'],transform),
    batch_size= ddpm_dataset_config['BATCH_SIZE'],
    shuffle = True,
    num_workers = 2
)

In [None]:
model = Unet(in_channels = 3,out_channels = 2, model_config = ddpm_model_config).to(DEVICE)

In [6]:
vae = VAE(model_config=vae_model_config).to(DEVICE)
vae.eval()
# vae.load_state_dict(torch.load(VAE_CKPT))
for param in vae.parameters():
    param.requires_grad = False

In [None]:
optimizer = Adam(model.parameters(),lr=1E-5)
criterion = torch.nn.MSELoss()

In [None]:
train(
    num_epochs = ddpm_training_config["NUM_EPOCHS"],
    data_loader = data_loader,
    optimizer = optimizer,
    T = ddpm_training_config["NUM_TIMESTEPS"],
    scheduler = scheduler
    model = model,
    criterion = criterion,
)