In [1]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import QuantileTransformer
from sklearn.model_selection import train_test_split
from torch import nn
import torch

In [2]:
def read_fvecs(file_name):
    a = np.fromfile(file_name, dtype="int32")
    dim = a[0]
    return a.view("float32").reshape((-1, dim + 1))[:,1:]

In [67]:
DATASET = "collections"
EMBEDDING_DIM = 50
HIDDEN_DIM = 128

ITEM_BATCH_SIZE = 10000
QUERY_BATCH_SIZE = 100
CLOSE_ITEM_COUNT = 100

In [4]:
item_features = read_fvecs("{}/data/item_features.fvecs".format(DATASET))
train_query_features = read_fvecs("{}/data/user_features_train.fvecs".format(DATASET))
test_query_features = read_fvecs("{}/data/user_features_test.fvecs".format(DATASET))

In [5]:
# def measure_scattering(data):
#     scattering = data.max(axis=0) - data.min(axis=0)
#     scattering /= data.std(axis=0)
#     return scattering

In [6]:
ITEM_COUNT, ITEM_FEATURES_COUNT = item_features.shape
QUERY_COUNT, QUERY_FEATURES_COUNT = train_query_features.shape
VALIDATION_QUERY_COUNT = 100
TRAIN_QUERY_COUNT = QUERY_COUNT - VALIDATION_QUERY_COUNT

train_query_indexes, val_query_indexes = train_test_split(
    np.arange(QUERY_COUNT), test_size=VALIDATION_QUERY_COUNT,
    random_state=0
)

In [7]:
target = np.fromfile(
    "{}/data/model_scores/scores_train.bin".format(DATASET),
    dtype="float32"
).reshape((ITEM_COUNT, QUERY_COUNT))
target = torch.FloatTensor(target)

In [8]:
close_items = np.fromfile(
    "{}/data/model_scores/groundtruth_train.bin".format(DATASET),
    dtype="int32"
).reshape(QUERY_COUNT, CLOSE_ITEM_COUNT)
close_probs = 1. / np.arange(1, CLOSE_ITEM_COUNT + 1)
close_probs /= close_probs.sum()


In [9]:
item_scaler = QuantileTransformer()
item_features = item_scaler.fit_transform(item_features)

query_scaler = QuantileTransformer()
train_query_features = query_scaler.fit_transform(train_query_features)
test_query_features = query_scaler.transform(test_query_features)

In [10]:
item_features = torch.FloatTensor(item_features)
train_query_features = torch.FloatTensor(train_query_features)
test_query_features = torch.FloatTensor(test_query_features)

In [68]:
item_net = nn.Sequential(
    nn.Linear(in_features=ITEM_FEATURES_COUNT, out_features=HIDDEN_DIM, bias=True),
    nn.BatchNorm1d(HIDDEN_DIM),
    nn.ELU(),
    nn.Linear(in_features=HIDDEN_DIM, out_features=HIDDEN_DIM, bias=True),
    nn.BatchNorm1d(HIDDEN_DIM),
    nn.ELU(),
    nn.Linear(in_features=HIDDEN_DIM, out_features=EMBEDDING_DIM, bias=True),
)

query_net = nn.Sequential(
    nn.Linear(in_features=QUERY_FEATURES_COUNT, out_features=HIDDEN_DIM, bias=True),
    nn.BatchNorm1d(HIDDEN_DIM),
    nn.ELU(),
    nn.Linear(in_features=HIDDEN_DIM, out_features=HIDDEN_DIM, bias=True),
    nn.BatchNorm1d(HIDDEN_DIM),
    nn.ELU(),
    nn.Linear(in_features=HIDDEN_DIM, out_features=EMBEDDING_DIM, bias=True),
)


In [69]:
# optimizer = torch.optim.Adam(
#     params=list(item_net.parameters()) + list(query_net.parameters()),
#     lr=0.01, weight_decay=0.1
# )
optimizer = torch.optim.SGD(
    params=list(item_net.parameters()) + list(query_net.parameters()),
    lr=0.001, weight_decay=0.1
)

In [70]:
def calc_batch_loss_stupid(query_indexes, train=True):
    batch_queries = train_query_features[query_indexes]
    
    batch_item_indexes = np.random.choice(ITEM_COUNT, ITEM_BATCH_SIZE, replace=False)
    batch_items = item_features[batch_item_indexes]
    batch_target = target[batch_item_indexes][:,query_indexes]

    if train:
        item_net.train()
        query_net.train()
    else:
        item_net.eval()
        query_net.eval()

    batch_item_embeds = item_net(batch_items)
    batch_query_embeds = query_net(batch_queries)
    relevance_prediction = torch.matmul(batch_item_embeds, batch_query_embeds.T)
    loss = ((batch_target - relevance_prediction) ** 2).mean()
    return loss

def calc_batch_loss(query_indexes, train=True):
    batch_queries = train_query_features[query_indexes]
    
    positives_ranks = np.random.choice(
        CLOSE_ITEM_COUNT, len(query_indexes), p=close_probs)
    positives_indexes = close_items[query_indexes, positives_ranks]
    batch_posives = item_features[positives_indexes]
    positive_target = target[positives_indexes, query_indexes]

    negative_indexes = np.random.choice(ITEM_COUNT, ITEM_BATCH_SIZE, replace=False)
    batch_negatives = item_features[negative_indexes]
    negative_target = target[negative_indexes][:,query_indexes]

    if train:
        item_net.train()
        query_net.train()
    else:
        item_net.eval()
        query_net.eval()

    batch_query_embeds = query_net(batch_queries)
    
    positive_item_embeds = item_net(batch_posives)
    positive_relevance_prediction = (positive_item_embeds * batch_query_embeds).sum(axis=1)
    positive_loss = ((positive_target - positive_relevance_prediction) ** 2).mean()
    
    negative_item_embeds = item_net(batch_negatives)
    negative_relevance_prediction = torch.matmul(negative_item_embeds, batch_query_embeds.T)
    negative_loss = ((negative_target - negative_relevance_prediction) ** 2).mean()
    return positive_loss + negative_loss


item_model_path = "collections/item.net"
query_model_path = "collections/query.net"
best_loss = None
best_epoch = 0
PASSES_PER_EPOCH = 10

for epoch in range(100):
    train_loss = 0
    validation_loss = 0
    
    for i in range(PASSES_PER_EPOCH):
        query_permutation = np.random.permutation(train_query_indexes)
        for batch_start in range(0, TRAIN_QUERY_COUNT, QUERY_BATCH_SIZE):
            batch_query_indexes = query_permutation[batch_start: batch_start + QUERY_BATCH_SIZE]
            loss = calc_batch_loss_stupid(batch_query_indexes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.data.item()
    train_loss /= TRAIN_QUERY_COUNT * PASSES_PER_EPOCH
    for i in range(PASSES_PER_EPOCH):
        validation_loss += calc_batch_loss_stupid(
            val_query_indexes, train=False).data.item()
    validation_loss /= VALIDATION_QUERY_COUNT * PASSES_PER_EPOCH
    
    if best_loss is None or validation_loss < best_loss:
        best_loss = validation_loss
        best_epoch = epoch
        torch.save(item_net, item_model_path)
        torch.save(query_net, query_model_path)
        
    print("Train loss: {:<30} Validation loss: {:<30}".format(train_loss, validation_loss))
    if epoch > best_epoch + 10:
        break

Train loss: 0.022129510098033482           Validation loss: 0.014315928816795349          
Train loss: 0.01157816696829266            Validation loss: 0.010313705563545227          
Train loss: 0.009377706574069129           Validation loss: 0.009026462733745575          
Train loss: 0.008474990427494049           Validation loss: 0.008391676664352418          
Train loss: 0.007956832223468357           Validation loss: 0.00797242033481598           
Train loss: 0.007617523729801178           Validation loss: 0.007679654002189637          
Train loss: 0.007360780808660719           Validation loss: 0.007427020072937012          
Train loss: 0.007151115225421058           Validation loss: 0.007273702263832092          
Train loss: 0.006992318438159095           Validation loss: 0.007137261748313904          
Train loss: 0.0068642560972107785          Validation loss: 0.0070020023584365845         
Train loss: 0.006756950325436062           Validation loss: 0.006921856641769409          

In [72]:
best_epoch, best_loss

(72, 0.00591660338640213)

In [50]:
item_net = torch.load(item_model_path)
query_net = torch.load(query_model_path)

In [73]:
best_item_net = torch.load(item_model_path)
best_query_net = torch.load(query_model_path)

In [74]:
best_item_net.eval()
best_query_net.eval()

item_embeddings = best_item_net(item_features).detach().numpy()
item_embeddings.astype("float32").tofile("{}/data/item_embeddings.bin".format(DATASET))

train_query_embeddings = best_query_net(train_query_features).detach().numpy()
train_query_embeddings.astype("float32").tofile("{}/data/query_embeddings_train.bin".format(DATASET))

test_query_embeddings = best_query_net(test_query_features).detach().numpy()
test_query_embeddings.astype("float32").tofile("{}/data/query_embeddings_test.bin".format(DATASET))

embedding_train_scores = item_embeddings.dot(train_query_embeddings.T)
embedding_train_scores.astype("float32").tofile(
    "{}/data/model_scores/embedding_scores_train.bin".format(DATASET)
)
del embedding_train_scores

embedding_test_scores = item_embeddings.dot(test_query_embeddings.T)
embedding_test_scores.astype("float32").tofile(
    "{}/data/model_scores/embedding_scores_test.bin".format(DATASET)
)
del embedding_test_scores