In [1]:
# !pip install pytorch_ranger

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

%matplotlib inline

In [3]:
os.listdir('../input/cassava-leaf-disease-classification/')

['train_tfrecords',
 'sample_submission.csv',
 'test_tfrecords',
 'label_num_to_disease_map.json',
 'train_images',
 'train.csv',
 'test_images']

In [4]:
train = pd.read_csv('../input/cassava-leaf-disease-merged/merged.csv')
test = pd.read_csv('../input/cassava-leaf-disease-classification//sample_submission.csv')
label_map = pd.read_json('../input/cassava-leaf-disease-classification/label_num_to_disease_map.json', orient='index')

display(train.head())
display(test.head())
display(label_map)

Unnamed: 0,image_id,label,source
0,1000015157.jpg,0,2020
1,1000201771.jpg,3,2020
2,100042118.jpg,1,2020
3,1000723321.jpg,1,2020
4,1000812911.jpg,3,2020


Unnamed: 0,image_id,label
0,2216849948.jpg,4


Unnamed: 0,0
0,Cassava Bacterial Blight (CBB)
1,Cassava Brown Streak Disease (CBSD)
2,Cassava Green Mottle (CGM)
3,Cassava Mosaic Disease (CMD)
4,Healthy


## Directory settings

In [5]:
OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
    
TRAIN_PATH = '../input/cassava-leaf-disease-merged/train'
TEST_PATH = '../input/cassava-leaf-disease-classification/test_images'

## CFG

In [6]:
class CFG:
    debug = False
    apex = False
    print_freq = 20
    num_workers = 4
    model_name = 'tf_efficientnet_b2_ns'
    size = 440
    scheduler = 'CosineAnnealingWarmRestarts'
    loss_train = 'BiTemperedLoss'
    epochs = 10
    T_0 = 10
    lr_1 = 5e-4
    lr_2 = 5e-5
    t1 = 0.9
    t2 = 1.5
    smooth = 1e-2
    min_lr = 1e-6
    batch_size = 16
    weight_decay = 1e-6
    gradient_accumulation_steps = 1
    max_grad_norm = 1000
    seed = 42
    target_size = 5
    target_col = 'label'
    n_fold = 5
    trn_fold = [0, 2]
    train = True
    inference = False
    
if CFG.debug:
    CFG.epochs = 3
    train = train.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)

## Library

In [7]:
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')

import os
import math
import time
import random
import shutil
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter

import scipy as sp
import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

from tqdm.auto import tqdm
from functools import partial

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

sys.path.append('../input/bi-tempered-loss-pytorch')
from bi_tempered_loss import *

# sys.path.append('../input/pytorch-optimizer')
# import torch_optimizer as optim

sys.path.append('../input/pytorch-sam')
from sam import SAM

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

import warnings 
warnings.filterwarnings('ignore')

if CFG.apex:
    from apex import amp

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Utils

In [8]:
def get_score(y_true, y_pred):
    return accuracy_score(y_true, y_pred)

@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f}')
    
def init_logger(log_file=OUTPUT_DIR+'train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

## CV split

In [9]:
folds = train.copy()
Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG.target_col])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
print(folds.groupby(['fold', CFG.target_col]).size())

fold  label
0     0         299
      1         695
      2         604
      3        3092
      4         578
1     0         299
      1         695
      2         604
      3        3092
      4         578
2     0         298
      1         695
      2         603
      3        3093
      4         578
3     0         298
      1         695
      2         603
      3        3093
      4         578
4     0         298
      1         696
      2         603
      3        3092
      4         578
dtype: int64


## Dataset

In [10]:
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.labels = df['label'].values
#         self.labels = pd.get_dummies(df['label']).values  # One Hot Encoding
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAIN_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).long()
        return image, label
    
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [11]:
# train_dataset = TrainDataset(train, transform=None)

# for i in range(1):
#     image, label = train_dataset[i]
#     plt.imshow(image)
#     plt.title(f'label: {label}')
#     plt.show()

## Transforms

In [12]:
def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            RandomResizedCrop(CFG.size, CFG.size), 
            Transpose(p=0.5), 
            HorizontalFlip(p=0.5), 
            VerticalFlip(p=0.5), 
            ShiftScaleRotate(p=0.5), 
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5), 
            RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5), 
            Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
            ), 
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(CFG.size, CFG.size), 
            Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
            ), 
            ToTensorV2(),
        ])

In [13]:
# train_dataset = TrainDataset(train, transform=get_transforms(data='train'))

# for i in range(1):
#     image, label = train_dataset[i]
#     plt.imshow(image[0])
#     plt.title(f'label: {label}')
#     plt.show()

## MODEL

In [14]:
class CustomEfficientNetB2ns(nn.Module):
    def __init__(self, model_name='tf_efficientnet_b2_ns', pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, CFG.target_size)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [15]:
# model = CustomEfficientNetB2ns(model_name=CFG.model_name, pretrained=False)
# train_dataset = TrainDataset(train, transform=get_transforms(data='train'))
# train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, 
#                           num_workers=4, pin_memory=True, drop_last=True)

# for image, label in train_loader:
#     print(image.size())
#     output = model(image)
#     print(output)
#     break

## Loss Functions

In [16]:
# ====================================================
# Label Smoothing
# ====================================================
class LabelSmoothingLoss(nn.Module): 
    def __init__(self, classes=5, smoothing=0.0, dim=-1): 
        super(LabelSmoothingLoss, self).__init__() 
        self.confidence = 1.0 - smoothing 
        self.smoothing = smoothing 
        self.cls = classes 
        self.dim = dim 
    def forward(self, pred, target): 
        pred = pred.log_softmax(dim=self.dim) 
        with torch.no_grad():
            true_dist = torch.zeros_like(pred) 
            true_dist.fill_(self.smoothing / (self.cls - 1)) 
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [17]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [18]:
class FocalCosineLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, xent=.1):
        super(FocalCosineLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

        self.xent = xent

        self.y = torch.Tensor([1]).cuda()

    def forward(self, input, target, reduction="mean"):
        cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), self.y, reduction=reduction)

        cent_loss = F.cross_entropy(F.normalize(input), target, reduce=False)
        pt = torch.exp(-cent_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * cent_loss

        if reduction == "mean":
            focal_loss = torch.mean(focal_loss)

        return cosine_loss + self.xent * focal_loss

In [19]:
class SymmetricCrossEntropy(nn.Module):

    def __init__(self, alpha=0.1, beta=1.0, num_classes=5):
        super(SymmetricCrossEntropy, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.num_classes = num_classes

    def forward(self, logits, targets, reduction='mean'):
        onehot_targets = torch.eye(self.num_classes)[targets].cuda()
        ce_loss = F.cross_entropy(logits, targets, reduction=reduction)
        rce_loss = (-onehot_targets*logits.softmax(1).clamp(1e-7, 1.0).log()).sum(1)
        if reduction == 'mean':
            rce_loss = rce_loss.mean()
        elif reduction == 'sum':
            rce_loss = rce_loss.sum()
        return self.alpha * ce_loss + self.beta * rce_loss

In [20]:
def log_t(u, t):
    """Compute log_t for `u'."""
    if t==1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t==1:
        return u.exp()
    else:
        return (1.0 + (1.0-t)*u).relu().pow(1.0 / (1.0 - t))

def compute_normalization_fixed_point(activations, t, num_iters):

    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
                exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                logt_partition.pow(1.0-t)

    logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants

def compute_normalization_binary_search(activations, t, num_iters):

    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
                (normalized_activations > -1.0 / (1.0-t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0/effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower)/2.0
        sum_probs = torch.sum(
                exp_t(normalized_activations - logt_partition, t),
                dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
                lower * update + (1.0-update) * logt_partition,
                shape_partition)
        upper = torch.reshape(
                upper * (1.0 - update) + update * logt_partition,
                shape_partition)

    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """
    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t=t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants 
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output
        
        return grad_input, None, None

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example. 
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)

def tempered_sigmoid(activations, t, num_iters = 5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)

def bi_tempered_binary_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing = 0.0,
        num_iters=5,
        reduction='mean'):

    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
        torch.zeros_like(activations)],
        dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
        1.0 - labels.to(activations.dtype)],
        dim=-1)
    return bi_tempered_logistic_loss(internal_activations, 
            internal_labels,
            t1,
            t2,
            label_smoothing = label_smoothing,
            num_iters = num_iters,
            reduction = reduction)

def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

In [21]:
class BiTemperedLogisticLoss(nn.Module): 
    def __init__(self, t1, t2, smoothing=0.0): 
        super(BiTemperedLogisticLoss, self).__init__() 
        self.t1 = t1
        self.t2 = t2
        self.smoothing = smoothing
    def forward(self, logit_label, truth_label):
        loss_label = bi_tempered_logistic_loss(
            logit_label, truth_label,
            t1=self.t1, t2=self.t2,
            label_smoothing=self.smoothing,
            reduction='none'
        )
        
        loss_label = loss_label.mean()
        return loss_label

In [22]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def asMinutes(s):
    """秒を分に変換する関数"""
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    """経過時間の測定と終了時間の予測を行う関数
    Parameters
    ----------
    since : float
        実験を始めた時刻
    percent : float
        実験が進んだ割合
        
    Returns
    -------
    s : 経過時間
    re : 終了までの時間の予測
    """
    now = time.time()
    s = now - since  # 経過時間の測定
    es = s / percent  # 終了時間の予測
    re = es - s  # 残り時間の予想
    return '%s (remain %s)' % (asMinutes(s), asMinutes(re))

def train_fn(train_loader, model, loss_train, loss_metric, optimizer, epoch, shechduler, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        y_preds = model(images)
        metric = loss_metric(y_preds, labels)
        loss = loss_train(y_preds, labels)
        # record loss
        losses.update(metric.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        if CFG.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else: 
            loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            optimizer.first_step(zero_grad=True)
#             optimizer.zero_grad()
            global_step += 1
        # measure elapsed time
        loss_train(model(images), labels).backward()
#         loss = torch.mean(loss)
#         loss.backward()
        optimizer.second_step(zero_grad=True)
        
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}]'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})'
                  'Elapsed {remain:s}' 
                  'Loss: {loss.val:.4f}({loss.avg:.4f})' 
                  'Grad: {grad_norm:.4f}  '
                  .format(epoch+1, step, len(train_loader), batch_time=batch_time, 
                          data_time=data_time, loss=losses, 
                          remain=timeSince(start, float(step+1)/len(train_loader)), 
                          grad_norm=grad_norm))
    return losses.avg

def valid_fn(valid_loader, model, loss_metric, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    preds = []
    start = end = time.time()
    for step, (images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            y_preds = model(images)
        loss = loss_metric(y_preds, labels)
        losses.update(loss.item(), batch_size)
        # record accuracy
        preds.append(y_preds.softmax(1).to('cpu').numpy())
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
            
    predictions = np.concatenate(preds)
    return losses.avg, predictions

def inference(model, states, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        images = images.to(device)
        avgpreds = []
        for state in states:
            model.load_state_dict(state['model'])
            model.eval()
            with torch.no_grad():
                y_preds = model(images)
            avg_preds.append(y_preds.softmax(1).to('cpu').numpy())
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs

## Train loop

In [23]:
# ======================================================
# Train loop
# ======================================================

def train_loop(folds, fold):
    
    seed_torch(seed=CFG.seed)    
    
    LOGGER.info(f'========== fold: {fold} training ============')
    
    # ======================================================
    # loader
    # ======================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index
    
    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    
    train_dataset = TrainDataset(train_folds, 
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds, 
                                 transform=get_transforms(data='valid'))
    
    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=True, 
                              num_workers=CFG.num_workers, 
                              pin_memory=True, 
                              drop_last=False)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers, 
                              pin_memory=True, 
                              drop_last=False)
    
    # ===============================================
    # scheduler
    # ===============================================
    def get_scheduler(optimizer):
        if CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
        return scheduler
    
    # ===============================================
    # model & optimizer
    # ===============================================
    model = CustomEfficientNetB2ns(CFG.model_name, pretrained=True)
    
    # 最初の3epochはclassifier層以外全て凍結する。
    for name, param in model.model.named_parameters():
        if 'classifier' not in name:
            param.requires_grad=False
    
    model.to(device)
    
    base_optimizer = Adam
    optimizer = SAM(model.parameters(), base_optimizer, lr=CFG.lr_1, weight_decay=CFG.weight_decay, amsgrad=False)
    
    scheduler = get_scheduler(optimizer)
    
    # ===============================================
    # apex 
    # ===============================================
    if CFG.apex:
        model.optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
        
    # ===============================================
    # loop
    # ===============================================
    def get_loss_train():
        if CFG.loss_train == 'CrossEntropyLoss':
            loss_train = nn.CrossEntropyLoss()
        elif CFG.loss_train == 'LabelSmoothing':
            loss_train = LabelSmoothingLoss(classes=CFG.target_size, smoothing=CFG.smooth)
        elif CFG.loss_train == 'FocalLoss':
            loss_train = FocalLoss().to(device)
        elif CFG.loss_train == 'FocalCosineLoss':
            loss_train = FocalCosineLoss()
        elif CFG.loss_train == 'SymmetricCrossEntropyLoss':
            loss_train = SymmetricCrossEntropy().to(device)
        elif CFG.loss_train == 'BiTemperedLoss':
            loss_train = BiTemperedLogisticLoss(t1=CFG.t1, t2=CFG.t2, smoothing=CFG.smooth)
        return loss_train
    
    loss_train = get_loss_train()
    LOGGER.info(f'loss_train: {loss_train}')
    loss_metric = nn.CrossEntropyLoss()
    
    best_score = 0.
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        if epoch == 1:
            
            # 2epoch目に重みを全て解凍する
            for param in model.model.parameters():
                param.requires_grad = True
                
            # 学習率を4e-3から4e-4に落とす
            base_optimizer = Adam
            optimizer = SAM(model.parameters(), base_optimizer, lr=CFG.lr_2, weight_decay=CFG.weight_decay, amsgrad=False)
            scheduler = get_scheduler(optimizer)

            LOGGER.info('requires_grad of all parameters are unlocked')
            
        
        # train
        avg_loss = train_fn(train_loader, model, loss_train, loss_metric, optimizer, epoch, scheduler, device)
        
        # eval
        avg_val_loss, preds = valid_fn(valid_loader, model, loss_metric, device)
        valid_labels = valid_folds[CFG.target_col].values
        
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()
        
        # scoring
        score = get_score(valid_labels, preds.argmax(1))
        
        elapsed = time.time() - start_time
        
        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Accuracy: {score}')
        
        if score > best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(), 
                        'preds': preds}, 
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
            
        # inference用に全て保存しておく        
        torch.save({'model': model.state_dict()}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')
    
    check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')
    valid_folds[[str(c) for c in range(5)]] = check_point['preds']
    valid_folds['preds'] = check_point['preds'].argmax(1)
    
    return valid_folds

In [24]:
# ====================================================
# main
# ====================================================
def main():
    
    """
    Prepare: 1.train 2.test 3.submission 4.folds
    """
    
    def get_result(result_df):
        preds = result_df['preds'].values
        labels = result_df[CFG.target_col].values
        score = get_score(labels, preds)
        LOGGER.info(f'Score: {score:<.5f}')
        
    if CFG.train:
        # train
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                _oof_df = train_loop(folds, fold)
                oof_df = pd.concat([oof_df, _oof_df])
                LOGGER.info(f'=============== fold: {fold} result ================')
                get_result(_oof_df)
                
                
        # CV result
        LOGGER.info(f'============ CV ============')
        get_result(oof_df)
        # save result
        oof_df.to_csv(OUTPUT_DIR+'oof_df.csv', index=False)
        
    if CFG.inference:
        # inference
        model = CustomEfficientNetB2ns(CFG.model_name, pretrained=False)
        states = [torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth') for fold in CFG.trn_fold]
        test_dataset = TestDataset(test, batch_size=CFG.batch_size, shuffle=False, pin_memory=True)
        predictions = inference(model, states, test_loader, device)
        # submission
        test['label'] = predictions.argmax(1)
        test[['image_id', 'label']].to_csv(OUTPUT_DIR+'submission.csv', index=False)

In [25]:
LOGGER.info(f'used device: {device}')

used device: cuda


In [26]:
if __name__ == '__main__':
    main()

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b2_ns-00306e48.pth
loss_train: BiTemperedLogisticLoss()


Epoch: [1][0/1317]Data 2.051 (2.051)Elapsed 0m 3s (remain 82m 2s)Loss: 1.6490(1.6490)Grad: 0.6519  
Epoch: [1][20/1317]Data 0.000 (0.101)Elapsed 0m 9s (remain 9m 24s)Loss: 1.3508(1.4228)Grad: 0.5230  
Epoch: [1][40/1317]Data 0.000 (0.052)Elapsed 0m 15s (remain 8m 1s)Loss: 1.5710(1.3855)Grad: 0.4310  
Epoch: [1][60/1317]Data 0.000 (0.036)Elapsed 0m 20s (remain 7m 10s)Loss: 0.8743(1.4140)Grad: 0.4157  
Epoch: [1][80/1317]Data 0.000 (0.027)Elapsed 0m 26s (remain 6m 49s)Loss: 1.0298(1.4214)Grad: 0.4736  
Epoch: [1][100/1317]Data 0.000 (0.027)Elapsed 0m 32s (remain 6m 35s)Loss: 0.5246(1.3964)Grad: 0.4674  
Epoch: [1][120/1317]Data 0.001 (0.025)Elapsed 0m 39s (remain 6m 32s)Loss: 1.9470(1.3889)Grad: 0.3552  
Epoch: [1][140/1317]Data 0.000 (0.027)Elapsed 0m 45s (remain 6m 22s)Loss: 1.8895(1.4030)Grad: 0.4727  
Epoch: [1][160/1317]Data 0.000 (0.025)Elapsed 0m 51s (remain 6m 12s)Loss: 2.1853(1.3904)Grad: 0.4707  
Epoch: [1][180/1317]Data 0.010 (0.026)Elapsed 0m 58s (remain 6m 5s)Loss: 1.1053(1.

Epoch 1 - avg_train_loss: 1.0561 avg_val_loss: 0.8585 time: 486s
Epoch 1 - Accuracy: 0.7201974183750949
Epoch 1 - Save Best Score: 0.7202 Model


EVAL: [329/330] Data 0.000 (0.132) Elapsed 1m 14s (remain 0m 0s) Loss: 0.4515(0.8585) 


requires_grad of all parameters are unlocked


Epoch: [2][0/1317]Data 1.209 (1.209)Elapsed 0m 2s (remain 54m 37s)Loss: 1.0423(1.0423)Grad: 2.8351  
Epoch: [2][20/1317]Data 0.000 (0.058)Elapsed 0m 16s (remain 17m 21s)Loss: 1.6622(1.0624)Grad: 3.4300  
Epoch: [2][40/1317]Data 0.000 (0.030)Elapsed 0m 31s (remain 16m 16s)Loss: 1.3478(0.9950)Grad: 2.4754  
Epoch: [2][60/1317]Data 0.000 (0.020)Elapsed 0m 45s (remain 15m 43s)Loss: 0.6971(1.0044)Grad: 1.8901  
Epoch: [2][80/1317]Data 0.000 (0.015)Elapsed 1m 0s (remain 15m 22s)Loss: 0.3124(0.9682)Grad: 2.5805  
Epoch: [2][100/1317]Data 0.000 (0.012)Elapsed 1m 14s (remain 15m 0s)Loss: 0.6269(0.9181)Grad: 1.6354  
Epoch: [2][120/1317]Data 0.000 (0.010)Elapsed 1m 29s (remain 14m 41s)Loss: 0.5956(0.8888)Grad: 2.1576  
Epoch: [2][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 24s)Loss: 1.3247(0.8682)Grad: 1.8946  
Epoch: [2][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 7s)Loss: 0.4398(0.8433)Grad: 2.1300  
Epoch: [2][180/1317]Data 0.000 (0.007)Elapsed 2m 12s (remain 13m 52s)Loss:

Epoch 2 - avg_train_loss: 0.6941 avg_val_loss: 0.5476 time: 1023s
Epoch 2 - Accuracy: 0.8559225512528473
Epoch 2 - Save Best Score: 0.8559 Model


EVAL: [329/330] Data 0.000 (0.117) Elapsed 1m 10s (remain 0m 0s) Loss: 0.0192(0.5476) 
Epoch: [3][0/1317]Data 1.219 (1.219)Elapsed 0m 2s (remain 50m 43s)Loss: 0.5405(0.5405)Grad: 2.0238  
Epoch: [3][20/1317]Data 0.000 (0.058)Elapsed 0m 16s (remain 17m 15s)Loss: 1.0601(0.6228)Grad: 1.7210  
Epoch: [3][40/1317]Data 0.000 (0.030)Elapsed 0m 31s (remain 16m 8s)Loss: 1.4291(0.5643)Grad: 2.1532  
Epoch: [3][60/1317]Data 0.000 (0.020)Elapsed 0m 45s (remain 15m 40s)Loss: 0.1261(0.6548)Grad: 0.9231  
Epoch: [3][80/1317]Data 0.000 (0.015)Elapsed 0m 59s (remain 15m 15s)Loss: 1.0948(0.6432)Grad: 1.6315  
Epoch: [3][100/1317]Data 0.000 (0.012)Elapsed 1m 14s (remain 14m 54s)Loss: 0.1590(0.6531)Grad: 1.7971  
Epoch: [3][120/1317]Data 0.000 (0.010)Elapsed 1m 28s (remain 14m 38s)Loss: 0.0731(0.6359)Grad: 1.2223  
Epoch: [3][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 21s)Loss: 0.5502(0.6384)Grad: 1.7403  
Epoch: [3][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 4s)Loss: 0.3792(0.6135)G

Epoch 3 - avg_train_loss: 0.6065 avg_val_loss: 0.5380 time: 1021s
Epoch 3 - Accuracy: 0.8722475322703113
Epoch 3 - Save Best Score: 0.8722 Model


EVAL: [329/330] Data 0.000 (0.115) Elapsed 1m 9s (remain 0m 0s) Loss: 0.0075(0.5380) 
Epoch: [4][0/1317]Data 1.468 (1.468)Elapsed 0m 2s (remain 65m 43s)Loss: 0.1839(0.1839)Grad: 1.1179  
Epoch: [4][20/1317]Data 0.000 (0.070)Elapsed 0m 17s (remain 18m 3s)Loss: 0.8936(0.5109)Grad: 1.6150  
Epoch: [4][40/1317]Data 0.000 (0.036)Elapsed 0m 31s (remain 16m 33s)Loss: 0.3296(0.5536)Grad: 2.0263  
Epoch: [4][60/1317]Data 0.000 (0.024)Elapsed 0m 46s (remain 15m 56s)Loss: 0.2212(0.5561)Grad: 1.2609  
Epoch: [4][80/1317]Data 0.000 (0.018)Elapsed 1m 0s (remain 15m 27s)Loss: 0.1875(0.5419)Grad: 1.7527  
Epoch: [4][100/1317]Data 0.000 (0.015)Elapsed 1m 15s (remain 15m 6s)Loss: 0.4302(0.5201)Grad: 1.3373  
Epoch: [4][120/1317]Data 0.000 (0.012)Elapsed 1m 29s (remain 14m 46s)Loss: 0.3563(0.5412)Grad: 0.6299  
Epoch: [4][140/1317]Data 0.000 (0.011)Elapsed 1m 44s (remain 14m 28s)Loss: 0.5577(0.5489)Grad: 1.3837  
Epoch: [4][160/1317]Data 0.000 (0.009)Elapsed 1m 58s (remain 14m 11s)Loss: 0.2650(0.5751)Gra

Epoch 4 - avg_train_loss: 0.5803 avg_val_loss: 0.5541 time: 1024s
Epoch 4 - Accuracy: 0.873006833712984
Epoch 4 - Save Best Score: 0.8730 Model


EVAL: [329/330] Data 0.000 (0.118) Elapsed 1m 10s (remain 0m 0s) Loss: 0.0051(0.5541) 
Epoch: [5][0/1317]Data 1.323 (1.323)Elapsed 0m 2s (remain 51m 35s)Loss: 0.0157(0.0157)Grad: 0.5198  
Epoch: [5][20/1317]Data 0.000 (0.063)Elapsed 0m 16s (remain 17m 19s)Loss: 0.2814(0.6444)Grad: 1.1891  
Epoch: [5][40/1317]Data 0.000 (0.032)Elapsed 0m 31s (remain 16m 15s)Loss: 0.2902(0.5806)Grad: 1.5265  
Epoch: [5][60/1317]Data 0.000 (0.022)Elapsed 0m 45s (remain 15m 43s)Loss: 0.1772(0.5097)Grad: 0.9567  
Epoch: [5][80/1317]Data 0.000 (0.016)Elapsed 1m 0s (remain 15m 20s)Loss: 0.0487(0.5360)Grad: 1.0419  
Epoch: [5][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 59s)Loss: 0.6039(0.5987)Grad: 1.4003  
Epoch: [5][120/1317]Data 0.000 (0.011)Elapsed 1m 29s (remain 14m 41s)Loss: 0.6566(0.5801)Grad: 1.6765  
Epoch: [5][140/1317]Data 0.000 (0.010)Elapsed 1m 43s (remain 14m 23s)Loss: 0.8976(0.6035)Grad: 0.8069  
Epoch: [5][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 7s)Loss: 1.3745(0.5984)G

Epoch 5 - avg_train_loss: 0.5610 avg_val_loss: 0.5303 time: 1024s
Epoch 5 - Accuracy: 0.8777524677296887
Epoch 5 - Save Best Score: 0.8778 Model


EVAL: [329/330] Data 0.000 (0.119) Elapsed 1m 11s (remain 0m 0s) Loss: 0.0018(0.5303) 
Epoch: [6][0/1317]Data 1.278 (1.278)Elapsed 0m 2s (remain 49m 56s)Loss: 0.0707(0.0707)Grad: 0.9670  
Epoch: [6][20/1317]Data 0.000 (0.061)Elapsed 0m 16s (remain 17m 8s)Loss: 0.3713(0.3996)Grad: 0.7168  
Epoch: [6][40/1317]Data 0.000 (0.031)Elapsed 0m 31s (remain 16m 10s)Loss: 0.8530(0.4976)Grad: 1.9020  
Epoch: [6][60/1317]Data 0.000 (0.021)Elapsed 0m 45s (remain 15m 39s)Loss: 0.2841(0.5179)Grad: 1.8207  
Epoch: [6][80/1317]Data 0.000 (0.016)Elapsed 1m 0s (remain 15m 16s)Loss: 0.1294(0.5252)Grad: 1.4789  
Epoch: [6][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 57s)Loss: 0.6526(0.5430)Grad: 2.1945  
Epoch: [6][120/1317]Data 0.000 (0.011)Elapsed 1m 28s (remain 14m 39s)Loss: 0.0379(0.5271)Grad: 1.1804  
Epoch: [6][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 22s)Loss: 0.1524(0.5578)Grad: 1.0707  
Epoch: [6][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 5s)Loss: 0.8865(0.5494)Gr

Epoch 6 - avg_train_loss: 0.5391 avg_val_loss: 0.5445 time: 1023s
Epoch 6 - Accuracy: 0.8775626423690205


EVAL: [329/330] Data 0.000 (0.119) Elapsed 1m 10s (remain 0m 0s) Loss: 0.0081(0.5445) 
Epoch: [7][0/1317]Data 1.062 (1.062)Elapsed 0m 2s (remain 49m 0s)Loss: 0.1598(0.1598)Grad: 0.7708  
Epoch: [7][20/1317]Data 0.000 (0.051)Elapsed 0m 16s (remain 17m 22s)Loss: 0.1260(0.3371)Grad: 0.9602  
Epoch: [7][40/1317]Data 0.000 (0.026)Elapsed 0m 31s (remain 16m 13s)Loss: 0.9872(0.3696)Grad: 1.8911  
Epoch: [7][60/1317]Data 0.000 (0.018)Elapsed 0m 45s (remain 15m 46s)Loss: 0.2726(0.3979)Grad: 1.0648  
Epoch: [7][80/1317]Data 0.000 (0.013)Elapsed 1m 0s (remain 15m 19s)Loss: 0.9465(0.3954)Grad: 1.3820  
Epoch: [7][100/1317]Data 0.000 (0.011)Elapsed 1m 14s (remain 14m 58s)Loss: 0.4414(0.4521)Grad: 0.8257  
Epoch: [7][120/1317]Data 0.000 (0.009)Elapsed 1m 29s (remain 14m 43s)Loss: 0.8913(0.4568)Grad: 1.1593  
Epoch: [7][140/1317]Data 0.000 (0.008)Elapsed 1m 43s (remain 14m 25s)Loss: 0.0515(0.4398)Grad: 0.6600  
Epoch: [7][160/1317]Data 0.000 (0.007)Elapsed 1m 58s (remain 14m 9s)Loss: 0.8651(0.4508)Gr

Epoch 7 - avg_train_loss: 0.5287 avg_val_loss: 0.5462 time: 1023s
Epoch 7 - Accuracy: 0.8779422930903569
Epoch 7 - Save Best Score: 0.8779 Model


EVAL: [329/330] Data 0.000 (0.119) Elapsed 1m 10s (remain 0m 0s) Loss: 0.0030(0.5462) 
Epoch: [8][0/1317]Data 1.135 (1.135)Elapsed 0m 2s (remain 49m 7s)Loss: 0.1874(0.1874)Grad: 1.4590  
Epoch: [8][20/1317]Data 0.000 (0.054)Elapsed 0m 16s (remain 16m 59s)Loss: 0.2893(0.4315)Grad: 1.6848  
Epoch: [8][40/1317]Data 0.000 (0.028)Elapsed 0m 30s (remain 16m 4s)Loss: 0.5805(0.4414)Grad: 1.2283  
Epoch: [8][60/1317]Data 0.000 (0.019)Elapsed 0m 45s (remain 15m 38s)Loss: 1.6922(0.4647)Grad: 2.1380  
Epoch: [8][80/1317]Data 0.000 (0.014)Elapsed 0m 59s (remain 15m 14s)Loss: 0.0557(0.5227)Grad: 0.9320  
Epoch: [8][100/1317]Data 0.000 (0.011)Elapsed 1m 14s (remain 14m 56s)Loss: 0.5669(0.5019)Grad: 1.0199  
Epoch: [8][120/1317]Data 0.000 (0.010)Elapsed 1m 28s (remain 14m 37s)Loss: 0.6031(0.4941)Grad: 1.2637  
Epoch: [8][140/1317]Data 0.000 (0.008)Elapsed 1m 43s (remain 14m 21s)Loss: 0.0430(0.4975)Grad: 0.7004  
Epoch: [8][160/1317]Data 0.000 (0.007)Elapsed 1m 57s (remain 14m 4s)Loss: 0.6454(0.5090)Gr

Epoch 8 - avg_train_loss: 0.5168 avg_val_loss: 0.5032 time: 1023s
Epoch 8 - Accuracy: 0.8845861807137434
Epoch 8 - Save Best Score: 0.8846 Model


EVAL: [329/330] Data 0.000 (0.118) Elapsed 1m 11s (remain 0m 0s) Loss: 0.0035(0.5032) 
Epoch: [9][0/1317]Data 1.257 (1.257)Elapsed 0m 2s (remain 49m 24s)Loss: 0.6662(0.6662)Grad: 1.6702  
Epoch: [9][20/1317]Data 0.000 (0.060)Elapsed 0m 16s (remain 17m 15s)Loss: 0.5523(0.5945)Grad: 1.4514  
Epoch: [9][40/1317]Data 0.000 (0.031)Elapsed 0m 31s (remain 16m 7s)Loss: 1.0606(0.5838)Grad: 1.3508  
Epoch: [9][60/1317]Data 0.000 (0.021)Elapsed 0m 45s (remain 15m 40s)Loss: 0.8037(0.5211)Grad: 1.3202  
Epoch: [9][80/1317]Data 0.000 (0.016)Elapsed 0m 59s (remain 15m 15s)Loss: 1.1288(0.5066)Grad: 1.7666  
Epoch: [9][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 54s)Loss: 0.1645(0.5211)Grad: 1.0749  
Epoch: [9][120/1317]Data 0.000 (0.011)Elapsed 1m 28s (remain 14m 38s)Loss: 1.0556(0.5380)Grad: 1.0225  
Epoch: [9][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 20s)Loss: 0.0309(0.5223)Grad: 0.7374  
Epoch: [9][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 5s)Loss: 0.7796(0.5216)G

Epoch 9 - avg_train_loss: 0.5095 avg_val_loss: 0.5258 time: 1023s
Epoch 9 - Accuracy: 0.8834472285497342


EVAL: [329/330] Data 0.000 (0.118) Elapsed 1m 10s (remain 0m 0s) Loss: 0.0016(0.5258) 
Epoch: [10][0/1317]Data 1.276 (1.276)Elapsed 0m 2s (remain 49m 39s)Loss: 0.6106(0.6106)Grad: 0.6086  
Epoch: [10][20/1317]Data 0.000 (0.061)Elapsed 0m 16s (remain 17m 7s)Loss: 0.2735(0.4108)Grad: 1.6551  
Epoch: [10][40/1317]Data 0.000 (0.031)Elapsed 0m 31s (remain 16m 7s)Loss: 0.3644(0.3985)Grad: 1.3562  
Epoch: [10][60/1317]Data 0.000 (0.021)Elapsed 0m 45s (remain 15m 36s)Loss: 0.5165(0.4535)Grad: 1.6259  
Epoch: [10][80/1317]Data 0.000 (0.016)Elapsed 0m 59s (remain 15m 14s)Loss: 0.0939(0.4656)Grad: 0.9998  
Epoch: [10][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 53s)Loss: 0.0364(0.4659)Grad: 0.7087  
Epoch: [10][120/1317]Data 0.000 (0.011)Elapsed 1m 28s (remain 14m 38s)Loss: 0.1194(0.4497)Grad: 0.6898  
Epoch: [10][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 20s)Loss: 0.1334(0.4665)Grad: 1.1596  
Epoch: [10][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 3s)Loss: 0.3042(

Epoch 10 - avg_train_loss: 0.4934 avg_val_loss: 0.5300 time: 1023s
Epoch 10 - Accuracy: 0.8834472285497342


EVAL: [329/330] Data 0.000 (0.120) Elapsed 1m 11s (remain 0m 0s) Loss: 0.0010(0.5300) 


Score: 0.88459
loss_train: BiTemperedLogisticLoss()


Epoch: [1][0/1317]Data 1.233 (1.233)Elapsed 0m 1s (remain 33m 3s)Loss: 1.6439(1.6439)Grad: 0.6666  
Epoch: [1][20/1317]Data 0.005 (0.064)Elapsed 0m 7s (remain 7m 46s)Loss: 1.3754(1.4283)Grad: 0.5217  
Epoch: [1][40/1317]Data 0.000 (0.034)Elapsed 0m 13s (remain 7m 5s)Loss: 1.3934(1.4033)Grad: 0.4839  
Epoch: [1][60/1317]Data 0.001 (0.023)Elapsed 0m 19s (remain 6m 34s)Loss: 0.8121(1.3831)Grad: 0.4426  
Epoch: [1][80/1317]Data 0.005 (0.019)Elapsed 0m 26s (remain 6m 37s)Loss: 0.5829(1.3983)Grad: 0.4533  
Epoch: [1][100/1317]Data 0.003 (0.015)Elapsed 0m 31s (remain 6m 24s)Loss: 1.6178(1.3843)Grad: 0.4220  
Epoch: [1][120/1317]Data 0.004 (0.013)Elapsed 0m 37s (remain 6m 10s)Loss: 1.9615(1.3643)Grad: 0.4048  
Epoch: [1][140/1317]Data 0.001 (0.012)Elapsed 0m 43s (remain 6m 0s)Loss: 1.0626(1.3544)Grad: 0.3376  
Epoch: [1][160/1317]Data 0.001 (0.010)Elapsed 0m 48s (remain 5m 50s)Loss: 1.7247(1.3698)Grad: 0.3753  
Epoch: [1][180/1317]Data 0.000 (0.009)Elapsed 0m 54s (remain 5m 43s)Loss: 2.0363(1.

Epoch 1 - avg_train_loss: 1.0645 avg_val_loss: 0.8562 time: 456s
Epoch 1 - Accuracy: 0.7222327700778431
Epoch 1 - Save Best Score: 0.7222 Model


EVAL: [329/330] Data 0.000 (0.105) Elapsed 1m 7s (remain 0m 0s) Loss: 0.1007(0.8562) 


requires_grad of all parameters are unlocked


Epoch: [2][0/1317]Data 1.307 (1.307)Elapsed 0m 2s (remain 50m 2s)Loss: 0.9578(0.9578)Grad: 2.9754  
Epoch: [2][20/1317]Data 0.000 (0.062)Elapsed 0m 16s (remain 17m 16s)Loss: 0.7383(0.7751)Grad: 2.4965  
Epoch: [2][40/1317]Data 0.000 (0.032)Elapsed 0m 31s (remain 16m 12s)Loss: 1.0386(0.8451)Grad: 2.7628  
Epoch: [2][60/1317]Data 0.000 (0.022)Elapsed 0m 45s (remain 15m 39s)Loss: 1.6557(0.8776)Grad: 2.3663  
Epoch: [2][80/1317]Data 0.000 (0.016)Elapsed 1m 0s (remain 15m 18s)Loss: 2.0192(0.8799)Grad: 3.3416  
Epoch: [2][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 58s)Loss: 0.3124(0.8549)Grad: 2.0318  
Epoch: [2][120/1317]Data 0.000 (0.011)Elapsed 1m 29s (remain 14m 40s)Loss: 0.7224(0.8559)Grad: 2.5422  
Epoch: [2][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 22s)Loss: 1.0370(0.8498)Grad: 2.0316  
Epoch: [2][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 6s)Loss: 0.6045(0.8398)Grad: 1.8060  
Epoch: [2][180/1317]Data 0.000 (0.007)Elapsed 2m 12s (remain 13m 51s)Loss:

Epoch 2 - avg_train_loss: 0.6998 avg_val_loss: 0.5673 time: 1020s
Epoch 2 - Accuracy: 0.8576039491171444
Epoch 2 - Save Best Score: 0.8576 Model


EVAL: [329/330] Data 0.000 (0.105) Elapsed 1m 6s (remain 0m 0s) Loss: 0.0002(0.5673) 
Epoch: [3][0/1317]Data 1.056 (1.056)Elapsed 0m 2s (remain 48m 9s)Loss: 0.4859(0.4859)Grad: 1.0691  
Epoch: [3][20/1317]Data 0.000 (0.050)Elapsed 0m 16s (remain 17m 4s)Loss: 1.1466(0.4661)Grad: 1.8928  
Epoch: [3][40/1317]Data 0.000 (0.026)Elapsed 0m 31s (remain 16m 13s)Loss: 0.8037(0.5062)Grad: 1.8311  
Epoch: [3][60/1317]Data 0.000 (0.017)Elapsed 0m 45s (remain 15m 39s)Loss: 0.6977(0.6084)Grad: 1.6056  
Epoch: [3][80/1317]Data 0.000 (0.013)Elapsed 0m 59s (remain 15m 14s)Loss: 0.6605(0.6484)Grad: 2.0036  
Epoch: [3][100/1317]Data 0.000 (0.011)Elapsed 1m 14s (remain 14m 59s)Loss: 0.8501(0.6586)Grad: 2.0408  
Epoch: [3][120/1317]Data 0.000 (0.009)Elapsed 1m 29s (remain 14m 40s)Loss: 0.8243(0.6605)Grad: 1.9044  
Epoch: [3][140/1317]Data 0.000 (0.008)Elapsed 1m 43s (remain 14m 24s)Loss: 2.5770(0.6806)Grad: 1.3092  
Epoch: [3][160/1317]Data 0.000 (0.007)Elapsed 1m 58s (remain 14m 7s)Loss: 0.0947(0.6578)Gra

Epoch 3 - avg_train_loss: 0.6061 avg_val_loss: 0.5546 time: 1020s
Epoch 3 - Accuracy: 0.8746914752230871
Epoch 3 - Save Best Score: 0.8747 Model


EVAL: [329/330] Data 0.000 (0.108) Elapsed 1m 7s (remain 0m 0s) Loss: 0.0001(0.5546) 
Epoch: [4][0/1317]Data 1.249 (1.249)Elapsed 0m 2s (remain 49m 26s)Loss: 0.4113(0.4113)Grad: 1.7427  
Epoch: [4][20/1317]Data 0.000 (0.060)Elapsed 0m 16s (remain 17m 19s)Loss: 0.0952(0.5119)Grad: 1.1430  
Epoch: [4][40/1317]Data 0.000 (0.031)Elapsed 0m 31s (remain 16m 10s)Loss: 0.4492(0.5445)Grad: 1.1026  
Epoch: [4][60/1317]Data 0.000 (0.021)Elapsed 0m 45s (remain 15m 45s)Loss: 0.6545(0.5717)Grad: 1.5265  
Epoch: [4][80/1317]Data 0.000 (0.016)Elapsed 1m 0s (remain 15m 19s)Loss: 0.3070(0.5406)Grad: 1.0450  
Epoch: [4][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 59s)Loss: 0.4530(0.5207)Grad: 1.0444  
Epoch: [4][120/1317]Data 0.000 (0.010)Elapsed 1m 29s (remain 14m 41s)Loss: 0.1452(0.5301)Grad: 1.5257  
Epoch: [4][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 23s)Loss: 0.2410(0.5521)Grad: 1.6865  
Epoch: [4][160/1317]Data 0.000 (0.008)Elapsed 1m 58s (remain 14m 7s)Loss: 0.1376(0.5741)Gr

Epoch 4 - avg_train_loss: 0.5848 avg_val_loss: 0.5587 time: 1022s
Epoch 4 - Accuracy: 0.87450161382191


EVAL: [329/330] Data 0.000 (0.110) Elapsed 1m 8s (remain 0m 0s) Loss: 0.0000(0.5587) 
Epoch: [5][0/1317]Data 1.258 (1.258)Elapsed 0m 2s (remain 50m 11s)Loss: 1.0731(1.0731)Grad: 1.4749  
Epoch: [5][20/1317]Data 0.000 (0.060)Elapsed 0m 16s (remain 17m 12s)Loss: 0.1240(0.6330)Grad: 0.6912  
Epoch: [5][40/1317]Data 0.000 (0.031)Elapsed 0m 31s (remain 16m 11s)Loss: 0.2236(0.5845)Grad: 0.8951  
Epoch: [5][60/1317]Data 0.000 (0.021)Elapsed 0m 45s (remain 15m 40s)Loss: 0.3602(0.5566)Grad: 0.7087  
Epoch: [5][80/1317]Data 0.000 (0.016)Elapsed 1m 0s (remain 15m 18s)Loss: 1.5584(0.5517)Grad: 1.5320  
Epoch: [5][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 57s)Loss: 0.0981(0.5722)Grad: 0.6330  
Epoch: [5][120/1317]Data 0.000 (0.011)Elapsed 1m 29s (remain 14m 41s)Loss: 0.4388(0.5922)Grad: 1.9392  
Epoch: [5][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 23s)Loss: 0.2874(0.5626)Grad: 1.6966  
Epoch: [5][160/1317]Data 0.000 (0.008)Elapsed 1m 58s (remain 14m 7s)Loss: 0.8674(0.5557)Gr

Epoch 5 - avg_train_loss: 0.5602 avg_val_loss: 0.5215 time: 1020s
Epoch 5 - Accuracy: 0.8767799506360356
Epoch 5 - Save Best Score: 0.8768 Model


EVAL: [329/330] Data 0.000 (0.104) Elapsed 1m 7s (remain 0m 0s) Loss: 0.0000(0.5215) 
Epoch: [6][0/1317]Data 1.261 (1.261)Elapsed 0m 2s (remain 52m 45s)Loss: 1.2576(1.2576)Grad: 0.7774  
Epoch: [6][20/1317]Data 0.000 (0.060)Elapsed 0m 17s (remain 17m 34s)Loss: 0.7338(0.6741)Grad: 1.0532  
Epoch: [6][40/1317]Data 0.000 (0.031)Elapsed 0m 31s (remain 16m 19s)Loss: 0.1951(0.5455)Grad: 1.7367  
Epoch: [6][60/1317]Data 0.000 (0.021)Elapsed 0m 46s (remain 15m 47s)Loss: 1.7303(0.5539)Grad: 2.2205  
Epoch: [6][80/1317]Data 0.000 (0.016)Elapsed 1m 0s (remain 15m 22s)Loss: 0.0748(0.5575)Grad: 0.9821  
Epoch: [6][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 15m 0s)Loss: 1.0181(0.5716)Grad: 1.2280  
Epoch: [6][120/1317]Data 0.000 (0.011)Elapsed 1m 29s (remain 14m 43s)Loss: 1.0507(0.5809)Grad: 0.9709  
Epoch: [6][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 25s)Loss: 0.3566(0.5779)Grad: 1.5135  
Epoch: [6][160/1317]Data 0.000 (0.008)Elapsed 1m 58s (remain 14m 10s)Loss: 0.8509(0.5731)Gr

Epoch 6 - avg_train_loss: 0.5434 avg_val_loss: 0.5464 time: 1022s
Epoch 6 - Accuracy: 0.8822859312701727
Epoch 6 - Save Best Score: 0.8823 Model


EVAL: [329/330] Data 0.000 (0.107) Elapsed 1m 7s (remain 0m 0s) Loss: 0.0000(0.5464) 
Epoch: [7][0/1317]Data 1.194 (1.194)Elapsed 0m 2s (remain 49m 4s)Loss: 0.3060(0.3060)Grad: 1.3431  
Epoch: [7][20/1317]Data 0.000 (0.057)Elapsed 0m 16s (remain 17m 12s)Loss: 0.0465(0.2940)Grad: 0.7991  
Epoch: [7][40/1317]Data 0.000 (0.029)Elapsed 0m 31s (remain 16m 12s)Loss: 0.7603(0.4931)Grad: 1.1538  
Epoch: [7][60/1317]Data 0.000 (0.020)Elapsed 0m 45s (remain 15m 40s)Loss: 0.1730(0.5543)Grad: 1.3681  
Epoch: [7][80/1317]Data 0.000 (0.015)Elapsed 1m 0s (remain 15m 19s)Loss: 0.0038(0.5513)Grad: 0.2360  
Epoch: [7][100/1317]Data 0.000 (0.012)Elapsed 1m 14s (remain 14m 57s)Loss: 1.2401(0.5630)Grad: 1.2543  
Epoch: [7][120/1317]Data 0.000 (0.010)Elapsed 1m 28s (remain 14m 39s)Loss: 0.4415(0.6092)Grad: 0.9724  
Epoch: [7][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 22s)Loss: 0.8134(0.5998)Grad: 1.4776  
Epoch: [7][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 5s)Loss: 0.7851(0.5963)Gra

Epoch 7 - avg_train_loss: 0.5244 avg_val_loss: 0.5506 time: 1020s
Epoch 7 - Accuracy: 0.8855135750901841
Epoch 7 - Save Best Score: 0.8855 Model


EVAL: [329/330] Data 0.000 (0.106) Elapsed 1m 7s (remain 0m 0s) Loss: 0.0000(0.5506) 
Epoch: [8][0/1317]Data 1.199 (1.199)Elapsed 0m 2s (remain 49m 47s)Loss: 0.6560(0.6560)Grad: 1.0933  
Epoch: [8][20/1317]Data 0.000 (0.057)Elapsed 0m 16s (remain 17m 19s)Loss: 0.2247(0.4505)Grad: 0.9950  
Epoch: [8][40/1317]Data 0.000 (0.029)Elapsed 0m 31s (remain 16m 9s)Loss: 0.5544(0.5137)Grad: 1.9093  
Epoch: [8][60/1317]Data 0.000 (0.020)Elapsed 0m 45s (remain 15m 43s)Loss: 0.1217(0.5147)Grad: 1.2054  
Epoch: [8][80/1317]Data 0.000 (0.015)Elapsed 1m 0s (remain 15m 18s)Loss: 0.3476(0.5039)Grad: 1.3962  
Epoch: [8][100/1317]Data 0.000 (0.012)Elapsed 1m 14s (remain 15m 1s)Loss: 1.8774(0.5118)Grad: 1.4102  
Epoch: [8][120/1317]Data 0.000 (0.010)Elapsed 1m 29s (remain 14m 42s)Loss: 0.4268(0.5033)Grad: 1.3043  
Epoch: [8][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 24s)Loss: 0.2277(0.4952)Grad: 1.3824  
Epoch: [8][160/1317]Data 0.000 (0.008)Elapsed 1m 58s (remain 14m 9s)Loss: 0.5898(0.4979)Grad

Epoch 8 - avg_train_loss: 0.5019 avg_val_loss: 0.5599 time: 1021s
Epoch 8 - Accuracy: 0.8862730206948928
Epoch 8 - Save Best Score: 0.8863 Model


EVAL: [329/330] Data 0.000 (0.110) Elapsed 1m 7s (remain 0m 0s) Loss: 0.0000(0.5599) 
Epoch: [9][0/1317]Data 0.962 (0.962)Elapsed 0m 2s (remain 46m 2s)Loss: 0.1610(0.1610)Grad: 0.4324  
Epoch: [9][20/1317]Data 0.000 (0.046)Elapsed 0m 16s (remain 17m 13s)Loss: 0.5637(0.5644)Grad: 1.7021  
Epoch: [9][40/1317]Data 0.000 (0.024)Elapsed 0m 31s (remain 16m 9s)Loss: 1.8795(0.6171)Grad: 1.7956  
Epoch: [9][60/1317]Data 0.000 (0.016)Elapsed 0m 45s (remain 15m 38s)Loss: 0.3050(0.5853)Grad: 1.5377  
Epoch: [9][80/1317]Data 0.000 (0.012)Elapsed 1m 0s (remain 15m 16s)Loss: 1.4168(0.6210)Grad: 1.1321  
Epoch: [9][100/1317]Data 0.000 (0.010)Elapsed 1m 14s (remain 14m 56s)Loss: 0.3350(0.5898)Grad: 0.9254  
Epoch: [9][120/1317]Data 0.000 (0.008)Elapsed 1m 28s (remain 14m 38s)Loss: 0.1384(0.5473)Grad: 0.9930  
Epoch: [9][140/1317]Data 0.000 (0.007)Elapsed 1m 43s (remain 14m 21s)Loss: 0.7384(0.5314)Grad: 1.1875  
Epoch: [9][160/1317]Data 0.000 (0.006)Elapsed 1m 57s (remain 14m 4s)Loss: 0.1650(0.5167)Grad

Epoch 9 - avg_train_loss: 0.5029 avg_val_loss: 0.5230 time: 1021s
Epoch 9 - Accuracy: 0.8847541294854756


EVAL: [329/330] Data 0.000 (0.109) Elapsed 1m 7s (remain 0m 0s) Loss: 0.0000(0.5230) 
Epoch: [10][0/1317]Data 1.278 (1.278)Elapsed 0m 2s (remain 49m 25s)Loss: 1.5785(1.5785)Grad: 1.7828  
Epoch: [10][20/1317]Data 0.000 (0.061)Elapsed 0m 16s (remain 17m 7s)Loss: 0.5509(0.5559)Grad: 1.2453  
Epoch: [10][40/1317]Data 0.000 (0.031)Elapsed 0m 31s (remain 16m 6s)Loss: 0.3969(0.5326)Grad: 1.3742  
Epoch: [10][60/1317]Data 0.000 (0.021)Elapsed 0m 45s (remain 15m 39s)Loss: 0.2770(0.5244)Grad: 1.4718  
Epoch: [10][80/1317]Data 0.000 (0.016)Elapsed 0m 59s (remain 15m 14s)Loss: 0.1648(0.5325)Grad: 0.8399  
Epoch: [10][100/1317]Data 0.000 (0.013)Elapsed 1m 14s (remain 14m 57s)Loss: 0.3845(0.5434)Grad: 0.9115  
Epoch: [10][120/1317]Data 0.000 (0.011)Elapsed 1m 28s (remain 14m 38s)Loss: 0.0414(0.5304)Grad: 0.8205  
Epoch: [10][140/1317]Data 0.000 (0.009)Elapsed 1m 43s (remain 14m 22s)Loss: 0.2310(0.5141)Grad: 1.5014  
Epoch: [10][160/1317]Data 0.000 (0.008)Elapsed 1m 57s (remain 14m 5s)Loss: 1.8510(0

Epoch 10 - avg_train_loss: 0.4991 avg_val_loss: 0.5366 time: 1021s
Epoch 10 - Accuracy: 0.8847541294854756


EVAL: [329/330] Data 0.000 (0.107) Elapsed 1m 7s (remain 0m 0s) Loss: 0.0000(0.5366) 


Score: 0.88627
Score: 0.88543
