# network.py

In [1]:
import torch
import torch.nn as nn

EMBED_SIZE = 768


class ENCODER(torch.nn.Module):

    def __init__(self):
        super(ENCODER, self).__init__()
        
        self.projection = nn.Linear(EMBED_SIZE, EMBED_SIZE)
        self.norm = nn.LayerNorm(EMBED_SIZE)
    def forward(self, x):
        x = self.projection(x)
        x = self.norm(x)
        return x

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# data.py

In [2]:
import pickle
import pandas as pd
from torch.utils.data import Dataset, DataLoader


POS_NEG_PATH = "/datadrive/jianx/data/qidpidtriples.train.full.2.tsv"
QUERY_TRAIN_DICT_PATH = "/datadrive/jianx/data/queries.train.tsv"
PASSAGE_DICT_PATH = "/datadrive/jianx/data/collection.tsv"
TOP_DICT_PATH = "/datadrive/jianx/data/initial_ranking.dict"
RATING_DICT_PATH = "/datadrive/jianx/data/rel_scores.dict"
QUERY_TEST_DICT_PATH = "/datadrive/jianx/data/msmarco-test2019-queries.tsv"

NROW = None

def load_tsv_dict(path):
    with open(path) as file:
        line = file.readline()
        my_dict = {}
        while line:
            tokens = line.split("\t")
            indexid = int(tokens[0])
            text = tokens[1].rstrip()
            my_dict[indexid] = text
            line = file.readline()
    return my_dict


def load_pos_neg(path):
    data = pd.read_csv(path, sep='\t', header = None, nrows = NROW)
    return data

def obj_reader(path):
    with open(path, 'rb') as handle:
        return pickle.loads(handle.read())

def load():
    pos_neg = load_pos_neg(POS_NEG_PATH)
    query_dict = load_tsv_dict(QUERY_TRAIN_DICT_PATH)
    passage_dict = load_tsv_dict(PASSAGE_DICT_PATH)
    top_dict = obj_reader(TOP_DICT_PATH)
    rating_dict = obj_reader(RATING_DICT_PATH)
    query_test_dict = load_tsv_dict(QUERY_TEST_DICT_PATH)
    return pos_neg, query_dict, passage_dict, top_dict, rating_dict, query_test_dict

In [3]:
def encode_text(text, model, device):
    tokens = model.encode(text).to(device)
    if tokens.shape[0] > 512:
        tokens[:512]
    last_layer_features = model.extract_features(tokens)
    return last_layer_features[:,0,:]

class TrainDataset(Dataset):

    def __init__(self, pos_neg, queries, passages, model, device):

        self.pos_neg = pos_neg
        self.queries = queries
        self.passages = passages
        self.model = model
        self.device = device

    def __len__(self):
        return len(self.pos_neg)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # Positive negative pair
        qid = self.pos_neg.iloc[idx,0]
        pos_pid = self.pos_neg.iloc[idx,1]
        neg_pid = self.pos_neg.iloc[idx,2]
        self.model.train()
        q_tokens = encode_text(self.queries[qid], self.model, self.device)
        pos_tokens = encode_text(self.passages[pos_pid], self.model, self.device)
        neg_tokens = encode_text(self.passages[neg_pid], self.model, self.device)
        label = torch.tensor(0)
        sample = {'qid': qid, 'pos_pid': pos_pid, 'neg_pid': neg_pid, 'query': q_tokens,
                  'pos': pos_tokens, 'neg': neg_tokens, 'label': label}

        return sample

# train.py

In [4]:
def train(device, net, epochsize, dataiter, dataloader, optimizer):
    net.train()
    criterion = nn.CrossEntropyLoss()
    train_loss = torch.tensor(0.0)
    for i in range(epochsize):
        try:
            batch = dataiter.next()
        except StopIteration:
            print("Finished iterating current dataset, begin reiterate")
            dataiter = iter(dataloader)
            batch = dataiter.next()
        queries = net(batch['query'].to(device))
        pos = net(batch['pos'].to(device)).to(device)
        neg = net(batch['neg'].to(device)).to(device)
        out_pos = (pos * queries).sum(1).to(device)
        out_neg = (neg * queries).sum(1).to(device)
        outputs = torch.cat((out_pos, out_neg), -1).to(device)
        labels = batch['label'].to(device)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        print(i, loss.item())
    return train_loss / epochsize
        

# test.py

In [5]:
import torch
import random
import math
import numpy as np

def test_loader(device, net, model, top_dict, query_test_dict, passage_dict, rating_dict):
    net.eval()
    model.eval()
    qid_list = list(rating_dict.keys())
#     qid_list = random.sample(qid_list, test_batch)
    # sample test_batch of non-empty qids
    qids = []
    queries = []
    for qid in qid_list:
        qids.append(qid)
        queries.append(encode_text(query_test_dict[qid], model, device))
    result_dict = {}
    for i, qid in enumerate(qids):
        print(i)
        top_list = top_dict[qid]
        q_embed = net(queries[i].to(device)).detach()
        q_results = {}
        for j, pid in enumerate(top_list):
            p_seq = passage_dict[pid]
            p_embed = net(encode_text(p_seq, model, device)).detach().to(device)
            score = (p_embed * q_embed).sum(1).item()
            q_results[pid] = score
        result_dict[qid] = q_results
    #print_message(sorted(result_dict[qid].items(), key=lambda x: (x[1], [-1, 1][random.randrange(2)]), reverse=True))
    return result_dict


def get_ndcg_precision_rr(true_dict, test_dict, rank):
    sorted_result = sorted(test_dict.items(), key=lambda x: (x[1], [-1,1][random.randrange(2)]), reverse=True)
    original_rank = rank
    rank = min(rank, len(sorted_result))
    cumulative_gain = 0
    num_positive = 0
    rr = float("NaN")
    for i in range(len(sorted_result)):
        pid = sorted_result[i][0]
        if pid in true_dict:
            rr = 1 / (i + 1)
            break
    for i in range(rank):
        pid = sorted_result[i][0]
        if pid in true_dict:
            num_positive += 1
    sorted_result = sorted(test_dict.items(), key=lambda x: x[1], reverse=True)
    for i in range(rank):
        pid = sorted_result[i][0]
        relevance = 0
        if pid in true_dict:
            relevance = true_dict[pid]
        discounted_gain = relevance / math.log2(2 + i)
        cumulative_gain += discounted_gain
    sorted_ideal = sorted(true_dict.items(), key=lambda x: x[1], reverse=True)
    ideal_gain = 0
    for i in range(rank):
        relevance = 0
        if i < len(sorted_ideal):
            relevance = sorted_ideal[i][1]
        discounted_gain = relevance / math.log2(2 + i)
        ideal_gain += discounted_gain
    ndcg = 0
    if ideal_gain != 0:
         ndcg = cumulative_gain / ideal_gain
    return ndcg, num_positive / original_rank, rr


def test(device, net, model, top_dict, query_test_dict, passage_dict, rating_dict, rank):
    result_dict = test_loader(device, net, model, top_dict, query_test_dict, passage_dict, rating_dict)
    qids = list(result_dict.keys())
    result_ndcg = []
    result_prec = []
    result_rr = []
    for qid in qids:
        print(qid)
        ndcg, prec, rr = get_ndcg_precision_rr(rating_dict[qid], result_dict[qid], rank)
        result_ndcg.append(ndcg)
        result_prec.append(prec)
        result_rr.append(rr)
        print("qid: {} ndcg: {} prec: {} rr: {}".format(qid, ndcg, prec, rr))
    avg_ndcg = np.nanmean(result_ndcg)
    avg_prec = np.nanmean(result_prec)
    avg_rr = np.nanmean(result_rr)
    return avg_ndcg, avg_prec, avg_rr

# main.py

In [14]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import csv

BATCH_SIZE = 8
LEARNING_RATE = 0.001
EPOCH_SIZE = 1500
NEPOCH = 5000
RANK= 10

In [15]:
output_path = "./results/output_roberta_cut.csv"
roberta_model_path = "./results/roberta_model_cut.pt"
net_path = "./results/net_model_cut.pt"

In [16]:
roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Current device is: %s" % torch.cuda.get_device_name(device))
roberta.cuda(device)
roberta.train()

Using cache found in /home/ruohan/.cache/torch/hub/pytorch_fairseq_master


Current device is: Tesla P100-PCIE-16GB


RobertaHubInterface(
  (model): RobertaModel(
    (encoder): RobertaEncoder(
      (sentence_encoder): TransformerSentenceEncoder(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(50265, 768, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(514, 768, padding_idx=1)
        (layers): ModuleList(
          (0): TransformerSentenceEncoderLayer(
            (dropout_module): FairseqDropout()
            (activation_dropout_module): FairseqDropout()
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=T

In [17]:
pos_neg, query_dict, passage_dict, top_dict, rating_dict, query_test_dict = load()

In [18]:
trainset = TrainDataset(pos_neg, query_dict, passage_dict, roberta, device)
print(trainset.__len__())
trainloader = DataLoader(trainset, batch_size = BATCH_SIZE, 
                         shuffle=True, num_workers=0)
trainiter = iter(trainloader)

397768673


In [19]:
net = ENCODER().to(device)

In [20]:
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

In [21]:
for i in range(NEPOCH):
    train_loss = train(device, net, EPOCH_SIZE, trainiter, trainloader, optimizer)
    avg_ndcg, avg_prec, avg_rr = test(device, net, roberta, top_dict, query_test_dict, passage_dict, rating_dict, rank = RANK)
    print("Epoch: {} {} {} {} {}".format(i, train_loss.item(), avg_ndcg, avg_prec, avg_rr))
    with open(output_path, mode='a+') as output:
        output_writer = csv.writer(output)
        output_writer.writerow([i, train_loss.item(), avg_ndcg, avg_prec, avg_rr])
    torch.save(net.state_dict(), net_path)
    torch.save(roberta.state_dict(), roberta_model_path)

0 8.851578712463379
1 8.48298454284668
2 8.321390151977539
3 8.122848510742188
4 7.966214656829834
5 7.7890472412109375
6 7.607579231262207
7 7.349181652069092
8 6.935258388519287
9 6.434542655944824
10 5.8426971435546875
11 5.239776611328125
12 4.233829021453857
13 3.2965760231018066
14 2.0512142181396484
15 1.3922836780548096
16 0.7709947824478149
17 0.49929988384246826
18 0.7605587244033813
19 1.0594967603683472
20 1.7693712711334229
21 2.4784514904022217
22 0.38945460319519043
23 0.7041702270507812
24 0.8140835762023926
25 0.6430211067199707
26 1.213646411895752
27 0.5686397552490234
28 0.8880472183227539
29 0.7840099334716797
30 0.8225448131561279
31 1.2094342708587646
32 0.5330288410186768
33 1.0055582523345947
34 1.2615275382995605
35 0.9083576202392578
36 0.6639366149902344
37 0.809894323348999
38 0.8287694454193115
39 1.0092380046844482
40 0.7107727527618408
41 1.3805404901504517
42 0.9673434495925903
43 0.6745114326477051
44 0.8333303928375244
45 1.1118476390838623
46 0.84720

364 1.0752633810043335
365 0.6245999336242676
366 1.086764931678772
367 0.5142303705215454
368 0.7150583267211914
369 0.5339111089706421
370 0.6594656705856323
371 0.6612969636917114
372 0.7921745777130127
373 0.5680103302001953
374 0.7587168216705322
375 0.4454610347747803
376 0.4779665470123291
377 0.5723347663879395
378 0.6783779859542847
379 1.265845537185669
380 0.9960130453109741
381 0.5517069101333618
382 0.35561490058898926
383 1.0698055028915405
384 0.8686047792434692
385 0.9133274555206299
386 0.8358938694000244
387 0.7058228254318237
388 1.1012831926345825
389 0.6616865396499634
390 0.8931066989898682
391 0.6033897399902344
392 0.6491208076477051
393 0.8578530550003052
394 0.5729907751083374
395 0.6341334581375122
396 0.9642506837844849
397 0.7256145477294922
398 0.47376322746276855
399 0.6660091876983643
400 0.5659044981002808
401 0.9032167196273804
402 0.7968398332595825
403 0.5223195552825928
404 0.7013345956802368
405 0.7479276657104492
406 0.47438645362854004
407 0.7432

KeyboardInterrupt: 