# An interactive version of DSSM forward ranker

In [1]:
import torch
import torch.nn as nn
from torch import optim
import csv
import os
import pickle
import random
import math
import numpy as np

### Load data

In [2]:
POS_NEG_DICT_PATH = "/home/jianx/data/qid_pos_neg.dict"
PASSAGE_DICT_PATH = "/home/jianx/data/passages.dict"
QUERY_TRAIN_DICT_PATH = "/home/jianx/data/queries_train.dict"
QUERY_EVAL_DICT_PATH = "/home/jianx/data/queries_eval.dict"
QUERY_DEV_DICT_PATH = "/home/jianx/data/queries_dev.dict"
TOP_DICT_PATH = "/home/jianx/data/initial_ranking.dict"
RATING_DICT_PATH = "/home/jianx/data/rel_scores.dict"
QUERY_TEST_DICT_PATH = "/home/jianx/data/queries_test.dict"

In [3]:
def obj_reader(path):
    with open(path, 'rb') as handle:
        return pickle.loads(handle.read())


def obj_writer(obj, path):
    with open(path, 'wb') as handle:
        pickle.dump(obj, handle)


def load():
    pos_neg_dict = obj_reader(POS_NEG_DICT_PATH)
    query_dict = obj_reader(QUERY_TRAIN_DICT_PATH)
    passage_dict = obj_reader(PASSAGE_DICT_PATH)
    top_dict = obj_reader(TOP_DICT_PATH)
    rating_dict = obj_reader(RATING_DICT_PATH)
    query_test_dict = obj_reader(QUERY_TEST_DICT_PATH)
    return pos_neg_dict, query_dict, passage_dict, top_dict, rating_dict, query_test_dict

In [4]:
print("Loading data")
pos_neg_dict_g, query_dict_g, passage_dict_g, top_dict_g, rating_dict_g, query_test_dict_g = load()
print("Data successfully loaded.")
print("Positive Negative Pair dict size: " + str(len(pos_neg_dict_g)))
print("Num of queries: " + str(len(query_dict_g)))
print("Num of passages: " + str(len(passage_dict_g)))

Loading data
Data successfully loaded.
Positive Negative Pair dict size: 400782
Num of queries: 808731
Num of passages: 8841823


### Define network function

In [5]:
NUM_HIDDEN_NODES = 64
NUM_HIDDEN_LAYERS = 3
DROPOUT_RATE = 0.1
FEAT_COUNT = 100000


# Define the network
class DSSM(torch.nn.Module):

    def __init__(self, embed_size, device):
        super(DSSM, self).__init__()

        layers = []
        last_dim = FEAT_COUNT
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)
        self.scale = torch.tensor([10], dtype=torch.float).to(device)
    def forward(self, x):
        return self.model(x) * self.scale

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

### Define train functions

In [6]:
VOCAB_LEN = 100000

def generate_sparse(idx, vocab_len=VOCAB_LEN):
    index_tensor = torch.LongTensor([idx])
    value_tensor = torch.Tensor([1/len(idx)] * len(idx))
    sparse_tensor = torch.sparse.FloatTensor(index_tensor, value_tensor, torch.Size([vocab_len, ]))
    return sparse_tensor


def mini_batch(batch_size, device, pos_neg_dict, query_dict, passage_dict):
    query_list = list(pos_neg_dict.keys())
    queries = []
    pos = []
    neg = []
    while len(queries) < batch_size:
        qid = random.sample(query_list, 1)[0]
        pos_neg_pair = random.sample(pos_neg_dict[qid], 1)
        pos_pid = pos_neg_pair[0][0]
        neg_pid = pos_neg_pair[0][1]
        q_seq = query_dict[qid]
        pos_seq = passage_dict[pos_pid]
        neg_seq = passage_dict[neg_pid]
        if q_seq != [] and pos_seq != [] and neg_seq != []:
            queries.append(generate_sparse(q_seq))
            pos.append(generate_sparse(pos_seq))
            neg.append(generate_sparse(neg_seq))
    labels = [0 for i in range(batch_size)]
    return torch.stack(queries).to(device), torch.stack(pos).to(device), torch.stack(neg).to(device), labels


def train(net, epoch_size, batch_size, optimizer, device, pos_neg_dict, query_dict,
          passage_dict):
    criterion = nn.CrossEntropyLoss()
    train_loss = 0.0
    net.train()
    for mb_idx in range(epoch_size):
        # Read in a new mini-batch of data!
        queries, pos, neg, labels = mini_batch(batch_size, device, pos_neg_dict, query_dict,
                                               passage_dict)
        optimizer.zero_grad()
        q_embed = net(queries)
        pos_embed = net(pos)
        neg_embed = net(neg)
        out_pos = torch.cosine_similarity(q_embed, pos_embed).unsqueeze(0).T
        out_neg = torch.cosine_similarity(q_embed, neg_embed).unsqueeze(0).T
        out = torch.cat((out_pos, out_neg), -1)
        loss = criterion(out, torch.tensor(labels).to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # print(str(mb_idx) + " iteration: " + str(train_loss / (mb_idx + 1)))
    return train_loss / epoch_size

### Define test functions

In [7]:
def test_loader(net, device, test_batch, top_dict, query_test_dict, passage_dict, rating_dict):
    net.eval()
    device = device
    qid_list = list(rating_dict.keys())
    # sample test_batch of non-empty qids
    qids = []
    queries = []
    while len(qids) < test_batch:
        qid = random.sample(qid_list, 1)[0]
        q_seq = query_test_dict[qid]
        if q_seq != [] and qid not in qids:
            qids.append(qid)
            queries.append(generate_sparse(q_seq).to(device))
    # compute cosine similarity
    result_dict = {}
    for i, qid in enumerate(qids):
        top_list = top_dict[qid]
        q_embed = net(queries[i]).detach()
        q_results = {}
        for j, pid in enumerate(top_list):
            p_seq = passage_dict[pid]
            if not p_seq:
                score = -1
            else:
                p_embed = net(generate_sparse(p_seq).to(device)).detach()
                score = torch.cosine_similarity(q_embed.unsqueeze(0), p_embed.unsqueeze(0)).item()
            q_results[pid] = score
        result_dict[qid] = q_results
    print(sorted(result_dict[qid].items(), key=lambda x: (x[1], [-1, 1][random.randrange(2)]), reverse=True))
    return result_dict


def get_ndcg_precision_rr(true_dict, test_dict, rank):
    sorted_result = sorted(test_dict.items(), key=lambda x: (x[1], [-1,1][random.randrange(2)]), reverse=True)
    original_rank = rank
    rank = min(rank, len(sorted_result))
    cumulative_gain = 0
    num_positive = 0
    rr = float("NaN")
    for i in range(len(sorted_result)):
        pid = sorted_result[i][0]
        if pid in true_dict:
            rr = 1 / (i + 1)
            break
    for i in range(rank):
        pid = sorted_result[i][0]
        if pid in true_dict:
            num_positive += 1
    sorted_result = sorted(test_dict.items(), key=lambda x: x[1], reverse=True)
    for i in range(rank):
        pid = sorted_result[i][0]
        relevance = 0
        if pid in true_dict:
            relevance = true_dict[pid]
        discounted_gain = relevance / math.log2(2 + i)
        cumulative_gain += discounted_gain
    sorted_ideal = sorted(true_dict.items(), key=lambda x: x[1], reverse=True)
    ideal_gain = 0
    for i in range(rank):
        relevance = 0
        if i < len(sorted_ideal):
            relevance = sorted_ideal[i][1]
        discounted_gain = relevance / math.log2(2 + i)
        ideal_gain += discounted_gain
    ndcg = 0
    if ideal_gain != 0:
         ndcg = cumulative_gain / ideal_gain
    return ndcg, num_positive / original_rank, rr


def test(net, device, test_batch, top_dict, query_test_dict, passage_dict, rating_dict, rank):
    result_dict = test_loader(net, device, test_batch, top_dict, query_test_dict, passage_dict, rating_dict)
    qids = list(result_dict.keys())
    result_ndcg = []
    result_prec = []
    result_rr = []
    for qid in qids:
        ndcg, prec, rr = get_ndcg_precision_rr(rating_dict[qid], result_dict[qid], rank)
        result_ndcg.append(ndcg)
        result_prec.append(prec)
        result_rr.append(rr)
    avg_ndcg = np.nanmean(result_ndcg)
    avg_prec = np.nanmean(result_prec)
    avg_rr = np.nanmean(result_rr)
    return avg_ndcg, avg_prec, avg_rr

## Train and test model

### Initialize hyper parameters

In [8]:
NUM_EPOCHS = 2
EPOCH_SIZE = 10
BATCH_SIZE = 100
LEARNING_RATE = 0.01
EMBED_SIZE = 256

In [10]:
CURRENT_DEVICE = torch.device("cuda:0")
print(CURRENT_DEVICE)
print("Num of epochs:", NUM_EPOCHS)
print("Epoch size:", EPOCH_SIZE)
print("Batch size:", BATCH_SIZE)
print("Learning rate:", LEARNING_RATE)
print("Embedding size:", EMBED_SIZE)
RANK = 10
TEST_BATCH = 43
MODEL_PATH = "./results/"
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)

cuda:0
Num of epochs: 2
Epoch size: 10
Batch size: 100
Learning rate: 0.01
Embedding size: 256


### Start training

In [11]:
net = DSSM(embed_size=EMBED_SIZE, device=CURRENT_DEVICE).to(CURRENT_DEVICE)
arg_str = str(NUM_EPOCHS) + "_" + str(EPOCH_SIZE) + "_" + str(BATCH_SIZE) + "_" + str(LEARNING_RATE) + "_" + str(
    EMBED_SIZE)
unique_path = MODEL_PATH + arg_str + ".model"
output_path = MODEL_PATH + arg_str + ".csv"
for ep_idx in range(NUM_EPOCHS):
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    train_loss = train(net, EPOCH_SIZE, BATCH_SIZE, optimizer, CURRENT_DEVICE, pos_neg_dict_g,
                       query_dict_g, passage_dict_g)
    avg_ndcg, avg_prec, avg_rr = test(net, CURRENT_DEVICE, TEST_BATCH, top_dict_g, query_test_dict_g, passage_dict_g,
                                      rating_dict_g, RANK)
    print("Epoch:{}, loss:{}, NDCG:{}, P:{}, RR:{}".format(ep_idx, train_loss, avg_ndcg, avg_prec, avg_rr))
    with open(output_path, mode='a+') as output:
        output_writer = csv.writer(output)
        output_writer.writerow([ep_idx, train_loss, avg_ndcg, avg_prec, avg_rr])
    torch.save(net, unique_path)

[(3905806, 0.9963648915290833), (985256, 0.9910889863967896), (8712732, 0.9909860491752625), (8388908, 0.9896034002304077), (3067684, 0.9881701469421387), (2176863, 0.9879141449928284), (3905808, 0.9873330593109131), (1301743, 0.9861308932304382), (7352218, 0.986065149307251), (3491395, 0.9856697916984558), (8179087, 0.98552006483078), (1234330, 0.9853038787841797), (3905803, 0.9848719239234924), (3147623, 0.9847763776779175), (7911557, 0.9847475290298462), (6685837, 0.9846718311309814), (3241580, 0.9841294288635254), (2385230, 0.9838538765907288), (7046935, 0.9838019609451294), (5127613, 0.9837325811386108), (6086944, 0.983561098575592), (3905057, 0.983371376991272), (150004, 0.9832867383956909), (8643552, 0.9830265641212463), (7132885, 0.9830014109611511), (1959030, 0.982833743095398), (3029489, 0.9826996326446533), (5168243, 0.9826874136924744), (985252, 0.9824442863464355), (262492, 0.9822874069213867), (4452004, 0.9820535778999329), (350, 0.9820248484611511), (892136, 0.9820197224

  "type " + obj.__name__ + ". It won't be checked "
