# Imports
perform necessary imports and preparation

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
import torchvision
import numpy as np

from torchviz import make_dot
from tqdm import tqdm

import matplotlib.pyplot as plt

from skindataset import SkinDataset
from skindiffuser import NoiseScheduler, SkinUnet
from util import HParams

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Notebook is running on device: {device}')

For easier tuning and experimentation regarding hyperparameters, we outsource them into a json file. The following utility class loads the values and makes the accessible to the code

In [None]:
hparams = HParams()
hparams

Load the model checkpoint

In [None]:
model = SkinUnet()
model.load_state_dict(torch.load("checkpoints/unet_epoch_125.pt"))
model.to(device)
model.eval()

## Inference

In [None]:
@torch.no_grad()
def sample_timestep(x, t, unet):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = NoiseScheduler.get_index_from_list(ns.betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = NoiseScheduler.get_index_from_list(
        ns.sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = NoiseScheduler.get_index_from_list(ns.sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * unet(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = NoiseScheduler.get_index_from_list(ns.posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 
    

In [None]:
# Sample noise
img = torch.randn((1, 3, 64, 64), device=device)

for i in range(hparams.data['T'])[::-1]:
    t = torch.full((1,), i, device=device).long()
    img = sample_timestep(img, torch.tensor([i]).long().to(device), model)

npimg = img[0].cpu().numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
