# Test diffusion models

## 1. Workspace setup

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
import yaml
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchmetrics
import pandas as pd
from diffusers import UNet2DModel, DDPMScheduler
from spinediffusion.models.diffusion_models import UnconditionalDiffusionModel
from pathlib import Path

## 2. Load model from checkpoint

In [None]:
log_path = Path(
    "P:\\Projects\\LMB_4Dspine\\Iship_Pau_Altur_Pastor\\4_training_logs\\depthmap\\version_1"
)

In [None]:
with open(log_path / "config.yaml", "r") as f:
    config = yaml.safe_load(f)

model = UNet2DModel(**config["model"]["init_args"]["model"]["init_args"])
scheduler = DDPMScheduler(**config["model"]["init_args"]["scheduler"]["init_args"])
loss = torch.nn.MSELoss(**config["model"]["init_args"]["loss"]["init_args"])

metrics = []
for metric_dict in config["model"]["init_args"]["metrics"].values():
    metrics.append(eval(metric_dict["class_path"])(**metric_dict["init_args"]))

In [None]:
lightning_module = UnconditionalDiffusionModel.load_from_checkpoint(
    log_path / "checkpoints" / "epoch=10-step=2519.ckpt",
    model=model,
    scheduler=scheduler,
    loss=loss,
    metrics=metrics,
)

## 3. Visualize tensorflow logs

In [None]:
from tensorflow.python.summary.summary_iterator import summary_iterator

In [None]:
df_tf = pd.DataFrame()

for e in summary_iterator(
    str(log_path / "events.out.tfevents.1719496170.Portatil_Pau.95140.0")
):
    if len(e.summary.value) == 0:
        continue
    index = e.wall_time
    tag = e.summary.value[0].tag
    value = e.summary.value[0].simple_value
    df_tf.loc[index, tag] = value

In [None]:
df_tf

## 4. Test inference

In [None]:
sample_size = config["model"]["init_args"]["model"]["init_args"]["sample_size"]

x = torch.randn(1, 1, sample_size, sample_size)
lightning_module.predict_step(x)

In [None]:
from diffusers import DiffusionPipeline


class MyPipeline(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)

    @torch.no_grad()
    def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
        image = torch.randn(
            (
                batch_size,
                self.unet.in_channels,
                self.unet.sample_size,
                self.unet.sample_size,
            )
        )
        image = image.to(self.unet.device)

        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.progress_bar(self.scheduler.timesteps):
            model_output = self.unet(image, t).sample
            image = self.scheduler.step(model_output, t, image).prev_sample

        return image


pipeline = MyPipeline(unet=lightning_module.model, scheduler=lightning_module.scheduler)

In [None]:
pipeline.enable_sequential_cpu_offload()

In [None]:
pipeline()