In [None]:
# 1. Load data: passage_embeddings, query_train_embeddings, train_data (a proportion of)
# 2. Network input: 768 output: 32
# 3. Train: concatenate original 768 with 32 

## Import data: load_data.py

In [1]:
import csv
import pickle
import numpy as np

PASSAGE_NP_PATH = "/home/jianx/results/passage_0__emb_p__data_obj_0.pb"
PASSAGE_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_passage_map.dict"
QUERY_TRAIN_NP_PATH = "/home/jianx/results/query_0__emb_p__data_obj_0.pb"
QUERY_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_query_train_map.dict"
TRAIN_RANK_PATH = "/datadrive/jianx/data/train_data/ance_training_rank100_200000_random.csv"
# TRAIN_RANK_PATH = "/datadrive/ruohan/reverse_ranker/new_training/combine_rank_train_phase2.csv"


OUT_RANK = 200
N_PASSAGE = 200000
def obj_reader(path):
    with open(path, 'rb') as handle:
        return pickle.load(handle, encoding="bytes")
def load_train(path):
    with open(path, "r") as file:
        pos_dict = {}
        neg_dict = {}
        count = 0
        for line in file:
#             if count >= N_PASSAGE * 100:
#                 break
#             count += 1
            tokens = line.split(",")
            pid = int(tokens[0])
            qid = int(tokens[1])
            rank = int(tokens[2].rstrip())
            if rank == 0:
                if pid not in neg_dict:
                    neg_dict[pid] = {}
                neg_dict[pid][qid] = OUT_RANK
            else:
                if pid not in pos_dict:
                    pos_dict[pid] = {}
                pos_dict[pid][qid] = rank
    return pos_dict, neg_dict
def map_id(old_np, mapping):
    new_dict = dict(zip(mapping.values(),old_np))
    return new_dict
def load():
    print("Load embeddings.")
    passage_np = obj_reader(PASSAGE_NP_PATH)
    pid_mapping = obj_reader(PASSAGE_MAP_PATH)
    query_np = obj_reader(QUERY_TRAIN_NP_PATH)
    qid_mapping = obj_reader(QUERY_MAP_PATH)
    print("Mapping ids.")
    query_dict = map_id(query_np, qid_mapping)
    passage_dict = map_id(passage_np, pid_mapping)
    print("Load training data.")
    train_pos_dict, train_neg_dict = load_train(TRAIN_RANK_PATH)
    return train_pos_dict, train_neg_dict, query_dict, passage_dict

In [2]:
train_pos_dict, train_neg_dict, query_dict, passage_dict = load()

Load embeddings.
Mapping ids.
Load training data.


## Network Architecture: network.py

In [10]:
import torch
import torch.nn as nn

NUM_HIDDEN_NODES = 1536
NUM_HIDDEN_LAYERS = 10
DROPOUT_RATE = 0.1
    
# Define the network
class ResidualNet(torch.nn.Module):

    def __init__(self, embed_size):
        super(ResidualNet, self).__init__()
        
        self.input = nn.Linear(embed_size, NUM_HIDDEN_NODES)
        self.relu = nn.ReLU()
        self.normlayer = nn.LayerNorm(NUM_HIDDEN_NODES)
        self.dropout = nn.Dropout(p=DROPOUT_RATE)
        self.output = nn.Linear(NUM_HIDDEN_NODES, FEAT_COUNT)

    def forward(self, x):
        identity = x
        out = x
        for i in range(NUM_HIDDEN_LAYERS):
            out = self.input(out)
            out = self.relu(out)
            out = self.normlayer(out)
            out = self.dropout(out)
            out = self.output(out)
            out += identity
            out = self.relu(out)
        return out

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

## Train reverse ranker: train.py

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import torch.nn.functional as F


TOP_K = 100

# With probability alpha
# Select a random negative sample from train_neg_dict
ALPHA = 0.5

def dot_product(A, B, normalize=False):
    if normalize:
        A = F.normalize(A)
        B = F.normalize(B)
    b = A.shape[0]
    embed = A.shape[1]
    result = torch.bmm(A.view(b, 1, embed), B.view(b, embed, 1))
    return result


def mini_batch(batch_size, device, train_pos_dict, train_neg_dict, query_dict, passage_dict):
    passage_list = list(train_neg_dict.keys())
    passages = []
    pos = []
    neg = []
    pos_rank_list = []
    neg_rank_list = []
    while len(passages) < batch_size:
        pid = random.sample(passage_list, 1)[0]
        try:
            temp_pos_list = list(train_pos_dict[pid].keys())
        except:
            continue
        try:
            temp_neg_list = list(train_neg_dict[pid].keys())
        except:
            continue
        if np.random.uniform(0,1,1) <= ALPHA:
            random_positive = random.sample(temp_pos_list, 1)
            pos_qid = random_positive[0]
            pos_rank = train_pos_dict[pid][pos_qid]
            random_negative = random.sample(temp_neg_list, 1)
            neg_qid = random_negative[0]
            neg_rank = train_neg_dict[pid][neg_qid]
#             not_negative = True
#             while not_negative:
#                 temp_neg_qid = random.sample(list(query_dict.keys()), 1)
#                 if temp_neg_qid not in temp_query_list:
#                     neg_qid = temp_neg_qid[0]
#                     neg_rank = 1000
#                     not_negative = False
        else:
            if len(temp_pos_list) < 2:
                continue
            pos_neg_pair = random.sample(temp_pos_list, 2)
            # e.g. 60 >= 3
            if train_pos_dict[pid][pos_neg_pair[0]] >= train_pos_dict[pid][pos_neg_pair[1]]:
                pos_qid = pos_neg_pair[1]
                neg_qid = pos_neg_pair[0]
            # e.g. 3 < 60
            else:
                pos_qid = pos_neg_pair[0]
                neg_qid = pos_neg_pair[1]   
            pos_rank = train_pos_dict[pid][pos_qid]
            neg_rank = train_pos_dict[pid][neg_qid]
        p_seq = passage_dict[pid]
        pos_seq = query_dict[pos_qid]
        neg_seq = query_dict[neg_qid]
        passages.append(p_seq)
        pos.append(pos_seq)
        neg.append(neg_seq)
        pos_rank_list.append(TOP_K - pos_rank)
        neg_rank_list.append(TOP_K - neg_rank)
#         pos_rank_list.append((TOP_K - pos_rank) * 2)
#         neg_rank_list.append((TOP_K - neg_rank) * 2)
    labels = torch.stack([torch.FloatTensor(pos_rank_list), torch.FloatTensor(neg_rank_list)], dim=1)
    passages = torch.from_numpy(np.stack(passages))
    pos = torch.from_numpy(np.stack(pos))
    neg = torch.from_numpy(np.stack(neg))
    return passages.to(device), pos.to(device), neg.to(device), labels.to(device)


def train(net, epoch_size, batch_size, optimizer, device, train_pos_dict, train_neg_dict, 
          query_dict, passage_dict, scale=10, loss_option="ce"):
    bce = nn.BCELoss()
    ce = nn.CrossEntropyLoss()
    softmax = nn.Softmax(dim=1)
    train_loss = 0.0
    net.train()
    for mb_idx in range(epoch_size):
        # Read in a new mini-batch of data!
        passages, pos, neg, labels = mini_batch(batch_size, device, train_pos_dict, train_neg_dict, 
                                                query_dict, passage_dict)
        optimizer.zero_grad()
        p_embed = net(passages).to(device)
        pos_embed = net(pos).to(device)
        neg_embed = net(neg).to(device)
        out_pos = dot_product(p_embed, pos_embed).to(device)
        out_neg = dot_product(p_embed, neg_embed).to(device)
        out = torch.cat((out_pos, out_neg), -1).squeeze()
        if loss_option == "bce":
            loss = bce(softmax(out), softmax(labels))
        if loss_option == "ce":
            loss = ce(out, torch.tensor([0 for i in range(batch_size)]).to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # print(str(mb_idx) + " iteration: " + str(train_loss / (mb_idx + 1)))
    return train_loss / epoch_size


## Main function: main.py

In [16]:
import datetime
from datetime import datetime, timezone, timedelta

TIME_OFFSET = -4


def print_message(s, offset=TIME_OFFSET):
    print("[{}] {}".format(datetime.now(timezone(timedelta(hours=offset))).strftime("%b %d, %H:%M:%S"), s), flush=True)

In [20]:
import torch
from torch import optim
import csv
import sys
import os


MODEL_PATH = "/datadrive/ruohan/random_sample/"
CURRENT_DEVICE = "cuda:2"

if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)


def main(num_epochs, epoch_size, batch_size, learning_rate, model_path, embed_size, pretrained=False):
    if pretrained:
        net = ResidualNet(embed_size=embed_size)
        net.load_state_dict(torch.load("/home/ruohan/DSSM/search-exposure/reverse_ranker/results/reverse_fine_tune1000_100_1000_0.0001_32.model"))
        net.to(CURRENT_DEVICE)
    else:
        net = ResidualNet(embed_size=embed_size).to(CURRENT_DEVICE)
    print("Loading data")
#     train_rank_dict, query_dict, passage_dict = load()
#     print("Data successfully loaded.")
#     print("Positive Negative Pair dict size: " + str(len(train_rank_dict)))
#     print("Num of queries: " + str(len(query_dict)))
#     print("Num of passages: " + str(len(passage_dict)))
#     print("Finish loading.")

    arg_str = "reverse_alpha0.5_initial_residual_saveoptim" + str(num_epochs) + "_" + str(epoch_size) + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(
        embed_size)
    unique_path = model_path + arg_str + ".model"
    output_path = model_path + arg_str + ".csv"
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    for ep_idx in range(num_epochs):
        train_loss = train(net, epoch_size, batch_size, optimizer, CURRENT_DEVICE, train_pos_dict, 
                           train_neg_dict, query_dict, passage_dict)
        print_message([ep_idx,train_loss])
        with open(output_path, mode='a+') as output:
            output_writer = csv.writer(output)
            output_writer.writerow([ep_idx, train_loss])
        torch.save({
                    "model": net.state_dict(),
                    "optimizer": optimizer.state_dict()
                    }, unique_path)

In [22]:
# checkpoint = torch.load("/datadrive/ruohan/random_sample/reverse_alpha0.5_initial_residual_saveoptim1000_100_1000_0.0001_768.model")
# ttnet = ResidualNet(embed_size=768)
# ttoptimizer = optim.Adam(ttnet.parameters(), lr=0.0001)
# ttnet.load_state_dict(checkpoint["model"])
# ttoptimizer.load_state_dict(checkpoint["optimizer"])

In [18]:
# Train on randomly sampled training data that contains 200,000 passages
main(1000,100,1000,0.0001,MODEL_PATH,768)

Loading data
[Aug 21, 22:27:49] [0, 1.049621695280075]
[Aug 21, 22:28:02] [1, 0.7177173060178756]
[Aug 21, 22:28:14] [2, 0.6477977627515793]
[Aug 21, 22:28:26] [3, 0.6193814289569854]
[Aug 21, 22:28:39] [4, 0.5911640566587448]
[Aug 21, 22:28:51] [5, 0.5795909243822098]
[Aug 21, 22:29:03] [6, 0.5672692477703094]
[Aug 21, 22:29:15] [7, 0.5547531509399414]
[Aug 21, 22:29:28] [8, 0.5422469416260719]
[Aug 21, 22:29:40] [9, 0.5309646573662757]
[Aug 21, 22:29:53] [10, 0.5194277182221413]
[Aug 21, 22:30:06] [11, 0.5122983610630035]
[Aug 21, 22:30:18] [12, 0.4996506205201149]
[Aug 21, 22:30:31] [13, 0.49126205772161485]
[Aug 21, 22:30:44] [14, 0.4788512253761292]
[Aug 21, 22:30:57] [15, 0.47239673852920533]
[Aug 21, 22:31:10] [16, 0.45934064537286756]
[Aug 21, 22:31:22] [17, 0.44842516869306565]
[Aug 21, 22:31:35] [18, 0.4428005975484848]
[Aug 21, 22:31:47] [19, 0.4365400493144989]
[Aug 21, 22:32:00] [20, 0.4327924346923828]
[Aug 21, 22:32:12] [21, 0.4263707169890404]
[Aug 21, 22:32:25] [22, 0.

[Aug 21, 23:05:35] [183, 0.2619546756148338]
[Aug 21, 23:05:48] [184, 0.2589871872961521]
[Aug 21, 23:06:00] [185, 0.26194216147065164]
[Aug 21, 23:06:12] [186, 0.2586447003483772]
[Aug 21, 23:06:25] [187, 0.2595504388213158]
[Aug 21, 23:06:38] [188, 0.26039090543985366]
[Aug 21, 23:06:50] [189, 0.25785535082221034]
[Aug 21, 23:07:04] [190, 0.25605717569589614]
[Aug 21, 23:07:16] [191, 0.2570407317578793]
[Aug 21, 23:07:29] [192, 0.2567103365063667]
[Aug 21, 23:07:42] [193, 0.2581514598429203]
[Aug 21, 23:07:54] [194, 0.25586354956030843]
[Aug 21, 23:08:07] [195, 0.2583950892090797]
[Aug 21, 23:08:19] [196, 0.25590036883950235]
[Aug 21, 23:08:32] [197, 0.256178684681654]
[Aug 21, 23:08:45] [198, 0.25642402544617654]
[Aug 21, 23:08:57] [199, 0.25481263399124143]
[Aug 21, 23:09:09] [200, 0.2542850385606289]
[Aug 21, 23:09:22] [201, 0.25223878398537636]
[Aug 21, 23:09:35] [202, 0.2521707198023796]
[Aug 21, 23:09:47] [203, 0.25471794575452805]
[Aug 21, 23:10:00] [204, 0.25248051032423974]


[Aug 21, 23:44:19] [363, 0.21622635141015054]
[Aug 21, 23:44:32] [364, 0.2176140959560871]
[Aug 21, 23:44:45] [365, 0.21654012456536292]
[Aug 21, 23:44:58] [366, 0.21809410020709039]
[Aug 21, 23:45:11] [367, 0.21592803567647934]
[Aug 21, 23:45:24] [368, 0.21793582245707513]
[Aug 21, 23:45:37] [369, 0.2159090718626976]
[Aug 21, 23:45:50] [370, 0.21240109488368034]
[Aug 21, 23:46:03] [371, 0.21402581334114074]
[Aug 21, 23:46:16] [372, 0.2155788266658783]
[Aug 21, 23:46:30] [373, 0.2145340174436569]
[Aug 21, 23:46:43] [374, 0.2133140343427658]
[Aug 21, 23:46:56] [375, 0.21598318442702294]
[Aug 21, 23:47:09] [376, 0.21628866508603095]
[Aug 21, 23:47:22] [377, 0.21487857267260552]
[Aug 21, 23:47:35] [378, 0.21628852397203446]
[Aug 21, 23:47:49] [379, 0.2154451847076416]
[Aug 21, 23:48:02] [380, 0.2136365845799446]
[Aug 21, 23:48:15] [381, 0.2157372437417507]
[Aug 21, 23:48:27] [382, 0.21391580179333686]
[Aug 21, 23:48:41] [383, 0.21455946043133736]
[Aug 21, 23:48:54] [384, 0.214136885255575

[Aug 22, 00:22:51] [543, 0.19448660150170327]
[Aug 22, 00:23:04] [544, 0.1966087044775486]
[Aug 22, 00:23:16] [545, 0.1974475434422493]
[Aug 22, 00:23:28] [546, 0.19710069492459298]
[Aug 22, 00:23:40] [547, 0.19502707958221435]
[Aug 22, 00:23:53] [548, 0.19229384645819664]
[Aug 22, 00:24:05] [549, 0.19603768050670622]
[Aug 22, 00:24:17] [550, 0.19597763001918792]
[Aug 22, 00:24:29] [551, 0.1960507233440876]
[Aug 22, 00:24:41] [552, 0.19354464188218118]
[Aug 22, 00:24:53] [553, 0.1959887807071209]
[Aug 22, 00:25:06] [554, 0.1939886762201786]
[Aug 22, 00:25:18] [555, 0.1978562805056572]
[Aug 22, 00:25:30] [556, 0.19618013754487038]
[Aug 22, 00:25:42] [557, 0.1953492605686188]
[Aug 22, 00:25:54] [558, 0.19435726955533028]
[Aug 22, 00:26:06] [559, 0.19382868304848672]
[Aug 22, 00:26:19] [560, 0.1972312794625759]
[Aug 22, 00:26:31] [561, 0.1950923989713192]
[Aug 22, 00:26:43] [562, 0.19465900510549544]
[Aug 22, 00:26:55] [563, 0.1975063343346119]
[Aug 22, 00:27:07] [564, 0.19462200239300728

[Aug 22, 00:59:27] [723, 0.1812646695971489]
[Aug 22, 00:59:40] [724, 0.1851092404127121]
[Aug 22, 00:59:52] [725, 0.18226028636097907]
[Aug 22, 01:00:04] [726, 0.18306563496589662]
[Aug 22, 01:00:16] [727, 0.1830461974442005]
[Aug 22, 01:00:28] [728, 0.1853538866341114]
[Aug 22, 01:00:41] [729, 0.1857893744111061]
[Aug 22, 01:00:53] [730, 0.1808656147122383]
[Aug 22, 01:01:05] [731, 0.18354333758354188]
[Aug 22, 01:01:17] [732, 0.18055806949734687]
[Aug 22, 01:01:29] [733, 0.1834531469643116]
[Aug 22, 01:01:42] [734, 0.18320471972227095]
[Aug 22, 01:01:54] [735, 0.18234472945332528]
[Aug 22, 01:02:06] [736, 0.18238206788897515]
[Aug 22, 01:02:18] [737, 0.1823534832894802]
[Aug 22, 01:02:30] [738, 0.18323394134640694]
[Aug 22, 01:02:43] [739, 0.183607407361269]
[Aug 22, 01:02:56] [740, 0.18312006443738937]
[Aug 22, 01:03:09] [741, 0.18349851548671722]
[Aug 22, 01:03:23] [742, 0.183738631606102]
[Aug 22, 01:03:36] [743, 0.1811194723844528]
[Aug 22, 01:03:49] [744, 0.18268521651625633]
[

[Aug 22, 01:36:06] [903, 0.17358581110835075]
[Aug 22, 01:36:19] [904, 0.17194602742791176]
[Aug 22, 01:36:31] [905, 0.17476639807224273]
[Aug 22, 01:36:43] [906, 0.1750466237962246]
[Aug 22, 01:36:55] [907, 0.17347493886947632]
[Aug 22, 01:37:07] [908, 0.1766258805990219]
[Aug 22, 01:37:20] [909, 0.17315524220466613]
[Aug 22, 01:37:32] [910, 0.1769346049427986]
[Aug 22, 01:37:44] [911, 0.17230464026331901]
[Aug 22, 01:37:56] [912, 0.17531048834323884]
[Aug 22, 01:38:08] [913, 0.17434409514069557]
[Aug 22, 01:38:20] [914, 0.17454180911183356]
[Aug 22, 01:38:33] [915, 0.17618634641170502]
[Aug 22, 01:38:45] [916, 0.17166389867663384]
[Aug 22, 01:38:57] [917, 0.1725667491555214]
[Aug 22, 01:39:09] [918, 0.17386890232563018]
[Aug 22, 01:39:21] [919, 0.17415894210338592]
[Aug 22, 01:39:33] [920, 0.1749108675122261]
[Aug 22, 01:39:46] [921, 0.17518358469009399]
[Aug 22, 01:39:58] [922, 0.1738073581457138]
[Aug 22, 01:40:10] [923, 0.17262760147452355]
[Aug 22, 01:40:22] [924, 0.1742000174522