In [1]:
import pandas as pd
import numpy as np
import cv2
from tqdm import tqdm
from glob import glob
from scipy.ndimage import zoom
from torch.utils.data import Dataset,DataLoader
import albumentations as A
import torch
import torch.nn as nn
import timm
from sklearn.model_selection import KFold
import math
import os

In [2]:
path = "/kaggle/input/rsna2024-lsdc-part-1-eda-and-preparing-data"

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
comp_dir = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification"
image_dir = f"{path}/cvt_png"
pd.set_option('future.no_silent_downcasting', True)

# Part 2.1 Dataset and Augs

In [5]:
train_y = pd.read_csv(f'{comp_dir}/train.csv')
study_ids = [int(study_id) for study_id in os.listdir(image_dir)]
train_y = train_y[train_y['study_id'].isin(study_ids)].reset_index(drop=True)
mapping = {"Normal/Mild":0,"Moderate":1,"Severe":2}
train_y = train_y.replace(mapping)

In [6]:
class RSNADataset(Dataset):
    def __init__(self,df:pd.DataFrame, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        study_id = self.df.iloc[idx,0]
        label = self.df.iloc[idx,1:].values.astype(np.int64)
        
        ret_image = []
        
        parent_path = f"{path}/cvt_png/{study_id}/Axial T2/"
        if len(os.listdir(parent_path))<20:
            image = [cv2.imread(image_dir,-1) for image_dir in glob(f"{parent_path}**")]
            image = np.transpose(np.stack(image),(1,2,0)).astype(np.uint8)
            image = zoom(image,(1,1,20/image.shape[2]))
        else:
            image = [cv2.imread(image_dir,-1) for image_dir in glob(f"{parent_path}**")]
            image = np.transpose(np.stack(image),(1,2,0)).astype(np.uint8)
        ret_image.append(image)    
        
        parent_path = f"{path}/cvt_png/{study_id}/Sagittal T1/"
        if len(os.listdir(parent_path))<10:
            image = [cv2.imread(image_dir,-1) for image_dir in glob(f"{parent_path}**")]
            image = np.transpose(np.stack(image),(1,2,0)).astype(np.uint8)
            image = zoom(image,(1,1,10/image.shape[2]))
        else:
            image = [cv2.imread(image_dir,-1) for image_dir in glob(f"{parent_path}**")]
            image = np.transpose(np.stack(image),(1,2,0)).astype(np.uint8)
        ret_image.append(image)
        
        parent_path = f"{path}/cvt_png/{study_id}/Sagittal T2/"
        if len(os.listdir(parent_path))<10:
            image = [cv2.imread(image_dir,-1) for image_dir in glob(f"{parent_path}**")]
            image = np.transpose(np.stack(image),(1,2,0)).astype(np.uint8)
            image = zoom(image,(1,1,10/image.shape[2]))
        else:
            image = [cv2.imread(image_dir,-1) for image_dir in glob(f"{parent_path}**")]
            image = np.transpose(np.stack(image),(1,2,0)).astype(np.uint8)
        ret_image.append(image)

        ret_image = np.concatenate(ret_image, axis=2)
        
        if self.transform is not None:
            image = self.transform(image=ret_image)['image']
        
        ret_image = ret_image.transpose(2,0,1)
        
        return ret_image, label

In [7]:
# Test
ret, label  = RSNADataset(train_y).__getitem__(1)
print(ret.shape)
del ret, label

(40, 224, 224)


In [8]:
transforms_train = A.Compose([
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=10, border_mode=0, p=0.7),
    A.RandomBrightnessContrast(p=0.7),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=0.7),
    A.OneOf([
        A.OpticalDistortion(),
        A.GridDistortion(),
        A.ElasticTransform(),
    ], p=0.7),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, min_holes=1, min_height=8, min_width=8, p=0.5),
#     A.CLAHE(clip_limit=4.0,p=0.7),
    A.Resize(224,224),
    A.Normalize(mean=0.5,std=0.5)
])



In [9]:
if not os.path.exists('rsna24-results'):
    os.makedirs('rsna24-results')       

# Part 2.2 Model

In [10]:
class RSNAModel(nn.Module):
    def __init__(self,in_c,n_classes):
        super(RSNAModel,self).__init__()
        self.encoder = timm.create_model("resnet18",
                                        in_chans = in_c,
                                        num_classes = 0,
                                        pretrained=False)
        self.encoder.fc = nn.Identity()
        
        self.lstm = nn.LSTM(512, 256, 3, bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(512,256),
            nn.Dropout(0.3),
            nn.LeakyReLU(0.1),
            nn.Linear(256,n_classes)
        )        
    def forward(self,x):
        x = self.encoder(x)
        x, _ = self.lstm(x)
        x = self.head(x)
        return x

In [11]:
# Gradient Scaler
scaler = torch.cuda.amp.GradScaler(init_scale=4096)

# K Fold Cross Validation
skf = KFold(n_splits=4, shuffle=True, random_state=8)

# Weights for training
weights = torch.tensor([1.0,3.0,10.0])

num_epochs = 40

# Part 2.3 Training

In [12]:
train_y_len = len(train_y)
for fold, (trn_idx,val_idx) in enumerate(skf.split(range(train_y_len))):
    model = RSNAModel(40,75)
    model = model.to(device)
    
    dataset = RSNADataset(train_y.iloc[trn_idx],
                          transform=transforms_train)
    dataloader = DataLoader(dataset,
                            batch_size=4,
                            shuffle=False,
                            pin_memory=True,
                            drop_last = False,
                            num_workers=0)
    
    v_dataset = RSNADataset(train_y.iloc[val_idx],
                            transform=transforms_train)
    v_dataloader = DataLoader(v_dataset, batch_size=1,
                              shuffle=False,
                              pin_memory=True,
                              drop_last = False,
                              num_workers=0)
    
    n_labels = 25
    
    optimizer = torch.optim.AdamW(model.parameters(),lr=0.01)
    criterion_train = nn.CrossEntropyLoss(weight=weights.to(device))
    criterion_val = nn.CrossEntropyLoss(weight=weights)
    
    best_loss = 10 
    best_wll = 10
    es_step = 0

    for epoch in range(1,1+num_epochs):
        model.train()
        total_loss = 0 
        with tqdm(dataloader, leave=True) as pbar:
            optimizer.zero_grad()
            for idx, (x,t) in enumerate(pbar):
                x = x.to(device).float()
                t = t.to(device).long()
                loss = 0
                y = model(x)
                for col in range(n_labels):
                    pred = y[:,col*3:col*3+3]
                    gt = t[:,col]
                    loss = loss + criterion_train(pred,gt)/n_labels
                total_loss += loss.item()

                if not math.isfinite(loss):
                    print(f"Loss is {loss}, stopping training")
                    sys.exit(1)

                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1e9)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        train_loss = total_loss/len(dataloader)
        print(f'train_loss: {train_loss:.6f}')

        total_loss = 0
        y_preds = []
        labels = []

        model.eval()
        with tqdm(v_dataloader,leave=True) as pbar:
            with torch.no_grad():
                for idx, (x,t) in enumerate(pbar):
                    x = x.to(device).float()
                    t = t.to(device).long()
                    loss = 0
                    y = model(x)
                    for col in range(n_labels):
                        pred = y[:, col*3:col*3+3]
                        gt = t[:,col]
                        loss = loss + criterion_train(pred,gt)/n_labels
                        y_preds.append(pred.float().cpu())
                        labels.append(gt.cpu())
                    total_loss += loss.item()
        val_loss = total_loss / len(v_dataloader)
        y_preds = torch.cat(y_preds,dim=0)
        labels = torch.cat(labels,dim=0)
        val_wll = criterion_val(y_preds,labels)

        print(f'val_loss: {val_loss:.6f}, val_wll: {val_wll:.6f}')
        if val_loss < best_loss or val_wll<best_wll:
            es_step = 0
            if val_loss<best_loss:
                print(f"epoch: {epoch}, best loss updated from {best_loss:.6f} to {val_loss:.6f}")
                best_loss = val_loss
            if val_wll<best_wll:
                print(f"epoch: {epoch}, best wll_metric updated from {best_wll:.6f} to {val_wll:.6f}")
                best_wll = val_wll
            fname = f'rsna24-results/best_wll_model_fold-{fold}.pt'
            torch.save(model.state_dict(),fname)                
        else:
            es_step += 1
            if es_step >= 3:
                print("early stopping")
                break

100%|██████████| 336/336 [08:06<00:00,  1.45s/it]


train_loss: 0.850750


100%|██████████| 448/448 [02:49<00:00,  2.64it/s]


val_loss: 0.666255, val_wll: 1.038749
epoch: 1, best loss updated from 10.000000 to 0.666255
epoch: 1, best wll_metric updated from 10.000000 to 1.038749


100%|██████████| 336/336 [03:03<00:00,  1.83it/s]


train_loss: 0.819541


100%|██████████| 448/448 [01:03<00:00,  7.04it/s]


val_loss: 0.670233, val_wll: 1.026076
epoch: 2, best wll_metric updated from 1.038749 to 1.026076


100%|██████████| 336/336 [03:00<00:00,  1.86it/s]


train_loss: 0.817695


100%|██████████| 448/448 [01:00<00:00,  7.37it/s]


val_loss: 0.654940, val_wll: 1.039246
epoch: 3, best loss updated from 0.666255 to 0.654940


100%|██████████| 336/336 [03:02<00:00,  1.84it/s]


train_loss: 0.810666


100%|██████████| 448/448 [01:02<00:00,  7.13it/s]


val_loss: 0.656660, val_wll: 1.039581


100%|██████████| 336/336 [03:01<00:00,  1.85it/s]


train_loss: 0.810403


100%|██████████| 448/448 [01:02<00:00,  7.14it/s]


val_loss: 0.652022, val_wll: 1.045124
epoch: 5, best loss updated from 0.654940 to 0.652022


100%|██████████| 336/336 [03:03<00:00,  1.83it/s]


train_loss: 0.811192


100%|██████████| 448/448 [01:01<00:00,  7.32it/s]


val_loss: 0.651085, val_wll: 1.060523
epoch: 6, best loss updated from 0.652022 to 0.651085


100%|██████████| 336/336 [02:58<00:00,  1.88it/s]


train_loss: 0.813242


100%|██████████| 448/448 [01:00<00:00,  7.35it/s]


val_loss: 0.645831, val_wll: 1.059536
epoch: 7, best loss updated from 0.651085 to 0.645831


100%|██████████| 336/336 [03:02<00:00,  1.84it/s]


train_loss: 0.811456


100%|██████████| 448/448 [01:03<00:00,  7.06it/s]


val_loss: 0.653102, val_wll: 1.049292


100%|██████████| 336/336 [03:02<00:00,  1.84it/s]


train_loss: 0.812176


100%|██████████| 448/448 [01:00<00:00,  7.35it/s]


val_loss: 0.660324, val_wll: 1.040800


100%|██████████| 336/336 [03:03<00:00,  1.83it/s]


train_loss: 0.813983


100%|██████████| 448/448 [00:59<00:00,  7.52it/s]


val_loss: 0.645594, val_wll: 1.061666
epoch: 10, best loss updated from 0.645831 to 0.645594


100%|██████████| 336/336 [02:59<00:00,  1.87it/s]


train_loss: 0.810210


100%|██████████| 448/448 [00:59<00:00,  7.50it/s]


val_loss: 0.651643, val_wll: 1.048604


100%|██████████| 336/336 [03:00<00:00,  1.87it/s]


train_loss: 0.809147


100%|██████████| 448/448 [01:02<00:00,  7.22it/s]


val_loss: 0.648377, val_wll: 1.052347


100%|██████████| 336/336 [02:57<00:00,  1.89it/s]


train_loss: 0.808444


100%|██████████| 448/448 [00:58<00:00,  7.60it/s]


val_loss: 0.640855, val_wll: 1.060712
epoch: 13, best loss updated from 0.645594 to 0.640855


100%|██████████| 336/336 [02:55<00:00,  1.92it/s]


train_loss: 0.808583


100%|██████████| 448/448 [01:02<00:00,  7.13it/s]


val_loss: 0.649102, val_wll: 1.051403


100%|██████████| 336/336 [02:58<00:00,  1.89it/s]


train_loss: 0.806590


100%|██████████| 448/448 [01:01<00:00,  7.34it/s]


val_loss: 0.650156, val_wll: 1.047414


100%|██████████| 336/336 [02:52<00:00,  1.94it/s]


train_loss: 0.807089


100%|██████████| 448/448 [01:00<00:00,  7.41it/s]


val_loss: 0.651823, val_wll: 1.051511
early stopping


100%|██████████| 336/336 [02:54<00:00,  1.93it/s]


train_loss: 0.835485


100%|██████████| 447/447 [00:59<00:00,  7.52it/s]


val_loss: 0.652913, val_wll: 1.047600
epoch: 1, best loss updated from 10.000000 to 0.652913
epoch: 1, best wll_metric updated from 10.000000 to 1.047600


100%|██████████| 336/336 [03:01<00:00,  1.85it/s]


train_loss: 0.816526


100%|██████████| 447/447 [01:00<00:00,  7.37it/s]


val_loss: 0.660442, val_wll: 1.026529
epoch: 2, best wll_metric updated from 1.047600 to 1.026529


100%|██████████| 336/336 [03:01<00:00,  1.85it/s]


train_loss: 0.808652


100%|██████████| 447/447 [00:58<00:00,  7.68it/s]


val_loss: 0.658806, val_wll: 1.016414
epoch: 3, best wll_metric updated from 1.026529 to 1.016414


100%|██████████| 336/336 [03:00<00:00,  1.86it/s]


train_loss: 0.808654


100%|██████████| 447/447 [00:59<00:00,  7.48it/s]


val_loss: 0.658932, val_wll: 1.012819
epoch: 4, best wll_metric updated from 1.016414 to 1.012819


100%|██████████| 336/336 [03:01<00:00,  1.85it/s]


train_loss: 0.807055


100%|██████████| 447/447 [00:59<00:00,  7.54it/s]


val_loss: 0.668674, val_wll: 1.008791
epoch: 5, best wll_metric updated from 1.012819 to 1.008791


100%|██████████| 336/336 [02:58<00:00,  1.88it/s]


train_loss: 0.807999


100%|██████████| 447/447 [01:00<00:00,  7.37it/s]


val_loss: 0.665307, val_wll: 1.011990


100%|██████████| 336/336 [03:02<00:00,  1.85it/s]


train_loss: 0.809103


100%|██████████| 447/447 [00:59<00:00,  7.57it/s]


val_loss: 0.668645, val_wll: 1.008918


100%|██████████| 336/336 [03:03<00:00,  1.83it/s]


train_loss: 0.806653


100%|██████████| 447/447 [01:03<00:00,  7.09it/s]


val_loss: 0.670073, val_wll: 1.006775
epoch: 8, best wll_metric updated from 1.008791 to 1.006775


100%|██████████| 336/336 [02:59<00:00,  1.88it/s]


train_loss: 0.831636


100%|██████████| 447/447 [00:59<00:00,  7.52it/s]


val_loss: 0.695003, val_wll: 1.039011


100%|██████████| 336/336 [03:01<00:00,  1.85it/s]


train_loss: 0.838705


100%|██████████| 447/447 [01:00<00:00,  7.36it/s]


val_loss: 0.647818, val_wll: 1.042713
epoch: 10, best loss updated from 0.652913 to 0.647818


100%|██████████| 336/336 [02:55<00:00,  1.91it/s]


train_loss: 0.826069


100%|██████████| 447/447 [01:00<00:00,  7.41it/s]


val_loss: 0.661105, val_wll: 1.020572


100%|██████████| 336/336 [02:59<00:00,  1.87it/s]


train_loss: 0.821267


100%|██████████| 447/447 [01:01<00:00,  7.28it/s]


val_loss: 0.653769, val_wll: 1.030317


100%|██████████| 336/336 [03:00<00:00,  1.86it/s]


train_loss: 0.816631


100%|██████████| 447/447 [00:58<00:00,  7.64it/s]


val_loss: 0.651170, val_wll: 1.023355
early stopping


100%|██████████| 336/336 [02:57<00:00,  1.89it/s]


train_loss: 0.838188


100%|██████████| 447/447 [01:00<00:00,  7.35it/s]


val_loss: 0.635824, val_wll: 1.033930
epoch: 1, best loss updated from 10.000000 to 0.635824
epoch: 1, best wll_metric updated from 10.000000 to 1.033930


100%|██████████| 336/336 [02:56<00:00,  1.91it/s]


train_loss: 0.811516


100%|██████████| 447/447 [00:58<00:00,  7.65it/s]


val_loss: 0.639083, val_wll: 1.025914
epoch: 2, best wll_metric updated from 1.033930 to 1.025914


100%|██████████| 336/336 [02:54<00:00,  1.93it/s]


train_loss: 0.810334


100%|██████████| 447/447 [01:00<00:00,  7.42it/s]


val_loss: 0.635739, val_wll: 1.029668
epoch: 3, best loss updated from 0.635824 to 0.635739


100%|██████████| 336/336 [03:00<00:00,  1.87it/s]


train_loss: 0.809286


100%|██████████| 447/447 [00:59<00:00,  7.50it/s]


val_loss: 0.640699, val_wll: 1.024688
epoch: 4, best wll_metric updated from 1.025914 to 1.024688


100%|██████████| 336/336 [02:58<00:00,  1.88it/s]


train_loss: 0.806498


100%|██████████| 447/447 [00:59<00:00,  7.48it/s]


val_loss: 0.633033, val_wll: 1.036090
epoch: 5, best loss updated from 0.635739 to 0.633033


100%|██████████| 336/336 [03:00<00:00,  1.87it/s]


train_loss: 0.803726


100%|██████████| 447/447 [00:59<00:00,  7.50it/s]


val_loss: 0.642756, val_wll: 1.014553
epoch: 6, best wll_metric updated from 1.024688 to 1.014553


100%|██████████| 336/336 [02:57<00:00,  1.90it/s]


train_loss: 0.802897


100%|██████████| 447/447 [01:00<00:00,  7.44it/s]


val_loss: 0.637560, val_wll: 1.026306


100%|██████████| 336/336 [02:59<00:00,  1.87it/s]


train_loss: 0.804176


100%|██████████| 447/447 [00:58<00:00,  7.62it/s]


val_loss: 0.640961, val_wll: 1.019452


100%|██████████| 336/336 [02:56<00:00,  1.90it/s]


train_loss: 0.801590


100%|██████████| 447/447 [00:57<00:00,  7.73it/s]


val_loss: 0.639729, val_wll: 1.020349
early stopping


100%|██████████| 336/336 [03:05<00:00,  1.82it/s]


train_loss: 0.832602


100%|██████████| 447/447 [01:00<00:00,  7.41it/s]


val_loss: 0.645243, val_wll: 1.143071
epoch: 1, best loss updated from 10.000000 to 0.645243
epoch: 1, best wll_metric updated from 10.000000 to 1.143071


100%|██████████| 336/336 [02:56<00:00,  1.90it/s]


train_loss: 0.810867


100%|██████████| 447/447 [01:00<00:00,  7.38it/s]


val_loss: 0.647718, val_wll: 1.108133
epoch: 2, best wll_metric updated from 1.143071 to 1.108133


100%|██████████| 336/336 [03:00<00:00,  1.86it/s]


train_loss: 0.806725


100%|██████████| 447/447 [00:57<00:00,  7.77it/s]


val_loss: 0.649705, val_wll: 1.128668


100%|██████████| 336/336 [02:58<00:00,  1.89it/s]


train_loss: 0.804079


100%|██████████| 447/447 [00:57<00:00,  7.74it/s]


val_loss: 0.644774, val_wll: 1.096835
epoch: 4, best loss updated from 0.645243 to 0.644774
epoch: 4, best wll_metric updated from 1.108133 to 1.096835


100%|██████████| 336/336 [02:56<00:00,  1.90it/s]


train_loss: 0.801638


100%|██████████| 447/447 [00:58<00:00,  7.67it/s]


val_loss: 0.647557, val_wll: 1.104710


100%|██████████| 336/336 [02:56<00:00,  1.90it/s]


train_loss: 0.800485


100%|██████████| 447/447 [01:01<00:00,  7.24it/s]


val_loss: 0.652372, val_wll: 1.073746
epoch: 6, best wll_metric updated from 1.096835 to 1.073746


100%|██████████| 336/336 [03:03<00:00,  1.83it/s]


train_loss: 0.800356


100%|██████████| 447/447 [00:57<00:00,  7.73it/s]


val_loss: 0.654061, val_wll: 1.069893
epoch: 7, best wll_metric updated from 1.073746 to 1.069893


100%|██████████| 336/336 [02:58<00:00,  1.88it/s]


train_loss: 0.797799


100%|██████████| 447/447 [00:59<00:00,  7.48it/s]


val_loss: 0.654772, val_wll: 1.071568


100%|██████████| 336/336 [02:54<00:00,  1.93it/s]


train_loss: 0.797584


100%|██████████| 447/447 [00:59<00:00,  7.54it/s]


val_loss: 0.651262, val_wll: 1.076146


100%|██████████| 336/336 [02:58<00:00,  1.88it/s]


train_loss: 0.797248


100%|██████████| 447/447 [00:58<00:00,  7.65it/s]


val_loss: 0.655322, val_wll: 1.062122
epoch: 10, best wll_metric updated from 1.069893 to 1.062122


100%|██████████| 336/336 [03:01<00:00,  1.85it/s]


train_loss: 0.798165


100%|██████████| 447/447 [00:58<00:00,  7.60it/s]


val_loss: 0.660487, val_wll: 1.057681
epoch: 11, best wll_metric updated from 1.062122 to 1.057681


100%|██████████| 336/336 [02:59<00:00,  1.87it/s]


train_loss: 0.805952


100%|██████████| 447/447 [00:59<00:00,  7.51it/s]


val_loss: 0.653382, val_wll: 1.117050


100%|██████████| 336/336 [02:58<00:00,  1.88it/s]


train_loss: 0.819752


100%|██████████| 447/447 [00:58<00:00,  7.58it/s]


val_loss: 0.644378, val_wll: 1.106303
epoch: 13, best loss updated from 0.644774 to 0.644378


100%|██████████| 336/336 [02:57<00:00,  1.89it/s]


train_loss: 0.810617


100%|██████████| 447/447 [00:58<00:00,  7.66it/s]


val_loss: 0.640277, val_wll: 1.120474
epoch: 14, best loss updated from 0.644378 to 0.640277


100%|██████████| 336/336 [02:58<00:00,  1.88it/s]


train_loss: 0.807823


100%|██████████| 447/447 [01:00<00:00,  7.39it/s]


val_loss: 0.650736, val_wll: 1.092726


100%|██████████| 336/336 [03:08<00:00,  1.78it/s]


train_loss: 0.808745


100%|██████████| 447/447 [01:04<00:00,  6.91it/s]


val_loss: 0.647151, val_wll: 1.096557


100%|██████████| 336/336 [03:09<00:00,  1.77it/s]


train_loss: 0.807492


100%|██████████| 447/447 [01:04<00:00,  6.97it/s]

val_loss: 0.651283, val_wll: 1.090462
early stopping



