In [3]:
import sys
sys.path.append('../src')

import torch
from pathlib import Path
from trainer import get_new_idx_split, run_one_epoch
from transformer import Transformer
from torchmetrics import MeanMetric
from torch.optim.lr_scheduler import StepLR
from copy import deepcopy
from datetime import datetime

from utils import set_seed, get_loss
from utils.get_data import get_data_loader, get_dataset

In [4]:
device = 'cuda:4'
dataset_name = 'tracking-6k'
model_name = 'trans_hept'
seed = 0
batch_size = 1
main_metric = 'accuracy@0.9'
set_seed(seed)

In [5]:
model_configs = {'block_size': 100, 'n_hashes': 3, 'num_buckets': 15, 'pe_type': 'none', 'num_heads': 8, 'h_dim': 24, 'n_layers': 4, 'num_w_per_dist': 10}

In [6]:
dataset_dir = Path('../data/') / dataset_name.split("-")[0]
dataset = get_dataset(dataset_name, dataset_dir)
dataset.idx_split = get_new_idx_split(dataset)
loaders = get_data_loader(dataset, dataset.idx_split, batch_size=batch_size)

In [7]:
model = Transformer(
            attn_type=model_name.split("_")[1],
            in_dim=dataset.x_dim,
            coords_dim=dataset.coords_dim,
            task=dataset.dataset_name,
            **model_configs).to(device)
model.name = model_name
model_dir = '../data/tracking/logs/' + (datetime.now().strftime("%m_%d_%Y-%H_%M_%S") + '-' + dataset_name + '-' + model_name + '-seed' + str(seed) + '-' + main_metric)

In [8]:
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
lr_s = StepLR(opt, step_size=500, gamma=0.5)

criterion = get_loss('infonce', {'dist_metric': 'l2_rbf', 'tau': 0.05})
pt_thres = [0, 0.5, 0.9]
metric_names = ["accuracy", "precision", "recall"]
metrics = {f"{name}@{pt}": MeanMetric(nan_strategy="error") for name in metric_names for pt in pt_thres}
metrics["loss"] = MeanMetric(nan_strategy="error")
metrics["pt_thres"] = pt_thres
best_epoch, best_train = 0, {metric: -1 * float("inf") for metric in metrics.keys()}
best_valid, best_test = deepcopy(best_train), deepcopy(best_train)

In [None]:
for epoch in range(2000):
    train_res = run_one_epoch(model, opt, criterion, loaders["train"], "train", epoch, device, metrics, lr_s)
    valid_res = run_one_epoch(model, opt, criterion, loaders["valid"], "valid", epoch, device, metrics, lr_s)
    test_res = run_one_epoch(model, opt, criterion, loaders["test"], "test", epoch, device, metrics, lr_s)
    
    if (valid_res[main_metric]) > (best_valid[main_metric]):
        best_epoch, best_train, best_valid, best_test = epoch, train_res, valid_res, test_res
        torch.save(model.state_dict(), model_dir + "best_model.pt")

    print(
        f"[Epoch {epoch}] Best epoch: {best_epoch}, train: {best_train[main_metric]:.4f}, "
            f"valid: {best_valid[main_metric]:.4f}, test: {best_test[main_metric]:.4f}"
    )
    print("=" * 50), print("=" * 50)
    
    lr_s.step()

In [12]:
checkpoint = torch.load("./best_model.pt", map_location="cpu")
model.load_state_dict(checkpoint, strict=True)
model = model.to(device)

with torch.no_grad():
    model.eval()
    test_res = run_one_epoch(model, opt, criterion, loaders["test"], "test", 0, device, metrics, lr_s)

print(f"Test accuracy@0.9: {test_res['accuracy@0.9']:.4f}")

[Epoch 0] test , loss: 0.6026, acc: 0.9184, prec: 0.3777, recall: 0.9789: 100%|██████████| 50/50 [00:11<00:00,  4.47it/s]

Test accuracy@0.9: 0.9184



