In [1]:
import os, random, json, pickle
import numpy as np
import torch
from tqdm import tqdm
from collections import Counter, defaultdict

In [2]:
data_dir = "/home/yingshac/CYS/WebQnA/WebQnA_data_new/"

# Compute scores for BM25 full-scale retrieval

In [4]:
### Load val/test Retrieval_answers
split='test'
imgRetrievalAns = pickle.load(open(os.path.join(data_dir, "CLIP_retrieval_experiments/{}_imgRetrievalAns.pkl".format(split)), "rb"))
txtRetrievalAns = pickle.load(open(os.path.join(data_dir, "CLIP_retrieval_experiments/{}_txtRetrievalAns.pkl".format(split)), "rb"))

In [5]:
QimgBM25_top2_i = torch.load("result_matrix_II.pt".format(split))
QtxtBM25_top2_i = torch.load("result_matrix_TT.pt".format(split))

In [8]:
def compute_retrieval_metrics(pred, gth):

    common = len(set(pred).intersection(gth))
    RE = common / (len(gth)) 
    PR = common / (len(pred)) # No protection against division by zero because it's assumed that CLIP never gives empty output
    F1 = 2*PR*RE / (PR + RE + 1e-5)
    return F1, RE, PR

In [11]:
top2_perf = [compute_retrieval_metrics(set(QimgBM25_top2_i[i].numpy()), set(imgRetrievalAns[i])) for i in range(len(imgRetrievalAns))]
print("BM25 Top2 img queries: F1={:.4f}, RE={:.4f}, PR={:.4f}".format(np.mean([P[0] for P in top2_perf]), np.mean([P[1] for P in top2_perf]), np.mean([P[2] for P in top2_perf]) ))

top2_perf = [compute_retrieval_metrics(set(QtxtBM25_top2_i[i].numpy()), set(txtRetrievalAns[i])) for i in range(len(txtRetrievalAns))]
print("BN25 Top2 txt queries: F1={:.4f}, RE={:.4f}, PR={:.4f}".format(np.mean([P[0] for P in top2_perf]), np.mean([P[1] for P in top2_perf]), np.mean([P[2] for P in top2_perf]) ))


BM25 Top2 img queries: F1=0.2019, RE=0.2574, PR=0.1742
BN25 Top2 txt queries: F1=0.3342, RE=0.3334, PR=0.3356


In [16]:
### Refer to I/T_unknown_modality_bm25.py. 
### The I_corpus was appended after T_corpus. So all img indices are shifted by 544489
imgRetrievalAns_unknownM = {}
for i in imgRetrievalAns:
    imgRetrievalAns_unknownM[i] = [j+544489 for j in imgRetrievalAns[i]]

In [6]:
QimgBM25_unknownM_top2_i = torch.load("result_matrix_Iall.pt".format(split))
QtxtBM25_unknownM_top2_i = torch.load("result_matrix_Tall.pt".format(split))

In [20]:
top2_perf = [compute_retrieval_metrics(set(QimgBM25_unknownM_top2_i[i].numpy()), set(imgRetrievalAns_unknownM[i])) for i in range(len(imgRetrievalAns_unknownM))]
print("BM25 Top2 unknownM img queries: F1={:.4f}, RE={:.4f}, PR={:.4f}".format(np.mean([P[0] for P in top2_perf]), np.mean([P[1] for P in top2_perf]), np.mean([P[2] for P in top2_perf]) ))

top2_perf = [compute_retrieval_metrics(set(QtxtBM25_unknownM_top2_i[i].numpy()), set(txtRetrievalAns[i])) for i in range(len(txtRetrievalAns))]
print("BN25 Top2 unknownM txt queries: F1={:.4f}, RE={:.4f}, PR={:.4f}".format(np.mean([P[0] for P in top2_perf]), np.mean([P[1] for P in top2_perf]), np.mean([P[2] for P in top2_perf]) ))


BM25 Top2 unknownM img queries: F1=0.2043, RE=0.2597, PR=0.1767
BN25 Top2 unknownM txt queries: F1=0.2815, RE=0.2810, PR=0.2825


# BM25 Restricted Retrieval

In [7]:
dataset = json.load(open("/home/yingshac/CYS/WebQnA/WebQnA_data_new/WebQA_0904_concat_newimgid_newguid.json", "r"))
print(Counter([dataset[k]['split'] for k in dataset]))
print(len(set([dataset[k]['Guid'] for k in dataset])))
print(Counter([dataset[k]['Qcate'] for k in dataset]))


Counter({'train': 36766, 'test': 7540, 'val': 4966})
49272
Counter({'text': 24343, 'YesNo': 8255, 'Others': 6470, 'choose': 5201, 'number': 2318, 'color': 2058, 'shape': 627})


In [22]:
from gensim import corpora
from gensim.summarization import bm25

In [26]:
retricted_bm25_scores = {'Qimg': [], 'Qtxt': []}
for g in tqdm(list(dataset.keys())):
    if not dataset[g]['split'] == 'test': continue
    key = 'Qtxt' if dataset[g]['Qcate'] == 'text' else 'Qimg'
    corpus = []
    if key == 'Qtxt':
        corpus.extend([x['fact'].split() for x in dataset[g]['txt_posFacts']])
        ans = list(range(len(corpus)))
    else:
        corpus.extend([x['caption'].split() for x in dataset[g]['img_posFacts']])
        ans = list(range(len(corpus)))
    corpus.extend([x['fact'].split() for x in dataset[g]['txt_negFacts']])
    corpus.extend([x['caption'].split() for x in dataset[g]['img_negFacts']])

    dictionary = corpora.Dictionary(corpus)
    corpus = [dictionary.doc2bow(text) for text in corpus]
    bm25_obj = bm25.BM25(corpus)

    query_doc = dictionary.doc2bow(dataset[g]['Q'].replace('"', '').split())
    scores = bm25_obj.get_scores(query_doc)
    best_docs = sorted(range(len(scores)), key=lambda i: scores[i])[-2:]

    retricted_bm25_scores[key].append(compute_retrieval_metrics(set(best_docs), set(ans)))


100%|██████████| 49272/49272 [00:16<00:00, 3056.96it/s]


In [27]:
print(len(retricted_bm25_scores['Qimg']), len(retricted_bm25_scores['Qtxt']))
print("BM25 Top2 unknownM img queries: F1={:.4f}, RE={:.4f}, PR={:.4f}".format(np.mean([P[0] for P in retricted_bm25_scores['Qimg']]), np.mean([P[1] for P in retricted_bm25_scores['Qimg']]), np.mean([P[2] for P in retricted_bm25_scores['Qimg']]) ))
print("BM25 Top2 unknownM txt queries: F1={:.4f}, RE={:.4f}, PR={:.4f}".format(np.mean([P[0] for P in retricted_bm25_scores['Qtxt']]), np.mean([P[1] for P in retricted_bm25_scores['Qtxt']]), np.mean([P[2] for P in retricted_bm25_scores['Qtxt']]) ))


3464 4076
BM25 Top2 unknownM img queries: F1=0.2561, RE=0.3206, PR=0.2239
BM25 Top2 unknownM txt queries: F1=0.4375, RE=0.4362, PR=0.4398


# Generate submission files to eval.ai to double check retrieval results

In [9]:
### Load uniid2fact
fact2uniid = pickle.load(open(os.path.join(data_dir, "CLIP_retrieval_experiments/fact2uniid.pkl"), "rb"))
uniid2fact = {i:fact for fact, i in fact2uniid.items()}

# Read test_imgguid2qid, test_txtguid2qid
test_imgguid2qid = pickle.load(open(os.path.join(data_dir, "CLIP_retrieval_experiments/test_imgguid2qid.pkl"), "rb"))
test_txtguid2qid = pickle.load(open(os.path.join(data_dir, "CLIP_retrieval_experiments/test_txtguid2qid.pkl"), "rb"))

In [16]:
evalai_submission_BM25_unknownM = {}
for g in dataset:
    if not dataset[g]['split'] == 'test': continue
    if dataset[g]['Qcate'] == 'text':
        retrieved_snippet_ids = []
        retrieved_facts = []
        for x in QtxtBM25_unknownM_top2_i[test_txtguid2qid[g]].tolist():
            if x < 544489: retrieved_facts.append(uniid2fact[x])
        for x in dataset[g]['txt_posFacts']:
            if x['fact'] in retrieved_facts:
                retrieved_snippet_ids.append(x['snippet_id'])
        retrieved_snippet_ids.extend((2-len(retrieved_snippet_ids))*["dummy"])
        evalai_submission_BM25_unknownM[g] = {'sources': retrieved_snippet_ids, 'answer': ""}
    else:
        evalai_submission_BM25_unknownM[g] = \
            {'sources': [ x+30000000-544489 for x in QimgBM25_unknownM_top2_i[test_imgguid2qid[g]].tolist() ], 
            'answer': ""}

json.dump(evalai_submission_BM25_unknownM, open("evalai_submission_BM25_unknownM.json", "w"))
