In [1]:
import os
import sys

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import SimpleITK as sitk
import nrrd
import vtk

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

import pytorch_lightning as pl
import pickle
import monai 
import glob 
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.append('/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/')
sys.path.append('/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl')
import dl.loaders.ultrasound_dataset as ultrasound_dataset
import dl.transforms.ultrasound_transforms as ultrasound_transforms
import dl.nets.us_simulation_jit as us_simulation_jit
import dl.nets.us_simu as us_simu

import importlib

sys.path.append('/mnt/raid/C1_ML_Analysis/source/ShapeAXI/src/')




In [2]:
def plot_diffusion(X):
    num_surf = len(X)
    specs_r = [{'type': 'scatter3d'} for _ in range(num_surf)]

    fig = make_subplots(
        rows=1, cols=num_surf,
        specs=[specs_r]
    )

    for idx, x in zip(range(num_surf), X):
        # First scatter plot
        fig.add_trace(
            go.Scatter3d(x=x[:,0], y=x[:,1], z=x[:,2], mode='markers', marker=dict(
                size=2,
                color=x[:,2],                # set color to an array/list of desired values
                colorscale='Viridis',   # choose a colorscale
                opacity=0.8
            )),
            row=1, col=idx+1
        )

    return fig

In [3]:


batch_size = 4
dm = ultrasound_dataset.ImgPCDataModule(
    mount_point="/mnt/raid/C1_ML_Analysis/simulated_data_export",
    csv_train="simulated_data_export/studies_fetus_train_train.csv",
    np_train="simulated_data_export/studies_fetus_train_train.npy",
    csv_valid="simulated_data_export/studies_fetus_train_test.csv",
    np_valid="simulated_data_export/studies_fetus_train_test.npy",
    csv_test="simulated_data_export/studies_fetus_test.csv",
    np_test="simulated_data_export/studies_fetus_test.npy",
    num_samples_train=5000,
    rescale_factor=10,
    batch_size=1
    )
dm.setup()



In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = us_simu.USDDPMPC.load_from_checkpoint("/mnt/raid/C1_ML_Analysis/train_output/simu_reconstruction/USDDPMPC/v0.3/epoch=85-val_loss=0.00.ckpt")
model.to(device)
model.eval()

In [None]:
dl_test = dm.test_dataloader()

In [6]:
X, X_origin, X_end, X_PC = next(iter(dl_test))
X = X[0:1]
X_origin = X_origin[0:1] 
X_end = X_end[0:1]
X_PC = X_PC[0:1]

In [None]:
fig = plot_diffusion(X_PC.cpu().numpy())
fig.show()

In [8]:
x_sweeps, sweeps_tags = model.volume_sampling(X.to(device), X_origin.to(device), X_end.to(device))

# x_sweeps shape is B, N, C, T, H, W. N for number of sweeps ex. torch.Size([2, 2, 1, 200, 256, 256]) 
# tags shape torch.Size([2, 2])
with torch.no_grad():
    batch_size = x_sweeps.shape[0]
    Nsweeps = x_sweeps.shape[1] # Number of sweeps -> T

    z = []
    x_v = []

    for n in range(Nsweeps):
        x_sweeps_n = x_sweeps[:, n, :, :, :, :] # [BS, C, T, H, W]
        sweeps_tags_n = sweeps_tags[:, n]

        z_mu, z_sigma = model.encode(x_sweeps_n)
        z_ = z_mu

        z_ = model.attn_chunk(z_) # [BS, self.hparams.latent_channels, self.hparams.n_chunks, 64. 64]

        z_ = z_.permute(0, 2, 3, 4, 1).reshape(batch_size, model.hparams.n_chunks, -1) # [BS, self.hparams.n_chunks, 64*64*self.hparams.latent_channels]

        z.append(z_.unsqueeze(1))

    z = torch.cat(z, dim=1) # [BS, N, self.hparams.n_chunks, 64*64*self.hparams.latent_channels]

    z = model.proj(z) # [BS, N, elf.hparams.n_chunks, 1280]

    # We don't need to do the trick of using the buffer for the positional encoding here, ALL the sweeps are present in validation
    z = model.p_encoding(z)
    z = z.view(batch_size, -1, model.hparams.embed_dim).contiguous()

    pc, intermediates = model.sample(intermediate_steps=5, z=z)
    # pc, intermediates = model.sample_wguidance(intermediate_steps=5, z=z, guidance_scale=7.5)

In [9]:
fig = plot_diffusion(torch.cat(intermediates, dim=0).cpu().numpy())
fig.show()

In [None]:
with torch.no_grad():
    pc_f, intermediates_f = model.sample(intermediate_steps=5, z=torch.zeros(1, 1, model.hparams.embed_dim, device=device))

In [11]:
fig = plot_diffusion(torch.cat(intermediates_f, dim=0).cpu().numpy())
fig.show()

In [24]:
def sample_guided(self, num_samples=1, guidance_scale=7.5, intermediate_steps=None, z=None):
    intermediates = []

    # Initialize random noise
    device = self.device
    x_t = torch.randn(num_samples, 64*64, self.hparams.input_dim, device=device)

    for i, t in enumerate(self.noise_scheduler.timesteps):
        
        # Conditional prediction (with context)
        x_cond = self(
            x_t.permute(0, 2, 1).view(-1, self.hparams.input_dim, 64, 64).contiguous(),
            timesteps=t,
            context=z
        )
        x_cond = x_cond.view(-1, self.hparams.input_dim, 64*64).permute(0, 2, 1)

        # Unconditional prediction (without context)
        x_uncond = self(
            x_t.permute(0, 2, 1).view(-1, self.hparams.input_dim, 64, 64).contiguous(),
            timesteps=t,
            context=torch.zeros(num_samples, 1, self.hparams.embed_dim, device=device)
        )
        x_uncond = x_uncond.view(-1, self.hparams.input_dim, 64*64).permute(0, 2, 1).contiguous()

        # Perform classifier-free guidance
        x_guided = x_uncond + guidance_scale * (x_cond - x_uncond)

        # Update the diffusion step using guided output
        x_t = self.noise_scheduler.step(model_output=x_guided, timestep=t, sample=x_t).prev_sample

        # Save intermediate steps if needed
        if intermediate_steps is not None and intermediate_steps > 0 and t % (self.hparams.num_train_steps//intermediate_steps) == 0:
            intermediates.append(x_t)

    return x_t, intermediates

In [30]:
with torch.no_grad():
    pc_guided, intermediates_guided = sample_guided(model, intermediate_steps=5, guidance_scale=10, z=z)

In [31]:
fig = plot_diffusion(torch.cat(intermediates_guided, dim=0).cpu().numpy())
fig.show()