# Test diffusion models

## 1. Workspace setup

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
import yaml
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torchmetrics
import pandas as pd
from diffusers import UNet2DModel, DDPMScheduler
from spinediffusion.models.diffusion_models import UnconditionalDiffusionModel
from pathlib import Path
from tensorflow.python.summary.summary_iterator import summary_iterator

## 2. Load event data from tensorboard

This takes all the events in the tensorboard event files and loads them into a pandas dataframe. Each event has a time, a tag and a value. In addition we add a run_name to differentiate between different runs. This way, we can compare different runs in the same plot.

In [None]:
log_paths = [
    Path(
        f"P:\\Projects\\LMB_4Dspine\\Iship_Pau_Altur_Pastor\\4_training_logs\\logs\\depthmap\\version_{i}"
    )
    for i in range(6, 11)
]

df_tf = pd.DataFrame(columns=["run_name", "time", "tag", "value"])

for path in log_paths:
    event_file = list(path.glob("events.out.tfevents.*"))[0]
    run_name = path.stem

    for e in summary_iterator(str(path / event_file)):
        if len(e.summary.value) == 0:
            continue

        df_tf.loc[len(df_tf)] = [
            run_name,
            e.wall_time,
            e.summary.value[0].tag,
            e.summary.value[0].simple_value,
        ]

df_tf

In [None]:
with open(log_path / "config.yaml", "r") as f:
    config = yaml.safe_load(f)

model = UNet2DModel(**config["model"]["init_args"]["model"]["init_args"])
scheduler = DDPMScheduler(**config["model"]["init_args"]["scheduler"]["init_args"])
loss = torch.nn.MSELoss(**config["model"]["init_args"]["loss"]["init_args"])

metrics = []
for metric_dict in config["model"]["init_args"]["metrics"].values():
    metrics.append(eval(metric_dict["class_path"])(**metric_dict["init_args"]))

In [None]:
lightning_module = UnconditionalDiffusionModel.load_from_checkpoint(
    log_path / "checkpoints" / "epoch=10-step=2519.ckpt",
    model=model,
    scheduler=scheduler,
    loss=loss,
    metrics=metrics,
)

## 3. Visualize tensorflow logs

In [None]:
o

In [None]:
df_tf = pd.DataFrame()

for e in summary_iterator(
    str(log_path / "events.out.tfevents.1719496170.Portatil_Pau.95140.0")
):
    if len(e.summary.value) == 0:
        continue
    index = e.wall_time
    tag = e.summary.value[0].tag
    value = e.summary.value[0].simple_value
    df_tf.loc[index, tag] = value

In [None]:
df_tf

In [None]:
import numpy as np

np.diff(df_tf.index.values())

## 4. Test inference