In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from skimage import io
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import monai
from monai.networks import one_hot
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from utils.SurfaceDice import compute_dice_coefficient
from skimage import io, transform
from glob import glob
from sklearn.model_selection import train_test_split
import pandas as pd
import nibabel as nib
import pickle
from torch.utils.data import RandomSampler
import random
import scipy
import torch.nn.functional as F
import img2pdf
from torchmetrics import F1Score

from MedSAM_HCP.dataset import MRIDataset, load_datasets
from MedSAM_HCP.MedSAM import MedSAM, medsam_inference
from MedSAM_HCP.build_sam import build_sam_vit_b_multiclass
from MedSAM_HCP.utils_hcp import *
from typing import List

# set seeds
torch.manual_seed(2023)
np.random.seed(2023)

In [12]:
# don't run unless want to regenerate checkpoints
pattern = '/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/second_round_w_bbox_yolov7_finetunes_60epochs_8-20-23/label*/*/medsam_model_best.pth'
for file in glob(pattern):
    prefix = '/'.join(file.split('/')[:-2])
    to_path = os.path.join(prefix, 'best_model_epoch_sam_readable.pth')
    convert_medsam_checkpt_to_readable_for_sam(file, to_save_dir = to_path)

In [2]:
def get_df_for_ids(id_list: List[int], label):
    assert len(id_list) == 1
    path_df_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv' # gets all the files
    df = pd.read_csv(path_df_path)
    df = df[df['id'].isin(id_list)].reset_index(drop=True)
    
    path_for_bboxes = f'/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df/yolov7/path_df_label{label}_only_with_bbox_yolov7.csv'
    
    df_bboxes = pd.read_csv(path_for_bboxes,index_col=0)
    df_bboxes = df_bboxes[df_bboxes['id'].isin(id_list)].reset_index(drop=True)

    df = df.drop(columns = ['bbox_0', 'bbox_1', 'bbox_2', 'bbox_3'])
    df = df.merge(df_bboxes, how='left', on=['id','slice','image_embedding_slice_path', 'segmentation_slice_path', 'image_path'])
    return df

def df_to_dataset(df, as_one_hot=True, pool_labels=True):
    df_hcp = pd.read_csv('/gpfs/home/kn2347/MedSAM/hcp_mapping_processed.csv')
    df_desired = pd.read_csv('/gpfs/home/kn2347/MedSAM/darts_name_class_mapping_processed.csv')
    NUM_CLASSES = len(df_desired)
    label_converter = LabelConverter(df_hcp, df_desired)

    dataset = MRIDataset(df, None, 0, label_converter = label_converter, NUM_CLASSES=NUM_CLASSES, as_one_hot=as_one_hot, pool_labels=pool_labels)
    return dataset





In [5]:
df = get_df_for_ids([162935],3)
dataset = df_to_dataset(df, as_one_hot=True, pool_labels=False)
#dataset[0][0]

<MedSAM_HCP.dataset.MRIDataset at 0x2aaab4d33490>

In [7]:
dataset[128]

(tensor([[[ 0.0042, -0.0061, -0.0064,  ...,  0.0019,  0.0055,  0.0076],
          [-0.0014, -0.0032, -0.0043,  ...,  0.0064,  0.0081,  0.0103],
          [-0.0027, -0.0020, -0.0025,  ...,  0.0066,  0.0076,  0.0040],
          ...,
          [ 0.0104,  0.0188,  0.0177,  ...,  0.0194,  0.0183,  0.0187],
          [ 0.0104,  0.0185,  0.0185,  ...,  0.0208,  0.0196,  0.0200],
          [-0.0049,  0.0066,  0.0056,  ...,  0.0085,  0.0072,  0.0085]],
 
         [[-0.1364, -0.1092,  0.0378,  ..., -0.0442,  0.0215,  0.0362],
          [-0.1626,  0.0035,  0.0028,  ..., -0.0306, -0.0271, -0.0439],
          [-0.1685, -0.1026, -0.0881,  ..., -0.0748, -0.0442, -0.0085],
          ...,
          [-0.0685, -0.1289, -0.1101,  ..., -0.0591, -0.0833, -0.0531],
          [-0.0895, -0.1075, -0.1017,  ..., -0.0794, -0.0805, -0.0619],
          [-0.0559, -0.1176, -0.0856,  ..., -0.1149, -0.1296, -0.0745]],
 
         [[ 0.0424,  0.0165,  0.0683,  ...,  0.0576,  0.0797,  0.0679],
          [ 0.0223, -0.0314,

In [8]:
# code to generate for round2 models: remember, these are single-task, so we'll have to iterate over all of them to collect models and predictions
def dataset_to_predictions_for_singletask_models(id_list, checkpoint_folder_pattern, save_path, device='cuda'):
    files = glob(checkpoint_folder_pattern)

    collector = []
    ctr = 0
    overall_tensor = torch.zeros((256, 256, 256)).cuda()
    for file in files:
        label_num = int(file.split('/')[-2].split('label')[1])
        if label_num==11:
            continue

        df = get_df_for_ids(id_list, label_num)
        dataset = df_to_dataset(df, as_one_hot=True, pool_labels=False)

        dataloader = DataLoader(
            dataset,
            batch_size = 16,
            shuffle = False,
            num_workers = 0,
            pin_memory = True)
        
        model = build_sam_vit_b_multiclass(3, checkpoint=file).to(device)  
        model.eval()
        
        
        this_label_collector = []
        for step, (image_embedding, gt2D, boxes, slice_names) in enumerate(tqdm(dataloader)):
            image_embedding, gt2D, boxes = image_embedding.to(device), gt2D.to(device), boxes.to(device)
            medsam_pred = torch.as_tensor(medsam_inference(
                model, image_embedding, boxes, 256, 256,
                as_one_hot=True, model_trained_on_multi_label=False,
                num_classes=1), dtype=torch.uint8).to(device)
            # hopefully this is B, 1, H, W as a binary 0/1 tensor
            this_label_collector.append(medsam_pred)
        
        total_tensor = torch.cat(this_label_collector, dim=0)[:,0,:,:] # now should be N, H, W
        overall_tensor[torch.where(total_tensor == 1)] = label_num
        #overall_tensor[:, label_num, :, :] = total_tensor[:, 0, :, :]
        ctr+=1
        print(f'done {ctr}/{len(files)}')
    
    all_np = overall_tensor.cpu().detach().numpy()
    #assert all_np.shape == (len(dataset), len(files), 256, 256)

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    np.save(save_path + '/singletask_seg_all.npy', all_np)

    




# checkpoint_folder_pattern = '/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/second_round_w_bbox_yolov7_finetunes_60epochs_8-20-23/label*/best_model_epoch_sam_readable.pth'

# xd = dataset_to_predictions_for_singletask_models(id_list=[162935], 
#                                             checkpoint_folder_pattern = checkpoint_folder_pattern, 
#                                             save_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/saved_round2_segmentations/162935',
#                                             device='cuda')

In [9]:
# do a val image: 151425


checkpoint_folder_pattern = '/gpfs/data/luilab/karthik/pediatric_seg_proj/results_copied_from_kn2347/second_round_w_bbox_yolov7_finetunes_60epochs_8-20-23/label*/best_model_epoch_sam_readable.pth'
xd2 = dataset_to_predictions_for_singletask_models(id_list=[151425], 
                                            checkpoint_folder_pattern = checkpoint_folder_pattern, 
                                            save_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/saved_round2_segmentations/151425_upd',
                                            device='cuda')

100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 1/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 2/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 3/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 4/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 5/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 6/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 7/101


100%|██████████| 16/16 [00:14<00:00,  1.11it/s]


done 8/101


100%|██████████| 16/16 [00:15<00:00,  1.01it/s]


done 9/101


100%|██████████| 16/16 [00:18<00:00,  1.13s/it]


done 10/101


100%|██████████| 16/16 [00:15<00:00,  1.07it/s]


done 11/101


100%|██████████| 16/16 [00:15<00:00,  1.06it/s]


done 12/101


100%|██████████| 16/16 [00:14<00:00,  1.09it/s]


done 13/101


100%|██████████| 16/16 [00:14<00:00,  1.08it/s]


done 14/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 15/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 16/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 17/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 18/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 19/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 20/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 21/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 22/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 23/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 24/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 25/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 26/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 27/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 28/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 29/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 30/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 31/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 32/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 33/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 34/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 35/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 36/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 37/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 38/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 39/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 40/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 41/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 42/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 43/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 44/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 45/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 46/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 47/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 48/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 49/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 50/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 51/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 52/101


100%|██████████| 16/16 [00:14<00:00,  1.11it/s]


done 53/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 54/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 55/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 56/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 57/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 58/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 59/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 60/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 61/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 62/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 63/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 64/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 65/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 66/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 67/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 68/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 69/101


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


done 70/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 71/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 72/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 73/101


100%|██████████| 16/16 [00:14<00:00,  1.11it/s]


done 74/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 75/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 76/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 77/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 78/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 79/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 80/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 81/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 82/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 83/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 84/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 85/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 86/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 87/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 88/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 89/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 90/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 91/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 92/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 93/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 94/101


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


done 95/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 96/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 97/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 98/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


done 99/101


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]

done 100/101





In [27]:
np.sum(xd.cpu().detach().numpy() == 46)

15276