Here we'll create model for level boundary estimation:  
- estimating level instances for series  
- estimating level box - once we find best performing boxes  

In [1]:
import os
import copy
import torch
import random
from torch import nn
import timm
from typing import List, Union, Dict

from sklearn.metrics import log_loss
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from detection_train.Box_model_mednet import BoxModel, Nms3dAxial, Nms3dSagittalForamina, Nms3d

from PIL import Image
import pandas as pd
from collections import defaultdict
import numpy as np
from tqdm.notebook import tqdm

import torchvision.transforms.v2 as v2
import torch.nn.functional as F

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import matplotlib.animation as animation
from IPython.display import HTML

In [None]:
torch.cuda.is_available()

postprocessing found bboxes:
Include guaranteed information about bboxes in image. 
-> 5 classes each with one bbox or 1 class with 5 bbox
-> If there exist level n-1 and n+1 in image there must also exist level n
-> The boxes of levels n-1, n, n+1 must be aligned next to another in heigh dimention
-> The boxes must overlap in x dimention to some extend 
-> If there exist series of boxes for levels n, n+1, n+2 and image height is bigger than mean level heigh than the n+3 level must also exist - same situation in reverse


# DATASET

In [3]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        # data_info: dict consisting of series types and dataframe with their info
        # config: dict - dataset configuration

        self.used_series_types = config['load_series']
        self.study_ids = np.unique(np.concatenate([series.study_id.unique() for series in data_info.values() if series is not None]))
        self.data_info = data_info
        self.return_series_type = config['return_series_type']

        self._condition_list = ['Left Neural Foraminal Narrowing', 
                                'Left Subarticular Stenosis', 
                                'Right Neural Foraminal Narrowing', 
                                'Right Subarticular Stenosis', 
                                'Spinal Canal Stenosis']
        
        self._status_map = {'Normal/Mild': [1., 0., 0.],
                            'Moderate': [0., 1., 0.],
                            'Severe': [0., 0., 1.]}
        
        self.level_ind = {'L1/L2': 0, 'L2/L3': 1, 'L3/L4': 2, 'L4/L5': 3, 'L5/S1':4}

        self.dataset_path = config['dataset_path']
        self.image_type = config['image_type']
        self.data = []
        self.prepare_data()

    def get_condition_labels(self, series_info:pd.DataFrame):
        labels = []
        cond_presence_masks = []
        level_presence_mask = []
        for level, _ in self.level_ind.items():
            if not series_info[series_info['level']==level].empty:
                labels.append(series_info[series_info['level']==level].iloc[0].status)
                cond_presence_masks.append(series_info[series_info['level']==level].iloc[0].presence_mask)
                level_presence_mask.append(True)
            else:
                level_presence_mask.append(False)

        return np.array(labels), np.array(cond_presence_masks), np.array(level_presence_mask)

    def info2dict(self, series_info, stype=None): #remember axials can be combination of different serieses (sagittals can't)
        level0 = series_info.iloc[0]
        data_dict = {}
        labels, masks, level_presence= self.get_condition_labels(series_info)

        data_dict['study_id'] = level0.study_id
        data_dict['series_id'] = level0.series_id
        data_dict['width'] = level0.image_width
        data_dict['height'] = level0.image_height
        data_dict['reversed'] = level0.reversed
        data_dict['series_type'] = stype
        data_dict['files'] = [f"{self.dataset_path}/{data_dict['study_id']}/{data_dict['series_id']}/{instance}.{self.image_type}" for instance in level0.present_instances]
        data_dict['labels'] = labels
        data_dict['masks'] = masks
        data_dict['level_presence'] = np.where(level_presence==True)[0]
        data_dict['IPP'] = level0.IPP
        return data_dict
   
    def prepare_data(self):
        # prepare paths for every image to load
        
        with tqdm(total=len(self.study_ids), desc="Preparing data: ") as pbar:
            for study_id in self.study_ids:
                study_dict = dict(
                                sagittal=[], 
                                sagittal_t2=[], 
                                axial=[])
                present = 0
                for stype in self.used_series_types:
                    for series_id in self.data_info[stype].query(f'study_id == {study_id}').series_id.unique():
                        present += 1
                        ddict = self.info2dict(self.data_info[stype][self.data_info[stype].series_id==series_id], stype=stype)
                        study_dict[stype].append(ddict)
                
                if present == 0:
                    continue
                else:
                    self.data.append(study_dict)
                pbar.update(1)
         
    def load_series(self, data) -> np.ndarray:
        oimg = np.zeros((len(data['files']), data['height'], data['width']), dtype = np.uint8)
        for i, path in enumerate(data['files']):  
            try:      
                oimg[i,:,:] = np.array(Image.open(path), dtype = np.uint8)
            except ValueError:
                temp = np.array(Image.open(path), dtype = np.uint8) 
                oimg[i,:temp.shape[0],:temp.shape[1]] = temp
                del temp
    
        if data['series_type'] == 'sagittal':
            key = np.argsort([d[0] for d in data['IPP']])[::-1]
            oimg = oimg[key]
    
        return oimg.copy()

    def prepare_series(self, data):
        # img [5, d, h, w]
        level_labels = torch.tensor(data['box_labels'], dtype=torch.int64)
        cond_labels = torch.zeros((5,5,3), dtype=torch.int64)
        cond_masks = torch.zeros((5,5), dtype=torch.int64)
        cond_labels[level_labels,...] = torch.tensor(data['labels'], dtype=torch.int64)
        cond_masks[level_labels,...] = torch.tensor(data['label_presence_mask'], dtype=torch.int64)

        return cond_labels, cond_masks
    
    def __getitem__(self, index: int):
        return None
        
    def __len__(self):
        return(len(self.data))

In [4]:
class BoxDatasetUnited(Dataset):
    def __init__(self, data_info:Dict[str, pd.DataFrame], config:Dict):
        super(BoxDatasetUnited, self).__init__(data_info, config)
        #split data to individual serieses from dict of study-series_type pairs
       
    def flip_sides(self, volume, labels, masks):
        # flip left/right sides and their labels
        volume = volume.flip(2)
        labels = labels[:, [2,3,0,1,4]]
        masks = masks[:, [2,3,0,1,4]]
        return volume, labels, masks

    def __getitem__(self, index: int)->tuple[np.ndarray, np.ndarray]:
        adata = self.data[index]

        type_to_ind = {}
        ind = 0
        for st in self.used_series_types:
            type_to_ind[st] = ind
            ind+=1

        oimgs = {
            'study_id': 0,
            'sagittal': [],
            'sagittal_t2': [],
            'axial': []
        }
        cond_labels = torch.zeros(5,5,3)
        cond_masks = torch.zeros(5,5)
        for key, data_list in adata.items():
            if data_list:
                for data in data_list: #schuffle if you want random in some cases
                   # print(data['study_id'])
                    oimgs['study_id'] = data['study_id']
                    try:
                        oimg = self.load_series(data)
                    except ZeroDivisionError: # one series was skipped in data preparation and now it raises exception ;_;
                            continue
  
                    oimgs[key].append(torch.tensor(oimg, dtype=torch.float))
                    cond_labels[data['level_presence']] += torch.tensor(data['labels'])
                    cond_masks[data['level_presence']] += torch.tensor(data['masks'])
        
        cond_labels = cond_labels.to(dtype=torch.int64).argmax(-1)
        cond_masks = cond_masks.clamp(min=0, max =1).to(dtype=torch.int64)
        
        return oimgs, cond_labels, cond_masks

    def __len__(self):
        return(len(self.data))
    
    def get_random_by_stype(self):
        return self[random.randint(0, len(self)-1)]

In [None]:
train_dataset_config ={
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset_vl",
    'load_series': ['sagittal', 'sagittal_t2', 'axial'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
   
    'return_series_type': False, # If True getitem will also return series orignial type
    'image_type': 'png',
}

tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv').iloc[0:100]
#tsd = tsd[tsd['study_id'] == 4003253]

data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
train_ids = tsd.study_id.unique()

train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
            'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
            'axial': data_axial[data_axial.study_id.isin(train_ids)]}

bb = BoxDatasetUnited(train_data, train_dataset_config)

In [6]:
ans = pd.read_csv("/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train.csv")

In [None]:
color_dict = dict([(0,'r'),(1,'g'), (2,'b'), (3,'m'), (4, 'y')])
oimgs, labels, masks = bb[7]
labels

# MODEL TRAINER

In [8]:
class SagittalTrainer():
    def __init__(self, model, config, train_data, eval_data) -> None:

        self.print_evaluation = config["print_evaluation"] if "print_evaluation" in config else False
        self.steps_per_plot = config["steps_per_plot"]

        self.checkpoints = config['checkpoints']
        self.save_path = config['save_path']
        self.step_per_save = config['step_per_save']

        self.config = config

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Device set to {self.device}")

        self.model_name = config['model_name']
        self.model = model().to(self.device)
        #self.model.to(self.device)
        
        self.optimizer = config["optimizer"](self.model.parameters(),**config["optimizer_params"])

        self.dataloaders = {'train': torch.utils.data.DataLoader(BoxDatasetUnited(train_data, config['train_dataset_config']), batch_size=config["batch_size"], shuffle=True, num_workers=12, prefetch_factor=1),
                           'val': torch.utils.data.DataLoader(BoxDatasetUnited(eval_data, config['val_dataset_config']), batch_size=config["batch_size"], shuffle=False, num_workers=12, prefetch_factor=1)}
        
        self.max_epochs = config["epochs"]
        self.early_stopping = config['early_stopping']
        self.early_stopping_tresh = config['early_stopping_treshold']

        # scheduler
        if config["scheduler"]:
            if 'epochs' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['epochs'] = self.max_epochs
            if 'steps_per_epoch' in list(config['scheduler_params'].keys()):
                config['scheduler_params']['steps_per_epoch'] = len(self.dataloaders['train'])
            self.scheduler = config["scheduler"](self.optimizer,**config["scheduler_params"])
            self.one_cycle_sched = self.scheduler.__class__.__name__ == 'OneCycleLR'
        else:
            self.scheduler = None
        
        ## Evaluation metrics
        self.best_ll = 0.1

    def save_model(self):
        torch.save(self.model.canale_estimator.state_dict(), os.path.join(self.save_path, f"{self.model_name}_best.pt"))

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))

    def train(self):
        for epoch in range(self.max_epochs):
            print(f"Epoch {epoch+1}/{self.max_epochs}")
            self.train_one_epoch()
            self.eval_one_epoch()
            #print examples
            #checkpoint
            if epoch%5==0:
                pass

    def train_one_epoch(self):

        self.model.train()  # Set model to training mode
        metrics = defaultdict(list)
        
        with tqdm(self.dataloaders['train'], unit = "batch",
                    total = len(self.dataloaders['train'])) as tepoch:
            for oimgs, labels, masks in self.dataloaders['train']:

                with torch.set_grad_enabled(True):
                    labels = labels.reshape(5,-1)
                    masks = masks.reshape(5,-1)
                    loss, loss_info_canale, loss_info_sub, loss_info_foramina = self.model.get_loss(oimgs, 
                                                          labels[:,4].unsqueeze(-1).to('cuda'), labels[:,[1,3]].to('cuda').reshape(-1,1), labels[:,[0,2]].to('cuda').reshape(-1,1), 
                                                          masks[:,4].to('cuda'), masks[:,[1,3]].to('cuda').reshape(-1), masks[:,[0,2]].to('cuda').reshape(-1))
                    for loss_t, loss_v in loss_info_sub.items():
                        metrics[loss_t+'_subart'].append(loss_v.clone().detach().cpu().numpy())
                    for loss_t, loss_v in loss_info_canale.items():
                        metrics[loss_t+'_canale'].append(loss_v.clone().detach().cpu().numpy())
                    for loss_t, loss_v in loss_info_foramina.items():
                        metrics[loss_t+'_foramina'].append(loss_v.clone().detach().cpu().numpy())
                    self.optimizer.zero_grad()
                    
                    loss.backward()
                    self.optimizer.step()
                    if self.scheduler and self.one_cycle_sched:
                        self.scheduler.step()

                #update tqdm data
                tepoch.set_description(self.metrics_description(metrics, 'train'))
                tepoch.update(1)

        if self.scheduler and not self.one_cycle_sched:
            self.scheduler.step()


    def eval_one_epoch(self):
        self.model.eval()
        alabels = []
        apreds = []
        aweight = []
        metrics_d = defaultdict(list)

        with tqdm(self.dataloaders['train'], unit = "batch",
                    total = len(self.dataloaders['val'])) as tepoch:
            for oimgs, labels, masks in self.dataloaders['val']:
                #labels = labels[:,:,4].reshape(5,-1)
                #masks = masks[:,:,4].reshape(5,-1)
                #masks[:, [0,2,1,3]] = 0
                with torch.set_grad_enabled(False):
                    s = self.model.predict(oimgs)
                    preds = s.get_condition_as_tensor()
                    apreds.append(preds.reshape(-1, 3))

                labels=  labels.reshape(-1)
                weights = 2**labels
                weights[torch.logical_not(masks.reshape(-1))] = 0.
                alabels.append(labels)
                aweight.append(weights)


                #update tqdm data
                tepoch.update(1)
        alabels = torch.cat(alabels, dim=0).cpu().numpy()
        apreds = torch.cat(apreds, dim=0).cpu().numpy()
        aweight=torch.cat(aweight, dim=0).cpu().numpy()

        #prind confusion matrix for every condition
        conditions = ['Left Neural Foraminal Narrowing', 
                                'Left Subarticular Stenosis', 
                                'Right Neural Foraminal Narrowing', 
                                'Right Subarticular Stenosis', 
                                'Spinal Canal Stenosis']
        
        fig, ax = plt.subplots(nrows=1, ncols=len(conditions), figsize=(15,5))
        if len(conditions) > 1:
            ax = ax.ravel()
        else:
            ax = [ax]
        for i in range(len(conditions)):
            cl = alabels[i::len(conditions)]
            cpred = apreds[i::len(conditions),:].argmax(-1)
            cm = confusion_matrix(cl, cpred)
            ax[i].set_title(conditions[i])
            ConfusionMatrixDisplay(
                confusion_matrix=cm).plot(ax=ax[i], colorbar=False)
        plt.show()

        ll = log_loss(alabels, apreds, normalize=True, sample_weight=aweight)
        if ll < self.best_ll:
            self.save_model()
            self.best_ll = ll

        print("Score:", ll)
        
        
    def metrics_description(self, metrics:dict, phase:str)->str:
        outputs = phase + ": ||"
        for k in metrics.keys():
            outputs += (" {}: {:4f} ||".format(k, np.mean(metrics[k])))
        return outputs


In [9]:
class MlpHead(nn.Module):
    """ MLP classification head
    """
    def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=nn.ReLU,
        norm_layer=nn.LayerNorm, head_dropout=0., bias=True):
        super().__init__()
        hidden_features = int(mlp_ratio * dim)
        self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
        self.act = act_layer()
        self.norm = norm_layer(hidden_features)
        self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
        self.head_dropout = nn.Dropout(head_dropout)


    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.head_dropout(x)
        x = self.fc2(x)
        return x

In [10]:
class ClassModelTimm2d(nn.Module):
    def __init__(self, backbone_name, series_dim, pretrained=True):
        super().__init__()
        self.model = timm.create_model(
                                    backbone_name,
                                    pretrained=pretrained, 
                                    features_only=False,
                                    in_chans=series_dim[0]*1,
                                    num_classes=3,
                                    global_pool='avg'
                                    )
    
    def forward(self, x):
        b, s, d, h, w = x.shape
        x = x.permute(0,2,1,3,4)
        x = x.reshape(b, s*d, h, w)
        y = self.model(x)
        return y.reshape(b, 3, -1)

    def get_loss(self, input, labels, masks):
        preds = self.forward(input)
        labels = labels
        w = 2 ** labels # sample_weight w = (1, 2, 4) for y = 0, 1, 2 (batch_size, n)
        w[torch.logical_not(masks)] = 0. # set weight for unnanoted conds to 0
        loss = F.cross_entropy(preds, labels, reduction='none', label_smoothing=0.) * w
        loss = loss.mean()*torch.tensor(preds.shape[0], dtype=loss.dtype).to(preds.device)

        return loss.mean(), dict(cross_entropy=loss.mean().clone().detach())
    
    def predict(self, input):
        return self.forward(input).permute(0,2,1).softmax(-1)

In [11]:
class Box3d(object):
    def __init__(self, x, y, z) -> None:
        # 3d box in coordinates of sagittal series
        self.x = x
        self.y = y
        self.z = z
    
    @classmethod
    def bbox_from_view(cls, bbox, view):
        if view in ['sagittal', 'sagittal_t2']:
            return cls((bbox[0], bbox[3]), (bbox[1], bbox[4]), (bbox[2], bbox[5]))
        elif view == 'axial':
            return cls((bbox[1], bbox[4]), (bbox[2], bbox[5]), (bbox[0], bbox[3]))
        
    def get_box_in_view_type(self, view_type):
        if view_type in ['sagittal', 'sagittal_t2']:
            return self.get_sagittal()
        elif view_type == 'coronal':
            return self.get_coronal()
        elif view_type == 'axial':
            return self.get_axial()
        
    def get_dim_in_view_type(self, view_type):
        if view_type in ['sagittal', 'sagittal_t2']:
            box = self.get_sagittal()
        elif view_type == 'coronal':
            box= self.get_coronal()
        elif view_type == 'axial':
            box= self.get_axial()
        return np.array(box[3:]) - np.array(box[0:3])
        
    def get_sagittal(self):
        return [self.x[0], self.y[0], self.z[0], self.x[1], self.y[1], self.z[1]]
    def get_coronal(self):
        return [self.z[0], self.y[0], self.x[0], self.z[1], self.y[1], self.x[1]]
    def get_axial(self):
        return [self.z[0], self.x[0], self.y[0], self.z[1], self.x[1], self.y[1]]
    
    def __add__(self, box):
        self.x += box.x[0]
        self.y += box.y[0]
        self.z += box.z[0]

        return self

class Condition(object):
    def __init__(self, condition_name:str, box: Box3d=None, status=None):
        self.condition_name = condition_name
        self.box = box
        self.status = status
        self.status_map = dict([(0,'Normal/Mild'),(1,'Moderate'), (2,'Severe')])
        if status is not None:
            self.status_name = self.status_map[int(np.argmax(status))]
        else:
            self.status_name = "Unknown"

    def get_box_in_view_type(self, view_type):
        return self.box.get_box_in_view_type(view_type)
    
    def set_status(self, status):
        self.status = status
        if status is not None:
            self.status_name = self.status_map[int(np.argmax(status))]

class Level(object):
    def __init__(self, level_id, bbox: Box3d, score):
        self.level_names = dict([(0,'L1/L2'),(1,'L2/L3'), (2,'L3/L4'), (3,'L4/L5'), (4, 'L5/S1')])
        self.level_id = level_id
        self.level_name = self.level_names[level_id]
        self.box = bbox
        self.score = score
        self.scale = np.concatenate([self.box.get_dim_in_view_type('sagittal')]*2,0)
        
        self.condition_list = {
            'Left Neural Foraminal Narrowing':0,
            'Left Subarticular Stenosis':1,
            'Right Neural Foraminal Narrowing':2,
            'Right Subarticular Stenosis':3,
            'Spinal Canal Stenosis':4}
        
        self.conditions = []
        self.status_map = dict([(0,'Normal/Mild'),(1,'Moderate'), (2,'Severe')])
        # {'Normal/Mild': 0,
        #  'Moderate': 1,
        #  'Severe': 2}
    
    def get_box_in_view_type(self, view_type):
        return self.box.get_box_in_view_type(view_type)
    
    def set_status_to_condition(self, condition_name, status):
        for condition in self.conditions:
            if condition.condition_name == condition_name:
                condition.set_status(status)
                break

    def add_condition(self, condition, box=None, status=None, boxed_im_size=None, box_view = None):
        if box is not None:
            box = box/np.array(boxed_im_size*2, dtype = float)[[2,1,0,5,4,3]]
            box = Box3d.bbox_from_view(box, box_view).get_box_in_view_type('sagittal') * self.scale
            box = Box3d.bbox_from_view(box, 'sagittal')
            # convert relative to level box to absolute for scan
            box += self.box
        self.conditions.append(Condition(condition, box, status=status))

    
    def add_conditions_from_boxes(self, boxes, map_names, boxed_im_size, box_view):
        # map names = dict mapping labels from detector to condition names
        # split_lr = split into left/right condition (remember to sort series from left side to right side )
        labels = boxes['labels'].detach().cpu().numpy()
        boxes = boxes['boxes'].detach().cpu().numpy()

        unique_conditions = np.unique(labels)# select unique conditions to split them left/right
        for condition in unique_conditions:
            cond_ind = np.argwhere(labels==condition).ravel()
            if len(cond_ind)==0:
                continue
            elif len(cond_ind) ==1: # if one we won't check for l/r placement
                bb = boxes[cond_ind][0]/np.array(boxed_im_size*2, dtype = float)[[2,1,0,5,4,3]]
                ind = 0 if box_view=='axial' else 2
                if bb[ind] <0.5:
                    side = 'Left '
                else:
                    side = 'Right '
                self.add_condition(side+map_names[condition], box = boxes[cond_ind][0], boxed_im_size=boxed_im_size, box_view=box_view)
            else:
                # select left/right based on relative placement 
                ind = 0 if box_view=='axial' else 2
                cbox = boxes[cond_ind]
                cbox = cbox[cbox[:,ind].argsort()]
                # add left
                self.add_condition("Left "+map_names[condition], box = cbox[0], boxed_im_size=boxed_im_size, box_view=box_view)
                self.add_condition("Right "+map_names[condition], box = cbox[1], boxed_im_size=boxed_im_size, box_view=box_view)
            

class Series(object):
    def __init__(self, volume, series_type, series_id=None):
        self.series_id = series_id
        self.series_type = series_type
        # every series is stored in sagittal view, same is for boxes
        self.view = series_type
        self.volume = volume
        if series_type == 'axial':
            # change axial into sagittal for beter generalization
            self.volume = self.get_in_view('sagittal')
            self.view = 'sagittal'
        elif series_type =='sagittal_t2':
            self.view = 'sagittal'
            
        self.scale = np.array(self.volume.shape*2)[[2,1,0, 5, 4, 3]]
        self.level_names = dict([(0,'L1/L2'),(1,'L2/L3'), (2,'L3/L4'), (3,'L4/L5'), (4, 'L5/S1')])
        self.levels = {
            'L1/L2': None,
            'L2/L3': None,
            'L3/L4': None,
            'L4/L5': None,
            'L5/S1': None
        }
        self.color_dict = dict([('L1/L2','r'),('L2/L3','darkorange'), ('L3/L4','gray'), ('L4/L5','m'), ('L5/S1', 'y')])
        
    def __getitem__(self, level:Union[str, int]):
        if isinstance(level, int):
            level = self.level_names[level]
        return self.levels[level]
    
    def get_in_view(self,new_view=None):
        if not new_view:
            new_view = self.view
        if self.view in ['sagittal', 'sagittal_t2']:
            if new_view in ['sagittal', 'sagittal_t2']:
                return self.volume
            elif new_view=='coronal':
                return self.volume.transpose(2, 1, 0) # n, h, w -> w, h, n
            elif new_view=='axial':
                return self.volume.transpose(1, 2, 0) #n, h, w -> h, n, w
        elif self.view=='axial':
            if new_view in ['sagittal', 'sagittal_t2']:
                return self.volume.transpose(2, 0, 1) # n, h, w -> w, n, h
            elif new_view =='coronal':
                return self.volume.transpose(2, 1, 0) # n, h, w -> h, n, w
            elif new_view=='axial':
                return self.volume
    
    def get_levels_in_view(self, view, training=False):
        level_list = []
        level_ind_list = []
        volume = self.get_in_view(view)
        for l, level in self.levels.items():
            level_ind_list.append(l)
            if level is not None:
                box = np.array(level.get_box_in_view_type(view))
                box = box.astype(int)
                if training:
                    z0, z1 = max(0, box[2]+random.randint(-2,1)),min(volume.shape[0], box[5]+random.randint(-1,2))
                    h0, h1 = max(0, box[1]+random.randint(-10,10)),min(volume.shape[1], box[4]+random.randint(-10,10))
                    w0, w1 = max(0, box[0]+random.randint(-10,10)),min(volume.shape[2], box[3]+random.randint(-10,10))
                    level_list.append(volume[z0: max(z1, z0+1), 
                                             h0: max(h1, h0+1), 
                                             w0: max(w1, w0+1)])
                else:
                    level_list.append(volume[box[2]:min(volume.shape[0],box[5]+1), 
                                            box[1]:min(volume.shape[1],box[4]+1),
                                            box[0]:min(volume.shape[2],box[3]+1)
                                            ])
            else:
                level_list.append(None)


        return level_ind_list, level_list
    
    def get_conditions_in_view(self, conditions_to_get, view, training=False):
        cond_list = []
        cond_ind_list = []
        volume = self.get_in_view(view)
        for l, level in self.levels.items():
            if level is not None:
                for condition in conditions_to_get:
                    cond_ind_list.append((l, condition))
                    cd = next((c for c in level.conditions if c.condition_name == condition), None)
                    if cd is not None:
                        box = np.array(cd.get_box_in_view_type(view))
                        box = box.astype(int)
                        if training:
                            z0, z1 = max(0, box[2]+random.randint(-2,0)),min(volume.shape[0], box[5]+random.randint(0,2))
                            h0, h1 = max(0, box[1]+random.randint(-5,5)),min(volume.shape[1], box[4]+random.randint(-5,5))
                            w0, w1 = max(0, box[0]+random.randint(-5,5)),min(volume.shape[2], box[3]+random.randint(-5,5))
                            cond_list.append(volume[z0: max(z1, z0+1), 
                                                    h0: max(h1, h0+1), 
                                                    w0: max(w1, w0+1)])
                        else:
                            cond_list.append(volume[box[2]:box[5]+1, box[1]:box[4]+1, box[0]:box[3]+1])
                    else: 
                        cond_list.append(None)
            else:
                for condition in conditions_to_get:
                    cond_ind_list.append((l, condition))
                    cond_list.append(None)
                

        return cond_ind_list, cond_list
    
    def set_levels_from_boxes(self, boxes, box_view, boxed_im_size):
        for box, level, score in zip(boxes['boxes'].detach().cpu().numpy(), boxes['labels'].detach().cpu().numpy(), boxes['scores'].detach().cpu().numpy()):
            if box.argmax() > 1.:
                box = box/np.array(boxed_im_size*2, dtype = float)[[2,1,0,5,4,3]]
            box = Box3d.bbox_from_view(box, box_view).get_box_in_view_type('sagittal') * self.scale
            box = Box3d.bbox_from_view(box, 'sagittal')
            self.levels[self.level_names[level]] = Level(level, box, score)
    
    def plot_level_boxes(self, ax=None):
        color_dict = dict([('L1/L2','r'),('L2/L3','darkorange'), ('L3/L4','gray'), ('L4/L5','m'), ('L5/S1', 'y')])
        if ax is not None:
            assert len(ax) == 2
            ax=ax.flatten()
        else: 
            fig, ax = plt.subplots(1, 2, figsize = (8, 8))
        volume = self.volume
        ax[0].imshow(volume[volume.shape[0]//2,:,:], aspect='auto')
        ax[1].imshow(volume.transpose(2,1,0)[volume.shape[2]//2,:,:],aspect='auto')
        
        for level_name, level in self.levels.items():
            if level is not None:
                abox= np.array(level.get_box_in_view_type('sagittal'))
                box = abox[[0,1,3,4]]
                box2 = abox[[2,1,5,4]]
                color = color_dict[level_name]
                ax[0].add_patch(patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], linewidth=1, edgecolor=color, facecolor='none'))
                ax[0].text(box[2], box[1]+3, level_name, c=color)
                ax[1].add_patch(patches.Rectangle((box2[0], box2[1]), box2[2]-box2[0], box2[3]-box2[1], linewidth=1, edgecolor=color, facecolor='none'))
                ax[1].text(box2[2], box2[1]+3, level_name, c=color)
            # plot()
        ax[0].set_title(self.series_type)
        ax[0].axis('off')
        ax[1].axis('off')
        if ax is None:
            plt.subplots_adjust(wspace=0.1, hspace=0)
            plt.show()

    def interactive_plot(self):
        #add_points_to_plot - used to show how rest of the spine points (in different levels) will change during spine changing
        params = {
            'view': self.series_type,
            }

        fig, ax = plt.subplots(figsize=(5,5))
        ax.axis('off')
        volume = self.get_in_view(self.series_type)
        def plot_vid(i):
            plt.cla()
            ax.clear()
            ax.imshow(volume[i].astype(np.uint8))
            ax.axis('off')
            for level in self.levels.values():
                if level is not None:
                    level_box = level.get_box_in_view_type(params['view'])
                    if level_box[2] <= i and level_box[5] >= i: 
                        ax.annotate(level.level_name, (level_box[3]+5, level_box[1]), c=self.color_dict[level.level_name], fontsize = 'medium')
                        ax.add_patch(patches.Rectangle((level_box[0],level_box[1]),level_box[3]-level_box[0], level_box[4]-level_box[1], linewidth=1, edgecolor=self.color_dict[level.level_name], fill=False))
                        
                        num = 0
                        for condition in level.conditions:
                            if condition.box is not None:
                                num+=1
                                condition_box = condition.get_box_in_view_type(params['view'])
                                if condition_box[2] <= i and condition_box[5] >= i: 
                                    if "Right" in condition.condition_name and self.series_type != 'sagittal':
                                        x, y = max(0,condition_box[0]-40), condition_box[1]
                                    else:
                                        x, y = max(0,condition_box[0]-40), condition_box[4]+10
                                    an = condition.condition_name + ": " +condition.status_name
                                    ax.annotate(an, (x, y), c=self.color_dict[level.level_name], fontsize = 'small')

                                    ax.add_patch(patches.Rectangle((condition_box[0],condition_box[1]),condition_box[3]-condition_box[0], condition_box[4]-condition_box[1], linewidth=1, 
                                                                edgecolor=self.color_dict[level.level_name], fill=False))
                            else:
                                ax.annotate(f"{condition.condition_name}: {condition.status_name}", (level_box[3]+5, level_box[1]+15), c=self.color_dict[level.level_name], fontsize = 'medium')
                                
        amn = animation.FuncAnimation(fig, lambda i : plot_vid(i), frames=range(0, volume.shape[0]))
        # Demonstrate the animation
        example = HTML(amn.to_jshtml())
        plt.close()
        return example, amn

from glob import glob
import dicomsdl as dicoml
import os

class Study(object):
    def __init__(self,study_id, serieses:Dict[str, Series]):
        self.serieses = serieses
        self.study_id = study_id

        self.condition_map= {
            'Left Neural Foraminal Narrowing':0,
            'Left Subarticular Stenosis':1,
            'Right Neural Foraminal Narrowing':2,
            'Right Subarticular Stenosis':3,
            'Spinal Canal Stenosis':4}
        
        
    def __getitem__(self, series_type: str, view_type = None):
        # series type - str [axial, sagittal, sagittal_t2]
        # view_type - str [sagittal, coronal, axial]
        if series_type in ['sagittal', 'sagittal_t1', 'sagittal_T1']:
            return self.serieses['sagittal_t1']
        elif series_type in ['sagittal_t2', 'sagittal_T2', 'SAGITTAL_T2']:
            return self.serieses['sagittal_t2']
        elif series_type in ['axial', 'axial_t2', 'axial_T2']:
            return self.serieses['axial_t2']

        
    @classmethod
    def from_folder(cls, path):
        # check if folder has subfolders (in case of multiple serieses in study) like in RSNA Lumbar spine data
        paths = glob(f'{path}/*/')
        # if paths is empty threat it as folder with images
        if not paths:
            paths = [path] # if path leads to single series
        
        study = {
            'sagittal_t1': [],
            'sagittal_t2':[],
            'axial_t2':[]
        }
        for series in paths:
            files = os.listdir(series)
            dataset = [(int(os.path.splitext(file)[0]), dicoml.open(f"{series}/{file}")) for file in files]
            series_type = dataset[0][1].ImageOrientationPatient
            series_desc = dataset[0][1].SeriesDescription
            print(series_desc)
            study_id = dataset[0][1].StudyID
            widht = dataset[0][1].Columns
            height = dataset[0][1].Rows
            if np.array_equal(np.round(series_type), [0.,  1.,  0.,  0., 0., -1.]):
                series_type = 'sagittal'
                slices = sorted(dataset, key=lambda s: -s[1].ImagePositionPatient[0]) # sort left to right (if sagittal)
                #print(slices)
            elif np.array_equal(np.round(series_type), [1.,  0.,  0.,  0., 1., 0.]):
                series_type = 'axial'
                slices = sorted(dataset, key=lambda s: s[0]) # sort left to right (if sagittal)
            else:
                continue
            # load images:
            oimg = np.zeros((len(slices), height, widht), dtype = float)
            for i, (instance, dimg) in enumerate(slices):  
                img = dimg.pixelData()
                if np.max(img) != 0:
                    img = img / np.max(img)
                img=(img * 255).astype(np.uint8)
                try:      
                    oimg[i,:,:] = img
                except ValueError:
                    oimg[i,:img.shape[0],:img.shape[1]] = img
            if series_desc:
                key = series_type+"_"+series_desc.lower()
            else:
                key = series_type+"_"+"t2" if series_type == 'axial' else series_type+"_"+"t1"
            study[key].append(Series(oimg,series_type))
        #print(study)
        return cls(study_id, study)

    @classmethod
    def from_tensor_dict(cls, tensor_dict):
        # load serieses from dictionary of tensors (for tests)
        study = {
            'sagittal_t1': [],
            'sagittal_t2':[],
            'axial_t2':[]
        }
        for key, val in tensor_dict.items():
            if key == 'sagittal':
                key = 'sagittal_t1'
            if key == 'sagittal_t2':
                key = 'sagittal_t2'
            if key == 'axial':
                key = 'axial_t2'

            if key == 'study_id':
                study_id = val
                continue
            if len(val) < 1:
                    continue
            for series in val:
                if series.count_nonzero() < 1:
                    continue
                else:
                    study[key].append(Series(series.clone().detach().cpu().squeeze().numpy(), key.split('_')[0]))
        return cls(study_id, study)

    def __iter__(self):
        return iter(self.serieses) 
    def items(self):
        return self.serieses.items()
    def keys(self):
        return self.serieses.keys()
    def values(self):
        return self.serieses.values()
    

    def get_levels_in_view(self, series_type, view, training=False):
        level_list = []
        level_ind_list = []

        for series in self[series_type]:
            lil, ll = series.get_levels_in_view(view, training=training)
            if not level_list and not level_ind_list:
                level_ind_list = lil
                level_list = ll
            else:
                for i, (level_o, level_c) in enumerate(zip(level_list, ll)):
                    if level_o is None and level_c is not None: 
                        level_list[i] = level_c
        return level_ind_list, level_list

    def get_conditions_in_view(self, series_type, conditions_to_get, view, training=False):
        cond_list = []
        cond_ind_list = []
        
        for series in self[series_type]:
            cil, cl = series.get_conditions_in_view(conditions_to_get, view, training=training)
            if not cond_ind_list and not cond_list:
                cond_ind_list = cil
                cond_list = cl
            else:
                for i, (cond_n, cond_o) in enumerate(zip(cl, cond_list)):
                    if cond_o is None and cond_n is not None: 
                        cond_ind_list[i] = cond_n
        return cond_ind_list, cond_list
    
    def plot_levels(self):
        ax_num = sum([len(series) for _, series in self.items()])
        fig, ax = plt.subplots(2, ax_num, figsize = (12, 8))
        i = 0
        for serieses in self.values():
            for series in serieses:
                if series is not None:
                    series.plot_level_boxes(ax[:,i])
                    i+=1
         
        plt.subplots_adjust(wspace=0.1, hspace=0)
        plt.show()

    def get_condition_as_tensor(self):
        results = torch.ones((5, 5, 3))/3
        for serieses in self.values():
            for series in serieses:
                if series is not None:
                    for i, (key, level) in enumerate(series.levels.items()):
                        if level is not None:
                            for condition in level.conditions:
                                if condition.status is not None:
                                    results[i, self.condition_map[condition.condition_name], :] = torch.tensor(condition.status, dtype = results.dtype)
        return results
    

In [None]:
a = Study.from_folder("/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_images/13317052")

In [13]:
class ModelDetectEstimate(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = 'cuda'

         #################################DETECTORS#############################################
        self.level_detector_model_config = {
            'backbone_name': 'resnet_18', 
            'series_dim': [96]*3,
            'device': 'cuda',
            'use_features': [0,1],
            'reg_max': 12, 
            'pretrained': False, 
            'num_classes':5,
            'postprocess': Nms3d(0.3, 0.3, True)
        }

        # level detector
        #model
        self.level_detector = BoxModel(**self.level_detector_model_config).to(self.device).eval()
        self.level_detector.load_state_dict(torch.load('/workspaces/RSNA_LSDC/model_weight/level_detector_mednet18_96x3_best.pt'))
        # rescaler
        self.resize_for_ld= torch.nn.Upsample(size=self.level_detector_model_config['series_dim'], mode='trilinear').to(self.device)

        #f sub/foraminal axial detector
        self.sub_detector_model_config = {
            'backbone_name': 'resnet_18', 
            'series_dim': [48, 96, 96],
            'device': 'cuda',
            'use_features': [0,1],
            'reg_max': 16, 
            'pretrained': False, 
            'num_classes':1,
            'postprocess': Nms3dAxial(0.1, 0.05, False, 1)
        }

        self.axial_detector = BoxModel(**self.sub_detector_model_config).to(self.device).eval()
        self.axial_detector.load_state_dict(torch.load('/workspaces/RSNA_LSDC/model_weight/sub_detect_axial_mednet18_96_96_48_best.pt'))
        self.resize_for_sub= torch.nn.Upsample(size=self.sub_detector_model_config['series_dim'], mode='trilinear')

        #f foraminal detector
        self.foramina_detector_model_config = {
            'backbone_name': 'resnet_18', 
            'series_dim': [48, 96, 96],
            'device': 'cuda',
            'use_features': [0,1],
            'reg_max': 16, 
            'pretrained': False, 
            'num_classes':1,
            'postprocess': Nms3dSagittalForamina(0.1, 0.05, True, 1)
        }

        self.sagittal_detector = BoxModel(**self.foramina_detector_model_config).to(self.device).eval()
        self.sagittal_detector.load_state_dict(torch.load('/workspaces/RSNA_LSDC/model_weight/foramina_detect_sagittal_mednet18_96_96_48_best.pt'))
        self.resize_for_sagittal= torch.nn.Upsample(size=self.foramina_detector_model_config['series_dim'], mode='trilinear')

        #################################ESTIMATORS#############################################
        # canale estimator
        self.canale_estimator_model_config = {
            'backbone_name': 'densenet121', 
            'series_dim': [10, 80, 80],
            'pretrained': False, 
            }
        
        self.canale_esitmator = ClassModelTimm2d(**self.canale_estimator_model_config).to(self.device).eval()
        self.canale_esitmator.load_state_dict(torch.load('/workspaces/RSNA_LSDC/model_weight/densenset121_80_80_10_canale_best.pt'))
        self.resize_for_canale = torch.nn.Upsample(size=self.canale_estimator_model_config['series_dim'], mode='trilinear').to(self.device)

        # sub estimator
        self.sub_estimator_model_config = {
            'backbone_name': 'densenet121', 
            'series_dim': [8, 80, 80],
            'pretrained': False, 
            }
        
        self.sub_estimator = ClassModelTimm2d(**self.sub_estimator_model_config).to(self.device).eval()
        self.sub_estimator.load_state_dict(torch.load('/workspaces/RSNA_LSDC/model_weight/densenet121_80_80_8_subs_best.pt'))
        self.resize_for_sub_estimator = torch.nn.Upsample(size=self.sub_estimator_model_config['series_dim'], mode='trilinear').to(self.device)

        # foramina estimator
        self.foramina_estimator_model_config = {
            'backbone_name': 'densenet121', 
            'series_dim': [3, 80, 80],
            'pretrained': False, 
            }
        
        self.foramina_estimator = ClassModelTimm2d(**self.foramina_estimator_model_config).to(self.device).eval()
        self.foramina_estimator.load_state_dict(torch.load('/workspaces/RSNA_LSDC/model_weight/densenet121_80_80_3_foramina_best.pt'))
        self.resize_for_foramina_estimator = torch.nn.Upsample(size=self.foramina_estimator_model_config['series_dim'], mode='trilinear').to(self.device)

        self.train_transforms = v2.Compose([
            v2.RandomChoice([v2.RandomVerticalFlip(p = 0.5), v2.RandomHorizontalFlip(p = 0.5), v2.RandomAffine(degrees=5), 
                            v2.RandomRotation(degrees=(90,90)), v2.RandomRotation(degrees=(-90,-90))]),
            v2.RandomChoice([v2.RandomAffine(degrees=0, scale=(0.8,1.2)), v2.RandomPerspective(distortion_scale=0.2, p=1.0), v2.RandomAffine(degrees=0, translate=(0.3,0.3), shear=(-5,5,-5,5))]), # translation + shearing
            v2.RandomChoice([v2.GaussianBlur(kernel_size=(3,7), sigma=(0.1, 0.7))]),
        ])
        
    def normalize_volume(self, volume, type=0):
        if type==0:
            return (volume - 0.5*255.)/(0.5 *255.)
        else:
            pixels = volume[volume > 0]
            if pixels.size==0:
                return volume
            
            mean = pixels.mean()
            std  = pixels.std()
            out = (volume - mean)/std
            #out_random = np.random.normal(0, 1, size = volume.shape)
            #out[volume == 0] = out_random[volume == 0]
        return out


    def limit_series(self, img, series_length):
        if series_length < img.shape[0]:
            st = np.round(np.linspace(0, img.shape[0] - 1, series_length)).astype(int)
            img = img[st,:,:]
        elif series_length > img.shape[0]:
            img = F.pad(img, (0, 0, 0, 0, 0, series_length-img.shape[0]))
        return img
    

    def resize_data(self, volume, resizer, series_dim=None, dim_reduction='scale'):
        #scale - use trilinear interpolation to include scaling linearly in volume depth dimention
        #limit - limit series length by selecting equalu spaced lices/padding end with zeros
        if dim_reduction=='limit':
            assert series_dim is not None
            volume = self.limit_series(volume, series_dim[0])
        return resizer(volume.reshape(1, 1, *volume.shape)).squeeze(1).to(self.device)
    
    
    def predict(self, series):
        # series - dict consisting of segittal, sagittal t2 and axial volumes with original dimentions 
        #print(series['study_id'])
        if isinstance(series, str):
            study = Study.from_folder(series)
        else:
            study = Study.from_tensor_dict(series)
        
        for series_type, serieses in study.items():
            for series in serieses:
                if series is not None:
                    level_boxes = self.get_levels(torch.tensor(series.get_in_view('sagittal'), dtype=torch.float).to(self.device))[0]
                    series.set_levels_from_boxes(level_boxes, 'sagittal', self.level_detector_model_config['series_dim'])

        # detect subarticular regions from axial scan
        if len(study['axial'])>0:
            for series in study['axial']:
                axial_lvl_ind, axial_lvls = series.get_levels_in_view('axial')
                axial_for_sub = self.detect_fs_axial(axial_lvls)
                for ali, afs in zip(axial_lvl_ind, axial_for_sub):
                    if series.levels[ali] is not None:
                        series.levels[ali].add_conditions_from_boxes(afs, dict([(0, 'Subarticular Stenosis')]), #(1, 'Subarticular Stenosis')
                                                                            self.sub_detector_model_config['series_dim'], 'axial')
                    
         # detect foramina regions from sagittal scan
        if len(study['sagittal'])>0:
            for series in study['sagittal']:
                sagittal_lvl_ind, sagittal_lvls = series.get_levels_in_view('sagittal')
                sagittal_for_foramina = self.detect_foramina_sagittal(sagittal_lvls)
                for sli, sfs in zip(sagittal_lvl_ind, sagittal_for_foramina):
                    if series.levels[sli] is not None:
                        series.levels[sli].add_conditions_from_boxes(sfs, dict([(0, 'Neural Foraminal Narrowing')]),
                                                                            self.foramina_detector_model_config['series_dim'], 'sagittal')


        # using sagittal t2 levels predict canale condition
        if len(study['sagittal_t2'])>0:
            for series in study['sagittal_t2']:
                sagittal_lvl_ind, sagittal_lvls = series.get_levels_in_view('sagittal')
                level_canale_condition = self.estimate_canales(sagittal_lvls).squeeze(0)
                for sli, lcc in zip(sagittal_lvl_ind, level_canale_condition.detach().cpu().numpy()):
                    if series.levels[sli] is not None:
                        series.levels[sli].add_condition('Spinal Canal Stenosis', status=lcc)

        # estimate subarticular stenosis condition from axial scan
        if len(study['axial'])>0:
            for series in study['axial']:
                axial_cond_ind, axial_cond = series.get_conditions_in_view(['Left Subarticular Stenosis', 'Right Subarticular Stenosis'], 'axial')
                axial_sub_conditions = self.estimate_subarticular(axial_cond).squeeze(0)
                for aci, asc in zip(axial_cond_ind, axial_sub_conditions.detach().cpu().numpy()):
                    if series.levels[aci[0]] is not None:
                        series.levels[aci[0]].set_status_to_condition(aci[1], asc)

        # estimate foraminal narrowing from sagittal scan
        if len(study['sagittal'])>0:
            for series in study['sagittal']:
                sagittal_cond_ind, sagittal_cond = series.get_conditions_in_view(['Left Neural Foraminal Narrowing', 'Right Neural Foraminal Narrowing'], 'sagittal')
                sagittal_fora_sonditions = self.estimate_foraminal(sagittal_cond).squeeze(0)
                for sfi, sfc in zip(sagittal_cond_ind, sagittal_fora_sonditions.detach().cpu().numpy()):
                    if series.levels[sfi[0]] is not None:
                        series.levels[sfi[0]].set_status_to_condition(sfi[1], sfc)

        return study 
    
    def get_loss(self, series, labels_canale, labels_subarticulars, labels_foramina, masks_canale, mask_subarticulars, mask_foramina):
        # series - dict consisting of segittal, sagittal t2 and axial volumes with original dimentions 
        study = Study.from_tensor_dict(series)
        for series_type, serieses in study.items():
            for series in serieses:
                if series is not None:
                    level_boxes = self.get_levels(torch.tensor(series.get_in_view('sagittal'), dtype=torch.float).to(self.device))[0]
                    series.set_levels_from_boxes(level_boxes, 'sagittal', self.level_detector_model_config['series_dim'])


        # using sagittal t2 levels predict canale condition
        loss_canale = torch.tensor(0., dtype = torch.float)
        metrics_canale = {}
        if len(study['sagittal_t2'])>0:
            _, sagittal_lvls = study.get_levels_in_view('sagittal_t2', 'sagittal', True)
            loss_canale, metrics_canale = self.estimate_canales(sagittal_lvls, labels_canale, masks_canale)

        # detect subarticular regions from axial scan
        if len(study['axial']) > 0:
            for series in study['axial']:
                axial_lvl_ind, axial_lvls = series.get_levels_in_view('axial')
                axial_for_sub = self.detect_fs_axial(axial_lvls)
                for ali, afs in zip(axial_lvl_ind, axial_for_sub):
                    if series.levels[ali] is not None:
                        series.levels[ali].add_conditions_from_boxes(afs, dict([(0, 'Subarticular Stenosis')]), #(1, 'Subarticular Stenosis')
                                                                            self.sub_detector_model_config['series_dim'], 'axial')
                    
         # detect foramina regions from sagittal scan
        if len(study['sagittal']) > 0:
            for series in study['sagittal']:
                sagittal_lvl_ind, sagittal_lvls = series.get_levels_in_view('sagittal')
                sagittal_for_foramina = self.detect_foramina_sagittal(sagittal_lvls)
                for sli, sfs in zip(sagittal_lvl_ind, sagittal_for_foramina):
                    if series.levels[sli] is not None:
                        series.levels[sli].add_conditions_from_boxes(sfs, dict([(0, 'Neural Foraminal Narrowing')]),
                                                                            self.foramina_detector_model_config['series_dim'], 'sagittal')

        # estimate subarticular stenosis condition from axial scan
        loss_sub = torch.tensor(0., dtype = torch.float)
        metrics_sub = {}
        if len(study['axial']) > 0:
            _, axial_cond = study.get_conditions_in_view('axial', ['Left Subarticular Stenosis', 'Right Subarticular Stenosis'], 'axial', True)
            loss_sub, metrics_sub = self.estimate_subarticular(axial_cond, labels_subarticulars, mask_subarticulars)

        # estimate subarticular stenosis condition from axial scan
        loss_foramina = torch.tensor(0., dtype = torch.float)
        metrics_foramina = {}
        if len(study['sagittal']) > 0:
            _, sagittal_cond = study.get_conditions_in_view('sagittal', ['Left Neural Foraminal Narrowing', 'Right Neural Foraminal Narrowing'], 'sagittal', True)
            loss_foramina, metrics_foramina = self.estimate_foraminal(sagittal_cond, labels_foramina, mask_foramina)

        return loss_canale+loss_sub +loss_foramina, metrics_canale, metrics_sub, metrics_foramina #level_canale_condition

    def get_levels(self, series:torch.tensor, prepared:bool=False) -> List[Dict]:
        # series torch tensor in shape [d, h, w]
        if not prepared:
            series = self.normalize_volume(self.resize_data(series.to(self.device), self.resize_for_ld))
        return self.level_detector.predict(series)
        
    def detect_fs_axial(self, levels):
        prepared_levels=[]
        for level in levels:
            if level is not None:
                level = torch.tensor(level, dtype=torch.float).to(self.device)
                prepared_levels.append(self.normalize_volume(self.resize_data(level, self.resize_for_sub, dim_reduction='scale'), type=1))
            else:
                prepared_levels.append(torch.zeros((1,*self.sub_detector_model_config['series_dim'])).to(self.device))    

        return self.axial_detector.predict((torch.cat(prepared_levels, dim=0)))
    
    def detect_foramina_sagittal(self, levels):
        prepared_levels=[]
        for level in levels:
            if level is not None:
                level = torch.tensor(level, dtype=torch.float).to(self.device)
                prepared_levels.append(self.normalize_volume(self.resize_data(level, self.resize_for_sagittal, dim_reduction='scale'), type=1))
            else:
                prepared_levels.append(torch.zeros((1,*self.foramina_detector_model_config['series_dim'])).to(self.device))    

        return self.sagittal_detector.predict((torch.cat(prepared_levels, dim=0)))

    def estimate_canales(self, series, *args) -> torch.tensor:
        # level torch list of tensors in shape [d, h, w]
        # return list of results
    
        prepared_levels = []
        for level in series:
            if level is not None:

                level = torch.tensor(level, dtype=torch.float).to(self.device)
                prepared_levels.append(self.normalize_volume(self.resize_data(level, self.resize_for_canale, dim_reduction='scale'), type=1).unsqueeze(1))
            else:
                prepared_levels.append(torch.zeros((1,1,*self.canale_estimator_model_config['series_dim'])).to(self.device))

        if self.canale_esitmator.training:
            volume = self.train_transforms(torch.cat(prepared_levels, dim=0))
            return self.canale_esitmator.get_loss(volume, *args)
        else:
            return self.canale_esitmator.predict(torch.cat(prepared_levels, dim=0))
        
    def estimate_subarticular(self, series, *args) -> torch.tensor:
        # level torch list of tensors in shape [d, h, w]
        # return list of results
    
        prepared_levels = []
        for level in series:
            if level is not None:
                level = torch.tensor(level, dtype=torch.float).to(self.device)
                prepared_levels.append(self.normalize_volume(self.resize_data(level, self.resize_for_sub_estimator, dim_reduction='scale'), type=1).unsqueeze(1))
            else:
                prepared_levels.append(torch.zeros((1,1,*self.sub_estimator_model_config['series_dim'])).to(self.device))

        if self.sub_estimator.training:
            volume = self.train_transforms(torch.cat(prepared_levels, dim=0))
            return self.sub_estimator.get_loss(volume, *args)
        else:
            return self.sub_estimator.predict(torch.cat(prepared_levels, dim=0))
        
    def estimate_foraminal(self, series, *args) -> torch.tensor:
        # level torch list of tensors in shape [d, h, w]
        # return list of results

        prepared_levels = []
        for level in series:
            if level is not None:
                level = torch.tensor(level, dtype=torch.float).to(self.device)
                prepared_levels.append(self.normalize_volume(self.resize_data(level, self.resize_for_foramina_estimator, dim_reduction='scale'), type=1).unsqueeze(1))
            else:
                prepared_levels.append(torch.zeros((1,1,*self.foramina_estimator_model_config['series_dim'])).to(self.device))

        if self.foramina_estimator.training:
            volume = self.train_transforms(torch.cat(prepared_levels, dim=0))
            return self.foramina_estimator.get_loss(volume, *args)
        else:
            return self.foramina_estimator.predict(torch.cat(prepared_levels, dim=0))
        
    def train(self):
        self.canale_esitmator.train()
        self.sub_estimator.train()
        self.foramina_estimator.train()
    
    def eval(self):
        self.canale_esitmator.eval()
        self.sub_estimator.eval()
        self.foramina_estimator.eval()


In [None]:
model = ModelDetectEstimate()

In [None]:
model.eval()
study = model.predict("/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_images/117720278")

In [None]:
study.plot_levels()

In [45]:
f, ani = study['sagittal'][0].interactive_plot()

In [46]:
#ani.save('subarticular_detection_example.gif', writer='Pillow', fps=2)

In [None]:
f

# Finetune model

In [180]:
train_dataset_config ={
    'dataset_path': "/workspaces/RSNA_LSDC/inputs/dataset_vl",
    'load_series': ['sagittal', 'sagittal_t2', 'axial'], # series types to load into dataset ['sagittal', 'axial', 'sagittal_t2']
    
    'return_series_type': False, # If True getitem will also return series orignial type
    'image_type': 'png',
}

val_dataset_config = copy.deepcopy(train_dataset_config)

In [181]:
trainer_config = {
    "print_evaluation": True,
    "steps_per_plot": 1,

    "checkpoints": False,
    "save_path": "/workspaces/RSNA_LSDC/models_3d_final/model_weight",
    "step_per_save":100,
    "model_name": "efficientnet_b2_80_80_10_finetuned",
    "train_dataset_config": train_dataset_config,
    "val_dataset_config": val_dataset_config,

    "epochs": 10, 
    "batch_size": 1,

    "optimizer": torch.optim.AdamW, #torch.optim.AdamW,#torch.optim.Adam,
    "optimizer_params": {'lr':3e-5}, #,'weight_decay': 1e-3 },#, 'weight_decay': 1e-3, 'momentum': 0.98},#, 'momentum': 0.98, 'weight_decay': 1e-3},#, 'momentum':0.98, 'weight_decay':1e-5},#, 'momentum':0.9},
    "scheduler": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, #torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, #torch.optim.lr_scheduler.ExponentialLR, #torch.optim.lr_scheduler.OneCycleLR
    "scheduler_params": {'T_0': 10, 'T_mult': 2, 'eta_min':3e-9}, #{'T_0': 2, 'T_mult': 2, 'eta_min':3e-5}, #{'max_lr': 0.001, 'epochs': None, 'steps_per_epoch':None}, {'gamma':0.9}

    "early_stopping": False,
    "early_stopping_treshold": 0.1,
    'vsa':False,
}


In [186]:
tsd = pd.read_csv(f'/workspaces/RSNA_LSDC/inputs/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv').iloc[0:300]
def kfoldCV(k, trainer_config):
    model_summaries = []
    data_sagittal = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t1.pkl")
    data_sagittal_t2 = pd.read_pickle("/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_unified_sagital_t2.pkl")
    data_axial = pd.read_pickle('/workspaces/RSNA_LSDC/inputs/box_data/coordinates/coordinates_axial_unified.pkl')
    unique_studies = np.random.permutation(np.array(tsd.study_id.unique()))
    if k == 1:
        with open('/workspaces/RSNA_LSDC/inputs/train_unique_studies.npy', 'rb') as f:
            train = np.load(f)
        with open('/workspaces/RSNA_LSDC/inputs/test_unique_studies.npy', 'rb') as f:
            test = np.load(f) 
        folds = [test, train[:1]]
    else:
        folds = np.array_split(unique_studies, k)
    
    for i in range(k):
        print(f"Fold: {i}")
        train_ids = np.concatenate(folds[:i]+folds[i+1:], axis=0)
        train_data={'sagittal': data_sagittal[data_sagittal.study_id.isin(train_ids)],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(train_ids)],
                    'axial': data_axial[data_axial.study_id.isin(train_ids)]}
        val_data=  {'sagittal': data_sagittal[data_sagittal.study_id.isin(folds[i])],
                    'sagittal_t2': data_sagittal_t2[data_sagittal_t2.study_id.isin(folds[i])],
                    'axial': data_axial[data_axial.study_id.isin(folds[i])]}
        
        trainer = SagittalTrainer(ModelDetectEstimate, trainer_config, train_data, val_data)
        trainer.train()
        model_summaries.append(trainer.get_summary())

    return model_summaries
# convformer_s18

In [None]:
kfoldCV(1, trainer_config)