# Example of Tracking-60k

In [None]:
import sys
sys.path.append('../src')

import torch
from pathlib import Path
from trainer import run_one_epoch, init_metrics
from transformer import Transformer

from utils import set_seed, get_loss
from utils.get_data import get_data_loader, get_dataset

In [2]:
device = 'cuda:0'
dataset_name = 'tracking-60k'
model_name = 'trans_hept'
seed = 0
batch_size = 1
main_metric = 'accuracy@0.9'
set_seed(seed)

In [3]:
model_configs = {'block_size': 100, 'n_hashes': 3, 'num_regions': 150, 'num_heads': 8, 'h_dim': 24, 'n_layers': 4, 'num_w_per_dist': 10}

In [4]:
dataset_dir = Path('../data/') / dataset_name.split("-")[0]
dataset = get_dataset(dataset_name, dataset_dir)
loaders = get_data_loader(dataset, dataset.idx_split, batch_size=batch_size)

In [5]:
model = Transformer(
            attn_type=model_name.split("_")[1],
            in_dim=dataset.x_dim,
            coords_dim=dataset.coords_dim,
            task=dataset.dataset_name,
            **model_configs).to(device)

In [6]:
criterion = get_loss('infonce', {'dist_metric': 'l2_rbf', 'tau': 0.05})
metrics = init_metrics(dataset_name)

In [7]:
checkpoint = torch.load("./ckpt/best_model.pt", map_location="cpu")
model.load_state_dict(checkpoint, strict=True)
model = model.to(device)

with torch.no_grad():
    model.eval()
    test_res = run_one_epoch(model, None, criterion, loaders["test"], "test", 0, device, metrics, None)

print(f"Test accuracy@0.9: {test_res['accuracy@0.9']:.4f}")

[Epoch 0] test , loss: 0.5851, acc: 0.9193, prec: 0.3807, recall: 0.9749: 100%|██████████| 5/5 [00:12<00:00,  2.44s/it]

Test accuracy@0.9: 0.9193



