In [None]:
import torch
import sys
import os
import nrrd 
sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/ShapeAXI/src')

import shapeaxi
from shapeaxi import utils

sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/')
sys.path.append('/mnt/famli_netapp_shared/C1_ML_Analysis/src/famli-ultra-sim/dl')
import dl.transforms.ultrasound_transforms as ultrasound_transforms
import dl.nets.us_simu as us_simu
from dl.loaders.ultrasound_dataset import FluidDataset, FluidDataModule

from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.core import LightningDataModule

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

import pandas as pd
import SimpleITK as sitk
import numpy as np

from torch.nn.utils.rnn import pad_sequence

In [None]:


class FixFluidDataset(Dataset):
    def __init__(self, df, mount_point = "./", transform=None, img_column="img", seg_column="seg", rois=["Amniotic_fluid", "Maternal_bladder", "Fetal_chest_abdomen", "Fetal_head", "Fetal_heart", "Fetal_limb_other", "Placenta", "Umbilical_cord"]):
        self.df = df
        self.mount_point = mount_point
        self.transform = transform
        self.img_column = img_column        
        self.seg_column = seg_column
        self.rois = rois

        # self.colname_roi_name_map = {'Amniotic_fluid_indicator': 'Amniotic_fluid',
        #                 'Maternal_bladder_indicator': 'Maternal_bladder',
        #                 'fetal_chest_abdomen_indicator': 'Fetal_chest_abdomen',
        #                 'fetal_head_indicator': 'Fetal_head',
        #                 'fetal_heart_indicator': 'Fetal_heart',
        #                 'fetal_limb_other_indicator': 'Fetal_limb_other',
        #                 'placenta_indicator': 'Placenta',
        #                 'umbilical_cord_indicator': 'Umbilical_cord',
        #                 'dropout_indicator': 'Dropout',
        #                 'shadowing_indicator': 'Shadowing'}

    def __len__(self):
        return len(self.df.index)
    
    def generate_label_map_simpleitk(self, seg_path):
        # Initialize an empty label map (will be set to the correct size later)
        label_map = None

        # Loop through each ROI and try to find the corresponding segmentation file
        for i, roi in enumerate(self.rois):
            # List all files in the seg_path directory
            seg_files = [f for f in os.listdir(seg_path) if roi in f]

            # Loop over each segmentation file for the current ROI
            for seg_file in seg_files:
                # Load the segmentation file
                seg_img = sitk.ReadImage(os.path.join(seg_path, seg_file))
                seg_data = sitk.GetArrayFromImage(seg_img)
                
                # If this is the first ROI, initialize the label map
                if label_map is None:
                    label_map = np.zeros_like(seg_data, dtype=np.int32)
                
                # Assign the label corresponding to the ROI index in 'rois' where segmentation is present
                label_map[seg_data > 0] = i + 1  # label index starts from 1

        # If label_map is None (no ROIs found), return None
        if label_map is None:
            return None

        # # Convert the label map back to a SimpleITK image
        label_map_img = sitk.GetImageFromArray(np.flip(label_map, axis=1))
        label_map_img.CopyInformation(seg_img)  # Copy metadata from one of the input images

        return label_map_img

    def __getitem__(self, idx):
        
        img_path = os.path.join(self.mount_point, self.df.iloc[idx][self.img_column])
        seg_path = os.path.join(self.mount_point, self.df.iloc[idx][self.seg_column])

        img = sitk.ReadImage(img_path)
        if img.GetNumberOfComponentsPerPixel() > 1:
            img = sitk.VectorIndexSelectionCast(img, 0)
        img_np = sitk.GetArrayFromImage(img)

        # Load the segmentation file
        seg = self.generate_label_map_simpleitk(seg_path)

        if seg is not None:
            
            seg_np = sitk.GetArrayFromImage(seg)

            fname = os.path.splitext(os.path.basename(img_path))[0]
            sitk.WriteImage(img, f"/mnt/raid/C1_ML_Analysis/raid_data_folder/nrrd/{fname}.nrrd", useCompression=True)    
            sitk.WriteImage(seg, f"/mnt/raid/C1_ML_Analysis/raid_data_folder/nrrd/{fname}_seg.nrrd", useCompression=True)

        else:
            seg_np = np.zeros_like(img_np)
        
        return torch.tensor(img_np), torch.tensor(seg_np)

# df_train = pd.read_csv("/mnt/raid/C1_ML_Analysis/raid_data_folder/fluid_annotation_dcm_LABOR_nrrd_resampled_256.csv")
# ds = FluidDataset(df_train, "/mnt/raid/C1_ML_Analysis/raid_data_folder", img_column="img")

In [None]:
# mount_point = "/mnt/raid/C1_ML_Analysis/raid_data_folder"
# df_train = pd.read_csv("/mnt/raid/C1_ML_Analysis/CSV_files/fluid_annotation_dcm_LABOR.csv")
# seg_paths = []
# for idx, row in df_train.iterrows():
#     img_path = row['img']
#     img_path = img_path.split("/")
    
#     seg_path = os.path.join("annotation_roi_LABOR", img_path[1], img_path[2].split("_")[1], img_path[3], img_path[-1].replace(".dcm", "_mask"))
#     if not os.path.exists(os.path.join(mount_point, seg_path)):
#         seg_paths.append(None)
#     else:
#         seg_paths.append(seg_path)

# df_train["seg"] = seg_paths
# df_train = df_train.dropna()
# df_train = df_train.reset_index(drop=True)
# df_train.to_csv("/mnt/raid/C1_ML_Analysis/CSV_files/fluid_annotation_dcm_LABOR_exist.csv", index=False)

In [None]:
# img_d = ds[0]

# img = img_d["img"]
# seg = img_d["seg"]

# def compute_grid(X):
#     bs, c, d, h, w = X.shape

#     mesh_grid_params = [torch.arange(end=s, device=X.device) for s in (d, h, w)]
#     mesh_grid_idx = torch.stack(torch.meshgrid(mesh_grid_params), dim=-1).squeeze().to(torch.float32)
    
#     return mesh_grid_idx

# def get_grid_VF(X):
        
#     V_ = []
#     VF_ = []

#     V = compute_grid(X)
#     V = V.reshape(-1, 3).to(X.device)

#     for x in X:
#         x = x.reshape(-1, 1)
        
#         V_filtered = V[x.squeeze() > 0]
#         F_filtered = x[x.squeeze() > 0]
#         # V_filtered = V
#         # F_filtered = x
#         V_.append(V_filtered)

#         # if hasattr(self.hparams, "use_v") and self.hparams.use_v:
#         #     F_filtered = torch.cat([V_filtered, F_filtered], dim=-1)

#         VF_.append(F_filtered)

#     V = pad_sequence(V_, batch_first=True, padding_value=0.0) 
#     VF = pad_sequence(VF_, batch_first=True, padding_value=0.0)

#     return V, VF


# V_seg, VF_seg = get_grid_VF(seg.unsqueeze(0).unsqueeze(0).cuda())
# print(V_seg.shape, VF_seg.shape)    

In [None]:
csv_train = "/mnt/raid/C1_ML_Analysis/raid_data_folder/fluid_annotation_dcm_LABOR_nrrd_resampled_256_train_train.csv"
csv_val = "/mnt/raid/C1_ML_Analysis/raid_data_folder/fluid_annotation_dcm_LABOR_nrrd_resampled_256_train_test.csv"
csv_test = "/mnt/raid/C1_ML_Analysis/raid_data_folder/fluid_annotation_dcm_LABOR_nrrd_resampled_256_test.csv"
mount_point="/mnt/raid/C1_ML_Analysis/raid_data_folder"
batch_size=2
num_workers=1
img_column="img"
seg_column='seg'

dm = FluidDataModule(csv_train=csv_train, csv_valid=csv_val, csv_test=csv_test, mount_point=mount_point, batch_size=batch_size, num_workers=num_workers, img_column=img_column, seg_column=seg_column, drop_last=False)
dm.setup()



In [None]:
dl_train = dm.train_dataloader()
batch = next(iter(dl_train))

In [None]:
X_v, X_f = batch["img"]

print(X_v.shape, X_f.shape)

Y_v, Y_f = batch["seg"]

print(Y_v.shape, Y_f.shape)

In [None]:
SN = 0
N = 100000

random_indices = torch.randperm(Y_f.size(1))[:N]

fig = go.Figure(data=[go.Scatter3d(x=Y_v[SN,random_indices,2].cpu().numpy(), y=Y_v[SN,random_indices,1].cpu().numpy(), z=Y_v[SN,random_indices,0].cpu().numpy(), mode='markers', marker=dict(
        size=2,
        color=Y_f[SN,random_indices].cpu().numpy().squeeze(),                # set color to an array/list of desired values
        colorscale='jet',   # choose a colorscale
        opacity=0.5
    ))])
fig.show()

In [None]:

# fig = create_figure(img_array, seg_array)
# fig.show()



# px.imshow(img_array, animation_frame=0, binary_string=True).show()


In [None]:

# img_np, head = nrrd.read(os.path.join(mount_point, "nrrd_resampled_256/1.2.840.113619.2.323.107629510516.1572949987.24.512.225.nrrd"), index_order='C')
img_np, head = nrrd.read(os.path.join(mount_point, "nrrd_resampled_256/1.2.840.113619.2.323.107629510517.1650475705.9307.512.37_seg.nrrd"), index_order='C')

# print(img_np.shape)

px.imshow(img_np, animation_frame=0, binary_string=True).show()

In [None]:
# for x in ds:
#     print(x[0].shape, x[1].shape)