In [2]:
import torch
import torch.nn as nn
from torch import optim
import csv
import os
import pickle
import random
import datetime
import math
import numpy as np
import fasttext
from IPython.display import display, HTML

display(HTML(data="""
<style>
    div#notebook-container    { width: 95%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>
"""))

def print_message(s):
    print("[{}] {}".format(datetime.datetime.utcnow().strftime("%b %d, %H:%M:%S"), s), flush=True)

In [3]:
POS_NEG_DICT_PATH = "/home/jianx/data/qid_pos_neg.dict"
PASSAGE_DICT_PATH = "/home/jianx/data/passages.dict"
QUERY_TRAIN_DICT_PATH = "/home/jianx/data/queries_train.dict"
QUERY_EVAL_DICT_PATH = "/home/jianx/data/queries_eval.dict"
QUERY_DEV_DICT_PATH = "/home/jianx/data/queries_dev.dict"
TOP_DICT_PATH = "/home/jianx/data/initial_ranking.dict"
RATING_DICT_PATH = "/home/jianx/data/rel_scores.dict"
QUERY_TEST_DICT_PATH = "/home/jianx/data/queries_test.dict"

In [4]:
def obj_reader(path):
    with open(path, 'rb') as handle:
        return pickle.loads(handle.read())


def obj_writer(obj, path):
    with open(path, 'wb') as handle:
        pickle.dump(obj, handle)


def load():
    pos_neg_dict = obj_reader(POS_NEG_DICT_PATH)
    query_dict = obj_reader(QUERY_TRAIN_DICT_PATH)
    passage_dict = obj_reader(PASSAGE_DICT_PATH)
    top_dict = obj_reader(TOP_DICT_PATH)
    rating_dict = obj_reader(RATING_DICT_PATH)
    query_test_dict = obj_reader(QUERY_TEST_DICT_PATH)
    return pos_neg_dict, query_dict, passage_dict, top_dict, rating_dict, query_test_dict

In [5]:
print_message("Loading data")
pos_neg_dict_g, query_dict_g, passage_dict_g, top_dict_g, rating_dict_g, query_test_dict_g = load()
print_message("Data successfully loaded.")
print_message("Positive Negative Pair dict size: " + str(len(pos_neg_dict_g)))
print_message("Num of queries: " + str(len(query_dict_g)))
print_message("Num of passages: " + str(len(passage_dict_g)))

[Jul 08, 23:27:31] Loading data
[Jul 08, 23:35:04] Data successfully loaded.
[Jul 08, 23:35:04] Positive Negative Pair dict size: 400782
[Jul 08, 23:35:04] Num of queries: 808731
[Jul 08, 23:35:04] Num of passages: 8841823


In [6]:
NUM_HIDDEN_NODES = 256
NUM_HIDDEN_LAYERS = 1
DROPOUT_RATE = 0.1

class OldDSSM(torch.nn.Module):

    def __init__(self, embed_size):
        super(OldDSSM, self).__init__()
        layers = []
        last_dim = VOCAB_LEN
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# Define the network
class DSSM(torch.nn.Module):

    def __init__(self, embed_size):
        super(DSSM, self).__init__()
        self.embed = nn.Embedding(VOCAB_LEN, NUM_HIDDEN_NODES, sparse=True)
        layers = []
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(NUM_HIDDEN_NODES, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
        layers.append(nn.Linear(NUM_HIDDEN_NODES, embed_size))
        self.fc = nn.Sequential(*layers)
        
    def forward(self, x):
        y = self.embed(x).sum(dim=1)
        y = self.fc(y)
        return y

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
class CDSSM(torch.nn.Module):
    
    def __init__(self, embed_size):
        super(CDSSM, self).__init__()
        self.embed = nn.Embedding(VOCAB_LEN, NUM_HIDDEN_NODES, sparse=True)
        self.conv = nn.Conv1d(NUM_HIDDEN_NODES, NUM_HIDDEN_NODES, kernel_size=3)
        layers = []
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(NUM_HIDDEN_NODES, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
        layers.append(nn.Linear(NUM_HIDDEN_NODES, embed_size))
        self.fc = nn.Sequential(*layers)

    def forward(self, x):
        y = self.embed(x)
        y = y.permute(0, 2, 1)
        y = self.conv(y)
        y, _ = torch.max(y, dim=-1)
        y = self.fc(y)
        return y
    
    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

class EmbedDSSM(torch.nn.Module):

    def __init__(self, embed_size, vocab_count, embedding):
        super(EmbedDSSM, self).__init__()
        layer_0 = nn.Linear(vocab_count, NUM_HIDDEN_NODES)
        layer_0.weight = nn.Parameter(embedding.T, requires_grad=True)
        layers = [layer_0]
#         layers = []
        last_dim = NUM_HIDDEN_NODES
        for i in range(NUM_HIDDEN_LAYERS):
            if i != 0:
                layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)
        self.embedding = embedding
    def forward(self, x):
#         x = torch.mm(x, self.embedding)
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [24]:
VOCAB_LEN = 100000
MAX_QUERY_TERMS = 20
MAX_PASSAGE_TERMS = 200

def generate_sparse(idx, vocab_len=VOCAB_LEN):
    index_tensor = torch.LongTensor([idx])
    value_tensor = torch.Tensor([1/len(idx)] * len(idx))
    sparse_tensor = torch.sparse.FloatTensor(index_tensor, value_tensor, torch.Size([vocab_len, ]))
    return sparse_tensor


def mini_batch(batch_size, device, pos_neg_dict, query_dict, passage_dict):
    query_list = list(pos_neg_dict.keys())
    queries = []
    pos = []
    neg = []
    while len(queries) < batch_size:
        qid = random.sample(query_list, 1)[0]
        pos_neg_pair = random.sample(pos_neg_dict[qid], 1)
        pos_pid = pos_neg_pair[0][0]
        neg_pid = pos_neg_pair[0][1]
        q_seq = query_dict[qid]
        pos_seq = passage_dict[pos_pid]
        neg_seq = passage_dict[neg_pid]
        if q_seq != [] and pos_seq != [] and neg_seq != []:
            if USE_SPARSE_INPUT:
                queries.append(generate_sparse(q_seq))
                pos.append(generate_sparse(pos_seq))
                neg.append(generate_sparse(neg_seq))
            else:
                queries.append(torch.tensor(q_seq[:MAX_QUERY_TERMS] + [100000]*(MAX_QUERY_TERMS - len(q_seq)), dtype=torch.int64))
                pos.append(torch.tensor(pos_seq[:MAX_PASSAGE_TERMS] + [100000]*(MAX_PASSAGE_TERMS - len(pos_seq)), dtype=torch.int64))
                neg.append(torch.tensor(neg_seq[:MAX_PASSAGE_TERMS] + [100000]*(MAX_PASSAGE_TERMS - len(neg_seq)), dtype=torch.int64))
    labels = [0 for i in range(batch_size)]
    return torch.stack(queries).to(device), torch.stack(pos).to(device), torch.stack(neg).to(device), labels


def train(net, net_q, epoch_size, batch_size, optimizer, optimizer_q, device, pos_neg_dict, query_dict,
          passage_dict):
    criterion = nn.CrossEntropyLoss()
    train_loss = 0.0
    net.train()
    net_q.train()
    for mb_idx in range(epoch_size):
        # Read in a new mini-batch of data!
        queries, pos, neg, labels = mini_batch(batch_size, device, pos_neg_dict, query_dict,
                                               passage_dict)
        optimizer.zero_grad()
        q_embed = net_q(queries)
        pos_embed = net(pos)
        neg_embed = net(neg)
        neg_rand_embed = torch.cat((neg_embed[1:,:], neg_embed[0:1,:]), 0)
        out_pos = torch.cosine_similarity(q_embed, pos_embed).unsqueeze(0).T
        out_neg = torch.cosine_similarity(q_embed, neg_embed).unsqueeze(0).T
        out_neg_rand = torch.cosine_similarity(q_embed, neg_rand_embed).unsqueeze(0).T
        out = torch.cat((out_pos, out_neg), -1)
        out_rand = torch.cat((out_pos, out_neg_rand), -1)
        out = torch.cat((out, out_rand), 0)
        out = out * 10
        labels = torch.tensor(labels).to(device)
        labels = torch.cat((labels, labels), 0)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        optimizer_q.step()
        train_loss += loss.item()
        # print_message(str(mb_idx) + " iteration: " + str(train_loss / (mb_idx + 1)))
    return train_loss / epoch_size

In [25]:
def test_loader(net, net_q, device, test_batch, top_dict, query_test_dict, passage_dict, rating_dict):
    net.eval()
    net_q.eval()
    device = device
    qid_list = list(rating_dict.keys())
    # sample test_batch of non-empty qids
    qids = []
    queries = []
    while len(qids) < test_batch:
        qid = random.sample(qid_list, 1)[0]
        q_seq = query_test_dict[qid]
        if q_seq != [] and qid not in qids:
            qids.append(qid)
            if USE_SPARSE_INPUT:
                queries.append(generate_sparse(q_seq).to(device))
            else:
                queries.append(torch.tensor(q_seq[:MAX_QUERY_TERMS] + [100000]*(MAX_QUERY_TERMS - len(q_seq)), dtype=torch.int64).unsqueeze(0).to(device))
    # compute cosine similarity
    result_dict = {}
    for i, qid in enumerate(qids):
        top_list = top_dict[qid]
#         q_embed = net(queries[i].unsqueeze(0)).detach()
        q_embed = net_q(queries[i]).detach()
        q_results = {}
        for j, pid in enumerate(top_list):
            p_seq = passage_dict[pid]
            if not p_seq:
                score = -1
            else:
                if USE_SPARSE_INPUT:
#                     p_embed = net(generate_sparse(p_seq).unsqueeze(0).to(device)).detach()
#                     score = torch.cosine_similarity(q_embed, p_embed).item()
                    p_embed = net(generate_sparse(p_seq).to(device)).detach()
                    score = torch.cosine_similarity(q_embed.unsqueeze(0), p_embed.unsqueeze(0)).item()
                else:
                    p_embed = net(torch.tensor(p_seq[:MAX_PASSAGE_TERMS] + [100000]*(MAX_PASSAGE_TERMS - len(p_seq)), dtype=torch.int64).unsqueeze(0).to(device)).detach()
                    score = torch.cosine_similarity(q_embed, p_embed).item()
            q_results[pid] = score
        result_dict[qid] = q_results
    #print_message(sorted(result_dict[qid].items(), key=lambda x: (x[1], [-1, 1][random.randrange(2)]), reverse=True))
    return result_dict


def get_ndcg_precision_rr(true_dict, test_dict, rank):
    sorted_result = sorted(test_dict.items(), key=lambda x: (x[1], [-1,1][random.randrange(2)]), reverse=True)
    original_rank = rank
    rank = min(rank, len(sorted_result))
    cumulative_gain = 0
    num_positive = 0
    rr = float("NaN")
    for i in range(len(sorted_result)):
        pid = sorted_result[i][0]
        if pid in true_dict:
            rr = 1 / (i + 1)
            break
    for i in range(rank):
        pid = sorted_result[i][0]
        if pid in true_dict:
            num_positive += 1
    sorted_result = sorted(test_dict.items(), key=lambda x: x[1], reverse=True)
    for i in range(rank):
        pid = sorted_result[i][0]
        relevance = 0
        if pid in true_dict:
            relevance = true_dict[pid]
        discounted_gain = relevance / math.log2(2 + i)
        cumulative_gain += discounted_gain
    sorted_ideal = sorted(true_dict.items(), key=lambda x: x[1], reverse=True)
    ideal_gain = 0
    for i in range(rank):
        relevance = 0
        if i < len(sorted_ideal):
            relevance = sorted_ideal[i][1]
        discounted_gain = relevance / math.log2(2 + i)
        ideal_gain += discounted_gain
    ndcg = 0
    if ideal_gain != 0:
         ndcg = cumulative_gain / ideal_gain
    return ndcg, num_positive / original_rank, rr


def test(net, net_q, device, test_batch, top_dict, query_test_dict, passage_dict, rating_dict, rank):
    result_dict = test_loader(net, net_q, device, test_batch, top_dict, query_test_dict, passage_dict, rating_dict)
    qids = list(result_dict.keys())
    result_ndcg = []
    result_prec = []
    result_rr = []
    for qid in qids:
        ndcg, prec, rr = get_ndcg_precision_rr(rating_dict[qid], result_dict[qid], rank)
        result_ndcg.append(ndcg)
        result_prec.append(prec)
        result_rr.append(rr)
    avg_ndcg = np.nanmean(result_ndcg)
    avg_prec = np.nanmean(result_prec)
    avg_rr = np.nanmean(result_rr)
    return avg_ndcg, avg_prec, avg_rr

In [9]:
# NUM_EPOCHS = 128
# NUM_EPOCHS = 2
NUM_EPOCHS = 256
EPOCH_SIZE = 1024
# EPOCH_SIZE = 2
BATCH_SIZE = 1024
LEARNING_RATE = 0.001
EMBED_SIZE = 256
USE_SPARSE_INPUT = True

In [10]:
# Train word embedding model
IN_FILE = "/home/jianx/data/processed_passages.txt"
def generate_embeddings(in_file, model_file):
    embeddings = fasttext.train_unsupervised(in_file, model='skipgram', dim=NUM_HIDDEN_NODES, bucket=10000, minCount=1, minn=1, maxn=0, ws=10, epoch=5)
    embeddings.save_model(model_file)
def get_pretrained_embeddings(model_file):
    model = fasttext.load_model(model_file)
    embed_size = model.get_input_matrix().shape[1]
    pretrained_embeddings = torch.FloatTensor(model.get_input_matrix())
    vocab = model.get_words()
    return vocab, pretrained_embeddings

In [11]:
# train and save word_embedding model
# generate_embeddings(IN_FILE, model_file = "./results/skipgram.model")

In [12]:
# load pretrained model
vocab, pretrained_embeddings = get_pretrained_embeddings(model_file = "./results/skipgram.model")
vocab_count = pretrained_embeddings.size()[0] - 1
print(vocab_count)



100000


In [13]:
# pretrained_model = fasttext.load_model("./results/skipgram.model")

In [14]:
# original_vocab = set(w.lower() for w in open(IN_FILE).read().split())

In [15]:
pretrained_vocab = set(vocab[1:])
len(pretrained_vocab)

100000

In [16]:
# len(original_vocab)

In [17]:
# diff_vocab = pretrained_vocab.difference(original_vocab)
# print(diff_vocab)

In [18]:
VOCAB_PATH = "/home/jianx/data/vocabs_ver2.csv"
with open(VOCAB_PATH, mode='r') as vocab_file:
    reader = csv.reader(vocab_file)
    old_vocab = {rows[0]:i for i, rows in enumerate(reader)}
reorder_index = []
for word in pretrained_vocab:
    reorder_index.append(old_vocab[word])

In [19]:
reorder_embeddings = pretrained_embeddings[reorder_index,:]

In [20]:
print(reorder_embeddings.shape)

torch.Size([100000, 256])


In [21]:
CURRENT_DEVICE = torch.device("cuda:1")
print_message(CURRENT_DEVICE)
print_message("Num of epochs:{}".format(NUM_EPOCHS))
print_message("Epoch size:{}".format(EPOCH_SIZE))
print_message("Batch size:{}".format(BATCH_SIZE))
print_message("Learning rate:{}".format(LEARNING_RATE))
print_message("Embedding size:{}".format(EMBED_SIZE))
RANK = 10
TEST_BATCH = 43
MODEL_PATH = "./results/"
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)

[Jul 08, 23:36:05] cuda:1
[Jul 08, 23:36:05] Num of epochs:256
[Jul 08, 23:36:05] Epoch size:1024
[Jul 08, 23:36:05] Batch size:1024
[Jul 08, 23:36:05] Learning rate:0.001
[Jul 08, 23:36:05] Embedding size:256


In [None]:
# net = OldDSSM(embed_size=EMBED_SIZE).to(CURRENT_DEVICE) if USE_SPARSE_INPUT else DSSM(embed_size=EMBED_SIZE).to(CURRENT_DEVICE)
# net = CDSSM(embed_size=EMBED_SIZE).to(CURRENT_DEVICE)
net = EmbedDSSM(embed_size = EMBED_SIZE, vocab_count=vocab_count, embedding=reorder_embeddings.to(CURRENT_DEVICE)).to(CURRENT_DEVICE)
net_q = EmbedDSSM(embed_size = EMBED_SIZE, vocab_count=vocab_count, embedding=reorder_embeddings.to(CURRENT_DEVICE)).to(CURRENT_DEVICE)
print(net)
arg_str = str(NUM_EPOCHS) + "_" + str(EPOCH_SIZE) + "_" + str(BATCH_SIZE) + "_" + str(LEARNING_RATE) + "_" + str(
    EMBED_SIZE)
unique_path = MODEL_PATH + arg_str + ".model"
q_net_path = MODEL_PATH + arg_str + "query.model"
output_path = MODEL_PATH + arg_str + ".csv"
optimizer = optim.Adagrad(net.parameters(), lr=LEARNING_RATE)
optimizer_q = optim.Adagrad(net.parameters(), lr=LEARNING_RATE)
for ep_idx in range(NUM_EPOCHS):
    train_loss = train(net, net_q, EPOCH_SIZE, BATCH_SIZE, optimizer, optimizer_q, CURRENT_DEVICE, pos_neg_dict_g,
                       query_dict_g, passage_dict_g)
    avg_ndcg, avg_prec, avg_rr = test(net, net_q, CURRENT_DEVICE, TEST_BATCH, top_dict_g, query_test_dict_g, passage_dict_g, rating_dict_g, RANK)
    print_message("Epoch:{}, loss:{}, NDCG:{}, P:{}, RR:{}".format(ep_idx, train_loss, avg_ndcg, avg_prec, avg_rr))
    #print_message("Epoch:{}, loss:{}".format(ep_idx, train_loss))
    with open(output_path, mode='a+') as output:
        output_writer = csv.writer(output)
        output_writer.writerow([ep_idx, train_loss, avg_ndcg, avg_prec, avg_rr])
    torch.save(net, unique_path)
    torch.save(net_q, q_net_path)

EmbedDSSM(
  (model): Sequential(
    (0): Linear(in_features=100000, out_features=256, bias=True)
    (1): ReLU()
    (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=256, out_features=256, bias=True)
  )
)
[Jul 08, 23:44:21] Epoch:0, loss:0.4448716451006476, NDCG:0.18296330343508982, P:0.23720930232558138, RR:0.3800335604430727


  "type " + obj.__name__ + ". It won't be checked "


[Jul 08, 23:51:00] Epoch:1, loss:0.37408088910160586, NDCG:0.19096475453706951, P:0.2511627906976744, RR:0.3981152565846276
[Jul 08, 23:57:49] Epoch:2, loss:0.35070205802912824, NDCG:0.20674335733560065, P:0.26046511627906976, RR:0.4432915779934522
[Jul 09, 00:04:31] Epoch:3, loss:0.3378398595377803, NDCG:0.2112313284447126, P:0.26744186046511625, RR:0.4307261587543296
[Jul 09, 00:11:10] Epoch:4, loss:0.32702755834907293, NDCG:0.21981388645250588, P:0.27674418604651163, RR:0.43265340420317266
[Jul 09, 00:17:52] Epoch:5, loss:0.31956963991979137, NDCG:0.2188579815947019, P:0.2744186046511628, RR:0.4333289372430661
[Jul 09, 00:24:31] Epoch:6, loss:0.31347700097830966, NDCG:0.22380010401852274, P:0.28139534883720935, RR:0.45179076177112726
[Jul 09, 00:31:19] Epoch:7, loss:0.3080224029254168, NDCG:0.22935603616565703, P:0.2837209302325581, RR:0.4551835085216562
[Jul 09, 00:37:59] Epoch:8, loss:0.30322391420486383, NDCG:0.22900768745370678, P:0.2837209302325581, RR:0.4528667604097615
[Jul 0