In [None]:
import os
import sys

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import SimpleITK as sitk
import nrrd
import vtk

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms

import pytorch_lightning as pl
import pickle
import monai 
import glob 
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.append('/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/')
sys.path.append('/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl')
import dl.transforms.ultrasound_transforms as ultrasound_transforms
import dl.nets.us_simulation_jit as us_simulation_jit
import dl.nets.us_simu as us_simu

import importlib

from dl.nets.layers import TimeDistributed

from torchvision import transforms as T


In [None]:
mount_point = '/mnt/raid/C1_ML_Analysis'

In [None]:
df = pd.read_csv('/mnt/famli_netapp_shared/C1_ML_Analysis/src/diffusion-models/blender/studies/placenta/C_dataset_analysis_protocoltagsonly_gaboe230_masked_resampled_256_spc075.csv')
df_study = df.query('study_id == "FAM-025-0950-4"')

In [None]:
# model_fn = os.path.join(mount_point, 'train_output/simu_reconstruction/USPCReconstruction/v0.2/epoch=7-val_loss=0.01.ckpt')
# model_fn = os.path.join(mount_point, 'train_output/simu_reconstruction/USPCReconstruction/v0.3/epoch=8-val_loss=0.01.ckpt')
model_fn = os.path.join(mount_point, 'train_output/seg/USSeg/0.1/epoch=139-val_loss=0.33.ckpt')
model = us_simu.USSeg.load_from_checkpoint(model_fn).eval().cuda()


In [None]:
def sweep_resample(sweep_t, target_shape=(128, 256, 256)):
    return F.interpolate(sweep_t, target_shape, mode='nearest')   

sweeps_fov = []
sweeps_t = []
sweep_tags = []
df_study_tags = df_study[df_study['tag'].isin(model.vs.tags)]


inverse_transform_us = T.Compose([T.Pad((0, 0, 0, 40)),  T.Lambda(lambda x: T.functional.crop(x, 40, 0, 256, 256))])

for idx, row in df_study_tags.iterrows():
    tag = row['tag']
    sweep_fn = os.path.join(mount_point, row['file_path'])
    img_np, _ = nrrd.read(sweep_fn, index_order='C')
    img_np = img_np.astype(np.float32)
    sweep_t = torch.tensor(img_np).cuda().float()/255.0
    sweep_t = sweep_t.unsqueeze(0)
    sweep_tags.append(tag)
    # sweep_t = sweep_t[:, :, :, 0:1].permute(3, 0, 1, 2)
    # print(sweep_t.shape)
    
    
    # sweep_t = inverse_transform_us(sweep_t.permute(1, 0, 2, 3)).permute(1, 0, 2, 3)
    sweep_resampled = sweep_resample(sweep_t.unsqueeze(0))
    sweeps_t.append(sweep_resampled)
    # sweeps_fov.append(sweep_resampled)
    
    # sweeps_fov.append(pc_r.vs.simulated_sweep_in_fov(tag, sweep_resampled))

# sweeps_fov = torch.stack(sweeps_fov)
# print(sweeps_fov[0].shape)
# print(sweeps_fov[0].shape)

sweeps_t = torch.cat(sweeps_t)
# fig = px.imshow(sweeps_t[0].squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

print(sweep_tags)
# fig = px.imshow(sweeps_fov[0].squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
sweeps_seg_t = model(sweeps_t.cuda())
# sweeps_seg_t = torch.where(sweeps_seg_t > 0.9, sweeps_seg_t, torch.zeros_like(sweeps_seg_t))
# sweeps_seg_t = torch.argmax(sweeps_seg_t, dim=1)
sweeps_seg_t = torch.argmax(sweeps_seg_t, dim=1, keepdim=True)

In [None]:
def plot_seq(x, x_hat):
        
        T, H, W = x.shape
        # Create subplot figure with 1 row and 2 columns
        fig = go.Figure()

        # Initial frame
        fig.add_trace(go.Heatmap(z=x[0], colorscale = 'gray'))
        fig.add_trace(go.Heatmap(z=x_hat[0], coloraxis = 'coloraxis'))

        # Add animation frames
        frames = []
        for i in range(T):
            frames.append(go.Frame(
                data=[
                    go.Heatmap(z=x[i], colorscale = 'gray'),
                    go.Heatmap(z=x_hat[i], coloraxis = 'coloraxis'),
                ],
                name=str(i)
            ))

        fig.frames = frames

        # Update layout with animation settings and fixed aspect ratio
        fig.update_layout(
            autosize=False,
            width=1200,  # Adjust width as needed
            height=600,  # Adjust height according to aspect ratio
            coloraxis={"colorscale": "jet",
                    "cmin": 0,  # Set global min value for color scale
                    "cmax": 11},   # Set global max value for color scale},  # Set colorscale for the shared coloraxis
            updatemenus=[{
                "buttons": [
                    {
                        "args": [None, {"frame": {"duration": 500, "redraw": True},
                                        "fromcurrent": True, "mode": "immediate"}],
                        "label": "Play",
                        "method": "animate"
                    },
                    {
                        "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                        "mode": "immediate"}],
                        "label": "Pause",
                        "method": "animate"
                    }
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 87},
                "showactive": False,
                "type": "buttons",
                "x": 0.1,
                "xanchor": "right",
                "y": 0,
                "yanchor": "top"
            }],
            sliders=[{
                "steps": [
                    {
                        "args": [[str(k)], {"frame": {"duration": 300, "redraw": True},
                                            "mode": "immediate"}],
                        "label": str(k),
                        "method": "animate"
                    } for k in range(x.shape[0])
                ],
                "active": 0,
                "yanchor": "top",
                "xanchor": "left",
                "currentvalue": {
                    "font": {"size": 20},
                    "prefix": "Frame:",
                    "visible": True,
                    "xanchor": "right"
                },
                "transition": {"duration": 300, "easing": "cubic-in-out"}
            }]
        )
        return fig

In [None]:

X_V = model.vs.embed_sweep(sweep_tags[0], sweeps_seg_t[0].unsqueeze(0))

X_V = X_V.squeeze(0).permute(1, 2, 3, 0).reshape(-1, 4)
# fig = px.imshow(sweeps_seg_t[0].squeeze().cpu().numpy(), animation_frame=0, binary_string=True)
# fig.show()

In [None]:
x_v = X_V.cpu().numpy()

x_v = x_v[x_v[:, 0] == 7]

N = 50000
random_indices = torch.randperm(x_v.shape[0])[:N]

fig = go.Figure(data=[go.Scatter3d(x=x_v[random_indices,1], y=x_v[random_indices,2], z=x_v[random_indices,3], mode='markers', marker=dict(
        size=2,
        color=x_v[random_indices,0],                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=1.0
    ))])
fig.show()

In [None]:
for img, seg, tag in zip(sweeps_t, sweeps_seg_t, sweep_tags):
    img = img.squeeze().cpu().numpy()
    img = (img*255).astype(np.uint8)
    seg = seg.squeeze().cpu().numpy().astype(np.uint8)
    img = sitk.GetImageFromArray(img)
    seg = sitk.GetImageFromArray(seg)
    sitk.WriteImage(img, os.path.join(mount_point, 'simulated_data_export/test_output/seg/', tag + '.nrrd'))
    sitk.WriteImage(seg, os.path.join(mount_point, 'simulated_data_export/test_output/seg/', tag + '_seg.nrrd'))