In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from skimage import io
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import monai
from monai.networks import one_hot
import sys
sys.path.append('./modified_medsam_repo')
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from utils.SurfaceDice import compute_dice_coefficient
from skimage import io, transform
from glob import glob
from sklearn.model_selection import train_test_split
import pandas as pd
import nibabel as nib
import pickle
from torch.utils.data import RandomSampler
import random
import scipy
import torch.nn.functional as F
import img2pdf
from torchmetrics import F1Score

from MedSAM_HCP.dataset import MRIDataset, load_datasets
from MedSAM_HCP.MedSAM import MedSAM, medsam_inference
from MedSAM_HCP.build_sam import build_sam_vit_b_multiclass
from MedSAM_HCP.utils_hcp import *
from PIL import Image

# set seeds
torch.manual_seed(2023)
np.random.seed(2023)

In [2]:
df_hcp = pd.read_csv('/gpfs/home/kn2347/MedSAM/hcp_mapping_processed.csv')
df_desired = pd.read_csv('/gpfs/home/kn2347/MedSAM/darts_name_class_mapping_processed.csv')
NUM_CLASSES = len(df_desired)
label_converter = LabelConverter(df_hcp, df_desired)

path_df_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df_constant_bbox.csv'
train_test_splits_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/train_val_test_split.pickle'
train_dataset, val_dataset, test_dataset = load_datasets(path_df_path, train_test_splits_path, label_id = None, bbox_shift=0, sample_n_slices = None, label_converter=label_converter, NUM_CLASSES=NUM_CLASSES)

In [12]:
train_dataset.load_image(128).min()

0

In [None]:
def proc_arr(arr):
    # arr has shape (classes, *)
    ax_starts = np.argmax(arr, axis=1) # shape (classes)
    ax_ends = arr.shape[1] - 1 - np.argmax(arr[:,::-1], axis=1) # shape (classes)

    maxs = np.max(arr, axis = 1) # shape (classes)
    ax_starts = np.where(maxs == 1, ax_starts, np.nan)
    ax_ends = np.where(maxs == 1, ax_ends, np.nan)

    return ax_starts, ax_ends


def get_bounding_box(seg_tens):
    # seg_tens has shape (256,256,256)
    # return shape (4, classes) - rmin, rmax, cmin, cmax
    
    cols = np.any(seg_tens, axis=1) # (classes, W)
    rows = np.any(seg_tens, axis=2) # (classes, H) of True/False, now find min row and max row with True
    
    rmin, rmax = proc_arr(rows)
    cmin, cmax = proc_arr(cols)
    
    return np.array((rmin, rmax, cmin, cmax))

def conv_format(seg_tens):

    # return shape (num_classes, 4) - x, y, width, height
    if type(seg_tens) == torch.Tensor:
        seg_tens = seg_tens.cpu().detach().numpy()
    res = get_bounding_box(seg_tens) # 4 x num_classes - rmin, rmax, cmin, cmax
    x = res[2, :]
    y = res[0, :]
    width = res[3, :] - res[2, :]
    height = res[1, :] - res[0, :]

    return np.array((x, y, width, height)).T

In [40]:
arr = np.zeros((2, 4, 4))
#arr[0, 0:2, 0:2] = 1
arr[1, 1:3, 1:4] = 1

print(conv_format(torch.Tensor(arr)))

[[nan nan nan nan]
 [ 1.  1.  2.  1.]]


In [None]:
train_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/hcp_ya_slices_npy/dir_structure_for_yolov7/train/labels/'

for i in tqdm(range(len(train_dataset))):
    _, seg_tens, _, img_slice_name = train_dataset[i] # seg_tens is 256x256x256
    bbox = conv_format(seg_tens) # (256, 4) -> rmin, rmax, ymin, ymax
    #print(img_slice_name) # e.g. 100206_slice0.npy
    img_slc = img_slice_name.split('.npy')[0]
    this_path  = os.path.join(train_path, img_slc + '.txt')
    with open(this_path, 'w') as f:
    
        for class_num in range(bbox.shape[0]):
           if not np.isnan(bbox[class_num, 0]):
                
                this_r = bbox[class_num, :].astype(int)

                output_line = f'{class_num} {this_r[0]} {this_r[1]} {this_r[2]} {this_r[3]}'
                f.write(output_line + '\n')
        f.close()
    
    img_itself = np.repeat(train_dataset.load_image(i)[:,:,None], 3, axis=-1) #(256,256,3)
    img_pil = Image.fromarray(img_itself)
    img_pil.save(img_write_path)
    


In [None]:
val_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/hcp_ya_slices_npy/dir_structure_for_yolov7/val/labels/'

for i in tqdm(range(len(val_dataset))):
    _, seg_tens, _, img_slice_name = val_dataset[i] # seg_tens is 256x256x256
    bbox = conv_format(seg_tens) # (256, 4) -> rmin, rmax, ymin, ymax
    #print(img_slice_name) # e.g. 100206_slice0.npy
    img_slc = img_slice_name.split('.npy')[0]
    this_path  = os.path.join(val_path, img_slc + '.txt')
    with open(this_path, 'w') as f:
    
        for class_num in range(bbox.shape[0]):
           if not np.isnan(bbox[class_num, 0]):
                
                this_r = bbox[class_num, :].astype(int)

                output_line = f'{class_num} {this_r[0]} {this_r[1]} {this_r[2]} {this_r[3]}'
                f.write(output_line + '\n')
        f.close()
    
    


In [None]:
test_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/hcp_ya_slices_npy/dir_structure_for_yolov7/test/labels/'

for i in tqdm(range(len(test_dataset))):
    _, seg_tens, _, img_slice_name = test_dataset[i] # seg_tens is 256x256x256
    bbox = conv_format(seg_tens) # (256, 4) -> rmin, rmax, ymin, ymax
    #print(img_slice_name) # e.g. 100206_slice0.npy
    img_slc = img_slice_name.split('.npy')[0]
    this_path  = os.path.join(test_path, img_slc + '.txt')
    with open(this_path, 'w') as f:
    
        for class_num in range(bbox.shape[0]):
           if not np.isnan(bbox[class_num, 0]):
                
                this_r = bbox[class_num, :].astype(int)

                output_line = f'{class_num} {this_r[0]} {this_r[1]} {this_r[2]} {this_r[3]}'
                f.write(output_line + '\n')
        f.close()
    
    


### Test label and image writing

In [8]:
import matplotlib.patches as patches

In [None]:
rt_path = '/gpfs/data/luilab/karthik/pediatric_seg_proj/hcp_ya_slices_npy/dir_structure_for_yolov7/train/images/100206_slice128.png'


im = np.array(Image.open(rt_path))[...,:3]

fig, ax = plt.subplots()
ax.imshow(im)


file = open('/gpfs/data/luilab/karthik/pediatric_seg_proj/hcp_ya_slices_npy/dir_structure_for_yolov7/train/labels/100206_slice128.txt', 'r')
lines = file.readlines()
print(lines)
total_dim = 256
for i in range(len(lines)):
    this_line = lines[i].split(' ')
    this_line[-1] = this_line[-1].split('\n')[0]

    class_num = int(this_line[0])
    center_x = float(this_line[1]) * total_dim
    center_y = float(this_line[2]) * total_dim
    width = float(this_line[3]) * total_dim
    height = float(this_line[4]) * total_dim


    min_x = center_x - width/2
    max_x = center_x + width/2
    min_y = center_y - height/2
    max_y = center_y + height/2

    rect = patches.Rectangle((min_x, min_y),width, height,linewidth=1,edgecolor='r',facecolor='none')
    ax.add_patch(rect)