In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms as T
import numpy as np

import nrrd
import pandas as pd
import os
import sys

# sys.path.append("/mnt/raid/C1_ML_Analysis/source/autoencoder/src")
sys.path.append("/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl/")
sys.path.append("/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl/nets")

from nets import diffusion
from transforms import ultrasound_transforms as ust 

from diffusers import DDIMScheduler

import plotly.express as px
import plotly.graph_objects as go

from nets.us_simu import VolumeSamplingBlindSweep, SweepSampling

import lpips
from torchmetrics.image import StructuralSimilarityIndexMeasure



In [None]:
mount_point = '/mnt/raid/C1_ML_Analysis'
model = diffusion.DiffusionModel.load_from_checkpoint(os.path.join(mount_point, 'train_output/diffusion/0.1/epoch=76-val_loss=0.01.ckpt'))
model.eval()
model = model.cuda()

AE = diffusion.AutoEncoderKL.load_from_checkpoint("/mnt/raid/C1_ML_Analysis/train_output/diffusionAE/extract_frames_Dataset_C_masked_resampled_256_spc075_wscores_meta_BPD01_MACFL025-7mo-9mo/v0.4/epoch=72-val_loss=0.01.ckpt")
AE.eval()
AE = AE.cuda()

In [None]:
model

In [None]:
args = {    
    'probe_paths': '/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export/all_poses/frame_0001/probe_paths',
    'diffusor_fn': '/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export/all_poses/frame_0001.nrrd',
    'params_csv': '/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export/shapes_intensity_map_nrrd_speckel.csv',    
    'grid_w': 256,
    'grid_h': 256,
    'center_x': 128.0,
    'center_y': -40.0,
    'r1': 20.0,
    'r2': 255.0,
    'theta': np.pi / 4.25,
    'padding': 55,  # Padding for the simulated ultrasound    
}
ss = SweepSampling(**args).cuda()

In [None]:
with torch.no_grad():
    x_hat = model.sample()

fig = go.Figure()

fig.add_trace(go.Heatmap(z=np.flip(x_hat.squeeze().cpu().numpy(), axis=0), opacity=0.8, colorscale='hot'))
# fig.add_trace(go.Heatmap(z=np.flip(img2, axis=0), opacity=0.1, colorscale='ice'))

fig.update_layout(height=800, width=800)
fig.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

def ssim_loss(image, target):
    # SSIM returns a similarity (1 = identical)
    return 1.0 - ssim_metric(image, target)

perceptual_metric = lpips.LPIPS(net='alex').to(device)

def perceptual_loss(image, target):
    # LPIPS expects 3-channel images in [-1, 1]
    image_3c = image.repeat(1, 3, 1, 1) * 2 - 1
    target_3c = target.repeat(1, 3, 1, 1) * 2 - 1
    return perceptual_metric(image_3c, target_3c).mean()


def guidance_loss(image, target, weights=(1.0, 1.0, 0.1)):
    l1 = torch.abs(image - target).mean()
    ssim = ssim_loss(image, target)
    lpips_val = perceptual_loss(image, target)
    return weights[0] * l1 + weights[1] * ssim + weights[2] * lpips_val

def inference(model, scheduler, targets, guidance_scale, chunks=1, weights=(1.0, 1.0, 0.1), noise_epsilon = 0.05 ):
    print("Generating image...")
    # noise = torch.randn(1, 1, model.hparams.image_size[1], model.hparams.image_size[0], device=targets.device)
    base_noise = torch.randn(1, 1, model.hparams.image_size[1], model.hparams.image_size[0], device=targets.device)    

    stack = []

    resize_128 = ust.Resize2D((128, 128))
    resize_256 = ust.Resize2D((256, 256))

    for guide in torch.chunk(targets, chunks=chunks, dim=1):
        guide = guide.permute(1, 0, 2, 3)  # Change to (B, C, H, W)
        guide = resize_128(guide)
        
        x = base_noise + noise_epsilon * torch.randn_like(guide)
        # x = noise.repeat(guide.shape[0], 1, 1, 1)
        # x = torch.randn_like(guide)
        
        for i, t in enumerate(scheduler.timesteps):
            with torch.no_grad():
                noise_pred = model(x, t)

            x = x.detach().requires_grad_()
            x0 = scheduler.step(noise_pred, t, x).pred_original_sample

            # Compute tweak using guidance loss gradient
            loss = guidance_loss(x0, guide, weights) * guidance_scale
            cond_grad = -torch.autograd.grad(loss, x, retain_graph=False)[0]
            cond_grad = cond_grad / (cond_grad.norm() + 1e-8)
            x = x.detach() + guidance_scale * cond_grad
            
            x = scheduler.step(noise_pred, t, x).prev_sample
        
        with torch.no_grad():
            x = AE(resize_256(x))[0]
        stack.append(x.cpu())

    # Combine chunks and remove channel dimension
    cat = torch.cat(stack).squeeze(1)
    print("Image generated!")
    return cat


In [None]:
sweeps, tags = ss.volume_sampling()

In [None]:
num_inference_steps = 100
guidance_scale = 15.0

scheduler = DDIMScheduler(beta_start=0.0001, beta_end=0.02,
                              beta_schedule="linear")
scheduler.set_timesteps(num_inference_steps)


sweeps_us = []
for sweep in sweeps[0][2:3]:
    sweep_us = inference(model, scheduler, sweep, guidance_scale, chunks=1, weights=(1.0, 10.0, 0.25))
    sweeps_us.append(sweep_us)
sweeps_us = torch.stack(sweeps_us)

In [None]:
px.imshow(sweeps_us[0], animation_frame=0, binary_string=True)