In [None]:
import torch
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm.auto import tqdm
import nrrd
import SimpleITK as sitk
import plotly.express as px
import plotly.subplots as sp
from diffusers import DDPMScheduler, UNet2DModel
from PIL import Image
import json
import cv2
import os

device = torch.device("cuda")

class DiffusionModel(pl.LightningModule):
    def __init__(self, lr=1e-5, num_train_timesteps=1000):
        super().__init__()
        self.save_hyperparameters()
        self.net = UNet2DModel(
            in_channels=1,  # the number of input channels, 3 for RGB images
            out_channels=1,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(256, 256, 512, 512, 1024, 1024),  # the number of output channes for each UNet block
            down_block_types=( 
                "DownBlock2D",  # a regular ResNet downsampling block
                "DownBlock2D", 
                "DownBlock2D", 
                "DownBlock2D", 
                "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                "DownBlock2D",
            ), 
            up_block_types=(
                "UpBlock2D",  # a regular ResNet upsampling block
                "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D", 
                "UpBlock2D", 
                "UpBlock2D", 
                "UpBlock2D"  
            ),
        )
        self.scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, beta_schedule="linear")
        self.loss_fn = nn.MSELoss()

    def forward(self, x, timesteps):
        return self.net(x, timesteps).sample

    def training_step(self, batch, batch_idx):
        x= batch  # Ignore labels if they exist
        timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (x.size(0),), device=self.device).long()
        noise = torch.randn_like(x).to(self.device)
        noisy_images = self.scheduler.add_noise(x, noise, timesteps)

        pred = self(noisy_images, timesteps)  # Model forward pass
        loss = self.loss_fn(pred, noise)

        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch
        timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (x.size(0),), device=self.device)
        outputs = self(x, timesteps)
        loss = self.loss_fn(outputs, x)

        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr)
        return optimizer

checkpoint_path = "/mnt/raid/home/ajarry/data/outputs_lightning/final53epoch/model.pth"
model = DiffusionModel()
state_dict = torch.load(checkpoint_path)
model.load_state_dict(state_dict, strict=False)
model = model.to(device)


In [None]:
# Define the scheduler (make sure it matches the one used in training)
scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")

# Set up sampling parameters
image_size = (1, 1, 256, 256)  # Adjust shape (batch, channels, height, width)
num_inference_steps =200  # Can be reduced for faster generation
scheduler.set_timesteps(num_inference_steps)
# Start with pure noise
device = "cuda" if torch.cuda.is_available() else "cpu"
noisy_image = torch.randn(image_size).to(device)

# Sample timesteps
timesteps = scheduler.timesteps.to(device)

# Reverse process (denoising)
with torch.no_grad():
    for i, t in tqdm(enumerate(timesteps)):
        # Predict noise
        noise_pred = model(noisy_image, t.unsqueeze(0))
        
        # Remove noise using scheduler
        noisy_image = scheduler.step(noise_pred, t, noisy_image).prev_sample

# Convert to image format

generated_image = noisy_image.cpu().squeeze(0).permute(1,2,0)

print(generated_image.shape)

plt.imshow(generated_image, cmap="gray")
plt.axis("off")
plt.show()

In [None]:
data = sitk.GetArrayFromImage(sitk.ReadImage('/mnt/raid/home/ajarry/data/cephalic_sweeps/frame_0001/C3_us.nrrd'))
print(data.shape)
fig = px.imshow(data,animation_frame=0, binary_string=True)
fig.show()
totensor = transforms.ToTensor()
targets=[]
for frame in data:
    targets.append(totensor(np.uint8(frame)).to(device))

In [None]:
def guidance_loss(image, target):
    return torch.abs(image - target).mean()

# The guidance scale determines the strength of the effect
guidance_loss_scale = 175 # Explore changing this to 5, or 100

noise = torch.randn(1, 1, 256, 256).to(device)
stack = []
for target in targets:
    x = noise
    for i, t in tqdm(enumerate(scheduler.timesteps)):

        # Prepare the model input
        # model_input = scheduler.scale_model_input(x, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = model(x, t)

        # Set x.requires_grad to True
        x = x.detach().requires_grad_()

        # Get the predicted x0
        x0 = scheduler.step(noise_pred, t, x).pred_original_sample

        # Calculate loss
        loss = guidance_loss(x0, target) * guidance_loss_scale
        # if i % 10 == 0:
        #     print(i, "loss:", loss.item())

        # Get gradient
        cond_grad = -torch.autograd.grad(loss, x)[0]

        # Modify x based on this gradient
        x = x.detach() + cond_grad

        # Now step with scheduler
        x = scheduler.step(noise_pred, t, x).prev_sample
    stack.append(x.squeeze(0).cpu().permute(2,1,0))

# print(x.shape)
# img = x.squeeze(0).cpu().permute(1,2,0)

# plt.imshow(img, cmap="gray")
# plt.axis("off")

In [None]:
volume = np.stack(stack,axis=2).squeeze()
print(volume.shape)
out = '/mnt/raid/home/ajarry/data/image_capture_output/gpu4diffused.nrrd'
nrrd.write(out,volume)

In [None]:
data = sitk.GetArrayFromImage(sitk.ReadImage('/mnt/raid/home/ajarry/data/cephalic_sweeps/frame_0001/C1_label.nrrd'))
print(data.shape)
fig = px.imshow(data,animation_frame=0, binary_string=True)
fig.show()

In [None]:
import glob

nrrd_paths = [
    path for path in glob.glob("/mnt/raid/home/ajarry/data/cephalic_sweeps/frame_*/*")
    if path.endswith(".nrrd") and "us" in os.path.basename(path).lower()
]

print(nrrd_paths)