# Test diffusion models

## 1. Workspace setup

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
import yaml
import glob
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torchmetrics
import pandas as pd
from diffusers import UNet2DModel, DDPMScheduler
from spinediffusion.models.diffusion_models import UnconditionalDiffusionModel
from pathlib import Path
from tensorflow.python.summary.summary_iterator import summary_iterator

## 2. Load data

### 2.1. CSV logs

This data has been previously transformed to csv format at the end of training by a pytorch callback and saved to disk. For more information refer to the source code of the `GenerateCSVLog` within the `callbacks.py` file.

In [None]:
log_paths = [
    Path(
        f"P:\\Projects\\LMB_4Dspine\\Iship_Pau_Altur_Pastor\\4_training_logs\\logs\\depthmap\\version_{i}\events.csv"
    )
    for i in range(6, 11)
]

df_tf = pd.DataFrame(columns=["run_name", "time", "tag", "value"])

for path in log_paths:
    run_name = path.parent.stem

    df_run = pd.read_csv(path)
    df_run["run_name"] = run_name

    df_tf = pd.concat([df_tf, df_run])

df_tf = df_tf.sort_values(by=["run_name", "tag", "time"])
df_tf["step"] = df_tf.groupby(["run_name", "tag"]).cumcount()

df_tf

In [None]:
df_tf.sort_values(["run_name", "tag", "time"])

### 2.2. Load config files

In [None]:
configs = {}

for path in log_paths:
    run_name = path.parent.stem

    with open(path.parent / "config.yaml", "r") as f:
        config = yaml.safe_load(f)

    configs[run_name] = config

### 2.3. Combine them

In [None]:
for run in configs:
    df_tf.loc[df_tf["run_name"] == run, "lr"] = configs[run]["optimizer"]["init_args"][
        "lr"
    ]

    max_epochs = configs[run]["trainer"]["max_epochs"]
    max_steps = df_tf.loc[df_tf["run_name"] == run, "step"].max()
    df_tf.loc[df_tf["run_name"] == run, "epoch"] = (
        df_tf.loc[df_tf["run_name"] == run, "step"] // max_steps
    ) * max_epochs

df_tf

## 3. Plot training curves

### Per run

In [None]:
keys = [
    "train_loss_step",
    "MSELoss",
    "PSNR",
    "SSIM",
    "val_loss_step",
    "val_loss_epoch",
    "train_loss_epoch",
]

for run in df_tf.run_name.unique():
    df_run = df_tf[df_tf.run_name == run]
    lr = df_run.lr.unique()[0]

    for key in keys:
        df_run_key = df_run[df_run.tag == key]
        plt.plot(df_run_key.value)
        plt.title(f"{run} - {lr}")
        plt.xlabel("time")
        plt.ylabel(key)
        plt.grid()
        plt.show()

### Compare runs

In [None]:
sns.lineplot(data=df_tf, x="step", y="value", hue="lr", col="tag")

In [None]:
for key in keys:
    for run in df_tf.run_name.unique():
        df_run_key = df_tf[(df_tf.run_name == run) & (df_tf.tag == key)]
        plt.plot(df_run_key.value, label=configs[run]["optimizer"]["init_args"]["lr"])

    plt.title(f"{key} vs step")
    plt.xlabel("step")
    plt.ylabel(key)
    plt.legend()
    plt.grid()
    plt.show()

### Compare train and validation curves

In [None]:
for run in df_tf.run_name.unique():
    df_run = df_tf[df_tf.run_name == run]

    plt.plot(
        df_run[df_run.tag == "train_loss_step"].step,
        df_run[df_run.tag == "train_loss_step"].value,
        label="train_loss",
    )
    plt.plot(
        df_run[df_run.tag == "val_loss_step"].step,
        df_run[df_run.tag == "val_loss_step"].value,
        label="val_loss",
    )
    plt.title(f"{run} - lr : {df_run.lr.unique()[0]}")
    plt.xlabel("step")
    plt.ylabel("loss")
    plt.legend()
    plt.grid()
    plt.show()

In [None]:
sns.set_theme("talk")
colors = ["r", "g", "b"]
runs = ["version_7", "version_8", "version_9"]

fig, ax = plt.subplots(figsize=(20, 8))

for run, color in zip(runs, colors):
    df_run = df_tf[df_tf.run_name == run]
    lr = df_run.lr.unique()[0]

    ax.plot(
        df_run[df_run.tag == "train_loss_step"].step,
        df_run[df_run.tag == "train_loss_step"].value,
        color + "-",
        label=f"{run} - lr : {lr}",
    )
    ax.plot(
        df_run[df_run.tag == "val_loss_step"].step,
        df_run[df_run.tag == "val_loss_step"].value,
        color + "--",
        label=f"{run} - lr : {lr}",
    )

ax.set_xlabel("step")
ax.set_ylabel("loss")
ax.set_yscale("log")
ax.legend()
ax.grid()

plt.show()

## 4. Test inference

Load model from checkpoint

In [None]:
run_name = "version_7"

config = configs[run_name]
ckpt_path = glob.glob(
    str(
        [path for path in log_paths if run_name in str(path)][0].parent
        / "checkpoints"
        / "*.ckpt"
    )
)[0]

model = UNet2DModel(**config["model"]["init_args"]["model"]["init_args"])
scheduler = DDPMScheduler(**config["model"]["init_args"]["scheduler"]["init_args"])
loss = eval(config["model"]["init_args"]["loss"]["class_path"])(
    **config["model"]["init_args"]["loss"]["init_args"]
)
metrics = []
for metric_dict in config["model"]["init_args"]["metrics"].values():
    metric = eval(metric_dict["class_path"])(**metric_dict["init_args"])
    metrics.append(metric)


lightning_model = UnconditionalDiffusionModel.load_from_checkpoint(
    ckpt_path, model=model, scheduler=scheduler, loss=loss, metrics=metrics
)

Make some predictions

In [None]:
batch_size = 16
n_channels = 1
height = config["data"]["init_args"]["transform_args"]["project_to_plane"]["height"]
width = config["data"]["init_args"]["transform_args"]["project_to_plane"]["width"]

input_noise = torch.randn(1, batch_size, n_channels, height, width)
generated_images = lightning_model.predict_step(input_noise, 38474)

In [None]:
n_cols = 4
n_rows = np.ceil(generated_images.shape[0] / n_cols).astype(int)

fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 20))

for i, ax in enumerate(axs.flat):
    if i < generated_images.shape[0]:
        ax.imshow(generated_images[i, 0].cpu().numpy(), cmap="gray")