## Import the Libraries

In [1]:
import os
import gc
import sys
from PIL import Image
import cv2
import math, random
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold

from collections import OrderedDict

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

import timm
from timm.utils import ModelEmaV2
from transformers import get_cosine_schedule_with_warmup

import albumentations as A

from sklearn.model_selection import KFold

import re
import pydicom
from typing import Optional
import glob

  data = fetch_version_info()


## Load & Edit The Data

In [2]:
rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'
OUTPUT_DIR = f'/kaggle/input/rsna2024-lsdc-training-baseline/rsna24-results'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [3]:
# Configuration
EXP_NO = "006"
MODEL_DIR = f"/kaggle/input/rsna24-a-{EXP_NO}"
MODEL_NAME = "tf_efficientnet_b5.ns_jft_in1k"

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
N_WORKERS = os.cpu_count()
USE_AMP = True
SEED = 42
NUM_FOLDS = 5


IMG_SIZE = [512, 512]
IN_CHANS = 30
N_LABELS = 25
N_CLASSES = 3 * N_LABELS

BATCH_SIZE = 1

In [4]:
rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

In [5]:
df = pd.read_csv(f'{rd}/test_series_descriptions.csv')
df.head()

Unnamed: 0,study_id,series_id,series_description
0,44036939,2828203845,Sagittal T1
1,44036939,3481971518,Axial T2
2,44036939,3844393089,Sagittal T2/STIR


In [6]:
study_ids = list(df['study_id'].unique())

In [7]:
sample_sub = pd.read_csv(f'{rd}/sample_submission.csv')

In [8]:
LABELS = list(sample_sub.columns[1:])
LABELS

['normal_mild', 'moderate', 'severe']

In [9]:
# Conditions and Levels
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

In [10]:
# Helper functions
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

## Feature Engineering

In [11]:
class RSNA24TestDataset(Dataset):
    def __init__(self, df, study_ids, phase='test', transform=None):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform
        self.phase = phase
    
    def __len__(self):
        return len(self.study_ids)
    
    def get_img_paths(self, study_id, series_desc):
        pdf = self.df[self.df['study_id']==study_id]
        pdf_ = pdf[pdf['series_description']==series_desc]
        allimgs = []
        for i, row in pdf_.iterrows():
            pimgs = glob.glob(f'{rd}/test_images/{study_id}/{row["series_id"]}/*.dcm')
            pimgs = sorted(pimgs, key=natural_keys)
            allimgs.extend(pimgs)
            
        return allimgs
    
    def read_dcm_ret_arr(self, src_path):
        dicom_data = pydicom.dcmread(src_path)
        image = dicom_data.pixel_array
        image = (image - image.min()) / (image.max() - image.min() + 1e-6) * 255
        img = cv2.resize(image, (IMG_SIZE[0], IMG_SIZE[1]),interpolation=cv2.INTER_CUBIC)
        assert img.shape==(IMG_SIZE[0], IMG_SIZE[1])
        return img

    def __getitem__(self, idx):
        x = np.zeros((IMG_SIZE[0], IMG_SIZE[1], IN_CHANS), dtype=np.uint8)
        st_id = self.study_ids[idx]        
        
        # Sagittal T1
        allimgs_st1 = self.get_img_paths(st_id, 'Sagittal T1')
        if len(allimgs_st1)==0:
            print(st_id, ': Sagittal T1, has no images')
        
        else:
            step = len(allimgs_st1) / 10.0
            st = len(allimgs_st1)/2.0 - 4.0*step
            end = len(allimgs_st1)+0.0001
            for j, i in enumerate(np.arange(st, end, step)):
                try:
                    ind2 = max(0, int((i-0.5001).round()))
                    img = self.read_dcm_ret_arr(allimgs_st1[ind2])
                    x[..., j] = img.astype(np.uint8)
                except:
                    print(f'failed to load on {st_id}, Sagittal T1')
                    pass
            
        # Sagittal T2/STIR
        allimgs_st2 = self.get_img_paths(st_id, 'Sagittal T2/STIR')
        if len(allimgs_st2)==0:
            print(st_id, ': Sagittal T2/STIR, has no images')
            
        else:
            step = len(allimgs_st2) / 10.0
            st = len(allimgs_st2)/2.0 - 4.0*step
            end = len(allimgs_st2)+0.0001
            for j, i in enumerate(np.arange(st, end, step)):
                try:
                    ind2 = max(0, int((i-0.5001).round()))
                    img = self.read_dcm_ret_arr(allimgs_st2[ind2])
                    x[..., j+10] = img.astype(np.uint8)
                except:
                    print(f'failed to load on {st_id}, Sagittal T2/STIR')
                    pass
            
        # Axial T2
        allimgs_at2 = self.get_img_paths(st_id, 'Axial T2')
        if len(allimgs_at2)==0:
            print(st_id, ': Axial T2, has no images')
            
        else:
            step = len(allimgs_at2) / 10.0
            st = len(allimgs_at2)/2.0 - 4.0*step
            end = len(allimgs_at2)+0.0001

            for j, i in enumerate(np.arange(st, end, step)):
                try:
                    ind2 = max(0, int((i-0.5001).round()))
                    img = self.read_dcm_ret_arr(allimgs_at2[ind2])
                    x[..., j+20] = img.astype(np.uint8)
                except:
                    print(f'failed to load on {st_id}, Axial T2')
                    pass  
            
            
        if self.transform is not None:
            x = self.transform(image=x)['image']

        x = x.transpose(2, 0, 1)
                
        return x, str(st_id)

In [12]:
transforms_test = A.Compose([
    A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5)
])

In [13]:
test_ds = RSNA24TestDataset(df, study_ids, transform=transforms_test)
test_dl = DataLoader(
    test_ds, 
    batch_size=1, 
    shuffle=False,
    num_workers=N_WORKERS,
    pin_memory=True,
    drop_last=False
)

## Build The Model

In [14]:
class RSNA24Model(nn.Module):
    def __init__(
        self,
        model_name: str,
        pretrained: bool,
        features_only: bool,
        in_chans: int,
        n_classes: int,
        n_labels: int,
        loss_name: str,
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name=model_name,
            pretrained=pretrained, 
            features_only=features_only,
            in_chans=in_chans,
            num_classes=n_classes,
            global_pool='avg'
        )
        self.loss_fn = loss_name
        self.n_labels = n_labels
    
    def forward(
        self,
        x: torch.Tensor,
        y: Optional[torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        
        logits = self.model(x)
        
        output = {"logits": logits}
        if y is not None:
            loss = 0
            for col in range(self.n_labels):
                pred = logits[:,col*3:col*3+3]
                gt = y[:,col]
                loss = loss + self.loss_fn(pred, gt) / self.n_labels
            output["loss"] = loss
        
        return output

In [15]:
models = []

In [16]:
for i in range(NUM_FOLDS):
    cp = f"{MODEL_DIR}/{EXP_NO}-{i}/best_model.pth"
#     cp = f"{MODEL_DIR}/best_model.pth"
    print(f'loading {cp}...')
    model = RSNA24Model(MODEL_NAME, False, False, IN_CHANS, N_CLASSES, 25, 'dummy')
    state_dict = torch.load(cp)

    # 予期しないキーを削除
    if 'loss_fn.weight' in state_dict:
        del state_dict['loss_fn.weight']
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)
    models.append(model)

loading /kaggle/input/rsna24-a-006/006-0/best_model.pth...


  state_dict = torch.load(cp)


loading /kaggle/input/rsna24-a-006/006-1/best_model.pth...
loading /kaggle/input/rsna24-a-006/006-2/best_model.pth...
loading /kaggle/input/rsna24-a-006/006-3/best_model.pth...
loading /kaggle/input/rsna24-a-006/006-4/best_model.pth...


In [17]:
autocast = torch.cuda.amp.autocast(enabled=USE_AMP, dtype=torch.half)
y_preds = []
row_names = []

with tqdm(test_dl, leave=True) as pbar:
    with torch.no_grad():
        for idx, (x, si) in enumerate(pbar):
            x = x + torch.randn(x.size()) * 0.15
            x = x.to(device)
            pred_per_study = np.zeros((25, 3))
            
            for cond in CONDITIONS:
                for level in LEVELS:
                    row_names.append(si[0] + '_' + cond + '_' + level)
            
            with autocast:
                for m in models:
                    y = m(x, None)["logits"][0]
                    for col in range(N_LABELS):
                        pred = y[col*3:col*3+3]
                        y_pred = pred.float().softmax(0).cpu().numpy()
                        pred_per_study[col] += y_pred / len(models)
                y_preds.append(pred_per_study)

y_preds = np.concatenate(y_preds, axis=0)

  autocast = torch.cuda.amp.autocast(enabled=USE_AMP, dtype=torch.half)
100%|██████████| 1/1 [00:02<00:00,  2.42s/it]


## Release the output

In [18]:
sub = pd.DataFrame()
sub['row_id'] = row_names
sub[LABELS] = y_preds
sub.head(25)

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.324835,0.371123,0.304042
1,44036939_spinal_canal_stenosis_l2_l3,0.172546,0.41299,0.414463
2,44036939_spinal_canal_stenosis_l3_l4,0.191044,0.426968,0.381988
3,44036939_spinal_canal_stenosis_l4_l5,0.278035,0.32778,0.394185
4,44036939_spinal_canal_stenosis_l5_s1,0.79612,0.135892,0.067988
5,44036939_left_neural_foraminal_narrowing_l1_l2,0.457876,0.502826,0.039298
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.344226,0.511099,0.144675
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.304424,0.381381,0.314196
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.158516,0.356312,0.485172
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.137358,0.390592,0.47205


In [19]:
sub.to_csv('submission.csv', index=False)
pd.read_csv('submission.csv').head()

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.324835,0.371123,0.304042
1,44036939_spinal_canal_stenosis_l2_l3,0.172546,0.41299,0.414463
2,44036939_spinal_canal_stenosis_l3_l4,0.191044,0.426968,0.381988
3,44036939_spinal_canal_stenosis_l4_l5,0.278035,0.32778,0.394185
4,44036939_spinal_canal_stenosis_l5_s1,0.79612,0.135892,0.067988
