# OBS Dictionary Retrieval

Reproduce the Top-N retrieval accuracy from the paper.

**Reference**: *Decoding Ancient Oracle Bone Script via Generative Dictionary Retrieval*


In [1]:
import os
import numpy as np
import torch
from tqdm.auto import tqdm  # 自动选择最佳进度条版本

# Load pre-extracted features
features_dir = '../data/features'
dict_data = torch.load(os.path.join(features_dir, 'dict_features.pt'))
query_data = torch.load(os.path.join(features_dir, 'query_features.pt'))

query_features = query_data['features']
query_labels = query_data['labels']
gallery_features = dict_data['features']
gallery_labels = dict_data['labels']

print(f"Query: {len(query_labels)} images")
print(f"Gallery: {len(gallery_labels)} images")


Query: 6830 images
Gallery: 6830 images


  dict_data = torch.load(os.path.join(features_dir, 'dict_features.pt'))
  query_data = torch.load(os.path.join(features_dir, 'query_features.pt'))


In [2]:
def voting_rerank(qf, ql, gf, gl):
    """Voting-based reranking (matches infer.py exactly)."""
    query = qf.view(-1, 1)
    score = torch.mm(gf, query).squeeze(1).cpu().numpy()
    index = np.argsort(score)[::-1]
    
    query_index = np.argwhere(gl == ql)
    good_index = query_index
    junk_index = np.argwhere(gl == -1)
    
    if good_index.size == 0:
        return [0] * len(index)
    
    mask = np.in1d(index, junk_index, invert=True)
    index = index[mask]
    
    mask = np.in1d(index, good_index)
    rows_good = np.argwhere(mask == True).flatten()
    
    if len(rows_good) == 0:
        return [0] * len(index)
    
    calculated_label = []
    label_score = {}
    for i in range(rows_good[0] + 1):
        current_label = gl[index[i]]
        if current_label not in calculated_label:
            calculated_label.append(current_label)
            label_index = np.argwhere(gl == current_label)
            label_mask = np.in1d(index, label_index)
            rows_good_label = np.argwhere(label_mask == True).flatten()
            label_score[current_label] = sum(rows_good_label) / len(rows_good_label)
    
    sorted_labels = [k for k, v in sorted(label_score.items(), key=lambda x: x[1])]
    new_ql_rank = len(sorted_labels)
    for i in range(len(sorted_labels)):
        if sorted_labels[i] == ql:
            new_ql_rank = i
            break
    
    new_rank = [0] * len(index)
    for i in range(len(index)):
        if i >= new_ql_rank:
            new_rank[i] = 1
    
    return new_rank


In [3]:
# Compute voting rerank results (matches infer.py)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
query_feature = query_features.to(device)
gallery_feature = gallery_features.to(device)
gallery_label_np = np.array(gallery_labels)

new_rank_all = [0] * len(gallery_labels)

for i in tqdm(range(len(query_labels)), desc="Voting rerank"):
    new_rank_tmp = voting_rerank(
        query_feature[i], 
        query_labels[i], 
        gallery_feature, 
        gallery_label_np
    )
    new_rank_all = [x + y for x, y in zip(new_rank_all, new_rank_tmp)]

new_rank_all = [x / len(query_labels) for x in new_rank_all]

# Output Top-N results (same format as infer.py)
num_gallery = len(gallery_labels)
print('\nTop1:%.2f Top10:%.2f Top20:%.2f Top50:%.2f Top100:%.2f' % (
    new_rank_all[0] * 100,
    new_rank_all[min(9, num_gallery - 1)] * 100,
    new_rank_all[min(19, num_gallery - 1)] * 100,
    new_rank_all[min(49, num_gallery - 1)] * 100,
    new_rank_all[min(99, num_gallery - 1)] * 100,
))


Voting rerank:   0%|          | 0/6830 [00:00<?, ?it/s]


Top1:21.20 Top10:54.33 Top20:66.76 Top50:86.15 Top100:96.85
