In [1]:
# Load reverse_query_index, reverse ranker
# Find subset of passages that are not in train set: take the first 100 of them
# For each passage, use reverse_query_index to find top 1000 nearest queries
# Load forward_passage_index, forward ranker
# Use forward_passage_index to find top k passages for each of the query, count += 1 if the passage is in top k
# Record the # of top k queries for each document 
# Plot histogram for distribution and compute the average

In [2]:
from annoy import AnnoyIndex
import torch
import sys
import random

sys.path.insert(0, '/home/jianx/search-exposure/forward_ranker/')
import random
import torch
import numpy as np
import matplotlib.pyplot as plt
from train import generate_sparse
from load_data import obj_reader, obj_writer
import network

from annoy import AnnoyIndex
from utils import print_message

EMBED_SIZE = 256
DEVICE = "cuda:1"
n_passage = 10000
n_query = 100
rank = 100

REVERSE_INDEX_PATH = "./results/128load_forward_query_index.ann"
REVERSE_RANKER_PATH = "./results/reverse_load_forward200_50_500_0.001_256_10.model"
FORWARD_INDEX_PATH = "/home/jianx/data/annoy/128_passage_index.ann"
FORWARD_RANKER_PATH = "/home/jianx/data/results/100_1000_1000_0.001_256_10.model"
PASSAGE_DICT_PATH = "/home/jianx/data/passages.dict"
QUERY_TRAIN_DICT_PATH = "/home/jianx/data/queries_train.dict"
TRAIN_RANK_PATH = "/home/jianx/data/train_data/256_20000_100_100_training.csv"
REVERSE_MAP_PATH = "./results/128load_forward_qid_map.dict"
FORWARD_MAP_PATH = "/home/jianx/data/annoy/128_pid_map.dict"

In [3]:
def load_train(path):
    with open(path) as file:
        line = file.readline()
        my_dict = {}
        while line:
            tokens = line.split(",")
            pid = int(tokens[0])
            qid = int(tokens[1])
            rank = int(tokens[2].rstrip())
            if pid not in my_dict:
                my_dict[pid] = {}
            my_dict[pid][qid] = rank
            line = file.readline()
    return my_dict
def load():
    query_dict = obj_reader(QUERY_TRAIN_DICT_PATH)
    passage_dict = obj_reader(PASSAGE_DICT_PATH)
    train_rank_dict = load_train(TRAIN_RANK_PATH)
    return train_rank_dict, query_dict, passage_dict

In [4]:
import torch.nn as nn

NUM_HIDDEN_NODES = 512
NUM_HIDDEN_LAYERS = 2
DROPOUT_RATE = 0.2
FEAT_COUNT = 256


# Define the network
class DSSM(torch.nn.Module):

    def __init__(self, embed_size):
        super(DSSM, self).__init__()

        layers = []
        last_dim = FEAT_COUNT
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [5]:
# Load reverse query index
reverse_query_index = AnnoyIndex(EMBED_SIZE, 'euclidean')
reverse_query_index.load(REVERSE_INDEX_PATH)
# Load reverse query index mapping dict
reverse_query_map = obj_reader(REVERSE_MAP_PATH)
# Load reverse ranker model
reverse_ranker = DSSM(embed_size=EMBED_SIZE)
reverse_ranker.load_state_dict(torch.load(REVERSE_RANKER_PATH))
reverse_ranker.to(DEVICE)
reverse_ranker.eval()
# Load forward passage index
forward_passage_index = AnnoyIndex(EMBED_SIZE, 'euclidean')
forward_passage_index.load(FORWARD_INDEX_PATH)
# Load forward passage index mapping dict 
forward_passage_map = obj_reader(FORWARD_MAP_PATH)
# Load forward ranker model
forward_ranker = network.DSSM(embed_size=EMBED_SIZE)
forward_ranker.load_state_dict(torch.load(FORWARD_RANKER_PATH))
forward_ranker.to(DEVICE)
forward_ranker.eval()
# Load train_rank, query, passage dict
train_rank_dict, query_dict, passage_dict = load()

In [6]:
# Find subset of passages that are not in train set: take the first 100 of them
train_passage_list = set(list(train_rank_dict.keys()))
all_passage_list = set(list(passage_dict.keys()))
test_passage_list = list(all_passage_list.difference(train_passage_list))

In [22]:
# For each passage, use reverse_query_index to find top 1000 nearest queries
# Use forward_passage_index to find top k passages for each of the query, count += 1 if the passage is in top k
# Record the # of top k queries for each document 
# Plot histogram for distribution and compute the average
random_test_passage = random.sample(test_passage_list, n_passage)
# random_test_passage = random.sample(list(train_passage_list), n_passage)
counter = 0
rankings = []
total_sum = 0
match_count = 0
non_zero = []
forward_count = []
for i, pid in enumerate(random_test_passage):
    print_message("Processing passage No. " + str(i+1) + "/" + str(n_passage))
    embedding = reverse_ranker(forward_ranker(generate_sparse(passage_dict[pid]).to(DEVICE)).detach()).detach()
    nearest_queries = reverse_query_index.get_nns_by_vector(embedding, n_query)
    matching_list = []
    temp_count = 0
    if pid in list(train_passage_list):
        forward_count.append(len(train_rank_dict[pid]))
    
    for i, annoy_qid in enumerate(nearest_queries):
        qid = reverse_query_map[annoy_qid]
        top_list = forward_passage_index.get_nns_by_vector(forward_ranker(generate_sparse(query_dict[qid]).to(DEVICE)).detach(),
                                                   rank)
        
        is_matched = False
        for j, annoy_pid in enumerate(top_list):
            if forward_passage_map[annoy_pid] == pid:
#                 print("Match!!!!! Rank: " + str(j + 1))
                matching_list.append(j + 1)
                non_zero.append(j + 1)
                is_matched = True
                match_count += 1
                temp_count += 1
                break
            if not is_matched:
                matching_list.append(0)
    print(temp_count)
#     print(temp_count, len(train_rank_dict[pid]))
    total_sum += sum(matching_list)
    rankings.append(matching_list)
    counter += 1
print(match_count / n_passage)

[Jul 16, 18:09:46] Processing passage No. 1/10000
2
[Jul 16, 18:09:47] Processing passage No. 2/10000
1
[Jul 16, 18:09:48] Processing passage No. 3/10000
0
[Jul 16, 18:09:49] Processing passage No. 4/10000
2
[Jul 16, 18:09:50] Processing passage No. 5/10000
62
[Jul 16, 18:09:51] Processing passage No. 6/10000
2
[Jul 16, 18:09:52] Processing passage No. 7/10000
2
[Jul 16, 18:09:53] Processing passage No. 8/10000
0
[Jul 16, 18:09:54] Processing passage No. 9/10000
2
[Jul 16, 18:09:55] Processing passage No. 10/10000
9
[Jul 16, 18:09:56] Processing passage No. 11/10000
0
[Jul 16, 18:09:57] Processing passage No. 12/10000
6
[Jul 16, 18:09:58] Processing passage No. 13/10000
1
[Jul 16, 18:09:59] Processing passage No. 14/10000
0
[Jul 16, 18:10:00] Processing passage No. 15/10000
8
[Jul 16, 18:10:01] Processing passage No. 16/10000
0
[Jul 16, 18:10:03] Processing passage No. 17/10000
19
[Jul 16, 18:10:04] Processing passage No. 18/10000
1
[Jul 16, 18:10:05] Processing passage No. 19/10000
15

KeyboardInterrupt: 

In [23]:
non_zero_rankings = []
for i in range(len(rankings)):
    for j in rankings[i]:
        if j != 0:
            non_zero_rankings.append(j)
print(non_zero_rankings)

[10, 85, 5, 79, 73, 2, 2, 19, 5, 4, 9, 9, 15, 18, 2, 14, 14, 33, 43, 34, 34, 21, 4, 31, 21, 21, 11, 14, 11, 24, 20, 51, 3, 20, 40, 56, 22, 42, 19, 73, 54, 58, 10, 41, 56, 56, 68, 43, 81, 63, 85, 25, 56, 62, 44, 25, 27, 82, 48, 32, 81, 81, 70, 70, 70, 83, 57, 4, 21, 41, 8, 72, 29, 20, 96, 96, 56, 80, 54, 34, 16, 21, 43, 11, 59, 9, 39, 44, 34, 48, 80, 15, 16, 9, 28, 26, 75, 1, 1, 78, 1, 16, 16, 37, 66, 6, 6, 31, 58, 27, 55, 94, 71, 70, 81, 30, 74, 65, 24, 19, 6, 34, 61, 61, 21, 43, 64, 77, 51, 69, 64, 69, 96, 3, 3, 72, 73, 96, 71, 60, 20, 87, 44, 61, 11, 92, 8, 96, 71, 47, 21, 68, 34, 18, 77, 97, 40, 79, 87, 78, 76, 99, 88, 51, 76, 63, 79, 48, 9, 91, 74, 33, 21, 68, 39, 85, 21, 54, 17, 71, 29, 86, 85, 16, 26, 100, 94, 32, 38, 82, 73, 71, 55, 8, 38, 38, 22, 14, 24, 16, 24, 52, 75, 24, 83, 100, 89, 45, 80, 40, 3, 6]


In [24]:
len(non_zero_rankings) / len(rankings)

7.066666666666666