# Example of Tracking-60k

In [1]:
import os
import sys
sys.path.append('../src')

import torch
import torch.utils.benchmark as benchmark
from pathlib import Path

from transformer import Transformer
from trainer import run_one_epoch, init_metrics
from utils import get_loss
from utils.get_data import get_data_loader, get_dataset

torch.set_num_threads(10)

In [2]:
device = 'cuda:1'
dataset_name = 'tracking-60k'
batch_size = 1
model_configs = {'block_size': 100, 'n_hashes': 3, 'num_regions': 150, 'num_heads': 8, 'h_dim': 24, 'n_layers': 4, 'num_w_per_dist': 10}
torch.cuda.set_device(device)

In [3]:
dataset_dir = Path('../data/') / dataset_name.split("-")[0]
dataset = get_dataset(dataset_name, dataset_dir)

In [4]:
loaders = get_data_loader(dataset, dataset.idx_split, batch_size=batch_size)

In [5]:
model = Transformer(in_dim=dataset.x_dim, coords_dim=dataset.coords_dim, num_classes=dataset.num_classes, **model_configs).to(device)

In [6]:
checkpoint = torch.load("./ckpt/tracking-60k-model.pt", map_location="cpu")
model.load_state_dict(checkpoint, strict=True)
model = model.to(device)

criterion = get_loss('infonce', {'dist_metric': 'l2_rbf', 'tau': 0.05})
metrics = init_metrics(dataset_name)

In [7]:
with torch.no_grad():
    model.eval()
    test_res = run_one_epoch(model, None, criterion, loaders["test"], "test", 0, device, metrics, None)

print(f"Test accuracy@0.9: {test_res['accuracy@0.9']:.4f}")

[Epoch 0] test , loss: 0.5884, acc: 0.9189, prec: 0.3805, recall: 0.9744: 100%|██████████| 5/5 [00:08<00:00,  1.64s/it]

Test accuracy@0.9: 0.9189





# Benchmark Inference Speed

In [8]:
model = torch.compile(model)

In [9]:
torch.set_float32_matmul_precision('high')
for data in loaders["test"]:
    if data.x.shape[0] > 60000:
        data = data.to(device)
        break

model.eval()
with torch.no_grad():
    t1 = benchmark.Timer(
        stmt=f"model(data.x, data.coords, data.batch)", setup=f"from __main__ import model, data"
    )
    m = t1.blocked_autorange(min_run_time=5)
print(m)

<torch.utils.benchmark.utils.common.Measurement object at 0x7ef33f5cf430>
model(data.x, data.coords, data.batch)
setup: from __main__ import model, data
  Median: 29.96 ms
  IQR:    0.07 ms (29.92 to 29.99)
  167 measurements, 1 runs per measurement, 1 thread
