# 1st Place Solution Training 2.5D Classification Type2

Hi all,

I'm very exciting to writing this notebook and the summary of our solution here.

This is small version of training my final models (stage2 type2), using convnext nano as backbone, and 224x224 as input.

After all stage1 models are trained, then we can use those model to predict 3D masks for all training samples (2k)

Then use those predicted masks to crop out all vertebraes (2k * 7 = 14k)

I'll skip the code of predicting 3D maks and cropping vertebraes, but just uploaded the dataset of cropped vertebraes (https://www.kaggle.com/datasets/haqishen/rsna-cropped-2d-224-0920-2m)

Now let's use this dataset to train a 2.5D classification with LSTM (Type2)

**NOTE: You should run it locally because it take too much GPU memory and RAM**

To see more details of my solution: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/discussion/362607

* Train Stage1 Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage1
* Train Stage2 (Type1) Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage2-type1
* Train Stage2 (Type2) Notebook: This notebook
* Inference Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-inference


**If you find these notebooks helpful please upvote. Thanks! **

In [None]:
DEBUG = False

In [None]:
import os
import sys
import gc
import ast
import cv2
import time
import timm
import pickle
import random
import argparse
import warnings
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

%matplotlib inline
rcParams['figure.figsize'] = 20, 8
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

# Config

In [None]:
kernel_type = '0920_2d_lstmv22headv2_convnn_224_15_6ch_8flip_augv2_drl3_rov1p2_rov3p2_bs4_lr6e5_eta6e6_lw151_50ep'
image_size = 224
n_slice_per_c = 15
data_dir = '../input/rsna-cropped-2d-224-0920-2m/cropped_2d_224_15_ext0_5ch_0920_2m/cropped_2d_224_15_ext0_5ch_0920_2m'
use_amp = True
os.makedirs('./logs', exist_ok=True)
os.makedirs('./models', exist_ok=True)

In [None]:
transforms_train = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.Transpose(p=0.5),  # switch X and Y axis
    albumentations.RandomBrightness(limit=0.1, p=0.7),
    albumentations.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, border_mode=4, p=0.7),
    # Randomly apply affine transforms: translate, scale and rotate the input
    
    albumentations.OneOf([
        albumentations.MotionBlur(blur_limit=3),          # Apply motion blur to the input image using a random-sized kernel
        albumentations.MedianBlur(blur_limit=3),          # Blur the input image using a median filter with a random aperture linear size.
        albumentations.GaussianBlur(blur_limit=3),        # Blur the input image using a Gaussian filter with a random kernel size
        albumentations.GaussNoise(var_limit=(3.0, 9.0)),  # Apply gaussian noise to the input image
    ], p=0.5),
    
    # In medical imaging problems, non-rigid transformations help to augment the data. 
    # It is unclear if they will help with this problem, but let's look at them. We will consider ElasticTransform, GridDistortion, OpticalDistortion.
    # https://albumentations.ai/docs/examples/example_kaggle_salt/#opticaldistortion
    albumentations.OneOf([
        albumentations.OpticalDistortion(distort_limit=1.),
        albumentations.GridDistortion(num_steps=5, distort_limit=1.),
    ], p=0.5),

    albumentations.Cutout(max_h_size=int(image_size * 0.5), max_w_size=int(image_size * 0.5), num_holes=1, p=0.5),  # CoarseDropout of the square regions in the image
])

transforms_valid = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
])

# DataFrame

In [None]:
df = pd.read_csv('../input/rsna-cropped-2d-224-0920-2m/train_seg.csv')
df = df.sample(16).reset_index(drop=True) if DEBUG else df
df.head()

# Dataset

In [None]:
class CLSDataset(Dataset):
    # 提取出某患者7个颈椎骨的所有切片图像，每个颈椎骨能提取出15个[6, 224, 224]的图像，因此一个患者共可以提取出7*15个[6, 224, 224]的图像
    # 训练时有一定概率将7*15个图像的顺序打乱，标签也会跟着打乱
    def __init__(self, df, mode, transform):

        self.df = df.reset_index()
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        
        images = []
        
        tmp = list(range(7))
        if self.mode == 'train' and random.random() < 0.2:
            random.shuffle(tmp)
        for cid in (tmp):
            for ind in list(range(n_slice_per_c)):
                filepath = os.path.join(data_dir, f'{row.StudyInstanceUID}_{cid+1}_{ind}.npy')
                image = np.load(filepath)
                # image:  [224, 224, 6]  0 - 255
                image = self.transform(image=image)['image']
                image = image.transpose(2, 0, 1).astype(np.float32) / 255.
                # image:  [6, 224, 224]  0 - 1
                images.append(image)
        images = np.stack(images, 0)
        # images: [7x15, 6, 224, 224]

        if self.mode != 'test':
            labels = []
            for i in row[[f'C{x+1}' for x in tmp]].tolist():
                labels += [i] * n_slice_per_c
            images = torch.tensor(images).float()
            # images: torch.Size([7x15, 6, 224, 224])
            labels = torch.tensor(labels).float()
            # labels: torch.Size([7x15])
            
            if self.mode == 'train' and random.random() < 0.2:
                indices = torch.randperm(images.size(0))
                images = images[indices]
                labels = labels[indices]

            return images, labels
            # images: torch.Size([7x15, 6, 224, 224])   labels: torch.Size([7x15])
        else:
            return torch.tensor(images).float()

In [None]:
rcParams['figure.figsize'] = 20,8

df_show = df
dataset_show = CLSDataset(df_show, 'train', transform=transforms_train)
loader_show = torch.utils.data.DataLoader(dataset_show, batch_size=8, shuffle=True, num_workers=4)

In [None]:
rcParams['figure.figsize'] = 16,8
f, axarr = plt.subplots(2,4)
for p in range(4):
    idx = p * 20
    imgs, lbl = dataset_show[idx]
    axarr[0, p].imshow(imgs[35][:3].permute(1, 2, 0))
    axarr[1, p].imshow(imgs[35][-1], cmap='gray')

# Model

In [None]:
class TimmModelType2(nn.Module):
    def __init__(self, backbone, pretrained=False):
        super(TimmModelType2, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=6,
            num_classes=1,
            features_only=False,
            drop_rate=0.,
            drop_path_rate=0.,
            pretrained=pretrained
        )

        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()

        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=0., bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1),
        )
        self.lstm2 = nn.LSTM(hdim, 256, num_layers=2, dropout=0., bidirectional=True, batch_first=True)
        self.head2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1),
        )

    def forward(self, x):  # (bs, nc*7, ch, sz, sz)
        bs = x.shape[0]                                                  # 8
        x = x.view(bs * n_slice_per_c * 7, 6, image_size, image_size)    # [8x7x15, 6, 224, 224]
        feat = self.encoder(x)                                           # [8x7x15, 640]
        feat = feat.view(bs, n_slice_per_c * 7, -1)                      # [8, 7x15, 640]
        feat1, _ = self.lstm(feat)                                       # [8, 7x15, 512]
        feat1 = feat1.contiguous().view(bs * n_slice_per_c * 7, 512)     # [8x7x15, 512]
        feat2, _ = self.lstm2(feat)                                      # [8, 7x15, 512]
 
        return self.head(feat1), self.head2(feat2[:, 0])
        # [8x7x15, 1]   [8, 1]
    

# batch size = 8
# ====================================================================================================
# Layer (type:depth-idx)                             Output Shape              Param #
# ====================================================================================================
# TimmModelType2                                     [840, 1]                  --
# ├─ConvNeXt: 1-1                                    [840, 640]                --
# │    └─Sequential: 2-1                             [840, 80, 56, 56]         --
# │    │    └─Conv2d: 3-1                            [840, 80, 56, 56]         7,760
# │    │    └─LayerNorm2d: 3-2                       [840, 80, 56, 56]         160
# │    └─Sequential: 2-2                             [840, 640, 7, 7]          --
# │    │    └─ConvNeXtStage: 3-3                     [840, 80, 56, 56]         111,680
# │    │    └─ConvNeXtStage: 3-4                     [840, 160, 28, 28]        479,680
# │    │    └─ConvNeXtStage: 3-5                     [840, 320, 14, 14]        6,907,520
# │    │    └─ConvNeXtStage: 3-6                     [840, 640, 7, 7]          7,448,320
# │    └─Identity: 2-3                               [840, 640, 7, 7]          --
# │    └─Sequential: 2-4                             --                        --
# │    │    └─SelectAdaptivePool2d: 3-7              [840, 640, 1, 1]          --
# │    │    └─LayerNorm2d: 3-8                       [840, 640, 1, 1]          1,280
# │    │    └─Flatten: 3-9                           [840, 640]                --
# │    │    └─Dropout: 3-10                          [840, 640]                --
# │    │    └─Identity: 3-11                         [840, 640]                --
# ├─LSTM: 1-2                                        [8, 105, 512]             3,416,064
# ├─LSTM: 1-3                                        [8, 105, 512]             3,416,064
# ├─Sequential: 1-4                                  [840, 1]                  --
# │    └─Linear: 2-5                                 [840, 256]                131,328
# │    └─BatchNorm1d: 2-6                            [840, 256]                512
# │    └─Dropout: 2-7                                [840, 256]                --
# │    └─LeakyReLU: 2-8                              [840, 256]                --
# │    └─Linear: 2-9                                 [840, 1]                  257
# ├─Sequential: 1-5                                  [8, 1]                    --
# │    └─Linear: 2-10                                [8, 256]                  131,328
# │    └─BatchNorm1d: 2-11                           [8, 256]                  512
# │    └─Dropout: 2-12                               [8, 256]                  --
# │    └─LeakyReLU: 2-13                             [8, 256]                  --
# │    └─Linear: 2-14                                [8, 1]                    257
# ====================================================================================================
# Total params: 22,052,722
# Trainable params: 22,052,722
# Non-trainable params: 0
# Total mult-adds (T): 2.08
# ====================================================================================================
# Input size (MB): 1011.55
# Forward/backward pass size (MB): 69769.34
# Params size (MB): 88.19
# Estimated Total Size (MB): 70869.08
# ====================================================================================================

# Loss & Metric

In [None]:
bce = nn.BCEWithLogitsLoss(reduction='none')


def criterion(logits, targets, activated=False):
    if activated:
        losses = nn.BCELoss(reduction='none')(logits.view(-1), targets.view(-1))
    else:
        losses = bce(logits.view(-1), targets.view(-1))
    losses[targets.view(-1) > 0] *= 2.
    norm = torch.ones(logits.view(-1).shape[0]).to(device)
    norm[targets.view(-1) > 0] *= 2
    return losses.sum() / norm.sum()

# Train & Valid func

In [None]:
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam


def train_func(model, loader_train, optimizer, scaler=None):
    model.train()
    train_loss = []
    train_loss1 = []
    train_loss2 = []
    bar = tqdm(loader_train)
    for images, targets in bar:
        # images: torch.Size([8, 7x15, 6, 224, 224])   targets: torch.Size([8, 7x15])
        optimizer.zero_grad()
        images = images.cuda()
        targets = targets.cuda()
        
        do_mixup = False
        if random.random() < 0.5:
            do_mixup = True
            images, targets, targets_mix, lam = mixup(images, targets)

        with amp.autocast():
            logits, logits2 = model(images)
            # logits: [8x7x15, 1]   logits2: [8, 1]
            loss1 = criterion(logits, targets)
            loss2 = criterion(logits2, targets.max(1).values)
            loss = (loss1 * 15 + loss2 * 1) / 16
            if do_mixup:
                loss11 = criterion(logits, targets_mix)
                loss22 = criterion(logits2, targets_mix.max(1).values)
                loss = loss * lam  + (loss11 * 15 + loss22 * 1) / 16 * (1 - lam)
        train_loss1.append(loss1.item())
        train_loss2.append(loss2.item())
        train_loss.append(loss.item())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bar.set_description(f'smth:{np.mean(train_loss1[-30:]):.4f} {np.mean(train_loss2[-30:]):.4f}')

    return np.mean(train_loss)


def valid_func(model, loader_valid):
    model.eval()
    valid_loss = []
    valid_loss1 = []
    valid_loss2 = []
    outputs = []
    bar = tqdm(loader_valid)
    with torch.no_grad():
        for images, targets in bar:
            images = images.cuda()
            targets = targets.cuda()

            logits, logits2 = model(images)
            loss1 = criterion(logits, targets)
            loss2 = criterion(logits2, targets.max(1).values)
            loss = (loss1 + loss2) / 2.
            valid_loss1.append(loss1.item())
            valid_loss2.append(loss2.item())
            valid_loss.append(loss.item())
            bar.set_description(f'smth:{np.mean(valid_loss1[-30:]):.4f} {np.mean(valid_loss2[-30:]):.4f}')

    return np.mean(valid_loss)

In [None]:
rcParams['figure.figsize'] = 20, 6
m = TimmModelType2('convnext_nano', pretrained=True)
optimizer = optim.AdamW(m.parameters(), lr=23e-5)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 50, eta_min=23e-6)

lrs = []
for epoch in range(1, 50+1):
    scheduler_cosine.step(epoch-1)
    lrs.append(optimizer.param_groups[0]["lr"])
plt.plot(range(len(lrs)), lrs)
plt.title('Learning Rate (Epochs: 50   Max: 23e-5   Min: 23e-6)', fontsize=20)
plt.tick_params(labelsize=15)
plt.grid()
_ = plt.xticks(np.arange(0, 55, 5))
_ = plt.yticks([0, 5e-5, 10e-5, 15e-5, 20e-5, 25e-5])

# Training

In [None]:
def run(fold):
    log_file = os.path.join('./logs', f'{kernel_type}.txt')
    model_file = os.path.join('./models', f'{kernel_type}_fold{fold}_best.pth')

    train_ = df[df['fold'] != fold].reset_index(drop=True)
    valid_ = df[df['fold'] == fold].reset_index(drop=True)
    dataset_train = CLSDataset(train_, 'train', transform=transforms_train)
    dataset_valid = CLSDataset(valid_, 'valid', transform=transforms_valid)
    loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=8, shuffle=True, num_workers=4, drop_last=True)
    loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=8, shuffle=False, num_workers=4)

    model = TimmModelType2('convnext_nano', pretrained=True)
    model = model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=23e-5)
    scaler = torch.cuda.amp.GradScaler() if use_amp else None
    from_epoch = 0
    metric_best = np.inf
    loss_min = np.inf

    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 50, eta_min=23e-6)

    print(len(dataset_train), len(dataset_valid))

    for epoch in range(1, 50+1):
        scheduler_cosine.step(epoch-1)
        print(time.ctime(), 'Epoch:', epoch)

        train_loss = train_func(model, loader_train, optimizer, scaler)
        valid_loss = valid_func(model, loader_valid)
        metric = valid_loss

        content = time.ctime() + ' ' + f'Fold {fold}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {valid_loss:.5f}, metric: {(metric):.6f}.'
        print(content)
        with open(log_file, 'a') as appender:
            appender.write(content + '\n')

        if metric < metric_best:
            print(f'metric_best ({metric_best:.6f} --> {metric:.6f}). Saving model ...')
            torch.save(model.state_dict(), model_file)
            metric_best = metric

        # Save Last
        if not DEBUG:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scaler_state_dict': scaler.state_dict() if scaler else None,
                    'score_best': metric_best,
                },
                model_file.replace('_best', '_last')
            )

    del model
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
run(0)
run(1)
run(2)
run(3)
run(4)