In [None]:
# 1. Load data: passage_embeddings, query_train_embeddings, train_data (a proportion of)
# 2. Network input: 768 output: 32
# 3. Train: concatenate original 768 with 32 

## Import data: load_data.py

In [1]:
import csv
import pickle
import numpy as np

PASSAGE_NP_PATH = "/home/jianx/results/passage_0__emb_p__data_obj_0.pb"
PASSAGE_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_passage_map.dict"
QUERY_TRAIN_NP_PATH = "/home/jianx/results/query_0__emb_p__data_obj_0.pb"
QUERY_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_query_train_map.dict"
TRAIN_RANK_PATH = "/datadrive/jianx/data/train_data/ance_training_rank100_8841823.csv"
# TRAIN_RANK_PATH = "/datadrive/ruohan/reverse_ranker/new_training/combine_rank_train_phase2.csv"


OUT_RANK = 200
N_PASSAGE = 200000
def obj_reader(path):
    with open(path, 'rb') as handle:
        return pickle.load(handle, encoding="bytes")
def load_train(path):
    with open(path, "r") as file:
        pos_dict = {}
        neg_dict = {}
        count = 0
        for line in file:
            if count >= N_PASSAGE * 100:
                break
            count += 1
            tokens = line.split(",")
            pid = int(tokens[0])
            qid = int(tokens[1])
            rank = int(tokens[2].rstrip())
            if rank == 0:
                if pid not in neg_dict:
                    neg_dict[pid] = {}
                neg_dict[pid][qid] = OUT_RANK
            else:
                if pid not in pos_dict:
                    pos_dict[pid] = {}
                pos_dict[pid][qid] = rank
    return pos_dict, neg_dict
def map_id(old_np, mapping):
    new_dict = dict(zip(mapping.values(),old_np))
    return new_dict
def load():
    print("Load embeddings.")
    passage_np = obj_reader(PASSAGE_NP_PATH)
    pid_mapping = obj_reader(PASSAGE_MAP_PATH)
    query_np = obj_reader(QUERY_TRAIN_NP_PATH)
    qid_mapping = obj_reader(QUERY_MAP_PATH)
    print("Mapping ids.")
    query_dict = map_id(query_np, qid_mapping)
    passage_dict = map_id(passage_np, pid_mapping)
    print("Load training data.")
    train_pos_dict, train_neg_dict = load_train(TRAIN_RANK_PATH)
    return train_pos_dict, train_neg_dict, query_dict, passage_dict

In [2]:
# train_pos_dict, train_neg_dict = load_train(TRAIN_RANK_PATH)

In [3]:
train_pos_dict, train_neg_dict, query_dict, passage_dict = load()

Load embeddings.
Mapping ids.
Load training data.


## Network Architecture: network.py

In [4]:
import torch
import torch.nn as nn

NUM_HIDDEN_NODES = 1536
NUM_HIDDEN_LAYERS = 3
DROPOUT_RATE = 0.1
FEAT_COUNT = 768


# Define the network
class CorpusNet(torch.nn.Module):

    def __init__(self, embed_size):
        super(CorpusNet, self).__init__()

        layers = []
        last_dim = FEAT_COUNT
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


## Train reverse ranker: train.py

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import torch.nn.functional as F


TOP_K = 100

# With probability alpha
# Select a random negative sample from train_neg_dict
ALPHA = 0.5

def dot_product(A, B, normalize=False):
    if normalize:
        A = F.normalize(A)
        B = F.normalize(B)
    b = A.shape[0]
    embed = A.shape[1]
    result = torch.bmm(A.view(b, 1, embed), B.view(b, embed, 1))
    return result


def mini_batch(batch_size, device, train_pos_dict, train_neg_dict, query_dict, passage_dict):
    passage_list = list(train_neg_dict.keys())
    passages = []
    pos = []
    neg = []
    pos_rank_list = []
    neg_rank_list = []
    while len(passages) < batch_size:
        pid = random.sample(passage_list, 1)[0]
        try:
            temp_pos_list = list(train_pos_dict[pid].keys())
        except:
            continue
        try:
            temp_neg_list = list(train_neg_dict[pid].keys())
        except:
            continue
        if np.random.uniform(0,1,1) <= ALPHA:
            random_positive = random.sample(temp_pos_list, 1)
            pos_qid = random_positive[0]
            pos_rank = train_pos_dict[pid][pos_qid]
            random_negative = random.sample(temp_neg_list, 1)
            neg_qid = random_negative[0]
            neg_rank = train_neg_dict[pid][neg_qid]
#             not_negative = True
#             while not_negative:
#                 temp_neg_qid = random.sample(list(query_dict.keys()), 1)
#                 if temp_neg_qid not in temp_query_list:
#                     neg_qid = temp_neg_qid[0]
#                     neg_rank = 1000
#                     not_negative = False
        else:
            if len(temp_pos_list) < 2:
                continue
            pos_neg_pair = random.sample(temp_pos_list, 2)
            # e.g. 60 >= 3
            if train_pos_dict[pid][pos_neg_pair[0]] >= train_pos_dict[pid][pos_neg_pair[1]]:
                pos_qid = pos_neg_pair[1]
                neg_qid = pos_neg_pair[0]
            # e.g. 3 < 60
            else:
                pos_qid = pos_neg_pair[0]
                neg_qid = pos_neg_pair[1]   
            pos_rank = train_pos_dict[pid][pos_qid]
            neg_rank = train_pos_dict[pid][neg_qid]
        p_seq = passage_dict[pid]
        pos_seq = query_dict[pos_qid]
        neg_seq = query_dict[neg_qid]
        passages.append(p_seq)
        pos.append(pos_seq)
        neg.append(neg_seq)
        pos_rank_list.append(TOP_K - pos_rank)
        neg_rank_list.append(TOP_K - neg_rank)
#         pos_rank_list.append((TOP_K - pos_rank) * 2)
#         neg_rank_list.append((TOP_K - neg_rank) * 2)
    labels = torch.stack([torch.FloatTensor(pos_rank_list), torch.FloatTensor(neg_rank_list)], dim=1)
    passages = torch.from_numpy(np.stack(passages))
    pos = torch.from_numpy(np.stack(pos))
    neg = torch.from_numpy(np.stack(neg))
    return passages.to(device), pos.to(device), neg.to(device), labels.to(device)


def train(net, epoch_size, batch_size, optimizer, device, train_pos_dict, train_neg_dict, 
          query_dict, passage_dict, scale=10, loss_option="ce"):
    bce = nn.BCELoss()
    ce = nn.CrossEntropyLoss()
    softmax = nn.Softmax(dim=1)
    train_loss = 0.0
    net.train()
    for mb_idx in range(epoch_size):
        # Read in a new mini-batch of data!
        passages, pos, neg, labels = mini_batch(batch_size, device, train_pos_dict, train_neg_dict, 
                                                query_dict, passage_dict)
        optimizer.zero_grad()
        p_embed = net(passages).to(device)
        pos_embed = net(pos).to(device)
        neg_embed = net(neg).to(device)
        out_pos = dot_product(p_embed, pos_embed).to(device)
        out_neg = dot_product(p_embed, neg_embed).to(device)
        out = torch.cat((out_pos, out_neg), -1).squeeze()
#         out = torch.cat((out_pos, out_neg), -1) * torch.tensor([scale], dtype=torch.float).to(device)
#         print(softmax(out))
#         print(labels)
#         loss = criterion(softmax(out).squeeze(), softmax(labels))
        if loss_option == "bce":
            loss = bce(softmax(out), softmax(labels))
        if loss_option == "ce":
            loss = ce(out, torch.tensor([0 for i in range(batch_size)]).to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # print(str(mb_idx) + " iteration: " + str(train_loss / (mb_idx + 1)))
    return train_loss / epoch_size


## Main function: main.py

In [7]:
import torch
from torch import optim
import csv
import sys
import os


MODEL_PATH = "/datadrive/ruohan/results/transformation/"
CURRENT_DEVICE = "cuda:0"

if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)


def main(num_epochs, epoch_size, batch_size, learning_rate, model_path, embed_size, pretrained=False):
    if pretrained:
        net = CorpusNet(embed_size=embed_size)
        net.load_state_dict(torch.load("/home/ruohan/DSSM/search-exposure/reverse_ranker/results/reverse_fine_tune1000_100_1000_0.0001_32.model"))
        net.to(CURRENT_DEVICE)
    else:
        net = CorpusNet(embed_size=embed_size).to(CURRENT_DEVICE)
    print("Loading data")
#     train_rank_dict, query_dict, passage_dict = load()
#     print("Data successfully loaded.")
#     print("Positive Negative Pair dict size: " + str(len(train_rank_dict)))
#     print("Num of queries: " + str(len(query_dict)))
#     print("Num of passages: " + str(len(passage_dict)))
#     print("Finish loading.")

    arg_str = "reverse_transformation_alpha0.5" + str(num_epochs) + "_" + str(epoch_size) + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(
        embed_size)
    unique_path = model_path + arg_str + ".model"
    output_path = model_path + arg_str + ".csv"
    for ep_idx in range(num_epochs):
        optimizer = optim.Adam(net.parameters(), lr=learning_rate)
        train_loss = train(net, epoch_size, batch_size, optimizer, CURRENT_DEVICE, train_pos_dict, 
                           train_neg_dict, query_dict, passage_dict)
        print(ep_idx,train_loss)
        with open(output_path, mode='a+') as output:
            output_writer = csv.writer(output)
            output_writer.writerow([ep_idx, train_loss])
        torch.save(net.state_dict(), unique_path)

In [9]:
# Phase 2 training on new training data based on 10000 passages
main(1000,100,1000,0.0001,MODEL_PATH,768)

Loading data
0 1.5841199100017547
1 1.2224436140060424
2 1.1239557188749314
3 1.0650203430652618
4 1.0230531549453736
5 0.9853734505176545
6 0.9481772208213806
7 0.9276855325698853
8 0.9114906793832779
9 0.8874350064992904
10 0.8655343216657638
11 0.8504319810867309
12 0.8310783970355987
13 0.820599313378334
14 0.7957104051113129
15 0.7754833590984345
16 0.7557948023080826
17 0.7494232040643692
18 0.7265695065259934
19 0.701999683380127
20 0.6834408783912659
21 0.6537602281570435
22 0.59505859375
23 0.5584335029125214
24 0.549400560259819
25 0.5413497179746628
26 0.5350819587707519
27 0.5316854470968246
28 0.5311836129426957
29 0.5264308962225914
30 0.5239516475796699
31 0.5239113008975983
32 0.522781457901001
33 0.5175915187597275
34 0.5154512268304825
35 0.5161955758929253
36 0.5155539497733116
37 0.5120792472362519
38 0.5121008962392807
39 0.5114485618472099
40 0.5107032564282418
41 0.5064786043763161
42 0.5065061470866203
43 0.5040673565864563
44 0.5028937247395515
45 0.50429358661

356 0.4020874750614166
357 0.3999578863382339
358 0.3975589391589165
359 0.3988784793019295
360 0.400655078291893
361 0.3994387823343277
362 0.39858532905578614
363 0.4002086138725281
364 0.39793529838323594
365 0.3989852511882782
366 0.4001371932029724
367 0.39970969140529633
368 0.3994301795959473
369 0.39590361624956133
370 0.3968467849493027
371 0.39943049877882003
372 0.39888779193162915
373 0.3977742424607277
374 0.39698962181806563
375 0.3956724187731743
376 0.3981036850810051
377 0.3959819310903549
378 0.3951424723863602
379 0.3954252102971077
380 0.3981539922952652
381 0.3928555351495743
382 0.395290738940239
383 0.3944598540663719
384 0.3972181913256645
385 0.3946239671111107
386 0.39692767947912216
387 0.39642723083496095
388 0.3951587095856667
389 0.3930241709947586
390 0.3920138227939606
391 0.39363288044929506
392 0.39314539819955824
393 0.39108975619077685
394 0.39407961785793305
395 0.3917383563518524
396 0.3932772585749626
397 0.39201695680618287
398 0.3952228465676308

707 0.3451260855793953
708 0.3452274280786514
709 0.34324182480573656
710 0.34428752064704893
711 0.34053148061037064
712 0.34447680830955507
713 0.34175893276929853
714 0.3428250840306282
715 0.34347822576761244
716 0.34309678345918654
717 0.3404114171862602
718 0.3401599684357643
719 0.3412065124511719
720 0.34410092890262606
721 0.3419762387871742
722 0.3402620804309845
723 0.341599662899971
724 0.3409479346871376
725 0.33753475546836853
726 0.34021723389625547
727 0.33897086352109906
728 0.3395370337367058
729 0.3401792931556702
730 0.3419737759232521
731 0.3394516581296921
732 0.33821287006139755
733 0.3402704209089279
734 0.3408271783590317
735 0.3419508385658264
736 0.33993827998638154
737 0.3366083163022995
738 0.3377618628740311
739 0.3408259290456772
740 0.3410904538631439
741 0.33955041229724886
742 0.3373316180706024
743 0.3372531354427338
744 0.33591152310371397
745 0.335827451646328
746 0.33843241691589354
747 0.33583991587162015
748 0.3362657940387726
749 0.3360169872641