In [1]:
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from glob import glob
import os
from tqdm import tqdm
from scipy.ndimage import zoom
import pydicom
from torch.utils.data import Dataset,DataLoader
import albumentations as A
import torch
import torch.nn as nn
import os
import timm
import math

In [2]:
comp_dir = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification"
test_desc_dir = f"{comp_dir}/test_series_descriptions.csv"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
test_series_descriptions = pd.read_csv(test_desc_dir)
file_num_df = test_series_descriptions
file_nums = [len(glob(f"{comp_dir}/test_images/{row['study_id']}/{row['series_id']}/**")) for idx,row in file_num_df.iterrows()]
file_num_df = pd.concat([file_num_df,pd.Series(file_nums)],axis=1)
file_num_df.columns = ['study_id','series_id','series_description','file_len']
file_num_df = file_num_df[file_num_df.index.isin(file_num_df.groupby(['study_id','series_description'])['file_len'].idxmax().unique())]
study_ids = file_num_df['study_id'].unique()
series_ids = file_num_df['series_id'].unique()
test_series_descriptions = test_series_descriptions[test_series_descriptions['study_id'].isin(study_ids) & test_series_descriptions['series_id'].isin(series_ids)]

In [4]:
def keyFunc(e):
    return int(e.split('/')[-1][:-4])

In [5]:
class RSNA24TestDataset(Dataset):
    def __init__(self, df, study_ids, transform=None):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform
        
    def __len__(self):
        return len(self.study_ids)
    
    def _fetch_images(self, dirs):
        images = []
        for j, dcm_file in enumerate(dirs): 
            dcm = pydicom.dcmread(dcm_file)
            image = dcm.pixel_array
            if image.shape[0]<=512:
                resized = cv2.resize(image,(512,512),interpolation = cv2.INTER_CUBIC)
                resized = (resized - resized.min())/(resized.max()-resized.min() +1e-6) * 255
            else:
                resized = cv2.resize(image,(512,512),interpolation = cv2.INTER_AREA)
                resized = (resized - resized.min())/(resized.max()-resized.min() +1e-6) * 255
            images.append(resized)
        return np.transpose(np.stack(images),(1,2,0)).astype(np.uint8)
    
    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        df = self.df[self.df["study_id"]==study_id]
        
        ret_img = []
        
        row = df[df['series_description'] == 'Axial T2']
        dirs = glob(f"{comp_dir}/test_images/{row['study_id'].iloc[0]}/{row['series_id'].iloc[0]}/**")
        dirs.sort(key=keyFunc)
        num_files = len(dirs)
        if num_files < 10:
            img = self._fetch_images(dirs)
            img = zoom(img,(1,1,10/img.shape[2]))
        else:
            interval_len = num_files/20
            indexes = [int(np.floor(i*interval_len)) for i in range(10)]
            dirs = [dirs[index] for index in indexes]
            img = self._fetch_images(dirs)
        ret_img.append(img)
        
        row = df[df['series_description'] == 'Sagittal T1']
        dirs = glob(f"{comp_dir}/test_images/{row['study_id'].iloc[0]}/{row['series_id'].iloc[0]}/**")
        dirs.sort(key=keyFunc)
        num_files = len(dirs)
        if num_files < 10:
            img = self._fetch_images(dirs)
            img = zoom(img,(1,1,10/img.shape[2]))
        else:
            interval_len = num_files/10
            indexes = [int(np.floor(i*interval_len)) for i in range(10)]
            dirs = [dirs[index] for index in indexes]
            img = self._fetch_images(dirs)
        ret_img.append(img)
            
        row = df[df['series_description'] == 'Sagittal T2/STIR']
        dirs = glob(f"{comp_dir}/test_images/{row['study_id'].iloc[0]}/{row['series_id'].iloc[0]}/**")
        dirs.sort(key=keyFunc)
        num_files = len(dirs)
        if num_files < 10:
            img = self._fetch_images(dirs)
            img = zoom(img,(1,1,10/img.shape[2]))
        else:
            interval_len = num_files/10
            indexes = [int(np.floor(i*interval_len)) for i in range(10)]
            dirs = [dirs[index] for index in indexes]
            img = self._fetch_images(dirs)
        ret_img.append(img)
        
        ret_img = np.concatenate(ret_img,axis=2)
        
        if self.transform is not None:
            ret_img = self.transform(image=ret_img)['image']
        
        return np.transpose(ret_img, (2,0,1)), str(study_id)

In [6]:
transforms_test = A.Compose([
    A.Resize(512,512),
    A.Normalize(mean=0.5,std=0.5)
])

In [7]:
test_dataset = RSNA24TestDataset(test_series_descriptions,study_ids,transform=transforms_test)
test_dl = DataLoader(test_dataset, 
    batch_size=1, 
    shuffle=False,
    num_workers=os.cpu_count(),
    pin_memory=True,
    drop_last=False)

In [8]:
class RSNAModel(nn.Module):
    def __init__(self,in_c,n_classes):
        super(RSNAModel,self).__init__()
        self.encoder = timm.create_model("convnext_small",
                                        in_chans = in_c,
                                        num_classes = n_classes,
                                        pretrained=False)
    def forward(self,x):
        x = self.encoder(x)
        return x

In [9]:
models = []
model_pths = glob("/kaggle/input/rsna-lsdc-part-2/rsna24-results/best_wll_model_fold-*.pt")
for i, path in enumerate(model_pths):
    model = RSNAModel(30, 75)
    model.load_state_dict(torch.load(path))
    model.eval()
    model.to(device)
    models.append(model)

In [10]:
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

n_labels = 25

In [11]:
y_preds = []
row_names = []

with tqdm(test_dl, leave=True) as pbar:
    with torch.no_grad():
        for idx, (x, si) in enumerate(pbar):
            x = x.to('cuda').float()
            pred_per_study = np.zeros((25,3))
            for cond in CONDITIONS:
                for level in LEVELS:
                    row_names.append(si[0] + "_" + cond + '_' + level)
                
            for m in models:
                y = m(x)[0]
                for col in range(n_labels):
                    pred = y[col*3:col*3+3]
                    y_pred = pred.float().softmax(0).cpu().numpy()
                    pred_per_study[col] += y_pred/len(models)
            y_preds.append(pred_per_study)

y_preds = np.concatenate(y_preds, axis=0)

100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


In [12]:
LABELS = ['normal_mild', 'moderate', 'severe']
sub = pd.DataFrame()
sub['row_id'] = row_names
sub[LABELS] = y_preds

In [13]:
sub

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.002514,0.20049,0.796996
1,44036939_spinal_canal_stenosis_l2_l3,0.02597,0.306551,0.667478
2,44036939_spinal_canal_stenosis_l3_l4,0.066357,0.510091,0.423552
3,44036939_spinal_canal_stenosis_l4_l5,0.102749,0.650451,0.246799
4,44036939_spinal_canal_stenosis_l5_s1,0.002026,0.283956,0.714019
5,44036939_left_neural_foraminal_narrowing_l1_l2,8.1e-05,0.003821,0.996098
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.002115,0.012614,0.985271
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.019595,0.047626,0.932779
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.100027,0.123859,0.776114
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.157681,0.373463,0.468855


In [14]:
sub.to_csv('submission.csv',index=False)