Here we'll create model for level boundary estimation:  
- estimating level instances for series  
- estimating level box - once we find best performing boxes  

In [9]:
import os
import cv2
import glob
import copy
import torch
import random
from torch import nn
import timm
import timm_3d
from typing import List, Sequence, Tuple, Union, Dict
from scipy import ndimage

from sklearn.metrics import log_loss
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from MedicalNet.MedicalNet import Struct, MedNet
from PIL import Image
import pandas as pd
from typing import Literal
from collections import defaultdict
import numpy as np
from tqdm.notebook import tqdm

import albumentations as A
from torchvision.tv_tensors import BoundingBoxes as BB
from torchmetrics.detection import MeanAveragePrecision
import torchvision.transforms.v2 as v2
import torch.nn.functional as F
from torchvision.ops import nms

import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:
torch.cuda.is_available()

postprocessing found bboxes:
Include guaranteed information about bboxes in image. 
-> 5 classes each with one bbox or 1 class with 5 bbox
-> If there exist level n-1 and n+1 in image there must also exist level n
-> The boxes of levels n-1, n, n+1 must be aligned next to another in heigh dimention
-> The boxes must overlap in x dimention to some extend 
-> If there exist series of boxes for levels n, n+1, n+2 and image height is bigger than mean level heigh than the n+3 level must also exist - same situation in reverse


# DATASET

In [11]:
class Bbox3d():
    def __init__(self, x, y, z) -> None:
        # 3d box in coordinates of sagittal series
        self.x = x
        self.y = y
        self.z = z
    
    def get_box_in_view_type(self, view_type, d3:bool=False):
        if view_type in ['sagittal', 'sagittal_t2']:
            return self.get_sagittal(d3)
        elif view_type == 'coronal':
            return self.get_coronal(d3)
        elif view_type == 'axial':
            return self.get_axial(d3)
    
    def get_sagittal(self, d3:bool=False):
        if d3:
            return [self.x[0], self.y[0], self.x[1], self.y[1], self.z[0], self.z[1]]
        return [self.x[0], self.y[0], self.x[1], self.y[1]]

    def get_coronal(self, d3:bool=False):
        if d3:
            return [self.z[0], self.y[0], self.z[1], self.y[1], self.x[0], self.x[1]]
        return [self.z[0], self.y[0], self.z[1], self.y[1]]

    def get_axial(self, d3:bool=False):
        if d3:
            return [self.z[0], self.x[0], self.z[1], self.x[1], self.y[0], self.y[1]]
        return [self.z[0], self.x[0], self.z[1], self.x[1]]

In [12]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        # data_info: dict consisting of series types and dataframe with their info
        # config: dict - dataset configuration
        #TODO: Better overlap in instance dimention
        #TODO: smart limit - selecting important slices based on series type and where they usually lay
        
        self.supress_warnings = config['supress_warinings']
        self.used_series_types = config['load_series']
        self.study_ids = np.unique(np.concatenate([series.study_id.unique() for series in data_info.values() if series is not None]))
        self.data_info = data_info
        if not np.all([tp in list(self.data_info.keys()) for tp in self.used_series_types]):
            raise Exception("Series types to use do not match provided data information.")
        
        self.series_out_types = config['series_out_types']
        self.current_view = self.series_out_types[0]
        self.return_series_type = config['return_series_type']

        self.d3 = True #config['3d_box'] if '3d_box' in list(config.keys()) else False
        self.preload = config['preload']

        self.im_size = config['im_size']
        if self.im_size:
            self.resize = v2.Resize((self.im_size[1], self.im_size[0]))

        self.image_type = config['image_type']

        if self.preload and not self.supress_warnings:
            print("Warning! Preloading of images is turned on. The program will attempt to load whole dataset into memory!")
        
        self.transforms = config['transforms']
        self.transforms_d = config['transforms_d']
        self.normalize = config['normalize']
        self.vsa = config['vsa']
        
        self.dataset_type = config['dataset_type']
        self.dataset_path = config['dataset_path']

        self.rng = np.random.default_rng()

        self._condition_list = ['Left Neural Foraminal Narrowing', 
                                'Left Subarticular Stenosis', 
                                'Right Neural Foraminal Narrowing', 
                                'Right Subarticular Stenosis', 
                                'Spinal Canal Stenosis']
        
        self._status_map = {'Normal/Mild': [1., 0., 0.],
                            'Moderate': [0., 1., 0.],
                            'Severe': [0., 0., 1.]}
        
        self.level_ind = {'L1/L2': 0, 'L2/L3': 1, 'L3/L4': 2, 'L4/L5': 3, 'L5/S1':4}

        self.box_labels = torch.tensor([list(self.level_ind.values())]) 

        self.limit = config['limit_series_len']
        self.series_length = 15
        
        self.x_overhead = config['x_overhead'] if 'x_overhead' in list(config.keys()) else [30, 30]
        self.z_overhead = config['z_overhead'] if 'z_overhead' in list(config.keys()) else 0
        self.overlap_levels = config['overlap_levels']
        self.y_overlap = config['y_overlap'] if self.overlap_levels else 0

        self.x_overhead_axial = config['x_overhead_axial'] if 'x_overhead_axial' in list(config.keys()) else [2, 2]
        self.z_overhead_axial = config['z_overhead_axial'] if 'z_overhead_axial' in list(config.keys()) else 4
        self.overlap_levels_axial = config['overlap_levels_axial'] if 'overlap_levels_axial' in list(config.keys()) else False
        self.y_overlap_axial = config['y_overlap_axial'] if self.overlap_levels_axial else 0
        
        if self.limit and not config['series_len'] and not self.supress_warnings:
            print("Series length is not specified. The limit is set to 15.")
        elif self.limit and config['series_len']:
            self.series_length = config['series_len']
            
        self.series_type_ind = {}
        for s in self.used_series_types:
            self.series_type_ind[s] = []

        self.rand_bord = config['randomize_borders']
        if self.rand_bord:
            self.x_overhead = self.x_overhead*2
            self.z_overhead = self.z_overhead*2
            self.y_overlap = self.y_overlap*2

            self.x_overhead_axial = self.x_overhead_axial*2
            self.z_overhead_axial = self.z_overhead_axial*2
            self.y_overlap_axial = self.y_overlap_axial*2

        self.sim_bd = False
        self.data = []
        self.prepare_data()

    def get_level_boxes(self, series_info:pd.DataFrame, stype='sagittal'):
        #default for each scan has 5 visible levels
        bboxes = []
        labels = []
        if stype in ['sagittal', 'sagittal_t2']:
            for _, row in series_info.iterrows():
                sl = np.array(row.slice_locations)
                
                #z_overhead = -self.z_overhead if row.reversed else self.z_overhead
                fr_a = sl[row.present_instances.index(min(row.instance_number))]
                too_a = sl[row.present_instances.index(max(row.instance_number))]
                
                fr = min(fr_a, too_a)
                too = max(fr_a, too_a)
 
                z_min = np.argmin(abs(sl - (fr-self.z_overhead )))
                z_max = np.argmin(abs(sl - (too+self.z_overhead)))

                labels.append(self.level_ind[row.level])
                bboxes.append(Bbox3d(
                    x=[max(0, min(row.x)-self.x_overhead[0]/row.pixel_spacing[0]), min(row.image_width, max(row.x)+self.x_overhead[1]/row.pixel_spacing[0])],
                    y=[max(0, row.level_boundaries[0]-self.y_overlap/row.pixel_spacing[1]), min(row.image_height, row.level_boundaries[1]+self.y_overlap/row.pixel_spacing[1])],
                    z=[min(z_max, z_min), max(z_max, z_min)]))
                
        elif stype=='axial':
            y_overlap = self.y_overlap + 0.5 
            for _, row in series_info.iterrows():
                labels.append(self.level_ind[row.level])
                bboxes.append(Bbox3d(
                    x=[max(0, min(row.y)-self.x_overhead[0]/row.pixel_spacing[1]), min(row.image_width, max(row.y)+self.x_overhead[1]/row.pixel_spacing[1])],
                    y=[max(0, row.present_instances.index(min(row.level_slices))-y_overlap), min(len(row.present_instances), row.present_instances.index(max(row.level_slices))+y_overlap)],
                    z=[max(0, min(row.x)-self.z_overhead/row.pixel_spacing[0]), min(row.image_height, max(row.x)+self.z_overhead/row.pixel_spacing[0])]
                                ))

        return bboxes, np.array(labels)

    def set_view(self, new_view):
        self.current_view = new_view
        
    def get_condition_labels(self, series_info:pd.DataFrame):
        labels = []
        cond_presence_masks = []
        level_presence_mask = []
        for level, _ in self.level_ind.items():
            if not series_info[series_info['level']==level].empty:
                labels.append(series_info[series_info['level']==level].iloc[0].status)
                cond_presence_masks.append(series_info[series_info['level']==level].iloc[0].presence_mask)
                level_presence_mask.append(True)
            else:
                level_presence_mask.append(False)

        return np.array(labels), np.array(cond_presence_masks), np.array(level_presence_mask)

    def info2dict(self, series_info, stype=None): #remember axials can be combination of different serieses (sagittals can't)
        level0 = series_info.iloc[0]
        data_dict = {}
        boxes, box_labels = self.get_level_boxes(series_info, stype=stype)
        labels, label_level_mask, label_cond_mask = self.get_condition_labels(series_info)

        data_dict['study_id'] = level0.study_id
        data_dict['series_id'] = level0.series_id
        data_dict['width'] = level0.image_width
        data_dict['height'] = level0.image_height
        data_dict['reversed'] = level0.reversed
        data_dict['series_type'] = stype
        data_dict['pixel_spacing'] = level0.pixel_spacing

        data_dict['boxes'] = boxes 
        data_dict['files'] = [f"{self.dataset_path}/{data_dict['study_id']}/{data_dict['series_id']}/{instance}.{self.image_type}" for instance in level0.present_instances]
        data_dict['box_labels'] = box_labels
        data_dict['labels'] = labels
        data_dict['label_presence_mask'] = label_level_mask
        #data_dict['level_presence_mask'] = label_cond_mask
        
        return data_dict
   
    def prepare_data(self):
        # prepare paths for every image to load
        with tqdm(total=len(self.study_ids), desc="Preparing data: ") as pbar:
            for study_id in self.study_ids:
                study_dict = dict(
                                sagittal=[], 
                                sagittal_t2=[], 
                                axial=[])
                present = 0
                for stype in self.used_series_types:
                    for series_id in self.data_info[stype].query(f'study_id == {study_id}').series_id.unique():
                        present += 1
                        ddict = self.info2dict(self.data_info[stype][self.data_info[stype].series_id==series_id], stype=stype)
                        if self.preload:
                            ddict = self.preload_series(ddict)
                        study_dict[stype].append(ddict)
                
                if present == 0:
                    continue
                else:
                    self.data.append(study_dict)
                pbar.update(1)

    def split_to_series(self):
        temp = []
        batches = []
        i=0
        for data in self.data:
            batch = []
            for series in data.values():
                if series:
                    temp+=series
                    batch.append(i)
                    i+=1
            batches.append(batch)

        self.data = temp
        self.batches = batches
    
    def drop_healthy(self, drop=0.5):
        to_del = []
        for i, data in enumerate(self.data):
            labels = torch.tensor(data['labels'])
            labels = labels.to(dtype=torch.int64).argmax(-1)
            if labels[:,4].sum(0) < 1:
                if random.random() > drop:
                    to_del.append(i)
        self.data = [data for i, data in enumerate(self.data) if i not in to_del]


    def change_img_view(self, img, current_view, new_view):
        if current_view in ['sagittal', 'sagittal_t2']:
            if new_view in ['sagittal', 'sagittal_t2']:
                return img
            elif new_view=='coronal':
                return img.transpose(0, 3, 2, 1) # n, h, w -> w, h, n
            elif new_view=='axial':
                return img.transpose(0, 2, 3, 1) #n, h, w -> h, n, w
        elif current_view=='axial':
            if new_view =='sagittal':
                return img.transpose(2, 0, 1) # n, h, w -> w, n, h
            elif new_view =='coronal':
                return img.transpose(2, 1, 0) # n, h, w -> h, n, w
            elif new_view=='axial':
                return img
            
            
    def preload_series(self, data):

        boxes = np.array([box.get_box_in_view_type(data['series_type'], d3 = True) for box in data['boxes']], dtype=int)
        level_labels = data['box_labels']
        #print(boxes)
        oimg = np.zeros((len(data['files']), data['height'], data['width']), dtype = np.float)
        for i, path in enumerate(data['files']):  
            try:      
                oimg[i,:,:] = np.array(Image.open(path), dtype = np.float)
            except ValueError:
                temp = np.array(Image.open(path), dtype = np.float) 
                oimg[i,:temp.shape[0],:temp.shape[1]] = temp
                del temp

        #oimg = self.__itensity_normalize_one_volume__(oimg)

        # split to boxes 
        imgs = [np.zeros((self.series_length, self.im_size[1], self.im_size[0]))]*5
        for i, (box, label) in enumerate(zip(boxes, level_labels)):
            imgs[label] = oimg[
                                int(max(0, box[4]-2)): int(min(oimg.shape[0], box[5]+2)), 
                                int(max(0, box[1]-10)): int(min(oimg.shape[1], box[3]+10)), 
                                int(max(0, box[0]-10)): int(min(oimg.shape[2], box[2]+10))
                                ]

        data['series'] = imgs
        return data

         
    def load_series(self, data) -> np.ndarray:
        #print(data)
        boxes = np.array([box.get_box_in_view_type(data['series_type'], d3 = True) for box in data['boxes']], dtype=int)
        level_labels = data['box_labels']
        #print(boxes)
        oimg = np.zeros((len(data['files']), data['height'], data['width']), dtype = float)
        for i, path in enumerate(data['files']):  
            try:      
                oimg[i,:,:] = np.array(Image.open(path), dtype = float)
            except ValueError:
                temp = np.array(Image.open(path), dtype = float) 
                oimg[i,:temp.shape[0],:temp.shape[1]] = temp
                del temp

        #oimg = self.__itensity_normalize_one_volume__(oimg)

        if self.sim_bd: # resize first to simulate bbox detector output
            oimg, boxes = self.__resize_data__(oimg, boxes)

        # split to boxes 
        imgs = np.zeros((5, self.series_length, self.im_size[1], self.im_size[0]), dtype=float)
        for i, (box, label) in enumerate(zip(boxes, level_labels)):
            if self.rand_bord:
                x_overhead = self.x_overhead_axial if data['series_type'] == 'axial' else self.x_overhead
                y_overhead = self.y_overlap_axial if data['series_type'] == 'axial' else self.y_overlap
                z_overhead = self.z_overhead_axial if data['series_type'] == 'axial' else self.z_overhead

                ig = oimg[
                                    int(max(0, box[4]+np.random.randint(-2,2))): int(min(oimg.shape[0], box[5]+np.random.randint(-2,2))), 

                                    int(max(0, box[1]+np.random.randint(-x_overhead[0],x_overhead[0])/data['pixel_spacing'][1])): 
                                    int(min(oimg.shape[1], box[3]+np.random.randint(-x_overhead[1],x_overhead[1])/data['pixel_spacing'][1])), 

                                    int(max(0, box[0]+np.random.randint(-y_overhead,y_overhead)/data['pixel_spacing'][0])): 
                                    int(min(oimg.shape[2], box[2]+np.random.randint(-y_overhead,y_overhead)/data['pixel_spacing'][0]))
                                    ]
                
                try:
                    ig = self.__itensity_normalize_one_volume__(self.__resize_data__(ig))
                except:
                    continue
            else:
                ig = self.change_img_view(oimg[box[4]: box[5], box[1]:box[3], box[0]:box[2]], current_view = data['series_type'], new_view='sagittal')
                try:
                    ig = self.__itensity_normalize_one_volume__(self.__resize_data__(ig))
                except:
                    continue
                
            imgs[label] = ig
        #print(imgs, imgs.shape)
        return imgs

    def coarse_dropout_3d(self, volume, max_holes_num, max_hole_size):
        # set cut data from volume but leave parts around points with condition of intrest untoutched.
        # draw number of holes
        hol_num = self.rng.integers(12, volume.shape[0]*max_holes_num)
        # for each hole draw its placement on image and size
        placement_and_size = self.rng.integers(low=0, 
                                      high = [[volume.shape[0], volume.shape[1], volume.shape[2], volume.shape[3], 
                                               max_hole_size[0], max_hole_size[1], max_hole_size[2]]]*hol_num)
        for i in range(hol_num):
            c = placement_and_size[i,:]
            volume[c[0], c[1]:min(volume.shape[1], c[1]+c[4]), c[2]:min(volume.shape[2], c[2]+c[5]), c[3]: min(volume.shape[3], c[3]+c[6])] = 0.
        return volume


    def __itensity_normalize_one_volume__(self, volume):
        """
        normalize the itensity of an nd volume based on the mean and std of nonzeor region
        inputs:
            volume: the input nd volume
        outputs:
            out: the normalized nd volume
        """
        #out = (volume - 0.5*255.)/(0.5 *255.)
       
        pixels = volume[volume > 0]
        if pixels.size==0:
            return volume
        
        mean = pixels.mean()
        std  = pixels.std()
        out = (volume - mean)/std
        #out_random = np.random.normal(0, 1, size = volume.shape)
        #out[volume == 0] = out_random[volume == 0]
       
        return out

    def __resize_data__(self, data, bb=None):
        """
        Resize the data to the input size
        """ 
        
        [depth, height, width] = data.shape
        scale = [self.series_length*1.0/depth, self.im_size[1]*1.0/height, self.im_size[0]*1.0/width]
        #scale_d = [self.series_length*1.0/depth, 1, 1]
        if bb:
            bb[..., [0,2]] = bb[..., [0,2]]/width*self.im_size[1]
            bb[..., [1,3]] = bb[..., [1,3]]/height*self.im_size[0]
            bb[..., [4,5]] = bb[..., [4,5]]/depth*self.series_length

        data = ndimage.zoom(data, scale, order=1)
        #data = ndimage.zoom(data, scale_d, order=0)
        #data = self.limit_series(data)
        
        if bb:
            return data, bb
        return data
    
    def __drop_invalid_range__(self, volume, boxes):
        """
        Cut off the invalid area
        """
        zero_value = volume[0, 0, 0]
        non_zeros_idx = np.where(volume != zero_value)
        
        [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
        [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
        
        boxes[..., [0,2]] -= min_w
        boxes[..., [1,3]] -= min_h
        
        return volume[min_z:max_z, min_h:max_h, min_w:max_w], boxes 
    
    def limit_series(self, img, z_bb=None):
        if self.series_length < img.shape[0]:
                if self.d3 and z_bb:
                    z_bb = z_bb* self.series_length/img.shape[0]
                st = np.round(np.linspace(0, img.shape[0] - 1, self.series_length)).astype(int)
                img = img[st,:,:]
        elif self.series_length > img.shape[0]:
                img = np.pad(img, ((0, self.series_length-img.shape[0]), (0, 0), (0, 0)))
        if z_bb:
            return img, z_bb
        return img

    
    def prepare_series(self, img, data):
    
        # img [5, d, h, w]
        level_labels = torch.tensor(data['box_labels'], dtype=torch.int64)
        cond_labels = torch.zeros((5,5,3), dtype=torch.int64)
        cond_masks = torch.zeros((5,5), dtype=torch.int64)
        cond_labels[level_labels,...] = torch.tensor(data['labels'], dtype=torch.int64)
        cond_masks[level_labels,...] = torch.tensor(data['label_presence_mask'], dtype=torch.int64)

        img = torch.tensor(img,dtype=torch.float)
        if self.transforms:
            # transform in height axis
            img = self.transforms(img)
            # transform in width axis
            if self.transforms_d and torch.rand(1)<0.5:
                img = self.transforms_d(img.permute(0, 2, 1, 3))
                img = img.permute(0, 2, 1, 3)

        if torch.rand(1) < 0.5 and self.vsa:
            img = self.coarse_dropout_3d(img,12,[int(self.series_length/3),int(self.im_size[1]/3), int(self.im_size[0]/3)])
        
        for i, im in enumerate(img):
            if torch.count_nonzero(im)==0:
                cond_masks[i] *= 0
            else: cond_masks[i]=torch.ones_like(cond_masks[i])

        return img, level_labels, cond_labels, cond_masks
    
    def __getitem__(self, index: int):
        return None
        
    def __len__(self):
        return(len(self.data))

In [13]:
class BoxDatasetUnited(Dataset):
    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        super(BoxDatasetUnited, self).__init__(data_info, config)
        self.mix_strategy = config['mix_strategy']
        if config['one_label']:
            self.box_labels = torch.tensor([0,0,0,0,0], dtype=torch.float)
        else:
           self.box_labels = torch.tensor([list(self.level_ind.values())]) 

        #split data to individual serieses from dict of study-series_type pairs
        self.split_to_series()
        #self.drop_healthy()

    def drop_series(self, volume,  masks, mask_dropped:bool=True):
        # drop one of the serieses and possibly mask conditions primary estimated with it (if mask_dropped)
        to_drop = random.choice(list(range(0, len(self.used_series_types))))
        volume[to_drop] = torch.zeros((5, self.series_length, self.im_size[1], self.im_size[0]))
        if mask_dropped:
            if to_drop == 0:
                masks[:,[0,2]] = 0
            elif to_drop==1:
                masks[:,4] = 0
            elif to_drop==2:
                masks[:,[1,3]] = 0
                
        return volume, masks

    def mirror_sides(self, volume, labels, masks):
        # mirror left to right or right to left
        if torch.rand(1)>0.5:
            #right to left
            volume[:, :, :int(volume.size(2)/2), ...] = volume[:, :, int(volume.size(2)/2):, ...].flip(2)
            labels[:, [2,3]] = labels[:, [0,1]]
            masks[:, [2,3]] = masks[:, [0,1]]
        else:
            #left to right
            volume[:, :, int(volume.size(2)/2):, ...] = volume[:, :, :int(volume.size(2)/2), ...].flip(2)
            labels[:, [0,1]] = labels[:, [2,3]]
            masks[:, [0,1]] = masks[:, [2,3]]

        return volume, labels, masks
        
    def flip_sides(self, volume, labels, masks):
        # flip left/right sides and their labels
        volume = volume.flip(2)
        labels = labels[:, [2,3,0,1,4]]
        masks = masks[:, [2,3,0,1,4]]
        return volume, labels, masks
    

    def __getitem__(self, index: int)->tuple[np.ndarray, np.ndarray]:

        data = self.data[index]
        #print(data)
        # limit series (if limit==True)
        all_type_img = []
        all_type_mask = []
        all_type_labels = []
        type_to_ind = {}
        ind = 0
        for st in self.used_series_types:
            type_to_ind[st] = ind
            ind+=1
            
        all_type_img = torch.zeros((1, 5, self.series_length, self.im_size[1], self.im_size[0]))
        
        oimg = self.load_series(data)
        img, level_presence,  cond_labels, cond_masks = self.prepare_series(oimg, data)
        
        all_type_img[0, level_presence,...]=img[level_presence,...]
        all_type_labels.append(cond_labels)
        all_type_mask.append(cond_masks)

        try:
            all_type_labels = torch.stack(all_type_labels)#.reshape(-1, 5, 4 if not self.d3 else 6)
        except:
            return all_type_img.permute(1,0,2,3,4), torch.zeros((5,5), dtype=torch.int64), torch.zeros((5,5), dtype=torch.int64)
        
        all_type_mask = torch.stack(all_type_mask)#.reshape(-1, 5, 1)
        all_type_labels,all_type_mask= self.prepare_labels(all_type_labels, all_type_mask)
        
        #if self.vsa:
            #all_type_img = self.transforms(all_type_img.permute(1,0,2,3,4).reshape(5, len(self.used_series_types)*self.series_length, ))
            #if torch.rand(1)<0.2:
                #all_type_img, all_type_labels, all_type_mask = self.flip_sides(all_type_img, all_type_labels, all_type_mask)
            #elif torch.rand(1)<0.2:
                #all_type_img, all_type_labels, all_type_mask = self.mirror_sides(all_type_img, all_type_labels, all_type_mask)
            #if torch.rand(1) <0.2:
                #all_type_img, all_type_mask = self.drop_series(all_type_img, all_type_mask, mask_dropped=False)
        
        return all_type_img.permute(1,0,2,3,4), all_type_labels.to(dtype=torch.int64), all_type_mask.to(dtype=torch.int64)
    
    def prepare_labels(self, labels, masks):
        labels = labels.sum(axis=0)>0
        masks = masks.sum(axis=0)>0
        return labels.to(dtype=torch.int64).argmax(-1), masks.to(dtype=torch.int64)

    def __len__(self):
        return(len(self.data))
    
    def get_random_by_stype(self):
        return self[random.randint(0, len(self)-1)]

In [None]:
im_size = 128

train_transforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.2), v2.RandomHorizontalFlip(p = 0.2)]), #flip
    v2.RandomChoice([v2.RandomAffine(degrees=15), v2.RandomAffine(degrees=0, translate=(0.1,0.1), shear=(-3,3,-3,3))]), # translation + shearing
    v2.RandomAffine(degrees=0, scale=(0.85,1.2)), #scaling
    v2.RandomChoice([v2.GaussianBlur(kernel_size=(3,3), sigma=(0.8, 0.8))]),
])

train_transforms = v2.Compose([
    v2.RandomChoice([v2.GaussianBlur(kernel_size=(3,3), sigma=(1, 1))]),
    v2.RandomAffine(degrees=0, translate=(0.1,0.1), shear=(-3,3,-3,3)), #resize
])

train_dataset_config ={
    'preload': True, # preload data into memory (WARNING IT MAY TAKE A LOT OF SPACE - DEPENDS ON THE DATASET USED)
    'im_size': [128, 128],
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset_vl",
    'load_series': ['sagittal_t2'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
    'united': True,

    'transforms': None,
    'transforms_d': None,
    'vsa': False,
    'one_label': False, # use one label for every level (do not differentiate between levels)
    'series_out_types': ['sagittal', 'axial'], # mix output series types to get views necessary to create 3d box ['sagittal', 'coronal']
    'mix_strategy': 'combined', # strategy for mixing output series types ['random', 'custom', 'combined'] random - randomly select output type, 
                              #'manual' - will return data based on currently choosen view, 'combined' will return all views in one call
    'return_series_type': False, # If True getitem will also return series orignial type
    'randomize_borders':False,
    'normalize': True,
    'image_type': 'png',
    'preload': False,

    'dataset_type': 'boxes', #'boxes', 'conditions'
    'supress_warinings': False, 

    'limit_series_len': True, # if True the series length will be limited to number N specified by 'series_len' parameter
    'series_len': 10, #  maximal number of slices in series

    'x_overhead': [20,20], # overhead for levels in x-dim (in mm)
    'z_overhead': 2,
    'overlap_levels': False, # if true the level upper and lower boundary will overlap with value specified in 'y_overlap'
    'y_overlap': 10, # overlap size of levels boundaries (in mm)

    'x_overhead_axial': [20,20],
    'y_overlap_axial': 40,
    'overlap_levels_axial':True,
    'z_overhead_axial': 50,
    '3d_box': True
}

tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv').iloc[0:100]

data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
train_ids = tsd.study_id.unique()

train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
            'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
            'axial': data_axial[data_axial.study_id.isin(train_ids)]}

bb = BoxDatasetUnited(train_data, train_dataset_config)

In [None]:
color_dict = dict([(0,'r'),(1,'g'), (2,'b'), (3,'m'), (4, 'y')]) # 4646740 41477684


inputs, labels, masks = bb[12]
i=3
inp = inputs[i]
print(labels[i])

for k in range(10):
    fig, ax = plt.subplots(1, 1, figsize = (8, 8))
    #imgs.append(inputs)

    # print ground truth
    for i, img in enumerate(inp):
        ax.imshow(img[k,:,:].detach().cpu().numpy())
    plt.show()

In [None]:
inputs.shape

In [None]:
labels

# MODEL TRAINER

In [18]:
class SagittalTrainer():
    def __init__(self, model, model_params, config, train_data, eval_data) -> None:

        self.print_evaluation = config["print_evaluation"] if "print_evaluation" in config else False
        self.steps_per_plot = config["steps_per_plot"]

        if config['train_dataset_config']['mix_strategy'] and len(config['train_dataset_config']['series_out_types'])>1:
            print("Warning! With mix strategy set to 'combined' and multiple series out types the batch size will be (len(series_out_types)) times bigger.")

        self.checkpoints = config['checkpoints']
        self.save_path = config['save_path']
        self.step_per_save = config['step_per_save']

        self.config = config

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Device set to {self.device}")

        self.model_name = config['model_name']
        self.model = model(**model_params).to(self.device)
        #self.model.to(self.device)
        
        self.optimizer = config["optimizer"](self.model.parameters(),**config["optimizer_params"])

        self.series_len = config['val_dataset_config']['series_len']
        self.dataloaders = {'train': torch.utils.data.DataLoader(BoxDatasetUnited(train_data, config['train_dataset_config']), batch_size=config["batch_size"], shuffle=True, num_workers=12, prefetch_factor=1),
                           'val': torch.utils.data.DataLoader(BoxDatasetUnited(eval_data, config['val_dataset_config']), batch_size=config["batch_size"], shuffle=False, num_workers=12, prefetch_factor=1)}
        
        self.max_epochs = config["epochs"]
        self.early_stopping = config['early_stopping']
        self.early_stopping_tresh = config['early_stopping_treshold']

        # scheduler
        if config["scheduler"]:
            if 'epochs' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['epochs'] = self.max_epochs
            if 'steps_per_epoch' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['steps_per_epoch'] = len(self.dataloaders['train'])
            self.scheduler = config["scheduler"](self.optimizer,**config["scheduler_params"])
            self.one_cycle_sched = self.scheduler.__class__.__name__ == 'OneCycleLR'
        else:
            self.scheduler = None
        
        ## Evaluation metrics
        self.series_in_types = config['train_dataset_config']['load_series']
        self.out_stypes = config['train_dataset_config']['series_out_types']
        self.batch_size = config["batch_size"] * len(self.out_stypes)
        self.best_ll = 1

    def save_model(self):
        torch.save(self.model.state_dict(), os.path.join(self.save_path, f"{self.model_name}_best.pt"))

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))

    def train(self):
        for epoch in range(self.max_epochs):
            print(f"Epoch {epoch+1}/{self.max_epochs}")
            self.train_one_epoch()
            self.eval_one_epoch()
            #print examples
            #checkpoint
            if epoch%5==0:
                pass

    def train_one_epoch(self):

        self.model.train()  # Set model to training mode
        metrics = defaultdict(list)
        
        with tqdm(self.dataloaders['train'], unit = "batch",
                    total = len(self.dataloaders['train'])) as tepoch:
            for inputs, labels, masks in self.dataloaders['train']:
                with torch.set_grad_enabled(True):
                    loss, loss_info = self.model.get_loss(inputs.to(self.device).reshape(-1, 5, self.series_len, inputs.shape[-2], inputs.shape[-1]), 
                                                          labels[:,:,4].to(self.device).reshape(-1, 1), 
                                                          masks[:,:,4].to(self.device).reshape(-1,1))
                    for loss_t, loss_v in loss_info.items():
                        metrics[loss_t].append(loss_v.clone().detach().cpu().numpy())
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    if self.scheduler and self.one_cycle_sched:
                        self.scheduler.step()

                #update tqdm data
                tepoch.set_description(self.metrics_description(metrics, 'train'))
                tepoch.update(1)

        if self.scheduler and not self.one_cycle_sched:
            self.scheduler.step()


    def eval_one_epoch(self):
        self.model.eval()
        alabels = []
        apreds = []
        aweight = []
        metrics_d = defaultdict(list)

        with tqdm(self.dataloaders['train'], unit = "batch",
                    total = len(self.dataloaders['val'])) as tepoch:
            for inputs, labels, masks in self.dataloaders['val']:
                with torch.set_grad_enabled(False):
                    preds = self.model.predict(inputs.to(self.device).reshape(-1, 5, self.series_len, inputs.shape[-2], inputs.shape[-1]))
                    _, loss_info = self.model.get_loss(inputs.to(self.device).reshape(-1, 1, self.series_len, inputs.shape[-2], inputs.shape[-1]), 
                                                          labels[:,:,4].to(self.device).reshape(-1, 1), 
                                                          masks[:,:,4].to(self.device).reshape(-1,1))
                    for loss_t, loss_v in loss_info.items():
                        metrics_d[loss_t].append(loss_v.clone().detach().cpu().numpy())
                
                    apreds.append(preds.reshape(-1, 3))

                labels=  labels[:,:,4].reshape(-1)
                weights = 2**labels
                weights[torch.logical_not(masks[:,:,4].reshape(-1))] = 0.
                alabels.append(labels)
                aweight.append(weights)

                #update tqdm data
                tepoch.set_description(self.metrics_description(metrics_d, 'train'))
                tepoch.update(1)
        alabels = torch.cat(alabels, dim=0).cpu().numpy()
        apreds = torch.cat(apreds, dim=0).cpu().numpy()
        aweight=torch.cat(aweight, dim=0).cpu().numpy()

        #prind confusion matrix for every condition
        conditions = ['Spinal Canal Stenosis']
        
        fig, ax = plt.subplots(nrows=1, ncols=len(conditions), figsize=(15,5))
        if len(conditions) > 1:
            ax = ax.ravel()
        else:
            ax = [ax]
        for i in range(len(conditions)):
            cl = alabels[i::len(conditions)]
            cpred = apreds[i::len(conditions),:].argmax(-1)
            cm = confusion_matrix(cl, cpred)
            ax[i].set_title(conditions[i])
            ConfusionMatrixDisplay(
                confusion_matrix=cm).plot(ax=ax[i], colorbar=False)
        plt.show()

        ll = log_loss(alabels, apreds, normalize=True, sample_weight=aweight)
        if ll < self.best_ll:
            self.save_model()
            self.best_ll = ll

        print("Score:", ll)
        
        
    def metrics_description(self, metrics:dict, phase:str)->str:
        outputs = phase + ": ||"
        for k in metrics.keys():
            outputs += (" {}: {:4f} ||".format(k, np.mean(metrics[k])))
        return outputs


In [19]:
class MlpHead(nn.Module):
    """ MLP classification head
    """
    def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=nn.ReLU,
        norm_layer=nn.LayerNorm, head_dropout=0., bias=True):
        super().__init__()
        hidden_features = int(mlp_ratio * dim)
        self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
        self.act = act_layer()
        self.norm = norm_layer(hidden_features)
        self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
        self.head_dropout = nn.Dropout(head_dropout)


    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.head_dropout(x)
        x = self.fc2(x)
        return x

In [20]:
class ClassModelTimm25d(nn.Module):
    def __init__(self, backbone_name, series_dim, pretrained=True):
        super().__init__()
        self.feature_extractor = timm.create_model(
                                    backbone_name,
                                    pretrained=pretrained, 
                                    features_only=True,
                                    in_chans=1,
                                    num_classes=3,
                                    global_pool='avg'
                                    )
        
        self.all_channels = self.feature_extractor.feature_info.channels()
        self.reduction = self.feature_extractor.feature_info.reduction()
        print(self.all_channels, self.reduction)
        self.mid_channels = 128
        self.out_channels = 16
        self.series_dim = series_dim
        self.reduce = nn.Sequential(nn.Conv3d(
                                        self.all_channels[-1],
                                        self.mid_channels,
                                        kernel_size=(3,3,3),
                                        stride=(1, 1, 1),
                                        padding=(1, 1, 1),
                                        ),
                                        nn.BatchNorm3d(self.mid_channels),
                                        nn.ReLU(inplace=True),

                                        nn.Conv3d(
                                        self.mid_channels,
                                        self.out_channels,
                                        kernel_size=(1,1,1),
                                        stride=(1, 1, 1),
                                        padding=0, #(1, 1, 1),
                                        )
                                        )
        self.reduce2 = nn.Sequential(nn.Conv3d(
                                        self.all_channels[-2],
                                        self.mid_channels,
                                        kernel_size=(3,3,3),
                                        stride=(1, 1, 1),
                                        padding=(1, 1, 1),
                                        ),
                                        nn.BatchNorm3d(self.mid_channels),
                                        nn.ReLU(inplace=True),

                                        nn.Conv3d(
                                        self.mid_channels,
                                        self.out_channels,
                                        kernel_size=(1,1,1),
                                        stride=(1, 1, 1),
                                        padding=0, #(1, 1, 1),
                                        )
                                        )
        
        reduce_1_out = self.out_channels*int(series_dim[0])*int((series_dim[1]/self.reduction[-1])*(series_dim[2]/self.reduction[-1]))
        reduce_2_out = self.out_channels*int(series_dim[0])*int((series_dim[1]/self.reduction[-2])*(series_dim[2]/self.reduction[-2]))
        self.head = MlpHead(reduce_1_out+reduce_2_out, 
                              3, 1)
    
    
    def forward(self, x):
        b, s, d, h, w = x.shape
        x = x.permute(0,2,1,3,4)
        x = self.feature_extractor(x.reshape(b*s*d, 1, h, w))[-2:]
        #print(x.shape)
        x1 = x[-1].reshape(b, s*d, self.all_channels[-1], self.series_dim[1]//self.reduction[-1], self.series_dim[2]//self.reduction[-1]).permute(0,2,1,3,4)
        x2 = x[-2].reshape(b, s*d, self.all_channels[-2], self.series_dim[1]//self.reduction[-2], self.series_dim[2]//self.reduction[-2]).permute(0,2,1,3,4)
        x1 = self.reduce(x1).reshape(b, -1)
        x2 = self.reduce2(x2).reshape(b, -1)
        y = self.head(torch.cat([x1, x2], dim=-1))

        return y.reshape(b, 3, -1)

    def get_loss(self, input, labels, masks):
        
        preds = self.forward(input)
        labels = labels
        w = 3 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss = F.cross_entropy(preds, labels, reduction='none', label_smoothing=0.01) * w
        loss = loss.mean()*torch.tensor(preds.shape[0], dtype=loss.dtype).to(preds.device)

        return loss.mean(), dict(cross_entropy=loss.mean().clone().detach())
    
    def predict(self, input):
        return self.forward(input).permute(0,2,1).softmax(-1)

In [21]:
class ClassModelTimm2d(nn.Module):
    def __init__(self, backbone_name, series_dim, pretrained=True):
        super().__init__()
        self.model = timm.create_model(
                                    backbone_name,
                                    pretrained=pretrained, 
                                    features_only=False,
                                    in_chans=series_dim[0]*1,
                                    num_classes=3,
                                    global_pool='avg'
                                    )

    def forward(self, x):
        b, s, d, h, w = x.shape
        x = x.permute(0,2,1,3,4)
        x = x.reshape(b, s*d, h, w)
        y = self.model(x)
        return y.reshape(b, 3, -1)

    def get_loss(self, input, labels, masks):
        
        preds = self.forward(input)
        w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss = F.cross_entropy(preds, labels, reduction='none', label_smoothing=0.0) * w
        loss = loss.mean()*torch.tensor(preds.shape[0], dtype=loss.dtype).to(preds.device)

        return loss.mean(), dict(cross_entropy=loss.mean().clone().detach())
    
    def predict(self, input):
        return self.forward(input).permute(0,2,1).softmax(-1)

In [71]:
class ClassModelTimm2dLSTM(nn.Module):
    def __init__(self, backbone_name, series_dim, pretrained=True):
        super().__init__()
        self.model = timm.create_model(
                                    backbone_name,
                                    pretrained=pretrained, 
                                    features_only=True,
                                    in_chans=series_dim[0],
                                    num_classes=3,
                                    global_pool='avg'
                                    )
        self.all_channels = self.model.feature_info.channels()
        self.reduction = self.model.feature_info.reduction()
        print(self.all_channels, self.reduction)
        dim = self.all_channels[-1]*int((series_dim[1]/self.reduction[-1])*(series_dim[2]/self.reduction[-1]))
        print(dim)
        self.lstm = nn.LSTM(dim, 1024, 3, batch_first=True, bidirectional=True)
        self.head = MlpHead(2048, 3, 1)
        
    def forward(self, x):
        b, s, d, h, w = x.shape
        x = x.reshape(b*s, d, h, w)
        y = self.model(x)[-1]
        x = y.reshape(b, s, -1)
        x, _ = self.lstm(x)
        x = self.head(x).reshape(-1,3)
        return x.unsqueeze(-1)

    def get_loss(self, input, labels, masks):
        
        preds = self.forward(input)
        w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss = F.cross_entropy(preds, labels, reduction='none', label_smoothing=0.0) * w
        loss = loss.mean()*torch.tensor(preds.shape[0], dtype=loss.dtype).to(preds.device)

        return loss.mean(), dict(cross_entropy=loss.mean().clone().detach())
    
    def predict(self, input):
        return self.forward(input).permute(0,2,1).softmax(-1)

# Training model

In [None]:
im_size = 64

train_transforms = v2.Compose([
    v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5), v2.RandomHorizontalFlip(p = 0.5)]), #flip
    v2.RandomAffine(degrees=0, translate=(0.2,0.2)),
    v2.RandomAffine(degrees=0, scale=(0.8,1.2)), #scaling
])
train_transforms_depth = v2.Compose([
    v2.RandomHorizontalFlip(p = 0.5),
    v2.RandomAffine(degrees=0, translate=(0.2,0.2)),
])

val_transforms = v2.Compose([
    v2.Resize((im_size,im_size)), #resize
])

train_dataset_config ={
    'preload': False, # preload data into memory (WARNING IT MAY TAKE A LOT OF SPACE - DEPENDS ON THE DATASET USED)
    'im_size': [im_size, im_size],
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset",
    'load_series': ['sagittal_t2'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
    'united': True,

    'transforms': train_transforms,
    'transforms_d': train_transforms_depth,
    
    'vsa': True,
    'one_label': False, # use one label for every level (do not differentiate between levels)
    'series_out_types': ['sagittal'], # mix output series types to get views necessary to create 3d box ['sagittal', 'coronal']
    'mix_strategy': 'combined', # strategy for mixing output series types ['random', 'custom', 'combined'] random - randomly select output type, 
                              #'manual' - will return data based on currently choosen view, 'combined' will return all views in one call
    'return_series_type': False, # If True getitem will also return series orignial type
    
    'normalize': True,
    'image_type': 'png',

    'dataset_type': 'boxes', #'boxes', 'conditions'
    'supress_warinings': False, 

    'limit_series_len': True, # if True the series length will be limited to number N specified by 'series_len' parameter
    'series_len': 10, #  maximal number of slices in series
    'randomize_borders':True,

    'x_overhead': [20,20], # overhead for levels in x-dim (in mm)
    'z_overhead': 20,
    'overlap_levels': True, # if true the level upper and lower boundary will overlap with value specified in 'y_overlap'
    'y_overlap': 5, # overlap size of levels boundaries (in mm)

    'x_overhead_axial': [20,20],
    'y_overlap_axial': 1,
    'overlap_levels_axial':True,
    'z_overhead_axial': 20,
    '3d_box': True
}

val_dataset_config = copy.deepcopy(train_dataset_config)
val_dataset_config['transforms'] = None
val_dataset_config['load_series'] = ['sagittal_t2']
val_dataset_config['transforms_d'] = None
val_dataset_config['mix_strategy'] = 'combined'
val_dataset_config['vsa'] = False
val_dataset_config['randomize_borders'] = True
val_dataset_config['return_series_type'] = True
'''
val_dataset_config['x_overhead'] = [2, 10]
val_dataset_config['z_overhead'] = 5
val_dataset_config['overlap_levels'] = True
val_dataset_config['y_overlap'] = 5

val_dataset_config['x_overhead_axial'] = [2, 10]
val_dataset_config['z_overhead_axial'] = 5
val_dataset_config['overlap_levels_axial'] = True
val_dataset_config['y_overlap_axial'] = 5
'''

In [73]:
trainer_config = {
    "print_evaluation": True,
    "steps_per_plot": 1,

    "checkpoints": False,
    "save_path": "/workspaces/RSNA_LSDC/models_3d_final/model_weight",
    "step_per_save":100,
    "model_name": "densenet121LSTM_80_80_10",
    "train_dataset_config": train_dataset_config,
    "val_dataset_config": val_dataset_config,

    "epochs": 37, 
    "batch_size": 5,

    "optimizer": torch.optim.AdamW, #torch.optim.AdamW,#torch.optim.Adam,
    "optimizer_params": {'lr':1e-3, 'weight_decay': 1e-3},#, 'weight_decay': 1e-3, 'momentum': 0.98},#, 'momentum': 0.98, 'weight_decay': 1e-3},#, 'momentum':0.98, 'weight_decay':1e-5},#, 'momentum':0.9},
    "scheduler": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, #torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, #torch.optim.lr_scheduler.ExponentialLR, #torch.optim.lr_scheduler.OneCycleLR
    "scheduler_params": {'T_0': 37, 'T_mult': 1, 'eta_min':3e-8}, #{'T_0': 2, 'T_mult': 2, 'eta_min':3e-5}, #{'max_lr': 0.001, 'epochs': None, 'steps_per_epoch':None}, {'gamma':0.9}

    "early_stopping": False,
    "early_stopping_treshold": 0.1,
    'vsa':False,
}

model_config = {
    'backbone_name': 'densenet121', 
    'series_dim': [train_dataset_config['series_len']]+train_dataset_config['im_size'],
    'pretrained': True, 
    }


In [74]:
tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv').iloc[0:3000]
def kfoldCV(k, trainer_config, model_config):
    model_summaries = []
    data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
    data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
    data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
    unique_studies = np.random.permutation(np.array(tsd.study_id.unique()))
    if k == 1:
        with open('/workspaces/RSNA_LSDC/models_3d_final/train_unique_studies.npy', 'rb') as f:
            train = np.load(f)
        with open('/workspaces/RSNA_LSDC/models_3d_final/test_unique_studies.npy', 'rb') as f:
            test = np.load(f) 
        folds = [train, test[0:300]]
    else:
        folds = np.array_split(unique_studies, k)
    
    for i in range(k):
        print(f"Fold: {i}")
        train_ids = np.concatenate(folds[:i]+folds[i+1:], axis=0)
        train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
                    'axial': data_axial[data_axial.study_id.isin(train_ids)]}
        val_data=  {'sagittal': data_sagittal[data_sagittal.study_id.isin(folds[i])],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(folds[i])],
                    'axial': data_axial[data_axial.study_id.isin(folds[i])]}
        
        trainer = SagittalTrainer(ClassModelTimm2dLSTM, model_config, trainer_config, train_data, val_data)
        trainer.train()
        model_summaries.append(trainer.get_summary())

    return model_summaries
# convformer_s18

In [None]:
kfoldCV(1, trainer_config, model_config)

In [None]:
timm.list_models()

In [None]:
class SingleSeriesClassificator(nn.Module):
    # takes whole level data combined from 3d slices from every series type avaliable in form [series_types = 3, d, h, w]
    def __init__(self, backbone_name:str, series_dim:list[int], out_preds, out_channels=32, mid_channels=128):
        super().__init__()
       
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.series_dim = series_dim
        self.backbone_name = backbone_name
        self.out_preds= out_preds*3

        opts = {
        'model': 'resnet',
        'input_W': series_dim[2],
        'input_H': series_dim[1],
        'input_D': series_dim[0],
        'device': self.device,
        'phase': 'train',
        'in_c' : 1,
        }

        model_pretrained_params = {
            'resnet_10': {'model_depth': 10, 'resnet_shortcut': 'B'},
            'resnet_10_23dataset': {'model_depth': 10, 'resnet_shortcut': 'B'},
            'resnet_18': {'model_depth': 18, 'resnet_shortcut': 'A'},
            'resnet_18_23dataset': {'model_depth': 18, 'resnet_shortcut': 'A'},
            'resnet_34': {'model_depth': 34, 'resnet_shortcut': 'A'},
            'resnet_34_23dataset': {'model_depth': 34, 'resnet_shortcut': 'A'}
        }

        for model_name, model_dict in model_pretrained_params.items():
            model_pretrained_params[model_name] = Struct({**model_dict, **opts})

        self.feature_extractor = MedNet('resnet_18_23dataset', model_pretrained_params, 1).to(self.device)
        self.feature_extractor.init_FE(self.device)

        self.all_channels = [64, 128, 256, 512]#self.feature_extractor.feature_info.channels()
        self.reduction = [4, 8, 8, 8] #[4,8,16,32] # self.feature_extractor.feature_info.reduction()
        self.mid_channels = mid_channels
        self.out_channel = out_channels
        self.reduce = nn.Sequential(nn.Conv3d(
                                        512,
                                        mid_channels,
                                        kernel_size=(3,3,3),
                                        stride=(1, 1, 1),
                                        padding=0,#(1, 1, 1),
                                        ),
                                        nn.BatchNorm3d(mid_channels),
                                        nn.ReLU(inplace=True),
                                        nn.Conv3d(
                                        mid_channels,
                                        out_channels,
                                        kernel_size=(3,3,3),
                                        stride=(1, 1, 1),
                                        padding=0, #(1, 1, 1),
                                        bias=False), 
                                        nn.BatchNorm3d(out_channels),
                                        nn.ReLU(inplace=True),
                                        )
        
        print(int((series_dim[1]/self.reduction[-1])*(series_dim[2]/self.reduction[-1])*(series_dim[0]/self.reduction[-1])))
        self.pred = nn.Sequential(
            nn.Linear(out_channels*4*4*4, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, self.out_preds),
        )                                 
        
        
    def forward(self, x):
        b, s, d, h, w = x.shape
        x = self.feature_extractor(x)
        x = self.reduce(x[-1])
        return x, self.pred(x.reshape(b,-1)).reshape(b,3,-1)

class ClassModelMednetSplitSeries(nn.Module):
    # separate feature extractor for each series
    def __init__(self, backbone_name:str, series_dim:list[int], pretrained:bool=True):
        super().__init__()
       
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.series_dim = series_dim
        self.backbone_name = backbone_name


        self.feature_extractor_st1 = SingleSeriesClassificator(backbone_name, series_dim, 5, 32, 128)
        self.feature_extractor_st2 = SingleSeriesClassificator(backbone_name, series_dim, 5, 32, 128)
        self.feature_extractor_ax2 = SingleSeriesClassificator(backbone_name, series_dim, 5, 32, 128)
        self.all_c_dim = 3*32

        self.expander = nn.Sequential(nn.Conv3d(
                                        self.all_c_dim,
                                        512,
                                        kernel_size=(3,3,3),
                                        stride=(1, 1, 1),
                                        padding=(1, 1, 1),
                                        ),
                                        nn.BatchNorm3d(512),
                                        nn.ReLU(inplace=True),
                                        nn.Conv3d(
                                        512,
                                        512,
                                        kernel_size=(1,1,1),
                                        stride=(1, 1, 1),
                                        padding=0,#(1, 1, 1),
                                        bias=False), 
                                        nn.BatchNorm3d(512),
                                        nn.ReLU(inplace=True),
                                        )
        self.reducer = nn.Sequential(nn.Conv3d(
                                        512,
                                        128,
                                        kernel_size=(3,3,3),
                                        stride=(1, 1, 1),
                                        padding=(1, 1, 1),
                                        ),
                                        nn.BatchNorm3d(128),
                                        nn.ReLU(inplace=True),
                                        nn.Conv3d(
                                        128,
                                        32,
                                        kernel_size=(1,1,1),
                                        stride=(1, 1, 1),
                                        padding=0,#(1, 1, 1),
                                        bias=False), 
                                        nn.BatchNorm3d(32),
                                        nn.ReLU(inplace=True),
                                        )

        self.pred = nn.Sequential(
            nn.Linear(32*4*4*4, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, 15),
        ) 

    def forward(self, x):
        b, s, d, h, w = x.shape
        #raise
        st1, pred_st1 = self.feature_extractor_st1(x[:,0,...].unsqueeze(1))
        st2, pred_st2 = self.feature_extractor_st2(x[:,1,...].unsqueeze(1))
        ax2, pred_ax2 = self.feature_extractor_ax2(x[:,2, ...].unsqueeze(1))
        x = torch.cat([st1, st2, ax2], dim=1)
        x = self.reducer(self.expander(x))
        return self.pred(x.reshape(b,-1)).reshape(b,3,5), (pred_st1, pred_st2, pred_ax2)

    def get_loss(self, input, labels, masks):
        a_preds, d_preds = self.forward(input)
        #foramina loss
        if input[:, 0,...].nonzero().size(0) == 0:
            w = labels*0
        else:
            w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
            w[:,[1,3,4]] *= torch.tensor(0.2, dtype=w.dtype)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss_st1 = F.cross_entropy(d_preds[0], labels, reduction='none', label_smoothing=0.) * w
        #subarticular loss
        if input[:, 2,...].nonzero().size(0) == 0:
            w = labels*0
        else:
            w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
            w[:,[0,2,4]] *= torch.tensor(0.2, dtype=w.dtype)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss_ax2 = F.cross_entropy(d_preds[2], labels, reduction='none', label_smoothing=0.) * w
        #subarticular loss
        if input[:, 1,...].nonzero().size(0) == 0:
            w = labels*0
        else:
            w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
            w[:,[0,1,2,3]] *= torch.tensor(0.2, dtype=w.dtype)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss_st2 = F.cross_entropy(d_preds[1], labels, reduction='none', label_smoothing=0.) * w


        #calcualte all predictions
        w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        aloss = F.cross_entropy(a_preds, labels, reduction='none', label_smoothing=0.) * w

        loss = aloss.mean()+loss_st1.mean()+loss_ax2.mean()+loss_st2.mean()
        metric_d = dict(sum_loss = loss, all_loss=aloss, st1_loss = loss_st1, st2_loss = loss_st2, ax2_loss=loss_ax2)
        return loss, metric_d
    
    def predict(self, input):
        x, _ = self.forward(input)
        return x.permute(0,2,1).softmax(-1)

In [None]:
# TransFormer
class StarReLU(nn.Module):
    """
    StarReLU: s * relu(x) ** 2 + b
    """
    def __init__(self, scale_value=1.0, bias_value=0.0,
        scale_learnable=True, bias_learnable=True, 
        mode=None, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.relu = nn.ReLU(inplace=inplace)
        self.scale = nn.Parameter(scale_value * torch.ones(1),
            requires_grad=scale_learnable)
        self.bias = nn.Parameter(bias_value * torch.ones(1),
            requires_grad=bias_learnable)
    def forward(self, x):
        return self.scale * self.relu(x)**2 + self.bias
    
class MlpHead(nn.Module):
    """ MLP classification head
    """
    def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=nn.ReLU,
        norm_layer=nn.LayerNorm, head_dropout=0., bias=True):
        super().__init__()
        hidden_features = int(mlp_ratio * dim)
        self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
        self.act = act_layer()
        self.norm = norm_layer(hidden_features)
        self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
        self.head_dropout = nn.Dropout(head_dropout)


    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.head_dropout(x)
        x = self.fc2(x)
        return x
    
class Mlp(nn.Module):
    """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
    Mostly copied from timm.
    """
    def __init__(self, dim, mlp_ratio=2, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs):
        super().__init__()
        in_features = dim
        out_features = out_features or in_features
        hidden_features = int(mlp_ratio * in_features)
        drop_probs = (0., 0.)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
    
class Attention(nn.Module):
    """
    Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762.
    Modified from timm.
    """
    def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False,
        attn_drop=0., proj_drop=0., proj_bias=False, **kwargs):
        super().__init__()

        self.head_dim = head_dim
        self.scale = head_dim ** -0.5

        self.num_heads = num_heads if num_heads else dim // head_dim
        if self.num_heads == 0:
            self.num_heads = 1
        
        self.attention_dim = self.num_heads * self.head_dim

        self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        
    def forward(self, x):
        B, D, H, W, C = x.shape
        N = H * W * D
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, D, H, W, self.attention_dim)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class MetaFormerBlock(nn.Module):
    """
    Implementation of one MetaFormer block.
    """
    def __init__(self, dim, drop=0.
                 ):

        super().__init__()

        self.norm1 = nn.LayerNorm(dim)
        self.token_mixer = Attention(dim=dim, drop=drop)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = Mlp(dim=dim, drop=drop)
        
    def forward(self, x):
        x = self.token_mixer(self.norm1(x))
        x = self.mlp(self.norm2(x))
        return x