In [1]:
# Notebook to create df for each label (e.g. label1_only_df.csv) which contains slices with at least one of that label
# (plus a random selection of other slices)
# will compare these in the training process

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from tqdm import tqdm
from skimage import transform 
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
import sys
sys.path.append('./modified_medsam_repo')
sys.path.append('../../modified_medsam_repo')
import random
from datetime import datetime
import pandas as pd
import nibabel as nib
import pickle   
from MedSAM_HCP.dataset import LabelConverter
from glob import glob

In [2]:
total_df = pd.read_csv('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv')
df_hcp = pd.read_csv('/gpfs/home/kn2347/HCP_MedSAM_project/modified_medsam_repo/hcp_mapping_processed.csv')
df_desired = pd.read_csv('/gpfs/home/kn2347/HCP_MedSAM_project/modified_medsam_repo/darts_name_class_mapping_processed.csv')
label_converter = LabelConverter(df_hcp, df_desired)


In [3]:
frac_blanks_to_include = 0.1 # i.e. 10%

collector = {}
for i, r in tqdm(total_df.iterrows(), total=total_df.shape[0]):
    #print(i)
    seg_path = r['segmentation_slice_path']
    arr = np.load(seg_path)
    arr = label_converter.hcp_to_compressed(arr)
    unq_vals = set(arr.flatten())
    for j in unq_vals:
        if j not in collector:
            collector[j] = []
        #print(f'adding to {j}')
        collector[j].append(i)



100%|██████████| 284928/284928 [2:41:41<00:00, 29.37it/s]  


In [31]:
len(collector[1])

121003

In [4]:
np.random.seed(2024)
total_indices = set(range(total_df.shape[0]))

frac_list = [0.1]
for frac_blanks_to_include in frac_list:
    for label_num in range(1, 103):
        # take set difference to get unincluded slices in this
        unused_idxs = list(total_indices - set(collector[label_num]))
        

        # now choose a random fraction of them
        chosen_idxs = np.random.choice(a = unused_idxs, size = round(frac_blanks_to_include * len(unused_idxs)), replace=False).tolist()

        idxs_for_label = collector[label_num] + chosen_idxs
        idxs_for_label.sort()
        label_df = total_df.iloc[idxs_for_label, :].reset_index(drop=True)

        targ_dir = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df_new/{label_num}/{frac_blanks_to_include}'
        if not os.path.exists(targ_dir):
            os.makedirs(targ_dir)
        label_df.to_csv(f'{targ_dir}/isolated_path_df_constant_bbox.csv', index=False)

        unet_version = label_df[['id', 'slice', 'image_path', 'segmentation_slice_path']]
        unet_version = unet_version.rename({'image_path': 'img_slice_path'}, axis=1)
        unet_version.to_csv(f'{targ_dir}/isolated_path_df_unet.csv', index=False)


Unnamed: 0,id,slice,image_embedding_slice_path,segmentation_slice_path,image_path,bbox_0,bbox_1,bbox_2,bbox_3
0,100206,152,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
1,100206,153,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
2,100206,154,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
3,100206,155,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
4,100206,156,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
...,...,...,...,...,...,...,...,...,...
48349,996782,191,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
48350,996782,192,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
48351,996782,193,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
48352,996782,194,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256


In [69]:
yu = glob('/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df_new/5/*/isolated_path_df_unet.csv')
for jj in yu:
    print(jj.split('/')[-2])
    print(pd.read_csv(jj).shape[0])
    print('===')

0.05
78906
===
0.01
70232
===
0.001
68280
===
0.25
122279
===
1
284928
===
0.1
89749
===
0
68063
===
