### Classification notebook for Kaggle RSNA competition

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
import gc
import random
from glob import glob
from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold,StratifiedGroupKFold
import warnings
import pickle
import json
import re
import time
import sys
from requests import get
import multiprocessing
import joblib
import torch
from torch.utils.data import Dataset, DataLoader
import transformers
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import timm
from sklearn.preprocessing import minmax_scale
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2,torchvision
from ipyexperiments.ipyexperiments import IPyExperimentsPytorch
from timm.optim.optim_factory import create_optimizer_v2
from timm import utils
from fastprogress.fastprogress import format_time
from fastai.vision.all import *
from torch.utils.data import WeightedRandomSampler
from sklearn.metrics import roc_auc_score

class CFG:
    seed = 46
    n_splits = 4
    SZ = (1536, 960)
    debug = False
    BS = 8
    EP = 12
    MODEL = 'tf_efficientnetv2_s'
    LR = 1e-04
    WD = 1e-06
    max_norm = 10
random.seed(CFG.seed)
os.environ["PYTHONHASHSEED"] = str(CFG.seed)
np.random.seed(CFG.seed)
plt.rcParams["font.size"] = 13
warnings.filterwarnings('ignore')

In [None]:
timm.list_models()

In [None]:
set_seed(CFG.seed)

In [None]:
root_dir = '///mnt/c/Personal/Competitions/Kaggle/rsna'
image_dir = f'{root_dir}/data/8bit'

DIR = '///mnt/c/Personal/Competitions/Kaggle/rsna/data/'
submit = pd.read_csv(os.path.join(DIR,'sample_submission.csv'))
train = pd.read_csv(os.path.join(DIR,'Train.csv'))
test_df = pd.read_csv(os.path.join(DIR,'Test.csv'))

if CFG.debug:
    train = train.sample(frac=0.01).reset_index(drop=True)

VERSION = "self_distill_tf_efficientnetv2_s"
MODEL_FOLDER = Path(f"{root_dir}/runs/{VERSION}/")
os.makedirs(MODEL_FOLDER,exist_ok=True)
KERNEL_TYPE = f"{CFG.MODEL}_{CFG.SZ[0]}_{CFG.SZ[1]}_bs{CFG.BS}_ep{CFG.EP}_lr{str(CFG.LR).replace('-','')}_wd{str(CFG.WD).replace('-','')}"

print(MODEL_FOLDER)
print(KERNEL_TYPE)

In [None]:
train['difficult_negative_case'] = train['difficult_negative_case'].astype(int)
train['laterality_enc'] = train['laterality'].map(dict({'L':0,'R':1}))
train['view_enc'] = train['view'].map(dict({'CC':0,'MLO':1,'ML':2,'LM':3,'AT':4,'LMO':5}))

In [None]:
train['BIRADS'] = train['BIRADS'].fillna(3).astype(int)
train['density'] = train['density'].fillna("E").map({'A':0,'B':1,'C':2,'D':3,'E':4})

### Get kfolds

In [None]:
mskf = StratifiedGroupKFold(n_splits=CFG.n_splits, shuffle=True, random_state=1)
fold_ids = []
train['fold'] = 0

for train_index, test_index in mskf.split(train,train['cancer'].values,train['patient_id'].values):
    fold_ids.append(test_index)    

for fld in range(CFG.n_splits):
    valIx = fold_ids[fld]
    train.loc[valIx,'fold']=fld 

#### Data loader

In [None]:
def crop_blackarea(d):
    X = cv2.imread(os.path.join(image_dir,f'{d.patient_id}_{d.image_id}.png'))
    X = X[5:-5, 5:-5]
    
    # regions of non-empty pixels
    output= cv2.connectedComponentsWithStats((X > 20).astype(np.uint8)[:, :, 0], 8, cv2.CV_32S)

    # stats.shape == (N, 5), where N is the number of regions, 5 dimensions correspond to:
    # left, top, width, height, area_size
    stats = output[2]

    # finding max area which always corresponds to the breast data. 
    idx = stats[1:, 4].argmax() + 1
    x1, y1, w, h = stats[idx][:4]
    x2 = x1 + w
    y2 = y1 + h
    
    # cutting out the breast data
    X_fit = X[y1: y2, x1: x2]
    
    return X_fit

In [None]:
def read_data(d):
    image = cv2.imread(os.path.join(image_dir,f'{d.patient_id}_{d.image_id}.png'))
    return image

class RsnaDataset(Dataset):
    def __init__(self, df, augs=None,mode='train',crop=True):
        self.length = len(df)
        self.df = df
        self.augs = augs
        self.mode = mode
        self.crop=crop
        
    def __len__(self):
        return self.length

    def __getitem__(self, index):
        d = self.df.iloc[index]
        
        if self.crop:
            image = crop_blackarea(d)
        else:
            image = read_data(d)
            
        image = image.astype(np.float32)
        
        if self.augs is not None:
            image = self.augs(image=image)['image']
                
        patient_id = d.patient_id
        
        cancer = torch.tensor(d.cancer).float()
        
        if self.mode=='test':
            return image,patient_id
        
        return image,cancer

In [None]:
def worker_init_fn(worker_id):
    """
    Handles PyTorch x Numpy seeding issues.
    Args:
        worker_id (int): Id of the worker.
    """
    np.random.seed(np.random.get_state()[1][0] + worker_id)

### Augmentations

In [None]:
TRAIN_AUG = A.Compose([
    A.ShiftScaleRotate(shift_limit=0.01, scale_limit=0.03, rotate_limit=15, p=0.5, border_mode=0),
    A.OneOf([A.HorizontalFlip(p = 0.5),
    A.VerticalFlip(p = 0.5)],p=0.5),    
    A.RandomBrightnessContrast(p=0.5),
    A.OneOf(
        transforms=[
           A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=0.1),
            A.PiecewiseAffine(p=0.3)],p=0.2),
    A.Affine(translate_percent=0.0, rotate=0, shear=0, scale=[0.8,1.2], p= 0.5),
    A.CoarseDropout(max_holes=8, max_height= 5, max_width= 5, p=0.5),
    A.Cutout(num_holes = 8,max_h_size = 5, max_w_size = 5, p=0.7),
    A.Resize(CFG.SZ[0],CFG.SZ[1]),
    A.Normalize(mean=0,std=1),
    ToTensorV2(),
])

VALID_AUG = A.Compose([
    A.Resize(CFG.SZ[0],CFG.SZ[1]),
    A.Normalize(mean=0,std=1),
    ToTensorV2(),
])

### Visualization

In [None]:
dataset_show = RsnaDataset(train, augs=TRAIN_AUG, mode='train')
loader_show = torch.utils.data.DataLoader(dataset_show, batch_size=8,shuffle=False)
img,target  = next(iter(loader_show))

grid = torchvision.utils.make_grid(img, normalize=True, padding=2)
grid = grid.permute(1, 2, 0)

In [None]:
dataset_show = RsnaDataset(train, augs=TRAIN_AUG, mode='train',crop=False)
loader_show = torch.utils.data.DataLoader(dataset_show, batch_size=8,shuffle=False)
img,target  = next(iter(loader_show))

grid = torchvision.utils.make_grid(img, normalize=True, padding=2)
grid = grid.permute(1, 2, 0)
show_image(grid, figsize=(15,8),title=[x for x in target.numpy()]);

### Model

In [None]:
import torch
from torch import nn
from torch.nn import functional as F


class AdaptiveConcatPool1d(nn.Module):
    def forward(self, x):
        return torch.cat((F.adaptive_avg_pool1d(x, 1), F.adaptive_max_pool1d(x, 1)), dim=1)

    def feat_mult(self):
        return 2


class SwinPooler(nn.Module):
    def forward(self, x):
        return x.mean(dim=1)


class AdaptiveConcatPool2d(nn.Module):
    def forward(self, x):
        return torch.cat((F.adaptive_avg_pool2d(x, 1), F.adaptive_max_pool2d(x, 1)), dim=1)

    def feat_mult(self):
        return 2


def gem_1d(x, p=3, eps=1e-6):
    return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1),)).pow(1.0 / p)


def gem_2d(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)


class GeM1d(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=False):
        super(GeM1d, self).__init__()
        if p_trainable:
            self.p = torch.nn.Parameter(torch.ones(1) * p)
        else:
            self.p = p
        self.eps = eps

    def feat_mult(self):
        return 1

    def forward(self, x):
        return gem_1d(x, p=self.p, eps=self.eps)


class GeM2d(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=False):
        super(GeM2d, self).__init__()
        if p_trainable:
            self.p = torch.nn.Parameter(torch.ones(1) * p)
        else:
            self.p = p
        self.eps = eps

    def feat_mult(self):
        return 1

    def forward(self, x):
        return gem_2d(x, p=self.p, eps=self.eps)

In [None]:
m = timm.create_model(CFG.MODEL, features_only=True, pretrained=True)

o = m(torch.randn(2, 3, 1024, 640))
for x in o:
    print(x.shape)

In [None]:
import math

import timm
import torch
import torch.utils.checkpoint as cp
from timm.models.layers import SelectAdaptivePool2d, create_act_layer, create_attn
from timm.models.resnet import Bottleneck, create_aa
from torch import nn
from torch.nn import functional as F


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        cardinality=1,
        base_width=64,
        reduce_first=1,
        dilation=1,
        first_dilation=None,
        act_layer=nn.ReLU,
        norm_layer=nn.BatchNorm2d,
        attn_layer=None,
        aa_layer=None,
        drop_block=None,
        drop_path=None,
    ):
        super(Bottleneck, self).__init__()

        width = int(math.floor(planes * (base_width / 64)) * cardinality)
        first_planes = width // reduce_first
        outplanes = planes * self.expansion
        first_dilation = first_dilation or dilation
        use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)

        self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
        self.bn1 = norm_layer(first_planes)
        self.act1 = create_act_layer(act_layer, inplace=True)

        self.conv2 = nn.Conv2d(
            first_planes,
            width,
            kernel_size=3,
            stride=1 if use_aa else stride,
            padding=first_dilation,
            dilation=first_dilation,
            groups=cardinality,
            bias=False,
        )
        self.bn2 = norm_layer(width)
        self.drop_block = drop_block() if drop_block is not None else nn.Identity()
        self.act2 = create_act_layer(act_layer, inplace=True)
        self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa)

        self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
        self.bn3 = norm_layer(outplanes)

        self.se = create_attn(attn_layer, outplanes)

        self.act3 = create_act_layer(act_layer, inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.drop_path = drop_path

    def zero_init_last(self):
        nn.init.zeros_(self.bn3.weight)

    def forward(self, x):
        shortcut = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.drop_block(x)
        x = self.act2(x)
        x = self.aa(x)

        x = self.conv3(x)
        x = self.bn3(x)

        if self.se is not None:
            x = self.se(x)

        if self.drop_path is not None:
            x = self.drop_path(x)

        if self.downsample is not None:
            shortcut = self.downsample(shortcut)
        x += shortcut
        x = self.act3(x)

        return x


class NetSelfDistill(nn.Module):
    def __init__(
        self,
        model_name=CFG.MODEL,
        pretrained=True,
        act_layer=nn.Mish,
        set_grad_checkpointing=False,
        pos_weight=None,
        drop_rate=0.2,
        verbose=False,
        **kwargs,
    ):
        super().__init__()

        super().__init__()
        self.backbone = timm.create_model(
            model_name=model_name,
            pretrained=pretrained,
            act_layer=act_layer,
            features_only=True,
            **kwargs,
        )
        self.drop_rate = drop_rate
        self.feature_info = self.backbone.feature_info.channels()
        self.set_grad_checkpointing = set_grad_checkpointing

        self.neck = SelectAdaptivePool2d(pool_type="catavgmax", flatten=True)
        self.dropout = nn.Dropout(self.drop_rate)
        self.cls_head = nn.Linear(self.feature_info[-1] * self.neck.feat_mult(), 1, bias=True)

        self.bottleneck1 = Bottleneck(
            self.feature_info[-3],
            math.ceil(self.feature_info[-3] / 4),
            act_layer=act_layer if act_layer is not None else nn.SiLU,
        )
        self.global_pool1 = SelectAdaptivePool2d(pool_type="avg", flatten=True)
        self.fc1 = nn.Linear(self.feature_info[-3], 1)

        self.bottleneck2 = Bottleneck(
            self.feature_info[-2],
            math.ceil(self.feature_info[-2] / 4),
            act_layer=act_layer if act_layer is not None else nn.SiLU,
        )
        self.global_pool2 = SelectAdaptivePool2d(pool_type="avg", flatten=True)
        self.fc2 = nn.Linear(self.feature_info[-2], 1)

        self.cls_head.weight.data.normal_(0, 0.01)
        self.cls_head.bias.data.zero_()

        self.pos_weight = torch.FloatTensor(pos_weight) if pos_weight is not None else None
        self.cls_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)

        if verbose:
            print(self)

    def forward(self, img, return_distill_logits=True):
        if self.set_grad_checkpointing:
            feat_maps = cp.checkpoint(self.backbone, img)
        else:
            feat_maps = self.backbone(img)

        feat_map_1, feat_map_2, feat_map_fin = feat_maps[-3], feat_maps[-2], feat_maps[-1]

        x = self.neck(feat_map_fin)
        x = self.dropout(x)
        x = self.cls_head(x).reshape(-1)

        if return_distill_logits:
            feat_map_1 = self.bottleneck1(feat_map_1)
            feat_map_1 = self.global_pool1(feat_map_1)
            feat_map_1 = self.dropout(feat_map_1)
            distill_logit_1 = self.fc1(feat_map_1).reshape(-1)
    
            feat_map_2 = self.bottleneck2(feat_map_2)
            feat_map_2 = self.global_pool2(feat_map_2)
            feat_map_2 = self.dropout(feat_map_2)
            distill_logit_2 = self.fc2(feat_map_2).reshape(-1)

            return x, distill_logit_1, distill_logit_2,feat_map_1, feat_map_2, feat_map_fin

        return x

In [None]:
dl = DataLoader(RsnaDataset(train, augs=TRAIN_AUG, mode='train'),
                          batch_size=2,
                          shuffle=True,
                          num_workers=8,
                          drop_last=True,
                        worker_init_fn=worker_init_fn)

image,cancer = next(iter(dl))
# a.shape,b.shape,c.shape

In [None]:
# m = get_rsna_classification_model(CFG.MODEL)
# m = NetSelfDistill()
# x, distill_logit_1, distill_logit_2,feat_map_1, feat_map_2, feat_map_fin = m(image)
# cancer1 = m(image)
# print(x)

### Sampling

In [None]:
class BalanceSampler(torch.utils.data.Sampler):

    def __init__(self, dataset, ratio = 3):
        self.r = ratio-1
        self.dataset = dataset
        self.pos_index = np.where(dataset.df.cancer>0)[0]
        self.neg_index = np.where(dataset.df.cancer==0)[0]

        self.length = self.r * int(np.floor(len(self.neg_index)/self.r)) 
        self.ds_len =  self.length + (self.length // self.r) 

    def __iter__(self):
        pos_index = self.pos_index.copy()
        neg_index = self.neg_index.copy()
        np.random.shuffle(pos_index)
        np.random.shuffle(neg_index)

        neg_index = neg_index[:self.length].reshape(-1,self.r)
        #pos_index = np.random.choice(pos_index, self.length//self.r).reshape(-1,1)
        pos_index = np.tile(pos_index, (len(neg_index) // len(pos_index)) + 1)[:len(neg_index)].reshape(-1,1)

        index = np.concatenate([pos_index,neg_index],-1).reshape(-1)
        return iter(index)

    def __len__(self):
        return self.ds_len

In [None]:
def class_imbalance_sampler(labels):
    class_count = torch.bincount(labels.squeeze())
    class_weighting = 1. / class_count
    sample_weights = class_weighting[labels]
    sampler = WeightedRandomSampler(sample_weights, len(labels))
    return sampler

In [None]:
class ExhaustiveWeightedRandomSampler(WeightedRandomSampler):
    """ExhaustiveWeightedRandomSampler behaves pretty much the same as WeightedRandomSampler
    except that it receives an extra parameter, exaustive_weight, which is the weight of the
    elements that should be sampled exhaustively over multiple iterations.

    This is useful when the dataset is very big and also very imbalanced, like the negative
    sample is way more than positive samples, we want to over sample positive ones, but also
    iterate over all the negative samples as much as we can.

    Args:
        weights (sequence): a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        exaustive_weight (int): which weight of samples should be sampled exhaustively,
            normally this is the one that should not been over sampled, like the lowest
            weight of samples in the dataset.
        generator (Generator): Generator used in sampling.
    """

    def __init__(
        self,
        weights: Sequence[float],
        num_samples: CFG.BS,
        exaustive_weight=1,
        generator=None,
    ) -> None:
        super().__init__(weights, num_samples, True, generator)
        self.all_indices = torch.tensor(list(range(num_samples)))
        self.exaustive_weight = exaustive_weight
        self.weights_mapping = torch.tensor(weights) == self.exaustive_weight
        self.remaining_indices = torch.tensor([], dtype=torch.long)

    def get_remaining_indices(self) -> torch.Tensor:
        remaining_indices = self.weights_mapping.nonzero().squeeze()
        return remaining_indices[torch.randperm(len(remaining_indices))]

    def __iter__(self) -> Iterator[int]:
        rand_tensor = torch.multinomial(
            self.weights, self.num_samples, self.replacement, generator=self.generator
        )
        exaustive_indices = rand_tensor[
            self.weights_mapping[rand_tensor].nonzero().squeeze()
        ]
        while len(exaustive_indices) > len(self.remaining_indices):
            self.remaining_indices = torch.cat(
                [self.remaining_indices, self.get_remaining_indices()]
            )
        yield_indexes, self.remaining_indices = (
            self.remaining_indices[: len(exaustive_indices)],
            self.remaining_indices[len(exaustive_indices) :],
        )
        rand_tensor[
            (rand_tensor[..., None] == exaustive_indices).any(-1).nonzero().squeeze()
        ] = yield_indexes
        yield from iter(rand_tensor.tolist())

### Custom Loss

In [None]:
### Customn losss fnc

### Customn losss fnc
class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """

    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = F.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        return torch.mean(focal_loss)
    
class LabelSmoothingLoss(nn.Module): 
    def __init__(self, classes=1, smoothing=0.0, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim 
    def forward(self, pred, target): 
        pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad():
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
  

class CustomAuxLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """

    def __init__(self, alpha=0.8):
        super(CustomAuxLoss, self).__init__()
        self.alpha = alpha
        self.epsilon = 1e-12  # prevent training from Nan-loss error

    def forward(self, x, distill_logit_1, distill_logit_2,feat_map_1, feat_map_2, feat_map_fin, gt_label, batch_ix):
#         print(feat_map_1.shape, feat_map_2.shape, feat_map_fin.shape)
        
        x, distill_logit_1, distill_logit_2 = x.view(-1, 1), distill_logit_1.view(-1, 1), distill_logit_2.view(-1, 1)
        gt_label = gt_label.view(-1, 1)
        
        x_loss = F.binary_cross_entropy_with_logits(x, gt_label)

        loss_1 = F.binary_cross_entropy_with_logits(distill_logit_1, gt_label)
        loss_2 = F.binary_cross_entropy_with_logits(distill_logit_2, gt_label)

        loss = x_loss + loss_1 + loss_2
        
        return loss


### Train & Validation Function

In [None]:
def pfbeta(labels, preds, beta=1,clip=True):
    if clip:
        preds = preds.clip(0, 1)
    y_true_count = labels.sum()
    ctp = preds[labels==1].sum()
    cfp = preds[labels==0].sum()
    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return torch.tensor(0.0)
    
def pfbeta_thresh(labels, preds, beta=1):
    preds = preds>0.1
    y_true_count = labels.sum()
    ctp = preds[labels==1].sum()
    cfp = preds[labels==0].sum()
    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return torch.tensor(0.0)
    
def optimal_f1(labels, predictions):
    labels = labels.cpu().numpy()
    predictions = predictions.cpu().numpy()
    thres = np.linspace(0, 1, 100)
    f1s = [pfbeta(labels, predictions > thr,clip=False) for thr in thres]
    idx = np.argmax(f1s)
    return f1s[idx], thres[idx]

In [None]:
def train_one_epoch(
    model: nn.Module,
    loader: Iterable,
    loss_fn: Callable,
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
    mixup_fn: Callable = None,
    grad_scaler: torch.cuda.amp.GradScaler = None,
    mbar: master_bar = None,
):

    model.train()

    losses_m = utils.AverageMeter()

    pbar = progress_bar(loader, parent=mbar, leave=False)
    pbar.update(0)

    for batch_idx, (input, target) in enumerate(loader):
        input, target  = input.cuda(), target.cuda()
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):
            x, distill_logit_1, distill_logit_2,feat_map_1, feat_map_2, feat_map_fin = model(input)
            loss = loss_fn(x, distill_logit_1, distill_logit_2,feat_map_1, feat_map_2, feat_map_fin,target,batch_idx)
        losses_m.update(loss.item(), input.size(0))

        grad_scaler.scale(loss).backward()
        grad_scaler.step(optimizer)
        
#         torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_norm)

        grad_scaler.update()

        if lr_scheduler is not None:
            lr_scheduler.step()

        pbar.update(batch_idx + 1)
        pbar.comment = f"{losses_m.avg:.4f}"

    pbar.on_iter_end()
    return OrderedDict([("loss", losses_m.avg)])


@torch.inference_mode()
def validate(model: nn.Module, loader: Iterable, loss_fn: Callable, mbar: master_bar):
    model.eval()

    metric_m = utils.AverageMeter()
    metric_m_thresh = utils.AverageMeter()
    auc_m = utils.AverageMeter()
    losses_m = utils.AverageMeter()

    pbar = progress_bar(loader, parent=mbar, leave=False)
    pbar.update(0)
    
    out = []
    real = []
    for batch_idx, (input, target) in enumerate(loader):
        input, target = input.cuda(), target.cuda()
        x, _, _,_, _, _ = model(input)

        loss = loss_fn(x, target).item()
        losses_m.update(loss, input.size(0))
    
        output = F.sigmoid(x)
        metric = pfbeta(target,output).item()
        metric_thresh,_ = optimal_f1(target, output)
        metric_m.update(metric, input.size(0))
        metric_m_thresh.update(metric_thresh.item(), input.size(0))
        pbar.update(batch_idx + 1)
        out.append(output)
        real.append(target)
    optf1, _ = optimal_f1(torch.cat(real), torch.cat(out))
    auc = roc_auc_score(torch.cat(real).cpu().numpy(),torch.cat(out).cpu().numpy())
    pbar.on_iter_end()
    return OrderedDict([("loss", losses_m.avg), ("metric", metric_m.avg),("metric_thresh", optf1),('auc',auc)])


### Run!

In [None]:
def training_loop(fold):
    
    with IPyExperimentsPytorch(exp_enable=False, cl_set_seed=42, cl_compact=True):
        print()
        print("*" * 100)
        print(f"Training fold {fold}")
        print("*" * 100)

        torch.backends.cudnn.benchmark = True
      
        dataset_train = RsnaDataset(train.query("fold!=@fold").reset_index(drop=True), augs=TRAIN_AUG, mode="train")
        dataset_valid = RsnaDataset(train.query("fold==@fold").reset_index(drop=True), augs=VALID_AUG, mode="valid")

        print(f"TRAIN: {len(dataset_train)} | VALID: {len(dataset_valid)}")

        loader_train = torch.utils.data.DataLoader(dataset_train, 
                                                   CFG.BS, 
                                                   num_workers=8, 
                                                   drop_last=True,
                                                  pin_memory=True,
#                                                    sampler = ExhaustiveWeightedRandomSampler(train.query("fold!=@fold").reset_index(drop=True).weight,num_samples=CFG.BS))
                                                   shuffle=True)
#                                                   sampler = class_imbalance_sampler(torch.tensor(train.query("fold!=@fold").reset_index(drop=True)['cancer'].values)))
#                                                    sampler=BalanceSampler(dataset_train))
        loader_valid = torch.utils.data.DataLoader(dataset_valid, CFG.BS * 2, num_workers=8, shuffle=False)

        model = NetSelfDistill()
        model.cuda()
        
        optimizer = create_optimizer_v2(model, "Adam", lr=CFG.LR,weight_decay=CFG.WD)

        num_train_steps = len(loader_train) * CFG.EP
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                           max_lr=CFG.LR,
                                                           pct_start=0.1,
                                                           total_steps=num_train_steps,
                                                           verbose=False)

        train_loss_fn = CustomAuxLoss()
        valid_loss_fn = nn.BCEWithLogitsLoss()

        grad_scaler = torch.cuda.amp.GradScaler()

        print(f"Scheduled epochs: {CFG.EP}")

        mbar = master_bar(list(range(CFG.EP)))
        best_epoch, best_metric = 0, 100
        metric_names = ["epoch", "train_loss", "valid_loss", "metric","metric_thresh", "auc", "time"]
        mbar.write([f"{l:.6f}" if isinstance(l, float) else str(l) for l in metric_names], table=True)
#         alpha = 0.5
        for epoch in range(CFG.EP):
            
#             train_loss_fn = CustomAuxLoss(alpha=alpha)
                        
            start_time = time.time()
            mbar.update(epoch)
            
            train_metrics = train_one_epoch(
                model, loader_train, train_loss_fn, optimizer,
                lr_scheduler=lr_scheduler, mixup_fn=None, grad_scaler=grad_scaler, mbar=mbar)

            valid_metrics = validate(model, loader_valid, valid_loss_fn, mbar=mbar)
            
            elapsed = format_time(time.time() - start_time)
            epoch_log = [epoch,train_metrics["loss"], valid_metrics["loss"], valid_metrics["metric"],
                         valid_metrics["metric_thresh"],valid_metrics["auc"], elapsed]
            mbar.write([f"{l:.6f}" if isinstance(l, float) else str(l) for l in epoch_log], table=True)

#             if (valid_metrics["loss"] < best_metric) or (1):
            if 1:
                best_epoch, best_metric = epoch, valid_metrics["loss"]
                path = Path(f'{MODEL_FOLDER}/fold_{fold}')
                os.makedirs(path,exist_ok=True)
                dirpath = path / (KERNEL_TYPE + f"_Epoch_{epoch}_fold_{fold}.pth")
                torch.save(model.state_dict(), dirpath)
            
                
        mbar.on_iter_end()
        print("*** Best metric: {0} (epoch {1})".format(best_metric, best_epoch))

In [None]:
folds = [0,1,2,3]
if __name__ == "__main__":
    for fold_idx in folds:
        training_loop(fold_idx)
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
for i in range(5):
    torch.cuda.empty_cache()
    gc.collect()
    
del _
gc.collect()

In [None]:
def gen_oof(fold):
   
    torch.backends.cudnn.benchmark = True
    dataset_valid = RsnaDataset(train.query("fold==@fold").reset_index(drop=True), augs=VALID_AUG, mode="valid")
    ix =  train.query("fold==@fold").index
    print(f"VALID: {len(dataset_valid)}")

    loader_valid = torch.utils.data.DataLoader(dataset_valid, CFG.BS , num_workers=8, shuffle=False)
    model = NetSelfDistill()
    model.load_state_dict(torch.load(f'{MODEL_FOLDER}/fold_{fold}/{KERNEL_TYPE}_Epoch_{CFG.EP-1}_fold_{fold}.pth'))
    model.cuda()  
    model.eval()

    preds = []
    imageids = []

    for input,label in tqdm(loader_valid, dynamic_ncols=True, desc="OOF Generation"):
        pred = []
        with torch.cuda.amp.autocast(), torch.no_grad():
            input = input.cuda()
            pred.append(F.sigmoid(model(input)[0]))
            torch.cuda.empty_cache()
            gc.collect()
        preds.append(torch.concat(pred).data.cpu().numpy())
    return np.concatenate(preds, axis=0),ix

In [None]:
oof = np.zeros((len(train)))
for k in tqdm(folds):
    oof_fold,ix = gen_oof(k)
    print(oof_fold.min(),oof_fold.max())
    oof[ix] += oof_fold

In [None]:
def optimal_f1_numpy(oof,fold):
    labels = train.loc[train['fold'].isin(fold)].reset_index(drop=True)['cancer'].values
    oof = oof[train.loc[train['fold'].isin(fold)].index]
    thres = np.linspace(0, 1, 100)
    f1s = [pfbeta(labels, oof > thr,clip=False) for thr in thres]
    idx = np.argmax(f1s)
    return f1s[idx], thres[idx]

In [None]:
scr, thresh = optimal_f1_numpy(oof,folds)

In [None]:
scr,thresh

### Fin 