In [1]:
import src
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision.transforms import Compose, Lambda, ToPILImage
from pathlib import Path
from torch.optim import Adam
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
beta_start = 0.0001
beta_end = 0.02
timesteps = 200
image_size = 32
num_channels = 3
batch_size = 128
dataset_name = 'cifar10'
results_folder_name = './results'
sample_and_save_freq = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
learninig_rate = 1e-3
epochs = 5

In [3]:
scheduler = src.LinearScheduler(beta_start=beta_start, beta_end=beta_end, timesteps=timesteps)

In [4]:
reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

forward_diffusion = src.ForwardDiffusion(sqrt_alphas_cumprod=scheduler.sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=scheduler.sqrt_one_minus_alphas_cumprod, reverse_transform=reverse_transform)


In [5]:
dataset = load_dataset(dataset_name)

forward_transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

def transforms_(examples):
   examples["pixel_values"] = [forward_transform(image.convert("RGB" if image.mode == "RGB" else "L")) for image in examples["img"]]
   del examples["img"]

   return examples

transformed_dataset = dataset.with_transform(transforms_).remove_columns("label")

dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)


results_folder = Path(results_folder_name)


Found cached dataset cifar10 (/Users/dhruvsrikanth/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)
100%|██████████| 2/2 [00:00<00:00, 327.33it/s]


In [6]:
device = torch.device(device)

In [7]:
model = src.DDPM(n_features=image_size, in_channels=num_channels, channel_scale_factors=(1, 2, 4,))
model.to(device)

optimizer = Adam(model.parameters(), lr=learninig_rate)
criterion = src.get_loss

sampler = src.Sampler(betas=scheduler.betas, sqrt_one_minus_alphas_cumprod=scheduler.sqrt_one_minus_alphas_cumprod, sqrt_one_by_alphas=scheduler.sqrt_one_by_alphas, posterior_variance=scheduler.posterior_variance, timesteps=timesteps)


In [8]:
src.train(
    image_size=image_size, 
    num_channels=num_channels, 
    epochs=epochs, 
    timesteps=timesteps, 
    sample_and_save_freq=sample_and_save_freq, 
    save_folder=results_folder, 
    forward_diffusion_model=forward_diffusion, 
    denoising_model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    dataloader=dataloader, 
    sampler=sampler, 
    device=device
)

Training DDPM:   1%|          | 2/391 [00:21<1:10:30, 10.88s/it, Loss=0.5514]


KeyboardInterrupt: 