In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
import wandb

import torch
from torch import nn
from torchvision import models
import torch.optim as optim

from dataset import get_dataloaders, get_datasets
from utils import seed_everything
from trainer import Trainer

# Params
Image.MAX_IMAGE_PIXELS = 1e11
CFG = {
    'seed': 42,
    'cv_fold': 5,
    'base_model': 'efficientnet_b1',   # resnet18/34/50, efficientnet_b0/b1/b2/b3/b4
    'img_size': 1024,
    'batch_size': 16,
    'freeze_epochs': 1,
    'epochs': 5,
    'base_lr': 1e-3,
    'affine_degrees': 10,
    'affine_translate': (0.1, 0.1),
    'affine_scale': (1.0, 1.4),
    'dataloader_num_workers': 8,
    'scheduler_step_size': 2,
    'img_color_mean': [0.8708488980328596, 0.75677901508938, 0.8545134911215124],
    'img_color_std': [0.08086288591996027, 0.11553960008706814, 0.06914169213328555],
    'optimizer': 'AdamW',
    'scheduler': 'CosineAnnealingLR',
    'lr_gamma': 0.1,
    'lr_cycl_step_size': 3,
    'sgd_momentum': 0.9
}
tags=['torch', 'thumbnails', 'cv']
notes = ''
plot_samples = False

# Wandb
wandb.login(key='1b0401db7513303bdea77fb070097f9d2850cf3b')
run = wandb.init(project='kaggle-ubc-ocean', config=CFG, tags=tags)

# Label encoder/decoder
encode = {'HGSC': 0, 'LGSC': 1, 'EC': 2, 'CC': 3, 'MC': 4}
decode = {v: k for k, v in encode.items()}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Paths
root = '/media/latlab/MR/projects/kaggle-ubc-ocean'
data_dir = os.path.join(root, 'data')
results_dir = os.path.join(root, 'results')
train_csv = 'train.csv'
train_image_dir = os.path.join(data_dir, 'train_images')
train_thumbnail_dir = os.path.join(data_dir, 'train_thumbnails')

# Seed
seed_everything(CFG['seed'])

# Load data
df = pd.read_csv(os.path.join(data_dir, train_csv))
df['label'] = df.loc[:,'label'].map(encode)

# Functions
def train_model(CFG, train_image_dir, train_thumbnail_dir, df_train, df_validation, encode, wandb_log=False):
    # Data loaders
    datasets = get_datasets(CFG, train_image_dir, train_thumbnail_dir, df_train, df_validation)
    dataloaders = get_dataloaders(CFG, datasets)

    # Model definition
    model = models.get_model(CFG['base_model'], weights='DEFAULT').to(device)

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace the last fully-connected layer
    if CFG['base_model'].startswith('resnet'):
        model.fc = nn.Linear(in_features=model.fc.in_features, out_features=len(encode)).to(device)
    elif CFG['base_model'].startswith('efficientnet'):
        model.classifier = nn.Linear(in_features=model.classifier[1].in_features, out_features=len(encode)).to(device)
    
    # Loss function
    loss_fn = nn.CrossEntropyLoss()

    # Optimizer
    if CFG['optimizer'] == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=CFG['base_lr'], momentum=CFG['sgd_momentum'])
    elif CFG['optimizer'] == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=CFG['base_lr'])
    elif CFG['optimizer'] == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=CFG['base_lr'])
    
    # Scheduler
    if CFG['scheduler'] == 'StepLR':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=CFG['scheduler_step_size'], gamma=CFG['lr_gamma'], verbose=True)
    elif CFG['scheduler'] == 'CyclicLR':
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=CFG['base_lr'], max_lr=CFG['base_lr']*10,
                                                step_size_up=3, cycle_momentum=False, mode='triangular2', verbose=True)
    elif CFG['scheduler'] == 'CosineAnnealingLR':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['epochs']+CFG['freeze_epochs'], verbose=True)
    elif CFG['scheduler'] == 'OneCycleLR':
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CFG['base_lr'], total_steps=CFG['epochs']+CFG['freeze_epochs'], verbose=True)

    # Training
    trainer = Trainer(model, dataloaders, loss_fn, optimizer, scheduler, device, metric='balanced_accuracy', wandb_log=wandb_log)
    model, _ = trainer.train_epochs(num_epochs=CFG['freeze_epochs'])
    trainer.unfreeze()
    model, balanced_acc = trainer.train_epochs(num_epochs=CFG['epochs'])
    return model, balanced_acc

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnaraiadam88[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/latlab/.netrc


In [2]:
# Show training data
if plot_samples:
    dataloaders = get_dataloaders(CFG, get_datasets(CFG, train_image_dir, train_thumbnail_dir, df, df))
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloaders['train']):
            plt.figure(figsize=(np.ceil(len(X)/2), 12))
            for i in range(len(X)):
                plt.subplot(int(np.ceil(len(X)/6)), 6, i+1)
                img_data = X[i].permute(1, 2, 0).cpu().numpy()
                plt.imshow(img_data)
                plt.title(f'{decode[y[i].item()]}')
            if batch >= 1:
                break

In [3]:
skf = StratifiedKFold(n_splits=CFG['cv_fold'], random_state=CFG['seed'], shuffle=True)
balanced_acc_list = []
for cv, (train_index, valid_index) in enumerate(skf.split(np.zeros(len(df['label'])), df['label'])):
    print(f"Cross-validation fold {cv+1}/{CFG['cv_fold']}")
    df_train = df.iloc[train_index]
    df_validation = df.iloc[valid_index]
    run_name = f'{run.name}-cv{cv+1}'
    model, balanced_acc = train_model(CFG, train_image_dir, train_thumbnail_dir, df_train, df_validation, encode)
    balanced_acc_list.append(balanced_acc)
    torch.save(model.state_dict(), os.path.join(results_dir, 'models', f'ubc-ocean-{run_name}.pt'))
    wandb.log({f'balanced_acc_cv{cv+1}': balanced_acc})
wandb.log({f'mean_balanced_acc': np.mean(balanced_acc_list)})
wandb.finish()

Cross-validation fold 1/5
Adjusting learning rate of group 0 to 1.0000e-04.
Epoch 1/1
----------


100%|██████████| 27/27 [00:34<00:00,  1.29s/it]
100%|██████████| 7/7 [00:11<00:00,  1.60s/it]


Adjusting learning rate of group 0 to 9.7975e-05.
train loss: 1.5995, test loss: 1.5827, balanced_accuracy: 0.3130

Training complete in 0m 46s
Final balanced_accuracy: 0.312954

Epoch 1/10
----------


100%|██████████| 27/27 [00:35<00:00,  1.32s/it]
100%|██████████| 7/7 [00:11<00:00,  1.58s/it]


Adjusting learning rate of group 0 to 9.2063e-05.
train loss: 1.4837, test loss: 1.3499, balanced_accuracy: 0.3070

Epoch 2/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.44s/it]
100%|██████████| 7/7 [00:11<00:00,  1.58s/it]


Adjusting learning rate of group 0 to 8.2743e-05.
train loss: 1.2481, test loss: 1.1631, balanced_accuracy: 0.4443

Epoch 3/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.38s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 7.0771e-05.
train loss: 1.0711, test loss: 1.0630, balanced_accuracy: 0.5621

Epoch 4/10
----------


100%|██████████| 27/27 [00:35<00:00,  1.32s/it]
100%|██████████| 7/7 [00:11<00:00,  1.58s/it]


Adjusting learning rate of group 0 to 5.7116e-05.
train loss: 0.9194, test loss: 0.9620, balanced_accuracy: 0.6010

Epoch 5/10
----------


100%|██████████| 27/27 [00:33<00:00,  1.25s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 4.2884e-05.
train loss: 0.8029, test loss: 0.9104, balanced_accuracy: 0.6332

Epoch 6/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.36s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 2.9229e-05.
train loss: 0.7314, test loss: 0.8924, balanced_accuracy: 0.6710

Epoch 7/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.38s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 1.7257e-05.
train loss: 0.6729, test loss: 0.8650, balanced_accuracy: 0.6817

Epoch 8/10
----------


100%|██████████| 27/27 [00:34<00:00,  1.29s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 7.9373e-06.
train loss: 0.6311, test loss: 0.8545, balanced_accuracy: 0.6872

Epoch 9/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.35s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 2.0254e-06.
train loss: 0.6650, test loss: 0.8618, balanced_accuracy: 0.6984

Epoch 10/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.37s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 0.0000e+00.
train loss: 0.6342, test loss: 0.8581, balanced_accuracy: 0.6848

Training complete in 7m 54s
Final balanced_accuracy: 0.684751

Cross-validation fold 2/5
Adjusting learning rate of group 0 to 1.0000e-04.
Epoch 1/1
----------


100%|██████████| 27/27 [00:38<00:00,  1.42s/it]
100%|██████████| 7/7 [00:09<00:00,  1.39s/it]


Adjusting learning rate of group 0 to 9.7975e-05.
train loss: 1.5954, test loss: 1.5819, balanced_accuracy: 0.3013

Training complete in 0m 48s
Final balanced_accuracy: 0.301255

Epoch 1/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.34s/it]
100%|██████████| 7/7 [00:09<00:00,  1.41s/it]


Adjusting learning rate of group 0 to 9.2063e-05.
train loss: 1.4614, test loss: 1.3166, balanced_accuracy: 0.3551

Epoch 2/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.41s/it]
100%|██████████| 7/7 [00:09<00:00,  1.39s/it]


Adjusting learning rate of group 0 to 8.2743e-05.
train loss: 1.2051, test loss: 1.1022, balanced_accuracy: 0.4626

Epoch 3/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.38s/it]
100%|██████████| 7/7 [00:09<00:00,  1.38s/it]


Adjusting learning rate of group 0 to 7.0771e-05.
train loss: 1.0254, test loss: 0.9902, balanced_accuracy: 0.5864

Epoch 4/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.41s/it]
100%|██████████| 7/7 [00:09<00:00,  1.40s/it]


Adjusting learning rate of group 0 to 5.7116e-05.
train loss: 0.8939, test loss: 0.9313, balanced_accuracy: 0.6154

Epoch 5/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.41s/it]
100%|██████████| 7/7 [00:09<00:00,  1.39s/it]


Adjusting learning rate of group 0 to 4.2884e-05.
train loss: 0.7835, test loss: 0.8761, balanced_accuracy: 0.6426

Epoch 6/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
100%|██████████| 7/7 [00:09<00:00,  1.39s/it]


Adjusting learning rate of group 0 to 2.9229e-05.
train loss: 0.6990, test loss: 0.8573, balanced_accuracy: 0.6585

Epoch 7/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.42s/it]
100%|██████████| 7/7 [00:09<00:00,  1.39s/it]


Adjusting learning rate of group 0 to 1.7257e-05.
train loss: 0.6370, test loss: 0.8337, balanced_accuracy: 0.6904

Epoch 8/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.41s/it]
100%|██████████| 7/7 [00:09<00:00,  1.41s/it]


Adjusting learning rate of group 0 to 7.9373e-06.
train loss: 0.6481, test loss: 0.8399, balanced_accuracy: 0.6574

Epoch 9/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.43s/it]
100%|██████████| 7/7 [00:09<00:00,  1.39s/it]


Adjusting learning rate of group 0 to 2.0254e-06.
train loss: 0.6112, test loss: 0.8395, balanced_accuracy: 0.6646

Epoch 10/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.38s/it]
100%|██████████| 7/7 [00:09<00:00,  1.40s/it]


Adjusting learning rate of group 0 to 0.0000e+00.
train loss: 0.6337, test loss: 0.8318, balanced_accuracy: 0.6588

Training complete in 7m 55s
Final balanced_accuracy: 0.658842

Cross-validation fold 3/5
Adjusting learning rate of group 0 to 1.0000e-04.
Epoch 1/1
----------


100%|██████████| 27/27 [00:37<00:00,  1.37s/it]
100%|██████████| 7/7 [00:09<00:00,  1.40s/it]


Adjusting learning rate of group 0 to 9.7975e-05.
train loss: 1.6049, test loss: 1.5896, balanced_accuracy: 0.2481

Training complete in 0m 47s
Final balanced_accuracy: 0.248135

Epoch 1/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.33s/it]
100%|██████████| 7/7 [00:09<00:00,  1.40s/it]


Adjusting learning rate of group 0 to 9.2063e-05.
train loss: 1.4817, test loss: 1.3306, balanced_accuracy: 0.3254

Epoch 2/10
----------


100%|██████████| 27/27 [00:35<00:00,  1.33s/it]
100%|██████████| 7/7 [00:09<00:00,  1.41s/it]


Adjusting learning rate of group 0 to 8.2743e-05.
train loss: 1.2487, test loss: 1.1375, balanced_accuracy: 0.3376

Epoch 3/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
100%|██████████| 7/7 [00:09<00:00,  1.40s/it]


Adjusting learning rate of group 0 to 7.0771e-05.
train loss: 1.0510, test loss: 1.0313, balanced_accuracy: 0.5975

Epoch 4/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.39s/it]
100%|██████████| 7/7 [00:09<00:00,  1.42s/it]


Adjusting learning rate of group 0 to 5.7116e-05.
train loss: 0.9454, test loss: 0.9648, balanced_accuracy: 0.6106

Epoch 5/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.39s/it]
100%|██████████| 7/7 [00:09<00:00,  1.42s/it]


Adjusting learning rate of group 0 to 4.2884e-05.
train loss: 0.8300, test loss: 0.9122, balanced_accuracy: 0.6422

Epoch 6/10
----------


100%|██████████| 27/27 [00:39<00:00,  1.45s/it]
100%|██████████| 7/7 [00:09<00:00,  1.42s/it]


Adjusting learning rate of group 0 to 2.9229e-05.
train loss: 0.7386, test loss: 0.8938, balanced_accuracy: 0.6613

Epoch 7/10
----------


100%|██████████| 27/27 [00:39<00:00,  1.45s/it]
100%|██████████| 7/7 [00:09<00:00,  1.42s/it]


Adjusting learning rate of group 0 to 1.7257e-05.
train loss: 0.7290, test loss: 0.8704, balanced_accuracy: 0.6665

Epoch 8/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.41s/it]
100%|██████████| 7/7 [00:09<00:00,  1.42s/it]


Adjusting learning rate of group 0 to 7.9373e-06.
train loss: 0.6629, test loss: 0.8548, balanced_accuracy: 0.6901

Epoch 9/10
----------


100%|██████████| 27/27 [00:34<00:00,  1.29s/it]
100%|██████████| 7/7 [00:09<00:00,  1.41s/it]


Adjusting learning rate of group 0 to 2.0254e-06.
train loss: 0.6762, test loss: 0.8537, balanced_accuracy: 0.6854

Epoch 10/10
----------


100%|██████████| 27/27 [00:39<00:00,  1.47s/it]
100%|██████████| 7/7 [00:09<00:00,  1.42s/it]


Adjusting learning rate of group 0 to 0.0000e+00.
train loss: 0.6720, test loss: 0.8645, balanced_accuracy: 0.6758

Training complete in 7m 55s
Final balanced_accuracy: 0.675833

Cross-validation fold 4/5
Adjusting learning rate of group 0 to 1.0000e-04.
Epoch 1/1
----------


100%|██████████| 27/27 [00:35<00:00,  1.30s/it]
100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Adjusting learning rate of group 0 to 9.7975e-05.
train loss: 1.6027, test loss: 1.5790, balanced_accuracy: 0.2932

Training complete in 0m 47s
Final balanced_accuracy: 0.293197

Epoch 1/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.37s/it]
100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Adjusting learning rate of group 0 to 9.2063e-05.
train loss: 1.4646, test loss: 1.3249, balanced_accuracy: 0.3522

Epoch 2/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.34s/it]
100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Adjusting learning rate of group 0 to 8.2743e-05.
train loss: 1.2387, test loss: 1.1548, balanced_accuracy: 0.4152

Epoch 3/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.39s/it]
100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Adjusting learning rate of group 0 to 7.0771e-05.
train loss: 1.0416, test loss: 1.0561, balanced_accuracy: 0.5125

Epoch 4/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.38s/it]
100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Adjusting learning rate of group 0 to 5.7116e-05.
train loss: 0.9214, test loss: 0.9967, balanced_accuracy: 0.5233

Epoch 5/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.35s/it]
100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Adjusting learning rate of group 0 to 4.2884e-05.
train loss: 0.8197, test loss: 0.9557, balanced_accuracy: 0.5366

Epoch 6/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.34s/it]
100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Adjusting learning rate of group 0 to 2.9229e-05.
train loss: 0.7532, test loss: 0.9348, balanced_accuracy: 0.5289

Epoch 7/10
----------


100%|██████████| 27/27 [00:35<00:00,  1.32s/it]
100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Adjusting learning rate of group 0 to 1.7257e-05.
train loss: 0.6730, test loss: 0.9374, balanced_accuracy: 0.5289

Epoch 8/10
----------


100%|██████████| 27/27 [00:35<00:00,  1.32s/it]
100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Adjusting learning rate of group 0 to 7.9373e-06.
train loss: 0.6643, test loss: 0.9365, balanced_accuracy: 0.5438

Epoch 9/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.41s/it]
100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Adjusting learning rate of group 0 to 2.0254e-06.
train loss: 0.6247, test loss: 0.9173, balanced_accuracy: 0.5438

Epoch 10/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.36s/it]
100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Adjusting learning rate of group 0 to 0.0000e+00.
train loss: 0.6222, test loss: 0.9174, balanced_accuracy: 0.5527

Training complete in 8m 3s
Final balanced_accuracy: 0.552721

Cross-validation fold 5/5
Adjusting learning rate of group 0 to 1.0000e-04.
Epoch 1/1
----------


100%|██████████| 27/27 [00:35<00:00,  1.31s/it]
100%|██████████| 7/7 [00:11<00:00,  1.58s/it]


Adjusting learning rate of group 0 to 9.7975e-05.
train loss: 1.6003, test loss: 1.5916, balanced_accuracy: 0.2770

Training complete in 0m 46s
Final balanced_accuracy: 0.276973

Epoch 1/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.37s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 9.2063e-05.
train loss: 1.4632, test loss: 1.3168, balanced_accuracy: 0.3009

Epoch 2/10
----------


100%|██████████| 27/27 [00:39<00:00,  1.45s/it]
100%|██████████| 7/7 [00:11<00:00,  1.61s/it]


Adjusting learning rate of group 0 to 8.2743e-05.
train loss: 1.2467, test loss: 1.1139, balanced_accuracy: 0.4405

Epoch 3/10
----------


100%|██████████| 27/27 [00:36<00:00,  1.35s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 7.0771e-05.
train loss: 1.0776, test loss: 0.9690, balanced_accuracy: 0.5491

Epoch 4/10
----------


100%|██████████| 27/27 [00:39<00:00,  1.48s/it]
100%|██████████| 7/7 [00:11<00:00,  1.61s/it]


Adjusting learning rate of group 0 to 5.7116e-05.
train loss: 0.9861, test loss: 0.8691, balanced_accuracy: 0.6915

Epoch 5/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.40s/it]
100%|██████████| 7/7 [00:11<00:00,  1.60s/it]


Adjusting learning rate of group 0 to 4.2884e-05.
train loss: 0.8727, test loss: 0.7906, balanced_accuracy: 0.7177

Epoch 6/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.38s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 2.9229e-05.
train loss: 0.8230, test loss: 0.7273, balanced_accuracy: 0.7046

Epoch 7/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.37s/it]
100%|██████████| 7/7 [00:11<00:00,  1.61s/it]


Adjusting learning rate of group 0 to 1.7257e-05.
train loss: 0.7580, test loss: 0.7070, balanced_accuracy: 0.7438

Epoch 8/10
----------


100%|██████████| 27/27 [00:37<00:00,  1.38s/it]
100%|██████████| 7/7 [00:11<00:00,  1.60s/it]


Adjusting learning rate of group 0 to 7.9373e-06.
train loss: 0.7303, test loss: 0.6965, balanced_accuracy: 0.7105

Epoch 9/10
----------


100%|██████████| 27/27 [00:40<00:00,  1.49s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 2.0254e-06.
train loss: 0.6908, test loss: 0.6907, balanced_accuracy: 0.7450

Epoch 10/10
----------


100%|██████████| 27/27 [00:38<00:00,  1.42s/it]
100%|██████████| 7/7 [00:11<00:00,  1.59s/it]


Adjusting learning rate of group 0 to 0.0000e+00.
train loss: 0.7304, test loss: 0.6879, balanced_accuracy: 0.7772

Training complete in 8m 12s
Final balanced_accuracy: 0.777177





VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
balanced_acc_cv1,▁
balanced_acc_cv2,▁
balanced_acc_cv3,▁
balanced_acc_cv4,▁
balanced_acc_cv5,▁
mean_balanced_acc,▁

0,1
balanced_acc_cv1,0.68475
balanced_acc_cv2,0.65884
balanced_acc_cv3,0.67583
balanced_acc_cv4,0.55272
balanced_acc_cv5,0.77718
mean_balanced_acc,0.66986


In [4]:
# Final training on all data
model, _ = train_model(CFG, train_image_dir, train_thumbnail_dir, df, df, encode, wandb_log=False)
torch.save(model.state_dict(), os.path.join(results_dir, 'models', f'ubc-ocean-{run.name}.pt'))

Adjusting learning rate of group 0 to 1.0000e-04.
Epoch 1/1
----------


100%|██████████| 34/34 [00:43<00:00,  1.27s/it]
100%|██████████| 34/34 [00:46<00:00,  1.36s/it]


Adjusting learning rate of group 0 to 9.7975e-05.
train loss: 1.5956, test loss: 1.5719, balanced_accuracy: 0.3363

Training complete in 1m 29s
Final balanced_accuracy: 0.336308

Epoch 1/10
----------


 65%|██████▍   | 22/34 [00:29<00:09,  1.25it/s]