In [30]:
import sys
sys.path.insert(0, '/home/jianx/search-exposure/')
from train import train
from load_data import load
import csv
from test import test
import os

NUM_EPOCHS = 2
EPOCH_SIZE = 1
BATCH_SIZE = 100
LEARNING_RATE = 0.01
print("Num of epochs:", NUM_EPOCHS)
print("Epoch size:", EPOCH_SIZE)
print("Batch size:", BATCH_SIZE)
print("Learning rate:", LEARNING_RATE)
RANK = 10
TEST_BATCH = 35
MODEL_PATH = "/home/jianx/data/results/"

if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)

Num of epochs: 2
Epoch size: 1
Batch size: 100
Learning rate: 0.01


In [34]:
import torch
import torch.nn as nn


NUM_HIDDEN_NODES = 64
NUM_HIDDEN_LAYERS = 3
DROPOUT_RATE = 0.1
FEAT_COUNT = 50000
EMBED_SIZE = 10

# Define the network
class DSSM(torch.nn.Module):

    def __init__(self, embed_size):
        super(DSSM, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            print("Current device is: %s" % torch.cuda.get_device_name(self.device))
        # self.scale = torch.tensor([1], dtype=torch.float).to(self.device)

        layers = []
        last_dim = FEAT_COUNT
        for i in range(NUM_HIDDEN_LAYERS):
            layers.append(nn.Linear(last_dim, NUM_HIDDEN_NODES))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(NUM_HIDDEN_NODES))
            layers.append(nn.Dropout(p=min(DROPOUT_RATE * (i+1), 0.5)))
            last_dim = NUM_HIDDEN_NODES
        layers.append(nn.Linear(last_dim, embed_size))
        layers.append(nn.ReLU())
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def parameter_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


In [21]:
num_epochs, epoch_size, batch_size, learning_rate, model_path, rank, test_batch = NUM_EPOCHS, EPOCH_SIZE, BATCH_SIZE, LEARNING_RATE, MODEL_PATH, RANK, TEST_BATCH

In [6]:

print("Loading data")
pos_neg_dict, query_dict, passage_dict, top_dict, rating_dict, query_test_dict = load()
print("Data successfully loaded.")
print("Positive Negative Pair dict size: " + str(len(pos_neg_dict)))
print("Num of queries: " + str(len(query_dict)))
print("Num of passages: " + str(len(passage_dict)))
# date_time_obj = datetime.now()
# timestamp_str = date_time_obj.strftime("%b-%d-%Y_%H:%M:%S")
arg_str = str(num_epochs) + "_" + str(epoch_size) + "_" + str(batch_size) + "_" + str(learning_rate)
unique_path = model_path + arg_str + ".model"
output_path = model_path + arg_str + ".csv"


Current device is: Tesla P100-PCIE-16GB
Loading data
Data successfully loaded.
Positive Negative Pair dict size: 400782
Num of queries: 808731
Num of passages: 8841823


In [35]:
dssm = DSSM(embed_size = 10)
net = dssm.to(dssm.device)
with open(output_path, mode='a') as output:
    output_writer = csv.writer(output)
    for ep_idx in range(num_epochs):
        train_loss = train(net, epoch_size, batch_size, learning_rate, dssm.device, pos_neg_dict,
                           query_dict, passage_dict)
        avg_ndcg, avg_prec, avg_rr = test(net, test_batch, top_dict, query_test_dict, passage_dict, rating_dict,
                                          rank)
        print("Epoch:{}, loss:{}, NDCG:{}, P:{}, RR:{}".format(ep_idx, train_loss, avg_ndcg, avg_prec, avg_rr))
        output_writer.writerow([ep_idx, train_loss, avg_ndcg, avg_prec, avg_rr])
torch.save(net, unique_path)

Current device is: Tesla P100-PCIE-16GB
573724 950.4421206712723 1000 {8008053: 0.979056179523468, 515996: 0.9892823100090027, 84173: 0.8960222005844116, 5166798: 0.9936414957046509, 7172145: 0.9453392624855042, 7175634: 0.8964442610740662, 8441636: 0.9837086796760559, 8441637: 0.9599940776824951, 8441639: 0.6721601486206055, 8441644: 0.9449630975723267, 5185671: 0.9366382956504822, 8444727: 0.9953448176383972, 8047170: 0.9895566701889038, 8048556: 0.9947184920310974, 7199667: 0.9703372716903687, 7199668: 0.9857556819915771, 7199669: 0.854787290096283, 7199674: 0.9950911402702332, 8488412: 0.9641856551170349, 8083118: 0.9928834438323975, 8492171: 0.9462181329727173, 7258827: 0.988493025302887, 8110336: 0.9921744465827942, 7275118: 0.9769201278686523, 5264969: 0.9623193144798279, 8534578: 0.9799063801765442, 7297267: 0.9963573217391968, 8136746: 0.9789842367172241, 7307656: 0.9773292541503906, 5293510: 0.9780465960502625, 5306910: 0.9888681173324585, 8573967: 0.9753663539886475, 8574665

  "type " + obj.__name__ + ". It won't be checked "
