In [None]:
import torch
import os
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
os.chdir('../src')

from models.diffusion import Diffusion, CondDiffusion
from models.unet import Unet
from models.condunet import condUnet

reverse_transform = transforms.Compose([transforms.Lambda(lambda x: (x+1)/2), transforms.ToPILImage()])

In [None]:
model_path = '../outputs/models/fashion_mnist_diffusion_epoch_9.pt'

# model params
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_shape = (28, 28)
channels = 1
dim_mults = (1, 2, 4, )
T = 300

unet = Unet(
    dim=data_shape[0],
    channels=channels,
    dim_mults=dim_mults,
)
unet.load_state_dict(torch.load(model_path))

model = Diffusion(
    model=unet,
    data_shape=data_shape,
    T=T,
    device=device,
)

In [None]:
# generate samples
sample = model.sample().squeeze(0)
sample = reverse_transform(sample)
plt.imshow(sample, cmap='gray')

In [None]:
model_path = '../outputs/models/cond_mnist_diffusion_epoch_2.pt'

# model params
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_shape = (28, 28)
channels = 1
dim_mults = (1, 2, 4, )
T = 300

net = condUnet(
    dim=data_shape[0],
    channels=channels,
    dim_mults=dim_mults,
    num_classes=10,
)
net.load_state_dict(torch.load(model_path))

model = CondDiffusion(
    model=net,
    data_shape=data_shape,
    T=T,
    device=device,
)

In [None]:
# generate samples
label = torch.tensor([2]).to(device)
sample = model.sample(label).squeeze(0)
sample = reverse_transform(sample)
plt.imshow(sample, cmap='gray')