In [1]:
import os
import time
import wandb
from tqdm.notebook import tqdm
from copy import deepcopy

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

from utils import seed_everything
from data import ImageDataset, stratified_kfold, get_train_transforms, get_valid_transforms
from model import SpecieClassifier
from scheduler import CosineAnnealingWarmupRestarts

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
class Config:
    checkpoint = '/root/dl_whale_classification/pths_specie'
    test = False
    tta = False
    resume_epoch = 0
    iters_to_accumulate = 1
    resume_root = None
    wandb_log = True
    model_name = 'tf_efficientnet_b4_ns'
    fold = 0
    n_split = 5
    seed = 2022
    data_dir = '/root/data2/' # root/data/train_images
    root_dir = '.'
    batch_size = 16
    lr = 1e-4
    weight_decay = 0.0005
    epoch = 20
    exp_name = 'test'
config = Config()

In [3]:
seed_everything(2022)

In [4]:
df = pd.read_csv('/root/data/train.csv')
df.species.replace({"globis": "short_finned_pilot_whale",
                  "pilot_whale": "short_finned_pilot_whale",
                  "kiler_whale": "killer_whale",
                  "bottlenose_dolpin": "bottlenose_dolphin"}, inplace=True)

In [5]:
specie_unique = df.species.unique()
specie_indices = range(len(specie_unique))
species2idx = {k: v for k, v in zip(specie_unique, specie_indices)}
df.species = df.species.map(species2idx)

individual_unique = df.individual_id.unique()
individual_indices = range(len(individual_unique))
individual2idx = {k : v for k, v in zip(individual_unique, individual_indices)}
df.individual_id = df.individual_id.map(individual2idx)

In [6]:
print('num_species', len(df.species.unique()))
print('num_individual', len(df.individual_id.unique()))
      

num_species 26
num_individual 15587


In [7]:
df_single = df[df['individual_id'].map(df['individual_id'].value_counts()) == 1]
df_others = df[df['individual_id'].map(df['individual_id'].value_counts()) > 1]

In [8]:
train_single, valid_single = stratified_kfold(df=df_single, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='species')
train_others, valid_others = stratified_kfold(df=df_others, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='species')



In [9]:
train_single_indices = np.take(df_single.index.to_numpy(), train_single)    
train_others_indices = np.take(df_others.index.to_numpy(), train_others)
valid_single_indices = np.take(df_single.index.to_numpy(), valid_single)
valid_others_indices = np.take(df_others.index.to_numpy(), valid_others)

In [10]:
full_train_indices = np.sort(np.concatenate((train_single_indices, train_others_indices), axis=0))

In [11]:
fnames, labels = df['image'].values, df['species'].values

In [12]:
fnames_train, labels_train = fnames[full_train_indices], labels[full_train_indices]
fnames_valid_single, labels_valid_single = fnames[valid_single_indices], labels[valid_single_indices]
fnames_valid_others, labels_valid_others = fnames[valid_others_indices], labels[valid_others_indices]

In [13]:
train_transforms = get_train_transforms()
valid_transforms = get_valid_transforms()

In [14]:
train_dataset = ImageDataset(path=fnames_train, target=labels_train, transform=train_transforms, root=config.data_dir + '/train_detec_512_v3/')
valid_dataset_single = ImageDataset(path=fnames_valid_single, target=labels_valid_single, transform=valid_transforms, root=config.data_dir + '/train_detec_512_v3/')
valid_dataset_others = ImageDataset(path=fnames_valid_others, target=labels_valid_others, transform=valid_transforms, root=config.data_dir + '/train_detec_512_v3/')

In [15]:
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
valid_loader_single = DataLoader(valid_dataset_single, batch_size=config.batch_size*2, shuffle=False, num_workers=8, pin_memory=True)
valid_loader_others = DataLoader(valid_dataset_others, batch_size=config.batch_size*2, shuffle=False, num_workers=8, pin_memory=True)

In [16]:
model = SpecieClassifier(config.model_name).to(device)

In [17]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [18]:
cosine_annealing_scheduler_arg = dict(
    first_cycle_steps=len(train_dataset)//config.batch_size*config.epoch,
    cycle_mult=1.0,
    max_lr=config.lr,
    min_lr=1e-07,
    warmup_steps=len(train_dataset)//config.batch_size*3, # wanrm up 0~3 epoch
    gamma=0.9
)

In [19]:
scheduler = CosineAnnealingWarmupRestarts(optimizer, **cosine_annealing_scheduler_arg)

In [20]:
if config.wandb_log:
    run = wandb.init(config=config.__dict__,
                project=config.model_name, 
                settings=wandb.Settings(start_method="thread"), 
                name=f"{config.exp_name}_fold{config.fold}",
                reinit=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjcdata[0m (use `wandb login --relogin` to force relogin)


In [21]:
grad_scaler = torch.cuda.amp.GradScaler()

best_model = None
best_acc,  best_epoch = 0, 0

In [22]:
if config.resume_root is not None:
    check = torch.load(config.resume_root)
    model.load_state_dict(check['model'])
    optimizer.load_state_dict(check['optimizer'])
    scheduler.load_state_dict(check['scheduler'])
    print('loaded checkpoint')

# Train Functions

In [23]:
def train_one_epoch(model, optimizer, criterion, loader, scheduler, scaler=None, iters_to_accumulate=1):
    model.train()
    
    match = 0
    top_k_match = 0
    
    losses, y_true, y_pred = [], [], []
    for i, (x, y) in enumerate(tqdm(loader)):
        x, y = x.to(device), y.to(device)
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                output = model(x) 
                loss = criterion(output, y)
                
            scaler.scale(loss).backward()

            if ((i + 1) % iters_to_accumulate == 0) or ((i + 1) == len(loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        

        else:
            output = model(x, y) 
            loss = criterion(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
        top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
        scheduler.step()

        losses.append(loss.detach().cpu().item())
        # if i == 10:
        #     break
        
    return np.mean(losses), match, top_k_match


In [24]:
def valid_one_epoch(model, criterion, loader, tta=False):
    model.eval()
    
    match = 0
    top_k_match = 0

    losses, y_trues, y_preds, tta_embed = [], [], [], []
    with torch.no_grad():
        for idx, (x, y) in tqdm(enumerate(loader)):
            x, y = x.to(device), y.to(device)
            with torch.cuda.amp.autocast():
                output = model(x) 
                loss = criterion(output, y)
  
            match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
            top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
            losses.append(loss.detach().cpu().item())
            
            y_preds.extend(np.argmax(output.detach().cpu().numpy(), axis=-1))
            y_trues.extend(y.detach().cpu().numpy())
            
    return np.mean(losses), match, top_k_match, np.array(y_preds), np.array(y_trues)

# Train

In [25]:
start = time.time()
os.makedirs(f"{config.checkpoint}/{config.model_name}", exist_ok=True)

print('Start Training')
for epo in range(config.resume_epoch, config.epoch):
    print(f"epoch: {epo}")
    lr = scheduler.get_lr()[0]
    train_loss, train_acc, train_top_k = train_one_epoch(model, optimizer, criterion, train_loader, scheduler, grad_scaler, config.iters_to_accumulate)
    valid_loss1, valid_acc1, valid_top_k1, valid_preds1, valid_trues1 = valid_one_epoch(model, criterion, valid_loader_single, config.tta)
    valid_loss2, valid_acc2, valid_top_k2, valid_preds2, valid_trues2 = valid_one_epoch(model, criterion, valid_loader_others, config.tta)

    print(f"train loss {train_loss :.4f} acc {train_acc :.4f} topk {train_top_k :.4f}")
    print(f"valid loss (single) {valid_loss1 :.4f} acc {valid_acc1/len(valid_loader_single.dataset) :.4f} topk {valid_top_k1/len(valid_loader_single.dataset) :.4f}")
    print(f"valid loss (others) {valid_loss2 :.4f} acc {valid_acc2/len(valid_loader_others.dataset) :.4f} topk {valid_top_k2/len(valid_loader_others.dataset) :.4f}")
    print(f"lr {lr} time {time.time() - start :.2f}s")
    
    print("-"*20, f"class (single)", "-"*20)
    for i in range(26):
        print(f"{i:02d}-#{(valid_trues1==i).sum():04d} : {((valid_preds1==i)&(valid_trues1==i)).sum()/(valid_trues1==i).sum():04f}", end=' / ') 
        if i % 4 == 3 and i > 0 or i == 25:
            print()

    print("-"*20, f"class (others)", "-"*20)
    for i in range(26):
        print(f"{i:02d}-#{(valid_trues2==i).sum():04d} : {((valid_preds2==i)&(valid_trues2==i)).sum()/(valid_trues2==i).sum():04f}", end=' / ') 
        if i % 4 == 3 and i > 0 or i == 25:
            print()
                                     
    if best_acc < valid_acc2/len(valid_loader_others.dataset):
        best_acc = valid_acc2/len(valid_loader_others.dataset)
        best_epoch = epo
        print(f'best acc updated {best_acc}')
        best_model_dict = {
            'model': deepcopy(model.state_dict()),
            'optimizer': deepcopy(optimizer.state_dict()),
            'scheduler': deepcopy(scheduler.state_dict()),
        }
        torch.save(best_model_dict, f"{config.checkpoint}/{config.model_name}/best.pt")

    if config.wandb_log:
        wandb_dict = {
            'train loss': train_loss,
            'train acc': train_acc,
            'valid loss (single)': valid_loss1,
            'valid acc (single)': valid_acc1 / len(valid_loader_single.dataset),
            'valid topk (single)': valid_top_k1 / len(valid_loader_single.dataset),
            'valid loss (others)': valid_loss2,
            'valid acc (others)': valid_acc2 / len(valid_loader_others.dataset),
            'valid topk (others)': valid_top_k2 / len(valid_loader_others.dataset),
            'learning rate': scheduler.get_lr()[0],
        }
        wandb.log(wandb_dict)
        
os.rename(f"{config.checkpoint}/{config.model_name}/best.pt", \
          f"{config.checkpoint}/{config.model_name}/fold{config.fold}_epoch{best_epoch}_{best_acc:.04f}.pt")

Start Training
epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




KeyboardInterrupt: 

In [26]:
valid_trues1==0

array([False, False, False, ..., False, False, False])