In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))

In [209]:
from torch.utils.data import DataLoader
import torch.optim as optim
import torch
import torch.nn.functional as F

import numpy as np
import pandas as pd

from src.ml.data_loader import Sequences, SequencesDataset
from src.ml.skipgram import SkipGram
from src.ml.mf import MF
from src.utils.logger import logger
from src.utils.io_utils import load_model
from sklearn.metrics import roc_auc_score

In [24]:
batchsize = 2
shuffle = False
num_workers = 4
emb_dim = 8
epochs = 1
initial_lr=0.025
MODEL_PATH = '../model'

In [25]:
dataset = 'electronics'

In [26]:
sequences = Sequences('../data/{}_sequences_samp.npy'.format(dataset),
                      '../data/{}_edges_val_samp.csv'.format(dataset))

2019-12-10 12:41:56,623 - Sequences loaded (length = 5,000)
2019-12-10 12:41:56,701 - Validation set loaded: (100000, 3)
2019-12-10 12:41:56,712 - Word frequency calculated
2019-12-10 12:41:56,748 - Adding val products to word2id, original size: 28695
2019-12-10 12:41:56,814 - Added val products to word2id, updated size: 133050
2019-12-10 12:41:56,819 - No. of unique tokens: 133050
2019-12-10 12:41:58,026 - Model saved to model/word2id
2019-12-10 12:41:59,266 - Model saved to model/id2word
2019-12-10 12:41:59,267 - Word2Id and Id2Word created and saved
2019-12-10 12:41:59,294 - Convert sequence and wordfreq to ID
2019-12-10 12:41:59,428 - Discard probability calculated
2019-12-10 12:42:00,998 - Negative sample table created


In [27]:
sequences_dset = SequencesDataset(sequences)

In [193]:
sequences_dload = DataLoader(sequences_dset, batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, collate_fn=sequences_dset.collate_for_mf)

In [194]:
device = 'cpu'

In [195]:
skipgram = SkipGram(sequences.n_unique_tokens, emb_dim).to(device)

In [196]:
mf = MF(sequences.n_unique_tokens, emb_dim).to(device)

In [197]:
mf

MF(
  (embedding): Embedding(133050, 8, sparse=True)
  (sig): Sigmoid()
  (bce): BCELoss()
)

In [211]:
val_samp = pd.read_csv('../data/{}_edges_val_samp.csv'.format(dataset), dtype={'product1': 'object', 'product2': 'object'})

In [212]:
val_samp.head()

Unnamed: 0,product1,product2,edge
0,b002goovnk,b008mrzsh8,1
1,b00aodd3js,b00f0rrcqi,1
2,b005abj0h8,b00dzrguao,1
3,b0002exjra,b000067rrx,0
4,b00dziz6qc,b008mogskm,0


In [213]:
word2id = load_model('../model/word2id')

2019-12-10 16:02:09,785 - Model loaded from: ../model/word2id (Size: 16818322 bytes)


In [214]:
word2id_func =  np.vectorize(sequences.get_product_id)

In [215]:
val_samp['product1_id'] = word2id_func(val_samp['product1'].values)
val_samp['product2_id'] = word2id_func(val_samp['product2'].values)

In [221]:
optimizer = optim.Adam(mf.parameters(), lr=initial_lr)

for epoch in range(epochs):
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(sequences_dload))
    
    running_loss = 0
    for i, batches in enumerate(sequences_dload):

        product1 = batches[0].to(device)
        product2 = batches[1].to(device)
        label = batches[2].to(device)

        optimizer.zero_grad()
        
        pred = mf.forward(product1, product2)
        loss = mf.loss(pred, label)
        loss.backward()
        
        optimizer.step()

        scheduler.step()
        running_loss = running_loss * 0.9 + loss.item() * 0.1

        if i % 1000 == 0:
            pred = mf.forward(torch.LongTensor(val_samp['product1_id']), torch.LongTensor(val_samp['product2_id']))
            score = roc_auc_score(val_samp['edge'], pred.detach().cpu().numpy())
            logger.info("Epoch: {}, Seq: {:,}/{:,}, " \
                            "Loss: {:.4f}, AUC-ROC: {:.4f}, Lr: {:.6f}".format(epoch, i, len(sequences_dload), running_loss,
                                                                               score, optimizer.param_groups[0]['lr']))
            running_loss = 0

    # skipgram.save_embeddings(file_name='{}/skipgram_epoch_{}.npy'.format(MODEL_PATH, epoch))

2019-12-10 16:05:29,330 - Epoch: 0, Seq: 0/2,500, Loss: 106469.9125, AUC-ROC: 0.5019, Lr: 0.025000
2019-12-10 16:05:45,865 - Epoch: 0, Seq: 1,000/2,500, Loss: 0.6931, AUC-ROC: 0.5000, Lr: 0.016348
2019-12-10 16:06:16,528 - Epoch: 0, Seq: 2,000/2,500, Loss: 0.6931, AUC-ROC: 0.5000, Lr: 0.002378


### Batch validation

In [224]:
val_samp = pd.read_csv('../data/{}_edges_val_samp.csv'.format(dataset), dtype={'product1': 'object', 'product2': 'object'})

In [203]:
word2id = load_model('../model/word2id')

2019-12-10 15:59:07,099 - Model loaded from: ../model/word2id (Size: 16818322 bytes)


In [204]:
word2id_func =  np.vectorize(sequences.get_product_id)

In [205]:
val_samp['product1_id'] = word2id_func(val_samp['product1'].values)
val_samp['product2_id'] = word2id_func(val_samp['product2'].values)

In [206]:
val_samp.head()

Unnamed: 0,product1,product2,edge,product1_id,product2_id
0,b002goovnk,b008mrzsh8,1,72788,3481
1,b00aodd3js,b00f0rrcqi,1,90036,11692
2,b005abj0h8,b00dzrguao,1,90178,40224
3,b0002exjra,b000067rrx,0,93919,80045
4,b00dziz6qc,b008mogskm,0,77350,9279


In [207]:
pred = mf.forward(torch.LongTensor(val_samp['product1_id']), torch.LongTensor(val_samp['product2_id']))

In [210]:
roc_auc_score(val_samp['edge'], pred.detach().cpu().numpy())

0.5018888801427521

In [185]:
batches[2]

tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,
        0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0,
        0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,

In [38]:
batches = [([(0, 1), (0, 16), (0, 17), (0, 18), (0, 19), (1, 16), (1, 17), (1, 18), (1, 19), (1, 18), (16, 1), (16, 17), (16, 18), (16, 19), (16, 18), (16, 20), (17, 1), (17, 16), (17, 18), (17, 19), (17, 18), (17, 20), (17, 21), (18, 1), (18, 16), (18, 17), (18, 19), (18, 20), (18, 21), (18, 22), (19, 1), (19, 16), (19, 17), (19, 18), (19, 18), (19, 20), (19, 21), (19, 22), (18, 1), (18, 16), (18, 17), (18, 19), (18, 20), (18, 21), (18, 22), (20, 16), (20, 17), (20, 18), (20, 19), (20, 18), (20, 21), (20, 22), (21, 17), (21, 18), (21, 19), (21, 18), (21, 20), (21, 22), (22, 18), (22, 19), (22, 18), (22, 20), (22, 21)], [np.array([16893,  7785,  1742,  3798, 23385]), np.array([ 3775, 13900,  3872,  3314, 13042]), np.array([15290, 12362, 12592,  1421,  9167]), np.array([17752,  1635, 22254,  2127, 10099]), np.array([16791, 21850,  9243,  8241, 17682]), np.array([11652, 16975, 26739, 22641,  7904]), np.array([ 1612,  6095,  3986, 18929, 27609]), np.array([22895, 23120, 20906, 19153,  2648]), np.array([22959, 22241, 19072, 28654,   138]), np.array([ 7192, 18522, 19229, 15920,  7314]), np.array([ 2313,  9281,  5522, 12830, 14786]), np.array([ 5529,  1174, 17789, 20110, 27401]), np.array([28122,  3385, 25407, 11989,  9864]), np.array([6444,  414, 5932, 3441, 6810]), np.array([20004,  1668,  5752, 25278,  9297]), np.array([14186,  4043, 24543, 13235, 17296]), np.array([6175,  111, 4998, 2079,  238]), np.array([ 3597, 27751, 25005, 10147, 28054]), np.array([17293,  1585, 11230, 27460, 27039]), np.array([13029,   199,  4485,  2235,  8265]), np.array([21392, 27721,  4601, 22600,     9]), np.array([10200,  4131,  5318, 18598, 16372]), np.array([14561, 10355, 13711,  2684, 20142]), np.array([  777, 14639, 21038, 24919,  4635]), np.array([23424, 27497,  7179,  2471, 20765]), np.array([ 8526, 23348, 18940, 10454, 14109]), np.array([23184,  9857,  5525, 10003, 16701]), np.array([ 8733,  7452, 10675, 13453, 15555]), np.array([  774, 14581, 23812, 14576,  3984]), np.array([ 6261, 25444,  3617, 15542, 15717]), np.array([15875,  3421,  6203, 26751,   389]), np.array([  280, 28225, 23929, 27923, 27860]), np.array([23558,  4122,  3447, 14617, 25795]), np.array([ 9149,  3157,  8415, 25822, 17136]), np.array([ 7150, 24854,  6515, 19203, 23283]), np.array([ 7479, 11691, 17956, 22284,  5715]), np.array([ 7330, 19830,  2261, 27297, 21537]), np.array([21605, 12773,  1536,  1877,  1068]), np.array([18420,  6077,  5444,  5818,   897]), np.array([ 6359, 12901, 12520,  5703,  3937]), np.array([ 6239,  8637, 13871, 14774,   939]), np.array([ 6827, 16997, 16239, 21134,  2589]), np.array([25608, 19451,  5426,  1412, 24042]), np.array([20998, 18141,  5213,  4765, 24992]), np.array([15465,  4693,   371, 10942,  3446]), np.array([17343, 22611, 16544, 11952, 10357]), np.array([ 4547,  9092, 17559, 18945,  1979]), np.array([14907,  8946,  8648, 18913, 20928]), np.array([ 5546,  9787,  6691,  3264, 12164]), np.array([17332,   888, 23341, 11662, 14697]), np.array([  241,  7105, 24888, 18353, 14597]), np.array([15592,  3471,   541, 23159,  1669]), np.array([ 9130, 28691, 21231, 15956,  6746]), np.array([  158, 14910, 11006, 24289,  1517]), np.array([ 4452,    49,  1434, 15089, 16928]), np.array([19768,  3258,  7135,  5240,  1749]), np.array([15400, 19319,  5405, 26454, 12264]), np.array([11499,  5086, 12174,  4803, 21026]), np.array([27747,  4911, 15460,  7008, 27865]), np.array([17256, 28646, 26508, 27799, 28248]), np.array([ 5945,   989, 13577,  2579, 10126]), np.array([  68, 2742, 6218, 9758, 2247]), np.array([14467, 11643, 23159, 26576, 17651])]), ([(0, 1), (0, 16), (0, 23), (0, 24), (0, 23), (1, 16), (1, 23), (1, 24), (1, 23), (1, 25), (16, 1), (16, 23), (16, 24), (16, 23), (16, 25), (16, 26), (23, 1), (23, 16), (23, 24), (23, 25), (23, 26), (23, 27), (24, 1), (24, 16), (24, 23), (24, 23), (24, 25), (24, 26), (24, 27), (24, 28), (23, 1), (23, 16), (23, 24), (23, 25), (23, 26), (23, 27), (23, 28), (25, 1), (25, 16), (25, 23), (25, 24), (25, 23), (25, 26), (25, 27), (25, 28), (26, 16), (26, 23), (26, 24), (26, 23), (26, 25), (26, 27), (26, 28), (27, 23), (27, 24), (27, 23), (27, 25), (27, 26), (27, 28), (28, 24), (28, 23), (28, 25), (28, 26), (28, 27)], [np.array([20400,   970,  4250, 13923, 28597]), np.array([20048,  3179, 18533, 19246,  3971]), np.array([ 6275, 26076,   966, 10791,  8962]), np.array([12274, 16744, 13640, 18653,  9292]), np.array([  754,  2368, 28473, 15944, 22976]), np.array([16799,  2158, 15687, 16118,  2261]), np.array([23990, 10712,  3120,  8953,  4989]), np.array([ 3433,  1601, 28215, 14982, 20326]), np.array([16638, 11504,  4356,  4247, 24401]), np.array([21379,   393, 26218,   542,  6727]), np.array([18303,  7593, 14208,  2877, 21141]), np.array([ 7220, 24594, 11930, 21150,  5756]), np.array([26657,  2830,  5761, 21517, 15013]), np.array([  845, 24737, 13115, 21113, 25831]), np.array([14035, 17158, 22113, 23112, 13200]), np.array([  501, 16632,   591,  3468,  5522]), np.array([12959, 23938, 20876, 12079, 25779]), np.array([19149,  1180,  9755, 25829, 19825]), np.array([ 8932,  5294, 15734,  4512,  8922]), np.array([13674,  4828,  3538,     8, 16153]), np.array([ 9279,  2332, 15912, 24442,  8474]), np.array([17501, 15205, 15568,  4835, 11255]), np.array([23497,  9399, 13637,  6697, 25748]), np.array([19999,  1420,  5208, 10182,   332]), np.array([27225, 13230,  3451, 17584, 24224]), np.array([18011, 22918, 25743, 14421,  7435]), np.array([ 2514,  1846,  7416, 10360,  5519]), np.array([12061, 10859,  4051,  3433, 14152]), np.array([20166,  9170,  4608, 13782,  9689]), np.array([24938, 12359, 18216,  7570,  8796]), np.array([ 1915,  6583,  4771, 14670, 16621]), np.array([ 2303, 20438, 10681,   302,  8278]), np.array([ 9320, 11406, 11807, 11327,  6123]), np.array([11623,  5042,  3924, 20323, 17752]), np.array([ 6663, 25323, 22566,   141,  5780]), np.array([20076, 12709, 11799, 18652, 23727]), np.array([   43, 25682, 22516, 18343, 11646]), np.array([20802,   446, 20245, 21910,   186]), np.array([21136, 11007, 26414,  8794, 24822]), np.array([21233,  1508, 14825,  5486, 24611]), np.array([19355, 18504, 19249, 11488, 14912]), np.array([ 1758, 15663, 15676, 21604, 10347]), np.array([22088,  3385, 13451, 23311,  9401]), np.array([22806, 25739, 17213, 21627, 28443]), np.array([11616,  6961,  5648, 23501,  6033]), np.array([ 9667, 16374, 25502, 27107, 25958]), np.array([24013,  5626, 23093, 12309, 22489]), np.array([15940, 24295, 12386,  4288, 12599]), np.array([ 5751,  6100, 13004,  4713, 26381]), np.array([    1,  4951, 15927,   171,   999]), np.array([ 7311, 13371, 26837,  8715, 17216]), np.array([10037,   645,  3242, 24129, 28494]), np.array([ 3984,  5788, 13243,  1818,  4412]), np.array([ 9936,  6655,  9828, 20231,  7555]), np.array([ 4077,  4513, 20980,  2472, 16093]), np.array([22576, 26184,  5427, 13584,  2453]), np.array([21447, 23684, 23994, 22414,   449]), np.array([21850,  2937, 22665,  3041,  9162]), np.array([17043, 28343, 11243, 28353,  2445]), np.array([ 8596, 23371,  1244,  1202, 22093]), np.array([10528,  6678, 18762, 16641,    83]), np.array([18999,  6132, 15222,  6224, 15491]), np.array([ 5217, 20698, 18030, 11194, 22216])])]

In [42]:
pairs_batch = [batch[0] for batch in batches]
neg_contexts_batch = [batch[1] for batch in batches]

In [166]:
batch_list = []

for batch in batches:
    pairs = np.array(batch[0])
    negs = np.array(batch[1])
    negs = np.vstack((pairs[:, 0].repeat(negs.shape[1]), negs.ravel())).T
    
    pairs_arr = np.ones((pairs.shape[0], pairs.shape[1] + 1), dtype=int)
    pairs_arr[:, :-1] = pairs
    
    negs_arr = np.zeros((negs.shape[0], negs.shape[1] + 1), dtype=int)
    negs_arr[:, :-1] = negs
    
    all_arr = np.vstack((pairs_arr, negs_arr))
    batch_list.append(all_arr)

In [168]:
batch_array = np.vstack(batch_list)

In [171]:
np.random.shuffle(batch_array)

In [173]:
batch_array

array([[   26,  8715,     0],
       [    0, 16744,     0],
       [   23,  3538,     0],
       ...,
       [   16,     1,     1],
       [   16,    20,     1],
       [    1, 27609,     0]])

In [177]:
torch.LongTensor(batch_array[:, 2])

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,

In [120]:
pairs = np.array(batches[0][0])

In [121]:
negs = np.array(batches[0][1])
negs = np.vstack((pairs[:, 0].repeat(negs.shape[1]), negs.ravel())).T

In [153]:
pairs_arr = np.ones((pairs.shape[0], pairs.shape[1] + 1), dtype=int)
pairs_arr[:, :-1] = pairs

In [154]:
negs_arr = np.zeros((negs.shape[0], negs.shape[1] + 1), dtype=int)
negs_arr[:, :-1] = negs

In [156]:
pairs_arr

array([[ 0,  1,  1],
       [ 0, 16,  1],
       [ 0, 17,  1],
       [ 0, 18,  1],
       [ 0, 19,  1],
       [ 1, 16,  1],
       [ 1, 17,  1],
       [ 1, 18,  1],
       [ 1, 19,  1],
       [ 1, 18,  1],
       [16,  1,  1],
       [16, 17,  1],
       [16, 18,  1],
       [16, 19,  1],
       [16, 18,  1],
       [16, 20,  1],
       [17,  1,  1],
       [17, 16,  1],
       [17, 18,  1],
       [17, 19,  1],
       [17, 18,  1],
       [17, 20,  1],
       [17, 21,  1],
       [18,  1,  1],
       [18, 16,  1],
       [18, 17,  1],
       [18, 19,  1],
       [18, 20,  1],
       [18, 21,  1],
       [18, 22,  1],
       [19,  1,  1],
       [19, 16,  1],
       [19, 17,  1],
       [19, 18,  1],
       [19, 18,  1],
       [19, 20,  1],
       [19, 21,  1],
       [19, 22,  1],
       [18,  1,  1],
       [18, 16,  1],
       [18, 17,  1],
       [18, 19,  1],
       [18, 20,  1],
       [18, 21,  1],
       [18, 22,  1],
       [20, 16,  1],
       [20, 17,  1],
       [20, 1

In [158]:
output = np.vstack((pairs_arr, negs_arr))

In [161]:
pairs_arr.shape[0] + negs_arr.shape[0]

378

In [163]:
np.random.shuffle(output)

In [97]:
pairs_arr

array([[ 0,  1],
       [ 0,  1],
       [ 0,  1],
       [ 0,  1],
       [ 0,  1],
       [ 1,  1],
       [ 1,  1],
       [ 1,  1],
       [ 1,  1],
       [ 1,  1],
       [16,  1],
       [16,  1],
       [16,  1],
       [16,  1],
       [16,  1],
       [16,  1],
       [17,  1],
       [17,  1],
       [17,  1],
       [17,  1],
       [17,  1],
       [17,  1],
       [17,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [19,  1],
       [19,  1],
       [19,  1],
       [19,  1],
       [19,  1],
       [19,  1],
       [19,  1],
       [19,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [18,  1],
       [20,  1],
       [20,  1],
       [20,  1],
       [20,  1],
       [20,  1],
       [20,  1],
       [20,  1],
       [21,  1],
       [21,  1],
       [21,  1],
       [21,  1],
       [21,  1],
       [21,  1],
       [22,  1

In [93]:
neg_pairs_arr

array([[    0, 16893],
       [    0,  7785],
       [    0,  1742],
       [    0,  3798],
       [    0, 23385],
       [    0,  3775],
       [    0, 13900],
       [    0,  3872],
       [    0,  3314],
       [    0, 13042],
       [    0, 15290],
       [    0, 12362],
       [    0, 12592],
       [    0,  1421],
       [    0,  9167],
       [    0, 17752],
       [    0,  1635],
       [    0, 22254],
       [    0,  2127],
       [    0, 10099],
       [    0, 16791],
       [    0, 21850],
       [    0,  9243],
       [    0,  8241],
       [    0, 17682],
       [    1, 11652],
       [    1, 16975],
       [    1, 26739],
       [    1, 22641],
       [    1,  7904],
       [    1,  1612],
       [    1,  6095],
       [    1,  3986],
       [    1, 18929],
       [    1, 27609],
       [    1, 22895],
       [    1, 23120],
       [    1, 20906],
       [    1, 19153],
       [    1,  2648],
       [    1, 22959],
       [    1, 22241],
       [    1, 19072],
       [   

In [72]:
pairs_arr[:, 0].repeat(5)

array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, 16,
       16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
       16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17,
       17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
       17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18,
       18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
       18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19,
       19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
       19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
       19, 19, 19, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
       18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,
       18, 18, 18, 18, 20

In [73]:
negs_arr.ravel()

array([16893,  7785,  1742,  3798, 23385,  3775, 13900,  3872,  3314,
       13042, 15290, 12362, 12592,  1421,  9167, 17752,  1635, 22254,
        2127, 10099, 16791, 21850,  9243,  8241, 17682, 11652, 16975,
       26739, 22641,  7904,  1612,  6095,  3986, 18929, 27609, 22895,
       23120, 20906, 19153,  2648, 22959, 22241, 19072, 28654,   138,
        7192, 18522, 19229, 15920,  7314,  2313,  9281,  5522, 12830,
       14786,  5529,  1174, 17789, 20110, 27401, 28122,  3385, 25407,
       11989,  9864,  6444,   414,  5932,  3441,  6810, 20004,  1668,
        5752, 25278,  9297, 14186,  4043, 24543, 13235, 17296,  6175,
         111,  4998,  2079,   238,  3597, 27751, 25005, 10147, 28054,
       17293,  1585, 11230, 27460, 27039, 13029,   199,  4485,  2235,
        8265, 21392, 27721,  4601, 22600,     9, 10200,  4131,  5318,
       18598, 16372, 14561, 10355, 13711,  2684, 20142,   777, 14639,
       21038, 24919,  4635, 23424, 27497,  7179,  2471, 20765,  8526,
       23348, 18940,

In [50]:
batches

[([(0, 1),
   (0, 16),
   (0, 17),
   (0, 18),
   (0, 19),
   (1, 16),
   (1, 17),
   (1, 18),
   (1, 19),
   (1, 18),
   (16, 1),
   (16, 17),
   (16, 18),
   (16, 19),
   (16, 18),
   (16, 20),
   (17, 1),
   (17, 16),
   (17, 18),
   (17, 19),
   (17, 18),
   (17, 20),
   (17, 21),
   (18, 1),
   (18, 16),
   (18, 17),
   (18, 19),
   (18, 20),
   (18, 21),
   (18, 22),
   (19, 1),
   (19, 16),
   (19, 17),
   (19, 18),
   (19, 18),
   (19, 20),
   (19, 21),
   (19, 22),
   (18, 1),
   (18, 16),
   (18, 17),
   (18, 19),
   (18, 20),
   (18, 21),
   (18, 22),
   (20, 16),
   (20, 17),
   (20, 18),
   (20, 19),
   (20, 18),
   (20, 21),
   (20, 22),
   (21, 17),
   (21, 18),
   (21, 19),
   (21, 18),
   (21, 20),
   (21, 22),
   (22, 18),
   (22, 19),
   (22, 18),
   (22, 20),
   (22, 21)],
  [array([16893,  7785,  1742,  3798, 23385]),
   array([ 3775, 13900,  3872,  3314, 13042]),
   array([15290, 12362, 12592,  1421,  9167]),
   array([17752,  1635, 22254,  2127, 10099]),
   array

In [14]:
emb_center = skipgram.center_embeddings(centers)  # Get embeddings for center word
emb_context = skipgram.context_embeddings(contexts)  # Get embeddings for context word
emb_neg_context = skipgram.context_embeddings(neg_contexts)  # Get embeddings for negative context words

In [16]:
emb_center.shape

torch.Size([514, 8])

In [17]:
emb_context.shape

torch.Size([514, 8])

In [18]:
emb_neg_context.shape

torch.Size([514, 5, 8])

### Save torch params

In [21]:
torch.save(skipgram.state_dict(), '../model/skipgram_sample.pt')

In [22]:
model = SkipGram(sequences.n_unique_tokens, emb_dim).to(device)

In [23]:
model.load_state_dict(torch.load('../model/skipgram_sample.pt'))

<All keys matched successfully>

In [24]:
model.eval()

SkipGram(
  (center_embeddings): Embedding(7757, 8, sparse=True)
  (context_embeddings): Embedding(7757, 8, sparse=True)
)

### Check with validation

In [5]:
val = pd.read_csv('../data/{}_edges_val.csv'.format(dataset), dtype={'product1': 'object', 'product2': 'object'})

In [6]:
sample_idx = np.random.randint(0, val.shape[0], 100000)

In [7]:
val_samp = val.iloc[sample_idx]

In [8]:
val_samp.head()

Unnamed: 0,product1,product2,edge
1055342,b002goovnk,b008mrzsh8,1
535317,b00aodd3js,b00f0rrcqi,1
737360,b005abj0h8,b00dzrguao,1
1333506,b0002exjra,b000067rrx,0
2376672,b00dziz6qc,b008mogskm,0


In [63]:
val_samp = pd.read_csv('../data/books_edges_train.csv', nrows=100, dtype={'product1': 'object', 'product2': 'object'})
val_samp['edge'] = np.where(val_samp['weight'] > 1, 1, 0)
val_samp.to_csv('../data/books_edges_train_samp.csv')

In [50]:
val_samp = pd.read_csv('../data/books_edges_val_samp.csv', dtype={'product1': 'object', 'product2': 'object'})

In [51]:
word2id = load_model('../model/word2id')

2019-12-03 15:39:30,897 - Model loaded from: ../model/word2id (Size: 969863 bytes)


In [52]:
word2id_func =  np.vectorize(sequences.get_product_id)

In [53]:
val_samp['product1_id'] = word2id_func(val_samp['product1'].values)
val_samp['product2_id'] = word2id_func(val_samp['product2'].values)

In [54]:
val_samp = val_samp[(val_samp['product1_id'] > -1) & (val_samp['product2_id'] > -1)]

In [55]:
val_samp

Unnamed: 0.1,Unnamed: 0,product1,product2,edge,product1_id,product2_id
2834,2460300,0062060244,1578643031,1,2516,7166
5158,4981598,0060501960,0439441609,1,167,259
5954,2408135,1606903888,1616553707,1,3733,3745
6342,1651623,0989103137,1492206601,1,1519,1480
6352,7430236,1440213747,1607058529,1,5796,5781
...,...,...,...,...,...,...
96025,2622723,0071819541,0875632157,1,469,610
96551,4896177,0060580461,0812980557,1,3226,3115
97112,5555367,014311753X,0385349580,1,1188,1053
97705,4657423,0822572257,0824603621,1,6256,6231


In [43]:
product1_emb = model.get_center_emb(torch.LongTensor(product1_id))
product2_emb = model.get_center_emb(torch.LongTensor(product2_id))

RuntimeError: index out of range: Tried to access index -1 out of table with 7756 rows. at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:418

In [32]:
product1_emb

tensor([[-0.2244,  0.1584, -0.2018,  0.2065,  0.2086, -0.2308,  0.2165, -0.2227],
        [-0.2875, -0.1795,  0.0093,  0.2784,  0.2273, -0.2611,  0.0504, -0.2687],
        [-0.2470, -0.2357, -0.0490, -0.3754,  0.3267,  0.3455, -0.0134, -0.3363],
        [-0.2553, -0.1330, -0.1388, -0.1701,  0.2362, -0.2577,  0.0798, -0.0122],
        [-0.1892, -0.2706, -0.3198, -0.3673, -0.3417, -0.3226,  0.1317, -0.3117],
        [ 0.0570,  0.0895, -0.1997, -0.2253,  0.0674, -0.2361, -0.2316, -0.2013],
        [-0.1330, -0.2262, -0.3651, -0.0874, -0.0853, -0.3225, -0.3004, -0.2572],
        [-0.2493, -0.0244, -0.2456,  0.2454, -0.0101, -0.2079, -0.1544, -0.0653],
        [-0.3194, -0.3137,  0.3596, -0.2866, -0.2485,  0.0837, -0.2537, -0.3440],
        [-0.2278, -0.2868, -0.3016,  0.0993, -0.0986, -0.3850,  0.3154, -0.3355]],
       grad_fn=<EmbeddingBackward>)

In [None]:
cos_sim = F.cosine_similarity(product1_emb, product2_emb)
cos_sim

In [None]:
cos_sim.detach().numpy()

In [None]:
x = np.array([-0.2257,  0.2379, -0.2139,  0.2115,  0.2185, -0.2326,  0.2114, -0.2235])
y = np.array([-0.2150, -0.1220,  0.0284,  0.2917,  0.1297, -0.2589, -0.1423, -0.2585])

In [None]:
np.inner(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

In [None]:
product1_tensor

In [None]:
print(emb)

In [None]:
skipgram.state_dict()