# 1st Place Solution Training 2.5D Classification Type1

Hi all,

I'm very exciting to writing this notebook and the summary of our solution here.

This is small version of training my final models (stage2 type1), using efficientnetv2_s as backbone, and 224x224 as input.

After all stage1 models are trained, then we can use those model to predict 3D masks for all training samples (2k)

Then use those predicted masks to crop out all vertebraes (2k * 7 = 14k)

I'll skip the code of predicting 3D maks and cropping vertebraes, but just uploaded the dataset of cropped vertebraes (https://www.kaggle.com/datasets/haqishen/rsna-cropped-2d-224-0920-2m)

Now let's use this dataset to train a 2.5D classification with LSTM (Type1)

**NOTE: The training time is too long for Kaggle kernels so you should run it locally**

To see more details of my solution: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/discussion/362607

* Train Stage1 Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage1
* Train Stage2 (Type1) Notebook: This notebook
* Train Stage2 (Type2) Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage2-type2
* Inference Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-inference


**If you find these notebooks helpful please upvote. Thanks!**

In [None]:
DEBUG = False

In [None]:
import os
import sys
import gc
import ast
import cv2
import time
import timm
import pickle
import random
import argparse
import warnings
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

%matplotlib inline
rcParams['figure.figsize'] = 20, 8
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

# Config

In [None]:
image_size = 224
n_slice_per_c = 15
data_dir = '../input/rsna-cropped-2d-224-0920-2m/cropped_2d_224_15_ext0_5ch_0920_2m/cropped_2d_224_15_ext0_5ch_0920_2m'
use_amp = True
os.makedirs('./logs', exist_ok=True)
os.makedirs('./models', exist_ok=True)

In [None]:
transforms_train = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.Transpose(p=0.5),  # switch X and Y axis
    albumentations.RandomBrightness(limit=0.1, p=0.7),
    albumentations.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, border_mode=4, p=0.7),
    # Randomly apply affine transforms: translate, scale and rotate the input
    
    albumentations.OneOf([
        albumentations.MotionBlur(blur_limit=3),          # Apply motion blur to the input image using a random-sized kernel
        albumentations.MedianBlur(blur_limit=3),          # Blur the input image using a median filter with a random aperture linear size.
        albumentations.GaussianBlur(blur_limit=3),        # Blur the input image using a Gaussian filter with a random kernel size
        albumentations.GaussNoise(var_limit=(3.0, 9.0)),  # Apply gaussian noise to the input image
    ], p=0.5),
    
    # In medical imaging problems, non-rigid transformations help to augment the data. 
    # It is unclear if they will help with this problem, but let's look at them. We will consider ElasticTransform, GridDistortion, OpticalDistortion.
    # https://albumentations.ai/docs/examples/example_kaggle_salt/#opticaldistortion
    albumentations.OneOf([
        albumentations.OpticalDistortion(distort_limit=1.),
        albumentations.GridDistortion(num_steps=5, distort_limit=1.),
    ], p=0.5),

    albumentations.Cutout(max_h_size=int(image_size * 0.5), max_w_size=int(image_size * 0.5), num_holes=1, p=0.5),  # CoarseDropout of the square regions in the image
])

transforms_valid = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
])

# DataFrame

In [None]:
df = pd.read_csv(os.path.join(f'../input/rsna-cropped-2d-224-0920-2m/train_seg.csv'))
df = df.sample(16).reset_index(drop=True) if DEBUG else df


sid = []
cs = []
label = []
fold = []
for _, row in df.iterrows():
    for i in [1,2,3,4,5,6,7]:
        sid.append(row.StudyInstanceUID)
        cs.append(i)
        label.append(row[f'C{i}'])
        fold.append(row.fold)

df = pd.DataFrame({
    'StudyInstanceUID': sid,
    'c': cs,
    'label': label,
    'fold': fold
})
# 2018 x 7 个样本，每个病人的7节颈椎分别作为一个样本，标签是该颈椎是否骨折
df.tail()

# Dataset

In [None]:
class CLSDataset(Dataset):
    '''
    从某病人某节颈椎骨的众多CT图像中均匀采样15张图片，每张图片由附近5张切片和1张掩码拼接而成，返回的images维度为(15, 6, 224, 224)
    '''
    def __init__(self, df, mode, transform):

        self.df = df.reset_index()
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        cid = row.c
        
        images = []
        
        for ind in list(range(n_slice_per_c)):
            filepath = os.path.join(data_dir, f'{row.StudyInstanceUID}_{cid}_{ind}.npy')
            image = np.load(filepath)
            # image: (224, 224, 6)  0 - 255
            image = self.transform(image=image)['image']
            image = image.transpose(2, 0, 1).astype(np.float32) / 255.
            # image: (6, 224, 224)  0 - 1
            images.append(image)
        images = np.stack(images, 0)
        # images: (15, 6, 224, 224)  0 - 1

        if self.mode != 'test':
            images = torch.tensor(images).float()
            labels = torch.tensor([row.label] * n_slice_per_c).float()
            
            if self.mode == 'train' and random.random() < 0.2:
                indices = torch.randperm(images.size(0))
                images = images[indices]
                # # images: (15, 6, 224, 224)  0 - 1  打乱images中15张(6, 224, 224)'图片'的顺序

            return images, labels
            # torch.Size([15, 6, 224, 224]), torch.Size([15])
        else:
            return torch.tensor(images).float()

In [None]:
rcParams['figure.figsize'] = 20,8

df_show = df
dataset_show = CLSDataset(df_show, 'train', transform=transforms_train)
loader_show = torch.utils.data.DataLoader(dataset_show, batch_size=8, shuffle=True, num_workers=4)

In [None]:
f, axarr = plt.subplots(2,4)
for p in range(4):
    idx = p * 20
    imgs, lbl = dataset_show[idx]
    axarr[0, p].imshow(imgs[7][:3].permute(1, 2, 0))
    axarr[1, p].imshow(imgs[7][-1])

# Model

In [None]:
class TimmModel(nn.Module):
    '''
    将一个病人某颈椎骨的15张图片，分别用encoder编码，每个样本因此会得到15个1280向量
    将这15个1280向量通过双向双层LSTM提取信息，得到15个512维向量，每个512向量都包含了该病人该颈椎骨从前至后的所有信息
    因此一个病人某颈椎骨的15张图片，就可以作为15个样本
    '''
    def __init__(self, backbone, pretrained=False):
        super(TimmModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=6,
            num_classes=1,
            features_only=False,
            drop_rate=0.,
            drop_path_rate=0.,
            pretrained=pretrained
        )

        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            # hdim: 1200
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()


        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=0., bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1),
        )

    def forward(self, x):  # (8, 15, 6, 224, 224)
        bs = x.shape[0]
        x = x.view(bs * n_slice_per_c, 6, image_size, image_size)         # [120, 6, 224, 224]
        feat = self.encoder(x)                                            # [120, 1280]
        feat = feat.view(bs, n_slice_per_c, -1)                           # [8, 15, 1280]
        feat, _ = self.lstm(feat)                                         # [8, 15, 512]
        feat = feat.contiguous().view(bs * n_slice_per_c, -1)             # [120, 512]
        feat = self.head(feat)                                            # [120, 1]
        feat = feat.view(bs, n_slice_per_c).contiguous()                  # [8, 15]

        return feat
    

# batch size = 8
# ====================================================================================================
# Layer (type:depth-idx)                             Output Shape              Param #
# ====================================================================================================
# input shape: [8, 15, 6, 224, 224]
# TimmModel                                          [8, 15]                   --
# reshape:  [8, 15, 6, 224, 224]  →  [120, 6, 224, 224]
# ├─EfficientNet: 1-1                                [120, 1280]                --
# │    └─Conv2dSame: 2-1                             [120, 24, 112, 112]        1,296
# │    └─BatchNormAct2d: 2-2                         [120, 24, 112, 112]        48
# │    │    └─Identity: 3-1                          [120, 24, 112, 112]        --
# │    │    └─SiLU: 3-2                              [120, 24, 112, 112]        --
# │    └─Sequential: 2-3                             [120, 256, 7, 7]           --
# │    │    └─Sequential: 3-3                        [120, 24, 112, 112]        10,464
# │    │    └─Sequential: 3-4                        [120, 48, 56, 56]          303,552
# │    │    └─Sequential: 3-5                        [120, 64, 28, 28]          589,184
# │    │    └─Sequential: 3-6                        [120, 128, 14, 14]         917,680
# │    │    └─Sequential: 3-7                        [120, 160, 14, 14]         3,463,840
# │    │    └─Sequential: 3-8                        [120, 256, 7, 7]           14,561,832
# │    └─Conv2d: 2-4                                 [120, 1280, 7, 7]          327,680
# │    └─BatchNormAct2d: 2-5                         [120, 1280, 7, 7]          2,560
# │    │    └─Identity: 3-9                          [120, 1280, 7, 7]          --
# │    │    └─SiLU: 3-10                             [120, 1280, 7, 7]          --
# │    └─SelectAdaptivePool2d: 2-6                   [120, 1280]                --
# │    │    └─AdaptiveAvgPool2d: 3-11                [120, 1280, 1, 1]          --
# │    │    └─Flatten: 3-12                          [120, 1280]                --
# │    └─Identity: 2-7                               [120, 1280]                --
# reshape:  [120, 1280]  →  [8, 15, 1280]   8 x 15 个 [6, 224, 224] 的图像被 encoder 编码成 1280 的向量
# ├─LSTM: 1-2                                        [8, 15, 512]              4,726,784
# ├─Sequential: 1-3                                  [120, 1]                   --
# │    └─Linear: 2-8                                 [120, 256]                 131,328
# │    └─BatchNorm1d: 2-9                            [120, 256]                 512
# │    └─Dropout: 2-10                               [120, 256]                 --
# │    └─LeakyReLU: 2-11                             [120, 256]                 --
# │    └─Linear: 2-12                                [120, 1]                   257
# ====================================================================================================
# Total params: 25,037,017
# Trainable params: 25,037,017
# Non-trainable params: 0
# Total mult-adds (G): 85.85
# ====================================================================================================
# Input size (MB): 36.13
# Forward/backward pass size (MB): 2926.68
# Params size (MB): 99.53
# Estimated Total Size (MB): 3062.34
# ====================================================================================================

# Loss & Metric

In [None]:
bce = nn.BCEWithLogitsLoss(reduction='none')


def criterion(logits, targets, activated=False):
    if activated:
        losses = nn.BCELoss(reduction='none')(logits.view(-1), targets.view(-1))
    else:
        losses = bce(logits.view(-1), targets.view(-1))
    losses[targets.view(-1) > 0] *= 2.
    norm = torch.ones(logits.view(-1).shape[0]).to(device)
    norm[targets.view(-1) > 0] *= 2
    return losses.sum() / norm.sum()

# Train & Valid func

In [None]:
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam


def train_func(model, loader_train, optimizer, scaler=None):
    model.train()
    train_loss = []
    bar = tqdm(loader_train)
    for images, targets in bar:
        optimizer.zero_grad()
        images = images.cuda()
        targets = targets.cuda()
        # images: torch.Size([15, 6, 224, 224])   targets: torch.Size([15])
        
        do_mixup = False
        if random.random() < 0.5:
            do_mixup = True
            images, targets, targets_mix, lam = mixup(images, targets)

        with amp.autocast():
            logits = model(images)
            loss = criterion(logits, targets)
            if do_mixup:
                loss11 = criterion(logits, targets_mix)
                loss = loss * lam  + loss11 * (1 - lam)
        train_loss.append(loss.item())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bar.set_description(f'smth:{np.mean(train_loss[-30:]):.4f}')

    return np.mean(train_loss)


def valid_func(model, loader_valid):
    model.eval()
    valid_loss = []
    gts = []
    outputs = []
    bar = tqdm(loader_valid)
    with torch.no_grad():
        for images, targets in bar:
            images = images.cuda()
            targets = targets.cuda()

            logits = model(images)
            loss = criterion(logits, targets)
            
            gts.append(targets.cpu())
            outputs.append(logits.cpu())
            valid_loss.append(loss.item())
            
            bar.set_description(f'smth:{np.mean(valid_loss[-30:]):.4f}')

    outputs = torch.cat(outputs)
    gts = torch.cat(gts)
    valid_loss = criterion(outputs, gts).item()

    return valid_loss

In [None]:
# 学习率展示
rcParams['figure.figsize'] = 20, 5
m = TimmModel('tf_efficientnetv2_s_in21ft1k')
optimizer = optim.AdamW(m.parameters(), lr=23e-5)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 75, eta_min=23e-6)

lrs = []
for epoch in range(1, 75+1):
    scheduler_cosine.step(epoch-1)
    lrs.append(optimizer.param_groups[0]["lr"])
plt.plot(range(len(lrs)), lrs)
plt.title('Learning Rate (Epochs: 75   Max: 23e-5   Min: 23e-6)', fontsize=20)
plt.tick_params(labelsize=15)
plt.grid()
_ = plt.xticks(np.arange(0, 80, 5))
_ = plt.yticks([0, 5e-5, 10e-5, 15e-5, 20e-5, 25e-5])

# Training

In [None]:
def run(fold):
    kernel_type = '0920_1bonev2_effv2s_224_15_6ch_augv2_mixupp5_drl3_rov1p2_bs8_lr23e5_eta23e6_50ep'
    log_file = os.path.join('./logs', f'{kernel_type}.txt')
    model_file = os.path.join('./models', f'{kernel_type}_fold{fold}_best.pth')

    train_ = df[df['fold'] != fold].reset_index(drop=True)
    valid_ = df[df['fold'] == fold].reset_index(drop=True)
    dataset_train = CLSDataset(train_, 'train', transform=transforms_train)
    dataset_valid = CLSDataset(valid_, 'valid', transform=transforms_valid)
    loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=8, shuffle=True, num_workers=4, drop_last=True)
    loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=8, shuffle=False, num_workers=4)

    model = TimmModel('tf_efficientnetv2_s_in21ft1k', pretrained=True)
    model = model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=23e-5)
    scaler = torch.cuda.amp.GradScaler() if use_amp else None

    metric_best = np.inf
    loss_min = np.inf

    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 75, eta_min=23e-6)

    print(len(dataset_train), len(dataset_valid))

    for epoch in range(1, 75+1):
        scheduler_cosine.step(epoch-1)

        print(time.ctime(), 'Epoch:', epoch)

        train_loss = train_func(model, loader_train, optimizer, scaler)
        valid_loss = valid_func(model, loader_valid)
        metric = valid_loss

        content = time.ctime() + ' ' + f'Fold {fold}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {valid_loss:.5f}, metric: {(metric):.6f}.'
        print(content)
        with open(log_file, 'a') as appender:
            appender.write(content + '\n')

        if metric < metric_best:
            print(f'metric_best ({metric_best:.6f} --> {metric:.6f}). Saving model ...')
#             if not DEBUG:
            torch.save(model.state_dict(), model_file)
            metric_best = metric

        # Save Last
        if not DEBUG:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scaler_state_dict': scaler.state_dict() if scaler else None,
                    'score_best': metric_best,
                },
                model_file.replace('_best', '_last')
            )

    del model
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
run(0)
run(1)
run(2)
run(3)
run(4)