# 1st Place Solution Training 3D Semantic Segmentation (Stage1)

Hi all,

I'm very exciting to writing this notebook and the summary of our solution here.

This is FULL version of training my final models (stage1), using resnet18d as backbone, unet as decoder and using 128x128x128 as input.

NOTE: **You need to run this code locally because the RAM is not enough here.**

NOTE2: **It is highly recommended to pre-process the 3D semantic segmentation training data first and save it locally, which can greatly speed up the loading of the data.**

My brief summary of winning solution: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/discussion/362607

* Train Stage1 Notebook: This notebook
* Train Stage2 (Type1) Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage2-type1
* Train Stage2 (Type2) Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage2-type2
* Inference Notebook: https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-inference

**If you find these notebooks helpful please upvote. Thanks! **

In [None]:
DEBUG = False

import os
import sys
sys.path = [
    '../input/covn3d-same',
] + sys.path

In [None]:
import os
import sys
import gc
import ast
import cv2
import time
import timm
import pickle
import random
import pydicom
import argparse
import warnings
import numpy as np
import pandas as pd
from glob import glob
import nibabel as nib
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from sklearn.model_selection import KFold, StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from monai.transforms import Resize
import  monai.transforms as transforms

%matplotlib inline
rcParams['figure.figsize'] = 20, 8
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

# Config

In [None]:
kernel_type = 'timm3d_res18d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
image_sizes = [128, 128, 128]
R = Resize(image_sizes)
data_dir = '../input/rsna-2022-cervical-spine-fracture-detection'
os.makedirs('./logs', exist_ok=True)
os.makedirs('./models', exist_ok=True)

In [None]:
transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=2),
    transforms.RandAffined(keys=["image", "mask"], translate_range=[int(x*y) for x, y in zip(image_sizes, [0.3, 0.3, 0.3])], padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
])

transforms_valid = transforms.Compose([
])

# DataFrame

In [None]:
df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))

mask_files = os.listdir(f'{data_dir}/segmentations')
df_mask = pd.DataFrame({
    'mask_file': mask_files,
})
df_mask['StudyInstanceUID'] = df_mask['mask_file'].apply(lambda x: x[:-4])
df_mask['mask_file'] = df_mask['mask_file'].apply(lambda x: os.path.join(data_dir, 'segmentations', x))
df = df_train.merge(df_mask, on='StudyInstanceUID', how='left')
df['image_folder'] = df['StudyInstanceUID'].apply(lambda x: os.path.join(data_dir, 'train_images', x))
df['mask_file'].fillna('', inplace=True)

df_seg = df.query('mask_file != ""').reset_index(drop=True)

kf = KFold(5)
df_seg['fold'] = -1
for fold, (train_idx, valid_idx) in enumerate(kf.split(df_seg, df_seg)):
    df_seg.loc[valid_idx, 'fold'] = fold

df_seg.tail()

# Dataset

In [None]:
revert_list = [
    '1.2.826.0.1.3680043.1363',
    '1.2.826.0.1.3680043.20120',
    '1.2.826.0.1.3680043.2243',
    '1.2.826.0.1.3680043.24606',
    '1.2.826.0.1.3680043.32071'
]

In [None]:
def load_dicom(path):
    # 给定ct图像的路径，返回128x128的图像数组
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = cv2.resize(data, (image_sizes[0], image_sizes[1]), interpolation = cv2.INTER_LINEAR)
    return data


def load_dicom_line_par(path):
    '''从一个病人的数百张CT图像中均匀选取128张读取、缩放、拼接、标准化，得到(128, 128, 128)的数组'''
    
    # path like '../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990'
    
    t_paths = sorted(glob(os.path.join(path, "*")),
       key=lambda x: int(x.split('/')[-1].split(".")[0]))
    # 将患者数百张ct图像的路径按序号顺序排序
    # t_paths: ['../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/1.dcm',
    #           '../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/2.dcm',
    #           '../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/3.dcm',
    #           '../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/4.dcm',
    #            ...]

    n_scans = len(t_paths)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., image_sizes[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]
    # 从患者的数百张ct图像中均匀选取128张
    # t_paths: ['../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/1.dcm',
    #           '../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/3.dcm',
    #           '../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/5.dcm',
    #           '../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.26990/7.dcm',
    #            ...]

    images = []
    for filename in t_paths:
        images.append(load_dicom(filename))
    images = np.stack(images, -1)
    # images: (128, 128, 128)，最后一个维度是切片方向
    
    images = images - np.min(images)
    images = images / (np.max(images) + 1e-4)
    images = (images * 255).astype(np.uint8)

    return images  # images: (128, 128, 128)  0-255


def load_sample(row, has_mask=True):
    '''获取一个患者(3, 128, 128, 128)的图像数组和(7, 128, 128, 128)的3D语义分割标签'''

    image = load_dicom_line_par(row.image_folder)
    # image: (128, 128, 128)  0-255
    if image.ndim < 4:
        image = np.expand_dims(image, 0).repeat(3, 0)
        # image: (3, 128, 128, 128)

    if has_mask:
        mask_org = nib.load(row.mask_file).get_fdata()
        shape = mask_org.shape
        # shape: (512, 512, ?)
        mask_org = mask_org.transpose(1, 0, 2)[::-1, :, ::-1]  # (d, w, h) 调整方向
        mask = np.zeros((7, shape[0], shape[1], shape[2]))
        # mask: (7, 512, 512, ?)
        for cid in range(7):
            mask[cid] = (mask_org == (cid+1))
            # mask: (7, 512, 512, ?)   0 / 1
        mask = mask.astype(np.uint8) * 255
        # mask: (7, 512, 512, ?)   0 / 255
        mask = R(mask).numpy()
        # mask: (7, 128, 128, 128), 代表每个颈椎骨对应的(128, 128, 128)的语义分割标签，标签值仅有0和255
        
        return image, mask
        # image: (3, 128, 128, 128)  0-255   mask: (7, 128, 128, 128)  0/255
    else:
        return image

    
for index in range(len(df_seg)):
    row = df_seg.iloc[index]
    t_path = os.path.join('data', row['StudyInstanceUID'])
    os.makedirs(t_path, exist_ok=True)
    t_path1 = os.path.join(t_path, 'image.npy')
    t_path2 = os.path.join(t_path, 'mask.npy')
    if not os.path.exists(t_path1):
        image, mask = load_sample(row, has_mask=True)
        np.save(t_path1, image)
        np.save(t_path2, mask)

In [None]:
class SEGDataset(Dataset):
    def __init__(self, df, mode, transform):
        self.df = df.reset_index()
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        t_path1 = os.path.join('data', row['StudyInstanceUID'], 'image.npy')
        t_path2 = os.path.join('data', row['StudyInstanceUID'], 'mask.npy')
        image = np.load(t_path1).astype(np.float32)
        mask = np.load(t_path2).astype(np.float32)
        # image(值域0-255): (3, 128, 128, 128)  mask(值域0/255): (7, 128, 128, 128)
    
        if row.StudyInstanceUID in revert_list:
            mask = mask[:, :, :, ::-1]

        res = self.transform({'image':image, 'mask':mask})
        image = res['image'] / 255.
        mask = res['mask']
        mask = (mask > 127).astype(np.float32)

        image, mask = torch.tensor(image).float(), torch.tensor(mask).float()

        return image, mask
        # image(值域0.-1.): (3, 128, 128, 128)  mask(值域0./1.): (7, 128, 128, 128)

In [None]:
rcParams['figure.figsize'] = 20,8
df_show = df_seg
dataset_show = SEGDataset(df_show, 'train', transform=transforms_train)

In [None]:
for i in range(2):
    f, axarr = plt.subplots(1,4)
    for p in range(4):
        idx = i*4+p
        img, mask = dataset_show[idx]
        img = img[:, :, :, 60]
        # img: (3, 128, 128)
        mask = mask[:, :, :, 60]
        # mask: (7, 128, 128)
        mask[0] = mask[0] + mask[3] + mask[6]
        mask[1] = mask[1] + mask[4]
        mask[2] = mask[2] + mask[5]
        mask = mask[:3]
        img = img * 0.7 + mask * 0.3
        axarr[p].imshow(img.transpose(0, 1).transpose(1,2).squeeze())

# Model

In [None]:
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=0.,
            drop_path_rate=0.,
            pretrained=pretrained
        )
        # torch.Size([1, 64, 32, 32])
        # torch.Size([1, 64, 16, 16])
        # torch.Size([1, 128, 8, 8])
        # torch.Size([1, 256, 4, 4])
        # torch.Size([1, 512, 2, 2])
        
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        # encoder_channels: [1, 64, 64, 128, 256, 512]
        decoder_channels = [256, 128, 64, 32, 16]
        n_blocks = 4
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                # encoder_channels: [1, 64, 64, 128, 256]
                decoder_channels=decoder_channels[:n_blocks],
                # decoder_channels: [256, 128, 64, 32]
                n_blocks=n_blocks,
            )
        # 32， 7
        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    
    
# batch size = 2
# ===============================================================================================
# Layer (type:depth-idx)                        Output Shape              Param #
# ===============================================================================================
# TimmSegModel                                  [2, 7, 128, 128]          --
# ├─FeatureListNet: 1-1                         [2, 64, 64, 64]           --
# │    └─Sequential: 2-1                        [2, 64, 64, 64]           --
# │    │    └─Conv2d: 3-1                       [2, 32, 64, 64]           864
# │    │    └─BatchNorm2d: 3-2                  [2, 32, 64, 64]           64
# │    │    └─ReLU: 3-3                         [2, 32, 64, 64]           --
# │    │    └─Conv2d: 3-4                       [2, 32, 64, 64]           9,216
# │    │    └─BatchNorm2d: 3-5                  [2, 32, 64, 64]           64
# │    │    └─ReLU: 3-6                         [2, 32, 64, 64]           --
# │    │    └─Conv2d: 3-7                       [2, 64, 64, 64]           18,432
# │    └─BatchNorm2d: 2-2                       [2, 64, 64, 64]           128
# │    └─ReLU: 2-3                              [2, 64, 64, 64]           --
# │    └─MaxPool2d: 2-4                         [2, 64, 32, 32]           --
# │    └─Sequential: 2-5                        [2, 64, 32, 32]           --
# │    │    └─BasicBlock: 3-8                   [2, 64, 32, 32]           73,984
# │    │    └─BasicBlock: 3-9                   [2, 64, 32, 32]           73,984
# │    └─Sequential: 2-6                        [2, 128, 16, 16]          --
# │    │    └─BasicBlock: 3-10                  [2, 128, 16, 16]          230,144
# │    │    └─BasicBlock: 3-11                  [2, 128, 16, 16]          295,424
# │    └─Sequential: 2-7                        [2, 256, 8, 8]            --
# │    │    └─BasicBlock: 3-12                  [2, 256, 8, 8]            919,040
# │    │    └─BasicBlock: 3-13                  [2, 256, 8, 8]            1,180,672
# │    └─Sequential: 2-8                        [2, 512, 4, 4]            --
# │    │    └─BasicBlock: 3-14                  [2, 512, 4, 4]            3,673,088
# │    │    └─BasicBlock: 3-15                  [2, 512, 4, 4]            4,720,640
# ├─UnetDecoder: 1-2                            [2, 32, 128, 128]         --
# │    └─Identity: 2-9                          [2, 256, 8, 8]            --
# │    └─ModuleList: 2-10                       --                        --
# │    │    └─DecoderBlock: 3-16                [2, 256, 16, 16]          1,475,584
# │    │    └─DecoderBlock: 3-17                [2, 128, 32, 32]          516,608
# │    │    └─DecoderBlock: 3-18                [2, 64, 64, 64]           147,712
# │    │    └─DecoderBlock: 3-19                [2, 32, 128, 128]         27,776
# ├─Conv2d: 1-3                                 [2, 7, 128, 128]          2,023
# ===============================================================================================
# Total params: 13,365,447
# Trainable params: 13,365,447
# Non-trainable params: 0
# Total mult-adds (G): 5.33
# ===============================================================================================
# Input size (MB): 0.39
# Forward/backward pass size (MB): 99.09
# Params size (MB): 53.46
# Estimated Total Size (MB): 152.95
# ===============================================================================================

In [None]:
from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output


m = TimmSegModel('resnet18d')
m = convert_3d(m)
m(torch.rand(1, 3, 128,128,128)).shape


# batch size = 2
# ===============================================================================================
# Layer (type:depth-idx)                        Output Shape              Param #
# ===============================================================================================
# TimmSegModel                                  [2, 7, 128, 128, 128]     --
# ├─FeatureListNet: 1-1                         [2, 64, 64, 64, 64]       --
# │    └─Sequential: 2-1                        [2, 64, 64, 64, 64]       --
# │    │    └─Conv3d: 3-1                       [2, 32, 64, 64, 64]       2,592
# │    │    └─BatchNorm3d: 3-2                  [2, 32, 64, 64, 64]       64
# │    │    └─ReLU: 3-3                         [2, 32, 64, 64, 64]       --
# │    │    └─Conv3d: 3-4                       [2, 32, 64, 64, 64]       27,648
# │    │    └─BatchNorm3d: 3-5                  [2, 32, 64, 64, 64]       64
# │    │    └─ReLU: 3-6                         [2, 32, 64, 64, 64]       --
# │    │    └─Conv3d: 3-7                       [2, 64, 64, 64, 64]       55,296
# │    └─BatchNorm3d: 2-2                       [2, 64, 64, 64, 64]       128
# │    └─ReLU: 2-3                              [2, 64, 64, 64, 64]       --
# │    └─MaxPool3d: 2-4                         [2, 64, 32, 32, 32]       --
# │    └─Sequential: 2-5                        [2, 64, 32, 32, 32]       --
# │    │    └─BasicBlock: 3-8                   [2, 64, 32, 32, 32]       221,440
# │    │    └─BasicBlock: 3-9                   [2, 64, 32, 32, 32]       221,440
# │    └─Sequential: 2-6                        [2, 128, 16, 16, 16]      --
# │    │    └─BasicBlock: 3-10                  [2, 128, 16, 16, 16]      672,512
# │    │    └─BasicBlock: 3-11                  [2, 128, 16, 16, 16]      885,248
# │    └─Sequential: 2-7                        [2, 256, 8, 8, 8]         --
# │    │    └─BasicBlock: 3-12                  [2, 256, 8, 8, 8]         2,688,512
# │    │    └─BasicBlock: 3-13                  [2, 256, 8, 8, 8]         3,539,968
# │    └─Sequential: 2-8                        [2, 512, 4, 4, 4]         --
# │    │    └─BasicBlock: 3-14                  [2, 512, 4, 4, 4]         10,750,976
# │    │    └─BasicBlock: 3-15                  [2, 512, 4, 4, 4]         14,157,824
# ├─UnetDecoder: 1-2                            [2, 32, 128, 128, 128]    --
# │    └─Identity: 2-9                          [2, 256, 8, 8, 8]         --
# │    └─ModuleList: 2-10                       --                        --
# │    │    └─DecoderBlock: 3-16                [2, 256, 16, 16, 16]      4,424,704
# │    │    └─DecoderBlock: 3-17                [2, 128, 32, 32, 32]      1,548,800
# │    │    └─DecoderBlock: 3-18                [2, 64, 64, 64, 64]       442,624
# │    │    └─DecoderBlock: 3-19                [2, 32, 128, 128, 128]    83,072
# ├─Conv3d: 1-3                                 [2, 7, 128, 128, 128]     6,055
# ===============================================================================================
# Total params: 39,728,967
# Trainable params: 39,728,967
# Non-trainable params: 0
# Total mult-adds (G): 839.07
# ===============================================================================================
# Input size (MB): 50.33
# Forward/backward pass size (MB): 7391.41
# Params size (MB): 158.92
# Estimated Total Size (MB): 7600.66
# ===============================================================================================

# Loss & Metric

In [None]:
from typing import Any, Dict, Optional


def binary_dice_iou_score(
    y_pred: torch.Tensor,
    y_true: torch.Tensor,
    threshold: Optional[float] = None,
    nan_score_on_empty=False,
    eps: float = 1e-7,
) -> float:

    if threshold is not None:
        y_pred = (y_pred > threshold).to(y_true.dtype)

    intersection = torch.sum(y_pred * y_true).item()
    cardinality = (torch.sum(y_pred) + torch.sum(y_true)).item()

    score = (2.0 * intersection) / (cardinality + eps)

    has_targets = torch.sum(y_true) > 0
    has_predicted = torch.sum(y_pred) > 0

    if not has_targets:
        if nan_score_on_empty:
            score = np.nan
        else:
            score = float(not has_predicted)
    return score


def multilabel_dice_iou_score(
    y_true: torch.Tensor,
    y_pred: torch.Tensor,
    threshold=None,
    eps=1e-7,
    nan_score_on_empty=False,
):
    ious = []
    num_classes = y_pred.size(0)
    for class_index in range(num_classes):
        iou = binary_dice_iou_score(
            y_pred=y_pred[class_index],
            y_true=y_true[class_index],
            threshold=threshold,
            nan_score_on_empty=nan_score_on_empty,
            eps=eps,
        )
        ious.append(iou)

    return ious


def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0
    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    return 1 - ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))


def bce_dice(input, target, loss_weights=[1, 1]):
    loss1 = loss_weights[0] * nn.BCEWithLogitsLoss()(input, target)
    loss2 = loss_weights[1] * dice_loss(input, target)
    return (loss1 + loss2) / sum(loss_weights)

criterion = bce_dice

# Train & Valid func

In [None]:
def mixup(input, truth, clip=[0, 1]):
    indices = torch.randperm(input.size(0))
    shuffled_input = input[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    input = input * lam + shuffled_input * (1 - lam)
    return input, truth, shuffled_labels, lam


def train_func(model, loader_train, optimizer, scaler=None):
    model.train()
    train_loss = []
    bar = tqdm(loader_train)
    for images, gt_masks in bar:
        optimizer.zero_grad()
        images = images.cuda()
        gt_masks = gt_masks.cuda()

        do_mixup = False
        if random.random() < 0.1:
            do_mixup = True
            images, gt_masks, gt_masks_sfl, lam = mixup(images, gt_masks)

        with amp.autocast():
            logits = model(images)
            loss = criterion(logits, gt_masks)
            if do_mixup:
                loss2 = criterion(logits, gt_masks_sfl)
                loss = loss * lam  + loss2 * (1 - lam)

        train_loss.append(loss.item())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bar.set_description(f'smth:{np.mean(train_loss[-30:]):.4f}')

    return np.mean(train_loss)


def valid_func(model, loader_valid):
    model.eval()
    valid_loss = []
    outputs = []
    ths = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    batch_metrics = [[]] * 7
    bar = tqdm(loader_valid)
    with torch.no_grad():
        for images, gt_masks in bar:
            images = images.cuda()
            gt_masks = gt_masks.cuda()

            logits = model(images)
            loss = criterion(logits, gt_masks)
            valid_loss.append(loss.item())
            for thi, th in enumerate(ths):
                pred = (logits.sigmoid() > th).float().detach()
                for i in range(logits.shape[0]):
                    tmp = multilabel_dice_iou_score(
                        y_pred=logits[i].sigmoid().cpu(),
                        y_true=gt_masks[i].cpu(),
                        threshold=0.5,
                    )
                    batch_metrics[thi].extend(tmp)
            bar.set_description(f'smth:{np.mean(valid_loss[-30:]):.4f}')
            
    metrics = [np.mean(this_metric) for this_metric in batch_metrics]
    print('best th:', ths[np.argmax(metrics)], 'best dc:', np.max(metrics))

    return np.mean(valid_loss), np.max(metrics)


In [None]:
rcParams['figure.figsize'] = 20, 2
optimizer = optim.AdamW(m.parameters(), lr=3e-3)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
lrs = []
for epoch in range(1, 1000+1):
    scheduler_cosine.step(epoch-1)
    lrs.append(optimizer.param_groups[0]["lr"])
plt.plot(range(len(lrs)), lrs)

# Training

In [None]:
def run(fold):
    log_file = os.path.join('./logs', f'{kernel_type}.txt')
    model_file = os.path.join('./models', f'{kernel_type}_fold{fold}_best.pth')

    train_ = df_seg[df_seg['fold'] != fold].reset_index(drop=True)
    valid_ = df_seg[df_seg['fold'] == fold].reset_index(drop=True)
    dataset_train = SEGDataset(train_, 'train', transform=transforms_train)
    dataset_valid = SEGDataset(valid_, 'valid', transform=transforms_valid)
    loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=4, shuffle=True, num_workers=4)
    loader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=4, shuffle=False, num_workers=4)

    model = TimmSegModel('resnet18d', pretrained=True)
    model = convert_3d(model)
    model = model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=3e-3)
    scaler = torch.cuda.amp.GradScaler()
    from_epoch = 0
    metric_best = 0.
    loss_min = np.inf

    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 1000)

    print(len(dataset_train), len(dataset_valid))

    for epoch in range(1, 1000+1):
        scheduler_cosine.step(epoch-1)

        print(time.ctime(), 'Epoch:', epoch)

        train_loss = train_func(model, loader_train, optimizer, scaler)
        valid_loss, metric = valid_func(model, loader_valid)

        content = time.ctime() + ' ' + f'Fold {fold}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {valid_loss:.5f}, metric: {(metric):.6f}.'
        print(content)
        with open(log_file, 'a') as appender:
            appender.write(content + '\n')

        if metric > metric_best:
            print(f'metric_best ({metric_best:.6f} --> {metric:.6f}). Saving model ...')
            torch.save(model.state_dict(), model_file)
            metric_best = metric

        # Save Last
        if not DEBUG:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scaler_state_dict': scaler.state_dict() if scaler else None,
                    'score_best': metric_best,
                },
                model_file.replace('_best', '_last')
            )

    del model
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
run(0)
run(1)
run(2)
run(3)
run(4)