In [1]:
from tqdm import tqdm
from glob import glob
import os

import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import random

from torchvision.models import *
from torch.utils.data import DataLoader
from torchvision.transforms import AutoAugment
from torchvision.transforms import AutoAugmentPolicy
from torchvision.transforms import transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.model_selection import train_test_split, StratifiedKFold
import numpy as np

import albumentations as A
from albumentations.pytorch import ToTensorV2

from CustomLoader import CustomLoader
import timm
from utils.CosineAnnealingWarmUpRestarts import  CosineAnnealingWarmUpRestarts

from utils.utils import *

  warn(f"Failed to load image Python extension: {e}")


In [2]:
training_set, test_set = getDataSet('./dataset/')
set_random_seed(1813)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Stratic Split

In [5]:
getDistributionDataSet(training_set)

{'0': 500,
 '1': 500,
 '10': 350,
 '11': 350,
 '12': 300,
 '13': 300,
 '14': 300,
 '15': 250,
 '16': 250,
 '17': 250,
 '18': 200,
 '19': 200,
 '2': 500,
 '20': 200,
 '21': 150,
 '22': 150,
 '23': 150,
 '24': 100,
 '25': 100,
 '26': 100,
 '27': 50,
 '28': 50,
 '29': 50,
 '3': 450,
 '4': 450,
 '5': 450,
 '6': 400,
 '7': 400,
 '8': 400,
 '9': 350}

In [6]:
CFG = {
    'AUG_MODE' : 'albu', # or [albu,transform]
    'MEAN' : [0.5051, 0.4853, 0.4409],
    'STD' : [0.2774, 0.2568, 0.2795]
}

## Class Weight

In [7]:
class_weight = []

dist = getDistributionDataSet(training_set)
dist = dict(sorted(dist.items(), key=lambda x:int(x[0])))

sum_of_dist = 0

for k in dist.keys():
    sum_of_dist += dist[k]

for k in dist.keys():
    class_weight.append(1 - (dist[k] / sum_of_dist))

# Transform

In [9]:
if CFG['AUG_MODE'] == 'albu':
    transform_train = A.Compose([
        A.Resize(32,32),
        A.Rotate(limit=(-360,360), interpolation=1, border_mode=1),
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.RandomToneCurve(p=0.2, scale=0.15),
        A.Blur(p=0.4, blur_limit=(1, 3)),
        A.CLAHE(p=0.4, clip_limit=(1, 4), tile_grid_size=(8, 8)),
        A.Normalize(CFG['MEAN'], CFG['STD']),
        ToTensorV2(),
    ])

    transform_test = A.Compose([
        A.Resize(32,32),
        A.Normalize(CFG['MEAN'], CFG['STD']),
        ToTensorV2(),
    ])
else:
    transform_train = transforms.Compose([
        transforms.Resize(32),
        AutoAugment(AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
    ])
    
    transform_test = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
    ])

## Kfold

In [10]:
all_labels = []
for file in training_set:
    all_labels.append(int(file.split('/')[3]))
len(all_labels), len(training_set)

(8250, 8250)

In [11]:
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1813)

train_dataset = CustomLoader(training_set, transforms=transform_train, is_train=True, aug_mode=CFG['AUG_MODE'])

for foldk, (train_idx, val_idx) in enumerate(kf.split(X=training_set, y=all_labels)):
    model = timm.create_model(model_name='vit_base_resnet50d_224', pretrained=False, num_classes=30, img_size=32, drop_rate=0.05)
    
    lr = 0.000
    
    criterion = nn.CrossEntropyLoss(weight = torch.tensor(class_weight).to(device))
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9,0.999), weight_decay=0.0001)
    scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0 = 50, T_mult=1, eta_max=0.0005, T_up=10, gamma=0.25)
    
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    valid_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(train_dataset, batch_size = 64, num_workers=4, sampler=train_subsampler)
    valid_loader = DataLoader(train_dataset, batch_size = 64, num_workers=4, sampler=valid_subsampler)
    
    # Train
    n_epochs = 100
    EPOCH_FROM = 0
    
    train_loss = torch.zeros(n_epochs)
    valid_loss = torch.zeros(n_epochs)
    
    train_acc = torch.zeros(n_epochs)
    valid_acc = torch.zeros(n_epochs)
    
    valid_loss_min = np.Inf
    past_lr = lr
    model.to(device)
    
    last_loss_update = 0
    
    for e in range(n_epochs):
        # Trian
        model.train()
        for (image, labels, _) in tqdm(train_loader):
            image, labels = image.to(device), labels.to(device)
            
            optimizer.zero_grad()
            pred = model(image)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()
            train_loss[e] += loss.item()
    
            ps = F.softmax(pred, dim=1)
            top_p, top_class = ps.topk(1, dim=1)
            equals = top_class == labels.reshape(top_class.shape)
            train_acc[e] += torch.mean(equals.type(torch.float)).detach().cpu()
            
        
        train_loss[e] /= len(train_loader)
        train_acc[e] /= len(train_loader)
        
        # Validation
        with torch.no_grad():
            model.eval()
            for image, labels, _ in tqdm(valid_loader):
                image, labels = image.to(device), labels.to(device)
                
                logits = model(image)
                loss = criterion(logits, labels)
                valid_loss[e] += loss.item()
    
                ps = F.softmax(logits, dim=1)
                top_p, top_class = ps.topk(1, dim=1)
                equals = top_class == labels.reshape(top_class.shape)
                valid_acc[e] += torch.mean(equals.type(torch.float)).detach().cpu()
    
        valid_loss[e] /= len(valid_loader)
        valid_acc[e] /= len(valid_loader)
        
        scheduler.step(epoch=e)
        
        print('Fold: {} \tEpoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(foldk, e, train_loss[e], valid_loss[e]))
        print('Fold: {} \tEpoch: {} \tTraining accuracy: {:.6f} \tValidation accuracy: {:.6f}'.format(foldk, e, train_acc[e], valid_acc[e]))
        
        if past_lr != optimizer.param_groups[0]['lr']:
            print(f"Learning Rete : {optimizer.param_groups[0]['lr']}")
            past_lr = optimizer.param_groups[0]['lr']
            
        if valid_loss[e] <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,valid_loss[e]))
            torch.save(model.state_dict(), f'./models/replay_test/best_model_fold{foldk}.pth')
            valid_loss_min = valid_loss[e]
            last_loss_update = e

100%|██████████| 104/104 [00:15<00:00,  6.57it/s]
100%|██████████| 26/26 [00:04<00:00,  6.21it/s]


Fold: 0 	Epoch: 0 	Training Loss: 3.527152 	Validation Loss: 3.530897
Fold: 0 	Epoch: 0 	Training accuracy: 0.040865 	Validation accuracy: 0.048149
Validation loss decreased (inf --> 3.530897).  Saving model ...


100%|██████████| 104/104 [00:12<00:00,  8.65it/s]
100%|██████████| 26/26 [00:03<00:00,  6.85it/s]


Fold: 0 	Epoch: 1 	Training Loss: 3.535241 	Validation Loss: 3.532308
Fold: 0 	Epoch: 1 	Training accuracy: 0.039814 	Validation accuracy: 0.042837
Learning Rete : 5e-05


100%|██████████| 104/104 [00:11<00:00,  8.76it/s]
100%|██████████| 26/26 [00:03<00:00,  6.53it/s]


Fold: 0 	Epoch: 2 	Training Loss: 2.831267 	Validation Loss: 2.637156
Fold: 0 	Epoch: 2 	Training accuracy: 0.185547 	Validation accuracy: 0.231082
Learning Rete : 0.0001
Validation loss decreased (3.530897 --> 2.637156).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.80it/s]
100%|██████████| 26/26 [00:03<00:00,  6.74it/s]


Fold: 0 	Epoch: 3 	Training Loss: 2.624049 	Validation Loss: 2.569494
Fold: 0 	Epoch: 3 	Training accuracy: 0.231520 	Validation accuracy: 0.236731
Learning Rete : 0.00015000000000000001
Validation loss decreased (2.637156 --> 2.569494).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.79it/s]
100%|██████████| 26/26 [00:03<00:00,  6.60it/s]


Fold: 0 	Epoch: 4 	Training Loss: 2.566082 	Validation Loss: 2.562721
Fold: 0 	Epoch: 4 	Training accuracy: 0.247897 	Validation accuracy: 0.239736
Learning Rete : 0.0002
Validation loss decreased (2.569494 --> 2.562721).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.83it/s]
100%|██████████| 26/26 [00:03<00:00,  6.74it/s]


Fold: 0 	Epoch: 5 	Training Loss: 2.500366 	Validation Loss: 2.455637
Fold: 0 	Epoch: 5 	Training accuracy: 0.260817 	Validation accuracy: 0.269615
Learning Rete : 0.00025
Validation loss decreased (2.562721 --> 2.455637).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.83it/s]
100%|██████████| 26/26 [00:03<00:00,  6.62it/s]


Fold: 0 	Epoch: 6 	Training Loss: 2.408598 	Validation Loss: 2.383157
Fold: 0 	Epoch: 6 	Training accuracy: 0.289062 	Validation accuracy: 0.304543
Learning Rete : 0.00030000000000000003
Validation loss decreased (2.455637 --> 2.383157).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.80it/s]
100%|██████████| 26/26 [00:03<00:00,  6.64it/s]


Fold: 0 	Epoch: 7 	Training Loss: 2.393866 	Validation Loss: 2.360771
Fold: 0 	Epoch: 7 	Training accuracy: 0.291016 	Validation accuracy: 0.297596
Learning Rete : 0.00035
Validation loss decreased (2.383157 --> 2.360771).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.76it/s]
100%|██████████| 26/26 [00:03<00:00,  6.60it/s]


Fold: 0 	Epoch: 8 	Training Loss: 2.342572 	Validation Loss: 2.345756
Fold: 0 	Epoch: 8 	Training accuracy: 0.302734 	Validation accuracy: 0.317861
Learning Rete : 0.0004
Validation loss decreased (2.360771 --> 2.345756).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.77it/s]
100%|██████████| 26/26 [00:03<00:00,  6.50it/s]


Fold: 0 	Epoch: 9 	Training Loss: 2.282842 	Validation Loss: 2.245130
Fold: 0 	Epoch: 9 	Training accuracy: 0.318960 	Validation accuracy: 0.334351
Learning Rete : 0.00045000000000000004
Validation loss decreased (2.345756 --> 2.245130).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.83it/s]
100%|██████████| 26/26 [00:03<00:00,  6.78it/s]


Fold: 0 	Epoch: 10 	Training Loss: 2.268361 	Validation Loss: 2.277185
Fold: 0 	Epoch: 10 	Training accuracy: 0.321665 	Validation accuracy: 0.331683
Learning Rete : 0.0005


100%|██████████| 104/104 [00:11<00:00,  8.82it/s]
100%|██████████| 26/26 [00:03<00:00,  6.89it/s]


Fold: 0 	Epoch: 11 	Training Loss: 2.217800 	Validation Loss: 2.227381
Fold: 0 	Epoch: 11 	Training accuracy: 0.344651 	Validation accuracy: 0.352019
Learning Rete : 0.000499229333433282
Validation loss decreased (2.245130 --> 2.227381).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.83it/s]
100%|██████████| 26/26 [00:03<00:00,  6.77it/s]


Fold: 0 	Epoch: 12 	Training Loss: 2.152088 	Validation Loss: 2.197854
Fold: 0 	Epoch: 12 	Training accuracy: 0.359826 	Validation accuracy: 0.332716
Learning Rete : 0.0004969220851487844
Validation loss decreased (2.227381 --> 2.197854).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.81it/s]
100%|██████████| 26/26 [00:03<00:00,  6.57it/s]


Fold: 0 	Epoch: 13 	Training Loss: 2.074582 	Validation Loss: 2.122983
Fold: 0 	Epoch: 13 	Training accuracy: 0.380258 	Validation accuracy: 0.369784
Learning Rete : 0.0004930924800994192
Validation loss decreased (2.197854 --> 2.122983).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.78it/s]
100%|██████████| 26/26 [00:03<00:00,  6.97it/s]


Fold: 0 	Epoch: 14 	Training Loss: 2.081739 	Validation Loss: 2.094129
Fold: 0 	Epoch: 14 	Training accuracy: 0.381611 	Validation accuracy: 0.384712
Learning Rete : 0.0004877641290737884
Validation loss decreased (2.122983 --> 2.094129).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.79it/s]
100%|██████████| 26/26 [00:04<00:00,  6.50it/s]


Fold: 0 	Epoch: 15 	Training Loss: 2.016150 	Validation Loss: 2.042502
Fold: 0 	Epoch: 15 	Training accuracy: 0.395282 	Validation accuracy: 0.387043
Learning Rete : 0.0004809698831278217
Validation loss decreased (2.094129 --> 2.042502).  Saving model ...


100%|██████████| 104/104 [00:11<00:00,  8.82it/s]
100%|██████████| 26/26 [00:04<00:00,  6.34it/s]


Fold: 0 	Epoch: 16 	Training Loss: 1.942291 	Validation Loss: 2.035476
Fold: 0 	Epoch: 16 	Training accuracy: 0.414814 	Validation accuracy: 0.390721
Learning Rete : 0.00047275163104709196
Validation loss decreased (2.042502 --> 2.035476).  Saving model ...


  0%|          | 0/104 [00:01<?, ?it/s]


KeyboardInterrupt: 