In [None]:
# 1. Load data: passage_embeddings, query_train_embeddings, train_data (a proportion of)
# 2. Network input: 768 output: 32
# 3. Train: concatenate original 768 with 32 

## Import data: load_data.py

In [59]:
import csv
import pickle
import numpy as np

PASSAGE_NP_PATH = "/home/jianx/results/passage_0__emb_p__data_obj_0.pb"
PASSAGE_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_passage_map.dict"
QUERY_TRAIN_NP_PATH = "/home/jianx/results/query_0__emb_p__data_obj_0.pb"
QUERY_MAP_PATH = "/datadrive/jianx/data/annoy/100_ance_query_train_map.dict"
# TRAIN_RANK_PATH = "/datadrive/jianx/data/train_data/ance_training_rank100_nqueries250000_200000_random_Aug_24_19:43:52.csv"
# TRAIN_RANK_PATH = "/datadrive/ruohan/reverse_ranker/new_training/combine_rank_train_phase2.csv"
# TRAIN_RANK_PATH = "/datadrive/jianx/data/train_data/ance_training_rank100_nqueries50000_200000_Aug_26_20:59:48.csv"
# New one step random sampling
TRAIN_RANK_PATH = "/datadrive/jianx/data/train_data/ance_training_rank100_nqueries50000_200000_Sep_03_22:56:31.csv"

OUT_RANK = 200
N_PASSAGE = 200000
def obj_reader(path):
    with open(path, 'rb') as handle:
        return pickle.load(handle, encoding="bytes")
def load_train(path):
    with open(path, "r") as file:
        pos_dict = {}
        neg_dict = {}
        count = 0
        for line in file:
#             if count >= N_PASSAGE * 100:
#                 break
#             count += 1
            tokens = line.split(",")
            pid = int(tokens[0])
            qid = int(tokens[1])
            rank = int(tokens[2].rstrip())
            if rank == 0:
                if pid not in neg_dict:
                    neg_dict[pid] = {}
                neg_dict[pid][qid] = OUT_RANK
            else:
                if pid not in pos_dict:
                    pos_dict[pid] = {}
                pos_dict[pid][qid] = rank
    return pos_dict, neg_dict
def map_id(old_np, mapping):
    new_dict = dict(zip(mapping.values(),old_np))
    return new_dict
def load():
    print("Load embeddings.")
    passage_np = obj_reader(PASSAGE_NP_PATH)
    pid_mapping = obj_reader(PASSAGE_MAP_PATH)
    query_np = obj_reader(QUERY_TRAIN_NP_PATH)
    qid_mapping = obj_reader(QUERY_MAP_PATH)
    print("Mapping ids.")
    query_dict = map_id(query_np, qid_mapping)
    passage_dict = map_id(passage_np, pid_mapping)
    print("Load training data.")
    train_pos_dict, train_neg_dict = load_train(TRAIN_RANK_PATH)
    return train_pos_dict, train_neg_dict, query_dict, passage_dict

In [60]:
train_pos_dict, train_neg_dict = load_train(TRAIN_RANK_PATH)

In [61]:
len(train_neg_dict)

200000

In [2]:
train_pos_dict, train_neg_dict, query_dict, passage_dict = load()

Load embeddings.
Mapping ids.
Load training data.


In [29]:
# def obj_writer(obj, path):
#     with open(path, 'wb') as handle:
#         pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [30]:
# obj_writer(train_pos_dict, "/datadrive/ruohan/data/train_pos_dict.pb")
# obj_writer(train_neg_dict, "/datadrive/ruohan/data/train_neg_dict.pb")

## Network Architecture: network.py

In [75]:
import torch
import torch.nn as nn

NUM_HIDDEN_NODES = 1536
NUM_HIDDEN_LAYERS = 1
DROPOUT_RATE = 0.1
    
# Define the network
class ResidualNet(torch.nn.Module):

    def __init__(self, embed_size):
        super(ResidualNet, self).__init__()
        
        self.input = nn.Linear(embed_size, NUM_HIDDEN_NODES)
        self.relu = nn.ReLU()
        self.normlayer = nn.LayerNorm(NUM_HIDDEN_NODES)
        self.dropout = nn.Dropout(p=DROPOUT_RATE)
        self.output = nn.Linear(NUM_HIDDEN_NODES, FEAT_COUNT)

    def forward(self, x):
        identity = x
        out = x
        for i in range(NUM_HIDDEN_LAYERS):
            out = self.input(out)
            out = self.relu(out)
            out = self.normlayer(out)
            out = self.dropout(out)
            out = self.output(out)
            out += identity
            out = self.relu(out)
        return out

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

## Train reverse ranker: train.py

In [70]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import torch.nn.functional as F


TOP_K = 100

# With probability alpha
# Select a random negative sample from train_neg_dict
ALPHA = 0.5

def dot_product(A, B, normalize=False):
    if normalize:
        A = F.normalize(A)
        B = F.normalize(B)
    b = A.shape[0]
    embed = A.shape[1]
    result = torch.bmm(A.view(b, 1, embed), B.view(b, embed, 1))
    return result


def mini_batch(batch_size, device, train_pos_dict, train_neg_dict, query_dict, passage_dict):
    passage_list = list(train_neg_dict.keys())
    passages = []
    pos = []
    neg = []
    pos_rank_list = []
    neg_rank_list = []
    while len(passages) < batch_size:
        pid = random.sample(passage_list, 1)[0]
        try:
            temp_pos_list = list(train_pos_dict[pid].keys())
        except:
            continue
        try:
            temp_neg_list = list(train_neg_dict[pid].keys())
        except:
            continue
        if np.random.uniform(0,1,1) <= ALPHA:
            random_positive = random.sample(temp_pos_list, 1)
            pos_qid = random_positive[0]
            pos_rank = train_pos_dict[pid][pos_qid]
            random_negative = random.sample(temp_neg_list, 1)
            neg_qid = random_negative[0]
            neg_rank = train_neg_dict[pid][neg_qid]
#             not_negative = True
#             while not_negative:
#                 temp_neg_qid = random.sample(list(query_dict.keys()), 1)
#                 if temp_neg_qid not in temp_query_list:
#                     neg_qid = temp_neg_qid[0]
#                     neg_rank = 1000
#                     not_negative = False
        else:
            if len(temp_pos_list) < 2:
                continue
            pos_neg_pair = random.sample(temp_pos_list, 2)
            # e.g. 60 >= 3
            if train_pos_dict[pid][pos_neg_pair[0]] >= train_pos_dict[pid][pos_neg_pair[1]]:
                pos_qid = pos_neg_pair[1]
                neg_qid = pos_neg_pair[0]
            # e.g. 3 < 60
            else:
                pos_qid = pos_neg_pair[0]
                neg_qid = pos_neg_pair[1]   
            pos_rank = train_pos_dict[pid][pos_qid]
            neg_rank = train_pos_dict[pid][neg_qid]
        p_seq = passage_dict[pid]
        pos_seq = query_dict[pos_qid]
        neg_seq = query_dict[neg_qid]
        passages.append(p_seq)
        pos.append(pos_seq)
        neg.append(neg_seq)
        pos_rank_list.append(TOP_K - pos_rank)
        neg_rank_list.append(TOP_K - neg_rank)
#         pos_rank_list.append((TOP_K - pos_rank) * 2)
#         neg_rank_list.append((TOP_K - neg_rank) * 2)
    labels = torch.stack([torch.FloatTensor(pos_rank_list), torch.FloatTensor(neg_rank_list)], dim=1)
    passages = torch.from_numpy(np.stack(passages))
    pos = torch.from_numpy(np.stack(pos))
    neg = torch.from_numpy(np.stack(neg))
    return passages.to(device), pos.to(device), neg.to(device), labels.to(device)


def train(net, epoch_size, batch_size, optimizer, device, train_pos_dict, train_neg_dict, 
          query_dict, passage_dict, scale=10, loss_option="ce"):
    bce = nn.BCELoss()
    ce = nn.CrossEntropyLoss()
    softmax = nn.Softmax(dim=1)
    train_loss = 0.0
    net.train()
    for mb_idx in range(epoch_size):
        # Read in a new mini-batch of data!
        passages, pos, neg, labels = mini_batch(batch_size, device, train_pos_dict, train_neg_dict, 
                                                query_dict, passage_dict)
        optimizer.zero_grad()
        p_embed = net(passages).to(device)
        pos_embed = net(pos).to(device)
        neg_embed = net(neg).to(device)
        out_pos = dot_product(p_embed, pos_embed).to(device)
        out_neg = dot_product(p_embed, neg_embed).to(device)
        out = torch.cat((out_pos, out_neg), -1).squeeze()
        if loss_option == "bce":
            loss = bce(softmax(out), softmax(labels))
        if loss_option == "ce":
            loss = ce(out, torch.tensor([0 for i in range(batch_size)]).to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        # print(str(mb_idx) + " iteration: " + str(train_loss / (mb_idx + 1)))
    return train_loss / epoch_size


## Main function: main.py

In [71]:
import datetime
from datetime import datetime, timezone, timedelta

TIME_OFFSET = -4


def print_message(s, offset=TIME_OFFSET):
    print("[{}] {}".format(datetime.now(timezone(timedelta(hours=offset))).strftime("%b %d, %H:%M:%S"), s), flush=True)

In [74]:
import torch
from torch import optim
import csv
import sys
import os


MODEL_PATH = "/datadrive/ruohan/fix_residual_overfit/"
CURRENT_DEVICE = "cuda:2"
PRETRAINED_PATH = "/datadrive/ruohan/random_sample/reverse_alpha0.5_initial_residual_saveoptim_layer51000_100_1000_0.0001_768.model"

if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)


def main(num_epochs, epoch_size, batch_size, learning_rate, model_path, embed_size, pretrained=False):
    if pretrained:
        checkpoint = torch.load(PRETRAINED_PATH)
        net = ResidualNet(embed_size=embed_size)
        net.load_state_dict(checkpoint['model'])
        net.to(CURRENT_DEVICE)
        optimizer = optim.Adam(net.parameters(), lr=learning_rate)
        optimizer.load_state_dict(checkpoint['optimizer'])
    else:
        net = ResidualNet(embed_size=embed_size).to(CURRENT_DEVICE)
        optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    print("Loading data")
#     train_rank_dict, query_dict, passage_dict = load()
#     print("Data successfully loaded.")
#     print("Positive Negative Pair dict size: " + str(len(train_rank_dict)))
#     print("Num of queries: " + str(len(query_dict)))
#     print("Num of passages: " + str(len(passage_dict)))
#     print("Finish loading.")

    arg_str = "reverse_alpha0.5_layer1_residual" + str(num_epochs) + "_" + str(epoch_size) + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(
        embed_size)
    unique_path = model_path + arg_str + ".model"
    output_path = model_path + arg_str + ".csv"
    
    for ep_idx in range(num_epochs):
        train_loss = train(net, epoch_size, batch_size, optimizer, CURRENT_DEVICE, train_pos_dict, 
                           train_neg_dict, query_dict, passage_dict)
        print_message([ep_idx,train_loss])
        with open(output_path, mode='a+') as output:
            output_writer = csv.writer(output)
            output_writer.writerow([ep_idx, train_loss])
        torch.save({
                "model": net.state_dict(),
                "optimizer": optimizer.state_dict(),
                "n_epoch": ep_idx,
                "train_loss": train_loss,
                "n_hidden_layer": NUM_HIDDEN_LAYERS
                    }, unique_path)

In [51]:
# checkpoint = torch.load("/datadrive/ruohan/random_sample/reverse_alpha0.5_initial_residual_saveoptim1000_100_1000_0.0001_768.model")
# ttnet = ResidualNet(embed_size=768)
# ttoptimizer = optim.Adam(ttnet.parameters(), lr=0.0001)
# ttnet.load_state_dict(checkpoint["model"])
# ttoptimizer.load_state_dict(checkpoint["optimizer"])

In [None]:
# New sample strategy
# Number of training queries: 50000
# 1 hidden layers
main(1000,100,1000,0.0001,MODEL_PATH,768,pretrained=False)

Loading data
[Sep 04, 17:15:04] [0, 0.5353488984704018]
[Sep 04, 17:15:10] [1, 0.40083329647779464]
[Sep 04, 17:15:16] [2, 0.396206733584404]
[Sep 04, 17:15:21] [3, 0.39210158795118333]
[Sep 04, 17:15:27] [4, 0.38254984736442565]
[Sep 04, 17:15:32] [5, 0.37357612818479535]
[Sep 04, 17:15:38] [6, 0.3616476389765739]
[Sep 04, 17:15:43] [7, 0.3535423043370247]
[Sep 04, 17:15:48] [8, 0.34328390210866927]
[Sep 04, 17:15:54] [9, 0.33508117735385895]
[Sep 04, 17:15:59] [10, 0.3254505109786987]
[Sep 04, 17:16:04] [11, 0.31572498500347135]
[Sep 04, 17:16:10] [12, 0.29453787535429]
[Sep 04, 17:16:15] [13, 0.2756103791296482]
[Sep 04, 17:16:20] [14, 0.2569262297451496]
[Sep 04, 17:16:26] [15, 0.24249009162187576]
[Sep 04, 17:16:31] [16, 0.23619298353791238]
[Sep 04, 17:16:37] [17, 0.22965820178389548]
[Sep 04, 17:16:42] [18, 0.222792037576437]
[Sep 04, 17:16:47] [19, 0.2213806051015854]
[Sep 04, 17:16:53] [20, 0.2167125564813614]
[Sep 04, 17:16:58] [21, 0.2176070536673069]
[Sep 04, 17:17:04] [22,

In [67]:
# New sample strategy
# Number of training queries: 50000
# 3 hidden layers
main(1000,100,1000,0.0001,MODEL_PATH,768,pretrained=False)

Loading data
[Sep 03, 23:47:20] [0, 0.6824299290776252]
[Sep 03, 23:47:27] [1, 0.4468103986978531]
[Sep 03, 23:47:34] [2, 0.41765457063913347]
[Sep 03, 23:47:41] [3, 0.4015886089205742]
[Sep 03, 23:47:48] [4, 0.39345457643270493]
[Sep 03, 23:47:55] [5, 0.39057127833366395]
[Sep 03, 23:48:02] [6, 0.38467361122369764]
[Sep 03, 23:48:09] [7, 0.3762369641661644]
[Sep 03, 23:48:16] [8, 0.3674654880166054]
[Sep 03, 23:48:23] [9, 0.3626611453294754]
[Sep 03, 23:48:30] [10, 0.35723533034324645]
[Sep 03, 23:48:36] [11, 0.35142105937004087]
[Sep 03, 23:48:43] [12, 0.34748754888772965]
[Sep 03, 23:48:50] [13, 0.33467329293489456]
[Sep 03, 23:48:57] [14, 0.33005423098802567]
[Sep 03, 23:49:03] [15, 0.3253258839249611]
[Sep 03, 23:49:10] [16, 0.3131204184889793]
[Sep 03, 23:49:17] [17, 0.30681286662817003]
[Sep 03, 23:49:24] [18, 0.2964709293842316]
[Sep 03, 23:49:31] [19, 0.282423205524683]
[Sep 03, 23:49:38] [20, 0.2690472428500652]
[Sep 03, 23:49:45] [21, 0.2591824935376644]
[Sep 03, 23:49:52] [

[Sep 04, 00:08:13] [182, 0.1010343998670578]
[Sep 04, 00:08:20] [183, 0.09900256440043449]
[Sep 04, 00:08:27] [184, 0.09887247174978256]
[Sep 04, 00:08:34] [185, 0.0982910592108965]
[Sep 04, 00:08:40] [186, 0.09863314129412175]
[Sep 04, 00:08:47] [187, 0.09772681497037411]
[Sep 04, 00:08:54] [188, 0.09848825924098492]
[Sep 04, 00:09:01] [189, 0.0977132785320282]
[Sep 04, 00:09:08] [190, 0.09580790504813194]
[Sep 04, 00:09:15] [191, 0.0962896203994751]
[Sep 04, 00:09:21] [192, 0.09728407479822636]
[Sep 04, 00:09:28] [193, 0.09770739644765854]
[Sep 04, 00:09:35] [194, 0.09607314042747021]
[Sep 04, 00:09:42] [195, 0.09519509807229042]
[Sep 04, 00:09:49] [196, 0.0954632043838501]
[Sep 04, 00:09:56] [197, 0.09460113786160945]
[Sep 04, 00:10:03] [198, 0.09577351786196232]
[Sep 04, 00:10:09] [199, 0.09545278891921044]
[Sep 04, 00:10:16] [200, 0.09424361780285835]
[Sep 04, 00:10:23] [201, 0.09398622713983058]
[Sep 04, 00:10:30] [202, 0.09560027845203876]
[Sep 04, 00:10:37] [203, 0.093721947446

[Sep 04, 00:29:05] [361, 0.06564018726348878]
[Sep 04, 00:29:12] [362, 0.06816897716373205]
[Sep 04, 00:29:18] [363, 0.06897669959813356]
[Sep 04, 00:29:25] [364, 0.06579600382596254]
[Sep 04, 00:29:32] [365, 0.06653825506567955]
[Sep 04, 00:29:39] [366, 0.06771994024515152]
[Sep 04, 00:29:46] [367, 0.06629314374178648]
[Sep 04, 00:29:53] [368, 0.0675684416294098]
[Sep 04, 00:30:00] [369, 0.06676524709910155]
[Sep 04, 00:30:07] [370, 0.06659393079578876]
[Sep 04, 00:30:14] [371, 0.067753817550838]
[Sep 04, 00:30:21] [372, 0.06674737740308047]
[Sep 04, 00:30:28] [373, 0.0658144111931324]
[Sep 04, 00:30:35] [374, 0.06610319204628468]
[Sep 04, 00:30:42] [375, 0.06772064462304116]
[Sep 04, 00:30:49] [376, 0.06567424103617668]
[Sep 04, 00:30:56] [377, 0.06561327613890171]
[Sep 04, 00:31:04] [378, 0.06620733994990587]
[Sep 04, 00:31:11] [379, 0.06424773823469877]
[Sep 04, 00:31:18] [380, 0.06557793539017438]
[Sep 04, 00:31:24] [381, 0.06382984478026628]
[Sep 04, 00:31:31] [382, 0.06558320928

[Sep 04, 00:49:26] [539, 0.054510286040604115]
[Sep 04, 00:49:33] [540, 0.05535843603312969]
[Sep 04, 00:49:40] [541, 0.05413617625832558]
[Sep 04, 00:49:47] [542, 0.05544829089194536]
[Sep 04, 00:49:54] [543, 0.054496350977569816]
[Sep 04, 00:50:00] [544, 0.05482436571270227]
[Sep 04, 00:50:07] [545, 0.05445144720375538]
[Sep 04, 00:50:14] [546, 0.05472732193768024]
[Sep 04, 00:50:21] [547, 0.05553591422736645]
[Sep 04, 00:50:28] [548, 0.05447132084518671]
[Sep 04, 00:50:35] [549, 0.053684534952044484]
[Sep 04, 00:50:42] [550, 0.0546006216481328]
[Sep 04, 00:50:49] [551, 0.05320557430386543]
[Sep 04, 00:50:56] [552, 0.05206711553037167]
[Sep 04, 00:51:03] [553, 0.05348336733877659]
[Sep 04, 00:51:10] [554, 0.053549153693020345]
[Sep 04, 00:51:17] [555, 0.05459520388394594]
[Sep 04, 00:51:24] [556, 0.05364351451396942]
[Sep 04, 00:51:31] [557, 0.053603911064565184]
[Sep 04, 00:51:38] [558, 0.05410996336489916]
[Sep 04, 00:51:46] [559, 0.05321942377835512]
[Sep 04, 00:51:53] [560, 0.054

[Sep 04, 01:09:53] [717, 0.04802448358386755]
[Sep 04, 01:10:00] [718, 0.048132610116153954]
[Sep 04, 01:10:07] [719, 0.04910765528678894]
[Sep 04, 01:10:14] [720, 0.04705456700176001]
[Sep 04, 01:10:21] [721, 0.04644714215770364]
[Sep 04, 01:10:28] [722, 0.047361026257276534]
[Sep 04, 01:10:34] [723, 0.04689804561436176]
[Sep 04, 01:10:41] [724, 0.047659827750176194]
[Sep 04, 01:10:48] [725, 0.04775084193795919]
[Sep 04, 01:10:55] [726, 0.04803379535675049]
[Sep 04, 01:11:02] [727, 0.04837719840928912]
[Sep 04, 01:11:09] [728, 0.048425319101661445]
[Sep 04, 01:11:16] [729, 0.04864410704001784]
[Sep 04, 01:11:22] [730, 0.04740563737228513]
[Sep 04, 01:11:29] [731, 0.04791833836585283]
[Sep 04, 01:11:36] [732, 0.04672452665865421]
[Sep 04, 01:11:43] [733, 0.04666119948029518]
[Sep 04, 01:11:50] [734, 0.04593508925288916]
[Sep 04, 01:11:57] [735, 0.04597004259005189]
[Sep 04, 01:12:04] [736, 0.04724353453144431]
[Sep 04, 01:12:11] [737, 0.046559968292713166]
[Sep 04, 01:12:17] [738, 0.04

[Sep 04, 01:30:08] [894, 0.0423198290169239]
[Sep 04, 01:30:15] [895, 0.04195749189704657]
[Sep 04, 01:30:22] [896, 0.043220481779426334]
[Sep 04, 01:30:29] [897, 0.044029826521873476]
[Sep 04, 01:30:36] [898, 0.04415842119604349]
[Sep 04, 01:30:43] [899, 0.04332923790439963]
[Sep 04, 01:30:50] [900, 0.04307803919538856]
[Sep 04, 01:30:56] [901, 0.043256971202790734]
[Sep 04, 01:31:03] [902, 0.04360282288864255]
[Sep 04, 01:31:10] [903, 0.042952913716435434]
[Sep 04, 01:31:17] [904, 0.0431132011115551]
[Sep 04, 01:31:24] [905, 0.043573181219398974]
[Sep 04, 01:31:31] [906, 0.04440591234713793]
[Sep 04, 01:31:38] [907, 0.04380960522219539]
[Sep 04, 01:31:45] [908, 0.04264043489471078]
[Sep 04, 01:31:52] [909, 0.04340072624385357]
[Sep 04, 01:31:59] [910, 0.0430159766972065]
[Sep 04, 01:32:05] [911, 0.042657329421490434]
[Sep 04, 01:32:12] [912, 0.04249405864626169]
[Sep 04, 01:32:20] [913, 0.04272604202851653]
[Sep 04, 01:32:28] [914, 0.043979182206094265]
[Sep 04, 01:32:37] [915, 0.042

In [58]:
# Number of training queries: 50000
# 3 hidden layers
main(1000,100,1000,0.0001,MODEL_PATH,768,pretrained=False)

Loading data
[Aug 26, 22:19:25] [0, 0.6429557383060456]
[Aug 26, 22:19:33] [1, 0.4474100682139397]
[Aug 26, 22:19:41] [2, 0.41430523365736005]
[Aug 26, 22:19:49] [3, 0.40435769528150556]
[Aug 26, 22:19:57] [4, 0.3929101538658142]
[Aug 26, 22:20:05] [5, 0.388307138979435]
[Aug 26, 22:20:13] [6, 0.3819301775097847]
[Aug 26, 22:20:21] [7, 0.37285622268915175]
[Aug 26, 22:20:29] [8, 0.36446011662483213]
[Aug 26, 22:20:37] [9, 0.35467886239290236]
[Aug 26, 22:20:45] [10, 0.34967049151659013]
[Aug 26, 22:20:53] [11, 0.33749841421842575]
[Aug 26, 22:21:01] [12, 0.3278729981184006]
[Aug 26, 22:21:09] [13, 0.3202175745368004]
[Aug 26, 22:21:17] [14, 0.3049624112248421]
[Aug 26, 22:21:25] [15, 0.2934373392164707]
[Aug 26, 22:21:33] [16, 0.2766068637371063]
[Aug 26, 22:21:41] [17, 0.26282958522439004]
[Aug 26, 22:21:49] [18, 0.24706998109817505]
[Aug 26, 22:21:57] [19, 0.23695026576519013]
[Aug 26, 22:22:04] [20, 0.22733796432614325]
[Aug 26, 22:22:13] [21, 0.2243678830564022]
[Aug 26, 22:22:21] 

[Aug 26, 22:43:12] [181, 0.0512075874209404]
[Aug 26, 22:43:20] [182, 0.0487503157556057]
[Aug 26, 22:43:28] [183, 0.05057424863800406]
[Aug 26, 22:43:35] [184, 0.04897398132830858]
[Aug 26, 22:43:43] [185, 0.050976268555969]
[Aug 26, 22:43:51] [186, 0.049801223687827585]
[Aug 26, 22:43:59] [187, 0.04904364012181759]
[Aug 26, 22:44:07] [188, 0.04841093074530363]
[Aug 26, 22:44:15] [189, 0.05094074748456478]
[Aug 26, 22:44:23] [190, 0.048941703028976914]
[Aug 26, 22:44:31] [191, 0.04945666244253516]
[Aug 26, 22:44:39] [192, 0.049862133860588076]
[Aug 26, 22:44:48] [193, 0.04870348870754242]
[Aug 26, 22:44:56] [194, 0.04741330217570067]
[Aug 26, 22:45:04] [195, 0.04821043493226171]
[Aug 26, 22:45:12] [196, 0.05020774163305759]
[Aug 26, 22:45:20] [197, 0.04836901370435953]
[Aug 26, 22:45:28] [198, 0.04759282095357776]
[Aug 26, 22:45:36] [199, 0.04801465839147568]
[Aug 26, 22:45:44] [200, 0.04791039306670428]
[Aug 26, 22:45:52] [201, 0.04978016298264265]
[Aug 26, 22:46:00] [202, 0.04860697

[Aug 26, 23:06:45] [359, 0.03466965641826391]
[Aug 26, 23:06:53] [360, 0.034918904844671486]
[Aug 26, 23:07:01] [361, 0.03476442420855164]
[Aug 26, 23:07:09] [362, 0.03591294903308153]
[Aug 26, 23:07:16] [363, 0.03378630006685853]
[Aug 26, 23:07:24] [364, 0.03273306658491492]
[Aug 26, 23:07:32] [365, 0.03317227244377136]
[Aug 26, 23:07:40] [366, 0.033876961786299946]
[Aug 26, 23:07:48] [367, 0.03356194173917174]
[Aug 26, 23:07:56] [368, 0.0336295335739851]
[Aug 26, 23:08:04] [369, 0.03403541067615151]
[Aug 26, 23:08:12] [370, 0.03349274020642042]
[Aug 26, 23:08:20] [371, 0.03382224635221064]
[Aug 26, 23:08:28] [372, 0.033904327638447286]
[Aug 26, 23:08:36] [373, 0.034274604339152574]
[Aug 26, 23:08:44] [374, 0.034061011485755445]
[Aug 26, 23:08:52] [375, 0.03350190676748752]
[Aug 26, 23:09:00] [376, 0.03393493093550205]
[Aug 26, 23:09:08] [377, 0.03276092523708939]
[Aug 26, 23:09:16] [378, 0.033866043593734504]
[Aug 26, 23:09:23] [379, 0.03338014390319586]
[Aug 26, 23:09:31] [380, 0.03

[Aug 26, 23:30:16] [536, 0.028891510078683496]
[Aug 26, 23:30:25] [537, 0.02828993007540703]
[Aug 26, 23:30:33] [538, 0.027253560330718754]
[Aug 26, 23:30:41] [539, 0.027438324671238662]
[Aug 26, 23:30:49] [540, 0.026912477631121874]
[Aug 26, 23:30:57] [541, 0.029413155037909745]
[Aug 26, 23:31:05] [542, 0.02840008765459061]
[Aug 26, 23:31:13] [543, 0.027924803271889686]
[Aug 26, 23:31:21] [544, 0.028227678425610067]
[Aug 26, 23:31:29] [545, 0.028908690474927425]
[Aug 26, 23:31:37] [546, 0.02573165135458112]
[Aug 26, 23:31:45] [547, 0.027157900426536798]
[Aug 26, 23:31:53] [548, 0.0266146676056087]
[Aug 26, 23:32:01] [549, 0.026984609868377446]
[Aug 26, 23:32:09] [550, 0.028364056777209042]
[Aug 26, 23:32:17] [551, 0.027030445467680693]
[Aug 26, 23:32:25] [552, 0.02713982719928026]
[Aug 26, 23:32:33] [553, 0.026038138791918755]
[Aug 26, 23:32:41] [554, 0.02724205602891743]
[Aug 26, 23:32:49] [555, 0.027545338943600655]
[Aug 26, 23:32:57] [556, 0.02698485093191266]
[Aug 26, 23:33:05] [5

[Aug 26, 23:53:53] [712, 0.023631732203066348]
[Aug 26, 23:54:02] [713, 0.023350876402109862]
[Aug 26, 23:54:10] [714, 0.0236562436632812]
[Aug 26, 23:54:18] [715, 0.02390151300467551]
[Aug 26, 23:54:26] [716, 0.024024355784058572]
[Aug 26, 23:54:35] [717, 0.02396719692274928]
[Aug 26, 23:54:43] [718, 0.02488162351772189]
[Aug 26, 23:54:51] [719, 0.02319118513725698]
[Aug 26, 23:54:59] [720, 0.023232433963567018]
[Aug 26, 23:55:07] [721, 0.023668029652908443]
[Aug 26, 23:55:15] [722, 0.02245474017225206]
[Aug 26, 23:55:22] [723, 0.02275867893360555]
[Aug 26, 23:55:30] [724, 0.0233404284901917]
[Aug 26, 23:55:38] [725, 0.022807031767442823]
[Aug 26, 23:55:46] [726, 0.02281228708103299]
[Aug 26, 23:55:54] [727, 0.023913574423640966]
[Aug 26, 23:56:02] [728, 0.023768684025853873]
[Aug 26, 23:56:10] [729, 0.022881418578326703]
[Aug 26, 23:56:18] [730, 0.02376969358883798]
[Aug 26, 23:56:26] [731, 0.022833670852705837]
[Aug 26, 23:56:34] [732, 0.023743179095909]
[Aug 26, 23:56:42] [733, 0.0

[Aug 27, 00:17:28] [888, 0.02016419637016952]
[Aug 27, 00:17:37] [889, 0.02181586991995573]
[Aug 27, 00:17:46] [890, 0.021292736548930408]
[Aug 27, 00:17:55] [891, 0.020992112718522547]
[Aug 27, 00:18:03] [892, 0.021663049925118683]
[Aug 27, 00:18:11] [893, 0.021497814804315566]
[Aug 27, 00:18:19] [894, 0.02042805904522538]
[Aug 27, 00:18:27] [895, 0.021117833498865365]
[Aug 27, 00:18:34] [896, 0.0212096194177866]
[Aug 27, 00:18:42] [897, 0.01945645110215992]
[Aug 27, 00:18:50] [898, 0.021485313978046178]
[Aug 27, 00:18:58] [899, 0.02235362526960671]
[Aug 27, 00:19:06] [900, 0.021326589761301876]
[Aug 27, 00:19:14] [901, 0.020889945179224014]
[Aug 27, 00:19:21] [902, 0.02063611090183258]
[Aug 27, 00:19:29] [903, 0.020876098526641726]
[Aug 27, 00:19:37] [904, 0.02228776317089796]
[Aug 27, 00:19:45] [905, 0.021648368639871476]
[Aug 27, 00:19:52] [906, 0.02121454053558409]
[Aug 27, 00:20:00] [907, 0.0213422735594213]
[Aug 27, 00:20:08] [908, 0.021525101317092777]
[Aug 27, 00:20:15] [909, 

In [52]:
# Split entire queries to train queries and test queries
# split ratio: 1:1
# 3 hidden layers
main(1000,100,1000,0.0001,MODEL_PATH,768,pretrained=False)

Loading data
[Aug 24, 20:24:47] [0, 0.846826805472374]
[Aug 24, 20:24:53] [1, 0.5714463549852371]
[Aug 24, 20:25:00] [2, 0.531261548101902]
[Aug 24, 20:25:07] [3, 0.5225673890113831]
[Aug 24, 20:25:14] [4, 0.523752268254757]
[Aug 24, 20:25:20] [5, 0.517607039809227]
[Aug 24, 20:25:27] [6, 0.5177364099025726]
[Aug 24, 20:25:34] [7, 0.5101031106710434]
[Aug 24, 20:25:41] [8, 0.5044968849420548]
[Aug 24, 20:25:48] [9, 0.49770285069942477]
[Aug 24, 20:25:55] [10, 0.49271502435207365]
[Aug 24, 20:26:02] [11, 0.48022217214107515]
[Aug 24, 20:26:08] [12, 0.4726865562796593]
[Aug 24, 20:26:15] [13, 0.4579222214221954]
[Aug 24, 20:26:22] [14, 0.4399645265936851]
[Aug 24, 20:26:29] [15, 0.4206054010987282]
[Aug 24, 20:26:36] [16, 0.4012561228871345]
[Aug 24, 20:26:43] [17, 0.3829398596286774]
[Aug 24, 20:26:49] [18, 0.370057587325573]
[Aug 24, 20:26:56] [19, 0.35951134413480756]
[Aug 24, 20:27:03] [20, 0.3497461199760437]
[Aug 24, 20:27:10] [21, 0.34322347849607465]
[Aug 24, 20:27:17] [22, 0.340

[Aug 24, 20:45:32] [183, 0.19632114678621293]
[Aug 24, 20:45:39] [184, 0.19372664123773575]
[Aug 24, 20:45:46] [185, 0.19475288897752763]
[Aug 24, 20:45:52] [186, 0.1948639640212059]
[Aug 24, 20:45:59] [187, 0.19534378990530968]
[Aug 24, 20:46:06] [188, 0.1902646803855896]
[Aug 24, 20:46:13] [189, 0.1921625839173794]
[Aug 24, 20:46:19] [190, 0.1944642920792103]
[Aug 24, 20:46:26] [191, 0.19220739737153053]
[Aug 24, 20:46:33] [192, 0.1909545588493347]
[Aug 24, 20:46:39] [193, 0.19337721467018126]
[Aug 24, 20:46:46] [194, 0.19084140464663504]
[Aug 24, 20:46:53] [195, 0.19026360034942627]
[Aug 24, 20:47:00] [196, 0.1898379646241665]
[Aug 24, 20:47:07] [197, 0.19014303207397462]
[Aug 24, 20:47:13] [198, 0.18890246301889418]
[Aug 24, 20:47:20] [199, 0.18848093956708908]
[Aug 24, 20:47:27] [200, 0.1890798583626747]
[Aug 24, 20:47:34] [201, 0.18651029348373413]
[Aug 24, 20:47:40] [202, 0.18577258259058]
[Aug 24, 20:47:47] [203, 0.18705086037516594]
[Aug 24, 20:47:54] [204, 0.18564628034830094

[Aug 24, 21:05:54] [363, 0.14812786713242532]
[Aug 24, 21:06:01] [364, 0.1485017741471529]
[Aug 24, 21:06:07] [365, 0.15027569122612477]
[Aug 24, 21:06:14] [366, 0.14807874791324138]
[Aug 24, 21:06:21] [367, 0.15112118802964689]
[Aug 24, 21:06:28] [368, 0.14754824869334698]
[Aug 24, 21:06:35] [369, 0.14682291947305204]
[Aug 24, 21:06:41] [370, 0.14584879234433173]
[Aug 24, 21:06:48] [371, 0.14809164449572562]
[Aug 24, 21:06:55] [372, 0.1478364259749651]
[Aug 24, 21:07:02] [373, 0.1468193167448044]
[Aug 24, 21:07:08] [374, 0.14749247305095195]
[Aug 24, 21:07:15] [375, 0.14386071413755416]
[Aug 24, 21:07:22] [376, 0.14591325677931308]
[Aug 24, 21:07:29] [377, 0.14559341050684452]
[Aug 24, 21:07:36] [378, 0.14779883936047555]
[Aug 24, 21:07:42] [379, 0.14603987976908683]
[Aug 24, 21:07:49] [380, 0.14408899553120136]
[Aug 24, 21:07:56] [381, 0.147188885435462]
[Aug 24, 21:08:03] [382, 0.14414234451949595]
[Aug 24, 21:08:09] [383, 0.14428295254707335]
[Aug 24, 21:08:16] [384, 0.144874163195

[Aug 24, 21:26:17] [543, 0.126715936884284]
[Aug 24, 21:26:24] [544, 0.12545107312500478]
[Aug 24, 21:26:31] [545, 0.12703460291028024]
[Aug 24, 21:26:38] [546, 0.12670144446194173]
[Aug 24, 21:26:45] [547, 0.12991874322295188]
[Aug 24, 21:26:51] [548, 0.1286407249420881]
[Aug 24, 21:26:58] [549, 0.12827736243605614]
[Aug 24, 21:27:05] [550, 0.12657752238214015]
[Aug 24, 21:27:12] [551, 0.1251082782447338]
[Aug 24, 21:27:19] [552, 0.1276262104511261]
[Aug 24, 21:27:26] [553, 0.12570759385824204]
[Aug 24, 21:27:32] [554, 0.12782968774437906]
[Aug 24, 21:27:39] [555, 0.1280971183627844]
[Aug 24, 21:27:46] [556, 0.12522316478192808]
[Aug 24, 21:27:53] [557, 0.12673636108636857]
[Aug 24, 21:28:00] [558, 0.1271457217633724]
[Aug 24, 21:28:06] [559, 0.1255954372882843]
[Aug 24, 21:28:13] [560, 0.12611707590520382]
[Aug 24, 21:28:20] [561, 0.12702750571072102]
[Aug 24, 21:28:27] [562, 0.12476996213197708]
[Aug 24, 21:28:34] [563, 0.12676187753677368]
[Aug 24, 21:28:41] [564, 0.126006537303328

[Aug 24, 21:46:40] [722, 0.11548398040235043]
[Aug 24, 21:46:47] [723, 0.11611198537051677]
[Aug 24, 21:46:54] [724, 0.11433243006467819]
[Aug 24, 21:47:01] [725, 0.11624764114618301]
[Aug 24, 21:47:07] [726, 0.11585713610053062]
[Aug 24, 21:47:14] [727, 0.11675544902682304]
[Aug 24, 21:47:21] [728, 0.1157023148983717]
[Aug 24, 21:47:28] [729, 0.1154110537469387]
[Aug 24, 21:47:35] [730, 0.11612001687288284]
[Aug 24, 21:47:42] [731, 0.11553486257791519]
[Aug 24, 21:47:49] [732, 0.11613194271922112]
[Aug 24, 21:47:55] [733, 0.11392279572784901]
[Aug 24, 21:48:02] [734, 0.11337225042283534]
[Aug 24, 21:48:09] [735, 0.11663099355995656]
[Aug 24, 21:48:16] [736, 0.11542884528636932]
[Aug 24, 21:48:22] [737, 0.11642394803464412]
[Aug 24, 21:48:29] [738, 0.11527901634573937]
[Aug 24, 21:48:36] [739, 0.11511919863522052]
[Aug 24, 21:48:43] [740, 0.11341921724379063]
[Aug 24, 21:48:50] [741, 0.11575542844831943]
[Aug 24, 21:48:57] [742, 0.11316885016858577]
[Aug 24, 21:49:03] [743, 0.113866734

[Aug 24, 22:07:14] [901, 0.11042640648782254]
[Aug 24, 22:07:21] [902, 0.10770335718989373]
[Aug 24, 22:07:28] [903, 0.1080351971834898]
[Aug 24, 22:07:35] [904, 0.10638320617377758]
[Aug 24, 22:07:42] [905, 0.10831359311938286]
[Aug 24, 22:07:49] [906, 0.10726979956030845]
[Aug 24, 22:07:55] [907, 0.1077052104473114]
[Aug 24, 22:08:02] [908, 0.10734739728271961]
[Aug 24, 22:08:09] [909, 0.1070775431394577]
[Aug 24, 22:08:16] [910, 0.1070754424482584]
[Aug 24, 22:08:23] [911, 0.10605075784027576]
[Aug 24, 22:08:30] [912, 0.10645886085927486]
[Aug 24, 22:08:36] [913, 0.1055822241306305]
[Aug 24, 22:08:43] [914, 0.10817600786685944]
[Aug 24, 22:08:50] [915, 0.10662521183490753]
[Aug 24, 22:08:57] [916, 0.10666097052395344]
[Aug 24, 22:09:04] [917, 0.10752025946974754]
[Aug 24, 22:09:10] [918, 0.10669443853199483]
[Aug 24, 22:09:17] [919, 0.10767552539706231]
[Aug 24, 22:09:24] [920, 0.10911688081920147]
[Aug 24, 22:09:31] [921, 0.10816423989832401]
[Aug 24, 22:09:38] [922, 0.105483935698

In [27]:
# Resume training on the same training set 
# 5 layers
main(1000,100,1000,0.0001,MODEL_PATH,768,pretrained=True)

Loading data
[Aug 22, 19:39:48] [0, 0.2200116577744484]
[Aug 22, 19:40:01] [1, 0.1701854132115841]
[Aug 22, 19:40:13] [2, 0.16275373697280884]
[Aug 22, 19:40:25] [3, 0.15984184712171554]
[Aug 22, 19:40:38] [4, 0.15997010990977287]
[Aug 22, 19:40:50] [5, 0.15666377276182175]
[Aug 22, 19:41:02] [6, 0.15974697306752206]
[Aug 22, 19:41:14] [7, 0.15735600560903548]
[Aug 22, 19:41:27] [8, 0.15693188093602659]
[Aug 22, 19:41:39] [9, 0.15855263605713843]
[Aug 22, 19:41:51] [10, 0.1560918805003166]
[Aug 22, 19:42:04] [11, 0.15621687188744546]
[Aug 22, 19:42:16] [12, 0.15637068718671798]
[Aug 22, 19:42:28] [13, 0.15675818920135498]
[Aug 22, 19:42:40] [14, 0.15704280957579614]
[Aug 22, 19:42:53] [15, 0.15502281308174135]
[Aug 22, 19:43:05] [16, 0.15625480100512504]
[Aug 22, 19:43:17] [17, 0.15599038928747178]
[Aug 22, 19:43:30] [18, 0.1565654807537794]
[Aug 22, 19:43:42] [19, 0.15734654545783996]
[Aug 22, 19:43:55] [20, 0.15312324792146684]
[Aug 22, 19:44:08] [21, 0.15489958703517914]
[Aug 22, 19

[Aug 22, 20:17:39] [182, 0.1466483372449875]
[Aug 22, 20:17:51] [183, 0.14582526087760925]
[Aug 22, 20:18:03] [184, 0.14613052025437356]
[Aug 22, 20:18:16] [185, 0.14716976508498192]
[Aug 22, 20:18:28] [186, 0.14417920023202896]
[Aug 22, 20:18:41] [187, 0.14528294578194617]
[Aug 22, 20:18:53] [188, 0.14575255922973157]
[Aug 22, 20:19:05] [189, 0.14580538645386695]
[Aug 22, 20:19:17] [190, 0.14554165817797185]
[Aug 22, 20:19:30] [191, 0.14574411794543266]
[Aug 22, 20:19:42] [192, 0.14429914020001888]
[Aug 22, 20:19:55] [193, 0.1480268671363592]
[Aug 22, 20:20:07] [194, 0.14554051391780376]
[Aug 22, 20:20:19] [195, 0.1455048944056034]
[Aug 22, 20:20:32] [196, 0.14614737018942833]
[Aug 22, 20:20:44] [197, 0.14587241359055042]
[Aug 22, 20:20:56] [198, 0.1459692928940058]
[Aug 22, 20:21:08] [199, 0.1479968909919262]
[Aug 22, 20:21:20] [200, 0.1458163396269083]
[Aug 22, 20:21:33] [201, 0.14335037671029568]
[Aug 22, 20:21:45] [202, 0.14427487067878247]
[Aug 22, 20:21:57] [203, 0.1467791119962

[Aug 22, 20:54:33] [362, 0.1427857515960932]
[Aug 22, 20:54:45] [363, 0.1414707288891077]
[Aug 22, 20:54:57] [364, 0.1407363757491112]
[Aug 22, 20:55:10] [365, 0.13997017733752729]
[Aug 22, 20:55:22] [366, 0.13939375407993793]
[Aug 22, 20:55:34] [367, 0.14144984178245068]
[Aug 22, 20:55:47] [368, 0.1403717103600502]
[Aug 22, 20:55:59] [369, 0.14048840031027793]
[Aug 22, 20:56:11] [370, 0.1413355578482151]
[Aug 22, 20:56:24] [371, 0.1396335794776678]
[Aug 22, 20:56:36] [372, 0.14000268861651421]
[Aug 22, 20:56:48] [373, 0.14080965489149094]
[Aug 22, 20:57:00] [374, 0.1403035070002079]
[Aug 22, 20:57:13] [375, 0.14176097080111505]
[Aug 22, 20:57:25] [376, 0.13952627658843994]
[Aug 22, 20:57:38] [377, 0.14105949006974697]
[Aug 22, 20:57:50] [378, 0.14190187588334083]
[Aug 22, 20:58:02] [379, 0.1416803777962923]
[Aug 22, 20:58:14] [380, 0.14110895231366158]
[Aug 22, 20:58:27] [381, 0.1389371082186699]
[Aug 22, 20:58:39] [382, 0.13891677387058735]
[Aug 22, 20:58:52] [383, 0.1410696865618229

[Aug 22, 21:31:28] [542, 0.13506773926317692]
[Aug 22, 21:31:40] [543, 0.13794789031147958]
[Aug 22, 21:31:53] [544, 0.13574501030147076]
[Aug 22, 21:32:05] [545, 0.13651987470686436]
[Aug 22, 21:32:18] [546, 0.13682472988963126]
[Aug 22, 21:32:30] [547, 0.13612683095037936]
[Aug 22, 21:32:42] [548, 0.13943766497075558]
[Aug 22, 21:32:55] [549, 0.13548584029078484]
[Aug 22, 21:33:07] [550, 0.13812229506671428]
[Aug 22, 21:33:19] [551, 0.13588641658425332]
[Aug 22, 21:33:32] [552, 0.1380486460030079]
[Aug 22, 21:33:44] [553, 0.13746780283749105]
[Aug 22, 21:33:56] [554, 0.1370799559354782]
[Aug 22, 21:34:09] [555, 0.13693112924695014]
[Aug 22, 21:34:21] [556, 0.13572490677237511]
[Aug 22, 21:34:34] [557, 0.13524612918496132]
[Aug 22, 21:34:46] [558, 0.13662939816713332]
[Aug 22, 21:34:58] [559, 0.13571142986416818]
[Aug 22, 21:35:10] [560, 0.13798594228923322]
[Aug 22, 21:35:23] [561, 0.13609116092324258]
[Aug 22, 21:35:35] [562, 0.1371144475787878]
[Aug 22, 21:35:47] [563, 0.1380719234

[Aug 22, 22:08:13] [722, 0.13248322427272796]
[Aug 22, 22:08:25] [723, 0.13023902997374534]
[Aug 22, 22:08:37] [724, 0.1320613180845976]
[Aug 22, 22:08:50] [725, 0.1337842959165573]
[Aug 22, 22:09:02] [726, 0.1333822489529848]
[Aug 22, 22:09:14] [727, 0.13333442583680152]
[Aug 22, 22:09:26] [728, 0.13468996345996856]
[Aug 22, 22:09:38] [729, 0.1337559775263071]
[Aug 22, 22:09:50] [730, 0.13179648496210575]
[Aug 22, 22:10:02] [731, 0.1344988474994898]
[Aug 22, 22:10:14] [732, 0.13278695791959763]
[Aug 22, 22:10:26] [733, 0.13248616710305214]
[Aug 22, 22:10:38] [734, 0.13386483520269393]
[Aug 22, 22:10:51] [735, 0.13130990654230118]
[Aug 22, 22:11:03] [736, 0.13305211529135705]
[Aug 22, 22:11:15] [737, 0.13238624967634677]
[Aug 22, 22:11:27] [738, 0.1329212311655283]
[Aug 22, 22:11:39] [739, 0.13331346854567527]
[Aug 22, 22:11:51] [740, 0.1339727409183979]
[Aug 22, 22:12:03] [741, 0.13227848194539546]
[Aug 22, 22:12:15] [742, 0.1324796861410141]
[Aug 22, 22:12:27] [743, 0.132484602257609

[Aug 22, 22:44:32] [902, 0.13043990053236484]
[Aug 22, 22:44:45] [903, 0.1292437770217657]
[Aug 22, 22:44:57] [904, 0.13133299089968203]
[Aug 22, 22:45:09] [905, 0.13091937005519866]
[Aug 22, 22:45:21] [906, 0.12965367540717124]
[Aug 22, 22:45:33] [907, 0.1298043140769005]
[Aug 22, 22:45:45] [908, 0.12899820528924466]
[Aug 22, 22:45:57] [909, 0.12978020526468753]
[Aug 22, 22:46:09] [910, 0.13055889099836349]
[Aug 22, 22:46:22] [911, 0.13075527489185335]
[Aug 22, 22:46:34] [912, 0.130255271717906]
[Aug 22, 22:46:46] [913, 0.1306098260730505]
[Aug 22, 22:46:58] [914, 0.13017060607671738]
[Aug 22, 22:47:10] [915, 0.13137181520462035]
[Aug 22, 22:47:22] [916, 0.13275882333517075]
[Aug 22, 22:47:34] [917, 0.1298472436517477]
[Aug 22, 22:47:46] [918, 0.1309479807317257]
[Aug 22, 22:47:58] [919, 0.1299597555398941]
[Aug 22, 22:48:10] [920, 0.130036443695426]
[Aug 22, 22:48:22] [921, 0.12947019815444946]
[Aug 22, 22:48:34] [922, 0.13163877539336682]
[Aug 22, 22:48:46] [923, 0.1294918091595173]

In [18]:
# Train on randomly sampled training data that contains 200,000 passages
main(1000,100,1000,0.0001,MODEL_PATH,768)

Loading data
[Aug 21, 22:27:49] [0, 1.049621695280075]
[Aug 21, 22:28:02] [1, 0.7177173060178756]
[Aug 21, 22:28:14] [2, 0.6477977627515793]
[Aug 21, 22:28:26] [3, 0.6193814289569854]
[Aug 21, 22:28:39] [4, 0.5911640566587448]
[Aug 21, 22:28:51] [5, 0.5795909243822098]
[Aug 21, 22:29:03] [6, 0.5672692477703094]
[Aug 21, 22:29:15] [7, 0.5547531509399414]
[Aug 21, 22:29:28] [8, 0.5422469416260719]
[Aug 21, 22:29:40] [9, 0.5309646573662757]
[Aug 21, 22:29:53] [10, 0.5194277182221413]
[Aug 21, 22:30:06] [11, 0.5122983610630035]
[Aug 21, 22:30:18] [12, 0.4996506205201149]
[Aug 21, 22:30:31] [13, 0.49126205772161485]
[Aug 21, 22:30:44] [14, 0.4788512253761292]
[Aug 21, 22:30:57] [15, 0.47239673852920533]
[Aug 21, 22:31:10] [16, 0.45934064537286756]
[Aug 21, 22:31:22] [17, 0.44842516869306565]
[Aug 21, 22:31:35] [18, 0.4428005975484848]
[Aug 21, 22:31:47] [19, 0.4365400493144989]
[Aug 21, 22:32:00] [20, 0.4327924346923828]
[Aug 21, 22:32:12] [21, 0.4263707169890404]
[Aug 21, 22:32:25] [22, 0.

[Aug 21, 23:05:35] [183, 0.2619546756148338]
[Aug 21, 23:05:48] [184, 0.2589871872961521]
[Aug 21, 23:06:00] [185, 0.26194216147065164]
[Aug 21, 23:06:12] [186, 0.2586447003483772]
[Aug 21, 23:06:25] [187, 0.2595504388213158]
[Aug 21, 23:06:38] [188, 0.26039090543985366]
[Aug 21, 23:06:50] [189, 0.25785535082221034]
[Aug 21, 23:07:04] [190, 0.25605717569589614]
[Aug 21, 23:07:16] [191, 0.2570407317578793]
[Aug 21, 23:07:29] [192, 0.2567103365063667]
[Aug 21, 23:07:42] [193, 0.2581514598429203]
[Aug 21, 23:07:54] [194, 0.25586354956030843]
[Aug 21, 23:08:07] [195, 0.2583950892090797]
[Aug 21, 23:08:19] [196, 0.25590036883950235]
[Aug 21, 23:08:32] [197, 0.256178684681654]
[Aug 21, 23:08:45] [198, 0.25642402544617654]
[Aug 21, 23:08:57] [199, 0.25481263399124143]
[Aug 21, 23:09:09] [200, 0.2542850385606289]
[Aug 21, 23:09:22] [201, 0.25223878398537636]
[Aug 21, 23:09:35] [202, 0.2521707198023796]
[Aug 21, 23:09:47] [203, 0.25471794575452805]
[Aug 21, 23:10:00] [204, 0.25248051032423974]


[Aug 21, 23:44:19] [363, 0.21622635141015054]
[Aug 21, 23:44:32] [364, 0.2176140959560871]
[Aug 21, 23:44:45] [365, 0.21654012456536292]
[Aug 21, 23:44:58] [366, 0.21809410020709039]
[Aug 21, 23:45:11] [367, 0.21592803567647934]
[Aug 21, 23:45:24] [368, 0.21793582245707513]
[Aug 21, 23:45:37] [369, 0.2159090718626976]
[Aug 21, 23:45:50] [370, 0.21240109488368034]
[Aug 21, 23:46:03] [371, 0.21402581334114074]
[Aug 21, 23:46:16] [372, 0.2155788266658783]
[Aug 21, 23:46:30] [373, 0.2145340174436569]
[Aug 21, 23:46:43] [374, 0.2133140343427658]
[Aug 21, 23:46:56] [375, 0.21598318442702294]
[Aug 21, 23:47:09] [376, 0.21628866508603095]
[Aug 21, 23:47:22] [377, 0.21487857267260552]
[Aug 21, 23:47:35] [378, 0.21628852397203446]
[Aug 21, 23:47:49] [379, 0.2154451847076416]
[Aug 21, 23:48:02] [380, 0.2136365845799446]
[Aug 21, 23:48:15] [381, 0.2157372437417507]
[Aug 21, 23:48:27] [382, 0.21391580179333686]
[Aug 21, 23:48:41] [383, 0.21455946043133736]
[Aug 21, 23:48:54] [384, 0.214136885255575

[Aug 22, 00:22:51] [543, 0.19448660150170327]
[Aug 22, 00:23:04] [544, 0.1966087044775486]
[Aug 22, 00:23:16] [545, 0.1974475434422493]
[Aug 22, 00:23:28] [546, 0.19710069492459298]
[Aug 22, 00:23:40] [547, 0.19502707958221435]
[Aug 22, 00:23:53] [548, 0.19229384645819664]
[Aug 22, 00:24:05] [549, 0.19603768050670622]
[Aug 22, 00:24:17] [550, 0.19597763001918792]
[Aug 22, 00:24:29] [551, 0.1960507233440876]
[Aug 22, 00:24:41] [552, 0.19354464188218118]
[Aug 22, 00:24:53] [553, 0.1959887807071209]
[Aug 22, 00:25:06] [554, 0.1939886762201786]
[Aug 22, 00:25:18] [555, 0.1978562805056572]
[Aug 22, 00:25:30] [556, 0.19618013754487038]
[Aug 22, 00:25:42] [557, 0.1953492605686188]
[Aug 22, 00:25:54] [558, 0.19435726955533028]
[Aug 22, 00:26:06] [559, 0.19382868304848672]
[Aug 22, 00:26:19] [560, 0.1972312794625759]
[Aug 22, 00:26:31] [561, 0.1950923989713192]
[Aug 22, 00:26:43] [562, 0.19465900510549544]
[Aug 22, 00:26:55] [563, 0.1975063343346119]
[Aug 22, 00:27:07] [564, 0.19462200239300728

[Aug 22, 00:59:27] [723, 0.1812646695971489]
[Aug 22, 00:59:40] [724, 0.1851092404127121]
[Aug 22, 00:59:52] [725, 0.18226028636097907]
[Aug 22, 01:00:04] [726, 0.18306563496589662]
[Aug 22, 01:00:16] [727, 0.1830461974442005]
[Aug 22, 01:00:28] [728, 0.1853538866341114]
[Aug 22, 01:00:41] [729, 0.1857893744111061]
[Aug 22, 01:00:53] [730, 0.1808656147122383]
[Aug 22, 01:01:05] [731, 0.18354333758354188]
[Aug 22, 01:01:17] [732, 0.18055806949734687]
[Aug 22, 01:01:29] [733, 0.1834531469643116]
[Aug 22, 01:01:42] [734, 0.18320471972227095]
[Aug 22, 01:01:54] [735, 0.18234472945332528]
[Aug 22, 01:02:06] [736, 0.18238206788897515]
[Aug 22, 01:02:18] [737, 0.1823534832894802]
[Aug 22, 01:02:30] [738, 0.18323394134640694]
[Aug 22, 01:02:43] [739, 0.183607407361269]
[Aug 22, 01:02:56] [740, 0.18312006443738937]
[Aug 22, 01:03:09] [741, 0.18349851548671722]
[Aug 22, 01:03:23] [742, 0.183738631606102]
[Aug 22, 01:03:36] [743, 0.1811194723844528]
[Aug 22, 01:03:49] [744, 0.18268521651625633]
[

[Aug 22, 01:36:06] [903, 0.17358581110835075]
[Aug 22, 01:36:19] [904, 0.17194602742791176]
[Aug 22, 01:36:31] [905, 0.17476639807224273]
[Aug 22, 01:36:43] [906, 0.1750466237962246]
[Aug 22, 01:36:55] [907, 0.17347493886947632]
[Aug 22, 01:37:07] [908, 0.1766258805990219]
[Aug 22, 01:37:20] [909, 0.17315524220466613]
[Aug 22, 01:37:32] [910, 0.1769346049427986]
[Aug 22, 01:37:44] [911, 0.17230464026331901]
[Aug 22, 01:37:56] [912, 0.17531048834323884]
[Aug 22, 01:38:08] [913, 0.17434409514069557]
[Aug 22, 01:38:20] [914, 0.17454180911183356]
[Aug 22, 01:38:33] [915, 0.17618634641170502]
[Aug 22, 01:38:45] [916, 0.17166389867663384]
[Aug 22, 01:38:57] [917, 0.1725667491555214]
[Aug 22, 01:39:09] [918, 0.17386890232563018]
[Aug 22, 01:39:21] [919, 0.17415894210338592]
[Aug 22, 01:39:33] [920, 0.1749108675122261]
[Aug 22, 01:39:46] [921, 0.17518358469009399]
[Aug 22, 01:39:58] [922, 0.1738073581457138]
[Aug 22, 01:40:10] [923, 0.17262760147452355]
[Aug 22, 01:40:22] [924, 0.1742000174522