In [1]:
import sys
sys.path.append('../src')

import torch
from pathlib import Path
from trainer import get_new_idx_split, run_one_epoch
from transformer import Transformer
from torchmetrics import MeanMetric
from torch.optim.lr_scheduler import StepLR

from utils import set_seed, get_loss
from utils.get_data import get_data_loader, get_dataset

In [2]:
device = 'cuda:3'
dataset_name = 'tracking-6k'
model_name = 'trans_hept'
seed = 0
batch_size = 1
main_metric = 'accuracy@0.9'
set_seed(seed)

In [3]:
model_configs = {'block_size': 100, 'n_hashes': 3, 'num_buckets': 15, 'pe_type': 'none', 'num_heads': 8, 'h_dim': 24, 'n_layers': 4, 'num_w_per_dist': 10}

In [4]:
dataset_dir = Path('../data/') / dataset_name.split("-")[0]
dataset = get_dataset(dataset_name, dataset_dir)
dataset.idx_split = get_new_idx_split(dataset)
loaders = get_data_loader(dataset, dataset.idx_split, batch_size=batch_size)

In [5]:
model = Transformer(
            attn_type=model_name.split("_")[1],
            in_dim=dataset.x_dim,
            coords_dim=dataset.coords_dim,
            task=dataset.dataset_name,
            **model_configs).to(device)
model.name = model_name

In [6]:
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
lr_s = StepLR(opt, step_size=500, gamma=0.5)
criterion = get_loss('infonce', {'dist_metric': 'l2_rbf', 'tau': 0.05})
pt_thres = [0, 0.5, 0.9]
metric_names = ["accuracy", "precision", "recall"]
metrics = {f"{name}@{pt}": MeanMetric(nan_strategy="error") for name in metric_names for pt in pt_thres}
metrics["loss"] = MeanMetric(nan_strategy="error")
metrics["pt_thres"] = pt_thres

In [7]:
for epoch in range(200):
    train_res = run_one_epoch(model, opt, criterion, loaders["train"], "train", epoch, device, metrics, lr_s)
    valid_res = run_one_epoch(model, opt, criterion, loaders["valid"], "valid", epoch, device, metrics, lr_s)
    test_res = run_one_epoch(model, opt, criterion, loaders["test"], "test", epoch, device, metrics, lr_s)
    
    lr_s.step()

[Epoch 0] train, loss: 6.2424, acc: 0.0397, prec: 0.0342, recall: 0.0854: 100%|██████████| 400/400 [01:56<00:00,  3.42it/s]
[Epoch 0] valid, loss: 6.1998, acc: 0.3594, prec: 0.2289, recall: 0.5487: 100%|██████████| 50/50 [00:10<00:00,  4.78it/s]
[Epoch 0] test , loss: 6.0572, acc: 0.3707, prec: 0.2374, recall: 0.5667: 100%|██████████| 50/50 [00:10<00:00,  4.70it/s]
[Epoch 1] train, loss: 4.8732, acc: 0.2942, prec: 0.2036, recall: 0.4727: 100%|██████████| 400/400 [01:50<00:00,  3.61it/s]
[Epoch 1] valid, loss: 5.0285, acc: 0.4119, prec: 0.2514, recall: 0.6019: 100%|██████████| 50/50 [00:10<00:00,  4.86it/s]
[Epoch 1] test , loss: 4.8916, acc: 0.4229, prec: 0.2608, recall: 0.6196: 100%|██████████| 50/50 [00:10<00:00,  4.81it/s]
[Epoch 2] train, loss: 4.7832, acc: 0.3180:  43%|████▎     | 173/400 [00:46<01:01,  3.71it/s]


KeyboardInterrupt: 