In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import nrrd
import pandas as pd
import os
import sys

import plotly.express as px
import plotly.graph_objects as go

# sys.path.append("/mnt/raid/C1_ML_Analysis/source/autoencoder/src")
sys.path.append("/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl/")
sys.path.append("/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl/nets")

from nets import cut
from loaders import ultrasound_dataset as usd
from transforms import ultrasound_transforms as ust

In [None]:
# batch_size = 2
# dm = usd.Cut3DDataModule(
#     csv_train_diffusor="/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export_resampled_sub_train_test.csv_",
#     csv_valid_diffusor="/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export_resampled_sub_train_test.csv_",
#     img_column_diffusor="img",    
#     csv_train="/mnt/raid/C1_ML_Analysis/CSV_files/ALL_C2_cines_gt_ga_withmeta_20221031_butterfly_train.csv",
#     csv_valid="/mnt/raid/C1_ML_Analysis/CSV_files/ALL_C2_cines_gt_ga_withmeta_20221031_butterfly_valid.csv",
#     img_column="file_path",    
#     num_frames=64,
#     mount_point="/mnt/raid/C1_ML_Analysis/",
#     batch_size=batch_size,
#     num_workers=1,    
#     prefetch_factor=2,
#     drop_last=False,
#     )
# dm.setup()

# train_dl = dm.val_dataloader()
# train_dl_iter = iter(train_dl)

# batch = next(train_dl_iter)

In [None]:
batch_size = 1
num_frames = 128
dm = usd.USButterflyBlindSweepDataModule(    
    csv_train="/mnt/raid/C1_ML_Analysis/CSV_files/ALL_C2_cines_gt_ga_withmeta_20221031_butterfly_train.csv",
    csv_valid="/mnt/raid/C1_ML_Analysis/CSV_files/ALL_C2_cines_gt_ga_withmeta_20221031_butterfly_valid.csv",
    img_column="file_path",    
    num_frames=num_frames,
    continous_frames=True,
    mount_point="/mnt/raid/C1_ML_Analysis/",
    batch_size=batch_size,
    num_workers=1,    
    prefetch_factor=2,
    drop_last=False,
    )
dm.setup()

train_dl = dm.train_dataloader()
train_dl_iter = iter(train_dl)



In [None]:
args = {
    'lr': 1e-4,
    'weight_decay': 0.01,
    'betas': (0.9, 0.999),
    'epochs': 200,
    'patience': 30,
    'steps': -1,
    'batch_size': batch_size,
    'conv3d': 1,
    'probe_paths': '/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export/all_poses/frame_0001/probe_paths',
    'diffusor': '/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export/all_poses/frame_0001.nrrd',
    'params_csv': '/mnt/raid/C1_ML_Analysis/simulated_data_export/animation_export/shapes_intensity_map_nrrd.csv',
    'center_y_start': -40.0,
    'center_y_end': -20.0,
    'r2_start': 2150.0,
    'r2_end': 235.0,
    'theta_start': np.pi / 5.0,
    'theta_end': np.pi / 3.0,
    'num_random_sweeps': 1,
    'lambda_y': 1,
    'create_grids': 0,
    'n_grids': 256,
    'num_labels': 333,  # Number of labels for the USR dictionary,
    'grid_w': 256,
    'grid_h': 256,
    'center_x': 128.0,
    'center_y': -30.0,
    'r1': 10.0,
    'r2': 225.0,
    'theta': np.pi / 4.75,
    'padding': 70,  # Padding for the simulated ultrasound
    'num_frames': num_frames,
    'continous_frames': True,
}
model = cut.CutLabel(**args)


In [None]:
batch = next(train_dl_iter)

In [None]:
def increase_contrast(tensor, factor=1.5):
    """
    Linearly increase contrast around the mean.
    
    Args:
        tensor (torch.Tensor): Input tensor of shape (D, H, W) or (C, D, H, W)
        factor (float): Contrast scaling factor. >1 increases contrast, <1 decreases
    
    Returns:
        torch.Tensor: Contrast-enhanced tensor
    """
    mean = tensor.mean()
    return torch.clamp((tensor - mean) * factor + mean, min=0.0, max=1.0)

In [None]:
model.init_grid(w=256, h=256, center_x=128, center_y=-30, r1=20.0, r2=210.0, theta=np.pi / 4.75, padding=70)
# model.cuda()

X, tags = model.volume_sampling(model.diffusor_t, model.diffusor_origin, model.diffusor_end, use_random=False)
X = increase_contrast(X, factor=1.6)
X = X[0]

# Assume you want to overlay Y_fake[0,0] and batch[0,0]
img1 = X[0,0,0].cpu().numpy()
img2 = batch[0,0,0].cpu().numpy()

fig = go.Figure()

fig.add_trace(go.Heatmap(z=np.flip(img1, axis=0), opacity=0.8, colorscale='hot'))
fig.add_trace(go.Heatmap(z=np.flip(img2, axis=0), opacity=0.1, colorscale='ice'))

fig.update_layout(height=800, width=800)
fig.show()


In [None]:
model = cut.CutLabel.load_from_checkpoint('/mnt/raid/C1_ML_Analysis/train_output/Cut3d/0.5/epoch=17-val_loss=3.08.ckpt')
model.on_fit_start()
model.eval()

In [None]:

with torch.no_grad():
    X, tags = model.volume_sampling(model.diffusor_t, model.diffusor_origin, model.diffusor_end, use_random=False)
    Y_fake = model.G(model.resize_t(X[0]))

Y_fake.shape
fig = px.imshow(Y_fake[0,0].cpu().numpy(), animation_frame=0, binary_string=True, height=800, width=800)
fig.show()
