In [None]:
# 1. Load data: passage_embeddings, query_train_embeddings, train_data (a proportion of)
# 2. Network input: 768 output: 32
# 3. Train: concatenate original 768 with 32 

## Import data: load_data.py

In [147]:
import csv
import pickle
import numpy as np

PASSAGE_NP_PATH = "/home/jianx/results/passage_0__emb_p__data_obj_0.pb"
PASSAGE_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_passage_map.dict"
QUERY_TRAIN_NP_PATH = "/home/jianx/results/query_0__emb_p__data_obj_0.pb"
QUERY_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_query_train_map.dict"
TRAIN_RANK_PATH = "/datadrive/jianx/data/train_data/ance_training_rank100_8841823.csv"

OUT_RANK = 200
N_PASSAGE = 200000
def obj_reader(path):
    with open(path, 'rb') as handle:
        return pickle.load(handle, encoding="bytes")
def load_train(path):
    with open(path, "r") as file:
        pos_dict = {}
        neg_dict = {}
        count = 0
        for line in file:
            if count >= N_PASSAGE * 100:
                break
            count += 1
            tokens = line.split(",")
            pid = int(tokens[0])
            qid = int(tokens[1])
            rank = int(tokens[2].rstrip())
            if rank == 0:
                if pid not in neg_dict:
                    neg_dict[pid] = {}
                neg_dict[pid][qid] = OUT_RANK
            else:
                if pid not in pos_dict:
                    pos_dict[pid] = {}
                pos_dict[pid][qid] = rank
    return pos_dict, neg_dict
def map_id(old_np, mapping):
    new_dict = dict(zip(mapping.values(),old_np))
    return new_dict
def load():
    print("Load embeddings.")
    passage_np = obj_reader(PASSAGE_NP_PATH)
    pid_mapping = obj_reader(PASSAGE_MAP_PATH)
    query_np = obj_reader(QUERY_TRAIN_NP_PATH)
    qid_mapping = obj_reader(QUERY_MAP_PATH)
    print("Mapping ids.")
    query_dict = map_id(query_np, qid_mapping)
    passage_dict = map_id(passage_np, pid_mapping)
    print("Load training data.")
    train_pos_dict, train_neg_dict = load_train(TRAIN_RANK_PATH)
    return train_pos_dict, train_neg_dict, query_dict, passage_dict

In [148]:
train_pos_dict, train_neg_dict = load_train(TRAIN_RANK_PATH)

In [2]:
train_pos_dict, train_neg_dict, query_dict, passage_dict = load()

Load embeddings.
Mapping ids.
Load training data.


## Network Architecture: network.py

In [63]:
import torch
import torch.nn as nn

NUM_HIDDEN_NODES = 64
NUM_HIDDEN_LAYERS = 3
DROPOUT_RATE = 0.1
FEAT_COUNT = 768


# Define the network
class CorpusNet(torch.nn.Module):

    def __init__(self, embed_size):
        super(CorpusNet, self).__init__()

        layers = []
        last_dim = FEAT_COUNT
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=DROPOUT_RATE))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


## Train reverse ranker: train.py

In [152]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import torch.nn.functional as F


TOP_K = 100

# With probability alpha
# Select a random negative sample from train_neg_dict
ALPHA = 0.5

def dot_product(A, B, normalize=False):
    if normalize:
        A = F.normalize(A)
        B = F.normalize(B)
    b = A.shape[0]
    embed = A.shape[1]
    result = torch.bmm(A.view(b, 1, embed), B.view(b, embed, 1))
    return result


def mini_batch(batch_size, device, train_pos_dict, train_neg_dict, query_dict, passage_dict):
    passage_list = list(train_rank_dict.keys())
    passages = []
    pos = []
    neg = []
    pos_rank_list = []
    neg_rank_list = []
    while len(passages) < batch_size:
        pid = random.sample(passage_list, 1)[0]
        try:
            temp_pos_list = list(train_pos_dict[pid].keys())
        except:
            continue
        try:
            temp_neg_list = list(train_neg_dict[pid].keys())
        except:
            continue
        if np.random.uniform(0,1,1) <= ALPHA:
            random_positive = random.sample(temp_pos_list, 1)
            pos_qid = random_positive[0]
            pos_rank = train_pos_dict[pid][pos_qid]
            random_negative = random.sample(temp_neg_list, 1)
            neg_qid = random_negative[0]
            neg_rank = train_neg_dict[pid][neg_qid]
#             not_negative = True
#             while not_negative:
#                 temp_neg_qid = random.sample(list(query_dict.keys()), 1)
#                 if temp_neg_qid not in temp_query_list:
#                     neg_qid = temp_neg_qid[0]
#                     neg_rank = 1000
#                     not_negative = False
        else:
            if len(temp_pos_list) < 2:
                continue
            pos_neg_pair = random.sample(temp_pos_list, 2)
            # e.g. 60 >= 3
            if train_pos_dict[pid][pos_neg_pair[0]] >= train_pos_dict[pid][pos_neg_pair[1]]:
                pos_qid = pos_neg_pair[1]
                neg_qid = pos_neg_pair[0]
            # e.g. 3 < 60
            else:
                pos_qid = pos_neg_pair[0]
                neg_qid = pos_neg_pair[1]   
            pos_rank = train_pos_dict[pid][pos_qid]
            neg_rank = train_pos_dict[pid][neg_qid]
        p_seq = passage_dict[pid]
        pos_seq = query_dict[pos_qid]
        neg_seq = query_dict[neg_qid]
        passages.append(p_seq)
        pos.append(pos_seq)
        neg.append(neg_seq)
        pos_rank_list.append(TOP_K - pos_rank)
        neg_rank_list.append(TOP_K - neg_rank)
#         pos_rank_list.append((TOP_K - pos_rank) * 2)
#         neg_rank_list.append((TOP_K - neg_rank) * 2)
    labels = torch.stack([torch.FloatTensor(pos_rank_list), torch.FloatTensor(neg_rank_list)], dim=1)
    passages = torch.from_numpy(np.stack(passages))
    pos = torch.from_numpy(np.stack(pos))
    neg = torch.from_numpy(np.stack(neg))
    return passages.to(device), pos.to(device), neg.to(device), labels.to(device)


def train(net, epoch_size, batch_size, optimizer, device, train_pos_dict, train_neg_dict, 
          query_dict, passage_dict, scale=10, loss_option="ce"):
    bce = nn.BCELoss()
    ce = nn.CrossEntropyLoss()
    softmax = nn.Softmax(dim=1)
    train_loss = 0.0
    net.train()
    for mb_idx in range(epoch_size):
        # Read in a new mini-batch of data!
        passages, pos, neg, labels = mini_batch(batch_size, device, train_pos_dict, train_neg_dict, 
                                                query_dict, passage_dict)
        optimizer.zero_grad()
        p_embed = torch.cat((net(passages), passages), 1).to(device)
        pos_embed = torch.cat((net(pos), pos), 1).to(device)
        neg_embed = torch.cat((net(neg), neg), 1).to(device)
        out_pos = dot_product(p_embed, pos_embed).to(device)
        out_neg = dot_product(p_embed, neg_embed).to(device)
        out = torch.cat((out_pos, out_neg), -1).squeeze()
#         out = torch.cat((out_pos, out_neg), -1) * torch.tensor([scale], dtype=torch.float).to(device)
#         print(softmax(out))
#         print(labels)
#         loss = criterion(softmax(out).squeeze(), softmax(labels))
        if loss_option == "bce":
            loss = bce(softmax(out), softmax(labels))
        if loss_option == "ce":
            loss = ce(out, torch.tensor([0 for i in range(batch_size)]).to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # print(str(mb_idx) + " iteration: " + str(train_loss / (mb_idx + 1)))
    return train_loss / epoch_size


## Main function: main.py

In [159]:
import torch
from torch import optim
import csv
import sys
import os


MODEL_PATH = "./results/"
CURRENT_DEVICE = "cuda:0"

if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)


def main(num_epochs, epoch_size, batch_size, learning_rate, model_path, embed_size, pretrained=True):
    if pretrained:
        net = CorpusNet(embed_size=embed_size)
        net.load_state_dict(torch.load("/home/ruohan/DSSM/search-exposure/reverse_ranker/results/reverse_corpus_features_not_normalize1000_100_1000_0.0001_32.model"))
        net.to(CURRENT_DEVICE)
    else:
        net = CorpusNet(embed_size=embed_size).to(CURRENT_DEVICE)
    print("Loading data")
#     train_rank_dict, query_dict, passage_dict = load()
#     print("Data successfully loaded.")
#     print("Positive Negative Pair dict size: " + str(len(train_rank_dict)))
#     print("Num of queries: " + str(len(query_dict)))
#     print("Num of passages: " + str(len(passage_dict)))
#     print("Finish loading.")

    arg_str = "reverse_fine_tune" + str(num_epochs) + "_" + str(epoch_size) + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(
        embed_size)
    unique_path = model_path + arg_str + ".model"
    output_path = model_path + arg_str + ".csv"
    for ep_idx in range(num_epochs):
        optimizer = optim.Adam(net.parameters(), lr=learning_rate)
        train_loss = train(net, epoch_size, batch_size, optimizer, CURRENT_DEVICE, train_pos_dict, 
                           train_neg_dict, query_dict, passage_dict)
        print(ep_idx,train_loss)
        with open(output_path, mode='a+') as output:
            output_writer = csv.writer(output)
            output_writer.writerow([ep_idx, train_loss])
        torch.save(net.state_dict(), unique_path)


In [160]:
# Continue to train, to see if loss further decrease
main(1000,100,1000,0.0001,MODEL_PATH,32)

Loading data
0 0.09242286026477814
1 0.09362810656428337
2 0.09357407428324223
3 0.09244837991893291
4 0.0922143154591322
5 0.09312672398984433
6 0.09255419731140137
7 0.09330147758126259
8 0.09183945842087268
9 0.0923061215877533
10 0.09331501424312591
11 0.09389676533639431
12 0.09220728158950806
13 0.0934059551358223
14 0.09242874532938003
15 0.08936845920979977
16 0.09358229465782643
17 0.0932971765100956
18 0.09298205085098743
19 0.08983258441090584
20 0.09362920552492142
21 0.09388238161802293
22 0.09274050042033195
23 0.09257986746728421
24 0.09082861796021462
25 0.09358493655920029
26 0.0940331619232893
27 0.09469860963523388
28 0.09065874092280865
29 0.09436166360974312
30 0.090460539534688
31 0.09320399314165115
32 0.0922974904626608
33 0.09338326439261437
34 0.09289215430617333
35 0.09165102533996106
36 0.09247908197343349
37 0.09109776228666305
38 0.09297090299427509
39 0.09061546042561532
40 0.09364765845239162
41 0.09124253802001477
42 0.0920777302980423
43 0.091157613843

348 0.08840005818754434
349 0.08758257932960987
350 0.08764669790863991
351 0.08852054990828037
352 0.087467403113842
353 0.08855969846248626
354 0.0873168183863163
355 0.08884152792394161
356 0.08982530929148197
357 0.08833487659692764
358 0.08696822308003903
359 0.08612955443561077
360 0.08623002678155899
361 0.08688206970691681
362 0.08616493996232748
363 0.08994153238832951
364 0.08651953391730785
365 0.08965283393859863
366 0.08663481146097184
367 0.08650971673429013
368 0.08885869033634662
369 0.08760820113122464
370 0.08772467516362667
371 0.0860393825173378
372 0.08702977031469344
373 0.08879687972366809
374 0.08629834465682507
375 0.08494268119335174
376 0.08778778035193682
377 0.08750556483864784
378 0.08823110021650792
379 0.08798402734100819
380 0.08940647959709168
381 0.09006117597222328
382 0.0882234325632453
383 0.08694288432598114
384 0.08867004103958606
385 0.08724404513835907
386 0.0855836496502161
387 0.08738089181482792
388 0.08809353031218052
389 0.0878038528561592

692 0.08593089610338212
693 0.08439540445804596
694 0.08479158908128738
695 0.08597571238875389
696 0.08609658271074296
697 0.08410146668553352
698 0.08507859349250793
699 0.08612381160259247
700 0.08458093494176865
701 0.08359806101769209
702 0.08380773246288299
703 0.08342213846743107
704 0.0840354061126709
705 0.0852862548828125
706 0.08400484375655651
707 0.0832472226768732
708 0.08289869219064712
709 0.08572328556329012
710 0.08340738736093044
711 0.0842294618114829
712 0.08411695286631585
713 0.0826189986616373
714 0.0855903384834528
715 0.08457422606647015
716 0.08501401081681252
717 0.08355985596776008
718 0.08488374285399913
719 0.08419502928853034
720 0.08384193021804094
721 0.0843980248272419
722 0.08293031550943851
723 0.086374884955585
724 0.08400150552392006
725 0.0854254986345768
726 0.08551552921533584
727 0.08387076668441296
728 0.0857486842572689
729 0.08460838854312897
730 0.08115911364555359
731 0.0844218560308218
732 0.08391157917678356
733 0.08418139293789864
734 

In [154]:
main(1000,100,1000,0.0001,MODEL_PATH,32)

Loading data
0 0.56778594404459
1 0.5361327835917473
2 0.513755610883236
3 0.4606933206319809
4 0.4318121728301048
5 0.39857877910137174
6 0.38456533759832384
7 0.3615595731139183
8 0.35005480200052264
9 0.3339472904801369
10 0.3274097794294357
11 0.31392027348279955
12 0.3061116561293602
13 0.2963147297501564
14 0.28868519350886346
15 0.28368862375617027
16 0.2772076211869717
17 0.27109444692730905
18 0.2638201983273029
19 0.2643818971514702
20 0.2585298290848732
21 0.25518435552716257
22 0.24759986594319344
23 0.24508200466632843
24 0.2442743219435215
25 0.23841279581189156
26 0.23553462833166122
27 0.23538652062416077
28 0.23230824798345565
29 0.22786283761262893
30 0.2267448529601097
31 0.2245577844977379
32 0.222188633531332
33 0.21788629859685898
34 0.21683309748768806
35 0.21622372403740883
36 0.21220046132802964
37 0.2103882522881031
38 0.21280795738101005
39 0.2088419534265995
40 0.20935980409383773
41 0.2038043324649334
42 0.20264502257108688
43 0.2026878772675991
44 0.200514

350 0.11572491951286792
351 0.1157834180444479
352 0.11541165180504322
353 0.11835237592458725
354 0.11572006590664387
355 0.11635373078286648
356 0.11516583666205406
357 0.11767843320965767
358 0.1159132120013237
359 0.11479521587491036
360 0.11769385948777199
361 0.11547223642468453
362 0.11416971199214458
363 0.11566571094095707
364 0.11675478801131249
365 0.11468841753900051
366 0.11626176193356513
367 0.11194879800081253
368 0.11638198189437389
369 0.11616038069128991
370 0.11652151018381118
371 0.11586239628493786
372 0.11611321739852429
373 0.11362403310835362
374 0.11323442600667477
375 0.11348753213882447
376 0.11244064010679722
377 0.11321325078606606
378 0.11660580858588218
379 0.11450220353901386
380 0.11158235982060433
381 0.11370470486581326
382 0.112759865000844
383 0.11527853228151798
384 0.11476400703191757
385 0.11305238097906113
386 0.11304062373936176
387 0.11538522511720657
388 0.11602147392928601
389 0.11523488625884055
390 0.11493959866464137
391 0.11406769484281

694 0.10067254401743413
695 0.09762756079435349
696 0.10070199310779572
697 0.09989629112184048
698 0.1007743313908577
699 0.1019477216899395
700 0.09909678183495998
701 0.09973478391766548
702 0.09986857138574123
703 0.0987181182205677
704 0.10089762762188911
705 0.10176450788974761
706 0.09881762765347958
707 0.09976345762610435
708 0.0985382516682148
709 0.0997630337625742
710 0.09871149957180023
711 0.1007556925714016
712 0.09974215231835842
713 0.10064550548791885
714 0.09863779850304127
715 0.09799405805766583
716 0.0999866197258234
717 0.09994803309440613
718 0.10101903811097145
719 0.09954906791448594
720 0.10006252840161324
721 0.10046975664794446
722 0.09902587212622166
723 0.09939764641225338
724 0.09887002363801002
725 0.09785522900521755
726 0.09995790056884289
727 0.09976165272295474
728 0.09825257569551468
729 0.09849015414714814
730 0.09816409789025783
731 0.10063843853771687
732 0.10002238862216473
733 0.09941036738455296
734 0.09705572806298733
735 0.0995917834341526
