In [1]:
train_kwargs = {
    "enc_arch": "resnet50",
    "enc_path": None,
    "classifier_type": "one_hot_conf",
    "seed": 42,
    "input_size": 224,
    "colour_jitter": True,
    "cudnn_deterministic": False,
    "batch_size": 32,
    "num_workers": 1,
    "validation_ratio": 0.1,
    "fine_tune": True,
    "optimizer_name": "adamw",
    "warmup_start_lr": 0.000001,
    "lr": 0.0001,
    "min_lr": 0.000001,
    "weight_decay": 0.0001,
    "scheduler_name": "warmup_cosine",
    "max_epochs": 100,
    "warmup_epochs": 10,
    "save_model": True,
    "save_curves": True,
    "use_benthicnet_normalization": False,
    "descendent_matrix_path": "./cfg/hierarchy/descendent_matrix.npy",
    "descendent_matrix": None,
    "hierarchy_dict_path": "./cfg/hierarchy/hierarchy_dict.json",
    "hierarchy_dict": None,
    "custom_trained": True,
    "num_classifiers": 3,
}

In [2]:
from types import SimpleNamespace

import json
import numpy as np
import pandas as pd
import torch

from utils.cost_weighted_ce import CostWeightedCELossWithLogits, \
    CalcDistance, ConfidenceLossWithLogits
from utils.dataset import FathomNetDataset
from utils.utils import build_model, df_split, get_augs, \
    map_label_to_idx, set_seed, collect_hierarchy, \
    convert_indices_to_label, get_cost_matrix


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_kwargs = SimpleNamespace(**train_kwargs)

# Set seed for reproducibility
set_seed(train_kwargs.seed, cudnn_deterministic=train_kwargs.cudnn_deterministic)

is_hml = train_kwargs.classifier_type == "hml"

df = pd.read_csv("./data/train/annotations.csv")
test_df = pd.read_csv("./data/test/annotations.csv")

df, label_map = map_label_to_idx(df, "label")

label_col = "label_idx"
if is_hml:
    label_map = json.load(
        open("./data/train/index_to_taxon.json", "r")
    )
    assert train_kwargs.hierarchy_dict_path is not None, (
        "hierarchy_dict_path must be specified for HML classifier."
    )
    train_kwargs.hierarchy_dict = json.load(
        open(train_kwargs.hierarchy_dict_path, "r")
    )
    
    label_col = "label_hml"
    df[label_col] = df.apply(collect_hierarchy, axis=1)
    df[label_col] = df[label_col].apply(convert_indices_to_label)

train_df, val_df = df_split(
    df, validation_ratio=train_kwargs.validation_ratio, seed=train_kwargs.seed
)

train_augs, val_augs = get_augs(
    colour_jitter=train_kwargs.colour_jitter,
    input_size=train_kwargs.input_size,
    use_benthicnet=train_kwargs.use_benthicnet_normalization
    )

train_dataset = FathomNetDataset(
    df=train_df, 
    label_col=label_col,
    transform=train_augs,
    )

train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_kwargs.batch_size,
        shuffle=True,
        num_workers=train_kwargs.num_workers,
        pin_memory=True,
        drop_last=True,
    )

if train_kwargs.validation_ratio > 0:
    val_dataset = FathomNetDataset(
        df=val_df, 
        label_col=label_col,
        transform=val_augs,
        )

    val_dataloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=train_kwargs.batch_size,
            shuffle=False,
            num_workers=train_kwargs.num_workers,
            pin_memory=True
        )
else:
    val_dataset = []
    val_dataloader = None

test_dataset = FathomNetDataset(
    df=test_df, 
    label_col=label_col,
    transform=val_augs,
    is_test=True,
    )

test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=train_kwargs.batch_size,
        shuffle=False,
        num_workers=train_kwargs.num_workers,
        pin_memory=True
    )

print("Total samples:", len(df))
print(len(train_dataset), f"training samples, {len(train_dataset)/len(df):.2%} of total")
print(len(val_dataset), f"validation samples, {len(val_dataset)/len(df):.2%} of total")

train_kwargs.steps_per_epoch = len(train_dataloader)

if "one_hot" in train_kwargs.classifier_type:
    metric_cost_matrix = get_cost_matrix(
            mode="cce"
        ).to(device)
    dist_metric = CalcDistance(
        cost_matrix=metric_cost_matrix,
    )
    output_dim = len(label_map)
    criterion = torch.nn.CrossEntropyLoss()
    if train_kwargs.classifier_type != "one_hot":
        mode = train_kwargs.classifier_type.split("_")[2]
        if mode == "conf":
            criterion = ConfidenceLossWithLogits()
        else:
            cost_matrix = get_cost_matrix(
                mode=mode
            ).to(device)
            criterion = CostWeightedCELossWithLogits(
                cost_matrix=cost_matrix,
            )
elif is_hml:
    assert train_kwargs.descendent_matrix_path is not None, (
        "descendent_matrix_path must be specified for HML classifier."
    )
    descendent_matrix = torch.from_numpy(
        np.load(train_kwargs.descendent_matrix_path)
    ).to(device)
    output_dim = descendent_matrix.shape[0]
    train_kwargs.descendent_matrix = descendent_matrix
    criterion = torch.nn.BCELoss()
else:
    raise ValueError("Unsupported classifier type.")

model = build_model(
    encoder_arch=train_kwargs.enc_arch,
    encoder_path=train_kwargs.enc_path,
    classifier_type=train_kwargs.classifier_type,
    num_classifiers=train_kwargs.num_classifiers,
    requires_grad=train_kwargs.fine_tune,
    custom_trained=train_kwargs.custom_trained,
    output_dim=output_dim,
)

model = model.to(device)

print("Model summary:")
print(model)
print("Model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

Total samples: 23699
21330 training samples, 90.00% of total
2369 validation samples, 10.00% of total
No encoder weights loaded.
Model summary:
FathomNetModel(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm

In [3]:
from utils.utils import train

train(
    model=model, 
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    test_loader=test_dataloader,
    label_map=label_map,
    criterion=criterion,
    dist_metric=dist_metric,
    device=device,
    train_kwargs=train_kwargs,
    )

inputs: torch.Size([32, 79])
targets: torch.Size([32])
confidence: torch.Size([32])
ce_loss: torch.Size([32])
scaled_loss: torch.Size([32])
loss: torch.Size([])
inputs: torch.Size([32, 79])
targets: torch.Size([32])
confidence: torch.Size([32])
ce_loss: torch.Size([32])
scaled_loss: torch.Size([32])
loss: torch.Size([])
inputs: torch.Size([32, 79])
targets: torch.Size([32])
confidence: torch.Size([32])
ce_loss: torch.Size([32])
scaled_loss: torch.Size([32])
loss: torch.Size([])
inputs: torch.Size([32, 79])
targets: torch.Size([32])
confidence: torch.Size([32])
ce_loss: torch.Size([32])
scaled_loss: torch.Size([32])
loss: torch.Size([])
inputs: torch.Size([32, 79])
targets: torch.Size([32])
confidence: torch.Size([32])
ce_loss: torch.Size([32])
scaled_loss: torch.Size([32])
loss: torch.Size([])
inputs: torch.Size([32, 79])
targets: torch.Size([32])
confidence: torch.Size([32])
ce_loss: torch.Size([32])
scaled_loss: torch.Size([32])
loss: torch.Size([])
inputs: torch.Size([32, 79])
targe

KeyboardInterrupt: 