# Learning item embeddings with Gensim's Word2Vec

< some kind of discussion > 


In [24]:
import collections
import itertools
from copy import deepcopy 

import numpy as np

from gensim.models.word2vec import Word2Vec
from gensim.models.callbacks import CallbackAny2Vec
from ray import tune
from ray.tune.schedulers import ASHAScheduler

from recsys.data import load_ecomm, train_test_split
from recsys.metrics import recall_at_k
from recsys.utils import absolute_filename

## Quick exploration of the data (products in the Online Retail dataset) 

There's an entire notebook on this in the Session-Based Recommenders repo. Just pull some of the good bits an dput them here. 

## Setting up training

What do we need to train good item embeddings? 

A few ingredients. 
- Logging 
- hyperparameter optimization
- early stopping 


### Logging with Gensim Callbacks

A loss logger is pretty standard but it turns out that Gensim's loss logging for the word2vec model is a little buggy. Many times, this metric provides little value because it is measuring the training loss compared to... blah blah blah

In [25]:
class LossLogger(CallbackAny2Vec):
    '''Report training loss at each epoch'''
    def __init__(self):
        self.epoch = 0
        self.previous_loss = 0
        self.training_loss = []

    def on_epoch_end(self, model):
        # the loss output by Word2Vec is more akin to a cumulative loss and increases each epoch
        # to get a value closer to loss per epoch, we subtract
        cumulative_loss = model.get_latest_training_loss()
        loss = cumulative_loss - self.previous_loss
        self.previous_loss = cumulative_loss
        self.training_loss.append(loss)
        print(f' Loss: {loss}')
        self.epoch += 1

word2vec is used to learn embeddings for individual words or word pieces, but these embeddings are hardly the ultimate goal. Instead, these embeddings are usually created with some other downstream task in mind. If that task is known, a more useful logging metric is one that measures the downstream task. Such metrics could include

- example
- example
- example



In [26]:
class RecallAtKLogger(CallbackAny2Vec):
    '''Report Recall@K at each epoch'''
    def __init__(self, validation_set, k, ray_tune=False, save_model=False):
        self.epoch = 0
        self.recall_scores = []
        self.validation = validation_set
        self.k = k
        self.tune = ray_tune
        self.save = save_model

    def on_epoch_begin(self, model):
        if not self.tune:
            print(f'Epoch: {self.epoch}', end='\t')

    def on_epoch_end(self, model):
        # method 1: deepcopy the model and set the model copy's wv to None
        mod = deepcopy(model)
        mod.wv.norms = None # will cause it recalculate norms? 
        
        # Every 10 epochs, save the model 
        if self.epoch%10 == 0 and self.save: 
            # method 2: save and reload the. model
            model.save(absolute_filename(f"{MODEL_DIR}w2v_{self.epoch}.model"))
            #mod = Word2Vec.load(f"w2v_{self.epoch}.model")
        
        ratk_score = recall_at_k(self.validation, mod.wv, self.k)  

        if self.tune: 
            tune.report(recall_at_k = ratk_score)    
        else:
            self.recall_scores.append(ratk_score)
            print(f' Recall@10: {ratk_score}')
        self.epoch += 1

Now we can create a wrapper to train our word2vec model. We'll allow it to be trained with or without the Callbacks above (to provide additional flexibility). After training, this function returns the all-important embeddings vectors. 

Potential Pitfalls: starting and stopping training -- don't do it!! Just train all at once. Explain why. 
(Start and end learning rate becomes saw-tooth.)

In [27]:
def train_w2v(train_data, params:dict, callbacks=None, model_name=None):
    if model_name: 
        # Load a model for additional training. 
        model = Word2Vec.load(model_name)
    else: 
        # train model
        if callbacks:
            model = Word2Vec(callbacks=callbacks, **params)
        else:
            model = Word2Vec(**params)
        model.build_vocab(train_data)

    model.train(train_data, total_examples=model.corpus_count, epochs=model.epochs, compute_loss=True)
    vectors = model.wv
    return vectors

In [28]:
def tune_w2v(config):
    # load data
    if config['dataset'] == 'recsys15':
        sessions = load_recsys15()
    elif config['dataset'] == 'aotm':
        sessions = load_aotm()
    elif config['dataset'] == 'ecomm':
        sessions = load_ecomm()
    else:
        print(f"{config['dataset']}  is not a valid dataset name. Please choose from recsys15, aotm or ecomm")
        return 

    train, test, valid = train_test_split(sessions, test_size=1000)
    ratk_logger = RecallAtKLogger(valid, k=config['k'], ray_tune=True)

    # remove keys from config that aren't hyperparameters of word2vec
    config.pop('dataset')
    config.pop('k')
    train_w2v(train, params=config, callbacks=[ratk_logger])

### Hyperperamter optimization and early stopping with Ray Tune

In [29]:
# Define the hyperparameter search space for Word2Vec algorithm
search_space = {
    "dataset": "ecomm",
    "k": 10,
    #"size": ### Hyperperamter optimization and early stopping with Ray Tunetune.grid_search(list(np.arange(10,106, 6))),
    #"window": tune.grid_search(list(np.arange(1,22, 3))),
    #"ns_exponent": tune.grid_search(list(np.arange(-1, 1.2, .2))),
    #"alpha": tune.grid_search([0.001, 0.01, 0.1]),
    "negative": tune.grid_search(list(np.arange(1,22, 3))),
    "iter": 10,
    "min_count": 1,
    "workers": 6,
    "sg": 1,
}

# The ASHA Scheduler will stop underperforming trials in a principled fashion
asha_scheduler = ASHAScheduler(max_t=100, grace_period=10) 

# Set the stopping critera -- use the smoke-test arg to test the system 
stopping_criteria = {"training_iteration": 9999}

# Perform hyperparamter sweep with Ray Tune
analysis = tune.run(
    tune_w2v,
    #name=args.name,
    local_dir=absolute_filename("ray_results"),
    metric="recall_at_k",
    mode="max",
    scheduler=asha_scheduler,
    stop=stopping_criteria,
    num_samples=1,
    verbose=1,
    resources_per_trial={
        "cpu": 1,
        "gpu": 0
    },
    config=search_space,
)

2021-07-13 13:34:42,795	INFO tune.py:450 -- Total run time: 60.91 seconds (60.89 seconds for the tuning loop).


In [30]:
analysis.dataframe()

Unnamed: 0,recall_at_k,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_total_s,...,iterations_since_restore,trial_id,config/dataset,config/iter,config/k,config/min_count,config/negative,config/sg,config/workers,logdir
0,0.231,1.542614,False,,,6,d6b9e5d4840a4117abb5b5c4582f392d,2021-07-13_13-33-53,1626201233,9.651698,...,6,d9252_00000,ecomm,10,10,1,1,1,6,/Users/mbeck/Projects/gensim_amp/ray_results/t...
1,0.218,2.38223,True,,,10,fb8beb277d6a4cfa9488a838c028ed28,2021-07-13_13-34-09,1626201249,25.74769,...,10,d9252_00001,ecomm,10,10,1,4,1,6,/Users/mbeck/Projects/gensim_amp/ray_results/t...
2,0.223,3.050999,True,,,10,a2865b42445646a1b0b7153a61d6843d,2021-07-13_13-34-20,1626201260,35.698053,...,10,d9252_00002,ecomm,10,10,1,7,1,6,/Users/mbeck/Projects/gensim_amp/ray_results/t...
3,0.215,3.226812,True,,,10,6cf613601c9f41f89f2a3c3d74ed9028,2021-07-13_13-34-27,1626201267,42.866144,...,10,d9252_00003,ecomm,10,10,1,10,1,6,/Users/mbeck/Projects/gensim_amp/ray_results/t...
4,0.211,3.913614,True,,,10,03b03c28984a41cd88fa6d853379b0eb,2021-07-13_13-34-34,1626201274,50.278771,...,10,d9252_00004,ecomm,10,10,1,13,1,6,/Users/mbeck/Projects/gensim_amp/ray_results/t...
5,0.214,3.129818,True,,,10,e8f20226ca704e04916ab87c286a7f74,2021-07-13_13-34-39,1626201279,55.460817,...,10,d9252_00005,ecomm,10,10,1,16,1,6,/Users/mbeck/Projects/gensim_amp/ray_results/t...
6,0.219,3.349319,False,,,9,f90429b07b974a6c95e49a9d0000ff16,2021-07-13_13-34-40,1626201280,56.162628,...,9,d9252_00006,ecomm,10,10,1,19,1,6,/Users/mbeck/Projects/gensim_amp/ray_results/t...


### Visualize ... something? 

Online retail dataset has *some* metadata but nothing categorical to show "similarity". 

Could visualize the early stopping mechanism. 