In [1]:
import albumentations as A

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import timm
import pandas as pd
import numpy as np
import pickle

import matplotlib.pyplot as plt
from transformers import get_cosine_schedule_with_warmup

from sklearn.model_selection import KFold
from tqdm import tqdm
import math
from collections import OrderedDict

import time

from lib import *

In [2]:


# Augmentation config
AUG = True
AUG_PROB = 0.75

IMG_SIZE = [512, 512]
MAX_SEQ_IMGS = 32

N_LABELS = 25
N_CLASSES = 3 * N_LABELS

N_FOLDS = 5
EPOCHS = 5 

SEED = 8620
N_WORKERS=4

# Model Config
OUTPUT_DIR="models"
MODEL_NAME = "tf_efficientnet_b3.ns_jft_in1k"

# Device Config
USE_AMP = True
DEVICE="cuda:0"

# BATCH SIZE

TGT_BATCH_SIZE = 32
GRAD_ACC = 8
BATCH_SIZE = TGT_BATCH_SIZE // GRAD_ACC
# MAX_GRAD_NORM = None
# EARLY_STOPPING_EPOCH = 3


In [3]:
df = pd.read_csv('train.csv')
df = df.fillna(-100)
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe':2}
df = df.replace(label2id)

  df = df.replace(label2id)


In [4]:
class RSNA24Dataset(Dataset):
    def __init__(self, df, phase='train', transform=None):
        self.df = df
        self.transform = transform
        self.phase = phase
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        x = np.zeros((512, 512, MAX_SEQ_IMGS), dtype=np.uint8)

        # 4 = Patient pos (x,y,z) + slice location
        pos_x = np.zeros((MAX_SEQ_IMGS, 3))

        patient_id = self.df.iloc[idx]["study_id"]
        label = self.df.iloc[idx][1:].values.astype(np.int64)
        patient_info_path = f"./processed-dataset/{patient_id}.pkl"

        with open(patient_info_path, "rb") as f:
            data = pickle.load(f)

        scans_used = []

        for series_info in data["series"].values():

            scan_type = series_info["series_description"]

            if scan_type in scans_used:
                continue

            scans = series_info["images"]

            if scan_type == "Sagittal T2/STIR":

                for i in range(min(10, len(scans))):
                    x[..., i] = scans[i]["img"].astype(np.uint8)
                    pos_x[i] = np.array(list(scans[i]["position_patient"]))

            elif scan_type == "Sagittal T1":

                for i in range(min(10, len(scans))):
                    x[..., i+10] = scans[i]["img"].astype(np.uint8)
                    pos_x[i+10] = np.array(list(scans[i]["position_patient"]))

            elif scan_type == "Axial T2":

                for i in range(min(12,len(scans))):
                    x[..., i+20] = scans[i]["img"].astype(np.uint8)
                    pos_x[i+20] = np.array(list(scans[i]["position_patient"]))
            else:
                raise ValueError(f"unknown series_description: {series_info["series_description"]}")

            scans_used.append(scan_type)

        if self.transform is not None:
            x = self.transform(image=x)['image']

        x = x.transpose(2, 0, 1)

        pos_x = pos_x.astype(np.float32)
        pos_x = pos_x - pos_x.min(0)

        return (x, pos_x), label


In [5]:
transforms_train = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=AUG_PROB),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=AUG_PROB),

    A.OneOf([
        A.OpticalDistortion(distort_limit=1.0),
        A.GridDistortion(num_steps=5, distort_limit=1.),
        A.ElasticTransform(alpha=3),
    ], p=AUG_PROB),

    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=AUG_PROB),
    A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.CoarseDropout(max_holes=16, max_height=64, max_width=64, min_holes=1, min_height=8, min_width=8, p=AUG_PROB),    
    A.Normalize(mean=0.5, std=0.5)
])

transforms_val = A.Compose([
    A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5)
])



  self.__pydantic_validator__.validate_python(data, self_instance=self)


In [6]:
class Encoder(nn.Module):
    def __init__(self, embed_dim=768, hidden_dim=1536, num_heads=12, dropout=0.1):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.ln_1 = nn.LayerNorm(self.embed_dim)
        self.ln_2 = nn.LayerNorm(self.embed_dim)

        self.key_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.query_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim, self.embed_dim)
        
        self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, bias=False)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
        )
    def forward(self, x, x_pos):
        x = self.ln_1(x)

        q,k,v = self.query_proj(x), self.key_proj(x), self.value_proj(x)
        x = x + self.dropout(self.mha(q,k,v, need_weights=False)[0])
        
        x = self.ln_2(x)
        x = x + self.dropout(self.mlp(x))
        
        return x

    

class RSNA24Model(nn.Module):
    def __init__(self, model_name, n_classes=0, pos_in_features=3, blocks=2, pretrained=True, features_only=False):
        super().__init__()
        
        self.pos_in_features = pos_in_features
        self.embed_dim = 768

        self.feature_extractor =  timm.create_model( model_name
                              , pretrained=pretrained
                              , num_classes=n_classes
                              , global_pool='avg' 
                              , features_only=features_only
                             )

        self.conv_stem = nn.Sequential( nn.Conv2d(1,3, kernel_size=(1,1))
                                                  
                                      # NOTE: Try with activation
                                      # , stride=(1,1)) ,nn.SiLU(inplace=True) 
                                      )

        self.img_feature_dim  = 1536
        
        self.pos_embedding = nn.Sequential(
            nn.Linear(pos_in_features, self.embed_dim),
            )
        self.img_feature_proj = nn.Linear(self.img_feature_dim, self.embed_dim)
        
        self.encoders = nn.ModuleList([Encoder() for _ in range(blocks)])
        self.classifier = nn.Linear(self.embed_dim, N_CLASSES)


        # NOTE: This is idea for combining kernel weights of the first conv layer
        # self.feature_extractor.conv_stem.weight = nn.Parameter(self.feature_extractor.conv_stem.weight.sum(dim=1, keepdim=True))
        # conv1_weight = self.feature_extractor.state_dict['conv_stem']
        # self.feature_extractor.state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
    
    def forward(self, x, x_pos, freeze_conv=True):

        batch_size, seq_len, h, w = x.shape

        # x_pos: BATCH_SIZE, NUMBER OF SCANS, SCAN POSITIONS

        x = x.view(-1, 1, h, w)
        x = self.conv_stem(x)
        if freeze_conv:
            with torch.no_grad(): x = self.feature_extractor(x)
        else:

            n = 4
            with torch.no_grad():
                x = self.feature_extractor.conv_stem(x)
                x = self.feature_extractor.bn1(x)

                for i in range(n):
                    x = self.feature_extractor.blocks[i](x)

            for i in range(n,7):
                x = self.feature_extractor.blocks[i](x)
            x = self.feature_extractor.conv_head(x)
            x = self.feature_extractor.bn2(x)
            x = self.feature_extractor.global_pool(x)
            x = self.feature_extractor.classifier(x)

        x = x.view(batch_size, seq_len, -1)
        
        # x_pos = self.pos_embedding(x_pos)

        x = self.img_feature_proj(x) + self.pos_embedding(x_pos)

        for enc in self.encoders:
            x = enc(x, None)

        x = self.classifier(x)
        return x

In [7]:
LR = 1e-4
skf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

In [8]:
def train(train_dl, model, optimizer, criterion, freeze=True):

    model.train()
    optimizer.zero_grad()
    
    training_loss = []
        
    for idx, (batch_input, y) in enumerate(pbar := tqdm(train_dl)):
        
        x,x_pos = batch_input[0].to(DEVICE), batch_input[1].to(DEVICE)
        y = y.to(DEVICE)
        
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            pred = model(x,x_pos, freeze)[:, 0, :]
            pred = pred.reshape(-1, 3)
            y = y.view(-1)
            loss = criterion(pred,y) / GRAD_ACC
            
        loss.backward()

        training_loss.append(loss.item() * GRAD_ACC)

        if (idx + 1) % GRAD_ACC == 0:
            optimizer.step()
            optimizer.zero_grad()

        pbar.set_description(f"loss: {loss.item() * GRAD_ACC:.6f}, training_loss: {np.mean(training_loss):.6f}")
                
def validation(val_dl, model, criterion):
    
        model.eval()
    
        y_preds = []
        labels = []
        avg_val_loss = []


        with torch.no_grad():
    
            for idx, (batch_input, y) in enumerate(pbar := tqdm(val_dl)):
                
                x,x_pos = batch_input[0].to(DEVICE), batch_input[1].to(DEVICE)
                y = y.to(DEVICE)
    
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    pred = model(x,x_pos)[:, 0, :]
                    pred = pred.reshape(-1, 3)
                    y = y.view(-1)
                    loss = criterion(pred,y)
    
    
                y_preds.append(pred)
                labels.append(y)
    
                avg_val_loss.append(loss.item())
    
                pbar.set_description(f"loss: {loss.item():.6f}, validation_loss: {np.mean(avg_val_loss):.6f}")
    
        y_preds = torch.cat(y_preds, dim=0)
        labels = torch.cat(labels)
    
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            final_val_loss = criterion(y_preds, labels).item()
            print(f"Final loss on validation set: {final_val_loss:.6f}")

In [9]:


for fold, (trn_idx, val_idx) in enumerate(skf.split(range(len(df)))):

    model = RSNA24Model(MODEL_NAME).to(DEVICE)
    
    optimizer = AdamW(model.parameters(), lr=LR)
    weights = torch.tensor([1.0, 2.0, 4.0])
    criterion = nn.CrossEntropyLoss(weight=weights.to(DEVICE))

    df_train = df.iloc[trn_idx]
    df_valid = df.iloc[val_idx]
    
    train_ds = RSNA24Dataset(df_train, transform=transforms_train)
    train_dl = DataLoader(
                train_ds,
                batch_size=BATCH_SIZE,
                shuffle=True,
                pin_memory=True,
                drop_last=True,
                num_workers=N_WORKERS
                )
    
    valid_ds = RSNA24Dataset(df_valid, transform=transforms_val)
    valid_dl = DataLoader(
                    valid_ds,
                    batch_size=BATCH_SIZE,
                    shuffle=False,
                    pin_memory=True,
                    drop_last=False,
                    num_workers=N_WORKERS
                    )

    for i in range(3):
        print(f"EPOCH: {i + 1}")

        # if i == 0:
        #     train(train_dl, model, optimizer, criterion, freeze=True)
        # else:
        train(train_dl, model, optimizer, criterion, freeze=False)
        validation(valid_dl, model, criterion)
    break

    

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/tf_efficientnet_b3.ns_jft_in1k)
INFO:timm.models._hub:[timm/tf_efficientnet_b3.ns_jft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


EPOCH: 1


loss: 0.922102, training_loss: 0.887149: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 395/395 [02:22<00:00,  2.77it/s]
loss: 0.661901, validation_loss: 0.842632: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:18<00:00,  5.24it/s]


Final loss on validation set: 0.862793
EPOCH: 2


loss: 0.751310, training_loss: 0.799280: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 395/395 [02:22<00:00,  2.78it/s]
loss: 0.740905, validation_loss: 0.769233: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:18<00:00,  5.24it/s]


Final loss on validation set: 0.784941
EPOCH: 3


loss: 0.925840, training_loss: 0.754734: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 395/395 [02:22<00:00,  2.78it/s]
loss: 0.585909, validation_loss: 0.726168: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:18<00:00,  5.26it/s]

Final loss on validation set: 0.742697





632.0