In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import pydicom
import numpy as np
import os
import glob
from tqdm import tqdm
import gc

import torchvision
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from fastai.vision.all import *
import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [ ]:
CV = 5
SEED = 777
fold = 1
PATCH_SIZE = 512
patch_size = 64
TH = .25
SEG_TRAIN = True
SEG = {
    'BS':16,
    'LR':5e-4,
    'EPOCHS':10
}
INF = {
    'BS':16,
    'LR':1e-4,
    'EPOCHS':10,
    'WD':0.1
}

In [ ]:
train = pd.read_csv('./datasets/rsna-2024-lumbar-spine-degenerative-classification/train.csv')

In [ ]:
canal_diagnosis = list(filter(lambda x: x.find('canal') > -1, train.columns))
train_canal = train[train[canal_diagnosis].isnull().values.sum(1)==0].reset_index(drop=True)

In [ ]:
foraminal_diagnosis = list(filter(lambda x: x.find('foraminal') > -1, train.columns))
train_foraminal = train[train[foraminal_diagnosis].isnull().values.sum(1)==0].reset_index(drop=True)

In [ ]:
df_meta_f = pd.read_csv('./datasets/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv')
df_meta_f['series_description'].groupby(df_meta_f['series_description']).count()

In [ ]:
df_coor = pd.read_csv('./datasets/rsna-2024-lumbar-spine-degenerative-classification/train_label_coordinates.csv')
df_coor['condition'].groupby(df_coor['condition']).count()

In [ ]:
def get_cooridnates_dataset(df:pd.core.frame.DataFrame, condition:str ):
    resultant_df = df[df['condition']==condition][[
        'study_id',
        'series_id',
        'instance_number',
        'level',
        'x',
        'y'
    ]].sort_values([
        'study_id',
        'series_id',
        'level'
    ])[[
        'study_id',
        'series_id',
        'level',
        'instance_number',
        'x',
        'y'    
    ]].drop_duplicates()
    return resultant_df

In [ ]:
LF = get_cooridnates_dataset(df_coor,'Left Neural Foraminal Narrowing')
RF = get_cooridnates_dataset(df_coor,'Right Neural Foraminal Narrowing')
SCS = get_cooridnates_dataset(df_coor,'Spinal Canal Stenosis')

In [ ]:
# checking if there are no missing values in the sequence - ensuring all study ids/series ids have l1/l2 to l5/s1 levels
(['L1/L2','L2/L3','L3/L4','L4/L5','L5/S1']*(len(LF)//5) == LF['level']).sum() == len(LF)

In [ ]:
LF = LF[[
    'study_id',
    'series_id',
    'instance_number',
    'x',
    'y'    
]]
LF[[
    'x_L1L2',
    'y_L1L2',
    'x_L2L3',
    'y_L2L3',
    'x_L3L4',
    'y_L3L4',
    'x_L4L5',
    'y_L4L5',
    'x_L5S1',
    'y_L5S1',    
]] = np.tile(LF[['x','y']].values.reshape(-1,1,5,2),(1,5,1,1)).reshape(-1,10)
LF = LF.drop(columns=['x','y']).drop_duplicates().reset_index(drop=True)
LF.tail()

In [ ]:
def merge_centers_with_dataset_for_na_values(df:pd.core.frame.DataFrame):
    centers = {}
    for i in range(len(df)):
        row = df.iloc[i]
        centers[row['study_id']]={}
    for i in range(len(df)):
        row = df.iloc[i]
        centers[row['study_id']][row['series_id']]={'L1/L2':[],'L2/L3':[],'L3/L4':[],'L4/L5':[],'L5/S1':[]}
    for i in range(len(df)):
        row = df.iloc[i]
        centers[row['study_id']][row['series_id']][row['level']].append([row['x'],row['y']])
     
    coordinates = np.zeros((len(df),10))
    coordinates[:] = np.nan
    for i in range(len(df)):
        row = df.iloc[i]
        for level in centers[row['study_id']][row['series_id']]:
            if len(centers[row['study_id']][row['series_id']][level]) > 0:
                center = np.array(centers[row['study_id']][row['series_id']][level]).mean(0)
                coordinates[
                    i,
                    {'L1/L2':0, 'L2/L3':2, 'L3/L4':4, 'L4/L5':6, 'L5/S1':8}[level]:{'L1/L2':0, 'L2/L3':2, 'L3/L4':4, 'L4/L5':6, 'L5/S1':8}[level]+2
                ] = center
    df = df[[
        'study_id',
        'series_id',
        'instance_number',
        'x',
        'y'    
    ]]
    df.loc[:,[
        'x_L1L2',
        'y_L1L2',
        'x_L2L3',
        'y_L2L3',
        'x_L3L4',
        'y_L3L4',
        'x_L4L5',
        'y_L4L5',
        'x_L5S1',
        'y_L5S1',    
    ]] = coordinates
    df = df.drop(columns=['x','y']).drop_duplicates().reset_index(drop=True)
    df = df[df[[
        'x_L1L2',
        'y_L1L2',
        'x_L2L3',
        'y_L2L3',
        'x_L3L4',
        'y_L3L4',
        'x_L4L5',
        'y_L4L5',
        'x_L5S1',
        'y_L5S1',    
    ]].isnull().values.sum(1)==0].reset_index(drop=True)
    df.tail()
    return df

In [ ]:
RF = merge_centers_with_dataset_for_na_values(RF)

In [ ]:
def merge_coordinates_with_main_dataset(df:pd.core.frame.DataFrame, train:pd.core.frame.DataFrame, condition:str):
    diagnosis = list(filter(lambda x: x.find(condition) > -1, train.columns))
    df = df.merge(train[['study_id']+diagnosis], left_on='study_id', right_on='study_id')
    df.tail()
    return df

In [ ]:
LF = merge_coordinates_with_main_dataset(LF,train_foraminal,'left_neural')
RF = merge_coordinates_with_main_dataset(RF,train_foraminal,'right_neural')
SCS = merge_coordinates_with_main_dataset(SCS,train_canal,'canal')

In [ ]:
def get_diagnosis_list_and_renamed_columns(df:pd.core.frame.DataFrame, orientation:str):
    diagnosis = list(filter(lambda x: x.find(orientation) > -1, df.columns))
    diagnosis = {x:x[len(orientation)+1:] for x in diagnosis}
    df = df.rename(columns=diagnosis)
    return df
    

In [ ]:
LF = get_diagnosis_list_and_renamed_columns(LF,'left')
RF = get_diagnosis_list_and_renamed_columns(RF,'right')

In [ ]:
FDF = pd.concat([LF,RF],axis=0,ignore_index=True)
FDF = FDF.merge(df_meta_f[['series_id','series_description']], left_on='series_id', right_on='series_id')
FDF.head()

In [ ]:
def define_cross_validation_index(df:pd.core.frame.DataFrame):
    v,c = np.unique(df['study_id'],return_counts=True)
    plt.plot(v,c,'.')
    L = len(v)
    S = L/CV
    fold_indices = list(np.rint(np.arange(CV)*S).astype(int))+[L]
    for i in range(5):
#         print(len(v[fold_indices[i]:fold_indices[i+1]]))
        df.loc[df['study_id'].isin(v[fold_indices[i]:fold_indices[i+1]]),'series_description'] = i+1
    df.tail()
    return df

In [ ]:
FDF = define_cross_validation_index(FDF)
SCS = define_cross_validation_index(SCS)

In [ ]:
labels = {
    'Normal/Mild':0,
    'Moderate':1,
    'Severe':2
}


In [ ]:
coor = [
    'x_L1L2',
    'y_L1L2',
    'x_L2L3',
    'y_L2L3',
    'x_L3L4',
    'y_L3L4',
    'x_L4L5',
    'y_L4L5',
    'x_L5S1',
    'y_L5S1',    
]

<h3> Segmentation Dataset </h3>

In [ ]:
def augment_image_and_centers(image,centers,alpha):
    '''
    # Randomly flip the image horizontally.
    if random.random() > .5:
      if random.random() > 1 - alpha:
        image = image.flip(-1)
        centers[:,0] = PATCH_SIZE - centers[:,0]
    # Randomly flip the image vertically.
    if random.random() > 0.5:
      if random.random() > 1 - alpha:
        image = image.flip(-2)
        centers[:,1] = PATCH_SIZE - centers[:,1]
  
    if random.random() > 1 - alpha:
      if random.random() > .5:
    #   Randomly flip the image
    #   Wich axis?
        axis = np.random.randint(2)
        image = image.flip(axis+1)
        centers[:,-1-axis] = PATCH_SIZE - centers[:,-1-axis]
    '''
#   Randomly rotate the image.
    angle = torch.as_tensor(random.uniform(-180, 180)*alpha)
    image = torchvision.transforms.functional.rotate(image,angle.item())
#   https://discuss.pytorch.org/t/rotation-matrix/128260
    angle = -angle*math.pi/180
    s = torch.sin(angle)
    c = torch.cos(angle)
    rot = torch.stack([
        torch.stack([c, s]),
        torch.stack([-s, c])
      ])
    centers = ((centers.cpu() - PATCH_SIZE//2) @ rot) + PATCH_SIZE//2

    return image,centers


In [ ]:
class T1Dataset(Dataset):
    def __init__(self, df, VALID=False, alpha=0):
        self.data = df
        self.VALID = VALID
        self.alpha = alpha

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        target = self.data.columns[-6:-1]
        centers = torch.as_tensor([x for x in row[coor]]).view(5,2).float()
        
        sample = './datasets/rsna-2024-lumbar-spine-degenerative-classification/train_images/'
        sample1 = sample+str(row['study_id'])+'/'+str(row['series_id'])+'/'+str(row['instance_number'])+'.dcm'
        
        image = pydicom.dcmread(sample1).pixel_array
        H,W = image.shape
#       By plane resizing I've been distorting the proportions
        if H > W:
            d = W
            if not self.VALID:
                h = int((H - d)*(.5 + self.alpha*(.5 - np.random.rand())))
            else:
                h = (H - d)//2
            image = image[h:h+d]
            centers[:,1] -= h
            H = W
        elif H < W:
            d = H
            if not self.VALID:
                w = int((W - d)*(.5 + self.alpha*(.5 - np.random.rand())))
            else:
                w = (W - d)//2
            image = image[:,w:w+d]
            centers[:,0] -= w
            W = H
        image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
        image = torch.as_tensor(image/np.max(image)).unsqueeze(0).float()
        
        label = torch.as_tensor([labels[x] for x in row[target]])
        
        centers[:,0] = centers[:,0]*PATCH_SIZE/W
        centers[:,1] = centers[:,1]*PATCH_SIZE/H

        if not self.VALID: image,centers = augment_image_and_centers(image,centers,self.alpha)

        return image.to(device),[label.to(device),centers.to(device)]

In [ ]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [ ]:
idx_map = torch.stack([torch.arange(PATCH_SIZE)]*PATCH_SIZE).to(device)
idx_map = torch.stack([idx_map,idx_map.T]).view(1,1,2,PATCH_SIZE,PATCH_SIZE)

 ### Segmentation UNet

In [ ]:
class myUNet(nn.Module):
    def __init__(self):
        super(myUNet, self).__init__()

        self.UNet = smp.Unet(
            encoder_name="resnet34",
            classes=5,
            in_channels=1
        ).to(device)

    def forward(self,X):
        x = self.UNet(X)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.view(-1,5,PATCH_SIZE*PATCH_SIZE).min(-1)[0].view(-1,5,1,1) # Bug, I've been MinMaxScaling with the wrong values
        max_values = x.view(-1,5,PATCH_SIZE*PATCH_SIZE).max(-1)[0].view(-1,5,1,1)
        x = (x - min_values)/(max_values - min_values)
        
        return x

### Segmentation Loss

In [ ]:
class myLoss(nn.Module):
    def __init__(
            self,
            alpha=.5
        ):
        super().__init__()
        self.alpha = alpha

    def clone(self):
        return myLoss(self.alpha)

    def forward(
            self,
            y,# Predictions
            t # Targets
        ):
        mask_pred = y
        _,mask_true = t
#       The heatmap Loss as the distance between the predicted Normal and the ideal one
#       Let's define the ideal heatmaps as the Normal distributions
#       centered on the diagnostic centers with s2 = PATCH_SIZE/8
        s2 = s2 = torch.as_tensor([PATCH_SIZE/8]*5)
#       Then the corresponding alphas and normalization constants would be
        A = -1/(2*s2).to(device)
        K = 1/torch.sqrt(2*math.pi*s2).to(device)
#       Predicted heatmaps rescaling
        mask_pred = mask_pred*K.view(1,5,1,1)
#       Ideal heatmaps
        mask = idx_map - mask_true.view(-1,5,2,1,1)
        mask = torch.exp((A.view(-1,5,1,1,1)*mask*mask).sum(2))*K.view(-1,5,1,1)
#       Distance
        D = 1 - ((mask*mask_pred).sum())**2/((mask*mask).sum()*(mask_pred*mask_pred).sum())
        
        return D

In [ ]:
def train_segmentation(df:pd.core.frame.DataFrame,condition:str,folds:int):
    
    def nt(nmin,nmax,tcur,tmax):
        return (nmax - .5*(nmax-nmin)*(1+np.cos(tcur*np.pi/tmax))).astype(np.float32)

    plt.plot(nt(0,1,np.arange(SEG['EPOCHS']),SEG['EPOCHS']))
    plt.show()

    # callback to update alpha during training
    def cb(self):
        alpha = torch.as_tensor(nt(.25,1,learn.train_iter,SEG['EPOCHS']*n_iter))
        learn.dls.train_ds.alpha = alpha
    alpha_cb = Callback(before_batch=cb)#
    target = df.columns[-6:-1]
    for fold in range(2,folds+1):
        tdf = df[df['series_description'] != fold]
        vdf = df[df['series_description'] == fold]

        tds = T1Dataset(tdf)
        vds = T1Dataset(vdf,VALID=True)
        tdl = torch.utils.data.DataLoader(tds, batch_size=SEG['BS'], shuffle=True, drop_last=True)
        vdl = torch.utils.data.DataLoader(vds, batch_size=SEG['BS'], shuffle=False)

        if SEG_TRAIN:
            seed_everything(SEED)

            dls = DataLoaders(tdl,vdl)

            n_iter = len(tds)//SEG['BS']

            model = myUNet()
            learn = Learner(
                dls,
                model,
                lr=SEG['LR'],
                loss_func=myLoss(alpha=0.5),
                cbs=[
                    ShowGraphCallback(),
                    alpha_cb
                ]
            )
            learn.fit_one_cycle(SEG['EPOCHS'],lr_max=5e-4, wd=0.1)
        #   learn.fit(SEG['EPOCHS'])
            torch.save(model,'SEG_'+condition+'_'+str(fold))
            del tdl,vdl,dls,model,learn
            gc.collect()

In [ ]:
train_segmentation(FDF,'foraminal',CV)

In [ ]:
train_segmentation(FDF,'canal',CV)

In [ ]:
foraminal_seg_models = []
for i in range(1,6):
    foraminal_seg_models.append(torch.load('./datasets/foraminal-segmentation-models/SEG_foraminal_'+str(i)))

In [ ]:
canal_seg_models = []
for i in range(1,6):
    canal_seg_models.append(torch.load('./datasets/canal-segmentation-models/SEG_canal_'+str(i)))

In [ ]:
# model1 = torch.load('./datasets/foraminal-segmentation-models/SEG_foraminal_1')
# i = np.random.randint(len(vds))
# img,centers = vds.__getitem__(i)
# OUT = model1(img.unsqueeze(0)).cpu().detach()
# centers = centers[1].cpu().long()
# print(i)
# fig, axes1 = plt.subplots(1, 5, figsize=(10,10))
# fig, axes2 = plt.subplots(1, 5, figsize=(10,10))
# for k in range(5):
#     image = img[0].cpu() + OUT[0,k].cpu()
#     c = (OUT[0,k].unsqueeze(0)*idx_map[0,0].cpu()).sum(-1).sum(-1)
#     d = OUT[0,k].sum()
#     c = c/d
#     Y,X = centers.cpu().long()[k]
#     YY,XX = c.long()
#     print(Y,X)
#     print(YY,XX)
#     for y in range(height):
#         for x in range(width):
#             # see if we're close to (x-a)**2 + (y-b)**2 == r**2
#             if abs((x-A)**2 + (y-B)**2 - r**2) < EPSILON**2:
#                 image[x+X-5,y+Y-5] = 0
#     axes1[k].imshow(image)
#     axes2[k].imshow(image[XX-64:XX+64,YY-64:YY+64])
# plt.show()

### Inference Dataset

In [1]:
def augment_image(image,alpha):
    # Randomly rotate the image.
    angle = torch.as_tensor(random.uniform(-180, 180)*alpha)
    image = torchvision.transforms.functional.rotate(image,angle.item())
    return image

In [ ]:


class ViT_T1_Dataset(Dataset):
    def __init__(self, df, UNet, Unet_Models, VALID=False, P=patch_size, alpha=0):
        self.data = df
        self.UNet = UNet
        self.unet_models = Unet_Models
        self.VALID = VALID
        self.P = P
        self.alpha = alpha

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        
        sample = './datasets/rsna-2024-lumbar-spine-degenerative-classification/train_images/'
        sample1 = sample+str(row['study_id'])+'/'+str(row['series_id'])+'/'+str(row['instance_number'])+'.dcm'
        centers = torch.as_tensor([x for x in row[coor]]).view(5,2).float()
        image = pydicom.dcmread(sample1).pixel_array
        H,W = image.shape
#       By plane resizing I've been distorting the proportions
        if H > W:
            d = W
            h = (H - d)//2
            image = image[h:h+d]
            centers[:,1] -= h
            H = W
        elif H < W:
            d = H
            w = (W - d)//2
            image = image[:,w:w+d]
            centers[:,0] -= w
            W = H
        image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
        image = torch.as_tensor(image/np.max(image)).unsqueeze(0).unsqueeze(0).float().to(device)

        OUT = 0
        with torch.no_grad():
                for rot in [0,1,2,3]:
                        OUT += torch.rot90(self.UNet(torch.rot90(image,rot,dims=[-2, -1])),-rot,dims=[-2, -1])

        OUT = (OUT/4 > 0.25)[0]
        c = (OUT.unsqueeze(1)*idx_map[0]).view(5,2,PATCH_SIZE*PATCH_SIZE).sum(-1).float()
        d = OUT.view(5,PATCH_SIZE*PATCH_SIZE).sum(-1).float()
        m = d > 0
        c[m] = (c[m]/d[m].unsqueeze(-1)).float()
#         c[~m] = self.P # I have to find a better solution
        c[~m] = self.P
        c[c < 64] = torch.nan
        c[c > 448] = torch.nan
        c_mean = torch.nanmean(c, dim=0)
        mask = torch.isnan(c)
        c[mask[:,0],0] = c_mean[0]
        c[mask[:,1],1] = c_mean[1]
        c = c.long()
#         print(row['study_id'],row['series_id'],row['instance_number'])
        image = torch.stack([image[
            0,
            0,
            xy[1]-self.P//2:xy[1]+self.P-self.P//2,
            xy[0]-self.P//2:xy[0]+self.P-self.P//2
        ] for xy in c])
#         print(image.shape,c,mask,c_mean)
        if not self.VALID: 
            
            image = augment_image(image,self.alpha)
        
        label = torch.as_tensor([labels[x] for x in row[target]])

        return [image.to(device),~m.to(device)],[label.to(device),~m.to(device)]


In [ ]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class myViT(nn.Module):
    def __init__(self, dim=512, depth=12, head_size=128, **kwargs):
        super().__init__()
        CNN = torchvision.models.resnet18(weights='DEFAULT')
        W = nn.Parameter(CNN.conv1.weight.sum(1, keepdim=True))
        CNN.conv1 = nn.Conv2d(1, patch_size, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        CNN.conv1.weight = W
        CNN.fc = nn.Identity()
        self.emb = CNN.to(device)
        self.pos_enc = nn.Parameter(SinusoidalPosEmb(dim)(torch.arange(5, device=device).unsqueeze(0)))
        self.transformer = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True, device=device), depth)
        self.proj_out = nn.Linear(dim,3).to(device)
    
    def forward(self, x):
#         print("x: ",x)
        x,mask = x
        x = self.emb(x.view(-1,1,patch_size,patch_size))
        x = x.view(-1,5,512)
        x = x + self.pos_enc
        x = self.transformer(x,src_key_padding_mask=mask)
        x = self.proj_out(x.view(-1,512))
        return x

### Inference VIT

### Inference Loss

In [ ]:
def myLoss(preds,target):
    target,mask = target
    target = target[~mask]
    preds = preds[~mask.view(-1)]
    return nn.CrossEntropyLoss(weight=torch.as_tensor([1.,2.,4.]).to(device))(preds,target)

### Inference Training

In [ ]:
def nt(nmin,nmax,tcur,tmax):
    return (nmax - .5*(nmax-nmin)*(1+np.cos(tcur*np.pi/tmax))).astype(np.float32)

plt.plot(nt(0,1,np.arange(INF['EPOCHS']),INF['EPOCHS']))
plt.show()

# callback to update alpha during training
def cb(self):
    alpha = torch.as_tensor(nt(.25,1,learn.train_iter,INF['EPOCHS']*n_iter))
    learn.dls.train_ds.alpha = alpha
alpha_cb = Callback(before_batch=cb)#

### Foraminal Prediction

In [ ]:
target = FDF.columns[-6:-1]
target

In [ ]:
for fold in range(1,6):
    seed_everything(SEED)
    tdf = FDF[FDF['series_description'] != fold]
    vdf = FDF[FDF['series_description'] == fold]
    tds = ViT_T1_Dataset(tdf,foraminal_seg_models[fold-1],foraminal_seg_models)
    vds = ViT_T1_Dataset(vdf,foraminal_seg_models[fold-1],foraminal_seg_models,VALID=True)
    tdl = torch.utils.data.DataLoader(tds, batch_size=INF['BS'], shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(vds, batch_size=INF['BS'], shuffle=False)

    dls = DataLoaders(tdl,vdl)

    n_iter = len(tds)//INF['BS']

    model = myViT()
    learn = Learner(
        dls,
        model,
        lr=1e-4,
        loss_func=myLoss,
        cbs=[
            ShowGraphCallback(),
            alpha_cb,
            GradientClip(3.0)
        ]
    )
    learn.fit_one_cycle(INF['EPOCHS'],lr_max=1e-4,wd=INF['WD'])
    torch.save(model,'Final_ViT_foraminal_'+str(fold))

In [ ]:
import sklearn
y_true = []
y_pred = []
with torch.no_grad():
    for [X,mask],[Y,mask] in tqdm(vdl):
        y_true.extend(Y[~mask].cpu().tolist())
        y_pred.extend(torch.argmax(model([X,mask]),-1)[~mask.view(-1)].cpu().tolist())

sklearn.metrics.confusion_matrix(y_true, y_pred)

In [ ]:
del y_true, y_pred

### Canal Prediction

In [ ]:
target = SCS.columns[-6:-1]
target

In [ ]:
for fold in range(1,6):
    seed_everything(SEED)
    tdf = SCS[SCS['series_description'] != fold]
    vdf = SCS[SCS['series_description'] == fold]
    tds = ViT_T1_Dataset(tdf, foraminal_seg_models[fold - 1], foraminal_seg_models)
    vds = ViT_T1_Dataset(vdf, foraminal_seg_models[fold - 1], foraminal_seg_models, VALID=True)
    tdl = torch.utils.data.DataLoader(tds, batch_size=INF['BS'], shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(vds, batch_size=INF['BS'], shuffle=False)

    dls = DataLoaders(tdl, vdl)

    n_iter = len(tds) // INF['BS']

    model = myViT()
    learn = Learner(
        dls,
        model,
        lr=1e-4,
        loss_func=myLoss,
        cbs=[
            ShowGraphCallback(),
            alpha_cb,
            GradientClip(3.0)
        ]
    )
    learn.fit_one_cycle(INF['EPOCHS'], lr_max=1e-4, wd=INF['WD'])
    torch.save(model, 'Final_ViT_foraminal_' + str(fold))

In [ ]:
y_true = []
y_pred = []
with torch.no_grad():
    for [X, mask], [Y, mask] in tqdm(vdl):
        y_true.extend(Y[~mask].cpu().tolist())
        y_pred.extend(torch.argmax(model([X, mask]), -1)[~mask.view(-1)].cpu().tolist())

sklearn.metrics.confusion_matrix(y_true, y_pred)