In [ ]:
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import pydicom
import numpy as np
import os
import glob
from tqdm import tqdm
import gc

import torchvision
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from fastai.vision.all import *
import sys
!pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [ ]:
rd = './datasets/rsna-2024-lumbar-spine-degenerative-classification'
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

In [ ]:
df = pd.read_csv("./datasets/rsna-2024-lumbar-spine-degenerative-classification/test_series_descriptions.csv")
# df = pd.read_csv("./datasets/rsna-lsdc-2024-submission-debug-dataset/debug/test_series_descriptions.csv")
df.head()

In [ ]:
study_ids = list(df['study_id'].unique())
sample_sub = pd.read_csv("./datasets/rsna-2024-lumbar-spine-degenerative-classification/sample_submission.csv")
LABELS = list(sample_sub.columns[1:])
LABELS

In [ ]:
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]
NEURAL_FORAMINAL_CONDITION = [
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
]

SPINAL_CANAL_CONDITION = [
    'spinal_canal_stenosis'
]

SUBARTICULAR_CONDITIONS = [
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]


LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]
PATCH_SIZE = 512
patch_size = 128
Lmax=36

In [ ]:
class myUNet(nn.Module):
    def __init__(self):
        super(foraminalunet, self).__init__()

        self.UNet= smp.Unet(
            encoder_name="resnet34",
            encoder_weights=None,
            classes=5,
            in_channels=1
        ).to(device)

    def forward(self,X):
        x = self.UNet(X)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.view(-1,5,PATCH_SIZE*PATCH_SIZE).min(-1)[0].view(-1,5,1,1) # Bug, I've been MinMaxScaling with the wrong values
        max_values = x.view(-1,5,PATCH_SIZE*PATCH_SIZE).max(-1)[0].view(-1,5,1,1)
        x = (x - min_values)/(max_values - min_values)
        
        return x

In [ ]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class myViT(nn.Module):
    def __init__(self, dim=512, depth=12, head_size=128, **kwargs):
        super().__init__()
        CNN = torchvision.models.resnet34(weights='DEFAULT')
        W = nn.Parameter(CNN.conv1.weight.sum(1, keepdim=True))
        CNN.conv1 = nn.Conv2d(1, patch_size, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        CNN.conv1.weight = W
        CNN.fc = nn.Identity()
        self.emb = CNN.to(device)
        self.pos_enc = nn.Parameter(SinusoidalPosEmb(dim)(torch.arange(5, device=device).unsqueeze(0)))
        self.transformer = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True, device=device), depth)
        self.proj_out = nn.Linear(dim,3).to(device)
    
    def forward(self, x):
#         print("x: ",x)
        x,mask = x
        x = self.emb(x.view(-1,1,patch_size,patch_size))
        x = x.view(-1,5,512)
        x = x + self.pos_enc
        x = self.transformer(x,src_key_padding_mask=mask)
        x = self.proj_out(x.view(-1,512))
        return x

In [ ]:
def get_image_for_canal_foraminal(path ,unet, patch_size, TH=0.5):
    idx_map = torch.stack([torch.arange(PATCH_SIZE)]*PATCH_SIZE).to(device)
    idx_map = torch.stack([idx_map,idx_map.T]).view(1,1,2,PATCH_SIZE,PATCH_SIZE)
    image = pydicom.dcmread(path).pixel_array
    H,W = image.shape
#       By plane resizing I've been distorting the proportions
    if H > W:
        d = W
        h = (H - d)//2
        image = image[h:h+d]
        H = W
    elif H < W:
        d = H
        w = (W - d)//2
        image = image[:,w:w+d]
        W = H
    image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
    image = torch.as_tensor(image/np.max(image)).unsqueeze(0).unsqueeze(0).float().to(device)

    OUT = 0
    with torch.no_grad():
            for rot in [0,1,2,3]:
                    OUT += torch.rot90(unet(torch.rot90(image,rot,dims=[-2, -1])),-rot,dims=[-2, -1])

    OUT = (OUT/4 > TH)[0]
    c = (OUT.unsqueeze(1)*idx_map[0]).view(5,2,PATCH_SIZE*PATCH_SIZE).sum(-1).float()
    d = OUT.view(5,PATCH_SIZE*PATCH_SIZE).sum(-1).float()
    m = d > 0
    c[m] = (c[m]/d[m].unsqueeze(-1)).float()
    c[~m] = 0 # I have to find a better solution
    c[c < 64] = torch.nan
    c[c > 448] = torch.nan
    c_mean = torch.nanmean(c, dim=0)
    mask = torch.isnan(c)
    c[mask[:,0],0] = c_mean[0]
    c[mask[:,1],1] = c_mean[1]
    c = c.long()
    image = torch.stack([image[
        0,
        0,
        xy[1]-patch_size//2:xy[1]+patch_size-patch_size//2,
        xy[0]-patch_size//2:xy[0]+patch_size-patch_size//2
    ] for xy in c])

    return [image.to(device),~m.to(device)]


In [ ]:
class AxialUnet(nn.Module):
    def __init__(self):
        super(AxialUnet, self).__init__()

        self.UNet = smp.Unet(
            encoder_name="resnet34",
            classes=2,
            in_channels=1
        ).to(device)

    def forward(self,X):
        x = self.UNet(X)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.view(-1,2,PATCH_SIZE*PATCH_SIZE).min(-1)[0].view(-1,2,1,1) # Bug, I've been MinMaxScaling with the wrong values
        max_values = x.view(-1,2,PATCH_SIZE*PATCH_SIZE).max(-1)[0].view(-1,2,1,1)
        x = (x - min_values)/(max_values - min_values)
        
        return x


In [ ]:
def input_image_resizing(image):
    H,W = image.shape
#       By plane resizing I've been distorting the proportions
    if H > W:
        d = W
        h = int((H - d)*(.5 + .5*(.5 - np.random.rand())))
        image = image[h:h+d]
        H = W
    elif H < W:
        d = H
        w = int((W - d)*(.5 + .5*(.5 - np.random.rand())))
        image = image[:,w:w+d]
        W = H
    image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
    image = torch.as_tensor(image/np.max(image)).unsqueeze(0).float()
    return image

def create_coords(unet_model, images):
    idx_map = torch.stack([torch.arange(PATCH_SIZE)]*PATCH_SIZE).to(device)
    idx_map = torch.stack([idx_map,idx_map.T]).view(1,1,2,PATCH_SIZE,PATCH_SIZE)
    slices = list(np.arange(len(images)))
    slices.sort(key=lambda k:int(images[k].split('/')[-1].replace('.dcm','')))
    images = np.array(images)[slices]
    images = [pydicom.dcmread(img).pixel_array for img in images]
    coords = []
    for img in images:
        img = input_image_resizing(img)
        OUT = unet_model(img.to(device).unsqueeze(0)).cpu().detach()
        for k in range(2):
            c = (OUT[0,k].unsqueeze(0)*idx_map[0,0].cpu()).sum(-1).sum(-1)
            d = OUT[0,k].sum()
            c = c/d
            YY,XX = c.float()
            coords.append(torch.tensor(np.array([YY,XX])))
    coords_mean = []
    for i in range(0,len(coords),2):
        coords_mean.append(torch.mean(torch.stack([coords[i],coords[i+1]]),axis=0))
    return coords, coords_mean

In [ ]:
def get_images_for_axial_slice_labelling(coord,images):
#     print(coord)
    resize = torchvision.transforms.Resize((PATCH_SIZE,PATCH_SIZE),antialias=True)
    c = torch.stack(coord)
    slices = list(np.arange(len(images)))
    slices.sort(key=lambda k:int(images[k].split('/')[-1].replace('.dcm','')))
    instance_numbers = torch.as_tensor([int(x.split('/')[-1].replace('.dcm','')) for x in images])[slices]
    images = np.array(images)[slices]
#     images = [pydicom.dcmread(img).pixel_array for img in images]
    D = len(images)
    
    c[c < 64] = torch.nan
    c[c > 448] = torch.nan
    
    if D > Lmax:
        slices = np.rint(torch.arange(Lmax)*D/Lmax).long()
        images = images[slices]
        c = c[slices]
    elif D < Lmax:
        N = Lmax//D + 1
        slices = torch.repeat_interleave(torch.arange(D), N) 
        slices = slices[np.rint(torch.arange(Lmax)*(D*N)/Lmax).long()]
        images = images[slices]
        c = c[slices]

    images = [torch.as_tensor(pydicom.dcmread(img).pixel_array.astype('float32')) for img in images]
    shapes = [img.shape for img in images]
    H,W = np.array(shapes).max(0)

    image = torch.concat([torch.nn.functional.pad(
        images[k].unsqueeze(0),(
            (W - shapes[k][-1])//2,
            (W - shapes[k][-1]) - (W - shapes[k][-1])//2,
            (H - shapes[k][-2])//2,
            (H - shapes[k][-2]) - (H - shapes[k][-2])//2
        ),
    mode='reflect') for k in range(len(images))]).float()

    if H > W:
        d = W
        h = (H - d)//2
        image = image[:,h:h+d]
        H = W
    elif H < W:
        d = H
        w = (W - d)//2
        image = image[:,:,w:w+d]
        W = H
        
    image = resize(image/image.max()).float().unsqueeze(1).to(device)
    

    mask = torch.isnan(c)
    c_mean = torch.nanmean(c,0)
    c[mask[:,0],0] = c_mean[0]
    c[mask[:,1],1] = c_mean[1]
    c = c.long()
    image = torch.stack([
        image[
            i,
            :,
            c[i,1]-patch_size//2:c[i,1]+patch_size-patch_size//2,
            c[i,0]-patch_size//2:c[i,0]+patch_size-patch_size//2
        ] for i in range(len(images))
    ])
    return image.to(device)

In [ ]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[...,None] * emb[None,...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Axial_ViT(nn.Module):
    def __init__(self, ENCODER, dim=512, depth=24, head_size=64, **kwargs):
        super().__init__()
        self.ENCODER = ENCODER
        self.AvgPool = nn.AdaptiveAvgPool2d(output_size=1).to(device)
        self.slices_enc = nn.Parameter(SinusoidalPosEmb(dim)(torch.arange(Lmax, device=device).unsqueeze(0)))
        self.slices_transformer = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=dim, nhead=dim//head_size, dim_feedforward=4*dim,
                dropout=0.1, activation=nn.GELU(), batch_first=True, norm_first=True, device=device), depth)
        self.proj_out = nn.Linear(dim,5).to(device)
    
    def forward(self, x):
        x = self.ENCODER(x.view(-1,1,patch_size,patch_size))[-1]
        x = self.AvgPool(x)
        x = x.view(-1,Lmax,512)
        x = x + self.slices_enc
        x = self.slices_transformer(x)#,src_key_padding_mask=slices_mask)
        x = self.proj_out(x.view(-1,512)).view(-1,Lmax,5).permute(0,2,1)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.min(-1)[0].view(-1,5,1)
        max_values = x.max(-1)[0].view(-1,5,1)
        x = (x - min_values)/(max_values - min_values)
        return x

In [ ]:
Lmax=36
indices = np.arange(Lmax)
def get_axial_slices(X):
#     for f in [1,2,3,4,5]:
#         model = torch.load('./datasets/rsna-models/axial_T2_levels_'+str(f))
#         model.eval()
# #         dl = torch.utils.data.DataLoader(Axial_ViT_Dataset(train,f,VALID=True,INFERENCE=True), batch_size=BS, shuffle=False)

#         popt = []
#         with torch.no_grad():
#             y_pred = (model(X)+model(X.flip(1)).flip(-1)).cpu()/2
#             x0 = torch.argmax(y_pred,-1)

#             for i in range(len(y_pred)):
#                 for k in range(5):
#                     popt = popt + list(curve_fit(func, indices, y_pred[i][k],[x0[i][k],1],maxfev = 9999)[0])
#         popt = torch.as_tensor(popt).view(-1,5,2).permute(0,2,1).reshape(-1,2*5)
#         mask = (popt[:,:5] < 0) + (popt[:,:5] > Lmax - 1)
#         popt[:,:5][mask] = torch.nan
#         popt[:,5:][mask] = torch.nan
#     return popt[:5]
#     model = torch.load('./datasets/rsna-models/axial_T2_levels_1',map_location=torch.device("cpu"))
#     models = []
#     for i in range(1,6):
#         model = torch.load('./datasets/rsna-models/axial_T2_levels_'+str(i))
#         models.append(model)
#     merged_model = torch.nn.Sequential(*list(model[0].children()))
#     for param_name, param in merged_model.named_parameters():
#         param.data = (model1.state_dict()[param_name] + model2.state_dict()[param_name]) / 2
#     model.eval()
#  ---------------------------------------
    models1 = []
    models2 = []
    for i in range(2,6):
#         model = torch.load('./datasets/rsna-models/axial_T2_levels_'+str(i))
        models1.append(torch.load('./datasets/rsna-models/axial_T2_levels_'+str(i)))
    merged_model1 = torch.load('./datasets/rsna-models/axial_T2_levels_1')
    for i in range(1,6):
#         model = torch.load('./datasets/rsna-models/axial_T2_levels_'+str(i))
        models2.append(torch.load('./datasets/final-axial-slice-calssifier/New_Axial_classifier_'+str(i)))
    merged_model2 = torch.load('./datasets/final-axial-slice-calssifier/New_Axial_classifier_1')
    for param_name, param in merged_model1.named_parameters():
        for i in range(0,4):
            param.data+=models1[i].state_dict()[param_name]
        param.data/=4
    for param_name, param in merged_model2.named_parameters():
        for i in range(0,5):
            param.data+=models2[i].state_dict()[param_name]
        param.data/=5
    merged_model1.eval()
    merged_model2.eval()
    levels0 = torch.argmax((merged_model1(X)+merged_model1(X.flip(1)).flip(-1)),-1).cpu().tolist()
    levels1 = torch.argmax((merged_model2(X)+merged_model2(X.flip(1)).flip(-1)),-1).cpu().tolist()
    return levels1

In [ ]:
class RSNA24Dataset(Dataset):
    def __init__(self, df,foraminal_unet,canal_unet,subarticular_unet,patch_size=128):
        self.df = df
        self.foraminal_unet = foraminal_unet
        self.canal_unet = canal_unet
        self.subarticular_unet = subarticular_unet
        self.P = patch_size
    def __len__(self):
        return len(self.df)
    def __getitem__(self,index):
        df = self.df.iloc[index]
        study_id = df['study_id']
        series_id = df['series_id']
        series_description = df['series_description']
        images = self.load_images(study_id,series_id)
        sample = './datasets/rsna-2024-lumbar-spine-degenerative-classification/test_images/'
#         sample = './datasets/rsna-lsdc-2024-submission-debug-dataset/debug/test_images/'
        sample1 = sample + str(int(study_id)) +"/" +str(int(series_id)) +"/"
        if series_description == "Sagittal T1":
            images_and_mask_tensor = [] 
            for i in range(len(images)):
                images_and_mask = get_image_for_canal_foraminal(sample1+str(i+1)+'.dcm',self.foraminal_unet,self.P,0.25)
                condition = 0 if i<min(len(images)//2,10) else 1
                images_and_mask+= [torch.tensor(condition)]
                images_and_mask_tensor.append(images_and_mask)
            final_item = []
            final_item.append(images_and_mask_tensor)
            final_item += [torch.tensor(study_id)]
            final_item+=[torch.tensor(0)]
            return final_item
        elif series_description == "Sagittal T2/STIR":
            images_and_mask_tensor = [] 
            for i in range(len(images)):
                images_and_mask = get_image_for_canal_foraminal(sample1+str(i+1)+'.dcm',self.canal_unet,self.P,0.25)
                images_and_mask_tensor.append(images_and_mask)
            final_item = []
            final_item.append(images_and_mask_tensor)
            final_item += [torch.tensor(study_id)]
            final_item+=[torch.tensor(1)]
            return final_item
        else:
            coords,coords_mean = create_coords(self.subarticular_unet,images)
            coords_left = []
            coords_right = []
            for j in range(0,len(coords),2):
                coords_left.append(coords[j])
                coords_right.append(coords[j+1])
            left_images = get_images_for_axial_slice_labelling(coords_left,images)
            right_images = get_images_for_axial_slice_labelling(coords_right,images)
            left_levels = get_axial_slices(left_images)
            right_levels = get_axial_slices(right_images)
#             print(study_id,left_levels,right_levels)
            left_images = left_images[left_levels]
            right_images = right_images[right_levels]
            return [torch.stack([left_images,right_images]),torch.tensor(study_id),torch.tensor(2)]
        
    def load_images(self,study_id,series_id):
#         df = self.df[self.df['series_description']==self.series_desc]
        sample = './datasets/rsna-2024-lumbar-spine-degenerative-classification/test_images/'
#         sample = './datasets/rsna-lsdc-2024-submission-debug-dataset/debug/test_images/'
        sample1 = sample + str(int(study_id)) +"/" +str(int(series_id)) +"/"
        images = [x.replace('\\','/') for x in glob.glob(sample1+'*dcm')]
        return images

In [ ]:
row_names = []
final_preds = []
conditions = {
    0:"Neural Foraminal Stenosis",
    1:"Canal Stenosis",
    2:"Subarticular Stenosis"
}

In [ ]:
funet = torch.load("./datasets/rsna-models/SEG_foraminal_1")
cunet = torch.load("./datasets/rsna-models/SEG_canal_1")
subunet = torch.load("./datasets/axial-models/SEG_subarticular_resnet34_1")

foraminalClassifier = torch.load("./datasets/rsna-models/ViT_1")
canalClassifierModels = [torch.load("./datasets/canal-classifier-models/Final_ViT_canal_"+str(i)) for i in [1,3,4,5]]
subarticularClassifier = torch.load("./datasets/axial-models/VIT_Axial_Classifier_1")



testdf = RSNA24Dataset(df,foraminal_unet=funet,canal_unet=cunet,subarticular_unet=subunet)
testdl = torch.utils.data.DataLoader(testdf,batch_size=1,shuffle=False)
with tqdm(testdl, leave=True) as pbar:
    with torch.no_grad():
        for idx, (x,stid,series_id) in enumerate(pbar):
            series_desc = conditions[series_id.item()]
            if series_desc == "Neural Foraminal Stenosis":
                left_y_preds = np.zeros((5,3))
                right_y_preds = np.zeros((5,3))
                left_count=right_count=0
                for data in x:
                    i,m,ori = data
                    i = i.to(device)
                    m = m.to(device)
                    pred_per_study = np.zeros((5, 3))
                    preds = foraminalClassifier([i,m])
        #             pred_mask = ~m.view(-1)
                    if ori:
                        for j in range(5):
                            left_y_preds[j]+=np.array(torch.softmax(preds[j].cpu(),dim=0))
                        left_count+=1
                    else:
                        for j in range(5):
                            right_y_preds[j]+=np.array(torch.softmax(preds[j].cpu(),dim=0))
                        right_count+=1
                left_y_preds/=left_count
                right_y_preds/=right_count
                study_id = stid.item()
                for i in range(2):
                    for j in range(5):
                        row_names.append(str(study_id)+"_"+NEURAL_FORAMINAL_CONDITION[i]+"_"+LEVELS[j])
                final_preds.append(left_y_preds)
                final_preds.append(right_y_preds)
                del left_y_preds,x,right_y_preds,study_id,left_count,right_count
                gc.collect()
            elif series_desc == "Canal Stenosis":
                canal_y_preds=np.zeros((5,3))
                for data in x:
                    i,m = data
                    i = i.to(device)
                    m = m.to(device)
                    preds = torch.from_numpy(np.zeros((5,3))).to(device)
                    for j in range(4):
                        preds += canalClassifierModels[j]([i,m])
                    preds/=4
        #             pred_mask = ~m.view(-1)
                    for j in range(5):
                        canal_y_preds[j]+=np.array(torch.softmax(preds[j].cpu(),dim=0))    
                canal_y_preds/=len(x)
                study_id = stid.item()
                for i in range(1):
                    for j in range(5):
                        row_names.append(str(study_id)+"_"+SPINAL_CANAL_CONDITION[i]+"_"+LEVELS[j])
                final_preds.append(canal_y_preds)
                del canal_y_preds,x,study_id
                gc.collect()
            else: 
                subarticular_y_preds=np.zeros((10,3))
                x = x.view(-1,10,patch_size,patch_size)
                x = x.to(device)
                for j in range(10):
                    preds = subarticularClassifier(x[0][j])
                    subarticular_y_preds[j]+=np.array(torch.softmax(preds[0].cpu(),dim=0))          
                study_id=stid.item()
                for i in range(2):
                    for j in range(5):
                        row_names.append(str(study_id)+"_"+SUBARTICULAR_CONDITIONS[i]+"_"+LEVELS[j])
                final_preds.append(subarticular_y_preds)
                del subarticular_y_preds,study_id
                gc.collect()
                
del funet,cunet,subunet,foraminalClassifier,canalClassifierModels,subarticularClassifier
gc.collect()

In [ ]:
final_preds = np.concatenate(final_preds,axis=0)
sub = pd.DataFrame()
sub['row_id'] = row_names
sub[LABELS] = final_preds
sub.head()

In [ ]:
sub = sub.groupby('row_id').mean()
sub = sub.reset_index()
sub.head()

In [ ]:
sub.to_csv('./datasets/submission.csv', index=False)
pd.read_csv('./datasets/working/submission.csv').head()