In [None]:
import sys
import ssl

ssl._create_default_https_context = ssl._create_unverified_context
sys.path.append('../../')

import numpy as np
import random
from IPython.display import Image

import torch
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import MNIST

from utils import show_first_batch, transform_data, show_images, plot_metrics_iddpm, show_tensor_images, calculate_metrics
from train import Trainer
from diffusion_models.ddpm_classifier_free_guidance import GaussianDiffusion
from metrics.fid_score import fid_score
from metrics.inception_score import inception_score
from diffusion_models.iddpm import Classifier, classifier_cond_fn

from diffusion_models.ddpm_classifier_free_guidance import Unet

import ipywidgets as widgets
widgets.IntSlider()

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
dataset = transform_data(MNIST, store_path="../../datasets")
train_dataloader = DataLoader(dataset, batch_size=128, num_workers=0, shuffle=True)
show_first_batch(train_dataloader)

In [None]:
num_classes = 10

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    num_classes = num_classes,
    cond_drop_prob = 0.5
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    timesteps = 1000
).cuda()


In [None]:
trainer = Trainer(
    diffusion,
    train_dataloader,
    train_lr = 2e-4,
    train_num_steps = 700000, 
    device=device
) 

In [None]:
trainer.train_guidance_free()

In [None]:
store_path = '../../model_weights/iddpm_mnist_steps_1000_guidance_free.pt'
torch.save(diffusion.state_dict(), store_path)

In [None]:
batch_size = 100
image_classes = torch.randint(0, num_classes, (batch_size,)).cuda()


sampled_images = diffusion.sample(
        classes = image_classes,
        cond_scale = 3.                # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
)

show_images(sampled_images)