In [None]:
"""
tsl: We will start with our baseline TSL (train stage-2 liver),
lrh: but use start and ending learning half as large. This is to offset 
    the fact that we are using a batch size that is 4, instead of 8 
    which is used for the cervical spine solution.
efb5: https://huggingface.co/timm/tf_efficientnet_b5.ap_in1k
"""

In [None]:
!pip -q install timm

In [None]:
from collections import defaultdict
import os
# import sys
import gc
from fastcore.all import Path
# import ast
# import cv2
import time
import timm 
# import pickle
import random
import warnings
import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, StratifiedKFold
import sklearn

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

%matplotlib inline
rcParams['figure.figsize'] = 20, 8
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = True
print('device is', device)

# Config

In [None]:
IN = Path('../input')
MODEL_INP_PATH = None
EXTRA_EPOCHS = 0
DEBUG = False
ORGANS = ['liver', 
#           'spleen', 'kidney', 'bowel'
         ]
LABELS = [['liver_healthy', 'liver_low', 'liver_high'], 
#           ['spleen_healthy', 'spleen_low', 'spleen_high'], 
#           ['kidney_healthy', 'kidney_low', 'kidney_high'], 
#           ['bowel_healthy', 'bowel_injury']
         ]
ABS = [[0, 15], 
#        [15, 30], [30, 60], [60, 75]
      ]
N_ORGANS = len(ORGANS)

kernel_type = '0920_1bonev2_effv2s_224_15_6ch_augv2_mixupp5_drl3_rov1p2_bs8_lr23e5_eta23e6_50ep'
load_kernel = None
load_last = True

n_folds = 5
test_fold = 4
backbone = 'tf_efficientnet_b5.ap_in1k'
n_epochs = 15

image_size = 224
n_slice_per_c = 15
in_chans = 6

init_lr = 23e-5 / 2
eta_min = 23e-6 / 2
batch_size = 4
drop_rate = 0.
drop_rate_last = 0.3
drop_path_rate = 0.
p_mixup = 0.5
p_rand_order_v1 = 0.2

# data_dir = '../input/rsna-cropped-2d-224-0920-2m/cropped_2d_224_15_ext0_5ch_0920_2m/cropped_2d_224_15_ext0_5ch_0920_2m'
use_amp = True
num_workers = 4
out_dim = 1


log_dir = './logs'
model_dir = './models'
os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

In [None]:
transforms_train = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.Transpose(p=0.5),
    albumentations.RandomBrightness(limit=0.1, p=0.7),
    albumentations.ShiftScaleRotate(shift_limit=0.3, scale_limit=0.3, rotate_limit=45, border_mode=4, p=0.7),

    albumentations.OneOf([
        albumentations.MotionBlur(blur_limit=3),
        albumentations.MedianBlur(blur_limit=3),
        albumentations.GaussianBlur(blur_limit=3),
        albumentations.GaussNoise(var_limit=(3.0, 9.0)),
    ], p=0.5),
    albumentations.OneOf([
        albumentations.OpticalDistortion(distort_limit=1.),
        albumentations.GridDistortion(num_steps=5, distort_limit=1.),
    ], p=0.5),

    albumentations.Cutout(max_h_size=int(image_size * 0.5), max_w_size=int(image_size * 0.5), num_holes=1, p=0.5),
])

transforms_valid = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
])

# DataFrame

In [None]:
INPUT = '/kaggle/input/rsna-2023-abdominal-trauma-detection'
def load_df(kind='train'):    
    df = pd.read_parquet(os.path.join(INPUT, f'{kind}_dicom_tags.parquet'))
    df['StudyInstanceUID'] = df.path.str.split('/').str[-2]

    df = df[['StudyInstanceUID', 'path', 'PatientID']].drop_duplicates('StudyInstanceUID')
    df['image_folder'] = INPUT + '/' + df.path.str.split('/').str[:-1].apply('/'.join)
    df['study'] = df.StudyInstanceUID
    df['patient'] = df.PatientID
    df['patient_id'] = df.PatientID.astype(int)
    return df
df = load_df('train')
df_train = pd.read_csv(os.path.join(INPUT, 'train.csv'))

In [None]:
cls_inp_paths = [x for x in IN.ls() if 's1-inf' in x.stem]
len(cls_inp_paths), cls_inp_paths[0]

In [None]:
for p in cls_inp_paths:
    for f in p.ls(): 
        if not str(f)[-3:] == 'npy': continue
#         print(f)
        study = f.stem
#         print(study)
        df.loc[df['study'] == study, 'cls_inp_path'] = str(f)

In [None]:
print(df.shape[0])
df = df[df.cls_inp_path.notnull()]
print(df.shape[0], 'size after removing studies with no inputs')

In [None]:
df = df.merge(df_train, on='patient_id')
assert df.isna().sum().sum() == 0

In [None]:
for organ, cols in zip(ORGANS, LABELS): 
    print(organ)
    display(df[['patient', 'study'] + cols].head(2))

In [None]:
np.random.seed(0)
kf = sklearn.model_selection.StratifiedGroupKFold()
df['fold'] = 0
for fold, (_, test_ind) in enumerate(kf.split(df, df['any_injury'], df['patient'])): 
#     print(len(_), len(test_ind))
    df.iloc[test_ind, -1] = fold

In [None]:
df.groupby('fold')[list(df)[-15:]].sum()

# Dataset

In [None]:
class CLSDataset(Dataset):
    def __init__(self, df, mode, transform):

        self.df = df.reset_index()
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        
        
        image_full = np.load(row.cls_inp_path)
        out = defaultdict(dict)
        for organ, cols, (a, b) in zip(ORGANS, LABELS, ABS): 
            images = []
            for image in image_full[a: b]: 
                image = image.transpose(1, 2, 0)
                image = transforms_train(image=image)['image']
                image = image.transpose(2, 0, 1)
                images.append(image)
            images = np.stack(images, 0)
            if organ == 'kidney': 
                images = np.concatenate((images[:15, :, :, :], images[15:, :, :, :]), 2)
            out[organ]['images'] = torch.tensor(images).float()
            out[organ]['labels'] = torch.tensor([row[cols]] * n_slice_per_c).float()
        return out

In [None]:
rcParams['figure.figsize'] = 20,8

df_show = df
dataset_show = CLSDataset(df_show, 'train', transform=transforms_train)
loader_show = torch.utils.data.DataLoader(dataset_show, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
f, axarr = plt.subplots(2,4)
for p in range(2):
    idx = p * 20
    out = dataset_show[idx]
f, axarr = plt.subplots(2,4)
for i, organ in enumerate(out.keys()): 
    print('*******', organ, '*******')
    axarr[0, i].imshow(out[organ]['images'][7][:3].permute(1, 2, 0))
    axarr[1, i].imshow(out[organ]['images'][7][-1])
#     print(out[organ]['labels'])

# Model

In [None]:
class TimmModel(nn.Module):
    def __init__(self, backbone, pretrained=False, out_dim=3, h=image_size, w=image_size):
        super(TimmModel, self).__init__()
        self.h = h
        self.w = w

        self.encoder = timm.create_model(
            backbone,
            in_chans=in_chans,
            num_classes=out_dim,
            features_only=False,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )

        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()


        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=drop_rate, bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(drop_rate_last),
            nn.LeakyReLU(0.1),
            nn.Linear(256, out_dim), # chacnged
        )

    def forward(self, x):  # (bs, nslice, ch, sz, sz)
        bs = x.shape[0]
        x = x.view(bs * n_slice_per_c, in_chans, self.h, self.w)
        feat = self.encoder(x)
        feat = feat.view(bs, n_slice_per_c, -1)
        feat, _ = self.lstm(feat)
        feat = feat.contiguous().view(bs * n_slice_per_c, -1)
        feat = self.head(feat)
        feat = feat.view(bs, n_slice_per_c, -1).contiguous()

        return feat

In [None]:
m = TimmModel(backbone)
m(torch.rand(2, n_slice_per_c, in_chans, image_size, image_size)).shape

# Loss & Metric

In [None]:
bce = nn.BCEWithLogitsLoss(reduction='none')
def criterion(logits, targets, activated=False):
    n_labels = targets.shape[-1]
    if activated:
        losses = nn.BCELoss(reduction='none')(logits.view(-1), targets.view(-1))
    else:
        losses = bce(logits.view(-1, n_labels), targets.view(-1, n_labels))
    norm = torch.ones(logits.view(-1, n_labels).shape).to(device)
    for i, weight in [[1, 2], [2, 4]]: 
        if i == n_labels: break
        mask = (targets.view(-1, n_labels)[:, i] == 1)    
        losses[mask] *= weight
        norm[mask] *= weight
    return losses.sum() / norm.sum()

# Train & Valid func

In [None]:
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam


def train_func(models, loader_train, optimizers, scalers=None):
    [model.train() for model in models]
    train_loss = [[] for _ in range(N_ORGANS)]
    bar = tqdm(loader_train)
    for out in bar:
        for i, (optimizer, organ, scaler, model) in enumerate(zip(optimizers, ORGANS, scalers, models)): 
            optimizer.zero_grad()
            images = out[organ]['images'].to(device)
            targets = out[organ]['labels'].to(device)

            do_mixup = False
            if random.random() < p_mixup:
                do_mixup = True
                images, targets, targets_mix, lam = mixup(images, targets)

            with amp.autocast():
                logits = model(images)
                loss = criterion(logits, targets)
                if do_mixup:
                    loss11 = criterion(logits, targets_mix)
                    loss = loss * lam  + loss11 * (1 - lam)
            train_loss[i].append(loss.item())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

#             bar.set_description(f'smth:{np.mean(train_loss[-30:]):.4f}')

    return [np.mean(tl) for tl in train_loss]


def valid_func(models, loader_valid):
    [model.eval() for model in models]
    valid_loss = [[] for _ in range(N_ORGANS)]
    gts = [[] for _ in range(N_ORGANS)]
    outputs = [[] for _ in range(N_ORGANS)]
    bar = tqdm(loader_valid)
    with torch.no_grad():
        for out in bar:
            for i, (organ, model) in enumerate(zip(ORGANS, models)):
                images = out[organ]['images'].to(device) 
                targets = out[organ]['labels'].to(device)

                logits = model(images)
                loss = criterion(logits, targets)

                gts[i].append(targets.cpu())
                outputs[i].append(logits.cpu())
                valid_loss[i].append(loss.item())

#                 bar.set_description(f'smth:{np.mean(valid_loss[-30:]):.4f}')

    outputs = [torch.cat(output) for output in outputs]
    gts = [torch.cat(gt) for gt in gts]
    valid_loss = [criterion(o, g).item() for o, g in zip(outputs, gts)]

    return valid_loss

In [None]:
rcParams['figure.figsize'] = 20, 2
optimizer = optim.AdamW(m.parameters(), lr=init_lr)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs, eta_min=eta_min)

lrs = []
for epoch in range(1, n_epochs+1):
    scheduler_cosine.step(epoch-1)
    lrs.append(optimizer.param_groups[0]["lr"])
plt.plot(range(len(lrs)), lrs)

# Training

In [None]:
def run(fold):

    log_file = os.path.join(log_dir, f'{kernel_type}.txt')
    model_files = [os.path.join(model_dir, f'{organ}_{kernel_type}_fold{fold}_best.pth') for organ in ORGANS]

    train_ = df[~df['fold'].isin([fold, test_fold])].reset_index(drop=True)
    valid_ = df[df['fold'] == fold].reset_index(drop=True)
    test_ = df[df['fold'] == test_fold].reset_index(drop=True)

    dataset_train = CLSDataset(train_, 'train', transform=transforms_train)
    dataset_valid = CLSDataset(valid_, 'valid', transform=transforms_valid)
    dataset_test = CLSDataset(test_, 'test', transform=transforms_valid)

    loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    models = [
         TimmModel(backbone, pretrained=True), # Liver
#          TimmModel(backbone, pretrained=True), # spleen
#          TimmModel(backbone, pretrained=True, h=image_size*2), # kidney
#          TimmModel(backbone, pretrained=True, out_dim=2), # bowels
             ] 
    models = [model.to(device) for model in models]

    optimizers = [optim.AdamW(model.parameters(), lr=init_lr) for model in models]
    scalers = [torch.cuda.amp.GradScaler() if use_amp else None for _ in range(N_ORGANS)]

    metric_best = [np.inf for _ in range(N_ORGANS)]
    epoch_best = [0 for _ in range(N_ORGANS)]
#     loss_mins = [np.inf for _ in range(N_ORGANS)]

    scheduler_cosines = [torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, n_epochs, eta_min=eta_min) for optimizer in optimizers]

    if MODEL_INP_PATH:
        print(len(dataset_train), len(dataset_valid))
        for i, (organ, model, scaler, optimizer) in enumerate(zip(ORGANS, models, scalers, optimizers)):
            sd_file = f'{MODEL_INP_PATH}/models/{organ}_0920_1bonev2_effv2s_224_15_6ch_augv2_mixupp5_drl3_rov1p2_bs8_lr23e5_eta23e6_50ep_fold0_last.pth'
            sd = torch.load(sd_file,  map_location=torch.device('cpu'))
            msd = sd['model_state_dict']
            msd = {k[7:] if k.startswith('module.') else k: msd[k] for k in msd.keys()}
            model.load_state_dict(msd, strict=True)
            optimizer.load_state_dict(sd['optimizer_state_dict'])
            scaler.load_state_dict(sd['scaler_state_dict'])
            metric_best[i] = sd['score_best']
            epoch_start = sd['epoch'] + 1
        print(epoch_start, 'epoch start')
    else: 
        epoch_start, EXTRA_EPOCHS = 0, n_epochs
        
    print('next')
    for epoch in range(epoch_start, epoch_start + EXTRA_EPOCHS):
        print('************* EPOCH {epoch} *********************')
        scheduler_cosine.step(epoch-1)
        print(time.ctime(), 'Epoch:', epoch)
        train_losss = train_func(models, loader_train, optimizers, scalers)
        valid_losss = valid_func(models, loader_valid)
        metrics = valid_losss

        for organ, optimizer, train_loss, valid_loss in zip(ORGANS, optimizers, train_losss, valid_losss): 
            content = time.ctime() + ' ' + f'Fold {fold}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {valid_loss:.5f}, metric: {(valid_loss):.6f}.'
            print(content)
            with open(log_file, 'a') as appender:
                appender.write(content + '\n')
            
        for i, (organ, metric, model_file, model, scaler, optimizer) in enumerate(zip(ORGANS, metrics, model_files, models, scalers, optimizers)): 
            if metric < metric_best[i]:
                print(f'{organ} metric_best ({metric_best[i]:.6f} --> {metric:.6f}). Saving model ...')
    #             if not DEBUG:
                torch.save(model.state_dict(), model_file)
                metric_best[i] = metric
                epoch_best[i] = epoch

        # Save Last
            if not DEBUG:
                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scaler_state_dict': scaler.state_dict() if scaler else None,
                        'score_best': metric_best[i],
                        'epoch_best': epoch_best[i],
                    },
                    model_file.replace('_best', '_last')
                )

    print('test set metric for this run', valid_func(models, loader_test))
    print('valid set metric for this run', valid_func(models, loader_valid))
    del models
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
run(0)