In [1]:
import sys
sys.path.append('./Humpback-Whale-Identification-1st/')

from models import *
from utils import *
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

# For Image Models
import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
CONFIG = {"seed": 2022,
          "epochs": 20,
          "img_size": 128,
          "model_name": "tf_efficientnet_b0",
          "num_classes": 15587,
          "train_batch_size": 128,
          "valid_batch_size": 128,
          "learning_rate": 1e-4,
          "scheduler": 'CosineAnnealingLR',
          "min_lr": 1e-6,
          "T_max": 500,
          "weight_decay": 1e-6,
          "n_fold": 5,
          "n_accumulate": 1,
          "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
          # ArcFace Hyperparameters
          "s": 30.0, 
          "m": 0.50,
          "ls_eps": 0.0,
          "easy_margin": False
          }

In [3]:
def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [4]:
ROOT_DIR = '../data/'
TRAIN_DIR = '../data/train_images-128-128/'
TEST_DIR = '../data/test_images-128-128/'

def get_train_file_path(id):
    return f"{TRAIN_DIR}/{id}"

df = pd.read_csv(f"{ROOT_DIR}/train.csv")
df['file_path'] = df['image'].apply(get_train_file_path)

encoder = LabelEncoder()
df['individual_id_map'] = encoder.fit_transform(df['individual_id'])

with open("le.pkl", "wb") as fp:
    joblib.dump(encoder, fp)
# encoder.inverse_transform([1636])

skf = StratifiedKFold(n_splits=CONFIG['n_fold'])
for fold, ( _, val_) in enumerate(skf.split(X=df, y=df.individual_id_map)):
      df.loc[val_ , "kfold"] = fold

In [5]:
class HappyWhaleDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.file_names = df['file_path'].values
        self.labels = df['individual_id_map'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = self.file_names[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = self.labels[index]
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': img,
            'label': torch.tensor(label, dtype=torch.long)
        }

In [6]:
data_transforms = {
    "train": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=30, p=0.5),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.)
}

In [7]:
def prepare_loaders(df, fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    train_dataset = HappyWhaleDataset(df_train, transforms=data_transforms["train"])
    valid_dataset = HappyWhaleDataset(df_valid, transforms=data_transforms["valid"])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=True, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

train_loader, valid_loader = prepare_loaders(df, fold=0)

In [8]:
num_classes = 15587
model = model_whale(num_classes=num_classes, inchannels=3, model_name='senet154').cuda()

In [9]:
def fetch_scheduler(optimizer):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'])
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CONFIG['T_0'], 
                                                             eta_min=CONFIG['min_lr'])
    elif CONFIG['scheduler'] == None:
        return None
        
    return scheduler

optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], 
                       weight_decay=CONFIG['weight_decay'])
scheduler = fetch_scheduler(optimizer)

In [10]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    dataset_size = 0
    running_loss = 0.0
    sum = 0
    train_loss_sum = 0
    train_top1_sum = 0
    train_map5_sum = 0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        optimizer.zero_grad()
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        batch_size = images.size(0)
        
        global_feat, local_feat, results = model(images)
        model.getLoss(global_feat, local_feat, results, labels)
        loss = model.loss
    
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
        optimizer.step()
        results = torch.cat([torch.sigmoid(results), torch.ones_like(results[:, :1]).float().cuda() * 0.5], 1)
        top1_batch = accuracy(results, labels, topk=(1,))[0]
        map5_batch = mapk(labels, results, k=5)
        loss = loss.data.cpu().numpy()
        sum += 1
        train_loss_sum += loss
        train_top1_sum += top1_batch
        train_map5_sum += map5_batch
        
        bar.set_postfix(Epoch=epoch, Train_Loss=loss,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return loss

In [11]:
# @torch.inference_mode()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    with torch.no_grad():
        dataset_size = 0
        running_loss = 0.0
        valid_loss, index_valid= 0, 0
        all_results = []
        all_labels = []

        bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for step, data in bar:        
            images = data['image'].to(device, dtype=torch.float)
            labels = data['label'].to(device, dtype=torch.long)
            global_feat, local_feat, results = model(images)
            model.getLoss(global_feat, local_feat, results, labels)

            results = torch.sigmoid(results)

            all_results.append(results)
            all_labels.append(labels)

            b = len(labels)
            valid_loss += model.loss.data.cpu().numpy() * b
            index_valid += b
        all_results = torch.cat(all_results, 0)
        all_labels = torch.cat(all_labels, 0)

        map5s, top1s, top5s = [], [], []
        if 1:
            ts = np.linspace(0.1, 0.9, 9)
            for t in ts:
                results_t = torch.cat([all_results, torch.ones_like(all_results[:, :1]).float().cuda() * t], 1)
                top1_, top5_ = accuracy(results_t, all_labels)
                map5_ = mapk(all_labels, results_t, k=5)
                map5s.append(map5_)
                top1s.append(top1_)
                top5s.append(top5_)
            map5 = max(map5s)
            i_max = map5s.index(map5)
            top1 = top1s[i_max]
            top5 = top5s[i_max]
            best_t = ts[i_max]

        valid_loss /= index_valid

    return valid_loss, top1, top5, map5, best_t

In [12]:
def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
#     wandb.watch(model, log_freq=100)
    
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_loss = np.inf
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        train_epoch_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=CONFIG['device'], epoch=epoch)
        
        val_epoch_loss, top1, top5, map5, best_t = valid_one_epoch(model, valid_loader, device=CONFIG['device'], 
                                         epoch=epoch)
    
        history['Train Loss'].append(train_epoch_loss)
        history['Valid Loss'].append(val_epoch_loss)
        
#         Log the metrics
#         wandb.log({"Train Loss": train_epoch_loss})
#         wandb.log({"Valid Loss": val_epoch_loss})
        print(f"{b_}Validation Loss Improved ({best_epoch_loss} ---> {val_epoch_loss}) map5：{map5}")
        if val_epoch_loss <= best_epoch_loss:
            best_epoch_loss = val_epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"model_{epoch}.pth"
            torch.save(model.state_dict(), './weight/'+PATH)
            torch.save({
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'best_t':best_t,
                }, f'./weight/optimizer_{epoch}.pth')
            # Save a model file from the current directory
            print(f"Model Saved{sr_}")
            
        print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Loss: {:.4f}".format(best_epoch_loss))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [None]:
model, history = run_training(model, optimizer, scheduler,
                              device=CONFIG['device'],
                              num_epochs=CONFIG['epochs'])

[INFO] Using GPU: NVIDIA GeForce RTX 3090



100%|██████████| 39/39 [01:02<00:00,  1.60s/it, Epoch=1, LR=0.0001, Train_Loss=1.4263262]
100%|██████████| 8/8 [00:02<00:00,  3.64it/s]


[34mValidation Loss Improved (inf ---> 1.4835968589782715) map5：0.0021666666666666666
Model Saved[0m



100%|██████████| 39/39 [01:03<00:00,  1.64s/it, Epoch=2, LR=0.0001, Train_Loss=1.4158506]
100%|██████████| 8/8 [00:02<00:00,  3.65it/s]

[34mValidation Loss Improved (1.4835968589782715 ---> 1.499460768699646) map5：0.0013333333333333333




100%|██████████| 39/39 [01:02<00:00,  1.59s/it, Epoch=3, LR=0.0001, Train_Loss=1.4493095]
100%|██████████| 8/8 [00:02<00:00,  3.65it/s]


[34mValidation Loss Improved (1.4835968589782715 ---> 1.4983874588012696) map5：0.005666666666666667



100%|██████████| 39/39 [01:02<00:00,  1.61s/it, Epoch=4, LR=0.0001, Train_Loss=1.4516565]
100%|██████████| 8/8 [00:02<00:00,  4.05it/s]