## APTOS2019 train kernel (PyTorch)

### Flags for training

In [1]:
FOLD = 0
MODEL = 'efficientnet-b5'

### train params

In [2]:
train_params = {
    'n_splits': 5,
    'n_epochs': 12,
    'lr': 1e-3,
    'base_lr': 1e-4,
    'max_lr': 3e-3,
    'step_factor': 6,
    'train_batch_size': 32,
    'test_batch_size': 32,
    'accumulation_steps': 10,
}

### packages

In [3]:
! pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ../input/nvidia-apex/repository/NVIDIA-apex-665b2dd

  cmdoptions.check_install_build_global(options)
Created temporary directory: /tmp/pip-ephem-wheel-cache-iskw613t
Created temporary directory: /tmp/pip-req-tracker-sqh75l1i
Created requirements tracker '/tmp/pip-req-tracker-sqh75l1i'
Created temporary directory: /tmp/pip-install-tntyfdd3
Processing /kaggle/input/nvidia-apex/repository/NVIDIA-apex-665b2dd
  Created temporary directory: /tmp/pip-req-build-e66udlv4
  Added file:///kaggle/input/nvidia-apex/repository/NVIDIA-apex-665b2dd to build tracker '/tmp/pip-req-tracker-sqh75l1i'
    Running setup.py (path:/tmp/pip-req-build-e66udlv4/setup.py) egg_info for package from file:///kaggle/input/nvidia-apex/repository/NVIDIA-apex-665b2dd
    Running command python setup.py egg_info
    torch.__version__  =  1.1.0
    running egg_info
    creating pip-egg-info/apex.egg-info
    writing pip-egg-info/apex.egg-info/PKG-INFO
    writing dependency_links to pip-egg-info/apex.egg-info/dependency_links.txt
    writing top-level names

In [4]:
import sys
sys.path.append('../input/pretrained-models-pytorch/repository/Cadene-pretrained-models.pytorch-021d978')
sys.path.append('../input/efficientnet-pytorch-repository/repository/lukemelas-EfficientNet-PyTorch-50a2bf2')

In [5]:
import gc
import os
import random
import time
from contextlib import contextmanager
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import scipy as sp
from fastprogress import master_bar, progress_bar
from functools import partial
from sklearn.metrics import cohen_kappa_score

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset

import pretrainedmodels
from efficientnet_pytorch import EfficientNet

from albumentations import (
    Compose, HorizontalFlip, IAAAdditiveGaussianNoise, Normalize, OneOf,
    RandomBrightness, RandomContrast, Resize, VerticalFlip, Rotate, ShiftScaleRotate,
    RandomBrightnessContrast, OpticalDistortion, GridDistortion, ElasticTransform, Cutout
)
from albumentations.pytorch import ToTensor

from apex import amp

from fastai.layers import Flatten, AdaptiveConcatPool2d

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### utils

In [7]:
@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')

In [8]:
def init_logger(log_file='train.log'):
    from logging import getLogger, DEBUG, FileHandler,  Formatter,  StreamHandler
    
    log_format = '%(asctime)s %(levelname)s %(message)s'
    
    stream_handler = StreamHandler()
    stream_handler.setLevel(DEBUG)
    stream_handler.setFormatter(Formatter(log_format))
    
    file_handler = FileHandler(log_file)
    file_handler.setFormatter(Formatter(log_format))
    
    logger = getLogger('APTOS')
    logger.setLevel(DEBUG)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

LOG_FILE = 'aptos-train.log'
LOGGER = init_logger(LOG_FILE)

In [9]:
def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

In [10]:
def quadratic_weighted_kappa(y_hat, y):
    return cohen_kappa_score(y_hat, y, weights='quadratic')

In [11]:
class OptimizedRounder():
    def __init__(self):
        self.coef_ = 0

    def _kappa_loss(self, coef, X, y):
        X_p = np.copy(X)
        for i, pred in enumerate(X_p):
            if pred < coef[0]:
                X_p[i] = 0
            elif pred >= coef[0] and pred < coef[1]:
                X_p[i] = 1
            elif pred >= coef[1] and pred < coef[2]:
                X_p[i] = 2
            elif pred >= coef[2] and pred < coef[3]:
                X_p[i] = 3
            else:
                X_p[i] = 4

        ll = quadratic_weighted_kappa(y, X_p)
        return -ll

    def fit(self, X, y):
        loss_partial = partial(self._kappa_loss, X=X, y=y)
        initial_coef = [0.5, 1.5, 2.5, 3.5]
        self.coef_ = sp.optimize.minimize(loss_partial, initial_coef, method='nelder-mead')

    def predict(self, X, coef):
        X_p = np.copy(X)
        for i, pred in enumerate(X_p):
            if pred < coef[0]:
                X_p[i] = 0
            elif pred >= coef[0] and pred < coef[1]:
                X_p[i] = 1
            elif pred >= coef[1] and pred < coef[2]:
                X_p[i] = 2
            elif pred >= coef[2] and pred < coef[3]:
                X_p[i] = 3
            else:
                X_p[i] = 4
        return X_p

    def coefficients(self):
        return self.coef_['x']

In [12]:
# NOTE: official CyclicLR implementation doesn't work now

from torch.optim.lr_scheduler import _LRScheduler

class CyclicLR(_LRScheduler):
    def __init__(self, optimizer, base_lr, max_lr, step_size, gamma=0.99, mode='triangular', last_epoch=-1):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.step_size = step_size
        self.gamma = gamma
        self.mode = mode
        assert mode in ['triangular', 'triangular2', 'exp_range']
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        new_lr = []
        # make sure that the length of base_lrs doesn't change. Dont care about the actual value
        for base_lr in self.base_lrs:
            cycle = np.floor(1 + self.last_epoch / (2 * self.step_size))
            x = np.abs(float(self.last_epoch) / self.step_size - 2 * cycle + 1)
            if self.mode == 'triangular':
                lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x))
            elif self.mode == 'triangular2':
                lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x)) / float(2 ** (cycle - 1))
            elif self.mode == 'exp_range':
                lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x)) * (self.gamma ** (self.last_epoch))
            new_lr.append(lr)
        return new_lr

### dataset

In [13]:
APTOS_DIR = Path('../input/aptos2019-blindness-detection')
APTOS_TRAIN_DIR = Path('../input/aptos-train-dataset')

APTOS_TRAIN_IMAGES = APTOS_TRAIN_DIR / 'aptos-train-images/aptos-train-images'

#APTOS_FOLDS = Path('../input/aptos-folds/folds.csv')
#APTOS_FOLDS = Path('../input/aptos-folds/jpeg_folds.csv')
#APTOS_TRAIN_FOLDS = Path('../input/aptos-folds/jpeg_folds_all.csv')
#APTOS_VALID_FOLDS = Path('../input/aptos-folds/png_folds_all.csv')
APTOS_TRAIN_FOLDS = Path('../input/aptos-folds/2015_5folds.csv')
APTOS_VALID_FOLDS = Path('../input/aptos-folds/2019_5folds.csv')

ID_COLUMN = 'id_code'
TARGET_COLUMN = 'diagnosis'

In [14]:
PRETRAINED_DIR = Path('../input/pytorch-pretrained-models')
EFFICIENTNET_PRETRAINED_DIR = Path('../input/efficientnet-pytorch')

PRETRAINED_MAPPING = {
    # ResNet
    'resnet18': PRETRAINED_DIR / 'resnet18-5c106cde.pth', 
    'resnet34': PRETRAINED_DIR / 'resnet34-333f7ec4.pth',
    'resnet50': PRETRAINED_DIR / 'resnet50-19c8e357.pth',
    'resnet101': PRETRAINED_DIR / 'resnet101-5d3b4d8f.pth',
    'resnet152': PRETRAINED_DIR / 'resnet152-b121ed2d.pth',

    # ResNeXt
    'resnext101_32x4d': PRETRAINED_DIR / 'resnext101_32x4d-29e315fa.pth',
    'resnext101_64x4d': PRETRAINED_DIR / 'resnext101_64x4d-e77a0586.pth',

    # WideResNet
    #'wideresnet50'

    # DenseNet
    'densenet121': PRETRAINED_DIR / 'densenet121-fbdb23505.pth',
    'densenet169': PRETRAINED_DIR / 'densenet169-f470b90a4.pth',
    'densenet201': PRETRAINED_DIR / 'densenet201-5750cbb1e.pth',
    'densenet161': PRETRAINED_DIR / 'densenet161-347e6b360.pth',

    # SE-ResNet
    'se_resnet50': PRETRAINED_DIR / 'se_resnet50-ce0d4300.pth',
    'se_resnet101': PRETRAINED_DIR / 'se_resnet101-7e38fcc6.pth',
    'se_resnet152': PRETRAINED_DIR / 'se_resnet152-d17c99b7.pth',

    # SE-ResNeXt
    'se_resnext50_32x4d': PRETRAINED_DIR / 'se_resnext50_32x4d-a260b3a4.pth',
    'se_resnext101_32x4d': PRETRAINED_DIR / 'se_resnext101_32x4d-3b2fe3d8.pth',

    # SE-Net
    'senet154': PRETRAINED_DIR / 'senet154-c7b49a05.pth',

    # InceptionV3
    'inceptionv3': PRETRAINED_DIR / 'inception_v3_google-1a9a5a14.pth',

    # InceptionV4
    'inceptionv4': PRETRAINED_DIR / 'inceptionv4-8e4777a0.pth',

    # BNInception
    'bninception': PRETRAINED_DIR / 'bn_inception-52deb4733.pth',

    # InceptionResNetV2
    'inceptionresnetv2': PRETRAINED_DIR / 'inceptionresnetv2-520b38e4.pth',

    # Xception
    'xception': PRETRAINED_DIR / 'xception-43020ad28.pth',

    # DualPathNet
    'dpn68': PRETRAINED_DIR / 'dpn68-4af7d88d2.pth',
    'dpn98': PRETRAINED_DIR / 'dpn98-722954780.pth',
    'dpn131': PRETRAINED_DIR / 'dpn131-7af84be88.pth',
    'dpn68b': PRETRAINED_DIR / 'dpn68b_extra-363ab9c19.pth',
    'dpn92': PRETRAINED_DIR / 'dpn92_extra-fda993c95.pth',
    'dpn107': PRETRAINED_DIR / 'dpn107_extra-b7f9f4cc9.pth',

    # PolyNet
    'polynet': PRETRAINED_DIR / 'polynet-f71d82a5.pth',

    # NasNet-A-Large
    'nasnetalarge': PRETRAINED_DIR / 'nasnetalarge-a1897284.pth',

    # PNasNet-5-Large
    'pnasnet5large': PRETRAINED_DIR / 'pnasnet5large-bf079911.pth',

    # EfficientNet
    'efficientnet-b0': EFFICIENTNET_PRETRAINED_DIR / 'efficientnet-b0-08094119.pth',
    'efficientnet-b1': EFFICIENTNET_PRETRAINED_DIR / 'efficientnet-b1-dbc7070a.pth',
    'efficientnet-b2': EFFICIENTNET_PRETRAINED_DIR / 'efficientnet-b2-27687264.pth',
    'efficientnet-b3': EFFICIENTNET_PRETRAINED_DIR / 'efficientnet-b3-c8376fa2.pth',
    'efficientnet-b4': EFFICIENTNET_PRETRAINED_DIR / 'efficientnet-b4-e116e8b3.pth',
    'efficientnet-b5': EFFICIENTNET_PRETRAINED_DIR / 'efficientnet-b5-586e6cc6.pth',
}

In [15]:
class APTOSTrainDataset(Dataset):
    def __init__(self, image_dir, file_paths, labels, transform=None):
        self.image_dir = image_dir
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = f'{self.image_dir}/{self.file_paths[idx]}'
        label = torch.tensor(self.labels[idx]).float()
        
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image, label

### transforms

In [16]:
from albumentations import ImageOnlyTransform

def crop_image_from_gray(img, tol=7):
    """
    Crop out black borders
    https://www.kaggle.com/ratthachat/aptos-updated-preprocessing-ben-s-cropping
    """  
    if img.ndim ==2:
        mask = img>tol
        return img[np.ix_(mask.any(1),mask.any(0))]
    elif img.ndim==3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mask = gray_img>tol        
        check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
        if (check_shape == 0):
            return img
        else:
            img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
            img = np.stack([img1,img2,img3],axis=-1)
        return img


class CircleCrop(ImageOnlyTransform):
    def __init__(self, tol=7, always_apply=False, p=1.0):
        super().__init__(always_apply, p)
        self.tol = tol
    
    def apply(self, img, **params):
        img = crop_image_from_gray(img)    
    
        height, width, depth = img.shape    
    
        x = int(width/2)
        y = int(height/2)
        r = np.amin((x,y))
    
        circle_img = np.zeros((height, width), np.uint8)
        cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
        img = cv2.bitwise_and(img, img, mask=circle_img)
        img = crop_image_from_gray(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
        return img 
    

class CircleCropV2(ImageOnlyTransform):
    def __init__(self, tol=7, always_apply=False, p=1.0):
        super().__init__(always_apply, p)
        self.tol = tol
    
    def apply(self, img, **params):
        img = crop_image_from_gray(img)
        
        height, width, depth = img.shape
        largest_side = np.max((height, width))
        img = cv2.resize(img, (largest_side, largest_side))
    
        height, width, depth = img.shape    
    
        x = int(width/2)
        y = int(height/2)
        r = np.amin((x,y))
    
        circle_img = np.zeros((height, width), np.uint8)
        cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
        img = cv2.bitwise_and(img, img, mask=circle_img)
        img = crop_image_from_gray(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
        return img 

In [17]:
def get_transforms(*, data):
    assert data in ('train', 'valid')
    
    if data == 'train':
        return Compose([
            CircleCropV2(),
            Resize(256, 256),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            Rotate(p=0.5), 
            #ShiftScaleRotate(p=0.5),
            #RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.5),
            #OpticalDistortion(distort_limit=(0.9,1.0), shift_limit=0.05, interpolation=1, border_mode=4, 
            #                  value=None, always_apply=False, p=0.5),
            #GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, border_mode=4,
            #               value=None, always_apply=False, p=0.5),
            #ElasticTransform(alpha=1, sigma=50, alpha_affine=50, interpolation=1, border_mode=4,
            #                 value=None, always_apply=True, approximate=False, p=0.5),
            Cutout(p=0.25, max_h_size=25, max_w_size=25, num_holes=8),
            #OneOf([
            #    RandomBrightness(0.1, p=1),
            #    RandomContrast(0.1, p=1),
            #], p=0.25),
            RandomContrast(0.5, p=0.5),
            IAAAdditiveGaussianNoise(p=0.25),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensor(),
        ])
    
    elif data == 'valid':
        return Compose([
            CircleCropV2(),
            Resize(256, 256),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensor(),
        ])

### model

In [18]:
class ClassifierModule(nn.Sequential):
    def __init__(self, n_features):
        super().__init__(
            nn.BatchNorm1d(n_features),
            nn.Dropout(0.5),
            nn.Linear(n_features, n_features),
            nn.PReLU(),
            nn.BatchNorm1d(n_features),
            nn.Dropout(0.2),
            nn.Linear(n_features, 1),
        )

In [19]:
class CustomResNet(nn.Module):
    def __init__(self, model_name='resnet50', weights_path=None):
        assert model_name in ('resnet50', 'resnet101', 'resnet152')
        super().__init__()
        
        self.net = pretrainedmodels.__dict__[model_name](pretrained=None)
        self.net.load_state_dict(torch.load(weights_path))
        
        n_features = self.net.last_linear.in_features
        
        self.net.avgpool = nn.AdaptiveAvgPool2d(1)
        # self.net.avgpool = AdaptiveConcatPool2d(1)
        self.net.last_linear = ClassifierModule(n_features)
        
    def forward(self, x):
        return self.net(x)

In [20]:
class CustomResNeXt(nn.Module):
    def __init__(self, model_name='resnext101_32x4d', weights_path=None):
        assert model_name in ('resnext101_32x4d', 'resnext101_64x4d')
        super().__init__()
        
        self.net = pretrainedmodels.__dict__[model_name](pretrained=None)
        self.net.load_state_dict(torch.load(weights_path))
        
        n_features = self.net.last_linear.in_features
        
        self.net.avg_pool = nn.AdaptiveAvgPool2d(1)
        # self.net.avg_pool = AdaptiveConcatPool2d(1)
        self.net.last_linear = ClassifierModule(n_features)
        
    def forward(self, x):
        return self.net(x)

In [21]:
class CustomSENet(nn.Module):
    def __init__(self, model_name='se_resnet50', weights_path=None):
        assert model_name in ('senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', 'se_resnext101_32x4d')
        super().__init__()
        
        self.net = pretrainedmodels.__dict__[model_name](pretrained=None)
        self.net.load_state_dict(torch.load(weights_path))
        
        n_features = self.net.last_linear.in_features
        
        self.net.avg_pool = nn.AdaptiveAvgPool2d(1)
        # self.net.avg_pool = AdaptiveConcatPool2d(1)
        self.net.last_linear = ClassifierModule(n_features)
        
    def forward(self, x):
        return self.net(x)

In [22]:
class CustomEfficientNet(nn.Module):
    def __init__(self, model_name='efficientnet-b0', weights_path=None):
        assert model_name in ('efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5')
        super().__init__()
        
        self.net = EfficientNet.from_name(model_name)
        self.net.load_state_dict(torch.load(weights_path))
        
        n_features = self.net._fc.in_features
        
        self.net._fc = ClassifierModule(n_features)
        
    def forward(self, x):
        return self.net(x)

### entry point

In [23]:
LOGGER.debug(f'Fold: {FOLD}')
LOGGER.debug(f'Model: {MODEL}')
LOGGER.debug(f'Train params: {train_params}')

2019-08-13 20:29:22,451 DEBUG Fold: 0
2019-08-13 20:29:22,452 DEBUG Model: efficientnet-b5
2019-08-13 20:29:22,454 DEBUG Train params: {'n_splits': 5, 'n_epochs': 12, 'lr': 0.001, 'base_lr': 0.0001, 'max_lr': 0.003, 'step_factor': 6, 'train_batch_size': 32, 'test_batch_size': 32, 'accumulation_steps': 10}


In [24]:
with timer('Prepare train and valid sets'):
    with timer('  * load folds csv'):
        #folds = pd.read_csv(APTOS_FOLDS)
        #train_fold = folds[folds['fold'] != FOLD].reset_index(drop=True)
        #valid_fold = folds[folds['fold'] == FOLD].reset_index(drop=True)
        folds = pd.read_csv(APTOS_TRAIN_FOLDS)
        train_fold = folds[folds['fold'] != FOLD].reset_index(drop=True)
        #valid_fold2015 = folds[folds['fold'] == FOLD].reset_index(drop=True)
        #valid_fold2019 = pd.read_csv(APTOS_VALID_FOLDS)
        #valid_fold = pd.concat([valid_fold2015, valid_fold2019]).reset_index(drop=True)
        valid_fold = pd.read_csv(APTOS_VALID_FOLDS)
    
    with timer('  * define dataset'):
        APTOSTrainDataset = partial(APTOSTrainDataset, image_dir=APTOS_TRAIN_IMAGES)
        train_dataset = APTOSTrainDataset(file_paths=train_fold.id_code.values,
                                          labels=train_fold.diagnosis.values[:, np.newaxis],
                                          transform=get_transforms(data='train'))
        valid_dataset = APTOSTrainDataset(file_paths=valid_fold.id_code.values,
                                          labels=valid_fold.diagnosis.values[:, np.newaxis],
                                          transform=get_transforms(data='valid'))
        
    with timer('  * define dataloader'):
        train_loader = DataLoader(train_dataset,
                                  batch_size=train_params['train_batch_size'],
                                  shuffle=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=train_params['test_batch_size'],
                                  shuffle=False)
        
LOGGER.debug(f'train size: {len(train_dataset)}, valid size: {len(valid_dataset)}')

2019-08-13 20:29:22,470 INFO [Prepare train and valid sets] start
2019-08-13 20:29:22,471 INFO [  * load folds csv] start
2019-08-13 20:29:22,522 INFO [  * load folds csv] done in 0 s.
2019-08-13 20:29:22,523 INFO [  * define dataset] start
2019-08-13 20:29:22,525 INFO [  * define dataset] done in 0 s.
2019-08-13 20:29:22,525 INFO [  * define dataloader] start
2019-08-13 20:29:22,526 INFO [  * define dataloader] done in 0 s.
2019-08-13 20:29:22,527 INFO [Prepare train and valid sets] done in 0 s.
2019-08-13 20:29:22,528 DEBUG train size: 28099, valid size: 3534


In [25]:
with timer('Train model'):
    n_epochs = train_params['n_epochs']
    lr = train_params['lr']
    base_lr = train_params['base_lr']
    max_lr = train_params['max_lr']
    step_factor = train_params['step_factor']
    test_batch_size = train_params['test_batch_size']
    accumulation_steps = train_params['accumulation_steps']
    
    model = CustomEfficientNet(model_name=MODEL, weights_path=PRETRAINED_MAPPING[MODEL])
    model.to(device)
    
    optimizer = Adam(model.parameters(), lr=lr, amsgrad=False)
    #optimizer = SGD(model.parameters(), lr=lr, weight_decay=4e-5, momentum=0.9, nesterov=True)
    scheduler = CyclicLR(optimizer,
                         base_lr=base_lr,
                         max_lr=max_lr,
                         step_size=len(train_loader) * step_factor)

    model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
    
    criterion = nn.MSELoss()
    #criterion = nn.SmoothL1Loss()
    
    optimized_rounder = OptimizedRounder()
    y_true = valid_fold.diagnosis.values
    
    for epoch in range(n_epochs):
        start_time = time.time()

        model.train()
        avg_loss = 0.

        optimizer.zero_grad()

        for i, (images, labels) in enumerate(train_loader):
            if isinstance(scheduler, CyclicLR):
                scheduler.step()

            y_preds = model(images.to(device))
            loss = criterion(y_preds, labels.to(device))

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            if (i+1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            avg_loss += loss.item() / accumulation_steps / len(train_loader)

        if not isinstance(scheduler, CyclicLR):
            scheduler.step()

        model.eval()
        valid_preds = np.zeros((len(valid_dataset)))
        avg_val_loss = 0.

        for i, (images, labels) in enumerate(valid_loader):
            with torch.no_grad():
                y_preds = model(images.to(device)).detach()

            loss = criterion(y_preds, labels.to(device))
            valid_preds[i * test_batch_size: (i+1) * test_batch_size] = y_preds[:, 0].to('cpu').numpy()

            avg_val_loss += loss.item() / len(valid_loader)

        optimized_rounder.fit(valid_preds, y_true)
        coefficients = optimized_rounder.coefficients()
        final_preds = optimized_rounder.predict(valid_preds, coefficients)
        qwk = quadratic_weighted_kappa(y_true, final_preds)

        elapsed = time.time() - start_time

        LOGGER.debug(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.debug(f'          - qwk: {qwk:.6f}  coefficients: {coefficients}')

        # FIXME: save all epochs for debug
        torch.save(model.state_dict(), f'{MODEL}_fold{FOLD}_epoch{epoch+1}.pth')

2019-08-13 20:29:22,556 INFO [Train model] start
2019-08-13 21:15:08,407 DEBUG   Epoch 1 - avg_train_loss: 0.1044  avg_val_loss: 0.5345  time: 2740s
2019-08-13 21:15:08,410 DEBUG           - qwk: 0.837237  coefficients: [0.593729 0.964637 2.150397 4.286211]
2019-08-13 21:59:22,356 DEBUG   Epoch 2 - avg_train_loss: 0.0632  avg_val_loss: 0.4578  time: 2654s
2019-08-13 21:59:22,357 DEBUG           - qwk: 0.852277  coefficients: [0.573924 1.478256 2.123378 3.624947]
2019-08-13 22:43:14,399 DEBUG   Epoch 3 - avg_train_loss: 0.0569  avg_val_loss: 0.4669  time: 2632s
2019-08-13 22:43:14,401 DEBUG           - qwk: 0.876678  coefficients: [0.488998 1.423123 2.826116 3.701562]
2019-08-13 23:26:27,984 DEBUG   Epoch 4 - avg_train_loss: 0.0558  avg_val_loss: 1.2862  time: 2593s
2019-08-13 23:26:27,986 DEBUG           - qwk: 0.865539  coefficients: [0.371486 1.19794  3.909248 4.590181]
2019-08-14 00:09:34,049 DEBUG   Epoch 5 - avg_train_loss: 0.0553  avg_val_loss: 0.5706  time: 2586s
2019-08-14 00:0