In [1]:
import sys

sys.path.append("../")

In [2]:
from pathlib import Path
from pprint import pprint

import torch 
import yaml
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator, precision_at_k
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader

from retrieval.engine.evaluate import evaluate, get_tester
from retrieval.models.net import RetrievalNet
from retrieval.getter import Getter

DEVICE = "cuda"
WEIGHTS_PATH = Path("../experiments/ROADMAP/GLDv2_ROADMAP_classification_splits128/weights/rolling.ckpt")
getter = Getter()

In [3]:
def load_cfg(cfg_path):
    with open(cfg_path, "r") as f:
        cfg = yaml.safe_load(f)
    return DictConfig(cfg)

# Model

In [4]:
net = RetrievalNet(
    "vit_deit_distilled",
    embed_dim=384,
    norm_features=False,
    without_fc=True,
    with_autocast=False,
)
weights = torch.load(WEIGHTS_PATH, map_location=DEVICE)["net_state"]
net.load_state_dict(weights)
net.to(DEVICE)
net.eval()

RetrievalNet(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU()
          (drop1): Dropout(p=0.0, inplace=False)
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (d

# Data

In [5]:
transform_cfg = load_cfg("../retrieval/config/transform/gldv2.yaml")
dataset_cfg = load_cfg("../retrieval/config/dataset/gldv2_10k_classification_splits.yaml")
test_transform = getter.get_transform(transform_cfg.test)
test_ds = getter.get_dataset(test_transform, 'test', dataset_cfg)
len(test_ds)

39482

# Tester

In [6]:
tester = get_tester(
    normalize_embeddings=False,
    batch_size=16,
    with_AP=False,
    num_workers=4,
    pca=None,
    exclude_ranks=None,
    k=2047
)

# Metrics Evaluation

In [7]:
pprint(evaluate(
    net,
    train_dataset=None,
    val_dataset=None,
    test_dataset=test_ds,
    epoch=None,
    tester=tester,
    custom_eval=None,
))

defaultdict(<class 'dict'>,
            {'test': {'epoch': 'None',
                      'gap_at_1_level0': 0.7326463460922241,
                      'mean_average_precision_at_r_level0': 0.549273167076305,
                      'mean_reciprocal_rank_level0': 0.8403075933456421,
                      'precision_at_1_level0': 0.803713885295386,
                      'r_precision_level0': 0.581161924911092,
                      'recall_at_1000_level0': 0.9311078786849976,
                      'recall_at_100_level0': 0.9086166024208069,
                      'recall_at_10_level0': 0.8547186255455017,
                      'recall_at_16_level0': 0.8682690858840942,
                      'recall_at_1_level0': 0.7553062438964844,
                      'recall_at_20_level0': 0.8742971420288086,
                      'recall_at_2_level0': 0.7909427285194397,
                      'recall_at_30_level0': 0.88534015417099,
                      'recall_at_32_level0': 0.8870877623558044,
       

# Speed Evaluation

In [9]:
from time import time
from tqdm import tqdm
t = 0
for input in tqdm(test_ds):
    with torch.no_grad():
        start_time = time()
        emb = net(input["image"].to(DEVICE).unsqueeze(0))
        t += time() - start_time
print(f"FPS: {(len(test_ds) / t):.1f}")

100%|██████████| 39482/39482 [08:35<00:00, 76.58it/s]

0.006383046952432113





In [12]:
print(f"FPS: {(len(test_ds) / t):.1f}")

FPS: 156.7
