In [1]:
import sys

sys.path.insert(0, '/home/jianx/search-exposure/')
import torch
from annoy import AnnoyIndex
import forward_ranker.load_data as load_data
import forward_ranker.train as train

generate_sparse = train.generate_sparse
obj_reader = load_data.obj_reader
obj_writer = load_data.obj_writer

# MODEL_PATH = "./results/reverse200_500_500_0.001_256_10.model"
# MODEL_PATH = "./results/reverse_load_forward200_50_500_0.001_256_10.model"
MODEL_PATH = "./results/200000_samples_0.8_neg200_10_500_0.001_256_10.model"
DEVICE = torch.device("cuda")
EMBED_SIZE = 256
TREE_SIZE = 128


In [2]:
import torch.nn as nn

NUM_HIDDEN_NODES = 64
NUM_HIDDEN_LAYERS = 3
DROPOUT_RATE = 0.1
FEAT_COUNT = 100000


# Define the network
class DSSM(torch.nn.Module):

    def __init__(self, embed_size):
        super(DSSM, self).__init__()

        layers = []
        last_dim = FEAT_COUNT
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [3]:
NET = DSSM(embed_size=EMBED_SIZE)
NET.load_state_dict(torch.load(MODEL_PATH))
NET.to(DEVICE)
NET.eval()

QUERY_TRAIN_DICT_PATH = "/home/jianx/data/queries_train.dict"
QUERY_DICT = obj_reader(QUERY_TRAIN_DICT_PATH)

def generate_annoy_index(net, device, embed_size, dictionary):
    mapping = {}
    i = 0
    index = AnnoyIndex(embed_size, 'euclidean')
    for key, value in dictionary.items():
        if len(value) != 0:
#             index.add_item(i, net(forward_ranker(generate_sparse(value).to(device).detach()).detach()).tolist())
            index.add_item(i, net(generate_sparse(value).to(device)).detach().tolist())
            mapping[i] = key
            i += 1
        if i % 10000 == 0:
            print("Progress: " + str(i) + "/" + str(len(dictionary)) + " " + str(i / len(dictionary)))
    return index, mapping


QID_INDEX, QID_MAP = generate_annoy_index(NET, DEVICE, EMBED_SIZE, QUERY_DICT)
QID_INDEX.build(TREE_SIZE)
# QID_INDEX.save("./results/" + str(TREE_SIZE) + "_query_index.ann")
# obj_writer(QID_MAP, "./results/" + str(TREE_SIZE) + "_qid_map.dict")
QID_INDEX.save("./results/" + str(TREE_SIZE) + "_200000_0.8_index.ann")
obj_writer(QID_MAP, "./results/" + str(TREE_SIZE) + "_200000_0.8_map.dict")

Progress: 10000/808731 0.012365050925462237
Progress: 20000/808731 0.024730101850924474
Progress: 30000/808731 0.03709515277638671
Progress: 40000/808731 0.04946020370184895
Progress: 50000/808731 0.061825254627311185
Progress: 60000/808731 0.07419030555277342
Progress: 70000/808731 0.08655535647823566
Progress: 80000/808731 0.0989204074036979
Progress: 90000/808731 0.11128545832916013
Progress: 100000/808731 0.12365050925462237
Progress: 110000/808731 0.1360155601800846
Progress: 120000/808731 0.14838061110554684
Progress: 130000/808731 0.16074566203100907
Progress: 140000/808731 0.17311071295647132
Progress: 150000/808731 0.18547576388193354
Progress: 160000/808731 0.1978408148073958
Progress: 170000/808731 0.21020586573285802
Progress: 180000/808731 0.22257091665832027
Progress: 190000/808731 0.2349359675837825
Progress: 200000/808731 0.24730101850924474
Progress: 210000/808731 0.259666069434707
Progress: 220000/808731 0.2720311203601692
Progress: 230000/808731 0.28439617128563144
P