In [1]:
# Load reverse_query_index, reverse ranker
# Find subset of passages that are not in train set: take the first 100 of them
# For each passage, use reverse_query_index to find top 1000 nearest queries
# Load forward_passage_index, forward ranker
# Use forward_passage_index to find top k passages for each of the query, count += 1 if the passage is in top k
# Record the # of top k queries for each document 
# Plot histogram for distribution and compute the average

In [20]:
from annoy import AnnoyIndex
import torch
import sys
import random

sys.path.insert(0, '/home/jianx/search-exposure/forward_ranker/')
import random
import torch
import numpy as np
import matplotlib.pyplot as plt
from train import generate_sparse
from load_data import obj_reader, obj_writer
import network

from annoy import AnnoyIndex
from utils import print_message

EMBED_SIZE = 256
DEVICE = "cuda:3"
n_passage = 100
n_query = 100
rank = 100

REVERSE_INDEX_PATH = "./results/128pretrained_negative_sample_query_index.ann"
REVERSE_RANKER_PATH = "./results/random_negative_pretrained200_10_500_0.001_256_10.model"
FORWARD_INDEX_PATH = "/home/jianx/data/annoy/128_passage_index.ann"
FORWARD_RANKER_PATH = "/home/jianx/data/results/100_1000_1000_0.001_256_10.model"
PASSAGE_DICT_PATH = "/home/jianx/data/passages.dict"
QUERY_TRAIN_DICT_PATH = "/home/jianx/data/queries_train.dict"
TRAIN_RANK_PATH = "/home/jianx/data/train_data/256_20000_100_100_training.csv"
REVERSE_MAP_PATH = "./results/128pretrained_negative_sample_qid_map.dict"
FORWARD_MAP_PATH = "/home/jianx/data/annoy/128_pid_map.dict"
FORWARD_QUERY_INDEX = "/home/jianx/data/annoy/128_query_index.ann"
FORWARD_QUERY_MAP = "/home/jianx/data/annoy/128_qid_map.dict"

In [2]:
def load_train(path):
    with open(path) as file:
        line = file.readline()
        my_dict = {}
        while line:
            tokens = line.split(",")
            pid = int(tokens[0])
            qid = int(tokens[1])
            rank = int(tokens[2].rstrip())
            if pid not in my_dict:
                my_dict[pid] = {}
            my_dict[pid][qid] = rank
            line = file.readline()
    return my_dict
def load():
    query_dict = obj_reader(QUERY_TRAIN_DICT_PATH)
    passage_dict = obj_reader(PASSAGE_DICT_PATH)
    train_rank_dict = load_train(TRAIN_RANK_PATH)
    return train_rank_dict, query_dict, passage_dict

In [3]:
import torch.nn as nn

NUM_HIDDEN_NODES = 64
NUM_HIDDEN_LAYERS = 3
DROPOUT_RATE = 0.1
FEAT_COUNT = 100000


# Define the network
class DSSM(torch.nn.Module):

    def __init__(self, embed_size):
        super(DSSM, self).__init__()

        layers = []
        last_dim = FEAT_COUNT
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [4]:
# Load reverse query index
reverse_query_index = AnnoyIndex(EMBED_SIZE, 'euclidean')
reverse_query_index.load(REVERSE_INDEX_PATH)
# Load reverse query index mapping dict
reverse_query_map = obj_reader(REVERSE_MAP_PATH)
# Load reverse ranker model
reverse_ranker = DSSM(embed_size=EMBED_SIZE)
reverse_ranker.load_state_dict(torch.load(REVERSE_RANKER_PATH))
reverse_ranker.to(DEVICE)
reverse_ranker.eval()
# Load forward passage index
forward_passage_index = AnnoyIndex(EMBED_SIZE, 'euclidean')
forward_passage_index.load(FORWARD_INDEX_PATH)
# Load forward passage index mapping dict 
forward_passage_map = obj_reader(FORWARD_MAP_PATH)
# Load forward ranker model
forward_ranker = network.DSSM(embed_size=EMBED_SIZE)
forward_ranker.load_state_dict(torch.load(FORWARD_RANKER_PATH))
forward_ranker.to(DEVICE)
forward_ranker.eval()
# Load train_rank, query, passage dict
train_rank_dict, query_dict, passage_dict = load()

In [22]:
# Load forward query index
forward_query_index = AnnoyIndex(EMBED_SIZE, 'euclidean')
forward_query_index.load(FORWARD_QUERY_INDEX)
# Load forward query index mapping dict 
forward_query_map = obj_reader(FORWARD_QUERY_MAP)

In [5]:
# Find subset of passages that are not in train set: take the first 100 of them
train_passage_list = set(list(train_rank_dict.keys()))
all_passage_list = set(list(passage_dict.keys()))
test_passage_list = list(all_passage_list.difference(train_passage_list))

In [36]:
# For each passage, use reverse_query_index to find top 1000 nearest queries
# Use forward_passage_index to find top k passages for each of the query, count += 1 if the passage is in top k
# Record the # of top k queries for each document 
# Plot histogram for distribution and compute the average
random_test_passage = random.sample(test_passage_list, n_passage)
# random_test_passage = random.sample(list(train_passage_list), n_passage)
counter = 0
rankings = []
total_sum = 0
match_count = 0
non_zero = []
forward_benchmark = 0
test_results_forward = {}
test_results_reverse = {}
for i, pid in enumerate(random_test_passage):
    temp_results_forward = {}
    temp_results_reverse = {}
    temp_reverse = 0
    temp_forward = 0
    print_message("Processing passage No. " + str(i+1) + "/" + str(n_passage))
    embedding = reverse_ranker(generate_sparse(passage_dict[pid]).to(DEVICE)).detach()
    nearest_queries = reverse_query_index.get_nns_by_vector(embedding, n_query)
    embedding_forward = forward_ranker(generate_sparse(passage_dict[pid]).to(DEVICE)).detach()
    nearest_queries_forward = forward_query_index.get_nns_by_vector(embedding_forward, n_query)
    matching_list = []
    print("Reverse:")
    for i, annoy_qid in enumerate(nearest_queries):
        qid = reverse_query_map[annoy_qid]
        top_list = forward_passage_index.get_nns_by_vector(forward_ranker(generate_sparse(query_dict[qid]).to(DEVICE)).detach(),
                                                   rank)
        is_matched = False
        for j, annoy_pid in enumerate(top_list):
            if forward_passage_map[annoy_pid] == pid:
                print("qid {}: rank {}".format(qid,j+1))
                temp_results_reverse[qid] = j+1
                matching_list.append(j + 1)
                non_zero.append(j + 1)
                is_matched = True
                match_count += 1
                temp_reverse += 1
                break
            if not is_matched:
                matching_list.append(0)
    print("Forward:")
    for k, annoy_qid_f in enumerate(nearest_queries_forward):
        qid_f = forward_query_map[annoy_qid_f]
        top_list_f = forward_passage_index.get_nns_by_vector(forward_ranker(generate_sparse(query_dict[qid_f]).to(DEVICE)).detach(),
                                                   rank)
        is_matched = False
        for t, annoy_pid_f in enumerate(top_list_f):
            if forward_passage_map[annoy_pid_f] == pid:
                print("qid {}: rank {}".format(qid_f,t+1))
                temp_results_forward[qid_f] = t+1
                is_matched = True
                temp_forward += 1
                break
            if not is_matched:
                matching_list.append(0)    
    total_sum += sum(matching_list)
    rankings.append(matching_list)
    counter += 1
    forward_benchmark += temp_forward
    test_results_forward[pid] = temp_results_forward
    test_results_reverse[pid] = temp_results_reverse
    print("{}: Forward: {} Reverse: {}".format(pid, temp_forward, temp_reverse))
print("Forward: {} Reverse: {}".format(forward_benchmark/ n_passage, match_count / n_passage))

[Jul 21, 02:55:49] Processing passage No. 1/100
Reverse:
qid 587393: rank 37
qid 1150018: rank 16
qid 240514: rank 30
Forward:
qid 240514: rank 30
qid 1150018: rank 16
qid 587393: rank 37
6474225: Forward: 3 Reverse: 3
[Jul 21, 02:55:51] Processing passage No. 2/100
Reverse:
qid 450271: rank 34
qid 562627: rank 82
qid 546496: rank 87
qid 415188: rank 5
Forward:
qid 415188: rank 5
qid 450271: rank 34
qid 562627: rank 82
qid 476163: rank 58
qid 546496: rank 87
qid 906051: rank 59
qid 62366: rank 31
qid 1138514: rank 92
4171143: Forward: 8 Reverse: 4
[Jul 21, 02:55:53] Processing passage No. 3/100
Reverse:
qid 1180372: rank 35
qid 339862: rank 51
qid 251711: rank 40
qid 260593: rank 39
Forward:
qid 718552: rank 12
qid 251711: rank 40
qid 1180372: rank 35
qid 242188: rank 16
qid 260593: rank 39
qid 938546: rank 37
qid 339862: rank 51
5896825: Forward: 7 Reverse: 4
[Jul 21, 02:55:55] Processing passage No. 4/100
Reverse:
qid 1036815: rank 14
qid 1182339: rank 24
qid 935912: rank 22
qid 9359

qid 761621: rank 89
qid 172731: rank 98
8749334: Forward: 6 Reverse: 2
[Jul 21, 02:56:35] Processing passage No. 23/100
Reverse:
Forward:
qid 893076: rank 85
qid 794163: rank 88
3813865: Forward: 2 Reverse: 0
[Jul 21, 02:56:37] Processing passage No. 24/100
Reverse:
qid 281303: rank 44
qid 923306: rank 96
qid 772925: rank 91
qid 939083: rank 20
qid 775486: rank 81
Forward:
qid 599866: rank 9
qid 939083: rank 20
qid 772925: rank 91
qid 342689: rank 41
qid 908547: rank 37
qid 775486: rank 81
qid 281303: rank 44
qid 1150564: rank 68
qid 1150565: rank 68
qid 923306: rank 96
7238630: Forward: 10 Reverse: 5
[Jul 21, 02:56:39] Processing passage No. 25/100
Reverse:
Forward:
8670376: Forward: 0 Reverse: 0
[Jul 21, 02:56:41] Processing passage No. 26/100
Reverse:
Forward:
7821922: Forward: 0 Reverse: 0
[Jul 21, 02:56:43] Processing passage No. 27/100
Reverse:
qid 543758: rank 99
qid 999860: rank 34
qid 309333: rank 24
qid 880855: rank 35
qid 209652: rank 48
qid 229505: rank 63
Forward:
qid 8908

qid 928398: rank 65
qid 920669: rank 47
qid 703532: rank 61
qid 121153: rank 32
qid 14479: rank 60
qid 134985: rank 64
qid 1144325: rank 88
1617941: Forward: 14 Reverse: 5
[Jul 21, 02:57:19] Processing passage No. 44/100
Reverse:
qid 35661: rank 99
qid 903224: rank 1
qid 442117: rank 28
qid 920668: rank 96
qid 706434: rank 74
qid 863354: rank 74
Forward:
qid 903224: rank 1
qid 442117: rank 28
qid 920668: rank 96
qid 35661: rank 99
qid 706434: rank 74
qid 863354: rank 74
8749176: Forward: 6 Reverse: 6
[Jul 21, 02:57:21] Processing passage No. 45/100
Reverse:
qid 1007325: rank 12
qid 838045: rank 24
qid 842135: rank 30
qid 921701: rank 71
qid 190450: rank 12
qid 513248: rank 51
qid 448303: rank 51
qid 1147065: rank 91
qid 864417: rank 18
qid 423106: rank 85
qid 1019129: rank 47
qid 834291: rank 74
qid 461815: rank 43
Forward:
qid 838045: rank 24
qid 1019129: rank 47
qid 423106: rank 85
qid 18788: rank 55
qid 864417: rank 18
qid 513248: rank 51
qid 448303: rank 51
qid 1007325: rank 12
qid

qid 339883: rank 25
qid 1180857: rank 46
qid 378474: rank 9
qid 258398: rank 52
Forward:
qid 265565: rank 1
qid 339883: rank 25
qid 378474: rank 9
qid 233909: rank 46
qid 267560: rank 15
qid 258308: rank 36
qid 367982: rank 23
qid 245749: rank 8
qid 250998: rank 41
qid 254342: rank 29
qid 257521: rank 60
qid 273557: rank 60
qid 259740: rank 48
qid 257375: rank 99
qid 252659: rank 98
qid 248808: rank 39
qid 1180857: rank 46
qid 252213: rank 39
qid 254603: rank 98
qid 257254: rank 95
qid 239336: rank 97
qid 265343: rank 46
qid 258398: rank 52
8270132: Forward: 23 Reverse: 12
[Jul 21, 02:58:02] Processing passage No. 65/100
Reverse:
qid 446293: rank 31
qid 879201: rank 59
qid 202914: rank 86
qid 545816: rank 42
qid 742136: rank 87
qid 995735: rank 60
qid 1148358: rank 52
qid 203143: rank 92
qid 1168142: rank 16
qid 1157945: rank 17
qid 203196: rank 3
Forward:
qid 995735: rank 60
qid 203196: rank 3
qid 245946: rank 12
qid 922207: rank 11
qid 1157945: rank 17
qid 545816: rank 42
qid 909483:

1659485: Forward: 3 Reverse: 3
[Jul 21, 02:58:51] Processing passage No. 87/100
Reverse:
Forward:
qid 335526: rank 17
qid 978814: rank 19
qid 333347: rank 15
qid 56674: rank 6
qid 262366: rank 2
qid 13492: rank 19
qid 772512: rank 7
qid 1155628: rank 64
qid 364428: rank 78
qid 569279: rank 74
qid 603593: rank 68
qid 954801: rank 96
qid 602067: rank 33
qid 228806: rank 87
qid 13296: rank 71
qid 1030901: rank 96
qid 262678: rank 49
qid 48124: rank 72
5788756: Forward: 18 Reverse: 0
[Jul 21, 02:58:53] Processing passage No. 88/100
Reverse:
qid 691897: rank 49
qid 184692: rank 27
qid 960503: rank 57
qid 570034: rank 3
qid 1028965: rank 30
qid 601087: rank 10
qid 745129: rank 40
qid 955638: rank 78
qid 1062047: rank 81
qid 578835: rank 8
qid 1172486: rank 63
qid 142141: rank 84
Forward:
qid 570034: rank 3
qid 184692: rank 27
qid 601087: rank 10
qid 52050: rank 33
qid 578835: rank 8
qid 622218: rank 27
qid 438930: rank 26
qid 960503: rank 57
qid 557557: rank 40
qid 139585: rank 25
qid 955638

In [38]:
# For each passage, use reverse_query_index to find top 1000 nearest queries
# Use forward_passage_index to find top k passages for each of the query, count += 1 if the passage is in top k
# Record the # of top k queries for each document 
# Plot histogram for distribution and compute the average
# random_test_passage = random.sample(test_passage_list, n_passage)
random_test_passage = random.sample(list(train_passage_list), n_passage)
counter = 0
rankings = []
total_sum = 0
match_count = 0
non_zero = []
forward_benchmark = 0
train_results_forward = {}
train_results_reverse = {}
for i, pid in enumerate(random_test_passage):
    temp_results_forward = train_rank_dict[pid]
    temp_results_reverse = {}
    temp_forward = len(temp_results_forward)
    forward_benchmark += temp_forward
    temp_reverse = 0
    print_message("Processing passage No. " + str(i+1) + "/" + str(n_passage))
    embedding = reverse_ranker(generate_sparse(passage_dict[pid]).to(DEVICE)).detach()
    nearest_queries = reverse_query_index.get_nns_by_vector(embedding, n_query)
    matching_list = []

    for i, annoy_qid in enumerate(nearest_queries):
        qid = reverse_query_map[annoy_qid]
        top_list = forward_passage_index.get_nns_by_vector(forward_ranker(generate_sparse(query_dict[qid]).to(DEVICE)).detach(),
                                                   rank)
        is_matched = False
        for j, annoy_pid in enumerate(top_list):
            if forward_passage_map[annoy_pid] == pid:
#                 print("Match!!!!! Rank: " + str(j + 1))
                temp_results_reverse[qid] = j+1
                matching_list.append(j + 1)
                non_zero.append(j + 1)
                is_matched = True
                match_count += 1
                temp_reverse += 1
                break
            if not is_matched:
                matching_list.append(0)
    total_sum += sum(matching_list)
    rankings.append(matching_list)
    counter += 1
    train_results_forward[pid] = temp_results_forward
    train_results_reverse[pid] = temp_results_reverse
    print("Reverse: {}".format(temp_results_reverse))
    print("Forward: {}".format(temp_results_forward))
    print("{}: Forward: {} Reverse: {}".format(pid, temp_forward, temp_reverse))
print("Forward: {} Reverse: {}".format(forward_benchmark/ n_passage, match_count / n_passage))

[Jul 21, 15:49:43] Processing passage No. 1/100
Reverse: {812377: 11}
Forward: {812377: 11, 1184550: 36}
2556448: Forward: 2 Reverse: 1
[Jul 21, 15:49:44] Processing passage No. 2/100
Reverse: {659304: 42}
Forward: {572585: 98, 659304: 42}
6385903: Forward: 2 Reverse: 1
[Jul 21, 15:49:45] Processing passage No. 3/100
Reverse: {103223: 10, 1179476: 11, 231432: 5, 712750: 35, 321384: 54, 311467: 98}
Forward: {174541: 3, 231432: 5, 325053: 19, 1179476: 11, 712750: 35, 721317: 20, 2359: 16, 103223: 10, 1177325: 5, 3977: 48, 103526: 14, 109000: 32, 316373: 46, 659276: 68, 78750: 16, 321384: 54, 320415: 54, 510627: 40, 33049: 72, 893229: 55, 1007372: 91, 230924: 96, 317411: 41, 321770: 86, 302607: 69, 468848: 52, 311467: 98, 323921: 76, 326642: 62, 328383: 60, 32424: 99}
533568: Forward: 31 Reverse: 6
[Jul 21, 15:49:46] Processing passage No. 4/100
Reverse: {835443: 40}
Forward: {835443: 40}
4286617: Forward: 1 Reverse: 1
[Jul 21, 15:49:47] Processing passage No. 5/100
Reverse: {1001137: 24,

Reverse: {1011940: 75, 168244: 35, 698400: 69, 338531: 47}
Forward: {930150: 10, 680522: 16, 1066899: 85, 594860: 45, 168244: 35, 338531: 47, 1011940: 75, 225303: 48, 954852: 84, 145278: 33, 198625: 42, 698400: 69, 501308: 89}
2622709: Forward: 13 Reverse: 4
[Jul 21, 15:50:21] Processing passage No. 35/100
Reverse: {172519: 45, 347670: 30, 60978: 26}
Forward: {60978: 26, 68289: 82, 615465: 60, 592910: 72, 846786: 89, 855011: 87, 347670: 30, 172519: 45, 226812: 71, 809644: 98, 429854: 85}
5352995: Forward: 11 Reverse: 3
[Jul 21, 15:50:22] Processing passage No. 36/100
Reverse: {369551: 77}
Forward: {369551: 77}
2865396: Forward: 1 Reverse: 1
[Jul 21, 15:50:23] Processing passage No. 37/100
Reverse: {685601: 23, 991220: 23, 823193: 2, 1073682: 86, 717375: 9, 453081: 21, 919165: 11, 990864: 83, 537285: 53, 630440: 73, 494713: 13}
Forward: {823193: 2, 768945: 22, 185398: 37, 717375: 9, 157400: 13, 919165: 11, 426270: 18, 494713: 13, 453081: 21, 685601: 23, 991220: 23, 699402: 19, 712472: 1

Reverse: {872676: 76, 493210: 32, 904355: 39, 490894: 40, 1150862: 81}
Forward: {493210: 32, 872676: 76, 490894: 40, 904355: 39, 1150862: 81, 151264: 100}
6345514: Forward: 6 Reverse: 5
[Jul 21, 15:51:03] Processing passage No. 73/100
Reverse: {174214: 4}
Forward: {174214: 4, 358416: 45}
3136055: Forward: 2 Reverse: 1
[Jul 21, 15:51:04] Processing passage No. 74/100
Reverse: {8562: 25, 481208: 71, 745922: 47, 428445: 72}
Forward: {428445: 72, 745922: 47, 697545: 44, 481208: 71, 8562: 25, 1065153: 89}
7405126: Forward: 6 Reverse: 4
[Jul 21, 15:51:05] Processing passage No. 75/100
Reverse: {892430: 40}
Forward: {892430: 40}
7457199: Forward: 1 Reverse: 1
[Jul 21, 15:51:07] Processing passage No. 76/100
Reverse: {222110: 9, 1163352: 49, 759274: 6}
Forward: {222110: 9, 759274: 6, 691837: 55, 389248: 63, 1163352: 49, 552295: 98, 169181: 5}
3027363: Forward: 7 Reverse: 3
[Jul 21, 15:51:08] Processing passage No. 77/100
Reverse: {1021436: 25, 679970: 43, 659903: 24}
Forward: {34265: 42, 74798

In [11]:
non_zero_rankings = []
for i in range(len(rankings)):
    for j in rankings[i]:
        if j != 0:
            non_zero_rankings.append(j)
print(non_zero_rankings)

[21, 72, 7, 3, 81, 52, 74, 3, 33, 27, 30, 19, 24, 91, 25, 94, 37, 71, 41, 41, 41, 66, 8, 95, 64, 64, 64, 64, 23, 74, 22, 43, 12, 75, 37, 67, 3, 90, 71, 29, 27, 75, 37, 65, 87, 82, 5, 6, 6, 48, 26, 16, 23, 32, 55, 31, 12, 12, 52, 89, 17, 98, 98, 28, 4, 27, 17, 20, 27, 27, 51, 29, 79, 46, 47, 75, 25, 39, 7, 18, 8, 80, 5, 25, 71, 44, 43, 62, 57, 30, 33, 90, 28, 9, 11, 61, 30, 49, 64, 64, 66, 75, 44, 4, 17, 10, 73, 11, 2, 29, 62, 82, 31, 80, 84, 61, 61, 16, 80, 6, 6, 24, 24, 24, 45, 67, 9, 93, 18, 5, 87, 24, 47, 67, 46, 1, 1, 1, 1, 2, 97, 26, 8, 43, 43, 19, 19, 79, 77, 76, 79, 90, 40, 81, 50, 52, 10, 10, 71, 73, 42, 31, 64, 2, 46, 17, 37, 46, 37, 63, 49, 91, 87, 7, 28, 36, 72, 42, 47, 69, 45, 20, 84, 13, 49, 20, 79, 17, 8, 6, 72, 82, 77, 87, 40, 18, 54, 58, 32, 61, 9, 3, 17, 86, 50, 24, 75, 68, 31, 52, 22, 70, 27, 61, 31, 23, 91, 1, 42, 24, 8, 10, 10, 31, 96, 92, 10, 49, 28, 55, 98, 33, 88, 79, 5, 47, 50, 98, 38, 19, 75, 75, 8, 52, 86, 36, 42, 59, 13, 1, 58, 45, 78, 97, 5, 28, 59, 2, 77, 2

In [12]:
len(non_zero_rankings) / len(rankings)

4.76