## Denoising Diffusion Probabilistic Models

[34m[1mwandb[0m: Currently logged in as: [33ms204078[0m ([33mdiffusion_[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [29]:
import torch
from data import DiffSet
import pytorch_lightning as pl
from model import DiffusionModel
from torch.utils.data import DataLoader
import imageio
import glob
import wandb
import torchvision
import torchvision.utils

### Set model parameters

In [7]:
# Training hyperparameters
diffusion_steps = 1000
dataset_choice = "CIFAR"
max_epoch = 10
batch_size = 128

# Loading parameters
load_model = False
load_version_num = 1

### Load dataset and train model

In [8]:
# Code for optionally loading model
pass_version = None
last_checkpoint = None

if load_model:
    pass_version = load_version_num
    last_checkpoint = glob.glob(
        f"./lightning_logs/{dataset_choice}/version_{load_version_num}/checkpoints/*.ckpt"
    )[-1]

In [12]:
# Create datasets and data loaders
train_dataset = DiffSet(True, dataset_choice)
val_dataset = DiffSet(False, dataset_choice)

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, shuffle=True)

# Create model and trainer
if load_model:
    model = DiffusionModel.load_from_checkpoint(last_checkpoint, in_size=train_dataset.size*train_dataset.size, t_range=diffusion_steps, img_depth=train_dataset.depth)
else:
    model = DiffusionModel(train_dataset.size*train_dataset.size, diffusion_steps, train_dataset.depth)

Files already downloaded and verified
Files already downloaded and verified


In [22]:
# Load Trainer model

wandb_logger = pl.loggers.WandbLogger(
    name=dataset_choice,
    version=pass_version,
    project="Diffusion_Project",
    log_model=True,
)

trainer = pl.Trainer(
    max_epochs=max_epoch, 
    log_every_n_steps=10, 
    # gpus=1, 
    # auto_select_gpus=True,
    # resume_from_checkpoint=last_checkpoint, 
    logger=wandb_logger,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [23]:
# Train model
trainer.fit(model, train_loader, val_loader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33ms204078[0m ([33mdiffusion_[0m). Use [1m`wandb login --relogin`[0m to force relogin



   | Name  | Type       | Params | Mode 
----------------------------------------------
0  | inc   | DoubleConv | 38.8 K | train
1  | down1 | Down       | 295 K  | train
2  | down2 | Down       | 1.2 M  | train
3  | down3 | Down       | 2.4 M  | train
4  | up1   | Up         | 6.2 M  | train
5  | up2   | Up         | 1.5 M  | train
6  | up3   | Up         | 406 K  | train
7  | outc  | OutConv    | 195    | train
8  | sa1   | SAWrapper  | 395 K  | train
9  | sa2   | SAWrapper  | 395 K  | train
10 | sa3   | SAWrapper  | 99.6 K | train
----------------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.681    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/fredmac/Documents/DTU-FredMac/pytorch-diffusion/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/fredmac/Documents/DTU-FredMac/pytorch-diffusion/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                           

/Users/fredmac/Documents/DTU-FredMac/pytorch-diffusion/.venv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 9: 100%|██████████| 391/391 [04:54<00:00,  1.33it/s, v_num=55d8]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 391/391 [04:54<00:00,  1.33it/s, v_num=55d8]


### Sample from model

In [40]:
gif_shape = [3, 3]
sample_batch_size = gif_shape[0] * gif_shape[1]
n_hold_final = 10

# Generate samples from denoising process
gen_samples = []
x = torch.randn((sample_batch_size, train_dataset.depth, train_dataset.size, train_dataset.size))
sample_steps = torch.arange(model.t_range-1, 0, -1)
for t in sample_steps:
    x = model.denoise_sample(x, t)
    if t % 50 == 0:
        gen_samples.append(x)
for _ in range(n_hold_final):
    gen_samples.append(x)
gen_samples = torch.stack(gen_samples, dim=0).moveaxis(2, 4).squeeze(-1)
gen_samples_ = (gen_samples.clamp(-1, 1) + 1) / 2

In [50]:
# Process samples and save as gif
gen_samples = (gen_samples * 255).type(torch.uint8)
gen_samples = gen_samples.reshape(-1, gif_shape[0], gif_shape[1], train_dataset.size, train_dataset.size, train_dataset.depth)

def stack_samples(gen_samples, stack_dim):
    gen_samples = list(torch.split(gen_samples, 1, dim=1))
    for i in range(len(gen_samples)):
        gen_samples[i] = gen_samples[i].squeeze(1)
    return torch.cat(gen_samples, dim=stack_dim)

gen_samples = stack_samples(gen_samples, 2)
gen_samples = stack_samples(gen_samples, 2)
image_tensor = gen_samples[-1].permute(2, 0, 1)
print(f"Shape: {image_tensor.shape}, Max value: {image_tensor.max()}, Data type: {image_tensor.dtype}")

# Normalize tensor to range [0, 1] and convert to float32 if necessary
if image_tensor.dtype != torch.float32:
    print("Converting tensor to torch.float32 and normalizing")
    image_tensor = image_tensor.float() / 255.0

# Save the image
torchvision.utils.save_image(image_tensor, "imgs_own/pred.png", normalize=False)
# imageio.mimsave(
#     f"imgs_own/pred2.gif",
#     list(gen_samples),
#     fps=5,
# )

Shape: torch.Size([3, 96, 96]), Max value: 255, Data type: torch.uint8
Converting tensor to torch.float32 and normalizing
