In [1]:
from tqdm import tqdm
from glob import glob
import os

import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import random

from torchvision.models import *
from torch.utils.data import DataLoader
from torchvision.transforms import AutoAugment
from torchvision.transforms import AutoAugmentPolicy
from torchvision.transforms import transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.model_selection import train_test_split, StratifiedKFold
import numpy as np

import albumentations as A
from albumentations.imgaug.transforms import IAAPiecewiseAffine
from albumentations.pytorch import ToTensorV2

from CustomLoader import CustomLoader
import timm
from utils.CosineAnnealingWarmUpRestarts import  CosineAnnealingWarmUpRestarts

from utils.utils import *
from utils.AugMix import *
from utils.CutMix import *
from utils.MyModel import *

  warn(f"Failed to load image Python extension: {e}")


In [2]:
training_set, test_set = getDataSet('./dataset/')
set_random_seed(1813)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Stratic Split

In [3]:
classes = {}

for ts in training_set:
    class_n = ts.split('/')[3]

    if class_n in classes.keys():
        classes[class_n].append(ts)
    else:
        classes[class_n] = [ts]

In [5]:
dist = getDistributionDataSet(training_set)
dist = dict(sorted(dist.items(), key=lambda x:int(x[0])))

In [6]:
dist

{'0': 500,
 '1': 500,
 '2': 500,
 '3': 450,
 '4': 450,
 '5': 450,
 '6': 400,
 '7': 400,
 '8': 400,
 '9': 350,
 '10': 350,
 '11': 350,
 '12': 300,
 '13': 300,
 '14': 300,
 '15': 250,
 '16': 250,
 '17': 250,
 '18': 200,
 '19': 200,
 '20': 200,
 '21': 150,
 '22': 150,
 '23': 150,
 '24': 100,
 '25': 100,
 '26': 100,
 '27': 50,
 '28': 50,
 '29': 50}

In [7]:
fold_k = 5

fold_dict = {}

for i in range(fold_k):
    fold_dict[i] = []
    
for k in classes.keys():
    ls = classes[k]
    random.shuffle(ls)
    ls = np.array(ls)
    splited_ls = np.split(ls, fold_k)
    
    for i, l in enumerate(splited_ls):
        fold_dict[i] += list(l)
    
    
training_set = []

for i in range(fold_k):
    training_set += fold_dict[i]

In [8]:
CFG = {
    'AUG_MODE' : 'albu', # or [albu,transform]
    'MEAN' : [0.5051, 0.4853, 0.4409],
    'STD' : [0.2774, 0.2568, 0.2795],
    'CutMix': False,
    'mix_prob' : 0.5
}

## Class Weight

In [9]:
class_weight = []

dist = getDistributionDataSet(training_set)
dist = dict(sorted(dist.items(), key=lambda x:int(x[0])))

sum_of_dist = 0

for k in dist.keys():
    sum_of_dist += dist[k]

for k in dist.keys():
    class_weight.append(1 - (dist[k] / sum_of_dist))

# Transform

In [11]:
def getTransform(train_mean = [.5, .5, .5], train_std= [.5, .5, .5], val_mean= [.5, .5, .5], val_std= [.5, .5, .5], aug_mode='albu'):
    if aug_mode == 'albu':
        transform_train = A.Compose([
            A.Resize(32,32),
            A.Rotate(limit=(-360,360), interpolation=1, border_mode=1, always_apply=True),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.RandomToneCurve(p=0.2, scale=0.15),
            A.CLAHE(p=0.4, clip_limit=(1, 4), tile_grid_size=(8, 8)),
            A.CoarseDropout(max_height=3, max_width=3, p = 0.7),
            A.Normalize(train_mean, train_std),
            ToTensorV2(),
        ])
    
        transform_test = A.Compose([
            A.Resize(32,32),
            A.Normalize(val_mean, val_std),
            ToTensorV2(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.Resize(32),
            AutoAugment(AutoAugmentPolicy.CIFAR10),
            transforms.Normalize(train_mean, train_std),
            transforms.ToTensor(),
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize(32),
            transforms.Normalize(val_mean, val_std),
            transforms.ToTensor(),
        ])
        
    return transform_train, transform_test

## Kfold

In [12]:
train_norm, train_std = getNormStd(training_set)
transform_train, transform_valid = getTransform(train_norm, train_std, train_norm, train_std)

In [13]:
all_labels = []
for file in training_set:
    all_labels.append(int(file.split('/')[3]))
len(all_labels), len(training_set)

(8250, 8250)

In [14]:
kf = StratifiedKFold(n_splits=5, shuffle=False) #, random_state=1813)

for foldk, (train_idx, val_idx) in enumerate(kf.split(X=training_set, y=all_labels)):
    print(f'============= Fold-{foldk} strat =============')
    model = timm.create_model(model_name='vit_base_resnet50d_224', pretrained=False, num_classes=30, img_size=32, drop_rate=0.2)
    
    lr = 0.000
    
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    valid_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
    
    train_dataset = CustomLoader(training_set, transforms=transform_train, is_train=True, aug_mode=CFG['AUG_MODE'])
    valid_dataset = CustomLoader(training_set, transforms=transform_valid, is_train=True, aug_mode=CFG['AUG_MODE'])
    
    train_loader = DataLoader(train_dataset, batch_size = 64, num_workers=4, sampler=train_subsampler)
    valid_loader = DataLoader(valid_dataset, batch_size = 64, num_workers=4, sampler=valid_subsampler)
    
    criterion = nn.CrossEntropyLoss(weight = torch.tensor(class_weight).to(device))
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.999), weight_decay=0.001)
    scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0 = 40, T_mult=1, eta_max=0.001, T_up=5, gamma=0.75)
    
    # Train
    n_epochs = 500
    EPOCH_FROM = 0
    
    train_loss = torch.zeros(n_epochs)
    valid_loss = torch.zeros(n_epochs)
    
    train_acc = torch.zeros(n_epochs)
    valid_acc = torch.zeros(n_epochs)
    
    valid_loss_min = np.Inf
    past_lr = lr
    model.to(device)
    
    last_loss_update = 0
    
    for e in range(n_epochs):
        # Trian
        model.train()
        for (image, labels, _) in tqdm(train_loader):
            image, labels = image.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            if CFG['CutMix'] and e < n_epochs - 10:
                mix_decision = np.random.rand()
                if mix_decision  < CFG['mix_prob']:
                    image, labels = cutmix(image, labels, 1.0)
            else:
                mix_decision = 1
                
            pred = model(image)
            
            if mix_decision < CFG['mix_prob']:
                loss = criterion(pred, labels[0]) * labels[2] + criterion(pred, labels[1]) * (1-labels[2])
            else:
                loss = criterion(pred, labels)
            
            loss.backward()
            optimizer.step()
            train_loss[e] += loss.item()
    
            ps = F.softmax(pred, dim=1)
            top_p, top_class = ps.topk(1, dim=1)
            
        
        train_loss[e] /= len(train_loader)
        train_acc[e] /= len(train_loader)
        
        # Validation
        with torch.no_grad():
            model.eval()
            for image, labels, _ in tqdm(valid_loader):
                image, labels = image.to(device), labels.to(device)
                
                logits = model(image)
                loss = criterion(logits, labels)
                valid_loss[e] += loss.item()
    
                ps = F.softmax(logits, dim=1)
                top_p, top_class = ps.topk(1, dim=1)
                equals = top_class == labels.reshape(top_class.shape)
                valid_acc[e] += torch.mean(equals.type(torch.float)).detach().cpu()
    
        valid_loss[e] /= len(valid_loader)
        valid_acc[e] /= len(valid_loader)
        
        scheduler.step(epoch=e)
        
        print('Fold: {} \tEpoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(foldk, e, train_loss[e], valid_loss[e]))
        print('Fold: {} \tEpoch: {} \tTraining accuracy: {:.6f} \tValidation accuracy: {:.6f}'.format(foldk, e, train_acc[e], valid_acc[e]))
        
        print(optimizer.param_groups[-1]['lr'])
            
        if valid_loss[e] <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,valid_loss[e]))
            torch.save(model.state_dict(), f'./models/mymodel/best_model_fold{foldk}.pth')
            valid_loss_min = valid_loss[e]
            last_loss_update = e
        



100%|██████████| 104/104 [00:15<00:00,  6.89it/s]
100%|██████████| 26/26 [00:04<00:00,  6.18it/s]


Fold: 0 	Epoch: 0 	Training Loss: 3.536098 	Validation Loss: 3.545161
Fold: 0 	Epoch: 0 	Training accuracy: 0.000000 	Validation accuracy: 0.042404
0.0
Validation loss decreased (inf --> 3.545161).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.67it/s]
100%|██████████| 26/26 [00:03<00:00,  7.03it/s]


Fold: 0 	Epoch: 1 	Training Loss: 3.545936 	Validation Loss: 3.547640
Fold: 0 	Epoch: 1 	Training accuracy: 0.000000 	Validation accuracy: 0.044207
0.0002


100%|██████████| 104/104 [00:11<00:00,  8.75it/s]
100%|██████████| 26/26 [00:03<00:00,  6.95it/s]


Fold: 0 	Epoch: 2 	Training Loss: 2.869291 	Validation Loss: 2.616351
Fold: 0 	Epoch: 2 	Training accuracy: 0.000000 	Validation accuracy: 0.234423
0.0004
Validation loss decreased (3.545161 --> 2.616351).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.72it/s]
100%|██████████| 26/26 [00:03<00:00,  7.04it/s]


Fold: 0 	Epoch: 3 	Training Loss: 2.720421 	Validation Loss: 2.670020
Fold: 0 	Epoch: 3 	Training accuracy: 0.000000 	Validation accuracy: 0.218029
0.0006000000000000001


100%|██████████| 104/104 [00:11<00:00,  8.71it/s]
100%|██████████| 26/26 [00:03<00:00,  6.73it/s]


Fold: 0 	Epoch: 4 	Training Loss: 2.718459 	Validation Loss: 2.553934
Fold: 0 	Epoch: 4 	Training accuracy: 0.000000 	Validation accuracy: 0.261106
0.0008
Validation loss decreased (2.616351 --> 2.553934).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.69it/s]
100%|██████████| 26/26 [00:03<00:00,  6.65it/s]


Fold: 0 	Epoch: 5 	Training Loss: 2.665587 	Validation Loss: 2.482176
Fold: 0 	Epoch: 5 	Training accuracy: 0.000000 	Validation accuracy: 0.241106
0.001
Validation loss decreased (2.553934 --> 2.482176).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.67it/s]
100%|██████████| 26/26 [00:03<00:00,  6.63it/s]


Fold: 0 	Epoch: 6 	Training Loss: 2.587035 	Validation Loss: 2.516703
Fold: 0 	Epoch: 6 	Training accuracy: 0.000000 	Validation accuracy: 0.265817
0.0009979871469976197


100%|██████████| 104/104 [00:11<00:00,  8.69it/s]
100%|██████████| 26/26 [00:03<00:00,  6.75it/s]


Fold: 0 	Epoch: 7 	Training Loss: 2.533258 	Validation Loss: 2.589780
Fold: 0 	Epoch: 7 	Training accuracy: 0.000000 	Validation accuracy: 0.259567
0.0009919647942993148


100%|██████████| 104/104 [00:11<00:00,  8.76it/s]
100%|██████████| 26/26 [00:03<00:00,  6.74it/s]


Fold: 0 	Epoch: 8 	Training Loss: 2.462719 	Validation Loss: 2.364062
Fold: 0 	Epoch: 8 	Training accuracy: 0.000000 	Validation accuracy: 0.313918
0.0009819814303479266
Validation loss decreased (2.482176 --> 2.364062).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.77it/s]
100%|██████████| 26/26 [00:03<00:00,  6.75it/s]


Fold: 0 	Epoch: 9 	Training Loss: 2.381591 	Validation Loss: 2.357682
Fold: 0 	Epoch: 9 	Training accuracy: 0.000000 	Validation accuracy: 0.311851
0.0009681174353198686
Validation loss decreased (2.364062 --> 2.357682).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.75it/s]
100%|██████████| 26/26 [00:03<00:00,  6.78it/s]


Fold: 0 	Epoch: 10 	Training Loss: 2.363885 	Validation Loss: 2.317262
Fold: 0 	Epoch: 10 	Training accuracy: 0.000000 	Validation accuracy: 0.314856
0.0009504844339512095
Validation loss decreased (2.357682 --> 2.317262).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.78it/s]
100%|██████████| 26/26 [00:03<00:00,  6.73it/s]


Fold: 0 	Epoch: 11 	Training Loss: 2.304745 	Validation Loss: 2.219729
Fold: 0 	Epoch: 11 	Training accuracy: 0.000000 	Validation accuracy: 0.336995
0.000929224396800933
Validation loss decreased (2.317262 --> 2.219729).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.80it/s]
100%|██████████| 26/26 [00:03<00:00,  6.66it/s]


Fold: 0 	Epoch: 12 	Training Loss: 2.285015 	Validation Loss: 2.157266
Fold: 0 	Epoch: 12 	Training accuracy: 0.000000 	Validation accuracy: 0.362139
0.0009045084971874737
Validation loss decreased (2.219729 --> 2.157266).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.75it/s]
100%|██████████| 26/26 [00:03<00:00,  6.71it/s]


Fold: 0 	Epoch: 13 	Training Loss: 2.265102 	Validation Loss: 2.164890
Fold: 0 	Epoch: 13 	Training accuracy: 0.000000 	Validation accuracy: 0.351226
0.0008765357330018055


100%|██████████| 104/104 [00:11<00:00,  8.74it/s]
100%|██████████| 26/26 [00:03<00:00,  6.74it/s]


Fold: 0 	Epoch: 14 	Training Loss: 2.197180 	Validation Loss: 2.162970
Fold: 0 	Epoch: 14 	Training accuracy: 0.000000 	Validation accuracy: 0.371514
0.0008455313244934324


100%|██████████| 104/104 [00:11<00:00,  8.78it/s]
100%|██████████| 26/26 [00:03<00:00,  6.74it/s]


Fold: 0 	Epoch: 15 	Training Loss: 2.191818 	Validation Loss: 2.122521
Fold: 0 	Epoch: 15 	Training accuracy: 0.000000 	Validation accuracy: 0.377500
0.0008117449009293668
Validation loss decreased (2.157266 --> 2.122521).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.75it/s]
100%|██████████| 26/26 [00:03<00:00,  7.03it/s]


Fold: 0 	Epoch: 16 	Training Loss: 2.160924 	Validation Loss: 2.087918
Fold: 0 	Epoch: 16 	Training accuracy: 0.000000 	Validation accuracy: 0.387981
0.0007754484907260512
Validation loss decreased (2.122521 --> 2.087918).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.75it/s]
100%|██████████| 26/26 [00:03<00:00,  6.80it/s]


Fold: 0 	Epoch: 17 	Training Loss: 2.104029 	Validation Loss: 2.069508
Fold: 0 	Epoch: 17 	Training accuracy: 0.000000 	Validation accuracy: 0.403437
0.0007369343312364993
Validation loss decreased (2.087918 --> 2.069508).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.77it/s]
100%|██████████| 26/26 [00:03<00:00,  7.06it/s]


Fold: 0 	Epoch: 18 	Training Loss: 2.081517 	Validation Loss: 2.013718
Fold: 0 	Epoch: 18 	Training accuracy: 0.000000 	Validation accuracy: 0.408077
0.0006965125158269618
Validation loss decreased (2.069508 --> 2.013718).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.78it/s]
100%|██████████| 26/26 [00:03<00:00,  7.11it/s]


Fold: 0 	Epoch: 19 	Training Loss: 2.039505 	Validation Loss: 1.980966
Fold: 0 	Epoch: 19 	Training accuracy: 0.000000 	Validation accuracy: 0.424832
0.0006545084971874737
Validation loss decreased (2.013718 --> 1.980966).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.73it/s]
100%|██████████| 26/26 [00:03<00:00,  7.17it/s]


Fold: 0 	Epoch: 20 	Training Loss: 2.037303 	Validation Loss: 1.966071
Fold: 0 	Epoch: 20 	Training accuracy: 0.000000 	Validation accuracy: 0.423437
0.0006112604669781572
Validation loss decreased (1.980966 --> 1.966071).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.79it/s]
100%|██████████| 26/26 [00:03<00:00,  6.82it/s]


Fold: 0 	Epoch: 21 	Training Loss: 1.974529 	Validation Loss: 1.913010
Fold: 0 	Epoch: 21 	Training accuracy: 0.000000 	Validation accuracy: 0.431418
0.0005671166329088278
Validation loss decreased (1.966071 --> 1.913010).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.79it/s]
100%|██████████| 26/26 [00:03<00:00,  6.77it/s]


Fold: 0 	Epoch: 22 	Training Loss: 1.929165 	Validation Loss: 1.911641
Fold: 0 	Epoch: 22 	Training accuracy: 0.000000 	Validation accuracy: 0.431346
0.0005224324151752575
Validation loss decreased (1.913010 --> 1.911641).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.68it/s]
100%|██████████| 26/26 [00:03<00:00,  6.82it/s]


Fold: 0 	Epoch: 23 	Training Loss: 1.918368 	Validation Loss: 1.899196
Fold: 0 	Epoch: 23 	Training accuracy: 0.000000 	Validation accuracy: 0.438029
0.0004775675848247426
Validation loss decreased (1.911641 --> 1.899196).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.78it/s]
100%|██████████| 26/26 [00:03<00:00,  6.75it/s]


Fold: 0 	Epoch: 24 	Training Loss: 1.900738 	Validation Loss: 1.885210
Fold: 0 	Epoch: 24 	Training accuracy: 0.000000 	Validation accuracy: 0.444712
0.0004328833670911723
Validation loss decreased (1.899196 --> 1.885210).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.76it/s]
100%|██████████| 26/26 [00:03<00:00,  6.80it/s]


Fold: 0 	Epoch: 25 	Training Loss: 1.832010 	Validation Loss: 1.873663
Fold: 0 	Epoch: 25 	Training accuracy: 0.000000 	Validation accuracy: 0.443702
0.00038873953302184284
Validation loss decreased (1.885210 --> 1.873663).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.77it/s]
100%|██████████| 26/26 [00:03<00:00,  6.74it/s]


Fold: 0 	Epoch: 26 	Training Loss: 1.833918 	Validation Loss: 1.842960
Fold: 0 	Epoch: 26 	Training accuracy: 0.000000 	Validation accuracy: 0.454591
0.00034549150281252633
Validation loss decreased (1.873663 --> 1.842960).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.76it/s]
100%|██████████| 26/26 [00:03<00:00,  6.75it/s]


Fold: 0 	Epoch: 27 	Training Loss: 1.811484 	Validation Loss: 1.812263
Fold: 0 	Epoch: 27 	Training accuracy: 0.000000 	Validation accuracy: 0.456659
0.0003034874841730383
Validation loss decreased (1.842960 --> 1.812263).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.77it/s]
100%|██████████| 26/26 [00:03<00:00,  7.08it/s]


Fold: 0 	Epoch: 28 	Training Loss: 1.770743 	Validation Loss: 1.772101
Fold: 0 	Epoch: 28 	Training accuracy: 0.000000 	Validation accuracy: 0.461034
0.0002630656687635007
Validation loss decreased (1.812263 --> 1.772101).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.71it/s]
100%|██████████| 26/26 [00:03<00:00,  6.69it/s]


Fold: 0 	Epoch: 29 	Training Loss: 1.731652 	Validation Loss: 1.763829
Fold: 0 	Epoch: 29 	Training accuracy: 0.000000 	Validation accuracy: 0.476058
0.0002245515092739488
Validation loss decreased (1.772101 --> 1.763829).  Saving model ...


  0%|          | 0/104 [00:02<?, ?it/s]


KeyboardInterrupt: 