In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import pydicom
import numpy as np
import os
import glob
from tqdm import tqdm
import gc
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig, TrainingArguments, Trainer, DefaultDataCollator, get_cosine_schedule_with_warmup, SwinConfig, SwinModel, AutoFeatureExtractor, SwinForImageClassification

import torchvision
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from fastai.vision.all import *
import PIL.Image as Image
!pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [ ]:
SEED = 777
FOLDS = [1,2,3,4,5]
PATCH_SIZE = 512
patch_size = 128
Lmax = 36
BS = 16
EPOCHS = 10
patch_size=128
CV=5

In [ ]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [ ]:
df_meta_f = pd.read_csv('./datasets/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv')
df_meta_f.tail()

In [ ]:
df_coor = pd.read_csv('./datasets/rsna-2024-lumbar-spine-degenerative-classification/train_label_coordinates.csv')
df_coor.tail()

In [ ]:
train = pd.read_csv("./datasets/rsna-2024-lumbar-spine-degenerative-classification/train.csv")
train.head()

In [ ]:
gdf = df = df_coor[
    df_coor.condition.isin(
        [
            'Left Subarticular Stenosis',
            'Right Subarticular Stenosis'
        ]
    )
][['study_id','series_id','instance_number','level','x','y']].drop_duplicates()

gdf_new = gdf.groupby(['study_id','series_id','level','instance_number'],as_index=False).agg({'x':'mean','y':'mean'})

### Localizing ROI for Subarticular Dataset Segmentation


In [ ]:
class AxialSegDataset(Dataset):
    def __init__(self, df, VALID=False, alpha=0):
        self.data = df
        self.VALID = VALID
        self.alpha = alpha

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]

#         centers = torch.as_tensor([x for x in row[coor]]).view(5,2).float()
        centers = torch.as_tensor(np.array([x for x in row[['x','y']]])).view(1,2).float()
        
        sample = './datasets/train_images/'
        sample1 = sample+str(row['study_id'])+'/'+str(row['series_id'])+'/'+str(row['instance_number'])+'.dcm'
        print(row)
        image = pydicom.dcmread(sample1).pixel_array
        H,W = image.shape
#       By plane resizing I've been distorting the proportions
        if H > W:
            d = W
            if not self.VALID:
                h = int((H - d)*(.5 + self.alpha*(.5 - np.random.rand())))
            else:
                h = (H - d)//2
            image = image[h:h+d]
            centers[:,1] -= h
            H = W
        elif H < W:
            d = H
            if not self.VALID:
                w = int((W - d)*(.5 + self.alpha*(.5 - np.random.rand())))
            else:
                w = (W - d)//2
            image = image[:,w:w+d]
            centers[:,0] -= w
            W = H
        image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
        image = torch.as_tensor(image/np.max(image)).unsqueeze(0).float()
        
#         label = torch.as_tensor([labels[x] for x in row[target]])
        
        centers[:,0] = centers[:,0]*PATCH_SIZE/W
        centers[:,1] = centers[:,1]*PATCH_SIZE/H

        if not self.VALID: image,centers = augment_image_and_centers(image,centers,self.alpha)

        return image.to(device),centers.to(device)

In [ ]:
idx_map = torch.stack([torch.arange(PATCH_SIZE)]*PATCH_SIZE).to(device)
idx_map = torch.stack([idx_map,idx_map.T]).view(1,1,2,PATCH_SIZE,PATCH_SIZE)

In [ ]:
class myLoss(nn.Module):
    def __init__(
            self,
            alpha=.5
        ):
        super().__init__()
        self.alpha = alpha

    def clone(self):
        return myLoss(self.alpha)

    def forward(
            self,
            y,# Predictions
            t # Targets
        ):
#         with torch.cuda.amp.autocast():
        mask_pred = y.to(device)
        mask_true = t.to(device)
#         print("mask_pred.shape",mask_pred.shape)
#         print("mask_true.shape",mask_true.shape)
#       The heatmap Loss as the distance between the predicted Normal and the ideal one
#       Let's define the ideal heatmaps as the Normal distributions
#       centered on the diagnostic centers with s2 = PATCH_SIZE/8
        s2 = s2 = torch.as_tensor([PATCH_SIZE/8]*2)
#       Then the corresponding alphas and normalization constants would be
        A = -1/(2*s2).to(device)
        K = 1/torch.sqrt(2*math.pi*s2).to(device)
#       Predicted heatmaps rescaling
        mask_pred = mask_pred*K.view(1,2,1,1)
#         print("mask_pred2 shape",mask_pred.shape)
#       Ideal heatmaps
        mask = idx_map - mask_true.view(-1,1,2,1,1)
#         mask = torch.exp((A.view(-1,2,1,1,1)*mask*mask).sum(2))*K.view(-1,2,1,1)
        mask = torch.exp((A.view(1,2,1,1,1)*mask*mask).sum(2))*K.view(1,2,1,1)
#             mask = torch.exp((A.view(-1,1,1,1,1)*mask*mask).sum(2))*K.view(1,2,1,1)
#             mask = idx_map - mask_true.view(-1,1,2,1,1)
#             mask = torch.exp((A.view(1,2,1,1,1)*mask*mask).sum(2))*K.view(1,2,1,1)
#         print(torch.exp((A.unsqueeze(0).view(1,1,2,1,1).expand(32,1,2,512,512)*mask*mask)).shape)
#             mask = torch.exp((A.unsqueeze(0).view(1,1,2,1,1).expand(BS,1,2,512,512)*mask*mask).sum(2))*K.unsqueeze(0).view(1,1,2,1,1).expand(BS,1,2,512,512)
#       Distance
#         print("mask shape",mask.shape)
#         print("mask_pred shape",mask_pred.shape)
        D = 1 - ((mask*mask_pred).sum())**2/((mask*mask).sum()*(mask_pred*mask_pred).sum())

        return D

In [ ]:
class AxialUnet(nn.Module):
    def __init__(self):
        super(myUNet, self).__init__()

        self.UNet = smp.Unet(
            encoder_name="resnet34",
            classes=2,
            in_channels=1
        ).to(device)

    def forward(self,X):
        x = self.UNet(X)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.view(-1,2,PATCH_SIZE*PATCH_SIZE).min(-1)[0].view(-1,2,1,1) # Bug, I've been MinMaxScaling with the wrong values
        max_values = x.view(-1,2,PATCH_SIZE*PATCH_SIZE).max(-1)[0].view(-1,2,1,1)
        x = (x - min_values)/(max_values - min_values)
        
        return x

### Segmentation Training

In [ ]:
def nt(nmin,nmax,tcur,tmax):
    return (nmax - .5*(nmax-nmin)*(1+np.cos(tcur*np.pi/tmax))).astype(np.float32)

plt.plot(nt(0,1,np.arange(10),10))
plt.show()

# callback to update alpha during training
def cb(self):
    alpha = torch.as_tensor(nt(.25,1,learn.train_iter,10*n_iter))
    learn.dls.train_ds.alpha = alpha
alpha_cb = Callback(before_batch=cb)#

seed_everything(SEED)

for fold in range(FOLDS):

    tdf = gdf_new[gdf_new['fold'] != fold]
    vdf = gdf_new[gdf_new['fold'] == fold]
    
    tds = AxialSegDataset(tdf)
    vds = AxialSegDataset(vdf,VALID=True)
    tdl = torch.utils.data.DataLoader(tds, batch_size=BS, shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(vds, batch_size=BS, shuffle=False)
    
    dls = DataLoaders(tdl,vdl)
    
    n_iter = len(tds)//BS
    
    model = AxialUnet()
    learn = Learner(
        dls,
        model,
        lr=5e-4,
        loss_func=myLoss(alpha=0.5),
        cbs=[
            ShowGraphCallback(),
            alpha_cb
        ]
    )
    learn.fit_one_cycle(10)
    #   learn.fit(SEG['EPOCHS'])
    torch.save(model,'SEG_subaritcular_'+str(fold))
    del tdl,vdl,dls,model,learn
    gc.collect()

In [ ]:
axial_seg_models = []
for i in range(1,6):
    axial_seg_models.append(torch.load('./datasets/axial-segmentation-models/SEG_Axial_'+str(i)))

In [ ]:
train_len = len(train)
split_len = train_len/CV
split_indices = list(np.rint(np.arange(CV)*split_len).astype(int))+[train_len]
for i in range(5):
    train.loc[split_indices[i]:split_indices[i+1],'fold']=i+1


In [ ]:
df = df_coor[
    df_coor.condition.isin(
        [
            'Left Subarticular Stenosis',
            'Right Subarticular Stenosis'
        ]
    )
][['study_id','series_id','instance_number','level']].drop_duplicates()
df_min = df.groupby(['study_id','series_id','level'],sort=True).min().reset_index().rename(columns={'instance_number':'instance_number_min'})
df_max = df.groupby(['study_id','series_id','level'],sort=True).max().reset_index().rename(columns={'instance_number':'instance_number_max'})
df_min['instance_number_max'] = df_max['instance_number_max']
df = df_min

In [ ]:
centers = {}
for i in range(len(df)):
    row = df.iloc[i]
    centers[row['study_id']]={}
for i in range(len(df)):
    row = df.iloc[i]
    centers[row['study_id']][row['series_id']]={}
for i in range(len(df)):
    row = df.iloc[i]
    centers[row['study_id']][row['series_id']][row['level']] = [row['instance_number_min'],row['instance_number_max']]

In [ ]:
df = df.groupby(['study_id','series_id']).count().reset_index()[['study_id','series_id']]
df.tail()

In [ ]:
# We'll use 0 as missing value
v = np.zeros((len(df),10)).astype(int)
for i in range(len(df)):
    row = df.iloc[i]
    for level in centers[row['study_id']][row['series_id']]:
        v_min,v_max = centers[row['study_id']][row['series_id']][level]
        v[i,{'L1/L2':0,'L2/L3':2,'L3/L4':4,'L4/L5':6,'L5/S1':8}[level]] = v_min
        v[i,{'L1/L2':1,'L2/L3':3,'L3/L4':5,'L4/L5':7,'L5/S1':9}[level]] = v_max
        

In [ ]:
df[[
    'L1L2_min','L1L2_max',
    'L2L3_min','L2L3_max',
    'L3L4_min','L3L4_max',
    'L4L5_min','L4L5_max',
    'L5S1_min','L5S1_max'
]] = v
df[(df[[
    'L1L2_min','L1L2_max',
    'L2L3_min','L2L3_max',
    'L3L4_min','L3L4_max',
    'L4L5_min','L4L5_max',
    'L5S1_min','L5S1_max'
]] == 0).sum(1)>0].reset_index(drop=True).tail()

df['flip'] = False
fdf = df.copy()
fdf['flip'] = True
df = pd.concat([df,fdf]).reset_index(drop=True)
df.tail()

In [ ]:
level_train = df.merge(train[['study_id','fold']],left_on='study_id',right_on='study_id')
level_train.tail()

### Create Coordinates from Segmentation Model

In [ ]:
coords = {}
for i in range(CV):
    coords[i+1]={}
for i in range(len(level_train)):
    study_id = level_train.iloc[i]['study_id']
    series_id = level_train.iloc[i]['series_id']
    f = level_train.iloc[i]['fold']
    if study_id not in coords[f]:
        coords[f][study_id] = {}
    if series_id not in coords[f][study_id]:
        coords[f][study_id][series_id]=[]

In [ ]:
def input_image_resizing(image):
    H,W = image.shape
#       By plane resizing I've been distorting the proportions
    if H > W:
        d = W
        h = int((H - d)*(.5 + .5*(.5 - np.random.rand())))
        image = image[h:h+d]
        H = W
    elif H < W:
        d = H
        w = int((W - d)*(.5 + .5*(.5 - np.random.rand())))
        image = image[:,w:w+d]
        W = H
    image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
    image = torch.as_tensor(image/np.max(image)).unsqueeze(0).float()
    return image

def create_coords(df,unetModels):
    counter = 0
    sample = './datasets/train_images/'
    for i in range(len(df)):
        row = df.iloc[i]
        sample1 = sample+str(int(row['study_id']))+'/'+str(int(row['series_id']))
        images = [x.replace('\\','/') for x in glob.glob(sample1+'/*.dcm')]
#         if(row['study_id']==2091088734):
#             print(len(images),sample)
        slices = list(np.arange(len(images)))
        slices.sort(key=lambda k:int(images[k].split('/')[-1].replace('.dcm','')))
        images = np.array(images)[slices]
        images = [pydicom.dcmread(img).pixel_array for img in images]
#         print(row['study_id'],row['series_id'])
        
        for img in images:
            img = input_image_resizing(img)
            OUT = unetModels[0](img.to(device).unsqueeze(0)).cpu().detach()
            for j in range(1,len(unetModels)):
                OUT+=unetModels[j](img.to(device).unsqueeze(0)).cpu().detach()
            OUT/=len(unetModels)
    
            # cc = []
            for k in range(2):
                c = (OUT[0,k].unsqueeze(0)*idx_map[0,0].cpu()).sum(-1).sum(-1)
                d = OUT[0,k].sum()
                c = c/d
#                 print(c,type(c),c.shape)
#                 cc.append(c)
                YY,XX = c.float()
            # cc = torch.stack(cc)
#             print(cc,cc.shape)
#             cc = torch.mean(cc,dim=0)
                coords[row['fold']][row['study_id']][row['series_id']].append(torch.tensor(np.array([YY,XX])))
        counter+=1
        if (counter+1)%100 == 0:
            print("Progress: {}/{}".format(counter,len(df)))

In [ ]:
create_coords(level_train,axial_seg_models)

In [ ]:
with open('axial_pred_centers.pkl', 'wb') as f:
    pickle.dump(coords, f)

In [ ]:
with open('axial_pred_centers.pkl', 'rb') as f:
    coords = pickle.load(f)

### Slice Predictor

In [ ]:
class Axial_ViT_Dataset(Dataset):
    def __init__(self, df, f, VALID=False, INFERENCE=False, alpha=0):
        self.data = df
        self.f = f
        self.VALID = VALID
        self.INFERENCE = INFERENCE
        self.resize = torchvision.transforms.Resize((PATCH_SIZE,PATCH_SIZE),antialias=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        
        sample = './datasets/rsna-2024-lumbar-spine-degenerative-classification/train_images/'
        sample1 = sample+str(int(row['study_id']))+'/'+str(int(row['series_id']))

        images = [x.replace('\\','/') for x in glob.glob(sample1+'/*.dcm')]
        slices = list(np.arange(len(images)))
        slices.sort(key=lambda k:int(images[k].split('/')[-1].replace('.dcm','')))
        instance_numbers = torch.as_tensor([int(x.split('/')[-1].replace('.dcm','')) for x in images])[slices]
        images = np.array(images)[slices]
        D = len(images)
        c = torch.stack(coord[row['fold']][row['study_id']][row['series_id']])
#         c = (coord[row['fold']][row['study_id']][row['series_id']]).clone()
        
#         print("c shape:",c.shape)
        c[c < 64] = torch.nan
        c[c > 512 - 64] = torch.nan
#         print(D,images,slices,row['study_id'],row['series_id'])
        if D > Lmax:
            slices = np.rint(torch.arange(Lmax)*D/Lmax).long()
            images = images[slices]
            c = c[slices]
            instance_numbers = instance_numbers[slices]
        elif D < Lmax:
            N = Lmax//D + 1
            slices = torch.repeat_interleave(torch.arange(D), N) 
            slices = slices[np.rint(torch.arange(Lmax)*(D*N)/Lmax).long()]
            images = images[slices]
            c = c[slices]
        
        images = [torch.as_tensor(pydicom.dcmread(img).pixel_array.astype('float32')) for img in images]
        shapes = [img.shape for img in images]
        H,W = np.array(shapes).max(0)

        image = torch.concat([torch.nn.functional.pad(
            images[k].unsqueeze(0),(
                (W - shapes[k][-1])//2,
                (W - shapes[k][-1]) - (W - shapes[k][-1])//2,
                (H - shapes[k][-2])//2,
                (H - shapes[k][-2]) - (H - shapes[k][-2])//2
            ),
        mode='reflect') for k in range(len(images))]).float()

        if H > W:
            d = W
            h = (H - d)//2
            image = image[:,h:h+d]
            H = W
        elif H < W:
            d = H
            w = (W - d)//2
            image = image[:,:,w:w+d]
            W = H

        image = self.resize(image/image.max()).float().unsqueeze(1).to(device)

        mask = torch.isnan(c)
        c_mean = torch.nanmean(c,0)
        c[mask[:,0],0] = c_mean[0]
        c[mask[:,1],1] = c_mean[1]
        c = c.long()
        image = torch.stack([
            image[
                i,
                :,
                c[i,1]-patch_size//2:c[i,1]+patch_size-patch_size//2,
                c[i,0]-patch_size//2:c[i,0]+patch_size-patch_size//2
            ] for i in range(len(images))
        ])
        
        if self.INFERENCE:
            return image.to(device)
        else:
            indices = np.arange(len(images))
            label_min = -torch.ones(5).int()
            label_max = -torch.ones(5).int()
            label = -torch.ones(5).float()
            values_min = torch.as_tensor(row[['L1L2_min','L2L3_min','L3L4_min','L4L5_min','L5S1_min']].values.astype(int)).view(-1)
            values_max = torch.as_tensor(row[['L1L2_max','L2L3_max','L3L4_max','L4L5_max','L5S1_max']].values.astype(int)).view(-1)
            mask = values_min != 0
            label_min[mask] = torch.as_tensor(np.array([np.where(instance_numbers==l)[0][0] if sum(instance_numbers==l)>0 else -1 for l in values_min[mask]]).astype(np.int32)).view(-1)
            label_max[mask] = torch.as_tensor(np.array([np.where(instance_numbers==l)[0][0] if sum(instance_numbers==l)>0 else -1 for l in values_max[mask]]).astype(np.int32)).view(-1)
            label[mask] = ((label_min + label_max)/2)[mask]
            label = label*Lmax/D
            if row['flip']:
                image = image.flip(0)
                c = c.flip(0)
                label[label != -1] = Lmax - 1 - label[label != -1]
            return image.to(device),label.to(device)



In [ ]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Axial_ViT(nn.Module):
    def __init__(self, ENCODER, dim=512, depth=24, head_size=64, **kwargs):
        super().__init__()
        self.ENCODER = ENCODER
        self.AvgPool = nn.AdaptiveAvgPool2d(output_size=1).to(device)
        self.slices_enc = nn.Parameter(SinusoidalPosEmb(dim)(torch.arange(Lmax, device=device).unsqueeze(0)))
        self.slices_transformer = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True, device=device), depth)
        self.proj_out = nn.Linear(dim,5).to(device)
    
    def forward(self, x):
        x = self.ENCODER(x.view(-1,1,patch_size,patch_size))[-1]
        x = self.AvgPool(x)
        x = x.view(-1,Lmax,512)
        x = x + self.slices_enc
        x = self.slices_transformer(x)#,src_key_padding_mask=slices_mask)
        x = self.proj_out(x.view(-1,512)).view(-1,Lmax,5).permute(0,2,1)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.min(-1)[0].view(-1,5,1)
        max_values = x.max(-1)[0].view(-1,5,1)
        x = (x - min_values)/(max_values - min_values)
        return x

In [ ]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [ ]:
idx_map1 = torch.arange(Lmax).to(device)

In [ ]:
class myLoss1(nn.Module):
    def __init__(
            self,
            alpha=.5
        ):
        super().__init__()
        self.alpha = alpha

    def clone(self):
        return myLoss(self.alpha)

    def forward(
            self,
            y,# Predictions
            t # Targets
        ):
        available = t >= 0
        mask_true = t[available]
        mask_pred = y[available]
#       The heatmap Loss as the distance between the predicted Normal and the ideal one
#       Let's define the ideal heatmaps as the Normal distributions
#       centered on the diagnostic centers with s2 = PATCH_SIZE/8
        s2 = s2 = torch.as_tensor([PATCH_SIZE/16])
#       Then the corresponding alphas and normalization constants would be
        A = -1/(2*s2).to(device)
        K = 1/torch.sqrt(2*math.pi*s2).to(device)
#       Predicted heatmaps rescaling
        mask_pred = mask_pred*K
#       Ideal heatmaps
        mask = mask_true.view(-1,1) - idx_map1.view(1,-1)
        mask = mask*mask
        mask = torch.exp(A*mask)*K

#       Distance
        D = 1 -((mask*mask_pred).sum(-1))**2/((mask*mask).sum(-1)*(mask_pred*mask_pred).sum(-1))
        
        return D.nanmean()

In [ ]:
df = level_train
for f in range(1,6):
    seed_everything(SEED)
    seg_model = torch.load('./datasets/axial-seg-mobilnet/SEG_subarticular_'+str(f))
    model = Axial_ViT(seg_model.UNet.encoder)
    tdf = df[df.fold != f]
    vdf = df[df.fold == f]
    tds = Axial_ViT_Dataset(tdf,f)
    vds = Axial_ViT_Dataset(vdf,f,VALID=True)
    tdl = torch.utils.data.DataLoader(tds, batch_size=BS, shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(vds, batch_size=BS, shuffle=False)
    dls = DataLoaders(tdl,vdl)

    n_iter = len(tds)//BS

    learn = Learner(
        dls,
        model,
        loss_func=myLoss1(),
        cbs=[
            ShowGraphCallback(),
            GradientClip(3.0)
        ]
    )
    learn.fit_one_cycle(EPOCHS, lr_max=1e-4, wd=0.05, pct_start=0.02)
    torch.save(model,'axial_T2_levels_'+str(f))
    del model,seg_model,tdf,vdf,tds,vds,tdl,vdl,dls,learn
    gc.collect()

In [None]:
levels = []
y_true = []
with torch.no_grad():
    for X,l in tqdm(vdl):
        levels = levels + torch.argmax((model(X)+model(X.flip(1)).flip(-1)),-1).cpu().tolist()
        y_true = y_true + l.cpu().tolist()

In [ ]:
y_true = np.array(y_true).flatten()
levels = np.array(levels).flatten()
plt.plot(y_true,levels,'.')

In [ ]:
for X,l in tqdm(vdl):
    with torch.no_grad():
            levels = (model(X)+model(X.flip(1)).flip(-1)).cpu()/2
            x0 = torch.argmax(levels,-1)
            break

In [ ]:
from scipy.optimize import curve_fit

def func(x, x0, b):
    r = x - x0
    return np.exp(-b*r*r)

In [ ]:
df_meta_f = pd.read_csv('./datasets/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv')
df_meta_f.tail()

In [ ]:
train = pd.read_csv("./datasets/rsna-2024-lumbar-spine-degenerative-classification/train.csv")
train.head()
CV=5
v,c = np.unique(train['study_id'],return_counts=True)
L = len(v)
S = L/CV
fold_indices = list(np.rint(np.arange(CV)*S).astype(int))+[L]

for i in range(5):
    print(len(v[fold_indices[i]:fold_indices[i+1]]))
    train.loc[train['study_id'].isin(v[fold_indices[i]:fold_indices[i+1]]),'fold'] = i+1
train.tail()
train = train[['study_id','fold']].merge(df_meta_f[df_meta_f.series_description == 'Axial T2'][['study_id','series_id']],left_on='study_id',right_on='study_id').sort_values('fold').reset_index(drop=True)
train.tail()

In [ ]:
train = level_train[level_train['flip']==True]
train.head()

In [ ]:
indices = np.arange(Lmax)
for f in [1,2,3,4,5]:
    model = torch.load('./datasets/final-axial-slice-calssifier/New_Axial_classifier_'+str(f))
    model.eval()
    dl = torch.utils.data.DataLoader(Axial_ViT_Dataset(train,f,VALID=True,INFERENCE=True), batch_size=BS, shuffle=False)

    popt = []
    with torch.no_grad():
        for X in tqdm(dl):
            y_pred = (model(X)+model(X.flip(1)).flip(-1)).cpu()/2
            x0 = torch.argmax(y_pred,-1)

            for i in range(len(y_pred)):
                for k in range(5):
                    popt = popt + list(curve_fit(func, indices, y_pred[i][k],[x0[i][k],1],maxfev = 9999)[0])

    
    popt = torch.as_tensor(popt).view(-1,5,2).permute(0,2,1).reshape(-1,2*5)
    mask = (popt[:,:5] < 0) + (popt[:,:5] > Lmax - 1)
    popt[:,:5][mask] = torch.nan
    popt[:,5:][mask] = torch.nan
    train.loc[:,[
        'L1L2_x0','L2L3_x0','L3L4_x0','L4L5_x0','L5S1_x0',
        'L1L2_b','L2L3_b','L3L4_b','L4L5_b','L5S1_b'
    ]] = np.array(popt)

    train[[
        'study_id','fold','series_id',
        'L1L2_x0','L2L3_x0','L3L4_x0','L4L5_x0','L5S1_x0',
        'L1L2_b','L2L3_b','L3L4_b','L4L5_b','L5S1_b'
    ]].to_csv('Final_axial_T2_levels_'+str(f)+'.csv',index=False)
    del model,dl,popt
    gc.collect()


In [ ]:
v = train[['L1L2_b','L2L3_b','L3L4_b','L4L5_b','L5S1_b']].values
plt.boxplot(v[~np.isnan(v)])

In [ ]:
train[(train[['L1L2_b','L2L3_b','L3L4_b','L4L5_b','L5S1_b']]>np.percentile(v[~np.isnan(v)],99)).sum(1)>0]

In [ ]:
train[(train[['L1L2_b','L2L3_b','L3L4_b','L4L5_b','L5S1_b']]>np.percentile(v[~np.isnan(v)],99)).sum(1)>0]

In [ ]:
final_train = train[(train[['L1L2_b', 'L2L3_b', 'L3L4_b', 'L4L5_b', 'L5S1_b']] <= np.percentile(v[~np.isnan(v)], 99)).all(1)]
final_train = final_train[['study_id', 'series_id',
       'fold', 'flip', 'L1L2_x0', 'L2L3_x0', 'L3L4_x0', 'L4L5_x0', 'L5S1_x0',
       'L1L2_b', 'L2L3_b', 'L3L4_b', 'L4L5_b', 'L5S1_b']]

In [ ]:
final_train.to_csv('final_slice_pred.csv',index=False)

### Training Final Classifier Model

In [ ]:
new_train_df = pd.read_csv("./datasets/data-to-train-on/data_to_train_on.csv")
new_train_df = new_train_df.dropna(axis=0).reset_index(drop=True)
new_train_df.head()

In [ ]:
v = len(new_train_df)
split_len = v/CV
indices = list(np.rint(np.arange(CV)*split_len).astype(int))+[v]
for i in range(CV):
    new_train_df.loc[indices[i]:indices[i+1],'fold']=i+1

In [ ]:
labels = {
    'Normal/Mild':0,
    'Moderate':1,
    'Severe':2
}
coords_cols = ['x','y']

In [ ]:
def augment_image(image,alpha):
    # Randomly rotate the image.
    angle = torch.as_tensor(random.uniform(-180, 180)*alpha)
    image = torchvision.transforms.functional.rotate(image.unsqueeze(0),angle.item())
    return image

In [ ]:
target = new_train_df.columns[-2]
target

In [ ]:
idx_map = torch.stack([torch.arange(PATCH_SIZE)]*PATCH_SIZE).to(device)
idx_map = torch.stack([idx_map,idx_map.T]).view(1,1,2,PATCH_SIZE,PATCH_SIZE)

In [ ]:
class AxialUnet(nn.Module):
    def __init__(self):
        super(myUNet, self).__init__()

        self.UNet = smp.Unet(
            encoder_name="resnet34",
            classes=2,
            in_channels=1
        ).to(device)

    def forward(self,X):
        x = self.UNet(X)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.view(-1,2,PATCH_SIZE*PATCH_SIZE).min(-1)[0].view(-1,2,1,1) # Bug, I've been MinMaxScaling with the wrong values
        max_values = x.view(-1,2,PATCH_SIZE*PATCH_SIZE).max(-1)[0].view(-1,2,1,1)
        x = (x - min_values)/(max_values - min_values)
        
        return x

In [ ]:
class ViT_T1_Dataset(Dataset):
    def __init__(self, df, UNet, VALID=False, P=patch_size, alpha=0):
        self.data = df
        self.UNet = UNet
        self.VALID = VALID
        self.P = P
        self.alpha = alpha

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        
        sample = './datasets/rsna-2024-lumbar-spine-degenerative-classification/train_images/'
        sample = sample+str(row['study_id'])+'/'+str(row['series_id'])+'/'+str(row['instance_number'])+'.dcm'
        centers = torch.as_tensor([x for x in row[coords_cols]]).view(1,2).float()
        image = pydicom.dcmread(sample).pixel_array
        H,W = image.shape
#       By plane resizing I've been distorting the proportions
        if H > W:
            d = W
            h = (H - d)//2
            image = image[h:h+d]
            centers[:,1] -= h
            H = W
        elif H < W:
            d = H
            w = (W - d)//2
            image = image[:,w:w+d]
            centers[:,0] -= w
            W = H
        image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
        image = torch.as_tensor(image/np.max(image)).unsqueeze(0).unsqueeze(0).float().to(device)
        OUT = 0
        with torch.no_grad():
                for rot in [0,1,2,3]:
                        OUT += torch.rot90(self.UNet(torch.rot90(image,rot,dims=[-2, -1])),-rot,dims=[-2, -1])
#         print(OUT.shape)
        OUT = (OUT/4 > 0.5)[0]
        c = (OUT.unsqueeze(0)*idx_map[0]).view(1,2,PATCH_SIZE*PATCH_SIZE).sum(-1).float()
        d = OUT.view(2,PATCH_SIZE*PATCH_SIZE).sum(-1).float()
#         print("d ",d,d.shape)
#         print(c.shape,d.shape) 
#         m = d > 0
#         print("m",m,m.shape)
#         print("c[m]",c[m],c[m].shape)
        c = (c.squeeze(0)/d).float().unsqueeze(0)
        c[torch.isnan(c)] = 256
#         c = (c/d.unsqueeze(-1)).float()
#         print("final C",c,c.shape)
#         c[~m] = self.P # I have to find a better solution
#         c[~m] = 0
#         return image
#         c[c < 32] = torch.nan
#         c[c > 580] = torch.nan
#         c_mean = torch.nanmean(c, dim=0)
#         mask = torch.isnan(c)
#         c[mask[:,0],0] = c_mean[0]
#         c[mask[:,1],1] = c_mean[1]
        c = c.long()
#         print(row['study_id'],row['series_id'],row['instance_number'])
        image = torch.stack([image[
            0,
            0,
            xy[1]-self.P//2:xy[1]+self.P-self.P//2,
            xy[0]-self.P//2:xy[0]+self.P-self.P//2
        ] for xy in c])
#         print(image.shape,c,mask,c_mean)
#         print(image.shape,c)
        if not self.VALID: 
            
            image = augment_image(image,self.alpha)
#         label = torch.as_tensor([labels[x] for x in row[target]])
        label = torch.as_tensor(labels[row[target]])
       
        return image.to(device),label.to(device)

In [ ]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class myAxialViT(nn.Module):
    def __init__(self, dim=512, depth=12, head_size=128, **kwargs):
        super().__init__()
        CNN = torchvision.models.resnet18(weights='DEFAULT')
        W = nn.Parameter(CNN.conv1.weight.sum(1, keepdim=True))
        CNN.conv1 = nn.Conv2d(1, patch_size, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        CNN.conv1.weight = W
        CNN.fc = nn.Identity()
        self.emb = CNN.to(device)
        self.pos_enc = nn.Parameter(SinusoidalPosEmb(dim)(torch.arange(1, device=device).unsqueeze(0)))
        self.transformer = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True, device=device), depth)
        self.proj_out = nn.Linear(dim,3).to(device)
    
    def forward(self, x):
        x = self.emb(x.view(-1,1,patch_size,patch_size))
        x = x.view(-1,1,512)
        x = x + self.pos_enc
        x = self.transformer(x)
        x = self.proj_out(x.view(-1,512))
        return x

In [ ]:
def myLoss(preds,target):
#     target,mask = target
#     target = target[~mask]
#     preds = preds[~mask.view(-1)]
#     print(preds.shape,target.shape)
    return nn.CrossEntropyLoss(weight=torch.as_tensor([1.,2.,4.]).to(device))(preds,target)

In [ ]:
def nt(nmin,nmax,tcur,tmax):
    return (nmax - .5*(nmax-nmin)*(1+np.cos(tcur*np.pi/tmax))).astype(np.float32)

plt.plot(nt(0,1,np.arange(10),10))
plt.show()

# callback to update alpha during training
def cb(self):
    alpha = torch.as_tensor(nt(.25,1,learn.train_iter,10*n_iter))
    learn.dls.train_ds.alpha = alpha
alpha_cb = Callback(before_batch=cb)#

In [ ]:
for fold in range(1,6):
    seed_everything(SEED)
    tdf = new_train_df[new_train_df['fold'] != fold]
    vdf = new_train_df[new_train_df['fold'] == fold]
    UNet = torch.load("./datasets/axial-seg-model/SEG_subarticular_"+str(fold))
    tds = ViT_T1_Dataset(tdf,UNet)
    vds = ViT_T1_Dataset(vdf,UNet,VALID=True)
    tdl = torch.utils.data.DataLoader(tds, batch_size=32, shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(vds, batch_size=32, shuffle=False)
    
    
    
    dls = DataLoaders(tdl,vdl)
    
    n_iter = len(tds)//32
    model = myAxialViT()
    
    learn = Learner(
        dls,
        model,
        lr=1e-4,
        loss_func=myLoss,
        cbs=[
            ShowGraphCallback(),
            alpha_cb,
            GradientClip(3.0)
        ]
    )
    learn.fit_one_cycle(10,lr_max=1e-4,wd=0.1)
    torch.save(model,'VIT_Axial_Classifier_'+str(fold))
    del tds,vds,tdl,vdl,dls,model,learn
    gc.collect()