In [None]:
import numpy as np
import faiss
import torch
from torch.utils.data import DataLoader

from models import *
from datasets import *

In [2]:
# test_dataset = CachedDataset(cache_file="data/dataset_cache/model2/cached_test_model2.pt")
test_dataset = LazyCachedDataset(cache_dir="data/dataset_cache/dmresnet/test")
test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=collate_fn_dif_length,
)

In [5]:
def process_test(model, epoch, top_k, device):
    # model.load_state_dict(torch.load(f"outputs/ResNet+Attention/model3_epoch{epoch}.pth"))
    ckpt_path = f"outputs/checkpoints/epoch{epoch}.pth"
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt["model"], strict=True)
    model.eval()

    test_doc_embs = []
    test_query_embs = []

    with torch.no_grad():
        for test_batch in test_loader:
            query_feat_batch, doc_feat_batch = test_batch
            query_feat_batch = query_feat_batch.to(device)
            doc_feat_batch = doc_feat_batch.to(device)

            query_emb = model(query_feat_batch)
            doc_emb = model(doc_feat_batch)

            test_query_embs.append(query_emb.cpu().numpy())
            test_doc_embs.append(doc_emb.cpu().numpy())

    test_query_embs = np.concatenate(test_query_embs, axis=0).astype(np.float32)
    test_doc_embs = np.concatenate(test_doc_embs, axis=0).astype(np.float32)

    dim = test_doc_embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(test_doc_embs)

    D, I = index.search(test_query_embs, top_k)

    TP = 0
    num_queries = test_query_embs.shape[0]
    for i in range(num_queries):
        # print(f"Query {i} Top-{top_k} Neighbors: {I[i]}")
        if i in I[i]:
            TP += 1

    print(f"Epoch {epoch}: Model accuracy is {TP / num_queries}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DimensionMaskedResNet().to(device)

for epoch in range(1, 101):
    process_test(model, epoch, 1, device)