# Test diffusion models

## 1. Workspace setup

Below we are activating two extensions:
- [Autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) to automatically reload modules when they change. Very useful when you are working on code in python files and want to test it in the notebook.
- [Jupyter-black](https://github.com/drillan/jupyter-black) to format the code cells with the black formatter.

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In this cell we import all of the modules needed to visualize the results of the trained diffusion models.

In [None]:
import yaml
import glob
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torchmetrics
import pandas as pd
import diffusers
from diffusers import UNet2DModel, DDPMScheduler
from spinediffusion.models.diffusion_models import UnconditionalDiffusionModel
from spinediffusion.datamodule.datamodule import SpineDataModule
from spinediffusion.utils.misc import find_test_param
from pathlib import Path
from pytorch_lightning import Trainer

## 2. Define analysis

The cells below controls the analysis that will be performed in this notebook. We will define the paths from which the model checkpoints and event logs will be loaded. 

### 2.1. Log paths

The paths to the logs and checkpoints of the models are defined below. They are defined through three variables:
- `versions`: The version number of each training run.
- `logdir`: The path to the parent directory where the logs are stored.
- `logsubdir`: The subdirectory where the versions are stored.

In [None]:
# version numbers to load
versions = [20]
logdir = Path("P:/Projects/LMB_4Dspine/Iship_Pau_Altur_Pastor/4_training_logs/logs")
logsubdir = Path("depthmap")

# compose the log paths
log_paths = [logdir / logsubdir / f"version_{i}" for i in versions]

### 2.2. Auxiliary variables

In this cell you can define a number of variables that control which analyses will be performed:
- `test_param`: Defines whether a sweep for a specific parameter has been performed and thus whether a comparison of the metrics for each parameter value should be performed.

In [None]:
test_param = None
axis_scale = "log"

## 3. Load data

### 3.1. CSV logs

This data has been previously transformed to csv format at the end of training by a pytorch callback and saved to disk. For more information refer to the source code of the `GenerateCSVLog` within the `callbacks.py` file.

In [None]:
df_tf = pd.DataFrame(columns=["run_name", "time", "tag", "value"])

for path in log_paths:
    run_name = path.stem

    df_run = pd.read_csv(path / "events.csv")
    df_run["run_name"] = run_name

    df_tf = pd.concat([df_tf, df_run])

df_tf = df_tf.sort_values(by=["run_name", "tag", "time"])
df_tf["step"] = df_tf.groupby(["run_name", "tag"]).cumcount()

df_tf.reset_index(drop=True, inplace=True)

df_tf

### 3.2. Load config files

The configuration files are loaded to extract the parameters used in the training of the models. They contain all of the information needed to reproduce the training of the models and to know what exactly where the parameters used in that specific run.

In [None]:
configs = {}

for path in log_paths:
    run_name = path.stem
    with open(path / "config.yaml", "r") as f:
        config = yaml.safe_load(f)

    configs[run_name] = config

### 3.3. Combine them

Specific data from the config files is combined with the csv logs to create a single dataframe that contains all of the information needed to analyze the training of the models. These include things like an epoch number and the `test_param` value.	

In [None]:
for run in configs:
    df_tf.loc[df_tf["run_name"] == run, test_param] = find_test_param(
        configs[run], test_param
    )

    max_epochs = configs[run]["trainer"]["max_epochs"]
    max_steps = df_tf.loc[df_tf["run_name"] == run, "step"].max()
    df_tf.loc[df_tf["run_name"] == run, "epoch"] = (
        df_tf.loc[df_tf["run_name"] == run, "step"] * max_epochs
    ) // max_steps
    df_tf.loc[df_tf["run_name"] == run, "epoch_fraction"] = (
        df_tf.loc[df_tf["run_name"] == run, "step"] * max_epochs
    ) / max_steps

df_tf

## 4. Plot training curves

### 4.1. Per run

In [None]:
keys = [
    "train_loss_step",
    "MSELoss_step",
    "PSNR_step",
    "SSIM_step",
    "val_loss_step",
    "val_loss_epoch",
    "train_loss_epoch",
]

for run in df_tf.run_name.unique():
    df_run = df_tf[df_tf.run_name == run]
    param_val = df_run[test_param].unique()[0]

    for key in keys:
        df_run_key = df_run[df_run.tag == key]
        plt.plot(df_run_key.epoch_fraction, df_run_key.value)
        plt.title(f"{run} - {test_param} : {param_val}")
        plt.xlabel("Epoch")
        pretty_key = key.replace("_step", "").replace("_epoch", "").replace("_", " ")
        plt.ylabel(pretty_key)
        if axis_scale == "log":
            plt.yscale("log")
        plt.grid()
        plt.show()

### 4.2. Compare runs

In [None]:
for key in keys:
    for run in df_tf.run_name.unique():
        df_run_key = df_tf[(df_tf.run_name == run) & (df_tf.tag == key)]
        plt.plot(
            df_run_key.epoch_fraction,
            df_run_key.value,
            label=find_test_param(configs[run], test_param),
        )

    pretty_key = key.replace("_step", "").replace("_epoch", "").replace("_", " ")
    plt.title(f"{pretty_key} vs epoch")
    plt.xlabel("Epoch")
    plt.ylabel(pretty_key)
    if axis_scale == "log":
        plt.yscale("log")
    plt.legend()
    plt.grid()
    plt.show()

### 4.3. Compare train and validation curves

In [None]:
for run in df_tf.run_name.unique():
    df_run = df_tf[df_tf.run_name == run]

    plt.plot(
        df_run[df_run.tag == "train_loss_epoch"].epoch_fraction,
        df_run[df_run.tag == "train_loss_epoch"].value,
        label="train loss",
    )
    plt.plot(
        df_run[df_run.tag == "val_loss_epoch"].epoch_fraction,
        df_run[df_run.tag == "val_loss_epoch"].value,
        label="val loss",
    )
    plt.title(f"{run} - {test_param} : {df_run[test_param].unique()[0]}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    if axis_scale == "log":
        plt.yscale("log")
    plt.legend()
    plt.grid()
    plt.show()

## 5. Test inference

### 5.1. Load models from checkpoint

In [None]:
lightning_models = {}

for run in df_tf.run_name.unique():
    df_run = df_tf[df_tf.run_name == run]
    best_val_loss = df_run.loc[df_run.tag == "val_loss_epoch", "value"].min()
    print(f"{run} - best val loss: {best_val_loss}")

    ckpt_path = glob.glob(str(logdir / logsubdir / run / "checkpoints" / "*.ckpt"))[0]

    config = configs[run]
    model = UNet2DModel(**config["model"]["init_args"]["model"]["init_args"])
    if isinstance(config["model"]["init_args"]["scheduler"], dict):
        scheduler = eval(config["model"]["init_args"]["scheduler"]["class_path"])(
            **config["model"]["init_args"]["scheduler"]["init_args"]
        )
    else:
        scheduler = eval(config["model"]["init_args"]["scheduler"])()
    loss = eval(config["model"]["init_args"]["loss"]["class_path"])(
        **config["model"]["init_args"]["loss"]["init_args"]
    )
    metrics = []
    for metric_dict in config["model"]["init_args"]["metrics"].values():
        metric = eval(metric_dict["class_path"])(**metric_dict["init_args"])
        metrics.append(metric)

    lightning_model = UnconditionalDiffusionModel.load_from_checkpoint(
        ckpt_path, model=model, scheduler=scheduler, loss=loss, metrics=metrics
    )
    lightning_models[run] = lightning_model

### 5.2. Perform inference

In [None]:
with open("../configs/config_uncond.yaml", "r") as f:
    data_config = yaml.safe_load(f)

data_config["data"]["init_args"]["predict_size"] = 16

datamodule = SpineDataModule(**data_config["data"]["init_args"])
datamodule.setup("test")

predict_dataloader = datamodule.predict_dataloader()

trainer = Trainer()
generated_images = {}

for run, lightning_model in lightning_models.items():
    generated_images[run] = trainer.predict(
        lightning_model, dataloaders=predict_dataloader
    )[0]

In [None]:
for run, images in generated_images.items():

    n_cols = 4
    n_rows = np.ceil(images.shape[0] / n_cols).astype(int)

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 20))

    for i, ax in enumerate(axs.flat):
        if i < images.shape[0]:
            img = images[i, 0].cpu().numpy()
            img[img < 0] = 0
            ax.imshow(img, cmap="gray")
            ax.axis("off")

    fig.suptitle(run)

### 5.3. Compute Frechet Inception Distance (FID)

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

real_images = datamodule.train_data[:16][0]

real_images = torch.cat([real_images] * 3, dim=1)

for run, images in generated_images.items():
    images = torch.cat([images] * 3, dim=1)
    images[images < 0] = 0

    # transform images to 255 range with dtype uint8
    real_images = (real_images * 255).byte()
    images = (images * 255).byte()

    fid = FrechetInceptionDistance()

    fid.update(real_images, real=True)
    fid.update(images, real=False)

    print(f"{run} - FID : {fid.compute()}")