# **Import 3rdparty**

In [55]:
import os
import pandas as pd
import numpy as np
import random
import cv2

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import timm
from matplotlib import pyplot as plt
from sklearn.model_selection import StratifiedKFold

from scipy.special import softmax

import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

In [4]:
# def seed_everything(seed):
#     random.seed(seed)
#     os.environ['PYTHONHASHSEED'] = str(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = True
    
    
def seed_everything(seed: int):
    if not seed:
        seed = 10

    print("[ Using Seed : ", seed, " ]")
    os.environ['PYTHONHASHSEED'] = str(seed)  # set PYTHONHASHSEED env var at fixed value
    np.random.seed(seed) # for numpy pseudo-random generator
    random.seed(seed) # set fixed value for python built-in pseudo-random generator
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False


# **Load Image**

In [10]:
def load_image(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

# **CassavaDataset**

In [11]:
class CassavaDataset(Dataset):
    def __init__(self, data_dir, df, transforms=None, output_label=True):
        self.data_dir = data_dir
        self.df = df
        self.transforms = transforms
        self.output_label = output_label

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        image_infos = self.df.iloc[index]
        image_path = self.data_dir + image_infos.image_id

        image = load_image(image_path)

        if image is None:
            raise FileNotFoundError(image_path)

        ### augment
        if self.transforms is not None:
            image = self.transforms(image=image)['image']
        else:
            image = torch.from_numpy(image)

        if self.output_label:
            return image, image_infos.label
        else:
            return image       

# **CassavaClassifier**

In [12]:
# ====================================================
# Vit Model
# ====================================================
class CustomViT(nn.Module):
    def __init__(self, model_arch, num_classes, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        ### vit
        num_features = self.model.head.in_features
        self.model.head = nn.Linear(num_features, num_classes)
        '''
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            #nn.Linear(num_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(num_features, num_classes, bias=True)
        )
        '''
    def forward(self, x):
        x = self.model(x)
        return x
    
# ====================================================
# ResNext Model
# ====================================================
class CustomResNext(nn.Module):
    def __init__(self, model_arch, num_classes, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        #='resnext50_32x4d',
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, num_classes)

    def forward(self, x):
        x = self.model(x)
        return x
    
# ====================================================
# EfficientNet Model
# ====================================================
class CustomEfficientNet(nn.Module):

    def __init__(self, model_arch, num_classes, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

# **Train and Val transforms**

In [13]:
def get_train_transforms(CFG):
    return A.Compose([
            A.RandomResizedCrop(height=CFG.image_size, width=CFG.image_size, p=0.5),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            A.CenterCrop(CFG.image_size, CFG.image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            A.CoarseDropout(p=0.5),
            A.Cutout(p=0.5),
            ToTensorV2(),
        ],p=1.0)

# def get_train_transforms(CFG):
#     return A.Compose([
#             #Resize(CFG.size, CFG.size),
#             A.RandomResizedCrop(height=CFG.image_size, width=CFG.image_size, scale=(0.85, 1.0)),
#             A.HorizontalFlip(p=0.5),
#             A.RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2)),
#             A.HueSaturationValue(p=0.2, hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2),
#             A.ShiftScaleRotate(p=0.2, shift_limit=0.0625, scale_limit=0.2, rotate_limit=20),
#             A.CoarseDropout(p=0.2),
#             A.Cutout(p=0.2, max_h_size=16, max_w_size=16, fill_value=(0., 0., 0.), num_holes=16),
#             A.Normalize(
#                 mean=[0.485, 0.456, 0.406],
#                 std=[0.229, 0.224, 0.225],
#             ),
#             ToTensorV2(),
#         ], p=1.0)

def get_val_transforms(cfg):
    return A.Compose([
            A.CenterCrop(CFG.image_size, CFG.image_size, p=0.5),
            A.Resize(CFG.image_size, CFG.image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(),
        ],p=1.0)


# **Train and Val data loader**

In [44]:
def load_dataloader(CFG, df, train_idx, val_idx):
    df_train = df.loc[train_idx,:].reset_index(drop=True)
    df_val = df.loc[val_idx,:].reset_index(drop=True)

    train_dataset = CassavaDataset(
        CFG.train_data_dir,
        df_train,
        transforms=get_train_transforms(CFG), 
        output_label=True)

    val_dataset = CassavaDataset(
        CFG.train_data_dir,
        df_val,
        transforms=get_val_transforms(CFG), 
        output_label=True)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=CFG.train_batch_size,
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=CFG.num_workers,
        #sampler=BalanceClassSampler(labels=train_['label'].values, mode="downsampling")
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=CFG.val_batch_size,
        num_workers=CFG.num_workers,
        shuffle=False,
        pin_memory=False,
    )
    
    return train_loader, val_loader


def load_valdataloader(df_val):
    
    val_dataset = CassavaDataset(
        CFG.train_data_dir,
        df_val,
        transforms=get_val_transforms(CFG), 
        output_label=False
    )
        
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=CFG.val_batch_size,
        num_workers=CFG.num_workers,
        shuffle=False,
        pin_memory=False,
    )
    
    return val_loader

# **Train one epoch**

In [15]:
def train_one_epoch(epoch,model,loss_fn,optimizer,train_loader,device,scheduler=None,schd_batch_update=False):
    model.train()
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    running_loss = None
    pbar = tqdm(enumerate(train_loader),total=len(train_loader))
    for step,(images,targets) in pbar:
        images = images.to(device).float()
        targets = targets.to(device).long()
        
        with autocast():
            preds = model(images)
            loss = loss_fn(preds,targets)
        
            scaler.scale(loss).backward()
            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss* 0.99 + loss.item()*0.01
                
            if ((step + 1) % CFG.accum_iter == 0) or ((step + 1) == len(train_loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                
                if scheduler is not None and schd_batch_update:
                    scheduler.step()
            if ((step + 1) % CFG.accum_iter == 0) or ((step + 1) == len(train_loader)):
                description = f'Train epoch {epoch} loss: {running_loss:.5f}'
                pbar.set_description(description)
                
    if scheduler is not None and schd_batch_update:
        scheduler.step()

# **Valid one epoch**

In [16]:
def valid_one_epoch(epoch,model,loss_fn,val_loader,device,scheduler=None,schd_loss_update=False):
    model.eval()
    
    loss_sum = 0
    sample_num = 0
    preds_all = []
    targets_all = []
    scores = []
    
    pbar = tqdm(enumerate(val_loader),total=len(val_loader))
    for step,(images,targets) in pbar:
        images = images.to(device).float()
        targets = targets.to(device).long()
        preds = model(images)
            
        preds_all += [torch.argmax(preds,1).detach().cpu().numpy()]
        targets_all += [targets.detach().cpu().numpy()]

        loss = loss_fn(preds,targets)
        loss_sum += loss.item()*targets.shape[0]
        sample_num += targets.shape[0]
           
        if ((step + 1) % CFG.accum_iter == 0) or ((step + 1) == len(train_loader)):
            description = f'Val epoch {epoch} loss: {loss_sum/sample_num:.5f}'
            pbar.set_description(description)
            
    preds_all = np.concatenate(preds_all)
    targets_all = np.concatenate(targets_all)
    accuracy = (preds_all == targets_all).mean()
    print(f'Validation multi-class accuracy = {accuracy:.5f}')
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()
    
    return accuracy

# **Label Smoothing Cross Entropy Loss**

In [17]:
class LabelSmoothingCrossEntropy(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x, target):
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()
    
class TaylorSoftmax(nn.Module):

    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        
        fn = torch.ones_like(x)
        denor = 1.
        for i in range(1, self.n+1):
            denor *= i
            fn = fn + x.pow(i) / denor
        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out

# class LabelSmoothingLoss(nn.Module):

#     def __init__(self, classes, smoothing=0.0, dim=-1): 
#         super(LabelSmoothingLoss, self).__init__() 
#         self.confidence = 1.0 - smoothing 
#         self.smoothing = smoothing 
#         self.cls = classes 
#         self.dim = dim 
#     def forward(self, pred, target): 
#         """Taylor Softmax and log are already applied on the logits"""
#         #pred = pred.log_softmax(dim=self.dim) 
#         with torch.no_grad(): 
#             true_dist = torch.zeros_like(pred) 
#             true_dist.fill_(self.smoothing / (self.cls - 1)) 
#             true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 
#         return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
    

class TaylorCrossEntropyLoss(nn.Module):

    def __init__(self, n=2, ignore_index=-1, reduction='mean', smoothing=0.2):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingCrossEntropy(smoothing=smoothing)

    def forward(self, logits, labels):

        log_probs = self.taylor_softmax(logits).log()
        #loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
        #        ignore_index=self.ignore_index)
        loss = self.lab_smooth(log_probs, labels)
        return loss

# **EVAL**

In [69]:
class Config:
    seed = 42
    data_dir = '../input/cassava-leaf-disease-classification/'
    train_data_dir = data_dir + 'train_images/'
    train_csv_path = data_dir + 'train.csv'
    arch = 'tf_efficientnet_b3_ns' ## model name
    device = 'cuda'
    debug = False                 ##
    
    image_size = 512     ##384
    train_batch_size = 4        ##16
    val_batch_size = 4          ##32
    epochs = 10                  ## total train epochs
    freeze_bn_epochs = 5        ## freeze bn weights before epochs
    
    lr=1e-4                     ## init learning rate
    min_lr = 1e-6               ## min learning rate
    weight_decay = 1e-6
    num_workers = 0            ## 4
    num_splits = 5             ## numbers splits
    num_classes = 5            ## numbers classes
    T_0 = 10
    T_mult = 1
    accum_iter = 2
    verbose_step = 1
    
    criterion = 'LabelSmoothingCrossEntropy' ## CrossEntropy, LabelSmoothingCrossEntropy TaylorCrossEntropyLoss
    label_smoothing = 0.3
    
    train_id = [0,1,2,3,4]
    n_label = 5
    comment = 'init'
    
seed_everything(Config.seed)    

if __name__ == '__main__':
    
    # ====================================================
    # CFG 
    # ====================================================
    CFG = Config
    samples_df = pd.read_csv(CFG.train_csv_path)

    if CFG.debug:
        CFG.epochs = 1
        samples_df = samples_df.sample(2000,random_state=CFG.seed).reset_index(drop=True)
    
    print('CFG seed is ', CFG.seed)
    if CFG.seed is not None:
        seed_everything(CFG.seed)
        
    # ====================================================
    # split data
    # ====================================================
    folds = StratifiedKFold(n_splits=CFG.num_splits, shuffle=True, random_state=CFG.seed).split(np.arange(samples_df.shape[0]), samples_df.label.values)
    for fold,(train_idx,val_idx) in enumerate(folds):
        print('training_df')
        print(samples_df.iloc[train_idx]['label'].value_counts())
        print('validation_df')
        print(samples_df.iloc[val_idx]['label'].value_counts())
        break
    
    # ====================================================
    # load data
    # ====================================================
    train_loader,val_loader = load_dataloader(CFG, samples_df, train_idx, val_idx)
    samples_df.loc[val_idx,:].reset_index(drop=True).to_pickle('df_val.pkl')

    # ====================================================
    # select device
    # ====================================================
    device = torch.device(CFG.device)
    
    # ====================================================
    # build graph
    # ====================================================
    model = CustomEfficientNet('tf_efficientnet_b3_ns', CFG.n_label, pretrained=False).to(device)
    scaler = GradScaler()
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=CFG.T_0, 
        T_mult=CFG.T_mult, 
        eta_min=CFG.min_lr, 
        last_epoch=-1)
    loss_train = TaylorCrossEntropyLoss(smoothing=CFG.label_smoothing)
    loss_val = nn.CrossEntropyLoss().to(device)
    
    # ====================================================
    # train
    # ====================================================
    best_accuracy = 0
    best_epoch = 0
    for epoch in range(CFG.epochs):
        if epoch < CFG.freeze_bn_epochs:
            freeze_batchnorm_stats(model)  
        train_one_epoch(
            epoch, 
            model, 
            loss_train, 
            optimizer, 
            train_loader, 
            device, 
            scheduler=scheduler, 
            schd_batch_update=False)

        with torch.no_grad():
            epoch_accuracy = valid_one_epoch(
                epoch, 
                model, 
                loss_val, 
                val_loader, 
                device, 
                scheduler=None, 
                schd_loss_update=False)

        if epoch_accuracy > best_accuracy:
#             if not CFG.debug:
            torch.save(model.state_dict(),'{}_{}_best.ckpt'.format(CFG.arch, CFG.comment))
            print('Best model is saved')
            best_accuracy = epoch_accuracy
            best_epoch = epoch
        print('accuracy = {} in epoch {}'.format(epoch_accuracy, epoch))
        
    print('best accuracy = {} in epoch {}'.format(best_accuracy,best_epoch))
    del model, optimizer, train_loader, val_loader, scaler, scheduler
    torch.cuda.empty_cache()
    
    
    # ====================================================
    # validation
    # ====================================================

    def inference(model, states, test_loader, device):
        
        model.to(device)
        tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
        probs = []
        for i, (images) in tk0:
            images = images.to(device)
            avg_preds = []
            for state in states:
                model.load_state_dict(state)
                model.eval()
                with torch.no_grad():
                    y_preds = model(images)
                avg_preds.append(y_preds.softmax(1).to('cpu').numpy())
            avg_preds = np.mean(avg_preds, axis=0)
            probs.append(avg_preds)

        probs = np.concatenate(probs)
        return probs
    
    device = torch.device(CFG.device)
    df_val = pd.read_pickle('df_val.pkl')
    val_loader = load_valdataloader(df_val)
    model = CustomEfficientNet(CFG.arch, CFG.n_label, pretrained=False)
    states = [torch.load(p) for p in ['{}_{}_best.ckpt'.format(CFG.arch, CFG.comment)]]
    predictions = inference(model, states, val_loader, device)
    accuracy = sum(df_val.label==softmax(predictions).argmax(1))/df_val.shape[0]
    accuracy
    print('test accuracy = {} '.format(accuracy))

[ Using Seed :  42  ]
CFG seed is  42
[ Using Seed :  42  ]
training_df
3    10527
4     2061
2     1909
1     1751
0      869
Name: label, dtype: int64
validation_df
3    2631
4     516
2     477
1     438
0     218
Name: label, dtype: int64


Train epoch 0 loss: 1.34617: 100%|██████████| 4280/4280 [11:39<00:00,  6.12it/s]
Val epoch 0 loss: 1.47954: 100%|██████████| 1070/1070 [01:13<00:00, 14.61it/s]
  0%|          | 0/4280 [00:00<?, ?it/s]

Validation multi-class accuracy = 0.62570
Best model is saved
accuracy = 0.6257009345794392 in epoch 0


Train epoch 1 loss: 1.28493: 100%|██████████| 4280/4280 [11:38<00:00,  6.13it/s]
Val epoch 1 loss: 1.72804: 100%|██████████| 1070/1070 [01:13<00:00, 14.58it/s]
  0%|          | 0/4280 [00:00<?, ?it/s]

Validation multi-class accuracy = 0.64416
Best model is saved
accuracy = 0.6441588785046729 in epoch 1


Train epoch 2 loss: 1.29045: 100%|██████████| 4280/4280 [12:02<00:00,  5.92it/s]
Val epoch 2 loss: 1.73712: 100%|██████████| 1070/1070 [01:13<00:00, 14.60it/s]
  0%|          | 0/4280 [00:00<?, ?it/s]

Validation multi-class accuracy = 0.65911
Best model is saved
accuracy = 0.6591121495327102 in epoch 2


Train epoch 3 loss: 1.27693: 100%|██████████| 4280/4280 [11:44<00:00,  6.08it/s]
Val epoch 3 loss: 1.22849: 100%|██████████| 1070/1070 [01:17<00:00, 13.77it/s]
  0%|          | 0/4280 [00:00<?, ?it/s]

Validation multi-class accuracy = 0.68294
Best model is saved
accuracy = 0.6829439252336449 in epoch 3


Train epoch 4 loss: 1.25110: 100%|██████████| 4280/4280 [11:50<00:00,  6.02it/s]
Val epoch 4 loss: 1.57116: 100%|██████████| 1070/1070 [01:15<00:00, 14.09it/s]
  0%|          | 0/4280 [00:00<?, ?it/s]

Validation multi-class accuracy = 0.68785
Best model is saved
accuracy = 0.6878504672897197 in epoch 4


Train epoch 5 loss: 1.24977: 100%|██████████| 4280/4280 [11:42<00:00,  6.09it/s]
Val epoch 5 loss: 1.46568: 100%|██████████| 1070/1070 [01:12<00:00, 14.73it/s]
  0%|          | 0/4280 [00:00<?, ?it/s]

Validation multi-class accuracy = 0.69626
Best model is saved
accuracy = 0.6962616822429907 in epoch 5


Train epoch 6 loss: 1.24354: 100%|██████████| 4280/4280 [11:38<00:00,  6.13it/s]
Val epoch 6 loss: 1.31180: 100%|██████████| 1070/1070 [01:14<00:00, 14.35it/s]
  0%|          | 1/4280 [00:00<10:35,  6.73it/s]

Validation multi-class accuracy = 0.64930
accuracy = 0.6492990654205607 in epoch 6


Train epoch 7 loss: 1.19117: 100%|██████████| 4280/4280 [11:58<00:00,  5.96it/s]
Val epoch 7 loss: 0.88048: 100%|██████████| 1070/1070 [01:15<00:00, 14.08it/s]


Validation multi-class accuracy = 0.74089
Best model is saved
accuracy = 0.7408878504672897 in epoch 7
best accuracy = 0.7408878504672897 in epoch 7


100%|██████████| 1070/1070 [01:53<00:00,  9.41it/s]

test accuracy = 0.7404205607476636 





In [None]:
v1 baseline 0.72
v2 baseline change image size 384=>512　0.74

# **TRAIN**

# **INFERENCE**

# **Main Loop**

In [44]:
################ freeze bn 
def freeze_batchnorm_stats(net):
    try:
        for m in net.modules():
            if isinstance(m,nn.BatchNorm2d) or isinstance(m,nn.LayerNorm):
                m.eval()
    except ValuError:
        print('error with batchnorm2d or layernorm')
        return
    
    


if __name__ == '__main__':
    CFG = Config
    train = pd.read_csv(CFG.train_csv_path)
    
    if CFG.debug:
        CFG.epochs = 1
        train = train.sample(100,random_state=CFG.seed).reset_index(drop=True)
    
    print('CFG seed is ', CFG.seed)
    if CFG.seed is not None:
        seed_everything(CFG.seed)
    
    folds = StratifiedKFold(
        n_splits=CFG.num_splits, 
        shuffle=True, 
        random_state=CFG.seed).split(np.arange(train.shape[0]), train.label.values)
    
    cross_accuracy = []
    for fold,(train_idx,val_idx) in enumerate(folds):
        ########
        # load data
        #######
        train_loader,val_loader = load_dataloader(CFG, train, train_idx, val_idx)
        
        device = torch.device(CFG.device)
        
        
        ## 'vit_base_patch16_384' vc 0.66
        ## 'resnext50_32x4d' 0.680
        ## 'tf_efficientnet_b3_ns'  0.6700 => TaylorCrossEntropyLoss 0.67000 rollback
        ## 'tf_efficientnet_b3_ns'  0.6700 => 'tf_efficientnet_b4_ns' 0.56000 rollback
        ## 'tf_efficientnet_b3_ns'  0.6700 => change augumentation  0.60 rollback
        
        
        model = CustomEfficientNet('tf_efficientnet_b4_ns', train.label.nunique(), pretrained=True).to(device)
        
        scaler = GradScaler()
        optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=CFG.lr, 
            weight_decay=CFG.weight_decay)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 
            T_0=CFG.T_0, 
            T_mult=CFG.T_mult, 
            eta_min=CFG.min_lr, 
            last_epoch=-1)
    
        ########
        # criterion
        #######
        if CFG.criterion in['LabelSmoothingCrossEntropy', 'TaylorCrossEntropyLoss']:  #### label smoothing cross entropy
            loss_train = TaylorCrossEntropyLoss(smoothing=CFG.label_smoothing)
        else:
            loss_train = nn.CrossEntropyLoss().to(device)
        loss_val = nn.CrossEntropyLoss().to(device)
        
        best_accuracy = 0
        best_epoch = 0
        for epoch in range(CFG.epochs):
            if epoch < CFG.freeze_bn_epochs:
                freeze_batchnorm_stats(model)  
            train_one_epoch(
                epoch, 
                model, 
                loss_train, 
                optimizer, 
                train_loader, 
                device, 
                scheduler=scheduler, 
                schd_batch_update=False)

            with torch.no_grad():
                epoch_accuracy = valid_one_epoch(
                    epoch, 
                    model, 
                    loss_val, 
                    val_loader, 
                    device, 
                    scheduler=None, 
                    schd_loss_update=False)

            if epoch_accuracy > best_accuracy:
                torch.save(model.state_dict(),'{}_fold{}_best.ckpt'.format(CFG.arch, fold))
                best_accuracy = epoch_accuracy
                best_epoch = epoch
                print('Best model is saved')
        cross_accuracy += [best_accuracy]
        print('Fold{} best accuracy = {} in epoch {}'.format(fold,best_accuracy,best_epoch))
        del model, optimizer, train_loader, val_loader, scaler, scheduler
        torch.cuda.empty_cache()
    print('{} folds cross validation CV = {:.5f}'.format(CFG.num_splits,np.average(cross_accuracy)))

CFG seed is  42


Train epoch 0 loss: 1.61677: 100%|██████████| 20/20 [00:02<00:00,  7.26it/s]
Val epoch 0 loss: 1.50206: 100%|██████████| 5/5 [00:00<00:00, 19.28it/s]


Validation multi-class accuracy = 0.65000
Best model is saved
Fold0 best accuracy = 0.65 in epoch 0


Train epoch 0 loss: 1.59700: 100%|██████████| 20/20 [00:02<00:00,  7.35it/s]
Val epoch 0 loss: 1.41611: 100%|██████████| 5/5 [00:00<00:00, 18.81it/s]


Validation multi-class accuracy = 0.65000
Best model is saved
Fold1 best accuracy = 0.65 in epoch 0


Train epoch 0 loss: 1.62903: 100%|██████████| 20/20 [00:02<00:00,  6.94it/s]
Val epoch 0 loss: 1.52073: 100%|██████████| 5/5 [00:00<00:00, 18.71it/s]


Validation multi-class accuracy = 0.50000
Best model is saved
Fold2 best accuracy = 0.5 in epoch 0


Train epoch 0 loss: 1.61014: 100%|██████████| 20/20 [00:02<00:00,  7.21it/s]
Val epoch 0 loss: 1.53001: 100%|██████████| 5/5 [00:00<00:00, 18.78it/s]


Validation multi-class accuracy = 0.55000
Best model is saved
Fold3 best accuracy = 0.55 in epoch 0


Train epoch 0 loss: 1.58263: 100%|██████████| 20/20 [00:02<00:00,  7.06it/s]
Val epoch 0 loss: 1.48668: 100%|██████████| 5/5 [00:00<00:00, 18.10it/s]


Validation multi-class accuracy = 0.65000
Best model is saved
Fold4 best accuracy = 0.65 in epoch 0
5 folds cross validation CV = 0.60000
