In [1]:
import os
import gc
import cv2
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
from pytorch_toolbelt import losses as L

# Utils
from tqdm.auto import tqdm

# For Image Models
import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

## using gpu:1
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything()

In [2]:
class Customize_Model(nn.Module):
    def __init__(self, model_name, cls):
        super().__init__()
        self.model = timm.create_model(model_name, 
                                       pretrained=True, 
                                       num_classes=cls, 
                                       drop_rate= CFG['drop_out'], 
                                       drop_path_rate= CFG['drop_path'])
        
    def forward(self, image):
        x = self.model(image)
        return x

In [3]:
def get_train_transform(img_size):
    return A.Compose([
        A.SmallestMaxSize(max_size=img_size, interpolation=3, p=1),
#         A.Resize(img_size, img_size),
        
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
#         A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
#         A.Blur(blur_limit= 3, p=0.3), 
        A.GaussNoise(p=0.3),
        A.OneOf([
                A.Cutout(max_h_size=10, max_w_size=16),
                A.CoarseDropout(max_holes=4),
            ], p=0.5),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.15, rotate_limit= 0,
                                        interpolation=cv2.INTER_LINEAR, border_mode=0, p=0.7),
        ToTensorV2(p=1.0),
    ])


def get_test_transform(img_size):
    return A.Compose([
        A.SmallestMaxSize(max_size=img_size, interpolation=3, p=1),
#         A.Resize(img_size, img_size),
        ToTensorV2(p=1.0),
    ])

In [4]:
from toolbox.audio_aug import *

class Customize_Dataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
    
    def __getitem__(self, index):
        data = self.df.loc[index]
        img = cv2.imread(data['image_path'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = data['label']
        
#         if np.random.rand()>0.5:
#             img= cutout(img, cut_width=20, holes=2)
            
#         if np.random.rand()>0.5:
#             while True:
#                 temp_df= self.df[self.df['label']==label]
#                 data = temp_df.sample(n=1, random_state=1).reset_index(drop=True).loc[0]
#                 img_2 = cv2.imread(data['image_path'])
#                 img_2 = cv2.cvtColor(img_2, cv2.COLOR_BGR2RGB)
#                 label_2 = data['label']
#                 if label==label_2: break
#             img= cutmix(img, img_2, cut_width=60)
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': torch.tensor(img/255, dtype=torch.float32),
            'label': torch.tensor(label, dtype=torch.long),
        }
    
    def __len__(self):
        return len(self.df)

In [5]:
class Customize_loss(nn.Module):
    def  __init__(self):
        super().__init__()
        self.CrossEntropy= nn.CrossEntropyLoss()
        self.FocalCosineLoss= L.FocalCosineLoss()
        self.soft_ce= L.SoftCrossEntropyLoss(smooth_factor=0.25)
        self.bi_temp= L.BiTemperedLogisticLoss(t1=0.8, t2=1.2)
    
    def forward(self, y_pred, y_true):
        loss= 1.0 * self.soft_ce(y_pred, y_true)
        return loss

In [6]:
def train_epoch(dataloader, model, criterion, optimizer):
    scaler= amp.GradScaler()
    model.train()

    ep_loss= []
    for i, data in enumerate(tqdm(dataloader)):

        imgs= data['image'].to('cuda')
        labels= data['label'].to('cuda')
        
        with amp.autocast():
            preds= model(imgs)
            loss= criterion(preds, labels)
            ep_loss.append(loss.item())
            loss/= CFG['gradient_accumulation']
            scaler.scale(loss).backward()
            
            if (i+1) % CFG['gradient_accumulation']== 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                
    return np.mean(ep_loss)

In [7]:
from metrics import *

def valid_epoch(dataloader, model, criterion):
    model.eval()
    
    ep_loss= []
    all_pred= []
    all_label= []
    for i, data in enumerate(tqdm(dataloader)):

        imgs= data['image'].to('cuda')
        labels= data['label'].to('cuda')
        all_label.extend(labels.cpu().numpy())
        
        with torch.no_grad():
            preds= model(imgs)
            loss= criterion(preds, labels)
            ep_loss.append(loss.item())
        all_pred.extend(preds.cpu().softmax(dim=-1).numpy())
        
    
    ## caculate metrics
    all_label= np.array(all_label)
    all_pred= np.array(all_pred)
    
    acc= Accuracy(all_pred, all_label)
    print(f'accuracy: {acc}')
    recall= Mean_Recall(all_pred, all_label)
    print(f'mean_recall: {recall}')
    
    cmap= padded_cmap(all_pred, all_label)
    print(f'cmap: {cmap}')
    
    score= cmap
    return np.mean(ep_loss), score

# CFG

In [8]:
timm.list_models(pretrained=True)

['adv_inception_v3',
 'bat_resnext26ts.ch_in1k',
 'beit_base_patch16_224.in22k_ft_in22k',
 'beit_base_patch16_224.in22k_ft_in22k_in1k',
 'beit_base_patch16_384.in22k_ft_in22k_in1k',
 'beit_large_patch16_224.in22k_ft_in22k',
 'beit_large_patch16_224.in22k_ft_in22k_in1k',
 'beit_large_patch16_384.in22k_ft_in22k_in1k',
 'beit_large_patch16_512.in22k_ft_in22k_in1k',
 'beitv2_base_patch16_224.in1k_ft_in22k',
 'beitv2_base_patch16_224.in1k_ft_in22k_in1k',
 'beitv2_large_patch16_224.in1k_ft_in22k',
 'beitv2_large_patch16_224.in1k_ft_in22k_in1k',
 'botnet26t_256',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'coatnet_0_rw_224.sw_in1k',
 'coatnet_1_rw_224.sw_in1k',
 'coatnet_2_rw_224.sw_in12k',
 'coatnet_2_rw_224.sw_in12k_ft_in1k',
 'coatnet_3_rw_224.sw_in12k',
 'coatnet_bn_0_r

In [9]:
CFG= {
    'fold': 0,
    'epoch': 30,
    'model_name': 'tf_efficientnet_b0_ns',
    'finetune': False,
    
    'img_size': 192,
    'batch_size': 64,
    'gradient_accumulation': 1,
    'gradient_checkpoint': False,
    'drop_out': 0.3,
    'drop_path': 0.2,
    
    'lr': 3e-4,
    'weight_decay': 3e-4,
    
    'num_classes': 264,
    'load_model': False, 
    'save_model': './train_model'
}

if CFG['finetune']:
    CFG['lr']= 3e-5
    CFG['load_model']= f"./train_model/cv{CFG['fold']}_best.pth"
CFG

{'fold': 0,
 'epoch': 30,
 'model_name': 'tf_efficientnet_b0_ns',
 'finetune': False,
 'img_size': 192,
 'batch_size': 64,
 'gradient_accumulation': 1,
 'gradient_checkpoint': False,
 'drop_out': 0.3,
 'drop_path': 0.2,
 'lr': 0.0003,
 'weight_decay': 0.0003,
 'num_classes': 264,
 'load_model': False,
 'save_model': './train_model'}

# Prepare Dataset

In [10]:
df= pd.read_csv('Data/train.csv')

train_dataset= df[df['fold']!=CFG['fold']].reset_index(drop=True)
label= train_dataset['label'].unique().tolist()
n= 500
for i,g in enumerate(tqdm(label)):
    sample_df= train_dataset[train_dataset['label']==g]
    if len(sample_df)>=500: sample_df= sample_df.sample(n=n, replace=False, random_state=1).reset_index(drop=True)
    elif len(sample_df)<50: sample_df= sample_df.sample(n=50, replace=True, random_state=1).reset_index(drop=True)
    if i==0: new_df= sample_df
    else: new_df= pd.concat([new_df, sample_df], axis=0).reset_index(drop=True)
train_dataset= new_df

valid_dataset= df[df['fold']==CFG['fold']].reset_index(drop=True)
print(f'train dataset: {len(train_dataset)}')
print(f'valid dataset: {len(valid_dataset)}')

train_dataset= Customize_Dataset(train_dataset, get_train_transform(CFG['img_size']))
valid_dataset= Customize_Dataset(valid_dataset, get_test_transform(CFG['img_size']))

train_loader= DataLoader(train_dataset, batch_size= CFG['batch_size'], shuffle=True, num_workers=0)
valid_loader= DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0)
df.head()

  0%|          | 0/264 [00:00<?, ?it/s]

train dataset: 56523
valid dataset: 26194




Unnamed: 0,image_path,label,label_name,group,fold
0,Data/train_img\abethr1\XC128013._0.png,0,abethr1,XC128013.,1.0
1,Data/train_img\abethr1\XC128013._1.png,0,abethr1,XC128013.,1.0
2,Data/train_img\abethr1\XC128013._2.png,0,abethr1,XC128013.,1.0
3,Data/train_img\abethr1\XC128013._3.png,0,abethr1,XC128013.,1.0
4,Data/train_img\abethr1\XC128013._4.png,0,abethr1,XC128013.,1.0


# Train

In [11]:
## create model
if CFG['load_model']:
    print(f"load_model: {CFG['load_model']}")
    model= torch.load(CFG['load_model'], map_location= 'cuda')
else:
    model= Customize_Model(CFG['model_name'], CFG['num_classes'])
    
if CFG['gradient_checkpoint']: 
    print('use gradient checkpoint')
    model.model.set_grad_checkpointing(enable=True)
model.to('cuda')
    
## hyperparameter
criterion= Customize_loss()
optimizer= optim.AdamW(model.parameters(), lr= CFG['lr'], weight_decay= CFG['weight_decay'])

## start training
best_score= 0
for ep in range(1, CFG['epoch']+1):
    print(f'\nep: {ep}')
    
    train_loss= train_epoch(train_loader, model, criterion, optimizer)
    valid_loss, valid_acc= valid_epoch(valid_loader, model, criterion)
    print(f'train loss: {round(train_loss, 5)}')
    print(f'valid loss: {round(valid_loss, 5)}, valid_acc: {round(valid_acc, 5)}')
    
    if valid_acc >= best_score:
        best_score= valid_acc
        torch.save(model, f"{CFG['save_model']}/cv{CFG['fold']}_best.pth")
        print(f'model save at score: {round(best_score, 5)}')
        
    ## save model every epoch
    torch.save(model, f"{CFG['save_model']}/cv{CFG['fold']}_ep{ep}.pth")

  **kwargs,



ep: 1


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.6005573795525693
mean_recall: 0.3333314525310996


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.5853688470207781
train loss: 4.14565
valid loss: 3.31685, valid_acc: 0.58537
model save at score: 0.58537

ep: 2


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.6480873482476903
mean_recall: 0.4142251088731771


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.6419565163734026
train loss: 3.15133
valid loss: 3.14901, valid_acc: 0.64196
model save at score: 0.64196

ep: 3


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.6731694281133084
mean_recall: 0.44236620963293266


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.6687915984421541
train loss: 2.82122
valid loss: 3.06846, valid_acc: 0.66879
model save at score: 0.66879

ep: 4


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.689470871191876
mean_recall: 0.46687680597356485


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.6816991919710995
train loss: 2.64465
valid loss: 3.01387, valid_acc: 0.6817
model save at score: 0.6817

ep: 5


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.697258914255173
mean_recall: 0.4935588571362261


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.6940577260236853
train loss: 2.52009
valid loss: 2.99631, valid_acc: 0.69406
model save at score: 0.69406

ep: 6


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.705695960907078
mean_recall: 0.49291235798554534


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.6985674626564384
train loss: 2.43745
valid loss: 2.96326, valid_acc: 0.69857
model save at score: 0.69857

ep: 7


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7063067878140032
mean_recall: 0.4964496045725966


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.700214513838065
train loss: 2.37285
valid loss: 2.95812, valid_acc: 0.70021
model save at score: 0.70021

ep: 8


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7100862793006032
mean_recall: 0.5034058982193718


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7035486223881015
train loss: 2.32644
valid loss: 2.94644, valid_acc: 0.70355
model save at score: 0.70355

ep: 9


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7190577994960679
mean_recall: 0.4944724445378629


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7025997599705162
train loss: 2.28929
valid loss: 2.94161, valid_acc: 0.7026

ep: 10


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7113842864778194
mean_recall: 0.4947194180572809


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7057854717320997
train loss: 2.25718
valid loss: 2.94819, valid_acc: 0.70579
model save at score: 0.70579

ep: 11


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7213865770787203
mean_recall: 0.5033412796316257


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.706677235005853
train loss: 2.23501
valid loss: 2.9191, valid_acc: 0.70668
model save at score: 0.70668

ep: 12


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7204321600366496
mean_recall: 0.5060486409478532


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7113226457809058
train loss: 2.21304
valid loss: 2.92466, valid_acc: 0.71132
model save at score: 0.71132

ep: 13


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7250897152019546
mean_recall: 0.5016858500642877


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7143153138435808
train loss: 2.19659
valid loss: 2.91287, valid_acc: 0.71432
model save at score: 0.71432

ep: 14


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7230281743910819
mean_recall: 0.5128825439135536


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7101049118209866
train loss: 2.17584
valid loss: 2.91476, valid_acc: 0.7101

ep: 15


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7242498282049324
mean_recall: 0.5116110278467244


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7139011818271851
train loss: 2.16522
valid loss: 2.91207, valid_acc: 0.7139

ep: 16


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7252424219286859
mean_recall: 0.5238783157678265


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7124721724074708
train loss: 2.15279
valid loss: 2.92055, valid_acc: 0.71247

ep: 17


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7222264640757425
mean_recall: 0.513180148980035


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7119838302331966
train loss: 2.14273
valid loss: 2.91804, valid_acc: 0.71198

ep: 18


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7286783232801405
mean_recall: 0.5181522221307768


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7197132895594951
train loss: 2.13632
valid loss: 2.90405, valid_acc: 0.71971
model save at score: 0.71971

ep: 19


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.72745666946629
mean_recall: 0.5146006798446375


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.712871010430474
train loss: 2.12581
valid loss: 2.90076, valid_acc: 0.71287

ep: 20


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7277239062380698
mean_recall: 0.5170117610974251


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7135727678791655
train loss: 2.1162
valid loss: 2.89618, valid_acc: 0.71357

ep: 21


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7304726273192335
mean_recall: 0.5130497094512046


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.713265280583564
train loss: 2.11301
valid loss: 2.90135, valid_acc: 0.71327

ep: 22


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.72745666946629
mean_recall: 0.5218141731571527


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7142806289471814
train loss: 2.10497
valid loss: 2.90833, valid_acc: 0.71428

ep: 23


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7213102237153547
mean_recall: 0.515044059085934


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7151726947357463
train loss: 2.09817
valid loss: 2.90287, valid_acc: 0.71517

ep: 24


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7318469878598153
mean_recall: 0.5197834486073867


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7135682651762785
train loss: 2.09487
valid loss: 2.87968, valid_acc: 0.71357

ep: 25


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7264258990608536
mean_recall: 0.5079943261333458


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7178646909614792
train loss: 2.08902
valid loss: 2.89027, valid_acc: 0.71786

ep: 26


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.725471482018783
mean_recall: 0.5177918101343779


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7150826316482704
train loss: 2.08283
valid loss: 2.89547, valid_acc: 0.71508

ep: 27


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7302053905474536
mean_recall: 0.5277086232005718


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7176641460713252
train loss: 2.07986
valid loss: 2.89477, valid_acc: 0.71766

ep: 28


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7297090936855769
mean_recall: 0.5262739989737745


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7185917376160803
train loss: 2.07778
valid loss: 2.88155, valid_acc: 0.71859

ep: 29


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7287164999618233
mean_recall: 0.5211781628192248


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7157626106323522
train loss: 2.07229
valid loss: 2.89145, valid_acc: 0.71576

ep: 30


  0%|          | 0/884 [00:00<?, ?it/s]



  0%|          | 0/819 [00:00<?, ?it/s]

accuracy: 0.7268840192410476
mean_recall: 0.5294815888795167


  _warn_prf(average, modifier, msg_start, len(result))


cmap: 0.7155882226846
train loss: 2.07012
valid loss: 2.90442, valid_acc: 0.71559
