# Test diffusion models

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [None]:
import yaml
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from diffusers import UNet2DModel, DDPMScheduler
from spinediffusion.datamodule.datamodule import SpineDataModule

In [None]:
with open("../configs/config.yaml", "r") as f:
    config = yaml.safe_load(f)

data_config = config["data"]["init_args"]
data_config["n_subjects"] = 5
data_config["use_cache"] = False
data_config["transform_args"]["project_to_plane"]["height"] = 128
data_config["transform_args"]["project_to_plane"]["width"] = 128


data_module = SpineDataModule(**data_config)

In [None]:
data_module.setup(stage=None)

In [None]:
img = data_module.train_data[0]["depth_map"].unsqueeze(0).type(torch.float32)
img_size = img.shape[-1]

model = UNet2DModel(
    img_size,  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(
        128,
        128,
        256,
        256,
        512,
        512,
    ),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

scheduler = DDPMScheduler(num_train_timesteps=1000)

In [None]:
noise = torch.randn(img.shape)
timesteps = torch.LongTensor([50])
noisy_image = scheduler.add_noise(img, noise, timesteps)

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(img[0, 0], cmap="gray")
axs[1].imshow(noisy_image[0, 0], cmap="gray")
axs[2].imshow(noisy_image[0, 0] - img[0, 0], cmap="gray")

In [None]:
noise_pred = model(noisy_image, timesteps).sample
loss = F.mse_loss(noise_pred, noise)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(noise[0, 0], cmap="gray")
axs[1].imshow(noise_pred.detach().numpy()[0, 0], cmap="gray")
axs[2].imshow(noise_pred.detach().numpy()[0, 0] - noise.numpy()[0, 0], cmap="gray")

In [None]:
import yaml

with open("../configs/config.yaml", "r") as f:
    config = yaml.safe_load(f)

print(config["model"])

In [None]:
from spinediffusion.models.diffusion_models import DepthMapDiffusionModel
from diffusers import UNet2DModel, DDPMScheduler, get_cosine_schedule_with_warmup
import torch

model = UNet2DModel(**config["model"]["init_args"]["model"]["init_args"])
scheduler = DDPMScheduler(**config["model"]["init_args"]["scheduler"]["init_args"])
optimizer = torch.optim.AdamW
lr_scheduler = get_cosine_schedule_with_warmup
loss = torch.nn.MSELoss


In [None]:
lightning_model = DepthMapDiffusionModel(model, scheduler, optimizer, lr_scheduler, 0.001, loss, 100) 

In [None]:
lightning_model

In [None]:
for batch in train_dataloader:
    print(batch)