In [1]:
# Notebook to create df for each label (e.g. label1_only_df.csv) which contains slices with at least one of that label
# (plus a random selection of other slices)
# will compare these in the training process

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from tqdm import tqdm
from skimage import transform 
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
import sys
sys.path.append('./modified_medsam_repo')
sys.path.append('../../modified_medsam_repo')
import random
from datetime import datetime
import pandas as pd
import nibabel as nib
import pickle   
from MedSAM_HCP.dataset import LabelConverter
from glob import glob

In [4]:
total_df = pd.read_csv('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv')
df_hcp = pd.read_csv('/gpfs/home/kn2347/HCP_MedSAM_project/modified_medsam_repo/hcp_mapping_processed.csv')
df_desired = pd.read_csv('/gpfs/home/kn2347/HCP_MedSAM_project/modified_medsam_repo/darts_name_class_mapping_processed.csv')
label_converter = LabelConverter(df_hcp, df_desired)


In [3]:
frac_blanks_to_include = 0.1 # i.e. 10%

collector = {}
for i, r in tqdm(total_df.iterrows(), total=total_df.shape[0]):
    #print(i)
    seg_path = r['segmentation_slice_path']
    arr = np.load(seg_path)
    arr = label_converter.hcp_to_compressed(arr)
    unq_vals = set(arr.flatten())
    for j in unq_vals:
        if j not in collector:
            collector[j] = []
        #print(f'adding to {j}')
        collector[j].append(i)



100%|██████████| 284928/284928 [2:41:41<00:00, 29.37it/s]  


In [31]:
len(collector[1])

121003

In [4]:
np.random.seed(2024)
total_indices = set(range(total_df.shape[0]))

frac_list = [0.1]
for frac_blanks_to_include in frac_list:
    for label_num in range(1, 103):
        # take set difference to get unincluded slices in this
        unused_idxs = list(total_indices - set(collector[label_num]))
        

        # now choose a random fraction of them
        chosen_idxs = np.random.choice(a = unused_idxs, size = round(frac_blanks_to_include * len(unused_idxs)), replace=False).tolist()

        idxs_for_label = collector[label_num] + chosen_idxs
        idxs_for_label.sort()
        label_df = total_df.iloc[idxs_for_label, :].reset_index(drop=True)

        targ_dir = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df_new/{label_num}/{frac_blanks_to_include}'
        if not os.path.exists(targ_dir):
            os.makedirs(targ_dir)
        label_df.to_csv(f'{targ_dir}/isolated_path_df_constant_bbox.csv', index=False)

        unet_version = label_df[['id', 'slice', 'image_path', 'segmentation_slice_path']]
        unet_version = unet_version.rename({'image_path': 'img_slice_path'}, axis=1)
        unet_version.to_csv(f'{targ_dir}/isolated_path_df_unet.csv', index=False)


Unnamed: 0,id,slice,image_embedding_slice_path,segmentation_slice_path,image_path,bbox_0,bbox_1,bbox_2,bbox_3
0,100206,152,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
1,100206,153,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
2,100206,154,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
3,100206,155,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
4,100206,156,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
...,...,...,...,...,...,...,...,...,...
48349,996782,191,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
48350,996782,192,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
48351,996782,193,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
48352,996782,194,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256


In [69]:
yu = glob('/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df_new/5/*/isolated_path_df_unet.csv')
for jj in yu:
    print(jj.split('/')[-2])
    print(pd.read_csv(jj).shape[0])
    print('===')

0.05
78906
===
0.01
70232
===
0.001
68280
===
0.25
122279
===
1
284928
===
0.1
89749
===
0
68063
===


Unnamed: 0,id,slice,image_embedding_slice_path,segmentation_slice_path,image_path,bbox_0,bbox_1,bbox_2,bbox_3
0,100206,0,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
1,100206,1,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
2,100206,2,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
3,100206,3,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
4,100206,4,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
...,...,...,...,...,...,...,...,...,...
284923,996782,251,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
284924,996782,252,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
284925,996782,253,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256
284926,996782,254,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/pr...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/se...,/gpfs/data/cbi/hcp/hcp_ya/hcp_ya_slices_npy/di...,0,0,256,256


In [None]:
medsam_df = pd.read_csv('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv')
class_counts

merged = medsam_df.merge(class_counts, how='left', on=['id', 'slice'])

In [34]:


for label in range(1, 103):
    key = f'label{label}'
    zz = merged[key].to_numpy()
    xx = (merged[key] >= np.quantile(zz[zz > 0], 0.05)) | (merged[key] ==0)
    savepath = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_eval_dfs/quantile5_mask_arrays/label{label}.npy'
    np.save(savepath, xx)

for label in range(1, 103):
    key = f'label{label}'
    zz = merged[key].to_numpy()
    xx = (merged[key] >= 100) | (merged[key] == 0)
    savepath = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_eval_dfs/at_least_100_mask_arrays/label{label}.npy'
    np.save(savepath, xx)

In [32]:
xx.sum() * 0.1/64

435.75156250000003

In [5]:
### Df's for classes with >= x pixels
class_counts = pd.read_csv('/gpfs/data/luilab/karthik/pediatric_seg_proj/class_counts/class_counts_per_slice.csv')
dicto = {}
for i, r in class_counts.iterrows():
    for label in [1,5,7]:
        dicto[(r['id'], r['slice'], label)] = r[f'label{label}']


#class_counts = class_counts[(class_counts['label1'] >= 69) | (class_counts['label1'] == 0)]

KeyboardInterrupt: 

In [28]:
cutoffs = {}
for lab in range(1, 103):
    xy = class_counts[f'label{lab}']
    xy = xy[xy>0]
    cutoffs[lab] = np.quantile(xy, 0.05)

print(cutoffs[7])
# using:
# label 1: 66
# label 5: 20
# label 7: 11

11.0


In [None]:
def operate(pathname, label, cutoff, suffix_name):
    df = pd.read_csv(pathname)
    mask = [] 
    for i, r in df.iterrows():
        tup = (r['id'], r['slice'], label)
        cnts = dicto[tup]
        mask.append(cnts >= cutoff or cnts == 0)

    df = df[mask].reset_index(drop=True)

    rt = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_eval_dfs/{label}'
    if not os.path.exists(rt):
        os.makedirs(rt)

    prefix = pathname.split('/')[-1].split('.csv')[0]
    suffix = suffix_name
    
    to_path = f'{rt}/{prefix}{suffix}.csv'
    df.to_csv(to_path, index=False)
    return to_path

for label in [1,5,7]:
    # threshold by 0.05th quantile
    operate('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv', label, cutoffs[label], '_filter_bottom_5%')
    operate('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_unet.csv', label, cutoffs[label], '_filter_bottom_5%')

    # threshold by 100 pixels
    operate('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv', label, 100, '_filter_under_100')
    operate('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_unet.csv', label, 100, '_filter_under_100')


    

In [6]:
class_counts

Unnamed: 0,id,slice,label1,label2,label3,label4,label5,label6,label7,label8,...,label93,label94,label95,label96,label97,label98,label99,label100,label101,label102
81,100206,81,146,0,0,0,0,0,0,0,...,197,20,0,0,278,40,0,0,0,0
82,100206,82,212,0,0,0,0,0,0,0,...,233,28,0,0,285,38,0,0,0,0
83,100206,83,292,0,0,0,0,0,0,0,...,267,36,0,0,301,96,0,0,0,0
84,100206,84,352,0,0,0,0,0,0,0,...,279,46,0,0,351,143,0,0,0,0
85,100206,85,412,0,0,0,0,0,0,0,...,333,50,0,0,393,186,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
284851,989987,179,278,0,0,643,1347,0,0,0,...,0,0,0,0,0,0,0,0,0,0
284852,989987,180,233,0,0,756,1267,0,0,0,...,0,0,0,0,0,0,0,0,0,0
284853,989987,181,179,0,0,906,1092,0,0,0,...,0,0,0,0,0,0,0,0,0,0
284854,989987,182,132,0,0,814,1118,0,0,0,...,0,0,0,0,0,0,0,0,0,0
