In [1]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mjcdata[0m (use `wandb login --relogin` to force relogin)


In [1]:
import os
import time
import wandb
from tqdm.notebook import tqdm
from copy import deepcopy

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

from sklearn.neighbors import NearestNeighbors

from utils import seed_everything
from data import ImageDataset, stratified_kfold, get_train_transforms, get_valid_transforms
from model import IndividualClassifier
from scheduler import CosineAnnealingWarmupRestarts

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class Config:
    checkpoint = '/root/dl_whale_classification/pths_individual'
    test = False
    tta = False
    resume_epoch = 0
    iters_to_accumulate = 1
    resume_root = None
    wandb_log = True
    model_name = 'tf_efficientnet_b3_ns'
    fold = 0
    n_split = 5
    seed = 2022
    data_dir = '/root/beluga/' # root/data/train_images
    root_dir = '.'
    batch_size = 4
    lr = 1e-4
    weight_decay = 0.0005
    epoch = 30
    exp_name = 'beluga'
config = Config()

In [3]:
seed_everything(2022)

In [4]:
df = pd.read_csv('/root/data/train.csv')
# df = df.iloc[:len(df)//20]
drop_index = [11604,15881,16782,21966,23306, 23626 ,24862,25895,29468,31831,35805,37176,40834,47480,48455,36710,47161]
df = df.drop(drop_index, axis=0).reset_index(drop=True)
df.species.replace({"globis": "short_finned_pilot_whale",
                  "pilot_whale": "short_finned_pilot_whale",
                  "kiler_whale": "killer_whale",
                  "bottlenose_dolpin": "bottlenose_dolphin"}, inplace=True)

In [5]:
df = df[df.species == 'beluga'].reset_index(drop=True) # blue_whale ,southern_right_whale

In [6]:
specie_unique = df.species.unique()
specie_indices = range(len(specie_unique))
species2idx = {k: v for k, v in zip(specie_unique, specie_indices)}
df.species = df.species.map(species2idx)

individual_unique = df.individual_id.unique()
individual_indices = range(len(individual_unique))
individual2idx = {k : v for k, v in zip(individual_unique, individual_indices)}
df.individual_id = df.individual_id.map(individual2idx)

In [7]:
print('num_species', len(df.species.unique()))
print('num_individual', len(df.individual_id.unique()))

num_species 1
num_individual 1012


In [8]:
df_single = df[df['individual_id'].map(df['individual_id'].value_counts()) == 1]
df_others = df[df['individual_id'].map(df['individual_id'].value_counts()) > 1]

In [9]:
train_single, valid_single = stratified_kfold(df=df_single, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='species', just_kfold=True)
train_others, valid_others = stratified_kfold(df=df_others, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='individual_id', just_kfold=True)

In [10]:
train_single_indices = np.take(df_single.index.to_numpy(), train_single) # 실제 index    
train_others_indices = np.take(df_others.index.to_numpy(), train_others)
valid_single_indices = np.take(df_single.index.to_numpy(), valid_single)
valid_others_indices = np.take(df_others.index.to_numpy(), valid_others)

In [11]:
full_train_indices = np.sort(np.concatenate((train_single_indices, train_others_indices), axis=0))
full_valid_indices = np.sort(np.concatenate((valid_single_indices, valid_others_indices), axis=0))

In [12]:
len(df.iloc[full_train_indices].individual_id.unique())

951

In [13]:
fnames, labels = df['image'].values, df['individual_id'].values

In [14]:
fnames_train, labels_train = fnames[full_train_indices], labels[full_train_indices]
fnames_valid_single, labels_valid_single = fnames[valid_single_indices], labels[valid_single_indices]
fnames_valid_others, labels_valid_others = fnames[valid_others_indices], labels[valid_others_indices]

In [15]:
train_transforms = get_train_transforms()
valid_transforms = get_valid_transforms()

In [16]:
train_dataset = ImageDataset(path=fnames_train, target=labels_train, transform=train_transforms, root=config.data_dir)# + '/train_detec_512_v5/')
valid_dataset_single = ImageDataset(path=fnames_valid_single, target=labels_valid_single, transform=valid_transforms, root=config.data_dir)# + '/train_detec_512_v5/')
valid_dataset_others = ImageDataset(path=fnames_valid_others, target=labels_valid_others, transform=valid_transforms, root=config.data_dir)# + '/train_detec_512_v5/')

In [17]:
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
valid_loader_single = DataLoader(valid_dataset_single, batch_size=config.batch_size, shuffle=False, num_workers=8, pin_memory=True)
valid_loader_others = DataLoader(valid_dataset_others, batch_size=config.batch_size, shuffle=False, num_workers=8, pin_memory=True)

In [18]:
model = IndividualClassifier(config.model_name, n_classes=1012).to(device)

In [19]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [20]:
cosine_annealing_scheduler_arg = dict(
    first_cycle_steps=len(train_dataset)//config.batch_size*config.epoch,
    cycle_mult=1.0,
    max_lr=config.lr,
    min_lr=1e-07,
    warmup_steps=len(train_dataset)//config.batch_size*3, # wanrm up 0~3 epoch
    gamma=0.9
)

In [21]:
scheduler = CosineAnnealingWarmupRestarts(optimizer, **cosine_annealing_scheduler_arg)

In [22]:
!wandb online

W&B online, running your script from this directory will now sync to the cloud.


In [23]:
if config.wandb_log:
    run = wandb.init(config=config.__dict__,
                project=config.model_name, 
                settings=wandb.Settings(start_method="thread"), 
                name=f"{config.exp_name}_fold{config.fold}",
                reinit=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjcdata[0m (use `wandb login --relogin` to force relogin)


In [24]:
if config.resume_root is not None:
    check = torch.load(config.resume_root)
    model.load_state_dict(check['model'])
    optimizer.load_state_dict(check['optimizer'])
    scheduler.load_state_dict(check['scheduler'])
    print('loaded checkpoint')

# Train Functions

In [25]:
def train_one_epoch(model, optimizer, criterion, loader, scheduler, scaler=None, iters_to_accumulate=1):
    model.train()
    
    match = 0
    top_k_match = 0
    
    losses, y_trues, y_preds, embeds = [], [], [], []
    for i, (x, y) in enumerate(tqdm(loader)):
        x, y = x.to(device), y.to(device)
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                output, embed = model(x, y) 
                loss = criterion(output, y)
                
            scaler.scale(loss).backward()

            if ((i + 1) % iters_to_accumulate == 0) or ((i + 1) == len(loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        else:
            output, embed = model(x, y) 
            loss = criterion(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        y_preds.extend(output.detach().cpu().numpy())
        y_trues.extend(y.detach().cpu().numpy())
        embeds.extend(embed.detach().cpu().numpy())
        
        match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
        top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
        scheduler.step()

        losses.append(loss.detach().cpu().item())
        # break
        
    return np.mean(losses), match, top_k_match, y_preds, y_trues, embeds


In [26]:
def valid_one_epoch(model, criterion, loader, tta=False):
    model.eval()
    
    match = 0
    top_k_match = 0

    losses, y_trues, y_preds, embeds = [], [], [], []
    with torch.no_grad():
        for idx, (x, y) in tqdm(enumerate(loader)):
            x, y = x.to(device), y.to(device)
            with torch.cuda.amp.autocast():
                output, embed = model(x, y) 
                loss = criterion(output, y)
  
            match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
            top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
            losses.append(loss.detach().cpu().item())
            
            y_preds.extend(output.detach().cpu().numpy())
            y_trues.extend(y.detach().cpu().numpy())
            embeds.extend(embed.detach().cpu().numpy())
            # break

    return np.mean(losses), match, top_k_match, y_preds, y_trues, embeds

# Train

In [None]:
start = time.time()
os.makedirs(f"{config.checkpoint}/{config.model_name}", exist_ok=True)

grad_scaler = torch.cuda.amp.GradScaler()

best_model = None
best_acc,  best_epoch = 0, 0
print('Start Training')
for i in range(config.resume_epoch, config.epoch):
    print(f"epoch: {i}")
    lr = scheduler.get_lr()[0]
    train_loss, train_acc, train_top_k, train_preds, train_trues, train_embeds = train_one_epoch(model, optimizer, criterion, train_loader, scheduler, grad_scaler, config.iters_to_accumulate)
    
    neigh = NearestNeighbors(n_neighbors=1, metric='cosine')
    neigh.fit(train_embeds) # 80 %
    
    valid_loss1, valid_acc1, valid_top_k1, valid_preds1, valid_trues1, valid_embeds1 = valid_one_epoch(model, criterion, valid_loader_single, config.tta)
    valid_loss2, valid_acc2, valid_top_k2, valid_preds2, valid_trues2, valid_embeds2 = valid_one_epoch(model, criterion, valid_loader_others, config.tta)
    
    distance1, indices1 = neigh.kneighbors(valid_embeds1, n_neighbors=1, return_distance=True) # 80 %에 해당하는 index
    distance2, indices2 = neigh.kneighbors(valid_embeds2, n_neighbors=1, return_distance=True)
    
    # ---------------------------------------------------------------- single valid - distance 분포
    print('-'*20, 'single','-'*20)
    # import pdb ; pdb.set_trace()
    neigh_preds1 = np.array(train_trues)[indices1.reshape(-1)] # k neighbor 예측값
    neigh_trues1 = df.individual_id[valid_single_indices].values # 정답
    print('min, mean, max :', distance1.min(), distance1.mean(), distance1.max())
    # print('wrong %', len(distance1[wrong1])/len(distance1))
    
    # ---------------------------------------------------------------- others valid - 맞은애 / 틀린애 distance 분포, 
    
    neigh_preds = np.array(train_trues)[indices2.reshape(-1)] # k neighbor 예측값
    neigh_trues = df.individual_id[valid_others_indices].values
    correct = (neigh_preds == neigh_trues) # neigh 예측값중 맞은 index
    wrong = (neigh_preds != neigh_trues)
    print('-'*20, 'others','-'*20)
    # print('min, mean, max (total):', distance2.min(), distance2.mean(), distance2.max())
    # print('min, mean, max (correct):', distance2[correct].min(), distance2[correct].mean(), distance2[correct].max())
    # print('min, mean, max (wrong):', distance2[wrong].min(), distance2[wrong].mean(), distance2[wrong].max())
    print('correct %', len(distance2[correct])/len(distance2))
    # print('wrong %', len(distance2[wrong2])/len(distance2))
    
    # ----------------------------------------------------------------
    
    # print(f"train loss {train_loss :.4f} acc {train_acc/len(train_loader.dataset) :.4f} topk {train_top_k/len(train_loader.dataset) :.4f}")
    # print(f"valid loss (single) {valid_loss1 :.4f} acc {valid_acc1/len(valid_loader_single.dataset) :.4f} topk {valid_top_k1/len(valid_loader_single.dataset) :.4f}")
    # print(f"valid loss (others) {valid_loss2 :.4f} acc {valid_acc2/len(valid_loader_others.dataset) :.4f} topk {valid_top_k2/len(valid_loader_others.dataset) :.4f}")
    print(f"lr {lr} time {time.time() - start :.2f}s")
    
    if best_acc < len(distance2[correct])/len(distance2):
        best_acc = len(distance2[correct])/len(distance2)
        best_epoch = i
        print(f'best acc updated {best_acc}')
        best_model_dict = {
            'model': deepcopy(model.state_dict()),
            'optimizer': deepcopy(optimizer.state_dict()),
            'scheduler': deepcopy(scheduler.state_dict()),
        }
        torch.save(best_model_dict, f"{config.checkpoint}/{config.model_name}/best.pt")

    if config.wandb_log:
        wandb_dict = {
            'train loss': train_loss,
            'train acc': train_acc/len(train_loader.dataset),
            # 'valid loss (single)': valid_loss1,
            # 'valid acc (single)': valid_acc1 / len(valid_loader_single.dataset),
            # 'valid topk (single)': valid_top_k1 / len(valid_loader_single.dataset),
            # 'valid loss (others)': valid_loss2,
            'valid acc (others)': valid_acc2 / len(valid_loader_others.dataset),
            # 'valid topk (others)': valid_top_k2 / len(valid_loader_others.dataset),
            'correct' : len(distance2[correct])/len(distance2),
            'learning rate': scheduler.get_lr()[0],
        }
        wandb.log(wandb_dict)
        
os.rename(f"{config.checkpoint}/{config.model_name}/best.pt", \
          f"{config.checkpoint}/{config.model_name}/fold{config.fold}_epoch{best_epoch}_{best_acc:.04f}_{config.exp_name}.pt")

Start Training
epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.21721785391314086 0.3463243564484499 0.49496091500952766
-------------------- others --------------------
correct % 0.020876826722338204
lr 1e-07 time 979.80s
best acc updated 0.020876826722338204
epoch: 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.13943958466857476 0.24997444507086014 0.44361050491321397
-------------------- others --------------------
correct % 0.020876826722338204
lr 3.3400000000000005e-05 time 1963.17s
epoch: 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.11219853412335112 0.2698189324272803 0.45546960859349506
-------------------- others --------------------
correct % 0.027835768963117607
lr 6.670000000000001e-05 time 2948.80s
best acc updated 0.027835768963117607
epoch: 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.13092214056385465 0.30178612009506894 0.5260094127046825
-------------------- others --------------------
correct % 0.06680584551148225
lr 0.0001 time 3931.98s
best acc updated 0.06680584551148225
epoch: 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.18807616666569715 0.37074105548808084 0.5029366268465753
-------------------- others --------------------
correct % 0.13221990257480862
lr 9.966225596921006e-05 time 4914.29s
best acc updated 0.13221990257480862
epoch: 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.27608758542114675 0.4210995422524931 0.534654640807485
-------------------- others --------------------
correct % 0.1837160751565762
lr 9.865359128546221e-05 time 5895.50s
best acc updated 0.1837160751565762
epoch: 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.31404124070783357 0.4569952184281506 0.620718684463399
-------------------- others --------------------
correct % 0.22825330549756437
lr 9.698764640825613e-05 time 6875.93s
best acc updated 0.22825330549756437
epoch: 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.33090615735305096 0.4899561607809889 0.575675166472991
-------------------- others --------------------
correct % 0.2776617954070981
lr 9.468695038415445e-05 time 7861.26s
best acc updated 0.2776617954070981
epoch: 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.4058545678277903 0.5199324727256134 0.6135991980840922
-------------------- others --------------------
correct % 0.31454418928322897
lr 9.178261618007618e-05 time 8851.70s
best acc updated 0.31454418928322897
epoch: 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.4144450462682916 0.5455167624571859 0.6617099269505098
-------------------- others --------------------
correct % 0.33890048712595683
lr 8.831391993379295e-05 time 9837.39s
best acc updated 0.33890048712595683
epoch: 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.3882471608474022 0.552792417674659 0.6457746913040645
-------------------- others --------------------
correct % 0.3813500347947112
lr 8.432776981154325e-05 time 10824.98s
best acc updated 0.3813500347947112
epoch: 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.38758355169398484 0.5653046434657321 0.6640487186160574
-------------------- others --------------------
correct % 0.38552540013917885
lr 7.987807165555417e-05 time 11809.09s
best acc updated 0.38552540013917885
epoch: 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.43652257906202074 0.5777220198954729 0.6822904086709997
-------------------- others --------------------
correct % 0.4126652748782185
lr 7.502500000000001e-05 time 12791.70s
best acc updated 0.4126652748782185
epoch: 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


-------------------- single --------------------
min, mean, max : 0.3920670741481532 0.5877779187718777 0.6864427986761823
-------------------- others --------------------
correct % 0.42240779401530965
lr 6.983418431365589e-05 time 13775.91s
best acc updated 0.42240779401530965
epoch: 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1488.0), HTML(value='')))