In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
import wandb

import torch
from torch import nn
from torchvision import models
import torch.optim as optim

from dataset import get_dataloaders, get_datasets
from utils import seed_everything
from trainer import Trainer

# Params
Image.MAX_IMAGE_PIXELS = 1e11
CFG = {
    'seed': 42,
    'base_model': 'resnet34',   # resnet18/34/50, efficientnet_v2_s/m/l
    'img_size': 1024,
    'batch_size': 8,
    'freeze_epochs': 1,
    'epochs': 10,
    'base_lr': 1e-3,
    'affine_degrees': 10,
    'affine_translate': (0.1, 0.2),
    'affine_scale': (0.8, 1.2),
    'cv_fold': 5,
    'dataloader_num_workers': 36
}

# Wandb
wandb.login(key='1b0401db7513303bdea77fb070097f9d2850cf3b')
run = wandb.init(project='kaggle-ubc-ocean', config=CFG, tags=['torch', 'baseline'])

# Label encoder/decoder
encode = {'HGSC': 0, 'LGSC': 1, 'EC': 2, 'CC': 3, 'MC': 4}
decode = {v: k for k, v in encode.items()}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Paths
root = '/media/latlab/MR/projects/kaggle-ubc-ocean'
data_dir = os.path.join(root, 'data')
results_dir = os.path.join(root, 'results')
train_csv = 'train.csv'
train_image_dir = os.path.join(data_dir, 'train_images')
train_thumbnail_dir = os.path.join(data_dir, 'train_thumbnails')

# Seed
seed_everything(CFG['seed'])

# Load data
df = pd.read_csv(os.path.join(data_dir, train_csv))
df['label'] = df.loc[:,'label'].map(encode)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnaraiadam88[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/latlab/.netrc


In [2]:
def train_model(CFG, train_image_dir, train_thumbnail_dir, df_train, df_validation, encode, wandb_log=False):
    # Data loaders
    datasets = get_datasets(CFG, train_image_dir, train_thumbnail_dir, df_train, df_validation)
    dataloaders = get_dataloaders(CFG, datasets)

    # Model definition
    model = models.get_model(CFG['base_model'], weights='DEFAULT').to(device)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, len(encode)).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=CFG['base_lr'], momentum=0.9)
    exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Training
    trainer = Trainer(model, dataloaders, loss_fn, optimizer, exp_lr_scheduler, device, metric='balanced_accuracy', wandb_log=wandb_log)
    model, _ = trainer.train_epochs(num_epochs=CFG['freeze_epochs'])
    trainer.unfreeze()
    model, best_balanced_acc = trainer.train_epochs(num_epochs=CFG['epochs'])
    return model, best_balanced_acc

In [3]:
skf = StratifiedKFold(n_splits=CFG['cv_fold'], random_state=CFG['seed'], shuffle=True)
balanced_acc_list = []
lb = df.label
for cv, (train_index, valid_index) in enumerate(skf.split(np.zeros(len(lb)), lb)):
    print(f"Cross-validation fold {cv+1}/{CFG['cv_fold']}")
    df_train = df.iloc[train_index]
    df_validation = df.iloc[valid_index]
    run_name = f'{run.name}-cv{cv+1}'
    model, balanced_acc = train_model(CFG, train_image_dir, train_thumbnail_dir, df_train, df_validation, encode)
    balanced_acc_list.append(balanced_acc)
    torch.save(model.state_dict(), os.path.join(results_dir, 'models', f'ubc-ocean-{run_name}.pt'))
    wandb.log({f'balanced_acc_cv{cv+1}': balanced_acc})
wandb.log({f'mean_balanced_acc': np.mean(balanced_acc_list)})
wandb.finish()

Cross-validation fold 1/5
Epoch 1/1
----------


100%|██████████| 54/54 [00:08<00:00,  6.25it/s]


train loss: 1.5647, test loss: 1.4412, balanced_accuracy: 0.3670

Training complete in 0m 13s
Final balanced_accuracy: 0.366964

Epoch 1/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.08it/s]


train loss: 1.3092, test loss: 1.2171, balanced_accuracy: 0.5760

Epoch 2/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.25it/s]


train loss: 1.1157, test loss: 1.1895, balanced_accuracy: 0.6058

Epoch 3/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.19it/s]


train loss: 1.0337, test loss: 1.7447, balanced_accuracy: 0.4396

Epoch 4/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.10it/s]


train loss: 0.9618, test loss: 1.1644, balanced_accuracy: 0.6077

Epoch 5/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.38it/s]


train loss: 0.8884, test loss: 1.3756, balanced_accuracy: 0.5886

Epoch 6/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.05it/s]


train loss: 0.8853, test loss: 1.3383, balanced_accuracy: 0.6200

Epoch 7/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.30it/s]


train loss: 0.6769, test loss: 1.0486, balanced_accuracy: 0.6435

Epoch 8/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.17it/s]


train loss: 0.6765, test loss: 1.0631, balanced_accuracy: 0.6330

Epoch 9/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.18it/s]


train loss: 0.6003, test loss: 1.0123, balanced_accuracy: 0.6348

Epoch 10/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.14it/s]


train loss: 0.5927, test loss: 1.0537, balanced_accuracy: 0.6638

Training complete in 2m 49s
Final balanced_accuracy: 0.663790

Cross-validation fold 2/5
Epoch 1/1
----------


100%|██████████| 54/54 [00:06<00:00,  7.90it/s]


train loss: 1.5263, test loss: 1.4178, balanced_accuracy: 0.3150

Training complete in 0m 11s
Final balanced_accuracy: 0.314980

Epoch 1/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.11it/s]


train loss: 1.3307, test loss: 1.2694, balanced_accuracy: 0.5232

Epoch 2/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.07it/s]


train loss: 1.0905, test loss: 0.8927, balanced_accuracy: 0.6507

Epoch 3/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.29it/s]


train loss: 0.9891, test loss: 1.0872, balanced_accuracy: 0.5054

Epoch 4/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.06it/s]


train loss: 1.0257, test loss: 0.8234, balanced_accuracy: 0.5892

Epoch 5/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.24it/s]


train loss: 0.9318, test loss: 0.9271, balanced_accuracy: 0.6199

Epoch 6/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.17it/s]


train loss: 0.8835, test loss: 0.9703, balanced_accuracy: 0.6364

Epoch 7/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.04it/s]


train loss: 0.6480, test loss: 0.8199, balanced_accuracy: 0.7216

Epoch 8/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.01it/s]


train loss: 0.6596, test loss: 0.7555, balanced_accuracy: 0.7241

Epoch 9/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.00it/s]


train loss: 0.6064, test loss: 0.8080, balanced_accuracy: 0.7187

Epoch 10/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.23it/s]


train loss: 0.5694, test loss: 0.8019, balanced_accuracy: 0.6948

Training complete in 2m 47s
Final balanced_accuracy: 0.694841

Cross-validation fold 3/5
Epoch 1/1
----------


100%|██████████| 54/54 [00:06<00:00,  7.95it/s]


train loss: 1.5575, test loss: 1.4017, balanced_accuracy: 0.4129

Training complete in 0m 11s
Final balanced_accuracy: 0.412897

Epoch 1/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.32it/s]


train loss: 1.3365, test loss: 1.2226, balanced_accuracy: 0.4192

Epoch 2/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.41it/s]


train loss: 1.3112, test loss: 1.0956, balanced_accuracy: 0.5763

Epoch 3/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.07it/s]


train loss: 1.1012, test loss: 1.0837, balanced_accuracy: 0.5738

Epoch 4/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.02it/s]


train loss: 0.9690, test loss: 0.9071, balanced_accuracy: 0.6243

Epoch 5/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.07it/s]


train loss: 0.8276, test loss: 1.0782, balanced_accuracy: 0.5622

Epoch 6/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.08it/s]


train loss: 0.8226, test loss: 1.1482, balanced_accuracy: 0.6151

Epoch 7/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.01it/s]


train loss: 0.6862, test loss: 0.8195, balanced_accuracy: 0.6651

Epoch 8/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.13it/s]


train loss: 0.6670, test loss: 0.8163, balanced_accuracy: 0.6948

Epoch 9/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.22it/s]


train loss: 0.6035, test loss: 0.8586, balanced_accuracy: 0.7029

Epoch 10/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.30it/s]


train loss: 0.5487, test loss: 0.8453, balanced_accuracy: 0.7341

Training complete in 2m 50s
Final balanced_accuracy: 0.734127

Cross-validation fold 4/5
Epoch 1/1
----------


100%|██████████| 54/54 [00:07<00:00,  7.70it/s]


train loss: 1.5213, test loss: 1.5678, balanced_accuracy: 0.2488

Training complete in 0m 11s
Final balanced_accuracy: 0.248810

Epoch 1/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.17it/s]


train loss: 1.3526, test loss: 1.1580, balanced_accuracy: 0.5071

Epoch 2/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.19it/s]


train loss: 1.0868, test loss: 1.1506, balanced_accuracy: 0.5133

Epoch 3/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.10it/s]


train loss: 1.0441, test loss: 1.9515, balanced_accuracy: 0.4970

Epoch 4/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.20it/s]


train loss: 0.9922, test loss: 1.2843, balanced_accuracy: 0.6329

Epoch 5/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.01it/s]


train loss: 0.8816, test loss: 1.3256, balanced_accuracy: 0.5514

Epoch 6/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.12it/s]


train loss: 0.7856, test loss: 1.3416, balanced_accuracy: 0.5793

Epoch 7/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.16it/s]


train loss: 0.6876, test loss: 0.9975, balanced_accuracy: 0.6618

Epoch 8/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.18it/s]


train loss: 0.6149, test loss: 0.9908, balanced_accuracy: 0.6734

Epoch 9/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.18it/s]


train loss: 0.6307, test loss: 0.9728, balanced_accuracy: 0.6993

Epoch 10/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.09it/s]


train loss: 0.5794, test loss: 0.9599, balanced_accuracy: 0.7138

Training complete in 2m 51s
Final balanced_accuracy: 0.713790

Cross-validation fold 5/5
Epoch 1/1
----------


100%|██████████| 54/54 [00:06<00:00,  8.15it/s]


train loss: 1.5389, test loss: 1.3697, balanced_accuracy: 0.3831

Training complete in 0m 11s
Final balanced_accuracy: 0.383135

Epoch 1/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.24it/s]


train loss: 1.3927, test loss: 1.0865, balanced_accuracy: 0.5089

Epoch 2/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.18it/s]


train loss: 1.1451, test loss: 1.0678, balanced_accuracy: 0.5179

Epoch 3/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.20it/s]


train loss: 1.0604, test loss: 0.9188, balanced_accuracy: 0.5173

Epoch 4/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.06it/s]


train loss: 1.0171, test loss: 0.9965, balanced_accuracy: 0.5542

Epoch 5/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.08it/s]


train loss: 0.8939, test loss: 1.0099, balanced_accuracy: 0.5492

Epoch 6/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.16it/s]


train loss: 0.8625, test loss: 0.9299, balanced_accuracy: 0.6802

Epoch 7/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.18it/s]


train loss: 0.7387, test loss: 0.8046, balanced_accuracy: 0.6310

Epoch 8/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.21it/s]


train loss: 0.6206, test loss: 0.7520, balanced_accuracy: 0.7304

Epoch 9/10
----------


100%|██████████| 54/54 [00:12<00:00,  4.26it/s]


train loss: 0.6150, test loss: 0.7503, balanced_accuracy: 0.6935

Epoch 10/10
----------


100%|██████████| 54/54 [00:13<00:00,  4.04it/s]


train loss: 0.6042, test loss: 0.7275, balanced_accuracy: 0.7232

Training complete in 2m 54s
Final balanced_accuracy: 0.723214



In [4]:
# Final training on all data
model, _ = train_model(CFG, train_image_dir, train_thumbnail_dir, df, df, encode, wandb_log=False)
torch.save(model.state_dict(), os.path.join(results_dir, 'models', f'ubc-ocean-{run.name}.pt'))

Epoch 1/1
----------


100%|██████████| 68/68 [00:07<00:00,  8.60it/s]


train loss: 1.4656, test loss: 1.2588, balanced_accuracy: 0.3958

Training complete in 0m 17s
Final balanced_accuracy: 0.395752

Epoch 1/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.42it/s]


train loss: 1.2973, test loss: 1.1971, balanced_accuracy: 0.4566

Epoch 2/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.50it/s]


train loss: 1.1999, test loss: 1.0475, balanced_accuracy: 0.5710

Epoch 3/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.41it/s]


train loss: 0.9922, test loss: 0.7465, balanced_accuracy: 0.6813

Epoch 4/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.31it/s]


train loss: 0.9043, test loss: 0.7112, balanced_accuracy: 0.7026

Epoch 5/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.45it/s]


train loss: 0.8807, test loss: 0.7704, balanced_accuracy: 0.7311

Epoch 6/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.27it/s]


train loss: 0.7583, test loss: 0.5487, balanced_accuracy: 0.7654

Epoch 7/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.33it/s]


train loss: 0.6367, test loss: 0.4248, balanced_accuracy: 0.8415

Epoch 8/10
----------


100%|██████████| 68/68 [00:16<00:00,  4.25it/s]


train loss: 0.6089, test loss: 0.3846, balanced_accuracy: 0.8366

Epoch 9/10
----------


100%|██████████| 68/68 [00:16<00:00,  4.23it/s]


train loss: 0.5853, test loss: 0.4129, balanced_accuracy: 0.8468

Epoch 10/10
----------


100%|██████████| 68/68 [00:15<00:00,  4.28it/s]


train loss: 0.6010, test loss: 0.3729, balanced_accuracy: 0.8590

Training complete in 4m 13s
Final balanced_accuracy: 0.859007





VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
balanced_accuracy,▁▂▄▅▆▆▇████
best_balanced_acc_cv1,▁
best_balanced_acc_cv2,▁
best_balanced_acc_cv3,▁
best_balanced_acc_cv4,▁
best_balanced_acc_cv5,▁
mean_best_balanced_acc,▁
train_loss,█▇▆▄▄▃▂▁▁▁▁
valid_loss,██▆▄▄▄▂▁▁▁▁

0,1
balanced_accuracy,0.85901
best_balanced_acc_cv1,0.66379
best_balanced_acc_cv2,0.69484
best_balanced_acc_cv3,0.73413
best_balanced_acc_cv4,0.71379
best_balanced_acc_cv5,0.72321
mean_best_balanced_acc,0.70595
train_loss,0.60101
valid_loss,0.37292
