In [None]:
%cd /home/jrottmay/ml-dev

import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
from torchsummary import summary

import ipywidgets as widgets
from IPython.display import display

# Local code inclusion
from modules import *

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Additional Info when using cuda
if device == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

# Parameters & Misc

In [None]:
flag_train = False
train_from_scratch = False

slider = widgets.ToggleButton(value=False, description='Perform Training')
train_from_scratch_slider = widgets.ToggleButton(value=False, description='Train from Scratch')

def on_slider_change(change):
    global flag_train
    flag_train = change['new']

def on_train_from_scratch_button_click(b):
    global train_from_scratch
    train_from_scratch = True

slider.observe(on_slider_change, names='value')
train_from_scratch_slider.observe(on_train_from_scratch_button_click, names='value')

display(widgets.HBox([slider, train_from_scratch_slider]))


In [None]:
# Misc
batch_size = 128
log_dir = "./log"
img_dir = "./img"
wandb_dir = "./tmp"

# VAE PARAMETERS
base_channels = 32
latent_channels = 100
channel_multipliers=(1, 2, 4)
attention_resultions=(1)
dropout=0.2
norm="bn"
vae_model_checkpoint = f"{log_dir}/MNIST_VAE-test_01-iteration-80000-model.pth"
vae_optim_checkpoint = f"{log_dir}/MNIST_VAE-test_01-iteration-80000-optim.pth"

# DDPM PARAMETERS
activation = F.relu
use_labels = True
schedule = "cosine"
schedule_low = 1e-4
schedule_high = 2e-2
num_timesteps = 100
num_res_blocks = 2
loss_type = "l2"

ema_decay = 0.9999
ema_update_rate = 1

learning_rate = 2e-4
iterations = 100000
checkpoint_rate = 10000
log_rate = 1000
num_samples = 10
classes = torch.arange(10)
project_name = "Image Space Diffusion"
entity="jan-rottmayer"
run_name = "testing_2"
log_to_wandb = False
model_checkpoint = f"{log_dir}/MNIST_DDPM-test_01-iteration-100000-model.pth"
optim_checkpoint = f"{log_dir}/MNIST_DDPM-test_01-iteration-100000-optim.pth"

# Data

In [None]:
train_data = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST('~/data', 
                transform=torchvision.transforms.ToTensor(), 
                download=True),
            batch_size=batch_size,
            shuffle=True)

test_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('~/data', 
                train=False,
            transform=torchvision.transforms.ToTensor(), 
            download=True),
        batch_size=batch_size,
        shuffle=True)

# VAE PARAMETERS
kld_weight = batch_size / (len(train_data) * batch_size)  

# Model Definition

In [None]:
if schedule == "cosine":
    betas = np.linspace(schedule_low, schedule_high, num_timesteps)
else:
    betas = np.linspace(
        schedule_low * 1000 / num_timesteps,
        schedule_high * 1000 / num_timesteps,
        num_timesteps
    )

# VAE

In [None]:
vae = VAE(
    1,
    base_channels,
    latent_channels,
    channel_multipliers=channel_multipliers,
    attention_resultions=attention_resultions,
    dropout=dropout,
    norm=norm,
)

vae_optim = torch.optim.Adam(vae.parameters(), lr=learning_rate)

In [None]:
if flag_train:
    train(
        vae,
        vae_optim,
        train_data,
        lambda m, x, y: m.loss(x, y),
        test_data=test_data,
        iterations=100000,
        checkpoint_rate=10000,
        log_rate=1000,
        run_name="test_01",
        project_name="MNIST_VAE",
        chkpt_callback=visualize_mnist_sample,
        model_checkpoint=vae_model_checkpoint if not train_from_scratch else None,
        optim_checkpoint=vae_optim_checkpoint if not train_from_scratch else None,
    )
else:
    vae.load_state_dict(torch.load(f"{vae_model_checkpoint}"))

# Image Space DDPM

In [None]:
model = UNet(
        img_channels=1,
        base_channels=32,
        channel_mults=(1, 2, 4),
        time_emb_dim=128 * 2,
        norm="gn",
        dropout=0.1,
        activation=F.relu,
        attention_resolutions=(1,),
        num_classes=None if not use_labels else 10,
        initial_pad=0,
    )

diffusion = GaussianDiffusion(
        model, (28, 28), 1, 
        betas,
        ema_decay=ema_decay,
        ema_update_rate=ema_update_rate,
        ema_start=2000,
    )

diffusion_opt = torch.optim.Adam(diffusion.parameters(), lr=learning_rate)

In [None]:
def visualize_ddpm_mnist(
    model,
    optimizer=None,
    loss_function=None,
    train_data=None,
    test_data=None,
    iteration=None,
    iterations=None,
    run_name=None,
    log_to_wandb=None,
    wandb_dir=None,
    img_dir="./img",
    log_dir="./log",
    entity=None,
    project_name=None,
    device=None,
    **kwargs
):
    n_samples = 5
    channel = 1
    subset = torch.utils.data.Subset(train_data.dataset,np.array(range(n_samples)))
    vis_samples_loader = torch.utils.data.DataLoader(subset, batch_size=n_samples, shuffle=False)
    
    model.eval()  # Set the model to evaluation mode
    device = next(model.parameters()).device  # Use the same device as the model parameters
    
    with torch.no_grad():  # No need to calculate gradients for this
        for samples, y in vis_samples_loader:
            # Assuming your DataLoader returns just images, adjust if it returns a tuple (images, labels) or similar
            reconstructions = model(samples.to(device), y=y.to(device)).to('cpu')

            # Concatenate original images and reconstructions
            comparison = torch.cat([samples, reconstructions])
            comparison = comparison[:,channel,:,:]

            # Create a grid of images
            grid = torchvision.utils.make_grid(comparison, nrow=n_samples, padding=2, normalize=True)

            # Convert the tensor to a PIL image and display it
            plt.clf()
            plt.imshow(grid.permute(1, 2, 0).numpy())
        
            plt.text(0.5,0.0,f"Project: {project_name}\nRun: {run_name}\nTime: {datetime.datetime.now().replace(second=0, microsecond=0)}\nIteration: {iteration}\n")
            plt.axis('off')
            
            plt.savefig(f"{img_dir}/{project_name}-{run_name}-iteration-{iteration}-model.png")
            break

    # Put model back into training mode
    model.train()
    
    return None

visualize_ddpm_mnist(
    diffusion,
    train_data=train_data,
)

In [None]:
def loss_function(model, x, y=None):
        z = x
        return {'loss': model(z, y=y)} 

if flag_train:
    train(
        diffusion,
        diffusion_opt,
        train_data,
        loss_function,
        test_data=test_data,
        iterations=100000,
        checkpoint_rate=10000,
        log_rate=1000,
        run_name="test_01",
        project_name="MNIST_DDPM",
        chkpt_callback=visualize_ddpm_mnist_sample,
        model_checkpoint=model_checkpoint if not train_from_scratch else None,
        optim_checkpoint=optim_checkpoint if not train_from_scratch else None,
    )
else:
    diffusion.load_state_dict(torch.load(f"{model_checkpoint}"))

# Latent Diffusion - based on VAE