In [6]:
import sys
import os
sys.path.insert(1, os.path.realpath(os.path.pardir))

In [None]:
import os
import cv2
import glob
import copy
import torch
import random
from torch import nn
import timm
import timm_3d
from typing import List, Sequence, Tuple, Union, Dict
from scipy import ndimage

from sklearn.metrics import log_loss
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from PIL import Image
import pandas as pd
from typing import Literal
from collections import defaultdict
import numpy as np
from tqdm.notebook import tqdm

import albumentations as A
from torchvision.tv_tensors import BoundingBoxes as BB
from torchmetrics.detection import MeanAveragePrecision
import torchvision.transforms.v2 as v2
import torch.nn.functional as F
from torchvision.ops import nms

import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:
torch.cuda.is_available()

postprocessing found bboxes:
Include guaranteed information about bboxes in image. 
-> 5 classes each with one bbox or 1 class with 5 bbox
-> If there exist level n-1 and n+1 in image there must also exist level n
-> The boxes of levels n-1, n, n+1 must be aligned next to another in heigh dimention
-> The boxes must overlap in x dimention to some extend 
-> If there exist series of boxes for levels n, n+1, n+2 and image height is bigger than mean level heigh than the n+3 level must also exist - same situation in reverse


# DATASET

In [10]:
class Bbox3d():
    def __init__(self, x, y, z) -> None:
        # 3d box in coordinates of sagittal series
        self.x = x
        self.y = y
        self.z = z
    
    def get_box_in_view_type(self, view_type, d3:bool=False):
        if view_type in ['sagittal', 'sagittal_t2']:
            return self.get_sagittal(d3)
        elif view_type == 'coronal':
            return self.get_coronal(d3)
        elif view_type == 'axial':
            return self.get_axial(d3)
    
    def get_sagittal(self, d3:bool=False):
        if d3:
            return [self.x[0], self.y[0], self.x[1], self.y[1], self.z[0], self.z[1]]
        return [self.x[0], self.y[0], self.x[1], self.y[1]]

    def get_coronal(self, d3:bool=False):
        if d3:
            return [self.z[0], self.y[0], self.z[1], self.y[1], self.x[0], self.x[1]]
        return [self.z[0], self.y[0], self.z[1], self.y[1]]

    def get_axial(self, d3:bool=False):
        if d3:
            return [self.z[0], self.x[0], self.z[1], self.x[1], self.y[0], self.y[1]]
        return [self.z[0], self.x[0], self.z[1], self.x[1]]

In [11]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        # data_info: dict consisting of series types and dataframe with their info
        # config: dict - dataset configuration
        #TODO: Better overlap in instance dimention
        #TODO: smart limit - selecting important slices based on series type and where they usually lay
        
        self.supress_warnings = config['supress_warinings']
        self.used_series_types = config['load_series']
        self.study_ids = np.unique(np.concatenate([series.study_id.unique() for series in data_info.values() if series is not None]))
        self.data_info = data_info
        if not np.all([tp in list(self.data_info.keys()) for tp in self.used_series_types]):
            raise Exception("Series types to use do not match provided data information.")
        
        self.series_out_types = config['series_out_types']
        self.current_view = self.series_out_types[0]
        self.return_series_type = config['return_series_type']

        self.d3 = True #config['3d_box'] if '3d_box' in list(config.keys()) else False
        self.preload = config['preload']
                 
        if self.preload and len(self.series_out_types)>1:
            self.preload = False
            if not self.supress_warnings:
                print("Preloading with mixed series type outputs is not supported. Preloading was turned off.")

        self.im_size = config['im_size']
        if self.im_size:
            self.resize = v2.Resize((self.im_size[1], self.im_size[0]))

        self.image_type = config['image_type']
        self.rev_lr = config['rev'] if 'rev' in list(config.keys()) else False
        if self.preload and not self.supress_warnings:
            print("Warning! Preloading of images is turned on. The program will attempt to load whole dataset into memory!")
        
        self.transforms = config['transforms']
        self.dtransforms = config['dtransforms']
        self.normalize = config['normalize']
        self.vsa = config['vsa']
        
        self.dataset_type = config['dataset_type']
        self.dataset_path = config['dataset_path']

        self._condition_list = ['Left Neural Foraminal Narrowing', 
                                'Left Subarticular Stenosis', 
                                'Right Neural Foraminal Narrowing', 
                                'Right Subarticular Stenosis', 
                                'Spinal Canal Stenosis']
        
        self.get_condition = [self._condition_list.index(cond) for cond in config['get_conditions']]
        self.used_conditions = config['get_conditions']
        self._status_map = {'Normal/Mild': [1., 0., 0.],
                            'Moderate': [0., 1., 0.],
                            'Severe': [0., 0., 1.]}
        
        self.level_ind = {'L1/L2': 0, 'L2/L3': 1, 'L3/L4': 2, 'L4/L5': 3, 'L5/S1':4}

        self.box_labels = torch.tensor([list(self.level_ind.values())]) 

        self.limit = config['limit_series_len']
        self.series_length = 15
        
        self.x_overhead = config['x_overhead'] if config['x_overhead'] else [30, 30]
        self.z_overhead = 0
        self.overlap_levels = config['overlap_levels']
        self.y_overlap = config['y_overlap'] if self.overlap_levels else 0

        self.cond_x_overhead = config['cond_x_overhead']
        self.cond_y_overhead = config['cond_y_overhead']
        self.cond_z_overhead = config['cond_z_overhead']

        if self.limit and not config['series_len'] and not self.supress_warnings:
            print("Series length is not specified. The limit is set to 15.")
        elif self.limit and config['series_len']:
            self.series_length = config['series_len']
            
        self.series_type_ind = {}
        for s in self.used_series_types:
            self.series_type_ind[s] = []

        self.data = []
        self.prepare_data()

    def get_level_boxes(self, series_info:pd.DataFrame, stype='sagittal'):
        #default for each scan has 5 visible levels
        bboxes = []
        labels = []
        cond_boxes = []
        cond_labels = []
        cond_state = []
        z_overlap = self.z_overhead
        if stype in ['sagittal', 'sagittal_t2']:
            for _, row in series_info.iterrows():
                labels.append(self.level_ind[row.level])
                z_min = min(row.present_instances.index(min(row.instance_number)), row.present_instances.index(max(row.instance_number)))
                z_max = max(row.present_instances.index(min(row.instance_number)), row.present_instances.index(max(row.instance_number)))
                bboxes.append(Bbox3d(
                    x=[max(0, min(row.x)-self.x_overhead[0]/row.pixel_spacing[0]), min(row.image_width, max(row.x)+self.x_overhead[1]/row.pixel_spacing[0])],
                    y=[max(0, row.level_boundaries[0]-self.y_overlap/row.pixel_spacing[1]), min(row.image_height, row.level_boundaries[1]+self.y_overlap/row.pixel_spacing[1])],
                    z=[max(0, z_min-z_overlap), min(len(row.present_instances), z_max+z_overlap)]
                                ))
                c, cl, cs = self.get_condition_boxes(row, stype=stype)
                cond_boxes.append(c)
                cond_labels.append(cl)
                cond_state.append(cs)

        elif stype=='axial':
            if z_overlap == 0:
                z_overlap = 0.5 
            for _, row in series_info.iterrows():
                labels.append(self.level_ind[row.level])
                bboxes.append(Bbox3d(
                    x=[max(0, min(row.y)-self.x_overhead[0]/row.pixel_spacing[1]), min(row.image_width, max(row.y)+self.x_overhead[1]/row.pixel_spacing[1])],
                    y=[max(0, row.present_instances.index(min(row.level_slices))-z_overlap), min(len(row.present_instances)-z_overlap, row.present_instances.index(max(row.level_slices))+z_overlap)],
                    z=[max(0, min(row.x)-self.y_overlap/row.pixel_spacing[0]), min(row.image_height, max(row.x)+self.y_overlap/row.pixel_spacing[0])]
                                ))
                c, cl, cs = self.get_condition_boxes(row, stype=stype)
                cond_boxes.append(c)
                cond_labels.append(cl)
                cond_state.append(cs)

        return bboxes, np.array(labels), cond_boxes, cond_labels, cond_state
    
    def get_condition_boxes(self, row:pd.DataFrame, stype='sagittal'):
        #default for each scan has 5 visible levels
        bboxes = []
        labels = []
        cond_state = []
        z_overlap = 0
        if stype in ['sagittal', 'sagittal_t2']:
            for condition, x,y,z in zip(row.condition, row.x, row.y, row.instance_number):
                cond_z_overhead = self.cond_z_overhead
                if 'Left' in condition:
                    if row['reversed']:
                        cond_z_overhead = [self.cond_z_overhead[1], self.cond_z_overhead[0]]
                    else:
                        cond_z_overhead = [self.cond_z_overhead[0], self.cond_z_overhead[1]]
                if 'Right' in condition:
                    if not row['reversed']:
                        cond_z_overhead = [self.cond_z_overhead[1], self.cond_z_overhead[0]]
                    else:
                        cond_z_overhead = [self.cond_z_overhead[0], self.cond_z_overhead[1]]

                labels.append(self._condition_list.index(condition))
                bboxes.append(Bbox3d(
                    x=[max(0, x-self.cond_x_overhead[0]/row.pixel_spacing[0]), min(row.image_width, x+self.cond_x_overhead[1]/row.pixel_spacing[0])],
                    y=[max(0, y-self.cond_y_overhead[0]/row.pixel_spacing[1]), min(row.image_height, y+self.cond_y_overhead[1]/row.pixel_spacing[1])],
                    z=[max(0, row.present_instances.index(z)-cond_z_overhead[0]), min(len(row.present_instances), row.present_instances.index(z)+cond_z_overhead[1])]
                                ))
                cond_state.append(row.status[row.all_conditions.index(condition)])
        elif stype=='axial':
            z_overlap = 0.5
            for condition, x,y,z in zip(row.condition, row.x, row.y, row.instance_number):
                labels.append(self._condition_list.index(condition))
                bboxes.append(Bbox3d(
                    x=[max(0, y-self.cond_x_overhead[0]/row.pixel_spacing[1]), min(row.image_width, y+self.cond_x_overhead[1]/row.pixel_spacing[1])],
                    y=[max(0, row.present_instances.index(min(row.level_slices))-z_overlap), min(len(row.present_instances)-z_overlap, row.present_instances.index(max(row.level_slices))+z_overlap)],
                    z=[max(0, x-self.cond_y_overhead[0]/row.pixel_spacing[0]), min(row.image_height, x+self.cond_y_overhead[1]/row.pixel_spacing[0])]
                                ))
                cond_state.append(row.status[row.all_conditions.index(condition)])

        return bboxes, np.array(labels), cond_state

    def set_view(self, new_view):
        self.current_view = new_view
        
    def get_condition_labels(self, series_info:pd.DataFrame):
        labels = []
        cond_presence_masks = []
        level_presence_mask = []

        for level, _ in self.level_ind.items():
            if not series_info[series_info['level']==level].empty:
                labels.append(series_info[series_info['level']==level].iloc[0].status)
                cond_presence_masks.append(series_info[series_info['level']==level].iloc[0].presence_mask)
                level_presence_mask.append(True)
            else:
                level_presence_mask.append(False)

        return np.array(labels), np.array(cond_presence_masks), np.array(level_presence_mask)

    def info2dict(self, series_info, stype=None): #remember axials can be combination of different serieses (sagittals can't)
        level0 = series_info.iloc[0]
        data_dict = {}
        boxes, box_labels, cond_boxes, cond_labels, cond_state = self.get_level_boxes(series_info, stype=stype)
        labels, label_level_mask, label_cond_mask = self.get_condition_labels(series_info)

        data_dict['study_id'] = level0.study_id
        data_dict['series_id'] = level0.series_id
        data_dict['width'] = level0.image_width
        data_dict['height'] = level0.image_height
        data_dict['reversed'] = level0.reversed
        data_dict['series_type'] = stype
        data_dict['pixel_spacing'] = level0.pixel_spacing

        data_dict['boxes'] = boxes 
        data_dict['files'] = [f"{self.dataset_path}/{data_dict['study_id']}/{data_dict['series_id']}/{instance}.{self.image_type}" for instance in level0.present_instances]
        data_dict['box_labels'] = box_labels
        data_dict['labels'] = labels
        data_dict['label_presence_mask'] = label_level_mask
        data_dict['cond_boxes'] = cond_boxes
        data_dict['cond_labels'] = cond_labels
        data_dict['label_cond_mask'] = label_cond_mask
        data_dict['cond_state'] = cond_state
        
        return data_dict
   
    def prepare_data(self):
        # prepare paths for every image to load
        with tqdm(total=len(self.study_ids), desc="Preparing data: ") as pbar:
            for study_id in self.study_ids:
                study_dict = dict(
                                sagittal=[], 
                                sagittal_t2=[], 
                                axial=[])
                present = 0
                for stype in self.used_series_types:
                    for series_id in self.data_info[stype].query(f'study_id == {study_id}').series_id.unique():
                        ddict = self.info2dict(self.data_info[stype][self.data_info[stype].series_id==series_id], stype=stype)
                        if self.preload:
                            ddict = self.preload_series(ddict)
                        study_dict[stype].append(ddict)
                        present +=1

                if present == 0:
                    continue
                else:
                    self.data.append(study_dict)
                pbar.update(1)

    def split_to_series(self):
        temp = []
        batches = []
        i=0
        for data in self.data:
            batch = []
            for series in data.values():
                if series:
                    temp+=series
                    batch.append(i)
                    i+=1
            batches.append(batch)

        self.data = temp
        self.batches = batches
    
    def change_img_view(self, img, current_view, new_view):
        if current_view in ['sagittal', 'sagittal_t2']:
            if new_view == 'sagittal':
                return img
            elif new_view=='coronal':
                return img.transpose(2, 1, 0) # n, h, w -> w, h, n
            elif new_view=='axial':
                return img.transpose(1, 2, 0 ) #n, h, w -> h, n, w
        elif current_view=='axial':
            if new_view =='sagittal':
                return img.transpose(2, 0, 1) # n, h, w -> w, n, h
            elif new_view =='coronal':
                return img.transpose(1, 0, 2) # n, h, w -> h, n, w
            elif new_view=='axial':
                return img
            
    def load_series(self, data) -> np.ndarray:
        boxes = np.array([box.get_box_in_view_type(data['series_type'], d3 = True) for box in data['boxes']], dtype=int)
        level_labels = data['box_labels']
        oimg = np.zeros((len(data['files']), data['height'], data['width']), dtype = np.uint8)
        for i, path in enumerate(data['files']):  
            try:      
                oimg[i,:,:] = np.array(Image.open(path), dtype = np.uint8)
            except ValueError:
                temp = np.array(Image.open(path), dtype = np.uint8) 
                oimg[i,:temp.shape[0],:temp.shape[1]] = temp
                del temp

        # split to boxes 
        imgs = np.zeros((5,len(self.used_conditions),self.series_length, self.im_size[1], self.im_size[0]))
        labels = np.zeros((5,len(self.used_conditions),3))
        masks = np.zeros((5,len(self.used_conditions),1))

        for i, (box, label, cond_box, cond_labels, cond_state) in enumerate(zip(boxes, level_labels, data['cond_boxes'], data['cond_labels'], data['cond_state'])):
            cb = np.array([cb.get_box_in_view_type(data['series_type'], d3 = True) for cb in cond_box], dtype=int)
            for c, cl, cs in zip(cb, cond_labels, cond_state):
                if self._condition_list[cl] in self.used_conditions:
                    labels[i, self.used_conditions.index(self._condition_list[cl]),:] = cs
                    masks[i, self.used_conditions.index(self._condition_list[cl]),0] = data['label_presence_mask'][i,cl]
                    imgs[i, self.used_conditions.index(self._condition_list[cl]),:,:,:] =  self.__itensity_normalize_one_volume__(self.__resize_data__(oimg[
                                        max(0, c[4]+np.random.randint(-2,1)): min(oimg.shape[0], c[5]+np.random.randint(-1,2)),

                                        int(max(0, c[1]+np.random.randint(-10,2)*data['pixel_spacing'][0])): 
                                        int(min(oimg.shape[1], c[3]+np.random.randint(-2, 10)*data['pixel_spacing'][0])), 

                                        int(max(0, c[0]+np.random.randint(-10,2)*data['pixel_spacing'][1])): 
                                        int(min(oimg.shape[2], c[2]+np.random.randint(-2,10)*data['pixel_spacing'][1]))
                                        ]))
                    
                    #imgs[i, self.used_conditions.index(self._condition_list[cl]),:,:,:]=  self.__itensity_normalize_one_volume__(self.__resize_data__(oimg[c[4]: c[5], c[1]:c[3], c[0]:c[2]]))
                    
        return imgs, labels, masks
    

    def __itensity_normalize_one_volume__(self, volume):
        """
        normalize the itensity of an nd volume based on the mean and std of nonzeor region
        inputs:
            volume: the input nd volume
        outputs:
            out: the normalized nd volume
        """
        #out = (volume - 0.5*255.)/(0.5 *255.)

        pixels = volume[volume > 0]
        if pixels.size==0:
            return volume
        
        mean = pixels.mean()
        std  = pixels.std()
        out = (volume - mean)/std
        #out_random = np.random.normal(0, 1, size = volume.shape)
        #out[volume == 0] = out_random[volume == 0]

        return out

    def __resize_data__(self, data, bb=None):
        """
        Resize the data to the input size
        """ 
        
        [depth, height, width] = data.shape
        scale = [self.series_length*1.0/depth, self.im_size[1]*1.0/height, self.im_size[0]*1.0/width]
        #scale_d = [self.series_length*1.0/depth, 1, 1]
        if bb:
            bb[..., [0,2]] = bb[..., [0,2]]/width*self.im_size[1]
            bb[..., [1,3]] = bb[..., [1,3]]/height*self.im_size[0]
            bb[..., [4,5]] = bb[..., [4,5]]/depth*self.series_length

        data = ndimage.zoom(data, scale, order=1)
        #data = self.limit_series(data)
        
        if bb:
            return data, bb
        return data
    
    def limit_series(self, img, z_bb=None):
        if self.series_length < img.shape[0]:
                if self.d3 and z_bb:
                    z_bb = z_bb* self.series_length/img.shape[0]
                st = np.round(np.linspace(0, img.shape[0] - 1, self.series_length)).astype(int)
                img = img[st,:,:]
        elif self.series_length > img.shape[0]:
                img = np.pad(img, ((0, self.series_length-img.shape[0]), (0, 0), (0, 0)))
        if z_bb:
            return img, z_bb
        return img

    
    def prepare_series(self, img):

        img = torch.tensor(img,dtype=torch.float)
        if self.transforms:
            # transform in height axis
            img = self.transforms(img)
            # transform in width axis
            if torch.rand(1)<0.5 and self.dtransforms:
                img = self.dtransforms(img.permute(0, 1, 3, 2, 4))
                img = img.permute(0, 1, 3, 2, 4)
        
        return img
    
    def __getitem__(self, index: int):
        return None
        
    def __len__(self):
        return(len(self.data))

In [12]:
class BoxDatasetUnited(Dataset):
    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        super(BoxDatasetUnited, self).__init__(data_info, config)
        self.mix_strategy = config['mix_strategy']
        if config['one_label']:
            self.box_labels = torch.tensor([0,0,0,0,0], dtype=torch.float)
        else:
           self.box_labels = torch.tensor([list(self.level_ind.values())]) 

        #split data to individual serieses from dict of study-series_type pairs
    
    def drop_series(self, volume,  masks, mask_dropped:bool=True):
        # drop one of the serieses and possibly mask conditions primary estimated with it (if mask_dropped)
        to_drop = random.choice(list(range(0, len(self.used_series_types))))
        volume[to_drop] = torch.zeros((5, self.series_length, self.im_size[1], self.im_size[0]))
        if mask_dropped:
            if to_drop == 0:
                masks[:,[0,2]] = 0
            elif to_drop==1:
                masks[:,4] = 0
            elif to_drop==2:
                masks[:,[1,3]] = 0
                
        return volume, masks

    def mirror_sides(self, volume, labels, masks):
        # mirror left to right or right to left
        if torch.rand(1)>0.5:
            #right to left
            volume[:, :, :int(volume.size(2)/2), ...] = volume[:, :, int(volume.size(2)/2):, ...].flip(2)
            labels[:, [2,3]] = labels[:, [0,1]]
            masks[:, [2,3]] = masks[:, [0,1]]
        else:
            #left to right
            volume[:, :, int(volume.size(2)/2):, ...] = volume[:, :, :int(volume.size(2)/2), ...].flip(2)
            labels[:, [0,1]] = labels[:, [2,3]]
            masks[:, [0,1]] = masks[:, [2,3]]

        return volume, labels, masks
        
    def flip_sides(self, volume, labels, masks):
        # flip left/right sides and their labels
        volume = volume.flip(2)
        labels = labels[:, [2,3,0,1,4]]
        masks = masks[:, [2,3,0,1,4]]
        return volume, labels, masks
    

    def __getitem__(self, index: int)->tuple[np.ndarray, np.ndarray]:

        adata = self.data[index]
        
        # limit series (if limit==True)
        all_type_img = []
        all_type_mask = []
        all_type_labels = []
        type_to_ind = {}
        ind = 0
        for st in self.used_series_types:
            type_to_ind[st] = ind
            ind+=1
            
        all_type_img = torch.zeros((len(self.used_series_types), 5, len(self.used_conditions), self.series_length, self.im_size[1], self.im_size[0]), dtype=torch.float)
        for key, data_list in adata.items():
            if data_list:
                for data in data_list: #schuffle if you want random in some cases
                    if not self.preload:
                        try:
                            oimg, labels, masks = self.load_series(data)
                            oimg = self.prepare_series(oimg)
                        except ZeroDivisionError or UnboundLocalError: # one series was skipped in data preparation and now it raises exception ;_;
                            print('Series ', data['series_id'], ' dropped.')
                            return all_type_img.permute(1,2,0,3,4,5), torch.zeros((5,2)).to(dtype=torch.int64), torch.zeros((5,2,1)).to(dtype=torch.int64)
                    all_type_img[type_to_ind[key]] = oimg
        # try:
        #     oimg = self.prepare_series(oimg)
        # except UnboundLocalError:
        #     return all_type_img, torch.zeros((5,2)).to(dtype=torch.int64), torch.zeros((5,2,1)).to(dtype=torch.int64)
        
        labels, masks = self.prepare_labels(labels, masks)
        
        return all_type_img.permute(1,2,0,3,4,5), labels, masks
    
    def prepare_labels(self, labels, masks):
        labels = torch.tensor(labels)
        masks = torch.tensor(masks)
        return labels.to(dtype=torch.int64).argmax(-1), masks.to(dtype=torch.int64)

    def __len__(self):
        return(len(self.data))
    
    def get_random_by_stype(self):
        return self[random.randint(0, len(self)-1)]

In [None]:
im_size = 64

train_transforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5), v2.RandomHorizontalFlip(p = 0.5), v2.RandomAffine(degrees=5), 
                     v2.RandomRotation(degrees=(90,90)), v2.RandomRotation(degrees=(-90,-90))]),
    v2.RandomChoice([v2.RandomPerspective(distortion_scale=0.1, p=1.0), v2.RandomResizedCrop(size=(im_size,im_size)), v2.RandomAffine(degrees=0, translate=(0.1,0.1), shear=(-2,2,-2,2))]), # translation + shearing
    v2.RandomAffine(degrees=0, scale=(0.8,1.2)), #scaling
    v2.RandomChoice([v2.ElasticTransform(alpha=40.0), v2.GaussianBlur(kernel_size=(3,7), sigma=(0.1, 0.7))]),
])
dtransforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5)]),
    v2.RandomAffine(degrees=0, scale=(0.8,1.2)), #scaling
    v2.RandomChoice([v2.GaussianBlur(kernel_size=(3,7), sigma=(0.1, 0.7))]),
])
val_transforms = v2.Compose([
    v2.Resize((im_size,im_size)), #resize
])

train_dataset_config ={
    'preload': True, # preload data into memory (WARNING IT MAY TAKE A LOT OF SPACE - DEPENDS ON THE DATASET USED)
    'im_size': [im_size, im_size],
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset",
    'load_series': ['axial', 'sagittal'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
    'united': True,

    'transforms': train_transforms,
    'dtransforms': dtransforms,
    'vsa': False,
    'one_label': False, # use one label for every level (do not differentiate between levels)
    'series_out_types': ['axial'], # mix output series types to get views necessary to create 3d box ['sagittal', 'coronal']
    'get_conditions':['Left Neural Foraminal Narrowing', 'Right Neural Foraminal Narrowing'], # condition to output from series
    'mix_strategy': 'combined', # strategy for mixing output series types ['random', 'custom', 'combined'] random - randomly select output type, 
                              #'manual' - will return data based on currently choosen view, 'combined' will return all views in one call
    'return_series_type': False, # If True getitem will also return series orignial type
    
    'normalize': True,
    'image_type': 'png',
    'preload': False,

    'dataset_type': 'boxes', #'boxes', 'conditions'
    'supress_warinings': False, 

    'limit_series_len': True, # if True the series length will be limited to number N specified by 'series_len' parameter
    'series_len': 6, #  maximal number of slices in series

    'x_overhead': [20,20], # overhead for levels in x-dim (in mm)
    'z_overhead':10,
    'overlap_levels': True, # if true the level upper and lower boundary will overlap with value specified in 'y_overlap'
    'y_overlap': 15, # overlap size of levels boundaries (in mm)
    'cond_x_overhead': [18,18],
    'cond_y_overhead': [8,8],
    'cond_z_overhead': [6,16],

    '3d_box': True
}

tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv').iloc[0:300]

data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
train_ids = tsd.study_id.unique()

train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
            'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
            'axial': data_axial[data_axial.study_id.isin(train_ids)]}

bb = BoxDatasetUnited(train_data, train_dataset_config)

In [None]:
color_dict = dict([(0,'r'),(1,'g'), (2,'b'), (3,'m'), (4, 'y')]) # 4646740 41477684


inputs, labels, masks = bb[3]


inp = inputs[3, 0]
print(labels[3, 0])
#imgs.append(inputs)

# print ground truth
for i in range(inp.shape[1]):
    fig, ax = plt.subplots(1, 2, figsize = (10, 10))
    ax[0].imshow(inp[0, i,:,:].detach().cpu().numpy())
    ax[1].imshow(inp[1,i,:,:].detach().cpu().numpy())
    plt.show()

In [None]:
labels.shape

In [None]:
labels

# MODEL TRAINER

In [15]:
class SagittalTrainer():
    def __init__(self, model, model_params, config, train_data, eval_data) -> None:

        self.print_evaluation = config["print_evaluation"] if "print_evaluation" in config else False
        self.steps_per_plot = config["steps_per_plot"]

        if config['train_dataset_config']['mix_strategy'] and len(config['train_dataset_config']['series_out_types'])>1:
            print("Warning! With mix strategy set to 'combined' and multiple series out types the batch size will be (len(series_out_types)) times bigger.")

        self.checkpoints = config['checkpoints']
        self.save_path = config['save_path']
        self.step_per_save = config['step_per_save']

        self.config = config

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Device set to {self.device}")

        self.model_name = config['model_name']
        self.model = model(**model_params).to(self.device)
        #self.model.to(self.device)
        
        self.optimizer = config["optimizer"](self.model.parameters(),**config["optimizer_params"])

        self.series_len = config['val_dataset_config']['series_len']
        self.dataloaders = {'train': torch.utils.data.DataLoader(BoxDatasetUnited(train_data, config['train_dataset_config']), batch_size=config["batch_size"], shuffle=True, num_workers=12, prefetch_factor=1),
                           'val': torch.utils.data.DataLoader(BoxDatasetUnited(eval_data, config['val_dataset_config']), batch_size=config["batch_size"], shuffle=False, num_workers=12, prefetch_factor=1)}
        
        self.max_epochs = config["epochs"]
        self.early_stopping = config['early_stopping']
        self.early_stopping_tresh = config['early_stopping_treshold']

        # scheduler
        if config["scheduler"]:
            if 'epochs' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['epochs'] = self.max_epochs
            if 'steps_per_epoch' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['steps_per_epoch'] = len(self.dataloaders['train'])
            self.scheduler = config["scheduler"](self.optimizer,**config["scheduler_params"])
            self.one_cycle_sched = self.scheduler.__class__.__name__ == 'OneCycleLR'
        else:
            self.scheduler = None
        
        ## Evaluation metrics
        self.series_in_types = config['train_dataset_config']['load_series']
        self.out_stypes = config['train_dataset_config']['series_out_types']
        self.batch_size = config["batch_size"] * len(self.out_stypes)

        self.mAP = MeanAveragePrecision(box_format = 'xyxy', iou_type='bbox', extended_summary=True).to(self.device)
        self.all_maps = []
        self.mAP_split = []
        self.best_ll = 1
        self.metrics_to_print = ['map', 'map_50', 'map_75', 'map_per_class']

    def save_model(self):
        torch.save(self.model.state_dict(), os.path.join(self.save_path, f"{self.model_name}_best.pt"))

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))

    def train(self):
        for epoch in range(self.max_epochs):
            print(f"Epoch {epoch+1}/{self.max_epochs}")
            self.train_one_epoch()
            self.eval_one_epoch()
            #print examples
            #checkpoint
            if epoch%5==0:
                pass

    def train_one_epoch(self):

        self.model.train()  # Set model to training mode
        metrics = defaultdict(list)
        
        with tqdm(self.dataloaders['train'], unit = "batch",
                    total = len(self.dataloaders['train'])) as tepoch:
            for inputs, labels, masks in self.dataloaders['train']:
                inputs = inputs.to(self.device).reshape(-1, 1, self.series_len, inputs.shape[-2], inputs.shape[-1])
                labels = labels.to(self.device).reshape(-1, 1)
                masks = masks.to(self.device).reshape(-1, 1)
                valid = [i for i in range(inputs.shape[0]) if (inputs[i].argmax() > 0)]
                if len(valid) <1:
                    tepoch.update(1)
                    continue
                with torch.set_grad_enabled(True):
                    loss, loss_info = self.model.get_loss(inputs[valid], labels[valid], masks[valid])
                    for loss_t, loss_v in loss_info.items():
                        metrics[loss_t].append(loss_v.clone().detach().cpu().numpy())
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    if self.scheduler and self.one_cycle_sched:
                        self.scheduler.step()

                #update tqdm data
                tepoch.set_description(self.metrics_description(metrics, 'train'))
                tepoch.update(1)

        if self.scheduler and not self.one_cycle_sched:
            self.scheduler.step()


    def eval_one_epoch(self):
        self.model.eval()
        alabels = []
        apreds = []
        aweight = []
        metrics_d = defaultdict(list)

        with tqdm(self.dataloaders['train'], unit = "batch",
                    total = len(self.dataloaders['val'])) as tepoch:
            for inputs, labels, masks in self.dataloaders['val']:
                with torch.set_grad_enabled(False):
                    inputs = inputs.to(self.device).reshape(-1, 1, self.series_len, inputs.shape[-2], inputs.shape[-1])
                    labels = labels.to(self.device).reshape(-1, 1)
                    masks = masks.to(self.device).reshape(-1, 1)
                    valid = [i for i in range(inputs.shape[0]) if (inputs[i].argmax() > 0)]
                    if len(valid) <1:
                        tepoch.update(1)
                        continue
                    inputs = inputs[valid]
                    labels = labels[valid]
                    masks = masks[valid]
                    
                    preds = self.model.predict(inputs.to(self.device).reshape(-1, len(self.series_in_types), self.series_len, inputs.shape[-2], inputs.shape[-1]))
                    _, loss_info = self.model.get_loss(inputs.to(self.device).reshape(-1, len(self.series_in_types), self.series_len, inputs.shape[-2], inputs.shape[-1]), 
                                                          labels.to(self.device).reshape(-1, 1), 
                                                          masks.to(self.device).reshape(-1,1))
                    for loss_t, loss_v in loss_info.items():
                        metrics_d[loss_t].append(loss_v.clone().detach().cpu().numpy())
                
                    apreds.append(preds.reshape(-1, 3))

                labels=  labels.reshape(-1)
                weights = 2**labels
                weights[torch.logical_not(masks.reshape(-1))] = 0.
                alabels.append(labels)
                aweight.append(weights)

                #update tqdm data
                tepoch.set_description(self.metrics_description(metrics_d, 'train'))
                tepoch.update(1)
        alabels = torch.cat(alabels, dim=0).cpu().numpy()
        apreds = torch.cat(apreds, dim=0).cpu().numpy()
        aweight=torch.cat(aweight, dim=0).cpu().numpy()

        #prind confusion matrix for every condition
        conditions = ['Left Neural Foraminal Narrowing','Left Neural Foraminal Narrowing']
        
        fig, ax = plt.subplots(nrows=1, ncols=len(conditions), figsize=(15,5))
        if len(conditions) > 1:
            ax = ax.ravel()
        else:
            ax = [ax]
        for i in range(len(conditions)):
            cl = alabels[i::len(conditions)]
            cpred = apreds[i::len(conditions),:].argmax(-1)
            cm = confusion_matrix(cl, cpred)
            ax[i].set_title(conditions[i])
            ConfusionMatrixDisplay(
                confusion_matrix=cm).plot(ax=ax[i], colorbar=False)
        plt.show()

        ll = log_loss(alabels, apreds, normalize=True, sample_weight=aweight)
        print("Score:", ll)
        if ll < self.best_ll:
            self.save_model()
            self.best_ll = ll
        
    def metrics_description(self, metrics:dict, phase:str)->str:
        outputs = phase + ": ||"
        for k in metrics.keys():
            outputs += (" {}: {:4f} ||".format(k, np.mean(metrics[k])))
        return outputs


In [16]:
class MlpHead(nn.Module):
    """ MLP classification head
    """
    def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=nn.ReLU,
        norm_layer=nn.LayerNorm, head_dropout=0., bias=True):
        super().__init__()
        hidden_features = int(mlp_ratio * dim)
        self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
        self.act = act_layer()
        self.norm = norm_layer(hidden_features)
        self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
        self.head_dropout = nn.Dropout(head_dropout)


    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.head_dropout(x)
        x = self.fc2(x)
        return x

In [17]:
class ClassModelTimm2d(nn.Module):
    def __init__(self, backbone_name, series_dim, pretrained=True):
        super().__init__()
        self.model = timm.create_model(
                                    backbone_name,
                                    pretrained=pretrained, 
                                    features_only=False,
                                    in_chans=series_dim[0]*1,
                                    num_classes=3,
                                    global_pool='avg'
                                    )
    
    def forward(self, x):
        b, s, d, h, w = x.shape
        x = x.permute(0,2,1,3,4)
        x = x.reshape(b, s*d, h, w)
        y = self.model(x)
        return y.reshape(b, 3, -1)

    def get_loss(self, input, labels, masks):
        preds = self.forward(input)
        labels = labels
        w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
        
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss = F.cross_entropy(preds, labels, reduction='none', label_smoothing=0.01, weight=torch.tensor([1, 2., 4.], dtype = preds.dtype).to(preds.device)) * masks
        loss = loss.mean()*torch.tensor(preds.shape[0], dtype=loss.dtype).to(preds.device)

        return loss.mean(), dict(cross_entropy=loss.mean().clone().detach())
    
    def predict(self, input):
        return self.forward(input).permute(0,2,1).softmax(-1)

In [18]:
class ClassModelTimm2dLSTM(nn.Module):
    def __init__(self, backbone_name, series_dim, pretrained=True):
        super().__init__()
        self.model = timm.create_model(
                                    backbone_name,
                                    pretrained=pretrained, 
                                    features_only=True,
                                    in_chans=series_dim[0],
                                    num_classes=3,
                                    global_pool='avg'
                                    )
        self.all_channels = self.model.feature_info.channels()
        self.reduction = self.model.feature_info.reduction()
        print(self.all_channels, self.reduction)
        dim = self.all_channels[-1]*int((series_dim[1]/self.reduction[-1])*(series_dim[2]/self.reduction[-1]))
        self.lstm = nn.LSTM(dim, 512, 2, batch_first=True, bidirectional=True)
        self.head = MlpHead(512*20, 10*3, 1)
        
    def forward(self, x):
        b, s, d, h, w = x.shape
        x = x.reshape(b*s, d, h, w)
        y = self.model(x)[-1]
        x = y.reshape(b, s, -1)
        x, _ = self.lstm(x)
        x = self.head(x.reshape(b,-1)).reshape(b,10,3)
        return x

    def get_loss(self, input, labels, masks):
        
        preds = self.forward(input)
        preds = preds.reshape(-1,3).unsqueeze(-1)
        labels = labels.reshape(-1,1)
        masks = masks.reshape(-1,1)
        w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss = F.cross_entropy(preds, labels, reduction='none', label_smoothing=0.0) * w
        loss = loss.mean()*torch.tensor(preds.shape[0], dtype=loss.dtype).to(preds.device)

        return loss.mean(), dict(cross_entropy=loss.mean().clone().detach())
    
    def predict(self, input):
        return self.forward(input).permute(0,2,1).softmax(-1)

# Training model

In [None]:
im_size = 80

train_transforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5), v2.RandomHorizontalFlip(p = 0.5), v2.RandomAffine(degrees=5), 
                     v2.RandomRotation(degrees=(90,90)), v2.RandomRotation(degrees=(-90,-90))]),
    v2.RandomChoice([v2.RandomPerspective(distortion_scale=0.2, p=.5), v2.RandomAffine(degrees=0, translate=(0.1,0.1), shear=(-5,5,-5,5))]), # translation + shearing
    v2.RandomAffine(degrees=0, scale=(0.8,1.2)), #scaling
    v2.RandomChoice([v2.GaussianBlur(kernel_size=(3,7), sigma=(0.1, 0.7))]),
])

dtransforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5)]),
    #v2.RandomAffine(degrees=0, scale=(0.8,1.2)), #scaling
    v2.RandomChoice([v2.GaussianBlur(kernel_size=(3,7), sigma=(0.1, 0.7))]),
])

val_transforms = v2.Compose([
    v2.Resize((im_size,im_size)), #resize
])

train_dataset_config ={
    'preload': True, # preload data into memory (WARNING IT MAY TAKE A LOT OF SPACE - DEPENDS ON THE DATASET USED)
    'im_size': [im_size, im_size],
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset",
    'load_series': ['axial'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
    'united': True,

    'transforms': train_transforms,
    'dtransforms': None,
    'vsa': False,
    'one_label': False, # use one label for every level (do not differentiate between levels)
    'series_out_types': ['axial'], # mix output series types to get views necessary to create 3d box ['sagittal', 'coronal']
    'get_conditions':[ 'Left Subarticular Stenosis',  'Right Subarticular Stenosis'], #['Left Neural Foraminal Narrowing', 'Right Neural Foraminal Narrowing'], # condition to output from series
    'mix_strategy': 'combined', # strategy for mixing output series types ['random', 'custom', 'combined'] random - randomly select output type, 
                              #'manual' - will return data based on currently choosen view, 'combined' will return all views in one call
    'return_series_type': False, # If True getitem will also return series orignial type
    
    'normalize': True,
    'image_type': 'png',
    'preload': False,

    'dataset_type': 'boxes', #'boxes', 'conditions'
    'supress_warinings': False, 

    'limit_series_len': True, # if True the series length will be limited to number N specified by 'series_len' parameter
    'series_len': 3, #  maximal number of slices in series

    'x_overhead': [20,20], # overhead for levels in x-dim (in mm)
    'z_overhead':10,
    'overlap_levels': True, # if true the level upper and lower boundary will overlap with value specified in 'y_overlap'
    'y_overlap': 15, # overlap size of levels boundaries (in mm)
    'cond_x_overhead': [6,9],
    'cond_y_overhead': [12,14],
    'cond_z_overhead': [4,4],

}


val_dataset_config = copy.deepcopy(train_dataset_config)
val_dataset_config['transforms'] = None
val_dataset_config['dtransforms'] = None
val_dataset_config['mix_strategy'] = 'combined'
val_dataset_config['vsa'] = False
val_dataset_config['randomize_borders'] = True
val_dataset_config['return_series_type'] = True
'''
val_dataset_config['x_overhead'] = [2, 10]
val_dataset_config['z_overhead'] = 5
val_dataset_config['overlap_levels'] = True
val_dataset_config['y_overlap'] = 5

val_dataset_config['x_overhead_axial'] = [2, 10]
val_dataset_config['z_overhead_axial'] = 5
val_dataset_config['overlap_levels_axial'] = True
val_dataset_config['y_overlap_axial'] = 5
'''

In [46]:
trainer_config = {
    "print_evaluation": True,
    "steps_per_plot": 1,

    "checkpoints": False,
    "save_path": "/workspaces/RSNA_LSDC/models_3d_final/model_weight",
    "step_per_save":100,
    "model_name": "densenet121_80_80_3_foraminas",
    "train_dataset_config": train_dataset_config,
    "val_dataset_config": val_dataset_config,

    "epochs": 30, 
    "batch_size": 5,

    "optimizer": torch.optim.AdamW, #torch.optim.AdamW,#torch.optim.Adam,
    "optimizer_params": {'lr': 3e-4, 'weight_decay': 1e-2},#, 'momentum': 0.98, 'weight_decay': 1e-3},#, 'momentum':0.98, 'weight_decay':1e-5},#, 'momentum':0.9},
    "scheduler": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, #torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, #torch.optim.lr_scheduler.ExponentialLR, #torch.optim.lr_scheduler.OneCycleLR
    "scheduler_params": {'T_0': 15, 'T_mult': 1, 'eta_min':1e-8}, #{'T_0': 2, 'T_mult': 2, 'eta_min':3e-5}, #{'max_lr': 0.001, 'epochs': None, 'steps_per_epoch':None}, {'gamma':0.9}

    "early_stopping": False,
    "early_stopping_treshold": 0.1,
    'vsa':True,
}

model_config = {
    'backbone_name': 'densenet121', 
    'series_dim': [train_dataset_config['series_len']]+train_dataset_config['im_size'],
    'pretrained': True, 
    }


In [47]:
tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv')#.iloc[0:300]
def kfoldCV(k, trainer_config, model_config):
    model_summaries = []
    data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
    data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
    data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
    unique_studies = np.random.permutation(np.array(tsd.study_id.unique()))
    if k == 1:
        with open('/workspaces/RSNA_LSDC/models_3d_final/test_unique_studies.npy', 'rb') as f:
            test = np.load(f)
        with open('/workspaces/RSNA_LSDC/models_3d_final/train_unique_studies.npy', 'rb') as f:
            train = np.load(f) 
        folds = [test, train]
    else:
        folds = np.array_split(unique_studies, k)
    
    for i in range(k):
        print(f"Fold: {i}")
        train_ids = np.concatenate(folds[:i]+folds[i+1:], axis=0)
        train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
                    'axial': data_axial[data_axial.study_id.isin(train_ids)]}
        val_data=  {'sagittal': data_sagittal[data_sagittal.study_id.isin(folds[i])],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(folds[i])],
                    'axial': data_axial[data_axial.study_id.isin(folds[i])]}
        
        trainer = SagittalTrainer(ClassModelTimm2d, model_config, trainer_config, train_data, val_data)
        trainer.train()
        model_summaries.append(trainer.get_summary())

    return model_summaries
# convformer_s18

In [None]:
kfoldCV(1, trainer_config, model_config)