# BoTNet | Res-type bottleneck blocks + Attention 🔥

*Bottleneck Transformers for Visual Recognition*: https://arxiv.org/abs/2101.11605

![](https://user-images.githubusercontent.com/22078438/106106482-f04da900-6188-11eb-8f15-820811c2f908.png)


> We present BoTNet, a conceptually simple yet powerful backbone architecture that incorporates self-attention for multiple computer vision tasks including image classification, object detection and instance segmentation. By just replacing the spatial convolutions with global self-attention in the final three bottleneck blocks of a ResNet and no other changes, our approach improves upon the baselines significantly on instance segmentation and object detection while also reducing the parameters, with minimal overhead in latency. Through the design of BoTNet, we also point out how **ResNet bottleneck blocks with self-attention can be viewed as Transformer blocks**. Without any bells and whistles, BoTNet achieves 44.4% Mask AP and 49.7% Box AP on the COCO Instance Segmentation benchmark using the Mask R-CNN framework; surpassing the previous best published single model and single scale results of ResNeSt evaluated on the COCO validation set. Finally, we present a simple adaptation of the BoTNet design for image classification, resulting in models that achieve a strong performance of 84.7% top-1 accuracy on the ImageNet benchmark while being up to 2.33x faster in compute time than the popular EfficientNet models on TPU-v3 hardware. We hope our simple and effective approach will serve as a strong baseline for future research in self-attention models for vision.

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">Install & Import Required Packages</span>

In [None]:
## with internet 
!pip install timm -q
!pip install adamp -q
!pip install bottleneck-transformer-pytorch
!pip install torch-summary -q

# ## w/o internet 
# import os, sys
# sys.path.append('../input/pytorch-image-models/pytorch-image-models-master') # v0.4.7
# sys.path.append('../input/bottleneck-transformers-pytorch')

In [None]:
import os, sys, gc
import cv2
import copy
import time
import random
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import torchvision
from torchvision import models
from torch.utils.data import DataLoader, Dataset
from torch.cuda import amp

from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.utils import class_weight

from tqdm.notebook import tqdm
from collections import defaultdict
import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm
from adamp import AdamP
import bottleneck_transformer_pytorch
from bottleneck_transformer_pytorch import BottleStack
from torchsummary import summary

import warnings 
warnings.filterwarnings('ignore')

print('Timm version:', timm.__version__)
print('Torch version:', torch.__version__)

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">Config</span>

In [None]:
class CFG:
    model_name = 'resnet18'   #    # [ resnet50, resnet101, 'seresnext50_32x4d', ... ]
    img_size = (256, 256)            # (224, 224); (256, 819)  

    scheduler = 'CosineAnnealingLR'   # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    num_epochs = 10    # 3    
    batch_size = 128   # 64   # 12 - 16 for 512 

    lr = 1e-4
    min_lr = 1e-6
    weight_decay = 1e-6
    num_classes = 1
    smoothing = 0.2
            
    apex=False
    debug=False
    train=True
    n_fold = 4
    trn_fold=[0]    # [0, 1, 2, 3]
    print_freq=1500 # 100-500   ## --> reduce to see more often
    num_workers=4
    seed = 2020
    T_max = 12   # CosineAnnealingLR
    T_0 = 12     # CosineAnnealingWarmRestarts
    factor=0.2   # ReduceLROnPlateau
    patience=4   # ReduceLROnPlateau
    eps=1e-6     # ReduceLROnPlateau
    gradient_accumulation_steps=1
    max_grad_norm=1000
    target_col='target'
    

In [None]:
# ====================================================
# Utils
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))



def init_logger(log_file='train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger



# Helpers 
def get_score(y_true, y_pred):
    return roc_auc_score(y_true, y_pred)

def get_result(result_df):
    preds = result_df['preds'].values
    labels = result_df[CFG.target_col].values
    score = get_score(labels, preds)
    LOGGER.info(f'AUC Score: {score:<.4f}')

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)


set_seed(CFG.seed)
LOGGER = init_logger()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

<span style="color: #0087e4; font-family: courier ; font-size:1.5em; font-weight: 300;">Load data & Split CV</span>

In [None]:
ROOT_DIR  = '../input/seti-breakthrough-listen'          # "../input/cassava-leaf-disease-classification"
TRAIN_DIR = '../input/seti-breakthrough-listen/train'    # "../input/cassava-leaf-disease-classification/train_images"
TEST_DIR  = '../input/seti-breakthrough-listen/test'     # "../input/cassava-leaf-disease-classification/test_images"

In [None]:
train = pd.read_csv(f'{ROOT_DIR}/train_labels.csv')
test = pd.read_csv(f'{ROOT_DIR}/sample_submission.csv')

def get_train_file_path(image_id):
    return "../input/seti-breakthrough-listen/train/{}/{}.npy".format(image_id[0], image_id)

def get_test_file_path(image_id):
    return "../input/seti-breakthrough-listen/test/{}/{}.npy".format(image_id[0], image_id)

train['file_path'] = train['id'].apply(get_train_file_path)
test['file_path'] = test['id'].apply(get_test_file_path)

# display(train.head())
# display(test.head())

In [None]:
skf = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)

df_folds = train.copy()
for n, (train_index, val_index) in enumerate(skf.split(train, train[CFG.target_col])):
    df_folds.loc[val_index, 'fold'] = int(n)
df_folds['fold'] = df_folds['fold'].astype(int)

print(f"Split data into {CFG.n_fold} folds\n")
display(df_folds.groupby(['fold', 'target']).size())

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">Dataset</span>

In [None]:
# ====================================================
# Dataset
# ====================================================
class SetiDataset(Dataset):
    def __init__(self, df, transform=None, ):  # n_channels=1
        self.df = df
        self.file_names = df['file_path'].values
        self.labels = df[CFG.target_col].values
        self.transform = transform
#         self.n_channels = n_channels
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_names[idx]
        image = np.load(file_path)
        image = image.astype(np.float32)  # (6, 273, 256)
        #  image = np.vstack(image).transpose((1, 0))  
        
        #         # 3-channel image (only targets - A)
        #         image = np.stack([
        #             image[0,:,:],
        #             image[2,:,:],
        #             image[4,:,:] ])               # (3, 273, 256)
        #         image = image.transpose((1,2,0))  # (273, 256, 3)
        
        # 1-channel image (only targets - A)
        image = image[[0, 2, 4]]             # shape: (3, 273, 256)
        image = np.vstack(image)             # shape: (819, 256)
        image = image.transpose(1, 0)        # shape: (256, 819)
        image = image.astype("float")[..., np.newaxis]  # shape: (256, 819, 1)

        if self.transform:
            image = self.transform(image=image)['image']
            image = image.transpose((2,1,0))      
        
        image = torch.from_numpy(image).float()
        label = torch.tensor(self.labels[idx]).float()
        return image, label
    
    
#     def _read_cadence_array(self, path: Path):
#         """Read cadence file and reshape"""
#         img = np.load(path)[[0, 2, 4]]  # shape: (3, 273, 256)
#         img = np.vstack(img)            # shape: (819, 256)
#         img = img.transpose(1, 0)       # shape: (256, 819)
#         img = img.astype("float")[..., np.newaxis]  # shape: (256, 819, 1)
#         return img    
    
# ## WiP - select channel configuration
# ds = SetiDataset(df_folds.sample(1000), transform=get_transforms(data='train'))
# img, lab = ds[0]
# img.shape, lab.shape

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">Augmentations</span>

In [None]:
# ====================================================
# Transforms
# ====================================================
def get_transforms(*, data):
    
    if data == 'train':
        return A.Compose([
            A.Resize(CFG.img_size[0], CFG.img_size[1]),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=60, p=0.5),
        ])

    elif data == 'valid':
        return A.Compose([
            A.Resize(CFG.img_size[0], CFG.img_size[1]),
        ])
    
### extra aug
# A.HueSaturationValue(
#                 hue_shift_limit=0.2, 
#                 sat_shift_limit=0.2, 
#                 val_shift_limit=0.2, 
#                 p=0.5 ),
# A.RandomBrightnessContrast(
#                 brightness_limit=(-0.1,0.1), 
#                 contrast_limit=(-0.1, 0.1), 
#                 p=0.5  ),

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">SAM/AdamP Optimizer</span>

*Sharpness-Aware Minimization for Efficiently Improving Generalization* : https://arxiv.org/abs/2010.01412

*AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights*: https://arxiv.org/abs/2006.08217

In [None]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">BottleNeck Transformer</span>

In [None]:
# Helper: Converts the activation function for the entire network

def convert_act_cls(model, layer_type_old, layer_type_new):
    conversion_count = 0
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = convert_act_cls(module, layer_type_old, layer_type_new)
        if type(module) == layer_type_old:
            layer_old = module
            layer_new = layer_type_new
            model._modules[name] = layer_new
    return model

In [None]:
from bottleneck_transformer_pytorch import BottleStack

bot_layer_1 = BottleStack(
    dim = 256,      #1024,              # channels in
    fmap_size = 16, # 14         # feature map size # need to adjust depend on image size/depth
    dim_out = 512, # 2048,         # channels out
    proj_factor = 4,        # projection factor
    downsample = True,      # downsample on first layer or not
    heads = 4,              # number of heads
    dim_head = 128,         # dimension per head, defaults to 128
    rel_pos_emb = True,     # use relative positional embedding - uses absolute if False
    activation = nn.SiLU()  # activation throughout the network
)

bot_layer_2 = BottleStack(
    dim = 2048,              # channels in
    fmap_size = 16,         # feature map size
    dim_out = 2048,         # channels out
    proj_factor = 4,        # projection factor
    downsample = True,      # downsample on first layer or not
    heads = 4,              # number of heads
    dim_head = 128,         # dimension per head, defaults to 128
    rel_pos_emb = True,     # use relative positional embedding - uses absolute if False
    activation = nn.SiLU()  # activation throughout the network
)

bot_layer_3 = BottleStack(
    dim = 2048,             # channels in
    fmap_size = 8,          # feature map size
    dim_out = 2048,         # channels out
    proj_factor = 4,        # projection factor
    downsample = True,      # downsample on first layer or not
    heads = 4,              # number of heads
    dim_head = 128,         # dimension per head, defaults to 128
    rel_pos_emb = True,     # use relative positional embedding - uses absolute if False
    activation = nn.SiLU()  # activation throughout the network
)


## stack of BoT layers with 4xheads + pos emb
# BotStackLayer = nn.Sequential(bot_layer_1, bot_layer_2, bot_layer_3)
BotStackLayer = nn.Sequential(bot_layer_1)

# create model 
def get_BoTModel(n_channels=1):
    # base model 
    model = timm.create_model(CFG.model_name, pretrained=True, in_chans=n_channels)
    num_features = model.fc.in_features  # print(num_features)

    # BoT layers
    model.layer4 = BotStackLayer
    # head
    model.fc = nn.Linear(num_features, CFG.num_classes)
    # convert ReLU activation to SiLU
    model = convert_act_cls(model, nn.ReLU, nn.SiLU()) 
    return model 

# ## WiP adjust fmaps for various image resolutions 

# # ## debug model call
# print(CFG.model_name)
# model = get_BoTModel(n_channels=1)
# summary(model, (1, 256, 256))

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">Train Engine</span>

In [None]:
# ====================================================
# LR scheduler 
# ====================================================

def get_scheduler(optimizer):
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
    return scheduler

In [None]:
## Usual Torch train/eval functions 
## adapted from here: https://www.kaggle.com/yasufuminakama/seti-nfnet-l0-starter-training

def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    if CFG.apex:
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    preds = []
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if CFG.apex:
            with autocast():
                y_preds = model(images)
                loss = criterion(y_preds.view(-1), labels)
        else:
            y_preds = model(images)
            loss = criterion(y_preds.view(-1), labels)
        # record loss
        losses.update(loss.item(), batch_size)
        # save predictions
        preds.append(y_preds.sigmoid().detach().to('cpu').numpy())
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        if CFG.apex:
            scaler.scale(loss).backward()
        else:
            loss.backward()     
            
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            if CFG.apex:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()
            global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                    # 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f} ({loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  #'LR: {lr:.6f}  '
                  .format(epoch+1, step, len(train_loader), 
                        # batch_time=batch_time,data_time=data_time, 
                      loss=losses,
                   remain=timeSince(start, float(step+1)/len(train_loader)),
                   grad_norm=grad_norm,
                   #lr=scheduler.get_lr()[0],
                   ))
    predictions = np.concatenate(preds)    
    return losses.avg, predictions


def valid_fn(valid_loader, model, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    preds = []
    start = end = time.time()
    for step, (images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            y_preds = model(images)
        loss = criterion(y_preds.view(-1), labels)
        losses.update(loss.item(), batch_size)
        # save predictions
        preds.append(y_preds.sigmoid().to('cpu').numpy())
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                    # 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f} ({loss.avg:.4f}) '
                  .format(
                   step, len(valid_loader), 
                   #  batch_time=batch_time,data_time=data_time, 
                      loss=losses,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
    predictions = np.concatenate(preds)
    return losses.avg, predictions

In [None]:
def run_fold(model, criterion, optimizer, scheduler, device, fold=0, num_epochs=10):
    
    trn_idx = df_folds[df_folds['fold'] != fold].index
    val_idx = df_folds[df_folds['fold'] == fold].index

    train_folds = df_folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = df_folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_col].values

    train_ds = SetiDataset(train_folds, transform=get_transforms(data='train'))
    valid_ds = SetiDataset(valid_folds, transform=get_transforms(data='valid'))

    train_dl = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, 
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    valid_dl = DataLoader(valid_ds, batch_size=CFG.batch_size, shuffle=False, 
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # loop
    # ====================================================
    best_score = 0.
    best_loss = np.inf
    history = defaultdict(list)
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        # train
        avg_loss, train_preds = train_fn(train_dl, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        avg_val_loss, preds = valid_fn(valid_dl, model, criterion, device)

        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()
            
        # compute metrics
        score = get_score(valid_labels, preds)
        score_tr = get_score(train_folds[CFG.target_col].values.astype(np.float), train_preds)
                
        history['train_loss'].append(avg_loss)
        history['valid_loss'].append(avg_val_loss)
        history['train_auc'].append(score_tr)
        history['valid_auc'].append(score)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')

        if score > best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(), 
                        'preds': preds},
                        f'BoT_{CFG.model_name}_fold{fold}_best_score.pth')
        
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
            torch.save({'model': model.state_dict(), 
                        'preds': preds},
                        f'BoT_{CFG.model_name}_fold{fold}_best_loss.pth')
    
    valid_folds['preds'] = torch.load(f'BoT_{CFG.model_name}_fold{fold}_best_loss.pth', map_location=torch.device('cpu'))['preds']

    return valid_folds, history

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">Run Training</span>

In [None]:
# define model
model = get_BoTModel(n_channels=1)
model.to(device);

# Select Optim, LR, Loss 
optimizer = AdamP(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
# optimizer = SAM(model.parameters(), AdamP, lr=CFG.lr, weight_decay=CFG.weight_decay)
# optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
criterion = nn.BCEWithLogitsLoss()
scheduler = get_scheduler(optimizer)

In [None]:
oof_df = pd.DataFrame()
hist_folds = []
for fold_id in range(CFG.n_fold):
    if fold_id in CFG.trn_fold:
        
        print(f'Start Training Fold {fold_id} with bs {CFG.batch_size}')
        print('-'*40)
    
        oof_, hist_ = run_fold(model, criterion, optimizer, scheduler, device=device, fold=fold_id, num_epochs=CFG.num_epochs)
        oof_df = pd.concat([oof_df, oof_])
        hist_folds.append(hist_)
        # Fold score
        get_result(oof_)
        torch.cuda.empty_cache()
        gc.collect()
    
# CV score
get_result(oof_df)
oof_df.to_csv('oof_df.csv', index=False)

<span style="color: #0087e4; font-family: courier; font-size: 1.5em; font-weight: 300;">Visualize Training & Metrics</span>

In [None]:
plt.style.use('fivethirtyeight')

fig = plt.figure(figsize=(14,6))
plt.plot(hist_folds[0]['train_loss'], label='train loss')
plt.plot(hist_folds[0]['valid_loss'], label='valid loss')
plt.legend()
plt.title(f'Loss Curve [Fold {0}]');

In [None]:
fig = plt.figure(figsize=(14,6))
plt.plot(hist_folds[0]['train_auc'], label='train auc')
plt.plot(hist_folds[0]['valid_auc'], label='valid auc')
plt.legend()
plt.title(f'AUC Curve [Fold {0}]');

In [None]:
# ====================================================
# Helper functions
# ====================================================

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['file_path'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.file_names[idx]
        image = np.load(file_path)[[0, 2, 4]]
        image = np.vstack(image)
        image = image.transpose((1, 0))
        
        if self.transform:
            image = self.transform(image=image)['image']
        else:
            image = image[np.newaxis,:,:]
            image = torch.from_numpy(image).float()
            # image = image.astype("float")[..., np.newaxis]
        return image





def inference(model, states, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        images = images.to(device)
        avg_preds = []
        for state in states:
            model.load_state_dict(state)  # ['model']
            model.eval()
            with torch.no_grad():
                y_preds = model(images.float())
            avg_preds.append(y_preds.sigmoid().to('cpu').numpy())
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs

In [None]:
# # ====================================================
# # inference
# # ====================================================

# model = CustomModel() # ...

# states = [torch.load(f'xxx.pth') for fold in CFG.trn_fold]
# print('no. of checkpoints:', len(states))

# test_dataset = TestDataset(test, transform=get_transforms(data='valid'))
# test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, 
#                          num_workers=CFG.num_workers, pin_memory=False)

# predictions = inference(model, states, test_loader, device=device)

![Upvote!](https://img.shields.io/badge/Upvote-If%20you%20like%20my%20work-07b3c8?style=for-the-badge&logo=kaggle)

# References: 

- [seresnext50-but-with-attention](https://www.kaggle.com/debarshichanda/seresnext50-but-with-attention) by [debarshichanda](https://www.kaggle.com/debarshichanda) (Model + Optimizer)

- [seti-nfnet-l0-starter-training](https://www.kaggle.com/yasufuminakama/seti-nfnet-l0-starter-training) by [yasufuminakama](https://www.kaggle.com/yasufuminakama) (Train engine + Helpers)

- [Bottleneck Transformers for Visual Recognition](https://arxiv.org/abs/2101.11605)

- [Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412)