In [None]:
import torch
import os
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
os.chdir('../src')

from models.diffusion import Diffusion
from models.unet import Unet

In [None]:
model_path = '../outputs/models/mnist_diffusion_epoch_0.pt'

# model params
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_shape = (28, 28)
channels = 1
dim_mults = (1, 2, 4, )
T = 300

unet = Unet(
    dim=data_shape[0],
    channels=channels,
    dim_mults=dim_mults,
)
unet.load_state_dict(torch.load(model_path))

model = Diffusion(
    model=unet,
    data_shape=data_shape,
    T=T,
    device=device,
)

In [None]:
reverse_transform = transforms.Compose([transforms.Lambda(lambda x: (x+1)/2), transforms.ToPILImage()])

# generate samples
sample = model.sample().squeeze(0)
sample = reverse_transform(sample)
plt.imshow(sample, cmap='gray')