In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from tqdm import tqdm
from skimage import transform
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
import monai
import sys
sys.path.append('./modified_medsam_repo')
from segment_anything import sam_model_registry
import torch.nn.functional as F
import argparse
import random
from datetime import datetime
import shutil
import glob
import pandas as pd
import nibabel as nib
import pickle
import time

from MedSAM_HCP.dataset import MRIDataset, load_datasets
from MedSAM_HCP.MedSAM import MedSAM, medsam_inference
from MedSAM_HCP.build_sam import build_sam_vit_b_multiclass
from MedSAM_HCP.utils_hcp import *

In [12]:
list_tups = pd.read_pickle('/gpfs/home/kn2347/MedSAM/darts_name_class_mapping_raw.p')

region_index = [x[0] for x in list_tups]
region_name = [x[1] for x in list_tups]
df = pd.DataFrame({'region_name':region_name, 'region_index':region_index})

df.loc[df['region_name']=='None', 'region_name'] = 'Unknown'

if df[df['region_name']=='Unknown']['region_index'].squeeze() != 0:
    df['region_index'] += 1
    df.loc[df['region_name']=='Unknown', 'region_index'] = 0

df_desired = df
df_desired.to_csv('/gpfs/home/kn2347/MedSAM/darts_name_class_mapping_processed.csv', index=False)
print(df_desired)

                      region_name  region_index
0      Left-Cerebral-White-Matter             1
1          Left-Lateral-Ventricle             2
2               Left-Inf-Lat-Vent             3
3    Left-Cerebellum-White-Matter             4
4          Left-Cerebellum-Cortex             5
..                            ...           ...
98        ctx-rh-superiortemporal            99
99           ctx-rh-supramarginal           100
100     ctx-rh-transversetemporal           101
101                 ctx-rh-insula           102
102                       Unknown             0

[103 rows x 2 columns]


In [27]:
df_desired['region_name'].tolist()

['Left-Cerebral-White-Matter',
 'Left-Lateral-Ventricle',
 'Left-Inf-Lat-Vent',
 'Left-Cerebellum-White-Matter',
 'Left-Cerebellum-Cortex',
 'Left-Thalamus-Proper',
 'Left-Caudate',
 'Left-Putamen',
 'Left-Pallidum',
 '3rd-Ventricle',
 '4th-Ventricle',
 'Brain-Stem',
 'Left-Hippocampus',
 'Left-Amygdala',
 'CSF',
 'Left-Accumbens-area',
 'Left-VentralDC',
 'Left-vessel',
 'Left-choroid-plexus',
 'Right-Cerebral-White-Matter',
 'Right-Lateral-Ventricle',
 'Right-Inf-Lat-Vent',
 'Right-Cerebellum-White-Matter',
 'Right-Cerebellum-Cortex',
 'Right-Thalamus-Proper',
 'Right-Caudate',
 'Right-Putamen',
 'Right-Pallidum',
 'Right-Hippocampus',
 'Right-Amygdala',
 'Right-Accumbens-area',
 'Right-VentralDC',
 'Right-vessel',
 'Right-choroid-plexus',
 'Optic-Chiasm',
 'CC_Posterior',
 'CC_Mid_Posterior',
 'CC_Central',
 'CC_Mid_Anterior',
 'CC_Anterior',
 'ctx-lh-caudalanteriorcingulate',
 'ctx-lh-caudalmiddlefrontal',
 'ctx-lh-cuneus',
 'ctx-lh-entorhinal',
 'ctx-lh-fusiform',
 'ctx-lh-inferio

In [21]:
# now read in hcp mapping
df = pd.read_table('/gpfs/home/kn2347/MedSAM/hcp_mapping_raw.txt', delim_whitespace=True)[['Label_Name', 'No']]
df.columns = ['region_name', 'region_index']
df
df.to_csv('/gpfs/home/kn2347/MedSAM/hcp_mapping_processed.csv', index=False)
df_hcp = df

In [24]:
xd = pd.merge(df_desired, df_hcp, on='region_name', how = 'left', suffixes = ['_desired', '_hcp'])
print(xd)

                      region_name  region_index_desired  region_index_hcp
0      Left-Cerebral-White-Matter                     1                 2
1          Left-Lateral-Ventricle                     2                 4
2               Left-Inf-Lat-Vent                     3                 5
3    Left-Cerebellum-White-Matter                     4                 7
4          Left-Cerebellum-Cortex                     5                 8
..                            ...                   ...               ...
98        ctx-rh-superiortemporal                    99              2030
99           ctx-rh-supramarginal                   100              2031
100     ctx-rh-transversetemporal                   101              2034
101                 ctx-rh-insula                   102              2035
102                       Unknown                     0                 0

[103 rows x 3 columns]


In [5]:
patho = '/gpfs/data/cbi/hcp/hcp_seg/data_orig/100206/mri/aparc+aseg.mgz'
aseg_img = nib.freesurfer.mghformat.MGHImage.from_filename(patho) 
aseg_img = aseg_img.get_fdata().astype(np.float64)

In [17]:
patho = '/gpfs/data/cbi/hcp/hcp_seg/data_orig/100206/mri/T1.mgz'
img = nib.freesurfer.mghformat.MGHImage.from_filename(patho)
img = img.get_fdata().astype(np.float64)

In [19]:
img.min()

0.0

In [15]:
df_hcp = pd.read_csv('/gpfs/home/kn2347/MedSAM/hcp_mapping_processed.csv')
df_desired = pd.read_csv('/gpfs/home/kn2347/MedSAM/darts_name_class_mapping_processed.csv')

lc = LabelConverter(df_hcp, df_desired)
lc.hcp_to_compressed(np.array([0, 0, 2, 0, 2035, 498]))

array([  0,   0,   1,   0, 102,   0])

### Test label mapping

Test by converting HCP-labeled array to compressed numbers, and then converting back

In [3]:
df_hcp = pd.read_csv('/gpfs/home/kn2347/MedSAM/hcp_mapping_processed.csv')
df_desired = pd.read_csv('/gpfs/home/kn2347/MedSAM/darts_name_class_mapping_processed.csv')

lc = LabelConverter(df_hcp, df_desired)

In [10]:
arr = np.load('/gpfs/data/luilab/karthik/pediatric_seg_proj/hcp_ya_slices_npy/segmentation_slices/100206/seg_128.npy')
print(arr.shape)

comp_version = lc.hcp_to_compressed(arr)
back_to_hcp = lc.compressed_to_hcp(comp_version)

(256, 256)


In [13]:
df_desired

Unnamed: 0,region_name,region_index
0,Left-Cerebral-White-Matter,1
1,Left-Lateral-Ventricle,2
2,Left-Inf-Lat-Vent,3
3,Left-Cerebellum-White-Matter,4
4,Left-Cerebellum-Cortex,5
...,...,...
98,ctx-rh-superiortemporal,99
99,ctx-rh-supramarginal,100
100,ctx-rh-transversetemporal,101
101,ctx-rh-insula,102


In [14]:
lc.lookup_table

Unnamed: 0,region_name,region_index_desired,region_index_hcp
0,Left-Cerebral-White-Matter,1,2
1,Left-Lateral-Ventricle,2,4
2,Left-Inf-Lat-Vent,3,5
3,Left-Cerebellum-White-Matter,4,7
4,Left-Cerebellum-Cortex,5,8
...,...,...,...
98,ctx-rh-superiortemporal,99,2030
99,ctx-rh-supramarginal,100,2031
100,ctx-rh-transversetemporal,101,2034
101,ctx-rh-insula,102,2035


In [None]:
for i in range(len(df_desired)):
    reg_name = df_desired.iloc[i, 0]
    reg_id_desired = df_desired.iloc[i, 1]
    reg_id_hcp = df_hcp[df_hcp['region_name'] == reg_name].iloc[0, 1]

    mask_arr = arr == reg_id_hcp
    mask_transf = back_to_hcp == reg_id_hcp

    print((mask_arr == mask_transf).all())

# note that the unknown region is the only region that differs - this is expected, because unknown also takes on the regions that were in HCP mapping but not in the desired mapping