In [None]:
!pip install -q /kaggle/input/iterative-stratification/iterative-stratification-master
!pip -q install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

In [None]:
import sys
sys.path.append("../input/timm-pytorch-image-models/pytorch-image-models-master")
from fastai.vision.all import *
import os, gc
import time
import random
import numpy as np
import pandas as pd
import subprocess
import cv2
import PIL.Image
import matplotlib.pyplot as plt
%matplotlib inline
from pylab import rcParams
import seaborn as sns
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional as FV
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from sklearn.metrics import roc_auc_score, accuracy_score
from warmup_scheduler import GradualWarmupScheduler
import albumentations
import timm
from tqdm import tqdm
import torch.cuda.amp as amp
import warnings

warnings.simplefilter('ignore')
scaler = amp.GradScaler()

In [None]:
print('Python        : ' + sys.version.split('\n')[0])
print('Numpy         : ' + np.__version__)
print('Pandas        : ' + pd.__version__)
print('PyTorch       : ' + torch.__version__)
print('Albumentations: ' + albumentations.__version__)
print('Timm          : ' + timm.__version__)

In [None]:
cls = ["Negative for Pneumonia", "Typical Appearance", "Indeterminate Appearance", "Atypical Appearance"]

class cfg:
    
    debug = False
    seed = 42
    nfold = 5
    num_classes = 4
    siamese_nfeatures = 128
    
    train = True
    analyze = False
    pretrained = True
    over_sample = False
    rand_aug = False
    multi_head = False
    freeze = False
    heads = [2,2]
    mode = "all"
    
    folds = 3
    dim = [320, 320, 224]
    batch_size = [32, 32, 32]
    valid_fold = [3, 2, 4]
    aggresive_aug = [False,False,False]
    
    # MODELS
    
    WGTS = [0.3,0.7]
    
    kernel_type = ['densenet169_320_lr3e5_bs32_10epo', 
                   'inception_v4_320_lr3e5_bs32_10epo',
                   'tf_efficientnet_b5_ns_224_lr3e5_bs32_10epo']

    backbone = ['densenet169',
                'inception_v4',
                'tf_efficientnet_b5_ns']
    
    base_weights = [None,
                    None]
    
    #base_weights = [None,None]
    
    pre_trained_weights = ["../input/siimcovidpytorchdataset/models/densenet169_320_lr3e5_bs32_8epo_best_fold0.pth",
                           "../input/siimcovidpytorchdataset/models/inception_v4_320_lr3e5_bs32_8epo_best_fold1.pth",
                           "../input/siimcovidptmodels/tf_efficientnet_b5_ns_224_lr3e4_bs32_10epo_best_fold1.pth"]
    
    #pre_trained_weights = [None,None,None]
    
    if mode == "binary":
        num_classes = 2
    
    # Other
    
    num_workers = 2
    warmup_epo = 2
    init_lr = 3e-5
    cosine_epo = 3 if not debug else 1
    n_epochs = warmup_epo + cosine_epo
    loss_weights = [1., 9.]
    log_dir = './logs'
    model_dir = './models'
    
os.makedirs(cfg.log_dir, exist_ok=True)
os.makedirs(cfg.model_dir, exist_ok=True)
kt_full = ''
for kt in cfg.kernel_type:
    kt_full += ("_"+kt) 
log_file = os.path.join(cfg.log_dir, f'log{kt_full}.txt')

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # set True to be faster
    print(f'Setting all seeds to be {seed} to reproduce...')
seed_everything(cfg.seed)

In [None]:
def over_sample(df_train, labels, anchor):
    
    df = df_train.copy()
    for l in labels:
        df_lbl = df[df[l] == 1]
        r = len(df[df[anchor]==1])//len(df_lbl)
        for _ in range(r):
            df = df.append(df_lbl, ignore_index = True)
        
    return df

df_train = pd.read_csv("../input/siimcovid19-512-jpg-image-dataset/train.csv")
df_train["image_path"] = "../input/siimcovid19-512-jpg-image-dataset/train/" + df_train["image_id"] + ".jpg"

if cfg.train and cfg.over_sample:
    df_train = over_sample(df_train, cls[1:], cls[0])

df_train["class_name"] = np.argmax(df_train[cls].values, axis=1)
print(df_train.class_name.value_counts())

y = df_train[cls].values
X = df_train['image_id'].values

df_train['fold'] = np.nan

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
mskf = MultilabelStratifiedKFold(n_splits=cfg.nfold, random_state=cfg.seed, shuffle=True)
for i, (_, test_index) in enumerate(mskf.split(X, y)):
    df_train.iloc[test_index, -1] = i
    
df_train['fold'] = df_train['fold'].astype('int')
df_train.head()

In [None]:
if cfg.mode == "binary":
    df_train["others"] = 1
    df_train["others"][df_train["Negative for Pneumonia"] == 1] = 0
    if not cfg.analyze:
        cls = ["Negative for Pneumonia", "others"]

In [None]:
if cfg.debug:
    df_train = df_train[0:10]
    
def split_data(vld_fold):
    df_train["is_valid"] = False
    df_train["is_valid"][df_train["fold"] == vld_fold] = True
    df_train[cls] = df_train[cls].astype(int)

In [None]:
def randAugment(N, M, dim=256, p=0.5):

    shift_x = np.linspace(0,150,10)
    shift_y = np.linspace(0,150,10)
    rot = np.linspace(0,30,10)
    shear = np.linspace(0,10,10)
    sola = np.linspace(0,256,10)
    post = [4,4,5,5,6,6,7,7,8,8]
    cont = [np.linspace(-0.8,-0.1,10),np.linspace(0.1,2,10)]
    bright = np.linspace(0.1,0.7,10)
    shar = np.linspace(0.1,0.9,10)
    cut = np.linspace(0.05,0.1,10)

    Aug = [albumentations.ShiftScaleRotate(shift_limit_x=shift_x[M], rotate_limit=0, shift_limit_y=0, shift_limit=shift_x[M], p=p),
        albumentations.ShiftScaleRotate(shift_limit_y=shift_y[M], rotate_limit=0, shift_limit_x=0, shift_limit=shift_y[M], p=p),
        albumentations.IAAAffine(rotate=rot[M], p=p),
        albumentations.IAAAffine(shear=shear[M], p=p),
        albumentations.InvertImg(p=p),
        albumentations.Equalize(p=p),
        albumentations.Solarize(threshold=sola[M], p=p),
        albumentations.Posterize(num_bits=post[M], p=p),
        albumentations.RandomContrast(limit=[cont[0][M], cont[1][M]], p=p),
        albumentations.RandomBrightness(limit=bright[M], p=p),
        albumentations.IAASharpen(alpha=shar[M], lightness=shar[M], p=p),
        albumentations.Cutout(num_holes=8, max_h_size=int(cut[M]*dim), max_w_size=int(cut[M]*dim), p=p)]

    ops = np.random.choice(Aug, N)
    ops = np.append(ops, [albumentations.Resize(dim, dim), albumentations.Normalize()])
    transforms = albumentations.Compose(ops)
    
    return transforms


def get_transforms(image_size, aggresive=True, rand=False, M=3, N=2, P=0.7):
    
    transforms_train_aggresive = albumentations.Compose([
        albumentations.Transpose(p=0.3),
        albumentations.VerticalFlip(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.RandomBrightness(limit=0.2, p=0.75),
        albumentations.RandomContrast(limit=0.2, p=0.75),
        
        albumentations.OneOf([
            albumentations.MotionBlur(blur_limit=5),
            albumentations.MedianBlur(blur_limit=5),
            albumentations.GaussianBlur(blur_limit=5),
            albumentations.GaussNoise(var_limit=(5.0, 30.0)),
        ], p=0.5),

        albumentations.OneOf([
            albumentations.OpticalDistortion(distort_limit=1.0),
            albumentations.GridDistortion(num_steps=5, distort_limit=1.),
            albumentations.ElasticTransform(alpha=3),
        ], p=0.5),

        albumentations.CLAHE(clip_limit=4.0, p=0.7),
        albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),    
        albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        albumentations.Resize(image_size, image_size),
        albumentations.Cutout(max_h_size=int(image_size * 0.12), max_w_size=int(image_size * 0.12), num_holes=10, p=0.7),
        albumentations.Normalize()
    ])
    
    
    transforms_train = albumentations.Compose([
        albumentations.VerticalFlip(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
#         albumentations.RandomBrightness(limit=0.2, p=0.65),
#         albumentations.RandomContrast(limit=0.2, p=0.65),
#         albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.65),
        albumentations.Resize(image_size, image_size),
        albumentations.Cutout(max_h_size=int(image_size * 0.12), max_w_size=int(image_size * 0.12), num_holes=3, p=0.5),
        albumentations.Normalize()
    ])
    
    transforms_val = albumentations.Compose([
        albumentations.Resize(image_size, image_size),
        albumentations.Normalize()
    ])
    
    if rand:
        return randAugment(N, M, image_size, P), transforms_val
    
    return transforms_train_aggresive if aggresive else transforms_train, transforms_val

In [None]:
class SIIMDataset(Dataset):
    
    def __init__(self, df, transforms=None, subset="train"):
        
        super().__init__()
        self.df = df
        self.transforms = transforms
        self.subset = subset
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
            
        row = self.df.iloc[idx]
        img = cv2.imread(row.image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
            
        img = FV.to_tensor(img) 
        
        if self.subset == "test":
            return img
        else:  
            label = row[cls].astype('int').values
            label = torch.as_tensor(label)   
            return img, label.float()

In [None]:
transforms_train, transforms_valid = get_transforms(320, cfg.aggresive_aug[0], rand=cfg.rand_aug)
transforms_train

In [None]:
split_data(0)
ds = SIIMDataset(df=df_train, transforms=transforms_train, subset="train")

In [None]:
rcParams['figure.figsize'] = 15,5
for i in range(2):
    f, axarr = plt.subplots(1,5)
    for p in range(5):
        idx = i*5 + p
        img, label = ds[idx]
        axarr[p].imshow(img.transpose(0,1).transpose(1,2).squeeze())


In [None]:
def get_activation(activ_name: str="relu"):

    act_dict = {"relu": nn.ReLU(inplace=True),
                "tanh": nn.Tanh(),
                "sigmoid": nn.Sigmoid(),
                "identity": nn.Identity()}
    
    if activ_name in act_dict:
        return act_dict[activ_name]
    else:
        raise NotImplementedError
        
class Conv2dBNActiv(nn.Module):
    """Conv2d -> (BN ->) -> Activation"""
    
    def __init__(
        self, in_channels: int, out_channels: int,
        kernel_size: int, stride: int=1, padding: int=0,
        bias: bool=False, use_bn: bool=True, activ: str="relu"
    ):
        """"""
        super(Conv2dBNActiv, self).__init__()
        layers = []
        layers.append(nn.Conv2d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=bias))
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
            
        layers.append(get_activation(activ))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        """Forward"""
        return self.layers(x)
    
class SSEBlock(nn.Module):
    """channel `S`queeze and `s`patial `E`xcitation Block."""

    def __init__(self, in_channels: int):
        """Initialize."""
        super(SSEBlock, self).__init__()
        self.channel_squeeze = nn.Conv2d(
            in_channels=in_channels, out_channels=1,
            kernel_size=1, stride=1, padding=0, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """Forward."""
        # # x: (bs, ch, h, w) => h: (bs, 1, h, w)
        h = self.sigmoid(self.channel_squeeze(x))
        # # x, h => return: (bs, ch, h, w)
        return x * h
    
    
class SpatialAttentionBlock(nn.Module):
    """Spatial Attention for (C, H, W) feature maps"""
    
    def __init__(
        self, in_channels: int,
        out_channels_list,
    ):
        """Initialize"""
        super(SpatialAttentionBlock, self).__init__()
        self.n_layers = len(out_channels_list)
        channels_list = [in_channels] + out_channels_list
        assert self.n_layers > 0
        assert channels_list[-1] == 1
        
        for i in range(self.n_layers - 1):
            in_chs, out_chs = channels_list[i: i + 2]
            layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="relu")
            setattr(self, f"conv{i + 1}", layer)
            
        in_chs, out_chs = channels_list[-2:]
        layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="sigmoid")
        setattr(self, f"conv{self.n_layers}", layer)
        
    def forward(self, x):
        """Forward"""
        h = x
        for i in range(self.n_layers):
            h = getattr(self, f"conv{i + 1}")(h)
            
        h = h * x
        return h
    
class SingleHeadModel(nn.Module):
    
    def __init__(
        self, base_name: str='resnext50_32x4d', out_dim: int=11, pretrained=False
    ):
        """"""
        self.base_name = base_name
        super(SingleHeadModel, self).__init__()
        
        # # load base model
        base_model = timm.create_model(base_name, pretrained=pretrained)
        in_features = base_model.num_features
        
        # # remove global pooling and head classifier
        # base_model.reset_classifier(0, '')
        base_model.reset_classifier(0)
        
        # # Shared CNN Bacbone
        self.backbone = base_model
        
        # # Single Heads.
        self.head_fc = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(in_features, out_dim))

    def forward(self, x):
        """"""
        h = self.backbone(x)
        h = self.head_fc(h)
        return h
    
class MultiHeadModel(nn.Module):
    
    def __init__(self, base_name, out_dims_head, weights=None):
        
        super(MultiHeadModel, self).__init__()
        self.base_name = base_name
        self.weights = weights
        self.n_heads = len(out_dims_head)
        
        # # load base model
        if  not isinstance(self.base_name, str):
            base_model = self.base_name

            if self.weights is not None:
                base_model.load_state_dict(torch.load(self.weights))

            in_features = base_model.myfc.in_features
            base_model.base.global_pool.flatten = False
            base_model.dropout = nn.Identity()
            base_model.myfc = nn.Identity()
        else:
            base_model = timm.create_model(base_name, pretrained=cfg.pretrained)
            in_features = base_model.num_features
            base_model.reset_classifier(0, '')
        
        # # Shared CNN Bacbone
        self.backbone = base_model
        
        for i, out_dim in enumerate(out_dims_head):
            layer_name = f"head_{i}"
            layer = nn.Sequential(
                SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
                nn.AdaptiveAvgPool2d(output_size=1),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, in_features),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(in_features, out_dim))
            setattr(self, layer_name, layer)

    def forward(self, x):
        """"""
        h = self.backbone(x)
        hs = [
            getattr(self, f"head_{i}")(h) for i in range(self.n_heads)]
        y = torch.cat(hs, axis=1)
        return y


class Model(nn.Module):
    
    def __init__(self, backbone, num_classes, dropout=0.5, pretrained=True):
        super(Model, self).__init__()
        
        self.backbone = backbone
        self.num_classes = num_classes
        self.base = timm.create_model(self.backbone, pretrained=pretrained)
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(dropout)
        
        if "res" in self.backbone:
            self.myfc = nn.Linear(self.base.fc.in_features, self.num_classes)
            self.base.global_pool = nn.Identity()
            self.base.fc = nn.Identity()
            
        elif "vit" in self.backbone:
            self.myfc = nn.Linear(self.base.head.in_features, self.num_classes)
            self.base.global_pool = nn.Identity()
            self.base.head = nn.Identity()
            
        elif "nfnet" in self.backbone:
            self.myfc = nn.Linear(self.base.head.fc.in_features, self.num_classes)
            self.base.global_pool = nn.Identity()
            self.base.head.fc = nn.Identity()
            
        elif "ception" in self.backbone:
            self.myfc = nn.Linear(self.base.last_linear.in_features, self.num_classes)
            self.base.global_pool = nn.Identity()
            self.base.last_linear = nn.Identity()
        
        elif "efficient" in self.backbone:
            self.myfc = nn.Linear(self.base.classifier.in_features, self.num_classes)
            self.base.classifier = nn.Identity()
            
        else:
            self.myfc = nn.Linear(self.base.classifier.in_features, self.num_classes)
            self.base.global_pool = nn.Identity()
            self.base.classifier = nn.Identity()
            
        if cfg.freeze:
            for param in self.base.parameters():
                param.requires_grad = False
        else:
            for param in self.base.parameters():
                param.requires_grad = True
        

    def extract(self, x):
            return self.base(x)

    def forward(self, x):

        if "efficient" in self.backbone:
            x = self.extract(x)
            h = self.myfc(self.dropout(x))
            return h
        else:
            bs = x.size(0)
            features = self.base(x)
            pooled_features = self.pooling(features).view(bs, -1)
            output = self.myfc(pooled_features)
            return output
    
class SiameseModel(nn.Module):
    
    def __init__(self, model, pretrained=True, weights=None, dropout=0.5):
        super(SiameseModel, self).__init__()
        
        self.model = model
        self.weights = weights
        self.dropout = nn.Dropout(dropout)
        
        if self.weights is None:
            
            self.base = timm.create_model(self.model, pretrained=pretrained)
            
            if "res" in self.backbone:
                self.myfc = nn.Linear(self.base.fc.in_features, cfg.siamese_nfeatures)
                self.base.fc = nn.Identity()

            elif "vit" in self.backbone:
                self.myfc = nn.Linear(self.base.head.in_features, cfg.siamese_nfeatures)
                self.base.head = nn.Identity()

            elif "nfnet" in self.backbone:
                self.myfc = nn.Linear(self.base.head.fc.in_features, cfg.siamese_nfeatures)
                self.base.head.fc = nn.Identity()

            else:
                self.myfc = nn.Linear(self.base.classifier.in_features, cfg.siamese_nfeatures)
                self.base.classifier = nn.Identity()
                
        else:
            self.base = self.model
            self.base.load_state_dict(torch.load(self.weights, map_location=device))
            self.base.dropout = nn.Identity()
            self.myfc = nn.Linear(self.base.myfc.in_features, cfg.siamese_nfeatures)
            self.base.myfc = nn.Identity()
            
        
    def extract(self, x):
        return self.base(x)

    def forward(self, x):
        x = self.extract(x)
        e = self.myfc(self.dropout(x))
        return e

In [None]:
m = Model(backbone=cfg.backbone[2], num_classes=cfg.num_classes)
x = torch.stack([ds[i][0] for i in range(2)])
y = m(x)
y.shape

In [None]:
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
        
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

optimizer = optim.Adam(m.parameters(), lr=cfg.init_lr)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.cosine_epo)
scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, 
                                            total_epoch=cfg.warmup_epo, 
                                            after_scheduler=scheduler_cosine)
lrs = []
for epoch in range(1, cfg.n_epochs+1):
    scheduler_warmup.step(epoch-1)
    lrs.append(optimizer.param_groups[0]["lr"])
rcParams['figure.figsize'] = 20,3
plt.plot(lrs)


In [None]:
def train_epoch(model, loader, optimizer, criterion):

    model.train()
    train_loss = []
    bar = tqdm(loader)
    for (data, targets) in bar:

        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)

        with amp.autocast():
            logits = model(data)
            loss = criterion(logits, targets)
            
        scaler.scale(loss).backward() 
        scaler.step(optimizer)
        scaler.update()

        loss_np = loss.item()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-50:]) / min(len(train_loss), 50)
        bar.set_description('loss: %.4f, smth: %.4f' % (loss_np, smooth_loss))

    return np.mean(train_loss)


def valid_epoch(model, loader, get_output=False):

    model.eval()
    val_loss = []
    PREDS = []
    LOGITS = []
    TARGETS = []
    
    with torch.no_grad():
        for (data, targets) in tqdm(loader):
            data, targets = data.to(device), targets.to(device)
            logits = model(data)
            loss = criterion(logits, targets)
            val_loss.append(loss.item())
            LOGITS.append(logits.cpu())
            act = nn.Softmax(dim=1)
            
            PREDS.append(F.one_hot(torch.argmax(act(logits.cpu()), dim=1), 
                                   num_classes=cfg.num_classes).numpy().astype(int))
            
            TARGETS.append(targets.cpu().numpy().astype(int))
            
    val_loss = np.mean(val_loss)
    PREDS = np.concatenate(np.array(PREDS))
    TARGETS = np.concatenate(np.array(TARGETS))

    if get_output:
        return LOGITS
    else:
        acc = accuracy_score(TARGETS[:-1], PREDS[:-1])
        
        try:
            auc = roc_auc_score(TARGETS[:-1], PREDS[:-1])
        except:
            auc = 0.0
        
        return val_loss,acc,auc
    


In [None]:
def run(fold):
    
    content = 'Fold: ' + str(fold)
    print(content)
    with open(log_file, 'a') as appender:
        appender.write(content + '\n')
        
    split_data(cfg.valid_fold[fold])
        
    train_ = df_train[~df_train["is_valid"]].copy()
    valid_ = df_train[df_train["is_valid"]].copy()
    
    transforms_train, transforms_valid = get_transforms(cfg.dim[fold], cfg.aggresive_aug[fold], rand=cfg.rand_aug)
    
    dataset_train = SIIMDataset(train_, subset='train', transforms=transforms_train)
    dataset_valid = SIIMDataset(valid_, subset='valid', transforms=transforms_valid)
    
    train_loader = DataLoader(dataset_train, 
                               batch_size=cfg.batch_size[fold], 
                               shuffle=True, 
                               num_workers=cfg.num_workers)
    
    valid_loader = DataLoader(dataset_valid, 
                               batch_size=cfg.batch_size[fold], 
                               shuffle=False, 
                               num_workers=cfg.num_workers)
    
    if cfg.multi_head: 
        if cfg.base_weights[fold] is not None:
            base_model = Model(backbone=cfg.backbone[fold], num_classes=cfg.num_classes, pretrained=cfg.pretrained)
            model = MultiHeadModel(base_model, cfg.heads, cfg.base_weights[fold])
        else:
            model = MultiHeadModel(cfg.backbone[fold], cfg.heads)
    else:
        model = Model(backbone=cfg.backbone[fold], num_classes=cfg.num_classes, pretrained=cfg.pretrained)
        
    if cfg.pre_trained_weights[fold] is not None:
        print("Loading Pretrained Weights.........")
        model.load_state_dict(torch.load(cfg.pre_trained_weights[fold]))
    
    model = model.to(device)
    loss_min = 10000
    model_file = os.path.join(cfg.model_dir, f'{cfg.kernel_type[fold]}_best_fold{fold}.pth')

    optimizer = optim.Adam(model.parameters(), lr=cfg.init_lr)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, cfg.cosine_epo)
    scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=cfg.warmup_epo, 
                                                after_scheduler=scheduler_cosine)
    
    for epoch in range(1, cfg.n_epochs+1):
        print(time.ctime(), 'Epoch:', epoch)
        scheduler_warmup.step(epoch-1)
        
        gc.collect()
        train_loss = train_epoch(model, train_loader, optimizer, criterion)
        
        gc.collect()
        val_loss, acc, auc = valid_epoch(model, valid_loader)

        content = time.ctime() + ' ' + f'Fold {fold} Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, \
                    train loss: {train_loss:.4f}, valid loss: {(val_loss):.4f}, acc: {acc:.4f} auc: {auc:.4f}.'
        
        print(content)
        with open(log_file, 'a') as appender:
            appender.write(content + '\n')
            
        if loss_min > val_loss:
            print('Val Loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(loss_min, val_loss))
            torch.save(model.state_dict(), model_file)
            loss_min = val_loss
    
    if cfg.multi_head:
        torch.save(model.state_dict(), os.path.join(cfg.model_dir, f'{cfg.kernel_type[fold]}_multihead.pth'))
    else:
        torch.save(model.state_dict(), os.path.join(cfg.model_dir, f'{cfg.kernel_type[fold]}.pth'))
        
    torch.cuda.empty_cache()   

In [None]:
class FocalCosineLoss(nn.Module):
    
    def __init__(self, alpha = 1, gamma = 2, xent = 0.1, reduction = "mean"):
        super(FocalCosineLoss, self).__init__()
        self.alpha     = alpha
        self.gamma     = gamma
        self.xent      = xent
        self.reduction = reduction
        self.y         = torch.Tensor([1]).to(device)
        
    def forward(self, input, target):
        cosine_loss = F.cosine_embedding_loss(input, target, self.y, reduction = self.reduction)
        cent_loss   = nn.BCEWithLogitsLoss()(input, target)
        pt          = torch.exp(-cent_loss)
        focal_loss  = self.alpha * (1-pt)**self.gamma * cent_loss

        if self.reduction == "mean":
            focal_loss = torch.mean(focal_loss)
        
        return cosine_loss + self.xent * focal_loss
    
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, dist, label):

        loss = torch.mean(1/2*(label) * torch.pow(dist, 2) +
                                      1/2*(1-label) * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2))


        return loss
    
class TripletLoss(torch.nn.Module):
    """
    Triplet loss function.
    """

    def __init__(self, margin=2.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):

        squarred_distance_1 = (anchor - positive).pow(2).sum(1)
        
        squarred_distance_2 = (anchor - negative).pow(2).sum(1)
        
        triplet_loss = F.relu( self.margin + squarred_distance_1 - squarred_distance_2 ).mean()
        
        return triplet_loss

In [None]:
if cfg.train:    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    criterion = nn.BCEWithLogitsLoss()
    for fold in range(cfg.folds):
        run(fold)

In [None]:
if not cfg.train and not cfg.analyze:
    
    WGTS = cfg.WGTS
    PREDS = np.zeros(shape=(len(df_train), cfg.num_classes))
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    def test_epoch(model, loader):

        model.eval()
        PREDS = []
        TARGETS = []
        
        with torch.no_grad():
            for data, labels in tqdm(loader):
                data = data.to(device)
                logits = model(data)
                act = nn.Softmax(dim=1)
                PREDS.append(act(logits.cpu()).numpy())
                TARGETS.append(labels.cpu().numpy())
                
        return np.concatenate(np.array(PREDS)), np.concatenate(np.array(TARGETS))
    
    for fold in range(cfg.folds):  
    
        _,transforms_test = get_transforms(cfg.dim[fold])

        test_dataset = SIIMDataset(df_train, subset='valid', transforms=transforms_test)

        test_loader = DataLoader(test_dataset, 
                                   batch_size=cfg.batch_size[fold], 
                                   shuffle=False, 
                                   num_workers=2)

        model = Model(backbone=cfg.backbone[fold], num_classes=cfg.num_classes, pretrained=False)

        if cfg.pre_trained_weights[fold] is not None:
            print("Loading Pretrained Weights.........")
            model.load_state_dict(torch.load(cfg.pre_trained_weights[fold]))
            
        model = model.to(device)
        preds,TARGETS = test_epoch(model, test_loader)
        PREDS += preds * WGTS[fold]
        
    df_preds = pd.DataFrame(columns = ["image_id", "Target", "Prediction", "Pred_Probability"])
    df_preds["image_id"] = df_train["image_id"]
    df_preds["Target"] = np.argmax(TARGETS, axis = 1)
    df_preds["Prediction"] = np.argmax(PREDS, axis = 1)
    df_preds["Pred_Probability"] = np.max(PREDS, axis = 1)
    df_preds.to_csv("Prediction_DataFrame.csv", index=False)

In [None]:
if cfg.analyze:
    
    df = pd.read_csv("../input/siimpytorchmodels/Prediction_DataFrame.csv")
    if cfg.mode == "binary":
        df["Org_Target"]=np.argmax(np.array(df_train[cls]), axis=1)
        
    def image_analysis(target, predicted):
    
        id_inc = df_inc[df_inc["Target"] == target]["image_id"] 
        paths = [df_train[df_train["image_id"].isin(id_inc)]["image_path"].values, 
                 df_train[df_train[cls[predicted]] == 1]["image_path"].values]

        rcParams['figure.figsize'] = 15,5
        titles = [f"Target label\n{cls[target]}", f"Predicted label\n{cls[predicted]}"]
        for i in range(2):
            f, axarr = plt.subplots(1,5)
            for p in range(5):
                idx = i*5 + p
                img = cv2.imread(paths[i][idx])
                axarr[p].title.set_text(titles[i])
                axarr[p].imshow(img)
                plt.tight_layout(w_pad=1)
    
    def class_ratio(df, lbl="Target", mode="all"):
        cl = [0,1,2,3]
        for c in cl:
            s = len(df[df[lbl]==c])
            l = np.sum(df_train[cls[c]])
            if mode == "all":
                print(f"Class {cls[c]}: {s/l*100}")
                print(f"incorrect : {s} | Total : {l}\n")
            else:
                print(f"Class {cls[c]}: {s/len(df)*100}")

    df_inc = df[df["Target"] != df["Prediction"]]
    df_crr = df[df["Target"] == df["Prediction"]]
    class_ratio(df_inc, "Target")
    print("For label 0")
    class_ratio(df_inc[df_inc["Target"] == 0], "Prediction", "spec")
    print("\nFor label 1")
    class_ratio(df_inc[df_inc["Target"] == 1], "Prediction", "spec")
    image_analysis(0,2)
#     image_analysis(3,0)

In [None]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# M = Model(backbone=cfg.backbone[0], num_classes=cfg.num_classes)
# model = SiameseModel(model=M, weights=cfg.pre_trained_weights[0], pretrained=False)
# x = torch.stack([ds[i][0] for i in range(2)])
# y = model(x)
# y.shape