In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Any
import pickle
from pathlib import Path

In [3]:
torch.manual_seed(100)

<torch._C.Generator at 0x7f7ee01b8530>

In [4]:
def save_model(model: Any, model_path: str) -> None:
    """
    Saves model in gzip format

    Args:
        model: Model to be saved
        model_path: Path to save model to
        
    Returns:
        (None)
    """
    import gzip
    with gzip.open(model_path, "wb") as f:
        pickle.dump(model, f)

    print(f'Model saved to {model_path}')

## Create the MF Model

In [5]:
def regularize_l2(array):
    loss = torch.sum(array ** 2)
    return loss

class MF(nn.Module):
    def __init__(self, emb_size, emb_dim, c_vector=1e-6):
        super().__init__()
        self.emb_size = emb_size # size of the dictionary of embeddings
        self.emb_dim = emb_dim # size of each embedding vector
        self.c_vector = c_vector
        
        # layers
        self.embedding = nn.Embedding(emb_size, emb_dim)
        self.sig = nn.Sigmoid()
        
        # loss
        self.bce = nn.BCELoss()
        
        print(f'Model initialized: {self}')
        
    def forward(self, product1, product2):
        emb_product1 = self.embedding(product1)
        emb_product2 = self.embedding(product2)
        interaction = self.sig(torch.sum(emb_product1*emb_product2, dim = 1, dtype = torch.float))
        return interaction
    
    
    def loss(self, pred, label):
        mf_loss = self.bce(pred, label)
        
        # L2 regularization
        product_prior = refularize_l2(self.embedding.weight) * self.c_vector
        
        loss_total  = mf_loss + product_prior # loss + regularization 
        
        return loss_total

In [6]:
MF(1000, 12)

Model initialized: MF(
  (embedding): Embedding(1000, 12)
  (sig): Sigmoid()
  (bce): BCELoss()
)


MF(
  (embedding): Embedding(1000, 12)
  (sig): Sigmoid()
  (bce): BCELoss()
)

## Create the data loader

In [7]:
import itertools
from collections import Counter
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from torch.utils.data import Dataset

In [8]:
from torch.utils.data import DataLoader

In [25]:
class Sequences:
    NEGATIVE_SAMPLE_TABLE_SIZE = 1e7
    WINDOW = 5
    
    def __init__(self, sequence_path: str, val_path: str, subsample: float = 0.001, power: float = 0.75):
        """
        Intialize the dataset object
        """
        self.negative_idx = 0
        self.n_unique_tokens = 0
        
        self.sequences = np.load(sequence_path).tolist()
        self.n_sequences = len(self.sequences)
        print(f'# Sequences = {self.n_sequences}')
        
        self.val = pd.read_csv(val_path)
        print(f'# Validation data = {self.val.shape}')
        
        self.word_freq = self.get_word_freq()
        
        self.word2id, self.id2word = self.get_mapping_dicts()
        self.add_val_product_to_mapping_dicts()
        self.n_unique_tokens = len(self.word2id)
        print(f'# Tokens = {self.n_unique_tokens}')
        
        sequence_file_name = Path(sequence_path).resolve().stem
        save_model(self.word2id, f'../data/processed/{sequence_file_name}_word2id')
        save_model(self.id2word, f'../data/processed/{sequence_file_name}_id2word')
        
        self.sequences = self.convert_sequence_to_id()
        self.word_freq = self.convert_word_freq_to_id()
        
        self.discard_probs = self.get_discard_probs(sample = subsample)
        
        self.neg_table = self.get_negative_sample_table(power = power)
        
    def get_word_freq(self) -> Counter:
        """
        Returns a dictionary of word frequencies
        """
        
        seq_flat = list(itertools.chain.from_iterable(self.sequences)) # flatten the array
        
        word_freq = Counter(seq_flat)
        
        return word_freq
    
    def get_mapping_dicts(self):
        word2id = dict()
        id2word = dict()
        
        wid = 0
        for w,c in self.word_freq.items():
            word2id[w] = wid
            id2word[wid] = w
            wid += 1
        
        return word2id, id2word
    
    def add_val_product_to_mapping_dicts(self):
        val_product_set = set(self.val['product1'].values).union(set(self.val['product2'].values))
        
        print(f'Size of word2id before adding val product : {len(self.word2id)}')
        wid = max(self.word2id.values()) + 1
        for w in val_product_set:
            if w in self.word2id:
                self.word2id[w] = wid
                self.id2word[wid] = w
                wid +=1
        
        self.val = None # free up space
        print(f'Size of the word2id after adding val product : {len(self.word2id)}')
        
                
    def convert_sequence_to_id(self):
        return np.vectorize(self.word2id.get)(self.sequences)
    
    def get_product_id(self, x):
        return self.word2id.get(x, -1)
    
    def convert_word_freq_to_id(self):
        return {self.word2id[k] : v for k ,v  in self.word_freq.items()}
    
    def get_discard_probs(self, sample = 0.001):
        """
        Returns a dictionary of words and their associated discard probability, 
        word should ne discarded if np.random.rand() < probability
        """
        
        # convert to array
        word_freq = np.array(list(self.word_freq.items()), dtype=np.float64)
        
        # convert to probability
        word_freq[:, 1] = word_freq[:, 1] / word_freq[:, 1].sum()
        
        # perform subsampling 
        # http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/
        word_freq[:, 1] = (np.sqrt(word_freq[:, 1]/ sample) + 1) * (sample / word_freq[:, 1]) 
        
        # get dict 
        discard_probs = {int(k) : v for k, v in word_freq.tolist()}
        
        return discard_probs
    
    def get_negative_sample_table(self, power=0.75):
        """
        Returns a table with size = NEGATIVE_SAMPLE_TABLE_SIZE of nagative samples which can be selected via indexing. 
        """
        
        # COnvert to array 
        word_freq = np.array(list(self.word_freq.items()), dtype = np.float)
        
        # adjust the power
        word_freq[:, 1] = word_freq[:, 1] ** power
        
        # Get probabilities
        word_freq_sum = word_freq[:, 1] ** power
        word_freq[:, 1] = word_freq[:, 1] / word_freq_sum
        
        # Multiply probabilities by sample table size
        word_freq[:, 1] = np.round(word_freq[:, 1] * self.NEGATIVE_SAMPLE_TABLE_SIZE)
        
        # Convert to int 
        word_freq = word_freq.astype(int).tolist()
        
        # create the sample table
        sample_table = [[tup[0]]*tup[1] for tup in word_freq] # repeating the index (wrod_id) by proportion of their frequency (more frequent words are more probable for sampling)
        sample_table = np.array(list(itertools.chain.from_iterable(sample_table)))
        np.random.shuffle(sample_table)

        return sample_table
        
    
    def get_pairs(self, idx, window = 5):
        pairs = []
        sequence = self.sequences[idx]
        
        for center_idx, node in enumerate(sequence):
            for i in range(-window, window + 1):
                context_idx = center_idx + i
                if (context_idx > 0) and (context_idx < len(sequence)) and (node != sequence[context_idx]) and (np.random.rand() < self.discard_probs[sequence[context_idx]]):
                    pairs.append((node, sequence[context_idx]))
    
        
        return pairs
    
    def get_all_center_context_pair(self, window = 5) -> List[Tuple[int, int]]:
        """
        Returns a list of tuples (center, context).
        
        Args: 
            window:
            
        Returns:
        
        """
        
        pairs = []
        
        for sequence in self.sequences:
            for center_idx, node in enumerate(sequence):
                context_idx = center_idx + i
                if (0 <= context_idx < len(sequence)) \
                    and node != sequence[context_idx] \
                    and np.random.rand() < self.discard_probs[sequence[context_idx]]:
                    pairs.append((node, sequence[context_idx]))
                        
            
        return pairs
    
    
    def get_negative_samples(self, context, sample_size = 5) -> np.array:
        """
        Returns a list of negative samples, where len = sample_size.
        
        Args:
        
            sample_size:
            
        """
        
        while True:
            neg_sample = self.neg_table[self.negative_idx:self.negative_idx + sample_size]
            
            self.negative_idx = (self.negative_idx + sample_size) % len(self.neg_table)
            
            if len(neg_sample) != sample_size:
                neg_sample = np.concatenate((neg_sample, 
                                             self.neg_table[:self.negative_idx]))
                
            
            if not context in neg_sample:
                return neg_sample    

In [26]:
class SequencesDataset(Dataset):
    def __init__(self, sequences: Sequences, neg_sample_size = 5):
        self.sequences = sequences
        self.neg_sample_size = neg_sample_size
        
    def __len__(self):
        return self.sequences.n_sequences
    
    def __getitem__(self, idx):
        pairs = self.sequences.get_pairs(idx)
        neg_samples = []
        for center, context in pairs:
            neg_samples.append(self.sequences.get_negative_samples(context))
        
        return pairs, neg_samples
    
    @staticmethod
    def collate(batches):
        pairs_batch = [batch[0] for batch in batches]
        neg_contexts_batch = [batch[1] for batch in batches]
        
        pairs_batch = list(itertools.chain.from_iterable(pairs_batch))
        neg_contexts = list(itertools.chain.from_iterable(neg_contexts_batch))
        
        centers = [center for center, _ in pairs_batch]
        contexts = [context for _, context in pairs_batch]
        
        return torch.LongTensor(centers), torch.LongTensor(contexts), torch.LongTensor(neg_contexts)
    
    @staticmethod
    def collate_for_mf(batches):
        batch_list = []
        
        for batch in batches:
            pairs = np.array(batch[0])
            negs = np.array(batch[1])
            negs = np.vstack((pairs[:, 0].repeat(negs.shape[1]), negs.ravel())).T
            
            pairs_arr = np.ones((pairs.shape[0], pairs.shape[1] + 1), dtype=int) # 2d
            pairs_arr[:, :-1] = pairs
            
            negs_arr = mp.zeros((negs.shape[0], negs.shape[1] + 1), dtype=int) # 2d
            negs_arr[:, :-1] = negs
            
            all_arr = np.vstack((pairs, negs_arr)) # 2d
            batch_list.append(all_arr)
            
        batch_array = np.vstack(batch_list)
        
        
        return (torch.LongTensor(batch_array[:, 0]),torch.LongTensor(batch_array[:, 1]),
                torch.LongTensor(batch_array[:, 2]))
    
    

## Testing DataLoader

In [27]:
read_path = '../data/processed/meta_Electronics_random_walks.npy'
val_path = '../data/interim/meta_Electronics_edges_val.csv'

In [39]:
shuffle = True
emb_dim = 128
epochs = 5
initial_lr = 0.01

In [28]:
sequences = Sequences(read_path, val_path)

# Sequences = 4649780
# Validation data = (1440998, 3)
Size of word2id before adding val product : 464978
Size of the word2id after adding val product : 464978
# Tokens = 464978
Model saved to ../data/processed/meta_Electronics_random_walks_word2id
Model saved to ../data/processed/meta_Electronics_random_walks_id2word


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  word_freq = np.array(list(self.word_freq.items()), dtype = np.float)


In [33]:
sequences.word_freq[sequences.word2id['b00f37z8q6']]

40

In [37]:
dataset = SequencesDataset(sequences)

In [40]:
dataloader = DataLoader(dataset, batch_size = 16, shuffle = shuffle, num_workers = 16, collate_fn = dataset.collate_for_mf)

In [43]:
dataset[0]

([(512542, 521805),
  (512542, 621400),
  (512542, 507254),
  (512542, 668318),
  (512542, 507254),
  (521805, 621400),
  (521805, 507254),
  (521805, 668318),
  (521805, 507254),
  (521805, 628220),
  (621400, 521805),
  (621400, 507254),
  (621400, 668318),
  (621400, 507254),
  (621400, 628220),
  (621400, 614492),
  (507254, 521805),
  (507254, 621400),
  (507254, 668318),
  (507254, 628220),
  (507254, 614492),
  (507254, 628220),
  (668318, 521805),
  (668318, 621400),
  (668318, 507254),
  (668318, 507254),
  (668318, 628220),
  (668318, 614492),
  (668318, 628220),
  (668318, 474220),
  (507254, 521805),
  (507254, 621400),
  (507254, 668318),
  (507254, 628220),
  (507254, 614492),
  (507254, 628220),
  (507254, 474220),
  (628220, 521805),
  (628220, 621400),
  (628220, 507254),
  (628220, 668318),
  (628220, 507254),
  (628220, 614492),
  (628220, 474220),
  (614492, 621400),
  (614492, 507254),
  (614492, 668318),
  (614492, 507254),
  (614492, 628220),
  (614492, 628220),
