In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from glob import glob
import pandas as pd
import pickle
from torch.utils.data import RandomSampler
import random
import scipy
import torch.nn.functional as F
from PIL import Image
from glob import glob

from MedSAM_HCP.utils_hcp import *

In [18]:
def read_format_table(path, read_gt = False):
    IMG_WIDTH = 256
    IMG_HEIGHT = 256
    if read_gt:
        cols = ['class', 'x_center', 'y_center', 'width', 'height']
    else:
        cols = ['class', 'x_center', 'y_center', 'width', 'height', 'confidence']
    df = pd.read_csv(path, delimiter=' ', header=None, names=cols)
    df = df.sort_values('class').reset_index(drop=True)
    df['x_center'] = (df['x_center'] * IMG_WIDTH)
    df['y_center'] = (df['y_center'] * IMG_HEIGHT)
    df['width'] = (df['width'] * IMG_WIDTH)
    df['height'] = (df['height'] * IMG_HEIGHT)

    return df
def yolov7_format_to_bbox_format(x_center, y_center, width, height):
    x0 = x_center - width / 2.0
    y0 = y_center - height / 2.0

    x1 = x_center + width / 2.0
    y1 = y_center + height / 2.0

    return x0, y0, x1, y1

def extract_box_np_from_df(df, label):
    row = df.loc[df['class'] == label]
    if len(row) == 0:
        return np.full((4,), np.nan)
    if len(row)>1:
        print(row)
    box = np.array(yolov7_format_to_bbox_format(row['x_center'].item(), 
                                      row['y_center'].item(),
                                      row['width'].item(), 
                                      row['height'].item()))
    return box

def run_tag(collect_lists, tag='train', conf_thresh = 0.225, force_stop = None):
    
    for ix, file in enumerate(tqdm(sorted(glob(f'/gpfs/data/luilab/karthik/pediatric_seg_proj/yolov7_results/train_hcp_bbox/{tag}_run/labels/*.txt')))):
        if force_stop is not None and ix >= force_stop:
            break
        basename = os.path.basename(file)
        id_num = int(basename.split('_')[0])
        slice_num = int(basename.split('_slice')[-1].split('.txt')[0])
        dfo = read_format_table(file, read_gt=False)
        dfo = dfo[dfo['confidence'] >= conf_thresh]
        dfo = dfo.sort_values('confidence').drop_duplicates('class', keep='last')    
        dfo['bbox_0'] = (dfo['x_center'] - dfo['width']/2.0).round().astype(int)
        dfo['bbox_1'] = (dfo['y_center'] - dfo['height']/2.0).round().astype(int)
        dfo['bbox_2'] = (dfo['x_center'] + dfo['width']/2.0).round().astype(int)
        dfo['bbox_3'] = (dfo['y_center'] + dfo['height']/2.0).round().astype(int)


        for i, r in dfo.iterrows():
            if r['class']==0:
                continue
            ans = np.array([id_num, slice_num, r['bbox_0'], r['bbox_1'], r['bbox_2'], r['bbox_3']]).astype(int)
            collect_lists[int(r["class"])].append(ans)
    
    return collect_lists
    
def merge_collect_lists(collect_lists, label):
    ori_df = pd.read_csv('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df.csv')
    this_class_df = pd.DataFrame(collect_lists[class_num], columns = ['id', 'slice', 'bbox_0', 'bbox_1', 'bbox_2', 'bbox_3'])
    merged = ori_df.merge(this_class_df, how = 'right', on = ['id', 'slice'])
    return merged

def run_medsam_tag(collect_lists, tag='train', force_stop = None):
    for ix, file in enumerate(tqdm(sorted(glob(f'/gpfs/data/luilab/karthik/pediatric_seg_proj/yolov7_results/train_hcp_bbox/{tag}_run/labels/*.txt')))):
        if force_stop is not None and ix >= force_stop:
            break
        basename = os.path.basename(file)
        id_num = int(basename.split('_')[0])
        slice_num = int(basename.split('_slice')[-1].split('.txt')[0])
        arr = np.load(f'/gpfs/data/luilab/karthik/pediatric_seg_proj/saved_round1_segmentations_bbox/{id_num}/{slice_num}.npy')
        arr = arr[:, [0,2,1,3]]
        masker = np.any(np.isnan(arr), axis=1)
        for i in range(arr.shape[0]):
            if i==0:
                continue
            if masker[i] == True: # NAN's here
                continue
            ans = np.array([id_num, slice_num, arr[i,0], arr[i,1], arr[i,2], arr[i,3]]).astype(int)
            collect_lists[i].append(ans)
    
    return collect_lists

In [24]:
collect_lists_medsam = []
for i in range(NUM_CLASSES):
    collect_lists_medsam.append([])
true_lens = sum([len(x) for x in run_medsam_tag(collect_lists_medsam, 'train', force_stop=10000)][1:])
print(f'medsam bboxes have total calls ~{true_lens}')

  4%|▍         | 10000/228096 [00:10<03:51, 943.58it/s]


In [41]:
print(f'medsam bboxes have total calls ~{true_lens}')

medsam bboxes have total calls ~139977


In [38]:
collect_lists_yolov7 = []
for i in range(NUM_CLASSES):
    collect_lists_yolov7.append([])

yolov7_lens_15 = run_tag(collect_lists_yolov7, tag='train', conf_thresh = 0.15, force_stop = 10000)

collect_lists_yolov7 = []
for i in range(NUM_CLASSES):
    collect_lists_yolov7.append([])

yolov7_lens_20 = run_tag(collect_lists_yolov7, tag='train', conf_thresh = 0.2, force_stop = 10000)

collect_lists_yolov7 = []
for i in range(NUM_CLASSES):
    collect_lists_yolov7.append([])

yolov7_lens_225 = run_tag(collect_lists_yolov7, tag='train', conf_thresh = 0.225, force_stop = 10000)

collect_lists_yolov7 = []
for i in range(NUM_CLASSES):
    collect_lists_yolov7.append([])

yolov7_lens_25 = run_tag(collect_lists_yolov7, tag='train', conf_thresh = 0.25, force_stop = 10000)

collect_lists_yolov7 = []
for i in range(NUM_CLASSES):
    collect_lists_yolov7.append([])

yolov7_lens_30 = run_tag(collect_lists_yolov7, tag='train', conf_thresh = 0.30, force_stop = 10000)

  4%|▍         | 10000/228096 [01:08<24:47, 146.58it/s]
  4%|▍         | 10000/228096 [01:00<22:07, 164.30it/s]
  4%|▍         | 10000/228096 [00:58<21:05, 172.41it/s]
  4%|▍         | 10000/228096 [00:56<20:34, 176.70it/s]
  4%|▍         | 10000/228096 [00:56<20:42, 175.50it/s]


In [42]:
collect_lists_yolov7 = []
for i in range(NUM_CLASSES):
    collect_lists_yolov7.append([])

yolov7_lens_175 = run_tag(collect_lists_yolov7, tag='train', conf_thresh = 0.175, force_stop = 10000)

  4%|▍         | 10000/228096 [01:10<25:38, 141.78it/s]


In [43]:
print(sum([len(x) for x in yolov7_lens_15][1:]))
print(sum([len(x) for x in yolov7_lens_175][1:]))
print(sum([len(x) for x in yolov7_lens_20][1:]))
print(sum([len(x) for x in yolov7_lens_225][1:]))
print(sum([len(x) for x in yolov7_lens_25][1:]))
print(sum([len(x) for x in yolov7_lens_30][1:]))

140427
138651
137050
135502
134049
131120


In [5]:
NUM_CLASSES=103
conf_thresh = 0.225
collect_lists = []
for i in range(NUM_CLASSES):
    collect_lists.append([])

collect_lists = run_tag(collect_lists, 'train', conf_thresh, )
collect_lists = run_tag(collect_lists, 'val', conf_thresh, )
collect_lists = run_tag(collect_lists, 'test', conf_thresh)

  1%|          | 1281/228096 [00:18<55:40, 67.91it/s]  


KeyboardInterrupt: 

In [None]:
ori_df = pd.read_csv('/gpfs/data/luilab/karthik/pediatric_seg_proj/path_df.csv')
for class_num in range(1, NUM_CLASSES):
    this_class_df = pd.DataFrame(collect_lists[class_num], columns = ['id', 'slice', 'bbox_0', 'bbox_1', 'bbox_2', 'bbox_3'])
    merged = ori_df.merge(this_class_df, how = 'right', on = ['id', 'slice'])

    save_path = os.path.join('/gpfs/data/luilab/karthik/pediatric_seg_proj/per_class_isolated_df/yolov7', f'path_df_label{class_num}_only_with_bbox_yolov7.csv')
    merged.to_csv(save_path)
