In [1]:
import os
import time
import wandb
from tqdm.notebook import tqdm
from copy import deepcopy

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

from utils import seed_everything
from data import ImageDataset, stratified_kfold, get_train_transforms, get_valid_transforms
from model import SpecieClassifier
from scheduler import CosineAnnealingWarmupRestarts

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
class Config:
    checkpoint = '/root/dl_whale_classification/pths_specie'
    test = False
    tta = False
    resume_epoch = 0
    iters_to_accumulate = 1
    resume_root = None
    wandb_log = True
    model_name = 'tf_efficientnet_b4_ns'
    fold = 0
    n_split = 5
    seed = 2022
    data_dir = '/root/data2/' # root/data/train_images
    root_dir = '.'
    batch_size = 16
    lr = 1e-4
    weight_decay = 0.0005
    epoch = 20
    exp_name = 'test'
config = Config()

In [3]:
seed_everything(2022)

In [4]:
df = pd.read_csv('/root/data/train.csv')
df.species.replace({"globis": "short_finned_pilot_whale",
                  "pilot_whale": "short_finned_pilot_whale",
                  "kiler_whale": "killer_whale",
                  "bottlenose_dolpin": "bottlenose_dolphin"}, inplace=True)

In [5]:
specie_unique = df.species.unique()
specie_indices = range(len(specie_unique))
species2idx = {k: v for k, v in zip(specie_unique, specie_indices)}
df.species = df.species.map(species2idx)

individual_unique = df.individual_id.unique()
individual_indices = range(len(individual_unique))
individual2idx = {k : v for k, v in zip(individual_unique, individual_indices)}
df.individual_id = df.individual_id.map(individual2idx)

In [6]:
print('num_species', len(df.species.unique()))
print('num_individual', len(df.individual_id.unique()))
      

num_species 26
num_individual 15587


In [7]:
df_single = df[df['individual_id'].map(df['individual_id'].value_counts()) == 1]
df_others = df[df['individual_id'].map(df['individual_id'].value_counts()) > 1]

In [8]:
train_single, valid_single = stratified_kfold(df=df_single, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='species')
train_others, valid_others = stratified_kfold(df=df_others, fold=config.fold, n_split=config.n_split, seed=config.seed, target_col='species')



In [9]:
train_single_indices = np.take(df_single.index.to_numpy(), train_single)    
train_others_indices = np.take(df_others.index.to_numpy(), train_others)
valid_single_indices = np.take(df_single.index.to_numpy(), valid_single)
valid_others_indices = np.take(df_others.index.to_numpy(), valid_others)

In [10]:
full_train_indices = np.sort(np.concatenate((train_single_indices, train_others_indices), axis=0))

In [11]:
fnames, labels = df['image'].values, df['species'].values

In [12]:
fnames_train, labels_train = fnames[full_train_indices], labels[full_train_indices]
fnames_valid_single, labels_valid_single = fnames[valid_single_indices], labels[valid_single_indices]
fnames_valid_others, labels_valid_others = fnames[valid_others_indices], labels[valid_others_indices]

In [13]:
train_transforms = get_train_transforms()
valid_transforms = get_valid_transforms()

In [14]:
train_dataset = ImageDataset(path=fnames_train, target=labels_train, transform=train_transforms, root=config.data_dir + '/train_detec_512_v3/')
valid_dataset_single = ImageDataset(path=fnames_valid_single, target=labels_valid_single, transform=valid_transforms, root=config.data_dir + '/train_detec_512_v3/')
valid_dataset_others = ImageDataset(path=fnames_valid_others, target=labels_valid_others, transform=valid_transforms, root=config.data_dir + '/train_detec_512_v3/')

In [15]:
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
valid_loader_single = DataLoader(valid_dataset_single, batch_size=config.batch_size*2, shuffle=False, num_workers=8, pin_memory=True)
valid_loader_others = DataLoader(valid_dataset_others, batch_size=config.batch_size*2, shuffle=False, num_workers=8, pin_memory=True)

In [16]:
model = SpecieClassifier(config.model_name).to(device)

In [17]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [18]:
cosine_annealing_scheduler_arg = dict(
    first_cycle_steps=len(train_dataset)//config.batch_size*config.epoch,
    cycle_mult=1.0,
    max_lr=config.lr,
    min_lr=1e-07,
    warmup_steps=len(train_dataset)//config.batch_size*3, # wanrm up 0~3 epoch
    gamma=0.9
)

In [19]:
scheduler = CosineAnnealingWarmupRestarts(optimizer, **cosine_annealing_scheduler_arg)

In [20]:
if config.wandb_log:
    run = wandb.init(config=config.__dict__,
                project=config.model_name, 
                settings=wandb.Settings(start_method="thread"), 
                name=f"{config.exp_name}_fold{config.fold}",
                reinit=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjcdata[0m (use `wandb login --relogin` to force relogin)


In [21]:
grad_scaler = torch.cuda.amp.GradScaler()

best_model = None
best_acc,  best_epoch = 0, 0

In [22]:
if config.resume_root is not None:
    check = torch.load(config.resume_root)
    model.load_state_dict(check['model'])
    optimizer.load_state_dict(check['optimizer'])
    scheduler.load_state_dict(check['scheduler'])
    print('loaded checkpoint')

# Train Functions

In [23]:
def train_one_epoch(model, optimizer, criterion, loader, scheduler, scaler=None, iters_to_accumulate=1):
    model.train()
    
    match = 0
    top_k_match = 0
    
    losses, y_true, y_pred = [], [], []
    for i, (x, y) in enumerate(tqdm(loader)):
        x, y = x.to(device), y.to(device)
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                output = model(x) 
                loss = criterion(output, y)
                
            scaler.scale(loss).backward()

            if ((i + 1) % iters_to_accumulate == 0) or ((i + 1) == len(loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        

        else:
            output = model(x, y) 
            loss = criterion(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
        top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
        scheduler.step()

        losses.append(loss.detach().cpu().item())
        # if i == 10:
        #     break
        
    return np.mean(losses), match, top_k_match


In [24]:
def valid_one_epoch(model, criterion, loader, tta=False):
    model.eval()
    
    match = 0
    top_k_match = 0

    losses, y_trues, y_preds, tta_embed = [], [], [], []
    with torch.no_grad():
        for idx, (x, y) in tqdm(enumerate(loader)):
            x, y = x.to(device), y.to(device)
            with torch.cuda.amp.autocast():
                output = model(x) 
                loss = criterion(output, y)
  
            match += (torch.argmax(output.detach().cpu(), dim=-1) == y.detach().cpu()).sum()
            top_k_match += (output.detach().topk(2, dim=-1).indices.cpu() == y.detach().cpu()[:, None]).sum()
            losses.append(loss.detach().cpu().item())
            
            y_preds.extend(np.argmax(output.detach().cpu().numpy(), axis=-1))
            y_trues.extend(y.detach().cpu().numpy())
            
    return np.mean(losses), match, top_k_match, np.array(y_preds), np.array(y_trues)

# Train

In [25]:
start = time.time()
os.makedirs(f"{config.checkpoint}/{config.model_name}", exist_ok=True)

print('Start Training')
for epo in range(config.resume_epoch, config.epoch):
    print(f"epoch: {epo}")
    lr = scheduler.get_lr()[0]
    train_loss, train_acc, train_top_k = train_one_epoch(model, optimizer, criterion, train_loader, scheduler, grad_scaler, config.iters_to_accumulate)
    valid_loss1, valid_acc1, valid_top_k1, valid_preds1, valid_trues1 = valid_one_epoch(model, criterion, valid_loader_single, config.tta)
    valid_loss2, valid_acc2, valid_top_k2, valid_preds2, valid_trues2 = valid_one_epoch(model, criterion, valid_loader_others, config.tta)

    print(f"train loss {train_loss :.4f} acc {train_acc :.4f} topk {train_top_k :.4f}")
    print(f"valid loss (single) {valid_loss1 :.4f} acc {valid_acc1/len(valid_loader_single.dataset) :.4f} topk {valid_top_k1/len(valid_loader_single.dataset) :.4f}")
    print(f"valid loss (others) {valid_loss2 :.4f} acc {valid_acc2/len(valid_loader_others.dataset) :.4f} topk {valid_top_k2/len(valid_loader_others.dataset) :.4f}")
    print(f"lr {lr} time {time.time() - start :.2f}s")
    
    print("-"*20, f"class (single)", "-"*20)
    for i in range(26):
        print(f"{i:02d}-#{(valid_trues1==i).sum():04d} : {((valid_preds1==i)&(valid_trues1==i)).sum()/(valid_trues1==i).sum():04f}", end=' / ') 
        if i % 4 == 3 and i > 0 or i == 25:
            print()

    print("-"*20, f"class (others)", "-"*20)
    for i in range(26):
        print(f"{i:02d}-#{(valid_trues2==i).sum():04d} : {((valid_preds2==i)&(valid_trues2==i)).sum()/(valid_trues2==i).sum():04f}", end=' / ') 
        if i % 4 == 3 and i > 0 or i == 25:
            print()
                                     
    if best_acc < valid_acc2/len(valid_loader_others.dataset):
        best_acc = valid_acc2/len(valid_loader_others.dataset)
        best_epoch = epo
        print(f'best acc updated {best_acc}')
        best_model_dict = {
            'model': deepcopy(model.state_dict()),
            'optimizer': deepcopy(optimizer.state_dict()),
            'scheduler': deepcopy(scheduler.state_dict()),
        }
        torch.save(best_model_dict, f"{config.checkpoint}/{config.model_name}/best.pt")

    if config.wandb_log:
        wandb_dict = {
            'train loss': train_loss,
            'train acc': train_acc,
            'valid loss (single)': valid_loss1,
            'valid acc (single)': valid_acc1 / len(valid_loader_single.dataset),
            'valid topk (single)': valid_top_k1 / len(valid_loader_single.dataset),
            'valid loss (others)': valid_loss2,
            'valid acc (others)': valid_acc2 / len(valid_loader_others.dataset),
            'valid topk (others)': valid_top_k2 / len(valid_loader_others.dataset),
            'learning rate': scheduler.get_lr()[0],
        }
        wandb.log(wandb_dict)
        
os.rename(f"{config.checkpoint}/{config.model_name}/best.pt", \
          f"{config.checkpoint}/{config.model_name}/fold{config.fold}_epoch{best_epoch}_{best_acc:.04f}.pt")

Start Training
epoch: 0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 1.4549 acc 25190.0000 topk 29005.0000
valid loss (single) 0.3428 acc 0.9001 topk 0.9460
valid loss (others) 0.2230 acc 0.9363 topk 0.9672
lr 1e-07 time 4407.69s
-------------------- class (single) --------------------
00-#0228 : 0.978070 / 01-#0370 : 0.986486 / 02-#0008 : 0.750000 / 03-#0051 : 0.960784 / 
04-#0052 : 1.000000 / 05-#0007 : 0.857143 / 06-#0062 : 0.887097 / 07-#0185 : 0.962162 / 
08-#0011 : 1.000000 / 09-#0061 : 0.983607 / 10-#0002 : 0.000000 / 11-#0031 : 0.741935 / 
12-#0045 : 0.577778 / 13-#0492 : 0.989837 / 14-#0012 : 0.000000 / 15-#0017 : 0.941176 / 
16-#0104 : 0.951923 / 17-#0021 : 0.190476 / 18-#0036 : 0.194444 / 19-#0004 : 0.000000 / 
20-#0010 : 0.000000 / 21-#0027 : 0.000000 / 22-#0007 : 0.000000 / 23-#0005 : 0.000000 / 
24-#0002 : 0.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.954955 / 01-#1109 : 0.971145 / 02-#0657 : 0.955860 / 03-#2105 : 0.981473 / 
04-#1436 : 0.997214 / 05-#0315 : 0.955556 / 0

  print(f"{i:02d}-#{(valid_trues2==i).sum():04d} : {((valid_preds2==i)&(valid_trues2==i)).sum()/(valid_trues2==i).sum():04f}", end=' / ')


epoch: 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.2206 acc 38204.0000 topk 39726.0000
valid loss (single) 0.1216 acc 0.9633 topk 0.9870
valid loss (others) 0.0827 acc 0.9767 topk 0.9913
lr 3.34e-05 time 10763.45s
-------------------- class (single) --------------------
00-#0228 : 0.995614 / 01-#0370 : 0.983784 / 02-#0008 : 0.625000 / 03-#0051 : 0.901961 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.919355 / 07-#0185 : 0.994595 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 0.935484 / 
12-#0045 : 0.866667 / 13-#0492 : 0.993902 / 14-#0012 : 0.833333 / 15-#0017 : 1.000000 / 
16-#0104 : 0.951923 / 17-#0021 : 0.857143 / 18-#0036 : 0.888889 / 19-#0004 : 0.250000 / 
20-#0010 : 0.600000 / 21-#0027 : 1.000000 / 22-#0007 : 0.000000 / 23-#0005 : 0.200000 / 
24-#0002 : 0.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.954955 / 01-#1109 : 0.981064 / 02-#0657 : 0.969559 / 03-#2105 : 0.988124 / 
04-#1436 : 0.996518 / 05-#0315 : 0.974603

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.1193 acc 39372.0000 topk 40272.0000
valid loss (single) 0.1043 acc 0.9676 topk 0.9892
valid loss (others) 0.0729 acc 0.9785 topk 0.9937
lr 6.67e-05 time 15833.49s
-------------------- class (single) --------------------
00-#0228 : 1.000000 / 01-#0370 : 0.997297 / 02-#0008 : 0.875000 / 03-#0051 : 0.862745 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.951613 / 07-#0185 : 0.989189 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.555556 / 13-#0492 : 0.995935 / 14-#0012 : 0.750000 / 15-#0017 : 1.000000 / 
16-#0104 : 0.971154 / 17-#0021 : 0.809524 / 18-#0036 : 1.000000 / 19-#0004 : 0.500000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.285714 / 23-#0005 : 0.400000 / 
24-#0002 : 0.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.972973 / 01-#1109 : 0.986474 / 02-#0657 : 0.981735 / 03-#2105 : 0.984798 / 
04-#1436 : 0.998607 / 05-#0315 : 0.980952

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0871 acc 39721.0000 topk 40454.0000
valid loss (single) 0.0780 acc 0.9789 topk 0.9941
valid loss (others) 0.0545 acc 0.9866 topk 0.9964
lr 0.0001 time 18644.95s
-------------------- class (single) --------------------
00-#0228 : 0.951754 / 01-#0370 : 0.997297 / 02-#0008 : 0.625000 / 03-#0051 : 0.960784 / 
04-#0052 : 1.000000 / 05-#0007 : 0.857143 / 06-#0062 : 0.983871 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.933333 / 13-#0492 : 0.995935 / 14-#0012 : 0.916667 / 15-#0017 : 1.000000 / 
16-#0104 : 0.980769 / 17-#0021 : 0.904762 / 18-#0036 : 1.000000 / 19-#0004 : 0.500000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.428571 / 23-#0005 : 0.800000 / 
24-#0002 : 0.500000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.936937 / 01-#1109 : 0.987376 / 02-#0657 : 0.984779 / 03-#2105 : 0.996200 / 
04-#1436 : 0.996518 / 05-#0315 : 0.977778 /

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0639 acc 39991.0000 topk 40592.0000
valid loss (single) 0.0715 acc 0.9789 topk 0.9951
valid loss (others) 0.0587 acc 0.9834 topk 0.9966
lr 9.914950632921091e-05 time 21113.00s
-------------------- class (single) --------------------
00-#0228 : 0.973684 / 01-#0370 : 0.989189 / 02-#0008 : 0.875000 / 03-#0051 : 0.941176 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.967742 / 07-#0185 : 0.994595 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.933333 / 13-#0492 : 0.993902 / 14-#0012 : 1.000000 / 15-#0017 : 0.882353 / 
16-#0104 : 0.961538 / 17-#0021 : 0.904762 / 18-#0036 : 0.972222 / 19-#0004 : 0.250000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.954955 / 01-#1109 : 0.981064 / 02-#0657 : 0.958904 / 03-#2105 : 0.993824 / 
04-#1436 : 0.997911 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0484 acc 40169.0000 topk 40654.0000
valid loss (single) 0.0701 acc 0.9789 topk 0.9968
valid loss (others) 0.0449 acc 0.9882 topk 0.9968
lr 9.662698785874757e-05 time 23660.96s
-------------------- class (single) --------------------
00-#0228 : 0.964912 / 01-#0370 : 0.994595 / 02-#0008 : 0.625000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.983871 / 07-#0185 : 0.989189 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 0.967742 / 
12-#0045 : 0.911111 / 13-#0492 : 0.997967 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 0.990385 / 17-#0021 : 0.904762 / 18-#0036 : 0.916667 / 19-#0004 : 1.000000 / 
20-#0010 : 0.900000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.200000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.945946 / 01-#1109 : 0.990081 / 02-#0657 : 0.989346 / 03-#2105 : 0.996200 / 
04-#1436 : 0.997214 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0370 acc 40333.0000 topk 40715.0000
valid loss (single) 0.0453 acc 0.9865 topk 0.9968
valid loss (others) 0.0559 acc 0.9878 topk 0.9964
lr 9.251834592969423e-05 time 26192.43s
-------------------- class (single) --------------------
00-#0228 : 0.995614 / 01-#0370 : 0.997297 / 02-#0008 : 0.625000 / 03-#0051 : 0.960784 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.967742 / 07-#0185 : 0.994595 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.977778 / 13-#0492 : 0.997967 / 14-#0012 : 0.833333 / 15-#0017 : 1.000000 / 
16-#0104 : 0.990385 / 17-#0021 : 0.952381 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.200000 / 
24-#0002 : 0.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.990991 / 01-#1109 : 0.992786 / 02-#0657 : 0.957382 / 03-#2105 : 0.992874 / 
04-#1436 : 0.999304 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0292 acc 40413.0000 topk 40740.0000
valid loss (single) 0.0557 acc 0.9860 topk 0.9957
valid loss (others) 0.0463 acc 0.9921 topk 0.9972
lr 8.696349541517193e-05 time 28709.73s
-------------------- class (single) --------------------
00-#0228 : 0.986842 / 01-#0370 : 0.991892 / 02-#0008 : 0.750000 / 03-#0051 : 0.941176 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.967742 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.933333 / 13-#0492 : 0.997967 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 0.980769 / 17-#0021 : 0.904762 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 1.000000 / 
24-#0002 : 0.500000 / 25-#0002 : 0.500000 / 
-------------------- class (others) --------------------
00-#0111 : 0.990991 / 01-#1109 : 0.987376 / 02-#0657 : 0.998478 / 03-#2105 : 0.997150 / 
04-#1436 : 0.999304 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0254 acc 40476.0000 topk 40755.0000
valid loss (single) 0.0625 acc 0.9822 topk 0.9973
valid loss (others) 0.0585 acc 0.9890 topk 0.9970
lr 8.015160008714386e-05 time 31256.80s
-------------------- class (single) --------------------
00-#0228 : 0.916667 / 01-#0370 : 1.000000 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.983871 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 1.000000 / 13-#0492 : 0.997967 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 0.990385 / 17-#0021 : 0.952381 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.600000 / 
24-#0002 : 0.500000 / 25-#0002 : 0.500000 / 
-------------------- class (others) --------------------
00-#0111 : 0.855856 / 01-#1109 : 0.996393 / 02-#0657 : 0.978691 / 03-#2105 : 0.997625 / 
04-#1436 : 0.997911 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0193 acc 40574.0000 topk 40774.0000
valid loss (single) 0.0495 acc 0.9860 topk 0.9946
valid loss (others) 0.0574 acc 0.9901 topk 0.9968
lr 7.23146308710381e-05 time 33719.86s
-------------------- class (single) --------------------
00-#0228 : 0.991228 / 01-#0370 : 0.997297 / 02-#0008 : 0.625000 / 03-#0051 : 0.960784 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.983871 / 07-#0185 : 0.989189 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.933333 / 13-#0492 : 0.993902 / 14-#0012 : 0.916667 / 15-#0017 : 1.000000 / 
16-#0104 : 0.990385 / 17-#0021 : 0.904762 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 0.900000 / 21-#0027 : 1.000000 / 22-#0007 : 1.000000 / 23-#0005 : 0.800000 / 
24-#0002 : 0.000000 / 25-#0002 : 0.500000 / 
-------------------- class (others) --------------------
00-#0111 : 0.990991 / 01-#1109 : 0.990983 / 02-#0657 : 0.987823 / 03-#2105 : 0.992874 / 
04-#1436 : 0.998607 / 05-#031

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0138 acc 40629.0000 topk 40786.0000
valid loss (single) 0.0471 acc 0.9892 topk 0.9957
valid loss (others) 0.0532 acc 0.9922 topk 0.9976
lr 6.371946635410054e-05 time 36488.40s
-------------------- class (single) --------------------
00-#0228 : 0.991228 / 01-#0370 : 1.000000 / 02-#0008 : 0.625000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 0.857143 / 06-#0062 : 1.000000 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.911111 / 13-#0492 : 0.995935 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 1.000000 / 17-#0021 : 0.952381 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 1.000000 / 23-#0005 : 1.000000 / 
24-#0002 : 0.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 1.000000 / 01-#1109 : 0.994590 / 02-#0657 : 0.987823 / 03-#2105 : 0.995724 / 
04-#1436 : 0.997214 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0110 acc 40675.0000 topk 40803.0000
valid loss (single) 0.0692 acc 0.9854 topk 0.9957
valid loss (others) 0.0382 acc 0.9935 topk 0.9976
lr 5.465880455519194e-05 time 38973.94s
-------------------- class (single) --------------------
00-#0228 : 0.986842 / 01-#0370 : 0.989189 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.967742 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.955556 / 13-#0492 : 0.995935 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 1.000000 / 17-#0021 : 0.952381 / 18-#0036 : 0.916667 / 19-#0004 : 0.750000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 1.000000 / 01-#1109 : 0.991885 / 02-#0657 : 0.992390 / 03-#2105 : 0.997150 / 
04-#1436 : 0.997911 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0080 acc 40718.0000 topk 40802.0000
valid loss (single) 0.0495 acc 0.9887 topk 0.9973
valid loss (others) 0.0445 acc 0.9937 topk 0.9980
lr 4.544119544480807e-05 time 41596.43s
-------------------- class (single) --------------------
00-#0228 : 0.991228 / 01-#0370 : 0.997297 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.983871 / 07-#0185 : 0.978378 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.955556 / 13-#0492 : 0.997967 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 0.990385 / 17-#0021 : 0.952381 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 1.000000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 1.000000 / 01-#1109 : 0.997295 / 02-#0657 : 0.993912 / 03-#2105 : 0.997625 / 
04-#1436 : 0.997214 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0056 acc 40745.0000 topk 40803.0000
valid loss (single) 0.0515 acc 0.9865 topk 0.9957
valid loss (others) 0.0490 acc 0.9937 topk 0.9980
lr 3.638053364589948e-05 time 44646.42s
-------------------- class (single) --------------------
00-#0228 : 0.991228 / 01-#0370 : 0.997297 / 02-#0008 : 0.625000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 0.857143 / 06-#0062 : 1.000000 / 07-#0185 : 0.972973 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.955556 / 13-#0492 : 0.997967 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 0.990385 / 17-#0021 : 0.952381 / 18-#0036 : 0.972222 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.990991 / 01-#1109 : 0.994590 / 02-#0657 : 0.995434 / 03-#2105 : 0.998575 / 
04-#1436 : 0.998607 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0043 acc 40755.0000 topk 40809.0000
valid loss (single) 0.0497 acc 0.9870 topk 0.9978
valid loss (others) 0.0460 acc 0.9951 topk 0.9983
lr 2.7785369128961917e-05 time 47655.99s
-------------------- class (single) --------------------
00-#0228 : 0.995614 / 01-#0370 : 0.994595 / 02-#0008 : 0.625000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 0.983871 / 07-#0185 : 0.994595 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.888889 / 13-#0492 : 0.997967 / 14-#0012 : 0.916667 / 15-#0017 : 0.941176 / 
16-#0104 : 1.000000 / 17-#0021 : 0.952381 / 18-#0036 : 0.944444 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 1.000000 / 01-#1109 : 0.994590 / 02-#0657 : 0.996956 / 03-#2105 : 0.998575 / 
04-#1436 : 0.999304 / 05-#0

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0030 acc 40777.0000 topk 40809.0000
valid loss (single) 0.0371 acc 0.9919 topk 0.9978
valid loss (others) 0.0397 acc 0.9951 topk 0.9986
lr 1.9948399912856146e-05 time 50228.59s
-------------------- class (single) --------------------
00-#0228 : 0.995614 / 01-#0370 : 0.994595 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 0.857143 / 06-#0062 : 0.983871 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.977778 / 13-#0492 : 0.997967 / 14-#0012 : 1.000000 / 15-#0017 : 1.000000 / 
16-#0104 : 0.990385 / 17-#0021 : 1.000000 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 0.500000 / 25-#0002 : 0.500000 / 
-------------------- class (others) --------------------
00-#0111 : 0.990991 / 01-#1109 : 0.995491 / 02-#0657 : 0.996956 / 03-#2105 : 0.998575 / 
04-#1436 : 0.999304 / 05-#0

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0026 acc 40780.0000 topk 40814.0000
valid loss (single) 0.0365 acc 0.9924 topk 0.9984
valid loss (others) 0.0391 acc 0.9953 topk 0.9988
lr 1.3136504584828086e-05 time 52910.29s
-------------------- class (single) --------------------
00-#0228 : 0.995614 / 01-#0370 : 1.000000 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 0.857143 / 06-#0062 : 0.983871 / 07-#0185 : 0.994595 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.977778 / 13-#0492 : 0.997967 / 14-#0012 : 1.000000 / 15-#0017 : 1.000000 / 
16-#0104 : 0.990385 / 17-#0021 : 1.000000 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 1.000000 / 
24-#0002 : 0.500000 / 25-#0002 : 0.000000 / 
-------------------- class (others) --------------------
00-#0111 : 0.981982 / 01-#1109 : 0.994590 / 02-#0657 : 0.998478 / 03-#2105 : 0.998100 / 
04-#1436 : 0.999304 / 05-#0

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0017 acc 40796.0000 topk 40815.0000
valid loss (single) 0.0337 acc 0.9935 topk 0.9984
valid loss (others) 0.0381 acc 0.9955 topk 0.9986
lr 7.58165407030577e-06 time 55353.34s
-------------------- class (single) --------------------
00-#0228 : 0.991228 / 01-#0370 : 0.997297 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 1.000000 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.955556 / 13-#0492 : 0.997967 / 14-#0012 : 1.000000 / 15-#0017 : 1.000000 / 
16-#0104 : 1.000000 / 17-#0021 : 1.000000 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.500000 / 
-------------------- class (others) --------------------
00-#0111 : 0.981982 / 01-#1109 : 0.994590 / 02-#0657 : 0.998478 / 03-#2105 : 0.999525 / 
04-#1436 : 0.997911 / 05-#031

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0013 acc 40806.0000 topk 40815.0000
valid loss (single) 0.0342 acc 0.9930 topk 0.9989
valid loss (others) 0.0371 acc 0.9959 topk 0.9988
lr 3.473012141252427e-06 time 57796.59s
-------------------- class (single) --------------------
00-#0228 : 0.995614 / 01-#0370 : 1.000000 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 1.000000 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.933333 / 13-#0492 : 0.997967 / 14-#0012 : 1.000000 / 15-#0017 : 1.000000 / 
16-#0104 : 1.000000 / 17-#0021 : 0.952381 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 0.500000 / 25-#0002 : 0.500000 / 
-------------------- class (others) --------------------
00-#0111 : 0.981982 / 01-#1109 : 0.994590 / 02-#0657 : 1.000000 / 03-#2105 : 0.999525 / 
04-#1436 : 0.999304 / 05-#03

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2551.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


train loss 0.0012 acc 40798.0000 topk 40813.0000
valid loss (single) 0.0363 acc 0.9946 topk 0.9989
valid loss (others) 0.0378 acc 0.9957 topk 0.9984
lr 9.504936707891058e-07 time 60246.04s
-------------------- class (single) --------------------
00-#0228 : 1.000000 / 01-#0370 : 1.000000 / 02-#0008 : 0.750000 / 03-#0051 : 0.980392 / 
04-#0052 : 1.000000 / 05-#0007 : 1.000000 / 06-#0062 : 1.000000 / 07-#0185 : 1.000000 / 
08-#0011 : 1.000000 / 09-#0061 : 1.000000 / 10-#0002 : 1.000000 / 11-#0031 : 1.000000 / 
12-#0045 : 0.933333 / 13-#0492 : 0.997967 / 14-#0012 : 1.000000 / 15-#0017 : 1.000000 / 
16-#0104 : 1.000000 / 17-#0021 : 1.000000 / 18-#0036 : 1.000000 / 19-#0004 : 1.000000 / 
20-#0010 : 1.000000 / 21-#0027 : 1.000000 / 22-#0007 : 0.857143 / 23-#0005 : 0.800000 / 
24-#0002 : 1.000000 / 25-#0002 : 0.500000 / 
-------------------- class (others) --------------------
00-#0111 : 0.981982 / 01-#1109 : 0.995491 / 02-#0657 : 1.000000 / 03-#2105 : 0.999525 / 
04-#1436 : 0.998607 / 05-#03

In [26]:
valid_trues1==0

array([False, False, False, ..., False, False, False])