In [24]:
import os
import time
import wandb
from tqdm.notebook import tqdm
from copy import deepcopy

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

from sklearn.neighbors import NearestNeighbors

from utils import seed_everything
from data import ImageDataset, stratified_kfold, get_train_transforms, get_valid_transforms
from model import IndividualClassifier
from scheduler import CosineAnnealingWarmupRestarts

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [25]:
class Config:
    checkpoint = '/root/dl_whale_classification/pths_individual'
    test = False
    tta = False
    resume_epoch = 0
    iters_to_accumulate = 1
    resume_root = None
    wandb_log = True
    model_name = 'tf_efficientnet_b6_ns'
    fold = 0
    n_split = 5
    seed = 2022
    data_dir = '/root/data2/' # root/data/train_images
    root_dir = '.'
    batch_size = 8
    lr = 1e-4
    weight_decay = 0.0005
    epoch = 20
    exp_name = 'test'
config = Config()

In [26]:
seed_everything(2022)

In [27]:
df = pd.read_csv('/root/data/train.csv')
df = df.iloc[:len(df)//20]
df.species.replace({"globis": "short_finned_pilot_whale",
                  "pilot_whale": "short_finned_pilot_whale",
                  "kiler_whale": "killer_whale",
                  "bottlenose_dolpin": "bottlenose_dolphin"}, inplace=True)

In [28]:
specie_unique = df.species.unique()
specie_indices = range(len(specie_unique))
species2idx = {k: v for k, v in zip(specie_unique, specie_indices)}
df.species = df.species.map(species2idx)

individual_unique = df.individual_id.unique()
individual_indices = range(len(individual_unique))
individual2idx = {k : v for k, v in zip(individual_unique, individual_indices)}
df.individual_id = df.individual_id.map(individual2idx)

In [29]:
print('num_species', len(df.species.unique()))
print('num_individual', len(df.individual_id.unique()))
      

num_species 25
num_individual 1793


In [30]:
df_single = df[df['individual_id'].map(df['individual_id'].value_counts()) == 1]
df_others = df[df['individual_id'].map(df['individual_id'].value_counts()) > 1]

In [31]:
train_single, valid_single = stratified_kfold(df=df_single, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='species')
train_others, valid_others = stratified_kfold(df=df_others, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='individual_id')



In [32]:
train_single_indices = np.take(df_single.index.to_numpy(), train_single) # 실제 index    
train_others_indices = np.take(df_others.index.to_numpy(), train_others)
valid_single_indices = np.take(df_single.index.to_numpy(), valid_single)
valid_others_indices = np.take(df_others.index.to_numpy(), valid_others)

In [33]:
full_train_indices = np.sort(np.concatenate((train_single_indices, train_others_indices), axis=0))
full_valid_indices = np.sort(np.concatenate((valid_single_indices, valid_others_indices), axis=0))

In [34]:
fnames, labels = df['image'].values, df['individual_id'].values

In [35]:
fnames_train, labels_train = fnames[full_train_indices], labels[full_train_indices]
fnames_valid_single, labels_valid_single = fnames[valid_single_indices], labels[valid_single_indices]
fnames_valid_others, labels_valid_others = fnames[valid_others_indices], labels[valid_others_indices]

In [36]:
train_transforms = get_train_transforms()
valid_transforms = get_valid_transforms()

In [37]:
train_dataset = ImageDataset(path=fnames_train, target=labels_train, transform=train_transforms, root=config.data_dir + '/train_detec_512_v2/')
valid_dataset_single = ImageDataset(path=fnames_valid_single, target=labels_valid_single, transform=valid_transforms, root=config.data_dir + '/train_detec_512_v2/')
valid_dataset_others = ImageDataset(path=fnames_valid_others, target=labels_valid_others, transform=valid_transforms, root=config.data_dir + '/train_detec_512_v2/')

In [38]:
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
valid_loader_single = DataLoader(valid_dataset_single, batch_size=config.batch_size, shuffle=False, num_workers=8, pin_memory=True)
valid_loader_others = DataLoader(valid_dataset_others, batch_size=config.batch_size, shuffle=False, num_workers=8, pin_memory=True)

In [39]:
model = IndividualClassifier('tf_efficientnet_b6_ns').to(device)

In [40]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [41]:
cosine_annealing_scheduler_arg = dict(
    first_cycle_steps=len(train_dataset)//config.batch_size*config.epoch,
    cycle_mult=1.0,
    max_lr=config.lr,
    min_lr=1e-07,
    warmup_steps=len(train_dataset)//config.batch_size*3, # wanrm up 0~3 epoch
    gamma=0.9
)

In [42]:
scheduler = CosineAnnealingWarmupRestarts(optimizer, **cosine_annealing_scheduler_arg)

In [43]:
if config.wandb_log:
    run = wandb.init(config=config.__dict__,
                project=config.model_name, 
                settings=wandb.Settings(start_method="thread"), 
                name=f"{config.exp_name}_fold{config.fold}",
                reinit=True)




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [44]:
if config.resume_root is not None:
    check = torch.load(config.resume_root)
    model.load_state_dict(check['model'])
    optimizer.load_state_dict(check['optimizer'])
    scheduler.load_state_dict(check['scheduler'])
    print('loaded checkpoint')

# Train Functions

In [45]:
def train_one_epoch(model, optimizer, criterion, loader, scheduler, scaler=None, iters_to_accumulate=1):
    model.train()
    
    match = 0
    top_k_match = 0
    
    losses, y_trues, y_preds, embeds = [], [], [], []
    for i, (x, y) in enumerate(tqdm(loader)):
        x, y = x.to(device), y.to(device)
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                output, embed = model(x, y) 
                loss = criterion(output, y)
                
            scaler.scale(loss).backward()

            if ((i + 1) % iters_to_accumulate == 0) or ((i + 1) == len(loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        else:
            output, embed = model(x, y) 
            loss = criterion(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        y_preds.extend(output.detach().cpu().numpy())
        y_trues.extend(y.detach().cpu().numpy())
        embeds.extend(embed.detach().cpu().numpy())
        
        match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
        top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
        scheduler.step()

        losses.append(loss.detach().cpu().item())
        # break
        
    return np.mean(losses), match, top_k_match, y_preds, y_trues, embeds


In [46]:
def valid_one_epoch(model, criterion, loader, tta=False):
    model.eval()
    
    match = 0
    top_k_match = 0

    losses, y_trues, y_preds, embeds = [], [], [], []
    with torch.no_grad():
        for idx, (x, y) in tqdm(enumerate(loader)):
            x, y = x.to(device), y.to(device)
            with torch.cuda.amp.autocast():
                output, embed = model(x, y) 
                loss = criterion(output, y)
  
            match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
            top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
            losses.append(loss.detach().cpu().item())
            
            y_preds.extend(output.detach().cpu().numpy())
            y_trues.extend(y.detach().cpu().numpy())
            embeds.extend(embed.detach().cpu().numpy())
            # break

    return np.mean(losses), match, top_k_match, y_preds, y_trues, embeds

# Train

In [None]:
start = time.time()
os.makedirs(f"{config.checkpoint}/{config.model_name}", exist_ok=True)

grad_scaler = torch.cuda.amp.GradScaler()

best_model = None
best_acc,  best_epoch = 0, 0
print('Start Training')
for i in range(config.resume_epoch, config.epoch):
    print(f"epoch: {i}")
    lr = scheduler.get_lr()[0]
    train_loss, train_acc, train_top_k, train_preds, train_trues, train_embeds = train_one_epoch(model, optimizer, criterion, train_loader, scheduler, grad_scaler, config.iters_to_accumulate)
    
    neigh = NearestNeighbors(n_neighbors=1, metric='cosine')
    neigh.fit(train_embeds) # 80 %
    
    valid_loss1, valid_acc1, valid_top_k1, valid_preds1, valid_trues1, valid_embeds1 = valid_one_epoch(model, criterion, valid_loader_single, config.tta)
    valid_loss2, valid_acc2, valid_top_k2, valid_preds2, valid_trues2, valid_embeds2 = valid_one_epoch(model, criterion, valid_loader_others, config.tta)
    
    distance1, indices1 = neigh.kneighbors(valid_embeds1, n_neighbors=1, return_distance=True) # 80 %에 해당하는 index
    distance2, indices2 = neigh.kneighbors(valid_embeds2, n_neighbors=1, return_distance=True)
    
    # ---------------------------------------------------------------- single valid - distance 분포
    print('-'*20, 'single','-'*20)
    import pdb ; pdb.set_trace()
    neigh_preds1 = np.array(train_trues)[indices1.reshape(-1)] # k neighbor 예측값
    neigh_trues1 = df.individual_id[valid_single_indices].values # 정답
    print('min, mean, max :', distance1.min(), distance1.mean(), distance1.max())
    # print('wrong %', len(distance1[wrong1])/len(distance1))
    
    # ---------------------------------------------------------------- others valid - 맞은애 / 틀린애 distance 분포, 
    
    neigh_preds = np.array(train_trues)[indices2.reshape(-1)] # k neighbor 예측값
    neigh_trues = df.individual_id[valid_others_indices].values
    correct = (neigh_preds == neigh_trues) # neigh 예측값중 맞은 index
    wrong = (neigh_preds != neigh_trues)
    print('-'*20, 'others','-'*20)
    print('min, mean, max (total):', distance2.min(), distance2.mean(), distance2.max())
    print('min, mean, max (correct):', distance2[correct].min(), distance2[correct].mean(), distance2[correct].max())
    print('min, mean, max (wrong):', distance2[wrong].min(), distance2[wrong].mean(), distance2[wrong].max())
    print('correct %', len(distance2[correct])/len(distance2))
    print('wrong %', len(distance2[wrong2])/len(distance2))
    
    # ----------------------------------------------------------------
    
    # print(f"train loss {train_loss :.4f} acc {train_acc/len(train_loader.dataset) :.4f} topk {train_top_k/len(train_loader.dataset) :.4f}")
    # print(f"valid loss (single) {valid_loss1 :.4f} acc {valid_acc1/len(valid_loader_single.dataset) :.4f} topk {valid_top_k1/len(valid_loader_single.dataset) :.4f}")
    # print(f"valid loss (others) {valid_loss2 :.4f} acc {valid_acc2/len(valid_loader_others.dataset) :.4f} topk {valid_top_k2/len(valid_loader_others.dataset) :.4f}")
    print(f"lr {lr} time {time.time() - start :.2f}s")
    
    if best_acc < valid_acc2/len(valid_loader_single.dataset):
        best_acc = valid_acc2/len(valid_loader_single.dataset)
        best_epoch = i
        print(f'best acc updated {best_acc}')
        best_model_dict = {
            'model': deepcopy(model.state_dict()),
            'optimizer': deepcopy(optimizer.state_dict()),
            'scheduler': deepcopy(scheduler.state_dict()),
        }
        torch.save(best_model_dict, f"{config.checkpoint}/{config.model_name}/best.pt")

    if config.wandb_log:
        wandb_dict = {
            'train loss': train_loss,
            'train acc': train_acc/len(train_loader.dataset),
            'valid loss (single)': valid_loss1,
            'valid acc (single)': valid_acc1 / len(valid_loader_single.dataset),
            'valid topk (single)': valid_top_k1 / len(valid_loader_single.dataset),
            'valid loss (others)': valid_loss2,
            'valid acc (others)': valid_acc2 / len(valid_loader_others.dataset),
            'valid topk (others)': valid_top_k2 / len(valid_loader_others.dataset),
            'learning rate': scheduler.get_lr()[0],
        }
        wandb.log(wandb_dict)
        
os.rename(f"{config.checkpoint}/{config.model_name}/best.pt", \
          f"{config.checkpoint}/{config.model_name}/fold{config.fold}_epoch{best_epoch}_{best_acc:.04f}_{config.exp_name}.pt")

Start Training
epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


> [0;32m<ipython-input-47-5206b3dbcc35>[0m(25)[0;36m<cell line: 9>[0;34m()[0m
[0;32m     23 [0;31m    [0;31m# ---------------------------------------------------------------- single valid[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m    [0;32mimport[0m [0mpdb[0m [0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m    [0mpreds1[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0marray[0m[0;34m([0m[0mtrain_trues[0m[0;34m)[0m[0;34m[[0m[0mindices1[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m][0m [0;31m# k neighbor 예측값[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m    [0mtrues1[0m [0;34m=[0m [0mdf[0m[0;34m.[0m[0mindividual_id[0m[0;34m[[0m[0mvalid_single_indices[0m[0;34m][0m[0;34m.[0m[0mvalues[0m [0;31m# 정답[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m    [0mcorrect1[0m [0;34m=[0m [0;34m([

ipdb>  preds1 = np.array(train_trues)[indices1.reshape(-1)]
ipdb>  preds1


array([ 873, 1135,  246,   36,  457, 1573,   61,  147,  411,  423, 1448,
        915, 1447,   47,   99, 1388,  224, 1573, 1677,  159,   23,  224,
       1309, 1368,   85,  655, 1573, 1568, 1573,  156,  514, 1443,  528,
       1042,  898,  156,  189, 1382, 1715,  154,   69, 1631, 1458,  101,
       1748, 1573,  265, 1431,  371,  618,  546,   61,  178,  114,  457,
       1243,  156,  862,   86, 1575,  578,   72,  687,  579,  224, 1021,
        411,   95,   50,  457, 1589,  411,  153, 1042,  156,  595,   10,
         61,   95,  612,  612,  411,  655,   47, 1715, 1691, 1573, 1592,
       1431,  154,  595,  870, 1247, 1703,  978,  857, 1715,  862,   36,
        323,  595, 1506,  411,   75,  147,  156, 1294, 1592,  154, 1419,
       1573,  431,  512,  914,  411, 1018,  528, 1366,  760, 1573,  323,
       1505,   77,  423, 1447,   24, 1018,  506,  426, 1366,  528,  182,
        147,  514,  320,  197,   77,  298,  898,  653, 1690, 1592,  514,
        117,   81,  595, 1573, 1079,  224,  156,  2

ipdb>  indices1


array([[1993],
       [1896],
       [1884],
       [1954],
       [1858],
       [2005],
       [1895],
       [2039],
       [2009],
       [1853],
       [1588],
       [1995],
       [1980],
       [1761],
       [1886],
       [1575],
       [1999],
       [2005],
       [1856],
       [ 624],
       [1991],
       [1999],
       [1967],
       [1391],
       [1788],
       [1828],
       [2005],
       [2032],
       [2005],
       [2010],
       [1969],
       [1857],
       [1792],
       [1758],
       [1798],
       [1990],
       [1488],
       [2004],
       [1881],
       [1908],
       [2023],
       [1956],
       [1905],
       [1787],
       [1970],
       [2005],
       [1962],
       [2024],
       [1931],
       [1817],
       [1799],
       [1895],
       [1950],
       [2031],
       [1858],
       [1732],
       [2010],
       [1880],
       [1940],
       [1682],
       [1449],
       [1909],
       [1171],
       [1407],
       [1999],
       [1731],
       [20

ipdb>  trues1 = df.individual_id[valid_single_indices].values
ipdb>  trues1.shape


(288,)


ipdb>  (neigh_preds2 == neigh_trues2).sum()


*** NameError: name 'neigh_preds2' is not defined


ipdb>  preds1 == tures1


*** NameError: name 'tures1' is not defined


ipdb>  (preds1 == trues1).sum()


0


ipdb>  distance1


array([[0.51049492],
       [0.42640319],
       [0.3751685 ],
       [0.59506786],
       [0.48249002],
       [0.44413705],
       [0.39170263],
       [0.44733957],
       [0.5098154 ],
       [0.47215509],
       [0.48885337],
       [0.31122275],
       [0.5850021 ],
       [0.39486537],
       [0.29171792],
       [0.48497398],
       [0.36763347],
       [0.36640579],
       [0.52938198],
       [0.73610344],
       [0.38147021],
       [0.33353376],
       [0.42885046],
       [0.5270371 ],
       [0.55344472],
       [0.3772239 ],
       [0.43738239],
       [0.55726739],
       [0.34144396],
       [0.36388886],
       [0.26487573],
       [0.32403987],
       [0.42770086],
       [0.4286198 ],
       [0.44891944],
       [0.37117771],
       [0.38399401],
       [0.52947791],
       [0.5510192 ],
       [0.54938153],
       [0.45023444],
       [0.43977031],
       [0.41225519],
       [0.48567321],
       [0.45679669],
       [0.15601193],
       [0.40574249],
       [0.382

ipdb>  distance.reshape(-1)[wrong]


*** NameError: name 'distance' is not defined


ipdb>  distance1.reshape(-1)[wrong1]


*** NameError: name 'wrong1' is not defined
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user
correct1 % 0.0
correct2 % 0.013452914798206279
lr 1e-07 time 10711.36s
epoch: 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


> [0;32m<ipython-input-47-5206b3dbcc35>[0m(24)[0;36m<cell line: 9>[0;34m()[0m
[0;32m     22 [0;31m[0;34m[0m[0m
[0m[0;32m     23 [0;31m    [0;31m# ---------------------------------------------------------------- single valid[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 24 [0;31m    [0;32mimport[0m [0mpdb[0m [0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m    [0mpreds1[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0marray[0m[0;34m([0m[0mtrain_trues[0m[0;34m)[0m[0;34m[[0m[0mindices1[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m][0m [0;31m# k neighbor 예측값[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m    [0mtrues1[0m [0;34m=[0m [0mdf[0m[0;34m.[0m[0mindividual_id[0m[0;34m[[0m[0mvalid_single_indices[0m[0;34m][0m[0;34m.[0m[0mvalues[0m [0;31m# 정답[0m[0;34m[0m[0;34m[0m[0m
[0m
