In [None]:
# !pip install matplotlib scipy tqdm kaggle

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from utils import *
from model import *
from diffusion import *
from config import Config
from train_utils import *

In [None]:
config = Config()

## Load dataset and display samples

In [None]:
data_root_dir = "data"

dataset = torchvision.datasets.StanfordCars(data_root_dir)
show_images(dataset, num_samples=4, cols=1)

In [None]:
data = load_transformed_dataset(config.img_size, data_root_dir)
# tiny_data = torch.utils.data.Subset(data, range(1))
train_dataloader = DataLoader(data, batch_size=config.batch_size, shuffle=True, drop_last=False) #, collate_fn=custom_collate_fn)

## Test forward & reverse diffusion

In [None]:
img = next(iter(train_dataloader))[0]

In [None]:
steps = 10
diffusion = Diffusion(timesteps=config.T)
sampler = Sampler(sample_timesteps=config.T)
t_test = torch.full((1,), steps, device=config.device, dtype=torch.long)
img_t, _ = diffusion.forward_diffusion_sample(img, t_test)
img_0 = sampler.sample_plot_image(None, sampling_steps=steps-1, testing=True, img=img, return_last_img=True, sampling="DDIM", show=False)
show_any_images([img, img_t, img_0], cols=3)
del diffusion
del sampler

## Train a base model

In [None]:
# Skip to distillation section if trained already
# Step 1
# Define diffusion object
diffusion_ = Diffusion(timesteps=config.T)

In [None]:
# Step 2
# Define model
model = Unet(time_embd_dim=config.time_embd_dim).to(config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)


In [None]:
# Step 3
# Run this cell to load a checkpoint, otherwise skip it
model, optimizer = load_model(path="DDIM_unet_v7_ep24999.pt", 
                    model=Unet(config.time_embd_dim).to(config.device), optimizer=optimizer)

In [None]:
# Step 4
# Start the training loop
teacher_model, losses = train_model(model, diffusion_, iters=config.max_iters, batch_size=config.batch_size, 
                        optimizer=optimizer, dataloader=train_dataloader, T=config.T)

In [None]:
# plt.plot(losses)

In [None]:
# del teacher_model

## Distillation: Training a student model

In [None]:
# Skip above steps 1 - 4, if you already have trained base/teacher model
# Load teacher model checkpoint
# We will only map teacher model to cuda when computing distillation loss to save memory
tmodel, _ = load_model("DDIM_unet_v1_st4_ep29999.pt", Unet(config.time_embd_dim))
#tmodel.to(config.device)



In [None]:
# next(tmodel.parameters()).device

In [None]:
# del tmodel

###

In [None]:
# Start student training loop
iters = config.max_iters
sm_steps = 2 # sampling steps: 128-> 64-> 32-> 16-> 8-> 4-> 2
student_model, losses = train_student(iters, sm_steps, tmodel, config.batch_size, train_dataloader, 
                                    config.T, student_ckpt_path=None)

In [None]:
plt.plot(losses)

In [None]:
del student_model

## Sampling

In [None]:
# Initializing noisy vectors for sampling
# Making the noisy imgs fixed for fair comparison and reproducibility
torch.manual_seed(42)
num_samples = 4
noisy_imgs = [torch.randn((1, 3, config.img_size, config.img_size)) for i in range(num_samples)]

In [None]:
# Sampling using teacher model
# tmodel.eval()
# for param in tmodel.parameters():
#             param.requires_grad = False

tt = 1000 # timesteps for diffusion
stps = 256 # sampling steps
tsampler = Sampler(tt)
for i in tqdm(range(num_samples)):
    img = tsampler.sample_plot_image(tmodel.to(config.device), stps, return_last_img=True, img=noisy_imgs[i].to(config.device))
    #save_tensor_image(img, f"image_{i}_{stps}.png", (256, 256))

In [None]:
# Load student model checkpoint
smodel, _ = load_model("DDIM_unet_v1_st2_ep29999.pt", Unet(time_embd_dim=config.time_embd_dim).to(config.device))

In [None]:
# Sampling using student model
ts = tt # timesteps for diffusion
stps =16 # sampling steps
ssampler = Sampler(ts)
for i in tqdm(range(num_samples)):
    img = ssampler.sample_plot_image(smodel, stps, return_last_img=True, img=noisy_imgs[i].to(config.device))
#     save_tensor_image(img, f"s_image_{i}_{stps}.png", (256, 256))

In [None]:
del teacher_model
del model
del student_model