In [1]:
import sys
import os
sys.path.insert(1, os.path.realpath(os.path.pardir))

In [2]:
import os
import copy
import torch

import random
from torch import nn

from typing import List, Sequence, Tuple, Union, Dict
from scipy import ndimage

from MedicalNet.MedicalNet import Struct, MedNet
import yolov9_head_func_3d as y9

from PIL import Image
import pandas as pd

from collections import defaultdict
import numpy as np
from tqdm.notebook import tqdm


from torchvision.tv_tensors import BoundingBoxes as BB
from torchmetrics.detection import MeanAveragePrecision
import torchvision.transforms.v2 as v2

import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:
torch.cuda.is_available()

postprocessing found bboxes:
Include guaranteed information about bboxes in image. 
-> 5 classes each with one bbox or 1 class with 5 bbox
-> If there exist level n-1 and n+1 in image there must also exist level n
-> The boxes of levels n-1, n, n+1 must be aligned next to another in heigh dimention
-> The boxes must overlap in x dimention to some extend 
-> If there exist series of boxes for levels n, n+1, n+2 and image height is bigger than mean level heigh than the n+3 level must also exist - same situation in reverse


# DATASET

In [4]:
class Bbox3d():
    def __init__(self, x, y, z) -> None:
        # 3d box in coordinates of sagittal series
        self.x = x
        self.y = y
        self.z = z
    
    def get_box_in_view_type(self, view_type, d3:bool=False):
        if view_type in ['sagittal', 'sagittal_t2']:
            return self.get_sagittal(d3)
        elif view_type == 'coronal':
            return self.get_coronal(d3)
        elif view_type == 'axial':
            return self.get_axial(d3)
    
    def get_sagittal(self, d3:bool=False):
        if d3:
            return [self.x[0], self.y[0], self.x[1], self.y[1], self.z[0], self.z[1]]
        return [self.x[0], self.y[0], self.x[1], self.y[1]]

    def get_coronal(self, d3:bool=False):
        if d3:
            return [self.z[0], self.y[0], self.z[1], self.y[1], self.x[0], self.x[1]]
        return [self.z[0], self.y[0], self.z[1], self.y[1]]

    def get_axial(self, d3:bool=False):
        if d3:
            return [self.z[0], self.x[0], self.z[1], self.x[1], self.y[0], self.y[1]]
        return [self.z[0], self.x[0], self.z[1], self.x[1]]

In [8]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        # data_info: dict consisting of series types and dataframe with their info
        # config: dict - dataset configuration
        #TODO: Better overlap in instance dimention
        #TODO: smart limit - selecting important slices based on series type and where they usually lay
        
        self.supress_warnings = config['supress_warinings']
        self.used_series_types = config['load_series']
        self.study_ids = np.unique(np.concatenate([series.study_id.unique() for series in data_info.values() if series is not None]))
        self.data_info = data_info
        if not np.all([tp in list(self.data_info.keys()) for tp in self.used_series_types]):
            raise Exception("Series types to use do not match provided data information.")
        
        self.series_out_types = config['series_out_types']
        self.current_view = self.series_out_types[0]
        self.return_series_type = config['return_series_type']

        self.d3 = True #config['3d_box'] if '3d_box' in list(config.keys()) else False
        self.preload = config['preload']
        self.rng = np.random.default_rng()         
        if self.preload and len(self.series_out_types)>1:
            self.preload = False
            if not self.supress_warnings:
                print("Preloading with mixed series type outputs is not supported. Preloading was turned off.")

        self.im_size = config['im_size']
        self.og_im_size = [20, 512, 512]

        self.image_type = config['image_type']
        self.rev_lr = config['rev'] if 'rev' in list(config.keys()) else False
        if self.preload and not self.supress_warnings:
            print("Warning! Preloading of images is turned on. The program will attempt to load whole dataset into memory!")
        
        self.transforms = config['transforms']
        self.normalize = config['normalize']
        self.vsa = config['vsa']
        
        self.dataset_type = config['dataset_type']
        self.dataset_path = config['dataset_path']

        self._condition_list = ['Left Neural Foraminal Narrowing', 
                                'Left Subarticular Stenosis', 
                                'Right Neural Foraminal Narrowing', 
                                'Right Subarticular Stenosis', 
                                'Spinal Canal Stenosis']
        
        self.get_condition = [self._condition_list.index(cond) for cond in config['get_conditions']]
        self.condition_to_get = config['get_conditions']
        self._status_map = {'Normal/Mild': [1., 0., 0.],
                            'Moderate': [0., 1., 0.],
                            'Severe': [0., 0., 1.]}
        
        self.level_ind = {'L1/L2': 0, 'L2/L3': 1, 'L3/L4': 2, 'L4/L5': 3, 'L5/S1':4}

        self.box_labels = torch.tensor([list(self.level_ind.values())]) 

        self.limit = config['limit_series_len']
        self.series_length = 15
        
        self.x_overhead = config['x_overhead'] if config['x_overhead'] else [30, 30]
        self.z_overhead = config['z_overhead']
        self.overlap_levels = config['overlap_levels']
        self.y_overlap = config['y_overlap'] if self.overlap_levels else 0

        self.cond_x_overhead = config['cond_x_overhead']
        self.cond_y_overhead = config['cond_y_overhead']
        self.cond_z_overhead = config['cond_z_overhead']

        if self.limit and not config['series_len'] and not self.supress_warnings:
            print("Series length is not specified. The limit is set to 15.")
        elif self.limit and config['series_len']:
            self.series_length = config['series_len']
            
        self.series_type_ind = {}
        for s in self.used_series_types:
            self.series_type_ind[s] = []

        self.data = []
        self.prepare_data()

    def get_level_boxes(self, series_info:pd.DataFrame, stype='sagittal'):
        #default for each scan has 5 visible levels
        bboxes = []
        labels = []
        cond_boxes = []
        cond_labels = []
        z_overlap = self.z_overhead
        if stype in ['sagittal', 'sagittal_t2']:
            for _, row in series_info.iterrows():
                sl = np.array(row.slice_locations)
        
                #z_overhead = -self.z_overhead if row.reversed else self.z_overhead
                fr_a = sl[row.present_instances.index(min(row.instance_number))]
                too_a = sl[row.present_instances.index(max(row.instance_number))]
                fr = min(fr_a, too_a)
                too = max(fr_a, too_a)
 
                z_min = np.argmin(abs(sl - (fr-self.z_overhead )))
                z_max = np.argmin(abs(sl - (too+self.z_overhead)))

                labels.append(self.level_ind[row.level])
                bboxes.append(Bbox3d(
                    x=[max(0, min(row.x)-self.x_overhead[0]/row.pixel_spacing[0]), min(row.image_width, max(row.x)+self.x_overhead[1]/row.pixel_spacing[0])],
                    y=[max(0, row.level_boundaries[0]-self.y_overlap/row.pixel_spacing[1]), min(row.image_height, row.level_boundaries[1]+self.y_overlap/row.pixel_spacing[1])],
                    z=[min(z_max, z_min), max(z_max, z_min)]))
                
                c, cl = self.get_condition_boxes(row, stype=stype)
                cond_boxes.append(c)
                cond_labels.append(cl)

        elif stype=='axial':
            y_overlap = self.y_overlap + 0.5 
            for _, row in series_info.iterrows():
                labels.append(self.level_ind[row.level])
                bboxes.append(Bbox3d(
                    x=[max(0, min(row.y)-self.x_overhead[0]/row.pixel_spacing[1]), min(row.image_width, max(row.y)+self.x_overhead[1]/row.pixel_spacing[1])],
                    y=[max(0, row.present_instances.index(min(row.level_slices))-y_overlap), min(len(row.present_instances), row.present_instances.index(max(row.level_slices))+y_overlap)],
                    z=[max(0, min(row.x)-self.z_overhead/row.pixel_spacing[0]), min(row.image_height, max(row.x)+self.z_overhead/row.pixel_spacing[0])]
                                ))

                c, cl = self.get_condition_boxes(row, stype=stype)
                cond_boxes.append(c)
                cond_labels.append(cl)

        return bboxes, np.array(labels), cond_boxes, cond_labels
    
    def get_condition_boxes(self, row:pd.DataFrame, stype='sagittal'):
        bboxes = []
        labels = []

        if stype in ['sagittal', 'sagittal_t2']:
            for condition, x,y,z in zip(row.condition, row.x, row.y, row.instance_number):
                if condition in self.condition_to_get:
                    sl = np.array(row.slice_locations)
                    #z_overhead = -self.z_overhead if row.reversed else self.z_overhead
                    fr_a = sl[row.present_instances.index(z)]
                    z_min = np.argmin(abs(sl - (fr_a-self.cond_z_overhead[condition][0])))
                    z_max = np.argmin(abs(sl - (fr_a+self.cond_z_overhead[condition][1])))

                    tmp = [(b,a) for a,b in zip(row.IPP, row.present_instances)]
        
                    reversed = True if tmp[0][1][0] > tmp[-1][1][0] else False
                    if reversed:
                        if "Left" in condition:
                            zz = [3, 2]
                        else:
                            zz = [2, 3]
                    else:
                        if "Left" in condition:
                            zz = [2, 3]
                        else:
                            zz = [3, 2]
                          
                    labels.append(self.condition_to_get.index(condition))
                    bboxes.append(Bbox3d(
                        x=[max(0, x-self.cond_x_overhead[condition][0]/row.pixel_spacing[0]), min(row.image_width, x+self.cond_x_overhead[condition][1]/row.pixel_spacing[0])],
                        y=[max(0, y-self.cond_y_overhead[condition][0]/row.pixel_spacing[1]), min(row.image_height, y+self.cond_y_overhead[condition][1]/row.pixel_spacing[1])],
                        z=[max(0, z-zz[0]), min(len(sl)-1, z+zz[1])]))
                
        elif stype=='axial':
            for condition, x,y,z in zip(row.condition, row.x, row.y, row.instance_number):
                if condition in self.condition_to_get:
                    sl = np.array(row.slice_locations)
                    fr_a = sl[row.present_instances.index(z)]
                    z_min = np.argmin(abs(sl - (fr_a-self.cond_y_overhead[condition][0])))
                    z_max = np.argmin(abs(sl - (fr_a+self.cond_y_overhead[condition][1])))

                    labels.append(self.condition_to_get.index(condition))
                    bboxes.append(Bbox3d(
                        x=[max(0, y-self.cond_x_overhead[condition][0]/row.pixel_spacing[1]), min(row.image_width, y+self.cond_x_overhead[condition][1]/row.pixel_spacing[1])],
                        y=[min(z_max, z_min), max(z_max, z_min)],#[max(0, row.present_instances.index(z)-self.cond_z_overhead[0]), min(len(row.level_instances), row.present_instances.index(z)+self.cond_z_overhead[1])],
                        z=[max(0, x-self.cond_z_overhead[condition][0]/row.pixel_spacing[0]), min(row.image_height, x+self.cond_z_overhead[condition][1]/row.pixel_spacing[0])]
                                    ))
    
        return bboxes, np.array(labels)

    def set_view(self, new_view):
        self.current_view = new_view
        
    def get_condition_labels(self, series_info:pd.DataFrame):
        labels = []
        cond_presence_masks = []
        level_presence_mask = []

        for level, _ in self.level_ind.items():
            if not series_info[series_info['level']==level].empty:
                labels.append(series_info[series_info['level']==level].iloc[0].status)
                cond_presence_masks.append(series_info[series_info['level']==level].iloc[0].presence_mask)
                level_presence_mask.append(True)
            else:
                level_presence_mask.append(False)

        return np.array(labels), np.array(cond_presence_masks), np.array(level_presence_mask)

    def info2dict(self, series_info, stype=None): #remember axials can be combination of different serieses (sagittals can't)
        level0 = series_info.iloc[0]
        data_dict = {}
        boxes, box_labels, cond_boxes, cond_labels = self.get_level_boxes(series_info, stype=stype)
        labels, label_level_mask, label_cond_mask = self.get_condition_labels(series_info)

        data_dict['study_id'] = level0.study_id
        data_dict['series_id'] = level0.series_id
        data_dict['width'] = level0.image_width
        data_dict['height'] = level0.image_height
        data_dict['reversed'] = level0.reversed
        data_dict['series_type'] = stype
        data_dict['pixel_spacing'] = level0.pixel_spacing

        data_dict['boxes'] = boxes 
        data_dict['files'] = [f"{self.dataset_path}/{data_dict['study_id']}/{data_dict['series_id']}/{instance}.{self.image_type}" for instance in level0.present_instances]
        data_dict['box_labels'] = box_labels
        data_dict['labels'] = labels
        data_dict['label_presence_mask'] = label_level_mask
        data_dict['cond_boxes'] = cond_boxes
        data_dict['cond_labels'] = cond_labels
        
        return data_dict
   
    def prepare_data(self):
        # prepare paths for every image to load
        with tqdm(total=len(self.study_ids), desc="Preparing data: ") as pbar:
            for study_id in self.study_ids:
                study_dict = dict(
                                sagittal=[], 
                                sagittal_t2=[], 
                                axial=[])
                for stype in self.used_series_types:
                    for series_id in self.data_info[stype].query(f'study_id == {study_id}').series_id.unique():
                        ddict = self.info2dict(self.data_info[stype][self.data_info[stype].series_id==series_id], stype=stype)
                        if self.preload:
                            ddict = self.preload_series(ddict)
                        study_dict[stype].append(ddict)

                self.data.append(study_dict)
                pbar.update(1)

    def split_to_series(self):
        temp = []
        batches = []
        i=0
        for data in self.data:
            batch = []
            for series in data.values():
                if series:
                    temp+=series
                    batch.append(i)
                    i+=1
            batches.append(batch)

        self.data = temp
        self.batches = batches
    
    def change_img_view(self, img, current_view, new_view):
        if current_view in ['sagittal', 'sagittal_t2']:
            if new_view in ['sagittal', 'sagittal_t2']:
                return img
            elif new_view=='coronal':
                return img.transpose(2, 1, 0) # n, h, w -> w, h, n
            elif new_view=='axial':
                return img.transpose(1, 2, 0) #n, h, w -> h, n, w
        elif current_view=='axial':
            if new_view in ['sagittal', 'sagittal_t2']:
                return img.transpose(2, 0, 1) # n, h, w -> w, n, h
            elif new_view =='coronal':
                return img.transpose(1, 0, 2) # n, h, w -> h, n, w
            elif new_view=='axial':
                return img
            
    def load_series(self, data) -> np.ndarray:
        #print(data)
        st = data['series_type']
        boxes = np.array([box.get_box_in_view_type(st, d3 = True) for box in data['boxes']], dtype=int)
        level_labels = data['box_labels']
        cond_labels = data['labels']
        #print(boxes)
        oimg = np.zeros((len(data['files']), data['height'], data['width']), dtype = np.uint8)
        for i, path in enumerate(data['files']):  
            try:      
                oimg[i,:,:] = np.array(Image.open(path), dtype = np.uint8)
            except ValueError:
                temp = np.array(Image.open(path), dtype = np.uint8) 
                oimg[i,:temp.shape[0],:temp.shape[1]] = temp
                del temp

        oimg = self.change_img_view(oimg, data['series_type'], st)
        oimg = oimg
        
        imgs = np.zeros((5, self.series_length, self.im_size[1], self.im_size[0]))
        condition_boxes = np.zeros((5, len(self.condition_to_get), 6))
        condition_labels = np.zeros((5, len(self.condition_to_get)))

        for i, (box, label, cond_label) in enumerate(zip(boxes, level_labels, cond_labels)):
            try:
                box[0] = int(max(0, box[0]+np.random.randint(-20, 20)))
                box[1] = int(max(0, box[1]+np.random.randint(-20, 20)))
                box[2] = int(min(oimg.shape[2], box[2]+np.random.randint(-20, 20)))
                box[3] = int(min(oimg.shape[1], box[3]+np.random.randint(-20, 20)))
                # box[4] = int(max(0, box[4]+np.random.randint(-1, 1)))
                # box[5] = int(min(oimg.shape[0], box[5]+np.random.randint(-1, 1)))
                lb = np.array([bb.get_box_in_view_type(st, d3 = True) for bb in data['cond_boxes'][i]], dtype=int)-np.array([box[0], box[1], box[0], box[1], box[4], box[4]])
            except ValueError:
                continue
            cl = data['cond_labels'][i]
            for cll in cl:
                condition_labels[i,cll]=np.argmax(cond_label[self._condition_list.index(self.condition_to_get[cll])])
            ig, lb = self.__resize_data__(oimg[box[4]: box[5], box[1]:box[3], box[0]:box[2]], lb)
            if 0 in ig.shape:
                continue
            imgs[label] = self.__itensity_normalize_one_volume__(ig)
            condition_boxes[i,cl,:] = lb
        #print(imgs, imgs.shape)

        return imgs, condition_boxes, condition_labels
    
    
    def __itensity_normalize_one_volume__(self, volume):
        """
        normalize the itensity of an nd volume based on the mean and std of nonzeor region
        inputs:
            volume: the input nd volume
        outputs:
            out: the normalized nd volume
        """
        #out = (volume - 0.5*255.)/(0.5 *255.)

        pixels = volume[volume > 0]
        if pixels.size==0:
            return volume
        
        mean = pixels.mean()
        std  = pixels.std()
        out = (volume - mean)/std
        #out_random = np.random.normal(0, 1, size = volume.shape)
        #out[volume == 0] = out_random[volume == 0]

        return out

    def __resize_data__(self, data, bb=None):
        """
        Resize the data to the input size
        """ 
        
        [depth, height, width] = data.shape
        scale = [self.series_length*1.0/depth, self.im_size[1]*1.0/height, self.im_size[0]*1.0/width]
        #scale_d = [self.series_length*1.0/depth, 1, 1]
        if bb is not None:
            bb[..., [0,2]] = bb[..., [0,2]]/width*self.im_size[0]
            bb[..., [1,3]] = bb[..., [1,3]]/height*self.im_size[1]
            bb[..., [4,5]] = bb[..., [4,5]]/depth*self.series_length

        data = ndimage.zoom(data, scale, order=1)
        #data = ndimage.zoom(data, scale_d, order=0)
        #data, bb = self.limit_series(data, bb)
        
        if bb is not None:
            return data, bb
        return data
    
    
    def limit_series(self, img, z_bb=None):
        if self.series_length < img.shape[0]:
                if self.d3 and z_bb:
                    z_bb = z_bb* self.series_length/img.shape[0]
                st = np.round(np.linspace(0, img.shape[0] - 1, self.series_length)).astype(int)
                img = img[st,:,:]
        elif self.series_length > img.shape[0]:
                img = np.pad(img, ((0, self.series_length-img.shape[0]), (0, 0), (0, 0)))
        if z_bb:
            return img, z_bb
        return img
    
    def coarse_dropout_3d(self, volume, max_holes_num, max_hole_size):
        # set cut data from volume but leave parts around points with condition of intrest untoutched.
        # draw number of holes
        hol_num = self.rng.integers(1, max_holes_num)
        # for each hole draw its placement on image and size
        placement_and_size = self.rng.integers(low=0, 
                                      high = [[volume.shape[0], volume.shape[1], volume.shape[2], 
                                               max_hole_size[0], max_hole_size[1], max_hole_size[2]]]*hol_num)
        for i in range(hol_num):
            c = placement_and_size[i,:]
            volume[c[0]:min(volume.shape[0], c[0]+c[3]), c[1]:min(volume.shape[1], c[1]+c[4]), c[2]: min(volume.shape[2], c[2]+c[5])] = 0.
        return volume
    
    def prepare_series(self, img, boxes, labels):

        #oimg = self.load_series(data, preloaded=self.preload)
        level_labels = torch.arange(0,len(self.condition_to_get)) #torch.tensor(labels).reshape(-1)#torch.arange(0,len(self.condition_to_get))
        if True: # unify foraminals with no regard to side
             level_labels[level_labels==1] = 0 # right subarticular label as left
             level_labels[level_labels==3] = 1 # right subarticular label as left
             level_labels[level_labels==2] = 1 # right subarticular label as left
        boxes = torch.tensor(boxes) 

        all_data = torch.cat([level_labels.unsqueeze(-1), boxes], dim = -1)
        all_data = all_data[all_data[:,1:].sum(dim=-1) != 0]
        boxes = all_data[:,1:]
        level_labels = all_data[:,0]
        
        z_bb = boxes[..., [-2, -1]]
        boxes = boxes[..., 0:4]
            
        boxes = BB(boxes, format='XYXY', canvas_size=(img.shape[1] , img.shape[2]), dtype=torch.float)
        img = torch.tensor(img,dtype=torch.float)
        target = {
                "boxes": boxes,
                "labels": level_labels}
        if self.transforms:
            target = {
                "boxes": boxes,
                "labels": level_labels}
            
            img, target = self.transforms(img, target)

        boxes = target['boxes']
        level_labels = target['labels']
        boxes = torch.cat([boxes, z_bb], dim=-1)[:, [0,1,4,2,3,5]]

        if torch.rand(1) < 0.5 and self.vsa:
            img = self.coarse_dropout_3d(img,12,[int(self.series_length/3),int(self.im_size[1]/3), int(self.im_size[0]/3)])

        boxes[torch.logical_or(boxes[:,0] > img.shape[2]-img.shape[2]/5, boxes[:,3]-img.shape[2]/5 < 0)] *= 0
        boxes[torch.logical_or(boxes[:,1] > img.shape[1]-img.shape[1]/5, boxes[:,4]-img.shape[1]/5 < 0)] *= 0
       #boxes[torch.logical_or(boxes[:,3] > img.shape[0]-img.shape[0]/4, boxes[:,5]-img.shape[0]/4 < 0)] *= 0

        boxes = boxes.clamp(
            min=torch.tensor([0.,0.,0.,0.,0.,0.], dtype = torch.float), max=torch.tensor([img.shape[2]-1, img.shape[1]-1, img.shape[0]-1, img.shape[2]-1, img.shape[1]-1, img.shape[0]-1], dtype = torch.float))
        if img.argmax() < 1:
            boxes *=0
        
        return img, boxes, level_labels
    
    def __getitem__(self, index: int):
        return None
        
    def __len__(self):
        return(len(self.data))

In [9]:
class BoxDatasetUnited(Dataset):
    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        super(BoxDatasetUnited, self).__init__(data_info, config)

        self.mix_strategy = config['mix_strategy']
        if config['one_label']:
            self.box_labels = torch.tensor([0,0,0,0,0], dtype=torch.float)
        else:
           self.box_labels = torch.tensor([list(self.level_ind.values())]) 

        #split data to individual serieses from dict of study-series_type pairs
        self.split_to_series()
    def __getitem__(self, index: int)->tuple[np.ndarray, np.ndarray]:

        adata = self.data[index]
        #print(adata)
        # limit series (if limit==True)
        all_view_img = torch.zeros(5, self.series_length, self.im_size[1], self.im_size[0])
        all_view_boxes = torch.zeros((5, len(self.condition_to_get), 6))
        all_view_labels = torch.zeros((5, len(self.condition_to_get), 1))
        try:
            imgs, boxes, labels = self.load_series(adata)
        except ZeroDivisionError: # one series was skipped in data preparation and now it raises exception ;_;
            if self.return_series_type:
                return all_view_img.to(dtype=torch.float), all_view_boxes.to(dtype=torch.float), all_view_labels.to(dtype=torch.float), adata['series_type']
            return all_view_img.to(dtype=torch.float), all_view_boxes.to(dtype=torch.float), all_view_labels.to(dtype=torch.float)
        
        for i, (img, box, label) in enumerate(zip(imgs, boxes, labels)):
            img, box, label= self.prepare_series(img, box, label)
            if torch.count_nonzero(img) > 0:
                all_view_img[i] = img
                all_view_boxes[i,:box.shape[0]] = box
                all_view_labels[i,:label.shape[0]] = label.unsqueeze(-1)
                
        if self.return_series_type:
            return all_view_img, all_view_boxes, all_view_labels, adata['series_type']
        
        return all_view_img, all_view_boxes, all_view_labels
    
    def __len__(self):
        return(len(self.data))
    
    def get_all_random(self):
        ai, ab, al, asl = [], [], [], []
        for stype in self.used_series_types:
            all_view_img, all_view_boxes, all_view_labels, s = self.get_random_by_stype(stype)
            ai.append(all_view_img.unsqueeze(0))
            ab.append(all_view_boxes.unsqueeze(0))
            al.append(all_view_labels.unsqueeze(0))
            asl.append(s)
        return torch.cat(ai, dim=0), torch.cat(ab, dim=0),  torch.cat(al, dim=0),  asl

    def get_random_by_stype(self, series_type):
        if self.preload:
            same_s_types = [ind for ind, data in enumerate(self.data) if data['o_series_type'] == series_type]
        else:
            same_s_types = [ind for ind, data in enumerate(self.data) if data['series_type'] == series_type]
            
        return self[random.choice(same_s_types)]

In [None]:
im_size = 128

train_transforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5), v2.RandomHorizontalFlip(p = 0.5), v2.RandomAffine(degrees=5), 
                     v2.RandomRotation(degrees=(90,90)), v2.RandomRotation(degrees=(90,90))]),
    v2.RandomChoice([v2.RandomPerspective(distortion_scale=0.2, p=1.0), v2.RandomAffine(degrees=0, translate=(0.3,0.3), shear=(-5,5,-5,5))]), # translation + shearing
    v2.RandomAffine(degrees=0, scale=(0.8,1.2)), #scaling
    v2.RandomChoice([v2.ElasticTransform(alpha=40.0), v2.GaussianBlur(kernel_size=(5,5), sigma=(0.7, 0.7))]),
])


val_transforms = v2.Compose([
     v2.RandomRotation(degrees=(90,90)),
])

cond_x_overhead = {
    'Left Neural Foraminal Narrowing': [6,9],
    'Right Neural Foraminal Narrowing': [6,9],
    'Left Subarticular Stenosis': [5,5],
    'Right Subarticular Stenosis': [5,5],
    'Spinal Canal Stenosis': [10, 10]
}
cond_y_overhead = {
    'Left Neural Foraminal Narrowing': [12,12],
    'Right Neural Foraminal Narrowing': [12,12],
    'Left Subarticular Stenosis': [4,4],
    'Right Subarticular Stenosis': [4,4],
    'Spinal Canal Stenosis': [200, 200]
}
cond_z_overhead = {
    'Left Neural Foraminal Narrowing': [8,8],
    'Right Neural Foraminal Narrowing': [8,8],
    'Left Subarticular Stenosis': [8,8],
    'Right Subarticular Stenosis': [8,8],
    'Spinal Canal Stenosis': [20, 20]
}

train_dataset_config ={
    'preload': None, # preload data into memory (WARNING IT MAY TAKE A LOT OF SPACE - DEPENDS ON THE DATASET USED)
    'im_size': [im_size, im_size],
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset",
    'load_series': ['sagittal'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
    'united': True,

    'transforms': train_transforms,
    'vsa': True,
    'one_label': False, # use one label for every level (do not differentiate between levels)
    'series_out_types': ['sagittal', 'axial', 'sagittal_t2'], # mix output series types to get views necessary to create 3d box ['sagittal', 'coronal'
    'get_conditions':['Left Neural Foraminal Narrowing', 'Right Neural Foraminal Narrowing'], # condition to output from series
    'mix_strategy': 'combined', # strategy for mixing output series types ['random', 'custom', 'combined'] random - randomly select output type, 
                              #'manual' - will return data based on currently choosen view, 'combined' will return all views in one call
    'return_series_type': False, # If True getitem will also return series orignial type
    
    'normalize': False,
    'image_type': 'png',
    'preload': False,

    'dataset_type': 'boxes', #'boxes', 'conditions'
    'supress_warinings': False, 

    'limit_series_len': True, # if True the series length will be limited to number N specified by 'series_len' parameter
    'series_len': 32, #  maximal number of slices in series

    'x_overhead': [50,50], # overhead for levels in x-dim (in mm)
    'z_overhead':59,
    'overlap_levels': True, # if true the level upper and lower boundary will overlap with value specified in 'y_overlap'
    'y_overlap': 50, # overlap size of levels boundaries (in mm)
    'cond_x_overhead': cond_x_overhead,
    'cond_y_overhead': cond_y_overhead,
    'cond_z_overhead': cond_z_overhead,

    '3d_box': True
}

tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv').iloc[0:100]

data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
train_ids = tsd.study_id.unique()

train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
            'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
            'axial': data_axial[data_axial.study_id.isin(train_ids)]}

bb = BoxDatasetUnited(train_data, train_dataset_config)

In [None]:
color_dict = dict([(0,'g'),(1,'b'), (2,'r')])

fig, ax = plt.subplots(5, 3, figsize = (30, 30))
inputs, boxes, labels = bb[4]

#imgs.append(inputs)
# print ground truth
for i, (img, box, label) in enumerate(zip(inputs.squeeze(), boxes.squeeze(), labels.squeeze())):
    ax[i,0].imshow(img[int(box[0,5]),:,:].detach().cpu().numpy())
    ax[i,1].imshow(img.permute(2,1,0)[int(box[0,2]+box[0,5])//2,:,:].detach().cpu().numpy())
    ax[i,2].imshow(img.permute(2,0,1)[16,:,:].detach().cpu().numpy())

    for j in range(box.shape[0]):
        l = label[j].numpy()
        #color = color_dict[int(l)]

        b_sag = box[j,[0,1,3,4]].numpy()
        b_cor = box[j,[2,1,5,4]].numpy()
        b_ax  = box[j,[1,2,4,5]].numpy()
        ax[i,0].add_patch(patches.Rectangle((b_sag[0], b_sag[1]), b_sag[2]-b_sag[0], b_sag[3]-b_sag[1], linewidth=1, edgecolor = color_dict[int(l)], facecolor='none'))
        ax[i,1].add_patch(patches.Rectangle((b_cor[0], b_cor[1]), b_cor[2]-b_cor[0], b_cor[3]-b_cor[1], linewidth=1, edgecolor = color_dict[int(l)], facecolor='none'))
        ax[i,2].add_patch(patches.Rectangle((b_ax[0], b_ax[1]), b_ax[2]-b_ax[0], b_ax[3]-b_ax[1], linewidth=1, edgecolor = color_dict[int(l)], facecolor='none'))

plt.show()

# MODEL TRAINER

In [12]:
class SagittalTrainer():
    def __init__(self, model, model_params, config, train_data, eval_data) -> None:

        self.print_evaluation = config["print_evaluation"] if "print_evaluation" in config else False
        self.steps_per_plot = config["steps_per_plot"]

        if config['train_dataset_config']['mix_strategy'] and len(config['train_dataset_config']['series_out_types'])>1:
            print("Warning! With mix strategy set to 'combined' and multiple series out types the batch size will be (len(series_out_types)) times bigger.")

        self.checkpoints = config['checkpoints']
        self.save_path = config['save_path']
        self.step_per_save = config['step_per_save']

        self.config = config

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Device set to {self.device}")

        self.model_name = config['model_name']
        self.model = model(**model_params).to(self.device)
        #self.model.to(self.device)
        
        self.optimizer = config["optimizer"](self.model.parameters(),**config["optimizer_params"])

        self.series_len = config['val_dataset_config']['series_len']
        self.dataloaders = {'train': torch.utils.data.DataLoader(BoxDatasetUnited(train_data, config['train_dataset_config']), batch_size=config["batch_size"], shuffle=True, num_workers=12, prefetch_factor=1),
                           'val': torch.utils.data.DataLoader(BoxDatasetUnited(eval_data, config['val_dataset_config']), batch_size=config["batch_size"], shuffle=False, num_workers=12, prefetch_factor=1)}
        
        self.max_epochs = config["epochs"]
        self.early_stopping = config['early_stopping']
        self.early_stopping_tresh = config['early_stopping_treshold']

        # scheduler
        if config["scheduler"]:
            if 'epochs' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['epochs'] = self.max_epochs
            if 'steps_per_epoch' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['steps_per_epoch'] = len(self.dataloaders['train'])
            self.scheduler = config["scheduler"](self.optimizer,**config["scheduler_params"])
            self.one_cycle_sched = self.scheduler.__class__.__name__ == 'OneCycleLR'
        else:
            self.scheduler = None
        
        ## Evaluation metrics
        self.series_in_types = config['train_dataset_config']['load_series']

        self.out_stypes = config['train_dataset_config']['series_out_types']
        self.batch_size = config["batch_size"] * len(self.out_stypes)
        self.condition_to_get = config['train_dataset_config']['get_conditions']
        self.mAP = MeanAveragePrecision(box_format = 'xyxy', iou_type='bbox', extended_summary=True).to(self.device)
        self.all_maps = []
        self.mAP_split = []
        self.best_map = 0.
        self.metrics_to_print = ['map', 'map_50', 'map_75']

    def save_model(self):
        torch.save(self.model.state_dict(), os.path.join(self.save_path, f"{self.model_name}_best.pt"))

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))

    def train(self):
        for epoch in range(self.max_epochs):
            print(f"Epoch {epoch+1}/{self.max_epochs}")
            self.train_one_epoch()
            self.eval_one_epoch()
            #print examples
            if (epoch+1)%self.steps_per_plot==0:
                self.plot_examples()
            #checkpoint
            if epoch%5==0:
                pass

    def train_one_epoch(self):

        self.model.train()  # Set model to training mode
        metrics = defaultdict(list)
        
        with tqdm(self.dataloaders['train'], unit = "batch",
                    total = len(self.dataloaders['train'])) as tepoch:
            for inputs, boxes, labels in self.dataloaders['train']:
                inputs = inputs.to(self.device).reshape(-1, self.series_len, inputs.shape[-2], inputs.shape[-1])
                boxes = boxes.to(self.device).reshape(-1, len(self.condition_to_get), 6)
                labels = labels.to(self.device).reshape(-1, len(self.condition_to_get), 1)
                valid = [i for i in range(inputs.shape[0]) if (inputs[i].argmax() > 0 and boxes[i].argmax()>0)]
                if len(valid) <1:
                    tepoch.update(1)
                    continue
                with torch.set_grad_enabled(True):
                    loss, loss_info = self.model.get_loss(inputs[valid], boxes[valid], labels[valid])
                    for loss_t, loss_v in loss_info.items():
                        metrics[loss_t].append(loss_v.clone().detach().cpu().numpy())
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    if self.scheduler and self.one_cycle_sched:
                        self.scheduler.step()

                #update tqdm data
                tepoch.set_description(self.metrics_description(metrics, 'train'))
                tepoch.update(1)

        if self.scheduler and not self.one_cycle_sched:
            self.scheduler.step()


    def eval_one_epoch(self):
        self.model.eval()
        targets = []
        preds = []
        stypes = []

        alabels = []
        apreds = []
        aweight = []

        with tqdm(self.dataloaders['val'], unit = "batch",
                            total = len(self.dataloaders['val'])) as tepoch:
            for inputs, boxes, labels, stype in self.dataloaders['val']: 
                with torch.set_grad_enabled(False):   
                    inputs = inputs.to(self.device).reshape(-1, self.series_len, inputs.shape[-2], inputs.shape[-1])
                    boxes = boxes.to(self.device).reshape(-1, len(self.condition_to_get), 6)
                    labels = labels.to(self.device).reshape(-1, len(self.condition_to_get), 1)
                    
                    valid = [i for i in range(inputs.shape[0]) if (inputs[i].argmax() > 0 and boxes[i].argmax()>0)]
                    if len(valid) <1:
                        tepoch.update(1)
                        continue
                    if len(inputs.shape) <4:
                        tepoch.update(1)
                        continue
                    inputs = inputs[valid]
                    boxes = boxes[valid]
                    labels = labels[valid]
                    
                    out = self.model.predict(inputs.to(self.device).reshape(-1, self.series_len, inputs.shape[-2], inputs.shape[-1])) # reshape input to accomodate multiple view outputs
                    # sort apreds by depth
                    for o in out:
                        sort_ind = o['boxes'][:,0].argsort()
                        apreds.append(o['logits'][sort_ind])
                    
                    # sort labels from min depth to max depth
                    sorted_labels = []
                    for label, box in zip(labels.reshape(-1,2), boxes.reshape(-1,2,6)):
                        sort_ind = box[:,0].argsort(-1)
                        sorted_labels.append(label[sort_ind])
                    
                    slabels = torch.cat(sorted_labels)
                    slabels=  labels.reshape(-1)
                    weights = 2**slabels
                    alabels.append(slabels)
                    aweight.append(weights)

                    for pred in out:
                        pred['boxes']= pred['boxes'][..., [0,1,3,4]]

                    preds += out
                    targets += [dict(boxes=box.to(self.device)[(box != 0).any(dim=-1).nonzero()].squeeze(1), 
                                     labels=label.to(self.device)[(box != 0).any(dim=-1).nonzero()].squeeze((1,2)).to(torch.int)) 
                                     for box, label in zip(boxes[..., [0,1,3,4]].reshape(-1,len(self.condition_to_get),4), labels.reshape(-1,len(self.condition_to_get),1))]
                    
                    stypes += [st for st in stype]
                    tepoch.update(1)

        alabels = torch.cat(alabels, dim=0).cpu().numpy()
        apreds = torch.cat(apreds, dim=0).cpu().numpy()
        aweight=torch.cat(aweight, dim=0).cpu().numpy()

        # #prind confusion matrix for every condition
        # conditions = ['Left Foramina', 'Right Foramina']
        
        # fig, ax = plt.subplots(nrows=1, ncols=len(conditions), figsize=(15,5))
        # if len(conditions) > 1:
        #     ax = ax.ravel()
        # else:
        #     ax = [ax]
        # for i in range(len(conditions)):
        #     cl = alabels[i::len(conditions)]
        #     cpred = apreds[i::len(conditions),:].argmax(-1)
        #     cm = confusion_matrix(cl, cpred)
        #     ax[i].set_title(conditions[i])
        #     ConfusionMatrixDisplay(
        #         confusion_matrix=cm).plot(ax=ax[i], colorbar=False)
        # plt.show()

        # ll = log_loss(alabels, apreds, normalize=True, sample_weight=aweight)
        # if ll < self.best_ll:
        #     self.save_model()
        #     self.best_ll = ll

        self.mAP.update(preds=preds, target=targets)
        all_maps = {k: v for k, v in self.mAP.compute().items()}
        self.all_maps.append(all_maps)
        self.mAP.reset()
        if all_maps['map'] > self.best_map:
            self.save_model()
            self.best_map = all_maps['map']
        #print validation metrics
        validation_log =[]
        validation_log.append("All mAP validation metrics: " + self.get_map_str(all_maps))
        # divide further by in series types
       
        #print("Score:", ll)
        if self.print_evaluation:
            for str in validation_log:
                print(str)

    
    def get_map_str(self, map: Dict):
        str = ""
        for k, v in map.items():
            if k in self.metrics_to_print:
                str+= f"{k}: {v}"
                str+= " || "
        return str

    def plot_examples(self, plot_num:int=1):
        color_dict = dict([(0,'r'),(1,'g'), (2,'b'), (3,'m'), (4, 'y'), (5, 'r'),  (6, 'r'),  (7, 'r'),  (8, 'r'), (9, 'r')])
        self.model.eval()
        for i in range(plot_num):#sin_type in self.series_in_types[0]:
            fig, ax = plt.subplots(3, 2*5, figsize = (16, 8))
            for a in ax.ravel():
                a.set_axis_off()
                a.set_yticklabels([])
                a.set_xticklabels([])

            #if len(self.out_stypes)<2:
             #   ax = np.expand_dims(ax, axis=1)
                
            #fig.suptitle(f"Example plot for input series with type {sin_type}.", fontsize=12)
            inputs, boxes, labels, _ = self.dataloaders['val'].dataset.get_all_random()
            #level = int(random.choice(list(range(0,5))))
            #inputs = inputs[:, level,...]
            #boxes = boxes[:, level,...]
            #labels = labels[:, level,...]
            with torch.set_grad_enabled(False):
                preds = self.model.predict(inputs.to(self.device).reshape(-1, self.series_len, inputs.shape[-2], inputs.shape[-1])) # predict 
            # print ground truth
            for i, (img, box, label, pred) in enumerate(zip(inputs.squeeze(), boxes.squeeze(), labels.squeeze(), preds)):
                ax[0,0+2*i].imshow(img[int(box[0,2]+box[0,5])//2,:,:].detach().cpu().numpy())
                ax[1,0+2*i].imshow(img[int(box[0,2]+box[0,5])//2,:,:].detach().cpu().numpy())
                ax[2,0+2*i].imshow(img[int(box[1,2]+box[1,5])//2,:,:].detach().cpu().numpy())

                ax[0,1+2*i].imshow(img.permute(1,2,0)[int(self.series_len/2),:,:].detach().cpu().numpy())
                ax[1,1+2*i].imshow(img.permute(1,2,0)[int(self.series_len/2),:,:].detach().cpu().numpy())
                ax[2,1+2*i].imshow(img.permute(1,2,0)[int(self.series_len/2),:,:].detach().cpu().numpy())

                for j in range(box.shape[0]):
                    b_sag = box[j,[0,1,3,4]].numpy()
                    b_cor = box[j,[2,0,5,3]].numpy()
                    l = label[j].numpy()
                    color = color_dict[int(l)]
                    ax[0,0+2*i].add_patch(patches.Rectangle((b_sag[0], b_sag[1]), b_sag[2]-b_sag[0], b_sag[3]-b_sag[1], linewidth=1, edgecolor=color, facecolor='none'))
                    ax[0,1+2*i].add_patch(patches.Rectangle((b_cor[0], b_cor[1]), b_cor[2]-b_cor[0], b_cor[3]-b_cor[1], linewidth=1, edgecolor=color, facecolor='none'))

                for j in range(len(pred['boxes'])):
                    b_sag = pred['boxes'][j, [0,1,3,4]].cpu().numpy()
                    b_cor = pred['boxes'][j, [2,0,5,3]].cpu().numpy() #[j, [2,1,5,4]].cpu().numpy()
                    l = pred['labels'][j].cpu().numpy()
                    color = color_dict[int(l)]
                    ax[1,0+2*i].add_patch(patches.Rectangle((b_sag[0], b_sag[1]), b_sag[2]-b_sag[0], b_sag[3]-b_sag[1], linewidth=1, edgecolor=color, facecolor='none'))
                    ax[1,1+2*i].add_patch(patches.Rectangle((b_cor[0], b_cor[1]), b_cor[2]-b_cor[0], b_cor[3]-b_cor[1], linewidth=1, edgecolor=color, facecolor='none'))

                    ax[2,0+2*i].add_patch(patches.Rectangle((b_sag[0], b_sag[1]), b_sag[2]-b_sag[0], b_sag[3]-b_sag[1], linewidth=1, edgecolor=color, facecolor='none'))
                    ax[2,1+2*i].add_patch(patches.Rectangle((b_cor[0], b_cor[1]), b_cor[2]-b_cor[0], b_cor[3]-b_cor[1], linewidth=1, edgecolor=color, facecolor='none'))

            plt.subplots_adjust(wspace=0.1, hspace=0)
            plt.show()
                

    def metrics_description(self, metrics:dict, phase:str)->str:
        outputs = phase + ": ||"
        for k in metrics.keys():
            outputs += (" {}: {:4f} ||".format(k, np.mean(metrics[k])))
        return outputs
    
    def get_summary(self, desc:str = None):
        # desc (str): additional description to include
        
        summary = {
            'model_name': self.model_name,
            'backbone': self.model.backbone_name,
            'conv_type': self.model.head_conv_mode,
            'use_reg_for_cls': self.model.use_reg_for_cls,
            'mix_strategy': self.config['train_dataset_config']['mix_strategy']
        }
        #find epoch with best mAP
        all_map = [epm['map'] for epm in self.all_maps]
        best_epoch = np.argmax(all_map)
        summary['best_epoch'] = best_epoch

        # append general metrics
        for key, val in self.all_maps[best_epoch].items():
            if key in self.metrics_to_print:
                summary[f"general_{key}"] = val
        
        if self.mAP_split:
            for map_split_in_name, map_split_in in self.mAP_split[best_epoch].items():
                for map_split_out_name, map_split_out in map_split_in.items():
                    for key, val in map_split_out.items():
                        if key in self.metrics_to_print:
                            summary[f"{map_split_in_name}_{map_split_out_name}_{key}"] = val
        if desc:
            summary['description'] = desc
            
        return summary


In [13]:
def lineinter(line1, line2):
    inter = (torch.min(line1[5], line2[:,5]) - torch.max(line1[2], line2[:,2])).clamp(0)
    uni = (torch.max(line1[5], line2[:,5]) - torch.min(line1[2], line2[:,2])).clamp(0) - inter
    return inter/(uni + 1e-7)

class Nms3dSagittalForamina(nn.Module):
    def __init__(self, iou_treshold=0.3, score_treshold = 0., del_same_depth=False, unique_cls = 2) -> None:
        super().__init__()
        self.score_tr = score_treshold
        self.iou_tr = iou_treshold
        self.del_same_depth = del_same_depth
        self.unique_cls = unique_cls

    def forward(self, results):
        # boxes in format [score, label, x, y, z, x, y, z]
        aresult = results.clone().detach()
        out = []
        for i in range(self.unique_cls): # max two objects per unique class
            result = aresult[aresult[:,1]==i]
            result = result[result[:,0].argsort(dim=0, descending=True)] # sort by score 
            result = result[result[:,0] >= self.score_tr] # score tr
            
            if result.nelement() == 0:
                return result
            
            filtered = []
            while result.nelement() != 0:
                filtered.append(result[[0]])
                ious = y9.bbox_iou(result[0, 2:2+6], result[:, 2:2+6], iou_mode=True)
                if self.del_same_depth:
                    li = lineinter(result[0, 2:2+6], result[:, 2:2+6])
                    result = result[torch.logical_and(ious.squeeze()<=self.iou_tr, li < 0.1)]
                else:
                    result = result[ious.squeeze()<=self.iou_tr]

            result = torch.cat(filtered, dim=0)
            out.append(result[0:2]) #limit to max 2 foraminas per image

        return torch.cat(out, dim=0)

In [14]:
class BoxModel(nn.Module):
    def __init__(self, backbone_name:str, series_dim:list[int], num_classes:int=1,
                 use_features:Union[str, list[int]]=[0], reg_max:int=16, pretrained:bool=False,
                 head_conv_mode:str = '2d', use_reg_for_cls:bool = False):
        # backbone_name - timm model to use as backbone
        # series_dim - dimentionality of the series [num_channels, im_width, im_height]
        # use_features - features to use from backbone output:
        #        example model outputs features with dim [64, 64, 128, 256, 512] 
        #               'last' or [0] will take only last layer
        #               'all' will take all layers
        #               [0, 1, 2] will take last three layers

        super().__init__()
        # model outputs bounding box in yolo format [x_mid, y_mid, width, height] (normalized)
        # chose feature extractor from timm models
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.num_classes = num_classes
        self.series_dim = series_dim
        self.backbone_name = backbone_name
        self.joint_train = False

        opts = {
        'model': 'resnet',
        'input_W': 128,
        'input_H': 128,
        'input_D': 32,
        'device': self.device,
        'phase': 'train',
        }
        model_pretrained_params = {
            'resnet_10': {'model_depth': 10, 'resnet_shortcut': 'B'},
            'resnet_10_23dataset': {'model_depth': 10, 'resnet_shortcut': 'B'},
            'resnet_18': {'model_depth': 18, 'resnet_shortcut': 'A'},
            'resnet_18_23dataset': {'model_depth': 18, 'resnet_shortcut': 'A'},
            'resnet_34': {'model_depth': 34, 'resnet_shortcut': 'A'},
            'resnet_34_23dataset': {'model_depth': 34, 'resnet_shortcut': 'A'}
        }
        for model_name, model_dict in model_pretrained_params.items():
            model_pretrained_params[model_name] = Struct({**model_dict, **opts})

        def construct_network(feature_extractor, model_pretrained_params):
            model = MedNet(feature_extractor, model_pretrained_params)
            return model

        self.feature_extractor = construct_network('resnet_18_23dataset', model_pretrained_params).to(self.device)
        self.feature_extractor.init_FE(self.device, '/workspaces/RSNA_LSDC/MedicalNet/Pretrained/resnet_18_23dataset.pth')

        self.all_channels = [64, 128, 256, 512]
        self.reduction = [4, 8, 8, 8]

        # self.series_dim = series_dim
        # self.backbone_name = backbone_name
        # self.feature_extractor = timm_3d.create_model(
        #                                 backbone_name,
        #                                 pretrained=True,
        #                                 features_only=True,
        #                                 in_chans=1
        #                             )
                                            
        # self.all_channels = self.feature_extractor.feature_info.channels()
        # self.reduction = self.feature_extractor.feature_info.reduction() #[4,8,16,32] 

        print(self.all_channels, self.reduction)

        # features dimention
        if use_features=='all':
            self.in_channels = self.all_channels
            self.featmap_stride = self.reduction
            self.fl = list(range(len(self.all_channels)))
        elif use_features=='last':
            self.in_channels = [self.all_channels[-1]]
            self.featmap_stride = [self.reduction[-1]]
            self.fl = [len(self.all_channels)-1]
        elif type(use_features)==list:
            self.in_channels = [self.all_channels[len(self.all_channels)-(1+i)] for i in sorted(use_features, reverse=True)]
            self.featmap_stride = [self.reduction[len(self.reduction)-(1+i)] for i in sorted(use_features, reverse=True)]
            self.fl= [len(self.all_channels)-(1+i) for i in sorted(use_features, reverse=True)]

        self.head_conv_mode = head_conv_mode
        self.use_reg_for_cls = use_reg_for_cls

        self.postprocess = Nms3dSagittalForamina(0.1, 0.1, True, 1)
        self.head = y9.Detect3d(nc=num_classes, ch=self.in_channels, strides=self.featmap_stride, reg_max=reg_max, return_logits=True)
        h = {
            "device": self.device,
            "cls_pw":None,
            "label_smoothing": 0.0,
            "fl_gamma": 0.0,
            'bbox_weight': 5.5,
            'class_weight': .5,
            'dfl_weight':3.,
            #'weight': [1., 2., 4.]
        }
        self.CL = y9.ComputeLoss(self.head, h)

    
    def forward(self, x: torch.Tensor) -> Tuple[List]:
        """Forward features from the upstream network.

        Args:
            x (Tensor): input series
        Returns:
            Tuple[List]: A tuple of multi-level classification scores, bbox
            predictions
        """
    
        x = self.feature_extractor(x.unsqueeze(1))
        x = [x[i] for i in self.fl]

        # for j in x:
        #     print(j.shape)

        return self.head(x)

    def get_loss(self, series, gt_boxes, gt_labels):
        head_out = self.forward(series)  
        loss, loss_split, pred_a_boxes = self.CL(head_out,
                                        gt_boxes, gt_labels)

        metrics = dict(loss_cls=loss_split[1], loss_bbox=loss_split[0], loss_dfl=loss_split[2])
        
        return loss, metrics
    
    def get_predictions(self, series, gt_boxes, gt_labels):
        result_list = []
        head_out = self.forward(series)  
        loss, loss_split, pred_a_boxes = self.CL(head_out,
                                        gt_boxes, gt_labels)
        #print(pred_a_boxes.shape)
        for result in pred_a_boxes:
            if result.nelement() == 0:
                result_list.append({'boxes': result, 'scores': result, 'labels': result.int(), 'logits': result})
            else:
                result_list.append({'boxes': result[:, 2:], 'scores': result[:, 0], 'labels': result[:, 1].int()})
        return result_list

    def predict(self, series):
        scores, labels, dbox, logits = self.forward(series)
        result_list = []
        for score, label, boxes, logit in zip(scores, labels, dbox, logits):
            result = self.postprocess(torch.cat((score, label, boxes, logit), dim=0).permute(1,0))
            result_list.append({'boxes': result[:, 2:], 'scores': result[:, 0], 'labels': result[:, 1].int(), 'logits': result[:, 8:].softmax(-1)})
        return result_list


# Training model

In [15]:
im_size = 96

train_transforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5), v2.RandomHorizontalFlip(p = 0.5), v2.RandomAffine(degrees=5), 
                     v2.RandomRotation(degrees=(90,90)), v2.RandomRotation(degrees=(-90,-90))]),
    v2.RandomChoice([v2.RandomPerspective(distortion_scale=0.2, p=1.0), v2.RandomAffine(degrees=0, translate=(0.3,0.3), shear=(-5,5,-5,5))]), # translation + shearing
    v2.RandomAffine(degrees=0, scale=(0.8,1.2)), #scaling
    v2.RandomChoice([v2.ElasticTransform(alpha=40.0), v2.GaussianBlur(kernel_size=(5,5), sigma=(0.7, 0.7))]),
])

val_transforms = v2.Compose([
    v2.RandomVerticalFlip(p = 0.),
    #v2.Resize((im_size,im_size)), #resize
])

cond_x_overhead = {
    'Left Neural Foraminal Narrowing': [6,9],
    'Right Neural Foraminal Narrowing': [6,9],
    'Left Subarticular Stenosis': [5,5],
    'Right Subarticular Stenosis': [5,5],
    'Spinal Canal Stenosis': [10, 10]
}
cond_y_overhead = {
    'Left Neural Foraminal Narrowing': [11,15],
    'Right Neural Foraminal Narrowing': [11,15],
    'Left Subarticular Stenosis': [4,4],
    'Right Subarticular Stenosis': [4,4],
    'Spinal Canal Stenosis': [200, 200]
}
cond_z_overhead = {
    'Left Neural Foraminal Narrowing': [12,12],
    'Right Neural Foraminal Narrowing': [12,12],
    'Left Subarticular Stenosis': [8,8],
    'Right Subarticular Stenosis': [8,8],
    'Spinal Canal Stenosis': [20, 20]
}
train_dataset_config ={
    'preload': True, # preload data into memory (WARNING IT MAY TAKE A LOT OF SPACE - DEPENDS ON THE DATASET USED)
    'im_size': [im_size, im_size],
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset",
    'load_series': ['sagittal'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
    'united': True,

    'transforms': train_transforms,
    'vsa': False,
    'one_label': False, # use one label for every level (do not differentiate between levels)
    'series_out_types': ['sagittal'], # mix output series types to get views necessary to create 3d box ['sagittal', 'coronal'
    'get_conditions':['Left Neural Foraminal Narrowing', 'Right Neural Foraminal Narrowing'], # condition to output from series
    'mix_strategy': 'combined', # strategy for mixing output series types ['random', 'custom', 'combined'] random - randomly select output type, 
                              #'manual' - will return data based on currently choosen view, 'combined' will return all views in one call
    'return_series_type': False, # If True getitem will also return series orignial type
    
    'normalize': True,
    'image_type': 'png',
    'preload': False,

    'dataset_type': 'boxes', #'boxes', 'conditions'
    'supress_warinings': False, 

    'limit_series_len': True, # if True the series length will be limited to number N specified by 'series_len' parameter
    'series_len': 48, #  maximal number of slices in series

    'x_overhead': [25,25], # overhead for levels in x-dim (in mm)
    'z_overhead':25,
    'overlap_levels': True, # if true the level upper and lower boundary will overlap with value specified in 'y_overlap'
    'y_overlap': 5, # overlap size of levels boundaries (in mm)
    'cond_x_overhead': cond_x_overhead,
    'cond_y_overhead': cond_y_overhead,
    'cond_z_overhead': cond_z_overhead,

    '3d_box': True
}

val_dataset_config = copy.deepcopy(train_dataset_config)
val_dataset_config['transforms'] = val_transforms
val_dataset_config['mix_strategy'] = 'combined'
val_dataset_config['vsa'] = False
val_dataset_config['rev'] = True
val_dataset_config['return_series_type'] = True

In [16]:
trainer_config = {
    "print_evaluation": True,
    "steps_per_plot": 1,

    "checkpoints": False,
    "save_path": "/workspaces/RSNA_LSDC/model_weight",
    "step_per_save":100,

    "model_name": "foramina_detect_sagittal_mednet18_96_96_48",
    "train_dataset_config": train_dataset_config,
    "val_dataset_config": val_dataset_config,

    "optimizer": torch.optim.AdamW,#torch.optim.Adam,
    "optimizer_params": {'lr': 2e-4},#, 'momentum':0.98, 'weight_decay':1e-5},#, 'momentum':0.9},
    "scheduler": torch.optim.lr_scheduler.ExponentialLR,#torch.optim.lr_scheduler.ExponentialLR, #torch.optim.lr_scheduler.OneCycleLR,
    "scheduler_params":{'gamma':0.99},#{'max_lr': 0.001, 'epochs': None, 'steps_per_epoch':None}, {'gamma':0.9}

    "epochs": 15, 
    "batch_size":2,
    "early_stopping": False,
    "early_stopping_treshold": 0.1
}

model_config = {
    'backbone_name': 'densenet121', 
    'series_dim': [train_dataset_config['series_len']]+train_dataset_config['im_size'],
    'use_features': [0,1],

    'reg_max': 16, 
    'pretrained': False, 
    'num_classes': 1, #1 if train_dataset_config['one_label'] else 5,
    'head_conv_mode': '3d',
    'use_reg_for_cls': False
    }


In [17]:
tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv')#.iloc[0:300]
def kfoldCV(k, trainer_config, model_config):
    model_summaries = []
    data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
    data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
    data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
    unique_studies = np.random.permutation(np.array(tsd.study_id.unique()))
    if k == 1:
        with open('/workspaces/RSNA_LSDC/inputs/train_unique_studies.npy', 'rb') as f:
            train = np.load(f)
        with open('/workspaces/RSNA_LSDC/inputs/test_unique_studies.npy', 'rb') as f:
            test = np.load(f) 
        folds = [test, train]
    else:
        folds = np.array_split(unique_studies, k)
    
    for i in range(k):
        print(f"Fold: {i}")
        train_ids = np.concatenate(folds[:i]+folds[i+1:], axis=0)
        train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
                    'axial': data_axial[data_axial.study_id.isin(train_ids)]}
        val_data=  {'sagittal': data_sagittal[data_sagittal.study_id.isin(folds[i])],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(folds[i])],
                    'axial': data_axial[data_axial.study_id.isin(folds[i])]}
        
        trainer = SagittalTrainer(BoxModel, model_config, trainer_config, train_data, val_data)
        trainer.train()
        model_summaries.append(trainer.get_summary())

    return model_summaries
# convformer_s18

In [None]:
torch.set_printoptions(profile="full")
kfoldCV(1, trainer_config, model_config)

In [None]:
# %%
if __name__ == '__main__':
    kfoldCV(1, trainer_config, model_config)



In [None]:
torch.cuda.is_available()

In [None]:
from MedicalNet.MedicalNet import Struct, MedNet
import torch

opts = {
'model': 'resnet',
'input_W': 224,
'input_H': 224,
'input_D': 32,
'device': 'cuda',
'n_seg_classes': 4,
'phase': 'train',
}

model_pretrained_params = {
    'resnet_10': {'model_depth': 10, 'resnet_shortcut': 'B'},
    'resnet_10_23dataset': {'model_depth': 10, 'resnet_shortcut': 'B'},
    'resnet_18': {'model_depth': 18, 'resnet_shortcut': 'A'},
    'resnet_18_23dataset': {'model_depth': 18, 'resnet_shortcut': 'A'},
    'resnet_34': {'model_depth': 34, 'resnet_shortcut': 'A'},
    'resnet_34_23dataset': {'model_depth': 34, 'resnet_shortcut': 'A'}
}


for model_name, model_dict in model_pretrained_params.items():
    model_pretrained_params[model_name] = Struct({**model_dict, **opts})

def construct_network(feature_extractor, model_pretrained_params):
    model = MedNet(feature_extractor, model_pretrained_params)
    return model

model = construct_network('resnet_34_23dataset', model_pretrained_params).to('cuda')
model.init_FE('cuda')

In [5]:
x = model(torch.rand((1,1,32,224,224)).to('cuda'))

In [None]:
for i in x:
    print(i.shape)

In [None]:
#TIMM 3D