In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from skimage import io
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import monai
from monai.networks import one_hot
import sys
sys.path.append('./modified_medsam_repo')
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from utils.SurfaceDice import compute_dice_coefficient
from skimage import io, transform
from glob import glob
from sklearn.model_selection import train_test_split
import pandas as pd
import nibabel as nib
import pickle
from torch.utils.data import RandomSampler
import random
import scipy
import torch.nn.functional as F
import img2pdf
from torchmetrics import F1Score

from MedSAM_HCP.dataset import MRIDatasetForPooled, load_datasets
from MedSAM_HCP.MedSAM import MedSAM, medsam_inference
from MedSAM_HCP.build_sam import build_sam_vit_b_multiclass
from MedSAM_HCP.utils_hcp import *
from typing import List

# set seeds
torch.manual_seed(2023)
np.random.seed(2023)

In [6]:
def get_df_for_ids(id_list: List[int], is_test=False):
    assert len(id_list) == 1
    path_df_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv' # gets all the files
    df = pd.read_csv(path_df_path)
    df = df[df['id'].isin(id_list)].reset_index(drop=True)
    
    if is_test:
        path_for_bboxes = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df/yolov7/test/path_df_pooled_labels_only_with_bbox_yolov7_TEST.csv'
    else:
        path_for_bboxes = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df/yolov7/path_df_pooled_labels_only_with_bbox_yolov7.csv'
    
    df_bboxes = pd.read_csv(path_for_bboxes,index_col=0)
    df_bboxes = df_bboxes.drop(columns = ['index'])
    
    df_bboxes = df_bboxes[df_bboxes['id'].isin(id_list)].reset_index(drop=True)
    

    df = df.drop(columns = ['bbox_0', 'bbox_1', 'bbox_2', 'bbox_3'])

    #print(df.shape)
    #print(df)
    #print(df_bboxes.shape)
    #print(df_bboxes)

    df = df.merge(df_bboxes, how='left', on=['id','slice','image_embedding_slice_path', 'segmentation_slice_path', 'image_path'])
    return df

def df_to_dataset(df, as_one_hot=True, pool_labels=True):
    df_hcp = pd.read_csv('/gpfs/home/kn2347/MedSAM/hcp_mapping_processed.csv')
    df_desired = pd.read_csv('/gpfs/home/kn2347/MedSAM/darts_name_class_mapping_processed.csv')
    NUM_CLASSES = len(df_desired)
    label_converter = LabelConverter(df_hcp, df_desired)

    dataset = MRIDatasetForPooled(df, 1, 0, label_converter = label_converter, NUM_CLASSES=NUM_CLASSES, as_one_hot=as_one_hot, pool_labels=pool_labels)
    return dataset





In [3]:
#df = get_df_for_ids([162935], is_test=True)
#dataset = df_to_dataset(df, as_one_hot=True, pool_labels=False)
#dataset[0][0]

(256, 5)
         id  slice                         image_embedding_slice_path  \
0    162935      0  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
1    162935      1  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
2    162935      2  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
3    162935      3  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
4    162935      4  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
..      ...    ...                                                ...   
251  162935    251  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
252  162935    252  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
253  162935    253  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
254  162935    254  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   
255  162935    255  /gpfs/data/luilab/karthik/pediatric_seg_proj/h...   

                               segmentation_slice_path  \
0    /gpfs/data/luilab/karthik/pediatric_seg_proj/h...  

In [37]:
# code to generate for pooled model
def dataset_to_predictions_for_pooled_model(id_list, checkpoint, save_path, device='cuda', is_test=False):

    df = get_df_for_ids(id_list, is_test=is_test)
    dataset = df_to_dataset(df, as_one_hot=True, pool_labels=True)
    dataloader = DataLoader(
            dataset,
            batch_size = 16,
            shuffle = False,
            num_workers = 0,
            pin_memory = True)

    model = build_sam_vit_b_multiclass(103, checkpoint=checkpoint).to(device)

    collector = []
    ctr = 0
    overall_tensor = torch.zeros((256, 256, 256)).cuda()
        
    pred_collector = []
    label_num_collector = []
    slice_collector = []

    viz = {}
    print(len(dataloader))
    for step, (image_embedding, gt2D, boxes, slice_names, label_nums) in enumerate(tqdm(dataloader)):
        image_embedding, gt2D, boxes = image_embedding.to(device), gt2D.to(device), boxes.to(device)
        medsam_pred = torch.as_tensor(medsam_inference(
            model, image_embedding, boxes, 256, 256,
            as_one_hot=True, model_trained_on_multi_label=False,
            num_classes=1), dtype=torch.uint8).to(device)
        # hopefully this is B, 1, H, W as a binary 0/1 tensor
        pred_collector.append(medsam_pred)
        label_num_collector.append(label_nums)

        slice_nums_list = [int(x.split('_slice')[1].split('.npy')[0]) for x in slice_names]
        slice_collector.extend(slice_nums_list)
    print('done all')
    print(pred_collector[0].shape)
    print(label_num_collector[0].shape)
    total_tensor = torch.cat(pred_collector, dim=0)[:,0,:,:] # now should be N, H, W
    total_labels = torch.cat(label_num_collector, dim=0)
    for i in range(total_tensor.shape[0]):
        this_label_num = total_labels[i].item()
        if np.isnan(this_label_num):
            continue
        

        this_slice_num = slice_collector[i]
        where_mask = torch.where(total_tensor[this_slice_num,:,:] == 1)
        overall_tensor[ [this_slice_num] * where_mask[0].shape[0], where_mask[0], where_mask[1]] = this_label_num
        if this_label_num not in viz:
            viz[this_label_num] = True
            print(f'hello at {this_label_num} {torch.where(total_tensor[this_slice_num,:,:] == 1)[0].shape}')
    
    
    all_np = overall_tensor.cpu().detach().numpy()
    #assert all_np.shape == (len(dataset), len(files), 256, 256)

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    np.save(save_path + '/singletask_seg_all.npy', all_np)

    return all_np






checkpoint = '/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/pooled_labels_ckpt_continue_8-22-23/model_best_20230822-115028.pth'

xd = dataset_to_predictions_for_pooled_model(id_list=[162935], 
                                            checkpoint = checkpoint, 
                                            save_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/pooled_seg_gifs/162935',
                                            device='cuda',
                                            is_test=True)

233


100%|██████████| 233/233 [00:41<00:00,  5.66it/s]


done all
torch.Size([16, 1, 256, 256])
torch.Size([16])
hello at 91.0 torch.Size([40])
hello at 93.0 torch.Size([40])
hello at 98.0 torch.Size([40])
hello at 62.0 torch.Size([38])
hello at 86.0 torch.Size([38])
hello at 55.0 torch.Size([37])
hello at 60.0 torch.Size([37])
hello at 67.0 torch.Size([37])
hello at 94.0 torch.Size([37])
hello at 97.0 torch.Size([37])
hello at 20.0 torch.Size([28])
hello at 1.0 torch.Size([35])
hello at 63.0 torch.Size([35])
hello at 66.0 torch.Size([75])
hello at 42.0 torch.Size([68])
hello at 73.0 torch.Size([68])
hello at 100.0 torch.Size([68])
hello at 77.0 torch.Size([79])
hello at 69.0 torch.Size([82])
hello at 46.0 torch.Size([55])
hello at 96.0 torch.Size([127])
hello at 61.0 torch.Size([81])
hello at 65.0 torch.Size([81])
hello at 92.0 torch.Size([81])
hello at 41.0 torch.Size([58])
hello at 48.0 torch.Size([128])
hello at 79.0 torch.Size([128])
hello at 72.0 torch.Size([17])
hello at 43.0 torch.Size([108])
hello at 87.0 torch.Size([109])
hello at 

In [42]:
for xz in np.unique(xd):
    print(xz, (xd==xz).sum())

0.0 16756038
12.0 3926
24.0 2564
45.0 120
47.0 133
78.0 800
84.0 1548
98.0 937
99.0 2377
100.0 2913
102.0 5860


In [39]:
np.unique(xd)

array([  0.,  12.,  24.,  45.,  47.,  78.,  84.,  98.,  99., 100., 102.],
      dtype=float32)

In [36]:
xd = np.load('/gpfs/data/luilab/karthik/pediatric_seg_proj/pooled_seg_gifs/162935/singletask_seg_all.npy')
np.unique(xd)

array([  0.,  12.,  24.,  45.,  47.,  78.,  84.,  98.,  99., 100., 102.],
      dtype=float32)

In [None]:
# do a val image: 151425
xd2 = dataset_to_predictions_for_singletask_models(id_list=[151425], 
                                            checkpoint_folder_pattern = checkpoint_folder_pattern, 
                                            dataset = dataset,
                                            save_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/saved_round2_segmentations/151425',
                                            device='cuda')

In [27]:
np.sum(xd.cpu().detach().numpy() == 46)

15276