# Building Word Embeddings with PyTorch

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn  as  nn
import torch.autograd as autograd
import torch.optim as optim
import torch.nn.functional as F  
import re
import itertools
from collections import Counter
import random, math
from numpy.random import multinomial
import nltk
from nltk.corpus import stopwords
from tqdm import trange
from textblob import TextBlob
from nltk.stem.porter import PorterStemmer

#### Check GPU Status

In [2]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())

True
1


## Data
- Trip Advisor hotel reviews
- Sci-fi stories

In [3]:
class DataLoader:
    """Download the files and load them into a dataframe.
    
    Attributes
    ----------
    url: Url of the online dataset
    """
    def __init__(self, url):
        self.url = url

    def load_csv(self):
        url = self.url
#         url = 'https://drive.google.com/uc?id=' + url.split('/')[-2]
        df = pd.read_csv(url)
        return df
    
    def load_txt(self):
        with open(self.url) as f:
            df = f.read()
        return df

Load data from files by inheriting DataLoader.

In [4]:
hotel_url = 'data/hotel.csv'
hotel_loader = DataLoader(hotel_url)
hotel_data = hotel_loader.load_csv()

scifi_url = 'data/scifi.txt'
scifi_loader = DataLoader(scifi_url)
scifi_data = scifi_loader.load_txt()

In [6]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [7]:
# pip install textblob

#### Text Data Preprocessing

In [8]:
class TextCleaner:
    """Clean text data.

    This Cleaner performes text pre-processing through a series of 
    operations while creating a NLP application. The operations include
    Lowercasing, Removing stopwords, Filtering Letters, Correcting Spellings and Stemming.
    """
    def __init__(self):
        self.ps = PorterStemmer()
    
    def clean_hotel_data(self, data):
        corpus = []
        stopwords_english = stopwords.words('english')
        for text in data:
            text = text.lower() # Lowercasing
            text.replace('\n', ' ') 
            text = re.sub('[^a-z ]+', '', text) # Filtering Letters
            textblob_ = TextBlob(text)
            text = textblob_.correct().string   # Correcting Spellings
            new_text = [self.ps.stem(word) for word in text.split()]    # Stemming
            text = ' '.join(new_text)
            corpus.append([w for w in text.split() if w != '' and (self.ps.stem(w) in stopwords_english) == False]) # Removing stopwords
        print("Hotel data preprocessing done.")
        return corpus
    
    def clean_scifi_text(self, text):
        stopwords_english = stopwords.words('english')
        corpus = []
        text = text.lower() # Lowercasing
        text.replace('\n', ' ') 
        text = re.sub('[^a-z ]+', '', text) # Filtering Letters
        textblob_ = TextBlob(text)
        text = textblob_.correct().string   # Correcting Spellings
        new_text = [self.ps.stem(word) for word in text.split()]    # Stemming
        text = ' '.join(new_text)
        corpus.append([w for w in text.split() if w != '' and (self.ps.stem(w) in stopwords_english) == False]) # Removing stopwords
        print("Scifi data preprocessing done.")
        return corpus

Implement Data Preprocessing

In [9]:
cleaner = TextCleaner()
hotel_corpus = cleaner.clean_hotel_data(hotel_data['Review'])
scifi_corpus = cleaner.clean_scifi_text(scifi_data)

Hotel data preprocessing done.
Scifi data preprocessing done.


## Context Generator
- **Subsampling Frequent Words** \\
We balanced the word occurences in the data by calculating the probability of keeping the word i in the corpus. \\
$P(w_i) = \frac{10^{-3}}{p_i}(\sqrt{10^3p_i}+1)$
- **Negative Samples** \\
To accelerate the learning, we chose the sigmoid function as output. Meanwhile, We generated negative examples in case the network does not learn from negative samples. The probability of sampling a negative context word is defined as follows: \\
$P(w_i)=\frac{|w_i|^\frac{3}{4}}{\sum^{n}_{j=1}|w_j|^\frac{3}{4}}$ \\
- **Building BoW** \\
In this model, a text is represented as the bag (multiset) of its words, disregarding grammar and even word order but keeping multiplicity.

In [10]:
class ContextGenerator:
    """Generate the context

    This Generator performs frequent words subsampling and 
    negative samples calculating. It builds a context dataset 
    with negative examples.
    """  
    def __init__(self, corpus, sample_size, context_width):
        self.corpus = corpus
        self.sample_size = sample_size
        self.w = context_width
    
    def subsample_frequent_words(self):
        filtered_corpus = []
        word_counts = dict(Counter(list(itertools.chain.from_iterable(self.corpus))))
        sum_word_counts = sum(list(word_counts.values()))
        word_counts = {word: word_counts[word]/float(sum_word_counts) for word in word_counts}
        for text in self.corpus:
            filtered_corpus.append([])
            for word in text:
                if random.random() < (1+math.sqrt(word_counts[word] * 1e3)) * 1e-3 / float(word_counts[word]):
                    filtered_corpus[-1].append(word)
        return filtered_corpus
    
    def sample_negative(self):
        sample_probability = {}
        word_counts = dict(Counter(list(itertools.chain.from_iterable(self.corpus))))
        normalizing_factor = sum([v**0.75 for v in word_counts.values()])
        for word in word_counts:
            sample_probability[word] = word_counts[word]**0.75 / normalizing_factor
        words = np.array(list(word_counts.keys()))
        while True:
            word_list = []
            sampled_index = np.array(multinomial(self.sample_size, list(sample_probability.values())))
            for index, count in enumerate(sampled_index):
                for _ in range(count):
                     word_list.append(words[index])
            yield word_list
            
    
    def generate_context(self):
        filtered_corpus = self.subsample_frequent_words()
        vocabulary = set(itertools.chain.from_iterable(filtered_corpus))
        context_tuple_list = []
        negative_samples = self.sample_negative()
        index = 0

        for text in filtered_corpus:
            for i, word in enumerate(text):
                first_context_word_index = max(0,i-self.w)
                last_context_word_index = min(i+self.w, len(text))
                for j in range(first_context_word_index, last_context_word_index):
                    if i!=j:
                        context_tuple_list.append((word, text[j], next(negative_samples)))
        print("Generated target and context words pairs.")
        return vocabulary, context_tuple_list

Implement Context Generating on datasets with **both context widths 2 and 5**

In [11]:
context_width_2 = 2
context_width_5 = 5
sample_size = 2

hotel_cg_2 = ContextGenerator(hotel_corpus, sample_size, context_width_2)
hotel_vocab_2, hotel_context_2 = hotel_cg_2.generate_context()

hotel_cg_5 = ContextGenerator(hotel_corpus, sample_size, context_width_5)
hotel_vocab_5, hotel_context_5 = hotel_cg_5.generate_context()

scifi_cg_2 = ContextGenerator(scifi_corpus, sample_size, context_width_2)
scifi_vocab_2, scifi_context_2 = scifi_cg_2.generate_context()

scifi_cg_5 = ContextGenerator(scifi_corpus, sample_size, context_width_5)
scifi_vocab_5, scifi_context_5 = scifi_cg_5.generate_context()

Generated pairs of target and context words
Generated pairs of target and context words
Generated pairs of target and context words
Generated pairs of target and context words


#### Early Stopping

In [12]:
class EarlyStopping:
    """Early Stopping during Learning
    
    This Class aims at stopping learning when the loss does not decrease 
    significantly anymore after a certain number of iterations.
    """
    def __init__(self, patience=5, min_percent_gain=0.1):
        self.patience = patience
        self.loss_list = []
        self.min_percent_gain = min_percent_gain / 100.
        
    def update_loss(self, loss):
        self.loss_list.append(loss)
        if len(self.loss_list) > self.patience:
            del self.loss_list[0]
    
    def stop_training(self):
        if len(self.loss_list) == 1:
            return False
        gain = (max(self.loss_list) - min(self.loss_list)) / max(self.loss_list)
        print("Loss gain: {}%".format(round(100*gain,2)))
        if gain < self.min_percent_gain:
            return True
        else:
            return False

## Learner

In [16]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.nn.functional as F


class CBOW(nn.Module):

    def __init__(self, embedding_size, vocab_size):
        super(CBOW, self).__init__()
        self.embeddings_target = nn.Embedding(vocab_size, embedding_size)
        self.embeddings_context = nn.Embedding(vocab_size, embedding_size)

    def forward(self, target_word, context_word, negative_example):
        emb_target = self.embeddings_target(target_word)
        emb_context = self.embeddings_context(context_word)
        emb_product = torch.mul(emb_target, emb_context)
        emb_product = torch.sum(emb_product, dim=1)
        out = torch.sum(F.logsigmoid(emb_product))
        emb_negative = self.embeddings_context(negative_example)
        emb_product = torch.bmm(emb_negative, emb_target.unsqueeze(2))
        emb_product = torch.sum(emb_product, dim=1)
        out += torch.sum(F.logsigmoid(-emb_product))
        return -out

The neural network in trained with the following parameters:
- embedding size: 50
- batch size: 5000

In [17]:
class EmbeddingLearner:
    def __init__(self, vocabulary, context, batch_size, epochs):
        self.vocabulary = vocabulary
        self.context_tuple_list = context
        self.batch_size = batch_size
        self.word_to_index = {w: idx for (idx, w) in enumerate(self.vocabulary)}
        self.index_to_word = {idx: w for (idx, w) in enumerate(self.vocabulary)}
        self.vocabulary_size = len(self.vocabulary)
        self.net = CBOW(embedding_size=50, vocab_size=self.vocabulary_size)
        self.epochs = epochs
        
        
    def get_batches(self):
        random.shuffle(self.context_tuple_list)
        batches = []
        batch_target, batch_context, batch_negative = [], [], []
        for i in range(len(self.context_tuple_list)):
            batch_target.append(self.word_to_index[self.context_tuple_list[i][0]])
            batch_context.append(self.word_to_index[self.context_tuple_list[i][1]])
            batch_negative.append([self.word_to_index[w] for w in self.context_tuple_list[i][2]])
            if (i+1) % self.batch_size == 0 or i == len(self.context_tuple_list)-1:
                tensor_target = autograd.Variable(torch.from_numpy(np.array(batch_target)).long())
                tensor_context = autograd.Variable(torch.from_numpy(np.array(batch_context)).long())
                tensor_negative = autograd.Variable(torch.from_numpy(np.array(batch_negative)).long())
                batches.append((tensor_target, tensor_context, tensor_negative))
                batch_target, batch_context, batch_negative = [], [], []
        return batches
        
    def learn(self):

        loss_function = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.net.parameters())
        early_stopping = EarlyStopping(patience=5, min_percent_gain=0.5)

        for i in trange(0, self.epochs):
            losses = []
            context_tuple_batches = self.get_batches()
            for i in range(len(context_tuple_batches)):
                self.net.zero_grad()
                target_tensor, context_tensor, negative_tensor = context_tuple_batches[i]
                loss = self.net(target_tensor, context_tensor, negative_tensor)
                loss.backward()
                optimizer.step()
                losses.append(loss.data)
            print("Loss: ", np.mean(losses))
            early_stopping.update_loss(np.mean(losses))
            if early_stopping.stop_training():
                break

    
    def get_closest_word(self, word, topn=5):
        word_distance = []
        emb = self.net.embeddings_target
        pdist = nn.PairwiseDistance()
        i = self.word_to_index[word]
        lookup_tensor_i = torch.tensor([i], dtype=torch.long)
        v_i = emb(lookup_tensor_i)
        for j in range(len(self.vocabulary)):
            if j != i:
                lookup_tensor_j = torch.tensor([j], dtype=torch.long)
                v_j = emb(lookup_tensor_j)
                word_distance.append((self.index_to_word[j], float(pdist(v_i, v_j))))
        word_distance.sort(key=lambda x: x[1])
        return word_distance[:topn]

## Training
**CBOW2 for the Hotel Reviews dataset**

In [18]:
hotel_learner_2 = EmbeddingLearner(hotel_vocab_2,hotel_context_2, 5000, 100)
hotel_learner_2.learn()

  1%|          | 1/100 [00:01<02:48,  1.70s/it]

Loss:  34249.586


  2%|▏         | 2/100 [00:02<02:33,  1.56s/it]

Loss:  33389.023
Loss gain: 2.51%


  3%|▎         | 3/100 [00:04<02:21,  1.46s/it]

Loss:  32612.803
Loss gain: 4.78%


  4%|▍         | 4/100 [00:05<02:12,  1.38s/it]

Loss:  31856.348
Loss gain: 6.99%


  5%|▌         | 5/100 [00:06<02:11,  1.39s/it]

Loss:  31118.777
Loss gain: 9.14%


  6%|▌         | 6/100 [00:08<02:10,  1.39s/it]

Loss:  30399.293
Loss gain: 8.95%


  7%|▋         | 7/100 [00:09<02:11,  1.41s/it]

Loss:  29698.268
Loss gain: 8.94%


  8%|▊         | 8/100 [00:10<02:07,  1.38s/it]

Loss:  29014.523
Loss gain: 8.92%


  9%|▉         | 9/100 [00:12<02:05,  1.38s/it]

Loss:  28348.14
Loss gain: 8.9%


 10%|█         | 10/100 [00:13<02:04,  1.39s/it]

Loss:  27698.492
Loss gain: 8.88%


 11%|█         | 11/100 [00:15<02:11,  1.48s/it]

Loss:  27064.887
Loss gain: 8.87%


 12%|█▏        | 12/100 [00:16<02:08,  1.46s/it]

Loss:  26447.307
Loss gain: 8.85%


 13%|█▎        | 13/100 [00:18<02:03,  1.42s/it]

Loss:  25845.697
Loss gain: 8.83%


 14%|█▍        | 14/100 [00:19<01:59,  1.39s/it]

Loss:  25259.006
Loss gain: 8.81%


 15%|█▌        | 15/100 [00:20<01:56,  1.38s/it]

Loss:  24686.82
Loss gain: 8.79%


 16%|█▌        | 16/100 [00:22<01:59,  1.42s/it]

Loss:  24129.518
Loss gain: 8.76%


 17%|█▋        | 17/100 [00:23<01:55,  1.40s/it]

Loss:  23586.682
Loss gain: 8.74%


 18%|█▊        | 18/100 [00:25<01:54,  1.39s/it]

Loss:  23057.08
Loss gain: 8.72%


 19%|█▉        | 19/100 [00:26<01:51,  1.38s/it]

Loss:  22541.984
Loss gain: 8.69%


 20%|██        | 20/100 [00:28<01:58,  1.48s/it]

Loss:  22039.047
Loss gain: 8.66%


 21%|██        | 21/100 [00:29<01:52,  1.43s/it]

Loss:  21549.203
Loss gain: 8.64%


 22%|██▏       | 22/100 [00:30<01:50,  1.41s/it]

Loss:  21072.117
Loss gain: 8.61%


 23%|██▎       | 23/100 [00:32<01:47,  1.39s/it]

Loss:  20607.387
Loss gain: 8.58%


 24%|██▍       | 24/100 [00:33<01:45,  1.39s/it]

Loss:  20155.008
Loss gain: 8.55%


 25%|██▌       | 25/100 [00:34<01:44,  1.39s/it]

Loss:  19713.572
Loss gain: 8.52%


 26%|██▌       | 26/100 [00:36<01:40,  1.36s/it]

Loss:  19283.912
Loss gain: 8.49%


 27%|██▋       | 27/100 [00:37<01:38,  1.35s/it]

Loss:  18864.883
Loss gain: 8.46%


 28%|██▊       | 28/100 [00:38<01:36,  1.34s/it]

Loss:  18457.234
Loss gain: 8.42%


 29%|██▉       | 29/100 [00:40<01:35,  1.34s/it]

Loss:  18059.535
Loss gain: 8.39%


 30%|███       | 30/100 [00:41<01:40,  1.44s/it]

Loss:  17672.895
Loss gain: 8.35%


 31%|███       | 31/100 [00:43<01:36,  1.39s/it]

Loss:  17295.691
Loss gain: 8.32%


 32%|███▏      | 32/100 [00:44<01:34,  1.40s/it]

Loss:  16928.922
Loss gain: 8.28%


 33%|███▎      | 33/100 [00:45<01:32,  1.38s/it]

Loss:  16570.932
Loss gain: 8.24%


 34%|███▍      | 34/100 [00:47<01:31,  1.38s/it]

Loss:  16222.732
Loss gain: 8.21%


 35%|███▌      | 35/100 [00:48<01:30,  1.39s/it]

Loss:  15883.434
Loss gain: 8.17%


 36%|███▌      | 36/100 [00:50<01:31,  1.43s/it]

Loss:  15552.455
Loss gain: 8.13%


 37%|███▋      | 37/100 [00:51<01:28,  1.41s/it]

Loss:  15231.187
Loss gain: 8.08%


 38%|███▊      | 38/100 [00:52<01:26,  1.39s/it]

Loss:  14917.287
Loss gain: 8.05%


 39%|███▉      | 39/100 [00:54<01:23,  1.36s/it]

Loss:  14611.96
Loss gain: 8.01%


 40%|████      | 40/100 [00:55<01:26,  1.44s/it]

Loss:  14314.778
Loss gain: 7.96%


 41%|████      | 41/100 [00:57<01:22,  1.41s/it]

Loss:  14024.865
Loss gain: 7.92%


 42%|████▏     | 42/100 [00:58<01:20,  1.39s/it]

Loss:  13743.094
Loss gain: 7.87%


 43%|████▎     | 43/100 [00:59<01:17,  1.36s/it]

Loss:  13468.017
Loss gain: 7.83%


 44%|████▍     | 44/100 [01:01<01:16,  1.37s/it]

Loss:  13200.872
Loss gain: 7.78%


 45%|████▌     | 45/100 [01:02<01:15,  1.38s/it]

Loss:  12940.307
Loss gain: 7.73%


 46%|████▌     | 46/100 [01:03<01:13,  1.36s/it]

Loss:  12686.631
Loss gain: 7.69%


 47%|████▋     | 47/100 [01:05<01:12,  1.36s/it]

Loss:  12439.577
Loss gain: 7.64%


 48%|████▊     | 48/100 [01:06<01:11,  1.37s/it]

Loss:  12199.21
Loss gain: 7.59%


 49%|████▉     | 49/100 [01:08<01:16,  1.49s/it]

Loss:  11965.44
Loss gain: 7.53%


 50%|█████     | 50/100 [01:09<01:12,  1.45s/it]

Loss:  11737.423
Loss gain: 7.48%


 51%|█████     | 51/100 [01:11<01:09,  1.41s/it]

Loss:  11515.231
Loss gain: 7.43%


 52%|█████▏    | 52/100 [01:12<01:06,  1.39s/it]

Loss:  11299.36
Loss gain: 7.38%


 53%|█████▎    | 53/100 [01:13<01:05,  1.38s/it]

Loss:  11088.825
Loss gain: 7.33%


 54%|█████▍    | 54/100 [01:15<01:02,  1.37s/it]

Loss:  10884.195
Loss gain: 7.27%


 55%|█████▌    | 55/100 [01:16<01:01,  1.37s/it]

Loss:  10684.652
Loss gain: 7.21%


 56%|█████▌    | 56/100 [01:17<01:00,  1.37s/it]

Loss:  10490.405
Loss gain: 7.16%


 57%|█████▋    | 57/100 [01:19<00:58,  1.35s/it]

Loss:  10301.104
Loss gain: 7.1%


 58%|█████▊    | 58/100 [01:20<00:55,  1.31s/it]

Loss:  10117.3125
Loss gain: 7.05%


 59%|█████▉    | 59/100 [01:21<00:56,  1.37s/it]

Loss:  9938.156
Loss gain: 6.99%


 60%|██████    | 60/100 [01:23<00:53,  1.34s/it]

Loss:  9763.018
Loss gain: 6.93%


 61%|██████    | 61/100 [01:24<00:52,  1.36s/it]

Loss:  9593.409
Loss gain: 6.87%


 62%|██████▏   | 62/100 [01:25<00:51,  1.35s/it]

Loss:  9427.845
Loss gain: 6.81%


 63%|██████▎   | 63/100 [01:27<00:50,  1.36s/it]

Loss:  9266.509
Loss gain: 6.76%


 64%|██████▍   | 64/100 [01:28<00:49,  1.38s/it]

Loss:  9109.724
Loss gain: 6.69%


 65%|██████▌   | 65/100 [01:30<00:48,  1.38s/it]

Loss:  8956.697
Loss gain: 6.64%


 66%|██████▌   | 66/100 [01:31<00:46,  1.38s/it]

Loss:  8807.77
Loss gain: 6.58%


 67%|██████▋   | 67/100 [01:32<00:45,  1.37s/it]

Loss:  8662.42
Loss gain: 6.52%


 68%|██████▊   | 68/100 [01:34<00:46,  1.46s/it]

Loss:  8521.387
Loss gain: 6.46%


 69%|██████▉   | 69/100 [01:35<00:44,  1.44s/it]

Loss:  8383.468
Loss gain: 6.4%


 70%|███████   | 70/100 [01:37<00:42,  1.43s/it]

Loss:  8249.395
Loss gain: 6.34%


 71%|███████   | 71/100 [01:38<00:40,  1.39s/it]

Loss:  8118.7515
Loss gain: 6.28%


 72%|███████▏  | 72/100 [01:39<00:38,  1.37s/it]

Loss:  7991.3145
Loss gain: 6.22%


 73%|███████▎  | 73/100 [01:41<00:36,  1.34s/it]

Loss:  7867.7324
Loss gain: 6.15%


 74%|███████▍  | 74/100 [01:42<00:34,  1.34s/it]

Loss:  7746.5674
Loss gain: 6.1%


 75%|███████▌  | 75/100 [01:43<00:33,  1.33s/it]

Loss:  7629.0444
Loss gain: 6.03%


 76%|███████▌  | 76/100 [01:45<00:32,  1.35s/it]

Loss:  7514.2544
Loss gain: 5.97%


 77%|███████▋  | 77/100 [01:46<00:31,  1.35s/it]

Loss:  7402.395
Loss gain: 5.91%


 78%|███████▊  | 78/100 [01:48<00:31,  1.45s/it]

Loss:  7293.3325
Loss gain: 5.85%


 79%|███████▉  | 79/100 [01:49<00:30,  1.44s/it]

Loss:  7187.061
Loss gain: 5.79%


 80%|████████  | 80/100 [01:51<00:28,  1.42s/it]

Loss:  7083.7163
Loss gain: 5.73%


 81%|████████  | 81/100 [01:52<00:27,  1.43s/it]

Loss:  6982.6816
Loss gain: 5.67%


 82%|████████▏ | 82/100 [01:53<00:25,  1.41s/it]

Loss:  6884.4985
Loss gain: 5.61%


 83%|████████▎ | 83/100 [01:55<00:24,  1.41s/it]

Loss:  6788.2695
Loss gain: 5.55%


 84%|████████▍ | 84/100 [01:56<00:22,  1.41s/it]

Loss:  6694.869
Loss gain: 5.49%


 85%|████████▌ | 85/100 [01:58<00:21,  1.42s/it]

Loss:  6603.8096
Loss gain: 5.43%


 86%|████████▌ | 86/100 [01:59<00:19,  1.40s/it]

Loss:  6514.63
Loss gain: 5.37%


 87%|████████▋ | 87/100 [02:01<00:19,  1.49s/it]

Loss:  6428.041
Loss gain: 5.31%


 88%|████████▊ | 88/100 [02:02<00:17,  1.46s/it]

Loss:  6343.293
Loss gain: 5.25%


 89%|████████▉ | 89/100 [02:03<00:15,  1.44s/it]

Loss:  6260.7383
Loss gain: 5.2%


 90%|█████████ | 90/100 [02:05<00:14,  1.41s/it]

Loss:  6180.166
Loss gain: 5.13%


 91%|█████████ | 91/100 [02:06<00:12,  1.38s/it]

Loss:  6101.537
Loss gain: 5.08%


 92%|█████████▏| 92/100 [02:08<00:10,  1.37s/it]

Loss:  6024.9624
Loss gain: 5.02%


 93%|█████████▎| 93/100 [02:09<00:09,  1.39s/it]

Loss:  5950.0713
Loss gain: 4.96%


 94%|█████████▍| 94/100 [02:10<00:08,  1.39s/it]

Loss:  5876.9243
Loss gain: 4.91%


 95%|█████████▌| 95/100 [02:12<00:06,  1.39s/it]

Loss:  5805.779
Loss gain: 4.85%


 96%|█████████▌| 96/100 [02:13<00:05,  1.39s/it]

Loss:  5736.121
Loss gain: 4.79%


 97%|█████████▋| 97/100 [02:15<00:04,  1.50s/it]

Loss:  5668.186
Loss gain: 4.74%


 98%|█████████▊| 98/100 [02:16<00:03,  1.51s/it]

Loss:  5601.584
Loss gain: 4.69%


 99%|█████████▉| 99/100 [02:18<00:01,  1.48s/it]

Loss:  5536.9194
Loss gain: 4.63%


100%|██████████| 100/100 [02:19<00:00,  1.40s/it]

Loss:  5473.533
Loss gain: 4.58%





**CBOW2 for the Sci-Fi story dataset**

In [19]:
scifi_learner_2 = EmbeddingLearner(scifi_vocab_2,scifi_context_2,5000,100)
scifi_learner_2.learn()

  1%|          | 1/100 [00:00<00:58,  1.70it/s]

Loss:  33992.7


  2%|▏         | 2/100 [00:01<00:58,  1.67it/s]

Loss:  33380.402
Loss gain: 1.8%


  3%|▎         | 3/100 [00:01<00:59,  1.64it/s]

Loss:  32851.13
Loss gain: 3.36%


  4%|▍         | 4/100 [00:02<00:57,  1.66it/s]

Loss:  32335.531
Loss gain: 4.88%


  5%|▌         | 5/100 [00:03<00:56,  1.69it/s]

Loss:  31829.893
Loss gain: 6.36%


  6%|▌         | 6/100 [00:03<00:54,  1.73it/s]

Loss:  31331.738
Loss gain: 6.14%


  7%|▋         | 7/100 [00:04<00:52,  1.76it/s]

Loss:  30840.455
Loss gain: 6.12%


  8%|▊         | 8/100 [00:04<00:51,  1.80it/s]

Loss:  30357.762
Loss gain: 6.12%


  9%|▉         | 9/100 [00:05<00:51,  1.77it/s]

Loss:  29881.014
Loss gain: 6.12%


 10%|█         | 10/100 [00:05<00:52,  1.73it/s]

Loss:  29411.793
Loss gain: 6.13%


 11%|█         | 11/100 [00:06<00:51,  1.72it/s]

Loss:  28948.932
Loss gain: 6.13%


 12%|█▏        | 12/100 [00:06<00:51,  1.72it/s]

Loss:  28492.648
Loss gain: 6.14%


 13%|█▎        | 13/100 [00:07<00:49,  1.74it/s]

Loss:  28042.969
Loss gain: 6.15%


 14%|█▍        | 14/100 [00:08<00:59,  1.45it/s]

Loss:  27600.387
Loss gain: 6.16%


 15%|█▌        | 15/100 [00:09<00:54,  1.56it/s]

Loss:  27164.543
Loss gain: 6.16%


 16%|█▌        | 16/100 [00:09<00:51,  1.62it/s]

Loss:  26735.607
Loss gain: 6.17%


 17%|█▋        | 17/100 [00:10<00:51,  1.60it/s]

Loss:  26311.799
Loss gain: 6.17%


 18%|█▊        | 18/100 [00:10<00:51,  1.58it/s]

Loss:  25894.414
Loss gain: 6.18%


 19%|█▉        | 19/100 [00:11<00:51,  1.58it/s]

Loss:  25483.582
Loss gain: 6.19%


 20%|██        | 20/100 [00:12<00:51,  1.55it/s]

Loss:  25079.492
Loss gain: 6.19%


 21%|██        | 21/100 [00:12<00:51,  1.52it/s]

Loss:  24680.773
Loss gain: 6.2%


 22%|██▏       | 22/100 [00:13<00:50,  1.56it/s]

Loss:  24288.72
Loss gain: 6.2%


 23%|██▎       | 23/100 [00:14<00:49,  1.57it/s]

Loss:  23902.336
Loss gain: 6.2%


 24%|██▍       | 24/100 [00:14<00:48,  1.57it/s]

Loss:  23522.023
Loss gain: 6.21%


 25%|██▌       | 25/100 [00:15<00:46,  1.63it/s]

Loss:  23147.389
Loss gain: 6.21%


 26%|██▌       | 26/100 [00:15<00:44,  1.66it/s]

Loss:  22778.29
Loss gain: 6.22%


 27%|██▋       | 27/100 [00:16<00:42,  1.71it/s]

Loss:  22415.143
Loss gain: 6.22%


 28%|██▊       | 28/100 [00:17<00:42,  1.68it/s]

Loss:  22057.377
Loss gain: 6.23%


 29%|██▉       | 29/100 [00:17<00:41,  1.71it/s]

Loss:  21705.775
Loss gain: 6.23%


 30%|███       | 30/100 [00:18<00:39,  1.75it/s]

Loss:  21359.035
Loss gain: 6.23%


 31%|███       | 31/100 [00:18<00:39,  1.74it/s]

Loss:  21018.56
Loss gain: 6.23%


 32%|███▏      | 32/100 [00:19<00:39,  1.72it/s]

Loss:  20682.627
Loss gain: 6.23%


 33%|███▎      | 33/100 [00:19<00:38,  1.74it/s]

Loss:  20352.346
Loss gain: 6.24%


 34%|███▍      | 34/100 [00:20<00:36,  1.78it/s]

Loss:  20028.205
Loss gain: 6.23%


 35%|███▌      | 35/100 [00:21<00:43,  1.49it/s]

Loss:  19707.994
Loss gain: 6.24%


 36%|███▌      | 36/100 [00:21<00:40,  1.57it/s]

Loss:  19394.318
Loss gain: 6.23%


 37%|███▋      | 37/100 [00:22<00:39,  1.59it/s]

Loss:  19084.281
Loss gain: 6.23%


 38%|███▊      | 38/100 [00:23<00:37,  1.63it/s]

Loss:  18779.799
Loss gain: 6.23%


 39%|███▉      | 39/100 [00:23<00:36,  1.68it/s]

Loss:  18480.898
Loss gain: 6.23%


 40%|████      | 40/100 [00:24<00:35,  1.70it/s]

Loss:  18186.018
Loss gain: 6.23%


 41%|████      | 41/100 [00:24<00:34,  1.72it/s]

Loss:  17896.648
Loss gain: 6.22%


 42%|████▏     | 42/100 [00:25<00:33,  1.75it/s]

Loss:  17611.56
Loss gain: 6.22%


 43%|████▎     | 43/100 [00:25<00:32,  1.76it/s]

Loss:  17331.477
Loss gain: 6.22%


 44%|████▍     | 44/100 [00:26<00:32,  1.72it/s]

Loss:  17056.05
Loss gain: 6.21%


 45%|████▌     | 45/100 [00:27<00:32,  1.70it/s]

Loss:  16785.193
Loss gain: 6.21%


 46%|████▌     | 46/100 [00:27<00:31,  1.70it/s]

Loss:  16518.953
Loss gain: 6.2%


 47%|████▋     | 47/100 [00:28<00:31,  1.67it/s]

Loss:  16256.412
Loss gain: 6.2%


 48%|████▊     | 48/100 [00:28<00:30,  1.68it/s]

Loss:  15999.313
Loss gain: 6.2%


 49%|████▉     | 49/100 [00:29<00:30,  1.69it/s]

Loss:  15745.577
Loss gain: 6.19%


 50%|█████     | 50/100 [00:30<00:30,  1.67it/s]

Loss:  15496.9
Loss gain: 6.19%


 51%|█████     | 51/100 [00:30<00:29,  1.67it/s]

Loss:  15252.313
Loss gain: 6.18%


 52%|█████▏    | 52/100 [00:31<00:28,  1.68it/s]

Loss:  15011.999
Loss gain: 6.17%


 53%|█████▎    | 53/100 [00:31<00:28,  1.66it/s]

Loss:  14774.696
Loss gain: 6.17%


 54%|█████▍    | 54/100 [00:32<00:28,  1.62it/s]

Loss:  14542.466
Loss gain: 6.16%


 55%|█████▌    | 55/100 [00:33<00:28,  1.61it/s]

Loss:  14313.128
Loss gain: 6.16%


 56%|█████▌    | 56/100 [00:33<00:26,  1.66it/s]

Loss:  14088.949
Loss gain: 6.15%


 57%|█████▋    | 57/100 [00:34<00:30,  1.40it/s]

Loss:  13868.288
Loss gain: 6.13%


 58%|█████▊    | 58/100 [00:35<00:28,  1.46it/s]

Loss:  13650.39
Loss gain: 6.13%


 59%|█████▉    | 59/100 [00:35<00:27,  1.51it/s]

Loss:  13437.649
Loss gain: 6.12%


 60%|██████    | 60/100 [00:36<00:25,  1.56it/s]

Loss:  13227.421
Loss gain: 6.11%


 61%|██████    | 61/100 [00:37<00:24,  1.59it/s]

Loss:  13021.057
Loss gain: 6.11%


 62%|██████▏   | 62/100 [00:37<00:23,  1.61it/s]

Loss:  12818.569
Loss gain: 6.09%


 63%|██████▎   | 63/100 [00:38<00:22,  1.63it/s]

Loss:  12619.315
Loss gain: 6.09%


 64%|██████▍   | 64/100 [00:38<00:21,  1.68it/s]

Loss:  12423.778
Loss gain: 6.08%


 65%|██████▌   | 65/100 [00:39<00:20,  1.73it/s]

Loss:  12230.994
Loss gain: 6.07%


 66%|██████▌   | 66/100 [00:40<00:19,  1.72it/s]

Loss:  12042.202
Loss gain: 6.06%


 67%|██████▋   | 67/100 [00:40<00:18,  1.77it/s]

Loss:  11856.371
Loss gain: 6.05%


 68%|██████▊   | 68/100 [00:41<00:17,  1.79it/s]

Loss:  11673.691
Loss gain: 6.04%


 69%|██████▉   | 69/100 [00:41<00:17,  1.81it/s]

Loss:  11494.41
Loss gain: 6.02%


 70%|███████   | 70/100 [00:42<00:17,  1.74it/s]

Loss:  11317.78
Loss gain: 6.02%


 71%|███████   | 71/100 [00:42<00:17,  1.67it/s]

Loss:  11144.985
Loss gain: 6.0%


 72%|███████▏  | 72/100 [00:43<00:16,  1.67it/s]

Loss:  10974.694
Loss gain: 5.99%


 73%|███████▎  | 73/100 [00:44<00:16,  1.65it/s]

Loss:  10807.648
Loss gain: 5.97%


 74%|███████▍  | 74/100 [00:44<00:16,  1.57it/s]

Loss:  10643.095
Loss gain: 5.96%


 75%|███████▌  | 75/100 [00:45<00:15,  1.60it/s]

Loss:  10481.83
Loss gain: 5.95%


 76%|███████▌  | 76/100 [00:46<00:15,  1.59it/s]

Loss:  10322.819
Loss gain: 5.94%


 77%|███████▋  | 77/100 [00:46<00:14,  1.64it/s]

Loss:  10167.376
Loss gain: 5.92%


 78%|███████▊  | 78/100 [00:47<00:13,  1.62it/s]

Loss:  10014.091
Loss gain: 5.91%


 79%|███████▉  | 79/100 [00:47<00:12,  1.65it/s]

Loss:  9863.783
Loss gain: 5.9%


 80%|████████  | 80/100 [00:48<00:15,  1.33it/s]

Loss:  9715.749
Loss gain: 5.88%


 81%|████████  | 81/100 [00:49<00:13,  1.44it/s]

Loss:  9570.895
Loss gain: 5.87%


 82%|████████▏ | 82/100 [00:50<00:11,  1.52it/s]

Loss:  9428.162
Loss gain: 5.85%


 83%|████████▎ | 83/100 [00:50<00:10,  1.61it/s]

Loss:  9287.868
Loss gain: 5.84%


 84%|████████▍ | 84/100 [00:51<00:09,  1.64it/s]

Loss:  9150.683
Loss gain: 5.82%


 85%|████████▌ | 85/100 [00:51<00:09,  1.64it/s]

Loss:  9014.886
Loss gain: 5.81%


 86%|████████▌ | 86/100 [00:52<00:08,  1.63it/s]

Loss:  8882.36
Loss gain: 5.79%


 87%|████████▋ | 87/100 [00:53<00:07,  1.65it/s]

Loss:  8752.197
Loss gain: 5.77%


 88%|████████▊ | 88/100 [00:53<00:07,  1.66it/s]

Loss:  8623.57
Loss gain: 5.76%


 89%|████████▉ | 89/100 [00:54<00:06,  1.69it/s]

Loss:  8497.8
Loss gain: 5.74%


 90%|█████████ | 90/100 [00:54<00:05,  1.71it/s]

Loss:  8374.428
Loss gain: 5.72%


 91%|█████████ | 91/100 [00:55<00:05,  1.70it/s]

Loss:  8252.619
Loss gain: 5.71%


 92%|█████████▏| 92/100 [00:55<00:04,  1.72it/s]

Loss:  8134.0923
Loss gain: 5.68%


 93%|█████████▎| 93/100 [00:56<00:04,  1.70it/s]

Loss:  8016.2705
Loss gain: 5.67%


 94%|█████████▍| 94/100 [00:57<00:03,  1.67it/s]

Loss:  7901.316
Loss gain: 5.65%


 95%|█████████▌| 95/100 [00:57<00:03,  1.66it/s]

Loss:  7788.963
Loss gain: 5.62%


 96%|█████████▌| 96/100 [00:58<00:02,  1.64it/s]

Loss:  7677.4595
Loss gain: 5.61%


 97%|█████████▋| 97/100 [00:58<00:01,  1.68it/s]

Loss:  7568.61
Loss gain: 5.58%


 98%|█████████▊| 98/100 [00:59<00:01,  1.71it/s]

Loss:  7461.4927
Loss gain: 5.57%


 99%|█████████▉| 99/100 [01:00<00:00,  1.75it/s]

Loss:  7356.306
Loss gain: 5.55%


100%|██████████| 100/100 [01:00<00:00,  1.65it/s]

Loss:  7253.316
Loss gain: 5.52%





**CBOW5 for the Hotel Reviews dataset**

In [20]:
hotel_learner_5 = EmbeddingLearner(hotel_vocab_5,hotel_context_5, 5000, 100)
hotel_learner_5.learn()

  1%|          | 1/100 [00:04<06:58,  4.22s/it]

Loss:  33734.562


  2%|▏         | 2/100 [00:08<06:41,  4.09s/it]

Loss:  32203.855
Loss gain: 4.54%


  3%|▎         | 3/100 [00:11<06:33,  4.06s/it]

Loss:  30779.34
Loss gain: 8.76%


  4%|▍         | 4/100 [00:16<06:33,  4.10s/it]

Loss:  29425.713
Loss gain: 12.77%


  5%|▌         | 5/100 [00:20<06:24,  4.05s/it]

Loss:  28140.48
Loss gain: 16.58%


  6%|▌         | 6/100 [00:24<06:15,  4.00s/it]

Loss:  26920.332
Loss gain: 16.41%


  7%|▋         | 7/100 [00:28<06:18,  4.07s/it]

Loss:  25762.672
Loss gain: 16.3%


  8%|▊         | 8/100 [00:32<06:11,  4.04s/it]

Loss:  24664.207
Loss gain: 16.18%


  9%|▉         | 9/100 [00:36<06:04,  4.01s/it]

Loss:  23622.242
Loss gain: 16.06%


 10%|█         | 10/100 [00:40<06:05,  4.06s/it]

Loss:  22633.607
Loss gain: 15.92%


 11%|█         | 11/100 [00:44<05:57,  4.01s/it]

Loss:  21696.242
Loss gain: 15.78%


 12%|█▏        | 12/100 [00:48<05:51,  4.00s/it]

Loss:  20807.002
Loss gain: 15.64%


 13%|█▎        | 13/100 [00:52<05:47,  3.99s/it]

Loss:  19963.607
Loss gain: 15.49%


 14%|█▍        | 14/100 [00:56<05:49,  4.07s/it]

Loss:  19163.752
Loss gain: 15.33%


 15%|█▌        | 15/100 [01:00<05:48,  4.10s/it]

Loss:  18405.643
Loss gain: 15.17%


 16%|█▌        | 16/100 [01:04<05:41,  4.07s/it]

Loss:  17686.75
Loss gain: 15.0%


 17%|█▋        | 17/100 [01:08<05:41,  4.11s/it]

Loss:  17005.08
Loss gain: 14.82%


 18%|█▊        | 18/100 [01:13<05:40,  4.15s/it]

Loss:  16359.106
Loss gain: 14.64%


 19%|█▉        | 19/100 [01:17<05:38,  4.18s/it]

Loss:  15747.109
Loss gain: 14.44%


 20%|██        | 20/100 [01:21<05:40,  4.26s/it]

Loss:  15167.022
Loss gain: 14.25%


 21%|██        | 21/100 [01:25<05:33,  4.22s/it]

Loss:  14617.639
Loss gain: 14.04%


 22%|██▏       | 22/100 [01:30<05:29,  4.23s/it]

Loss:  14097.044
Loss gain: 13.83%


 23%|██▎       | 23/100 [01:34<05:27,  4.26s/it]

Loss:  13604.018
Loss gain: 13.61%


 24%|██▍       | 24/100 [01:38<05:12,  4.12s/it]

Loss:  13137.3545
Loss gain: 13.38%


 25%|██▌       | 25/100 [01:42<05:10,  4.14s/it]

Loss:  12695.511
Loss gain: 13.15%


 26%|██▌       | 26/100 [01:46<05:05,  4.13s/it]

Loss:  12277.044
Loss gain: 12.91%


 27%|██▋       | 27/100 [01:50<05:05,  4.18s/it]

Loss:  11881.057
Loss gain: 12.67%


 28%|██▊       | 28/100 [01:54<05:00,  4.18s/it]

Loss:  11506.212
Loss gain: 12.42%


 29%|██▉       | 29/100 [01:59<04:55,  4.16s/it]

Loss:  11151.25
Loss gain: 12.16%


 30%|███       | 30/100 [02:03<04:58,  4.27s/it]

Loss:  10815.49
Loss gain: 11.9%


 31%|███       | 31/100 [02:07<04:51,  4.22s/it]

Loss:  10497.636
Loss gain: 11.64%


 32%|███▏      | 32/100 [02:12<04:48,  4.24s/it]

Loss:  10196.991
Loss gain: 11.38%


 33%|███▎      | 33/100 [02:16<04:44,  4.25s/it]

Loss:  9912.076
Loss gain: 11.11%


 34%|███▍      | 34/100 [02:20<04:37,  4.20s/it]

Loss:  9642.578
Loss gain: 10.84%


 35%|███▌      | 35/100 [02:24<04:30,  4.17s/it]

Loss:  9387.524
Loss gain: 10.57%


 36%|███▌      | 36/100 [02:28<04:30,  4.23s/it]

Loss:  9146.053
Loss gain: 10.31%


 37%|███▋      | 37/100 [02:32<04:21,  4.15s/it]

Loss:  8917.427
Loss gain: 10.03%


 38%|███▊      | 38/100 [02:36<04:07,  4.00s/it]

Loss:  8700.902
Loss gain: 9.77%


 39%|███▉      | 39/100 [02:40<04:04,  4.00s/it]

Loss:  8495.834
Loss gain: 9.5%


 40%|████      | 40/100 [02:44<04:03,  4.05s/it]

Loss:  8301.399
Loss gain: 9.24%


 41%|████      | 41/100 [02:48<03:55,  3.99s/it]

Loss:  8117.249
Loss gain: 8.97%


 42%|████▏     | 42/100 [02:52<03:55,  4.06s/it]

Loss:  7942.675
Loss gain: 8.71%


 43%|████▎     | 43/100 [02:57<03:58,  4.19s/it]

Loss:  7777.224
Loss gain: 8.46%


 44%|████▍     | 44/100 [03:01<03:51,  4.14s/it]

Loss:  7620.1455
Loss gain: 8.21%


 45%|████▌     | 45/100 [03:04<03:40,  4.01s/it]

Loss:  7471.251
Loss gain: 7.96%


 46%|████▌     | 46/100 [03:09<03:44,  4.16s/it]

Loss:  7329.7886
Loss gain: 7.72%


 47%|████▋     | 47/100 [03:13<03:41,  4.19s/it]

Loss:  7195.4673
Loss gain: 7.48%


 48%|████▊     | 48/100 [03:17<03:37,  4.19s/it]

Loss:  7067.771
Loss gain: 7.25%


 49%|████▉     | 49/100 [03:22<03:40,  4.32s/it]

Loss:  6946.635
Loss gain: 7.02%


 50%|█████     | 50/100 [03:26<03:33,  4.27s/it]

Loss:  6831.5166
Loss gain: 6.8%


 51%|█████     | 51/100 [03:30<03:23,  4.16s/it]

Loss:  6721.8286
Loss gain: 6.58%


 52%|█████▏    | 52/100 [03:34<03:19,  4.15s/it]

Loss:  6617.4707
Loss gain: 6.37%


 53%|█████▎    | 53/100 [03:39<03:20,  4.26s/it]

Loss:  6518.2397
Loss gain: 6.17%


 54%|█████▍    | 54/100 [03:43<03:13,  4.21s/it]

Loss:  6423.625
Loss gain: 5.97%


 55%|█████▌    | 55/100 [03:47<03:09,  4.21s/it]

Loss:  6333.4443
Loss gain: 5.78%


 56%|█████▌    | 56/100 [03:52<03:09,  4.30s/it]

Loss:  6247.381
Loss gain: 5.59%


 57%|█████▋    | 57/100 [03:56<03:03,  4.26s/it]

Loss:  6165.2583
Loss gain: 5.42%


 58%|█████▊    | 58/100 [03:59<02:52,  4.10s/it]

Loss:  6086.981
Loss gain: 5.24%


 59%|█████▉    | 59/100 [04:04<02:54,  4.26s/it]

Loss:  6012.108
Loss gain: 5.07%


 60%|██████    | 60/100 [04:08<02:45,  4.13s/it]

Loss:  5940.5137
Loss gain: 4.91%


 61%|██████    | 61/100 [04:12<02:39,  4.10s/it]

Loss:  5872.102
Loss gain: 4.75%


 62%|██████▏   | 62/100 [04:16<02:39,  4.20s/it]

Loss:  5806.474
Loss gain: 4.61%


 63%|██████▎   | 63/100 [04:20<02:31,  4.10s/it]

Loss:  5743.68
Loss gain: 4.46%


 64%|██████▍   | 64/100 [04:24<02:27,  4.10s/it]

Loss:  5683.5254
Loss gain: 4.33%


 65%|██████▌   | 65/100 [04:28<02:21,  4.04s/it]

Loss:  5625.9126
Loss gain: 4.19%


 66%|██████▌   | 66/100 [04:33<02:22,  4.20s/it]

Loss:  5570.366
Loss gain: 4.07%


 67%|██████▋   | 67/100 [04:37<02:18,  4.19s/it]

Loss:  5517.1543
Loss gain: 3.94%


 68%|██████▊   | 68/100 [04:41<02:07,  4.00s/it]

Loss:  5466.0283
Loss gain: 3.83%


 69%|██████▉   | 69/100 [04:45<02:11,  4.26s/it]

Loss:  5416.7856
Loss gain: 3.72%


 70%|███████   | 70/100 [04:49<02:03,  4.12s/it]

Loss:  5369.4067
Loss gain: 3.61%


 71%|███████   | 71/100 [04:53<02:00,  4.16s/it]

Loss:  5323.605
Loss gain: 3.51%


 72%|███████▏  | 72/100 [04:58<01:59,  4.28s/it]

Loss:  5279.5166
Loss gain: 3.41%


 73%|███████▎  | 73/100 [05:02<01:56,  4.33s/it]

Loss:  5237.0044
Loss gain: 3.32%


 74%|███████▍  | 74/100 [05:06<01:50,  4.23s/it]

Loss:  5196.009
Loss gain: 3.23%


 75%|███████▌  | 75/100 [05:11<01:47,  4.31s/it]

Loss:  5156.2285
Loss gain: 3.14%


 76%|███████▌  | 76/100 [05:15<01:41,  4.24s/it]

Loss:  5117.767
Loss gain: 3.06%


 77%|███████▋  | 77/100 [05:19<01:37,  4.23s/it]

Loss:  5080.588
Loss gain: 2.99%


 78%|███████▊  | 78/100 [05:23<01:31,  4.16s/it]

Loss:  5044.5127
Loss gain: 2.92%


 79%|███████▉  | 79/100 [05:28<01:30,  4.32s/it]

Loss:  5009.6045
Loss gain: 2.84%


 80%|████████  | 80/100 [05:32<01:25,  4.28s/it]

Loss:  4975.5347
Loss gain: 2.78%


 81%|████████  | 81/100 [05:36<01:20,  4.26s/it]

Loss:  4942.658
Loss gain: 2.71%


 82%|████████▏ | 82/100 [05:41<01:16,  4.28s/it]

Loss:  4910.5493
Loss gain: 2.66%


 83%|████████▎ | 83/100 [05:45<01:11,  4.22s/it]

Loss:  4879.366
Loss gain: 2.6%


 84%|████████▍ | 84/100 [05:49<01:06,  4.18s/it]

Loss:  4848.917
Loss gain: 2.54%


 85%|████████▌ | 85/100 [05:53<01:04,  4.28s/it]

Loss:  4819.443
Loss gain: 2.49%


 86%|████████▌ | 86/100 [05:57<00:58,  4.19s/it]

Loss:  4790.5327
Loss gain: 2.44%


 87%|████████▋ | 87/100 [06:01<00:54,  4.17s/it]

Loss:  4762.341
Loss gain: 2.4%


 88%|████████▊ | 88/100 [06:06<00:50,  4.23s/it]

Loss:  4734.8457
Loss gain: 2.35%


 89%|████████▉ | 89/100 [06:10<00:45,  4.17s/it]

Loss:  4708.0176
Loss gain: 2.31%


 90%|█████████ | 90/100 [06:14<00:41,  4.17s/it]

Loss:  4681.6836
Loss gain: 2.27%


 91%|█████████ | 91/100 [06:18<00:37,  4.13s/it]

Loss:  4655.9287
Loss gain: 2.23%


 92%|█████████▏| 92/100 [06:23<00:34,  4.32s/it]

Loss:  4630.752
Loss gain: 2.2%


 93%|█████████▎| 93/100 [06:27<00:30,  4.32s/it]

Loss:  4605.9897
Loss gain: 2.17%


 94%|█████████▍| 94/100 [06:32<00:26,  4.34s/it]

Loss:  4581.8047
Loss gain: 2.13%


 95%|█████████▌| 95/100 [06:36<00:22,  4.43s/it]

Loss:  4558.013
Loss gain: 2.1%


 96%|█████████▌| 96/100 [06:40<00:17,  4.37s/it]

Loss:  4534.717
Loss gain: 2.07%


 97%|█████████▋| 97/100 [06:44<00:12,  4.21s/it]

Loss:  4511.7725
Loss gain: 2.05%


 98%|█████████▊| 98/100 [06:49<00:08,  4.28s/it]

Loss:  4489.3174
Loss gain: 2.02%


 99%|█████████▉| 99/100 [06:53<00:04,  4.27s/it]

Loss:  4467.1235
Loss gain: 1.99%


100%|██████████| 100/100 [06:57<00:00,  4.18s/it]

Loss:  4445.3765
Loss gain: 1.97%





**CBOW5 for the Sci-Fi story dataset**

In [21]:
scifi_learner_5 = EmbeddingLearner(scifi_vocab_5,scifi_context_5,5000,100)
scifi_learner_5.learn()

  1%|          | 1/100 [00:01<02:53,  1.75s/it]

Loss:  34030.805


  2%|▏         | 2/100 [00:03<03:06,  1.90s/it]

Loss:  33134.617
Loss gain: 2.63%


  3%|▎         | 3/100 [00:05<03:00,  1.86s/it]

Loss:  32321.31
Loss gain: 5.02%


  4%|▍         | 4/100 [00:07<02:59,  1.87s/it]

Loss:  31528.137
Loss gain: 7.35%


  5%|▌         | 5/100 [00:09<03:01,  1.91s/it]

Loss:  30754.62
Loss gain: 9.63%


  6%|▌         | 6/100 [00:11<02:53,  1.84s/it]

Loss:  30001.088
Loss gain: 9.46%


  7%|▋         | 7/100 [00:13<02:48,  1.81s/it]

Loss:  29267.781
Loss gain: 9.45%


  8%|▊         | 8/100 [00:14<02:47,  1.82s/it]

Loss:  28552.197
Loss gain: 9.44%


  9%|▉         | 9/100 [00:16<02:46,  1.83s/it]

Loss:  27856.514
Loss gain: 9.42%


 10%|█         | 10/100 [00:19<02:55,  1.95s/it]

Loss:  27178.195
Loss gain: 9.41%


 11%|█         | 11/100 [00:20<02:48,  1.89s/it]

Loss:  26517.928
Loss gain: 9.4%


 12%|█▏        | 12/100 [00:22<02:42,  1.85s/it]

Loss:  25875.22
Loss gain: 9.38%


 13%|█▎        | 13/100 [00:24<02:38,  1.83s/it]

Loss:  25249.53
Loss gain: 9.36%


 14%|█▍        | 14/100 [00:26<02:36,  1.82s/it]

Loss:  24640.379
Loss gain: 9.34%


 15%|█▌        | 15/100 [00:27<02:36,  1.85s/it]

Loss:  24048.434
Loss gain: 9.31%


 16%|█▌        | 16/100 [00:29<02:36,  1.87s/it]

Loss:  23471.328
Loss gain: 9.29%


 17%|█▋        | 17/100 [00:32<02:46,  2.01s/it]

Loss:  22910.605
Loss gain: 9.26%


 18%|█▊        | 18/100 [00:34<02:41,  1.97s/it]

Loss:  22364.373
Loss gain: 9.24%


 19%|█▉        | 19/100 [00:35<02:36,  1.93s/it]

Loss:  21832.674
Loss gain: 9.21%


 20%|██        | 20/100 [00:37<02:32,  1.91s/it]

Loss:  21316.078
Loss gain: 9.18%


 21%|██        | 21/100 [00:39<02:26,  1.85s/it]

Loss:  20813.201
Loss gain: 9.15%


 22%|██▏       | 22/100 [00:41<02:21,  1.81s/it]

Loss:  20323.93
Loss gain: 9.12%


 23%|██▎       | 23/100 [00:43<02:20,  1.82s/it]

Loss:  19847.918
Loss gain: 9.09%


 24%|██▍       | 24/100 [00:45<02:26,  1.93s/it]

Loss:  19384.836
Loss gain: 9.06%


 25%|██▌       | 25/100 [00:47<02:24,  1.93s/it]

Loss:  18934.69
Loss gain: 9.03%


 26%|██▌       | 26/100 [00:49<02:19,  1.89s/it]

Loss:  18496.598
Loss gain: 8.99%


 27%|██▋       | 27/100 [00:50<02:15,  1.86s/it]

Loss:  18070.89
Loss gain: 8.95%


 28%|██▊       | 28/100 [00:52<02:15,  1.89s/it]

Loss:  17656.998
Loss gain: 8.91%


 29%|██▉       | 29/100 [00:54<02:13,  1.87s/it]

Loss:  17254.264
Loss gain: 8.87%


 30%|███       | 30/100 [00:56<02:12,  1.90s/it]

Loss:  16862.445
Loss gain: 8.83%


 31%|███       | 31/100 [00:58<02:21,  2.05s/it]

Loss:  16481.99
Loss gain: 8.79%


 32%|███▏      | 32/100 [01:00<02:15,  1.99s/it]

Loss:  16111.706
Loss gain: 8.75%


 33%|███▎      | 33/100 [01:02<02:07,  1.91s/it]

Loss:  15751.524
Loss gain: 8.71%


 34%|███▍      | 34/100 [01:04<02:06,  1.92s/it]

Loss:  15401.844
Loss gain: 8.66%


 35%|███▌      | 35/100 [01:06<02:06,  1.95s/it]

Loss:  15061.718
Loss gain: 8.62%


 36%|███▌      | 36/100 [01:08<02:03,  1.93s/it]

Loss:  14731.037
Loss gain: 8.57%


 37%|███▋      | 37/100 [01:10<01:58,  1.88s/it]

Loss:  14409.27
Loss gain: 8.52%


 38%|███▊      | 38/100 [01:11<01:54,  1.85s/it]

Loss:  14096.8545
Loss gain: 8.47%


 39%|███▉      | 39/100 [01:14<01:58,  1.94s/it]

Loss:  13792.91
Loss gain: 8.42%


 40%|████      | 40/100 [01:15<01:54,  1.91s/it]

Loss:  13497.922
Loss gain: 8.37%


 41%|████      | 41/100 [01:17<01:50,  1.88s/it]

Loss:  13211.024
Loss gain: 8.32%


 42%|████▏     | 42/100 [01:19<01:48,  1.86s/it]

Loss:  12931.527
Loss gain: 8.27%


 43%|████▎     | 43/100 [01:21<01:45,  1.85s/it]

Loss:  12660.393
Loss gain: 8.21%


 44%|████▍     | 44/100 [01:23<01:43,  1.84s/it]

Loss:  12396.923
Loss gain: 8.16%


 45%|████▌     | 45/100 [01:25<01:40,  1.84s/it]

Loss:  12140.788
Loss gain: 8.1%


 46%|████▌     | 46/100 [01:27<01:45,  1.96s/it]

Loss:  11891.739
Loss gain: 8.04%


 47%|████▋     | 47/100 [01:28<01:38,  1.85s/it]

Loss:  11649.681
Loss gain: 7.98%


 48%|████▊     | 48/100 [01:30<01:35,  1.84s/it]

Loss:  11414.507
Loss gain: 7.92%


 49%|████▉     | 49/100 [01:32<01:33,  1.82s/it]

Loss:  11186.047
Loss gain: 7.86%


 50%|█████     | 50/100 [01:34<01:28,  1.77s/it]

Loss:  10963.765
Loss gain: 7.8%


 51%|█████     | 51/100 [01:35<01:25,  1.75s/it]

Loss:  10747.952
Loss gain: 7.74%


 52%|█████▏    | 52/100 [01:37<01:24,  1.77s/it]

Loss:  10538.335
Loss gain: 7.68%


 53%|█████▎    | 53/100 [01:39<01:30,  1.92s/it]

Loss:  10334.447
Loss gain: 7.61%


 54%|█████▍    | 54/100 [01:41<01:26,  1.89s/it]

Loss:  10136.192
Loss gain: 7.55%


 55%|█████▌    | 55/100 [01:43<01:24,  1.88s/it]

Loss:  9943.774
Loss gain: 7.48%


 56%|█████▌    | 56/100 [01:45<01:21,  1.86s/it]

Loss:  9756.841
Loss gain: 7.42%


 57%|█████▋    | 57/100 [01:47<01:20,  1.88s/it]

Loss:  9575.257
Loss gain: 7.35%


 58%|█████▊    | 58/100 [01:49<01:19,  1.89s/it]

Loss:  9398.219
Loss gain: 7.28%


 59%|█████▉    | 59/100 [01:51<01:16,  1.87s/it]

Loss:  9226.97
Loss gain: 7.21%


 60%|██████    | 60/100 [01:52<01:11,  1.78s/it]

Loss:  9060.028
Loss gain: 7.14%


 61%|██████    | 61/100 [01:54<01:15,  1.92s/it]

Loss:  8898.189
Loss gain: 7.07%


 62%|██████▏   | 62/100 [01:56<01:11,  1.88s/it]

Loss:  8740.739
Loss gain: 7.0%


 63%|██████▎   | 63/100 [01:58<01:10,  1.91s/it]

Loss:  8587.835
Loss gain: 6.93%


 64%|██████▍   | 64/100 [02:00<01:08,  1.90s/it]

Loss:  8439.144
Loss gain: 6.85%


 65%|██████▌   | 65/100 [02:02<01:05,  1.87s/it]

Loss:  8294.747
Loss gain: 6.78%


 66%|██████▌   | 66/100 [02:03<01:01,  1.80s/it]

Loss:  8154.245
Loss gain: 6.71%


 67%|██████▋   | 67/100 [02:05<00:59,  1.81s/it]

Loss:  8017.761
Loss gain: 6.64%


 68%|██████▊   | 68/100 [02:08<01:02,  1.96s/it]

Loss:  7885.3486
Loss gain: 6.56%


 69%|██████▉   | 69/100 [02:09<00:59,  1.92s/it]

Loss:  7756.271
Loss gain: 6.49%


 70%|███████   | 70/100 [02:11<00:56,  1.87s/it]

Loss:  7631.1294
Loss gain: 6.42%


 71%|███████   | 71/100 [02:13<00:53,  1.85s/it]

Loss:  7509.1743
Loss gain: 6.34%


 72%|███████▏  | 72/100 [02:15<00:51,  1.84s/it]

Loss:  7390.926
Loss gain: 6.27%


 73%|███████▎  | 73/100 [02:17<00:51,  1.90s/it]

Loss:  7275.705
Loss gain: 6.2%


 74%|███████▍  | 74/100 [02:19<00:48,  1.88s/it]

Loss:  7163.8335
Loss gain: 6.12%


 75%|███████▌  | 75/100 [02:21<00:50,  2.01s/it]

Loss:  7055.0464
Loss gain: 6.05%


 76%|███████▌  | 76/100 [02:23<00:46,  1.95s/it]

Loss:  6949.323
Loss gain: 5.97%


 77%|███████▋  | 77/100 [02:25<00:43,  1.89s/it]

Loss:  6846.323
Loss gain: 5.9%


 78%|███████▊  | 78/100 [02:26<00:41,  1.88s/it]

Loss:  6746.33
Loss gain: 5.83%


 79%|███████▉  | 79/100 [02:28<00:39,  1.87s/it]

Loss:  6648.9355
Loss gain: 5.76%


 80%|████████  | 80/100 [02:30<00:36,  1.84s/it]

Loss:  6554.41
Loss gain: 5.68%


 81%|████████  | 81/100 [02:32<00:35,  1.85s/it]

Loss:  6462.4077
Loss gain: 5.61%


 82%|████████▏ | 82/100 [02:34<00:35,  1.96s/it]

Loss:  6372.604
Loss gain: 5.54%


 83%|████████▎ | 83/100 [02:36<00:32,  1.91s/it]

Loss:  6285.7134
Loss gain: 5.46%


 84%|████████▍ | 84/100 [02:38<00:30,  1.89s/it]

Loss:  6200.8857
Loss gain: 5.39%


 85%|████████▌ | 85/100 [02:40<00:27,  1.86s/it]

Loss:  6118.335
Loss gain: 5.32%


 86%|████████▌ | 86/100 [02:41<00:26,  1.87s/it]

Loss:  6038.136
Loss gain: 5.25%


 87%|████████▋ | 87/100 [02:43<00:24,  1.90s/it]

Loss:  5960.0215
Loss gain: 5.18%


 88%|████████▊ | 88/100 [02:45<00:22,  1.91s/it]

Loss:  5884.0737
Loss gain: 5.11%


 89%|████████▉ | 89/100 [02:47<00:20,  1.84s/it]

Loss:  5809.9365
Loss gain: 5.04%


 90%|█████████ | 90/100 [02:49<00:19,  1.98s/it]

Loss:  5737.9414
Loss gain: 4.97%


 91%|█████████ | 91/100 [02:51<00:17,  1.91s/it]

Loss:  5667.7812
Loss gain: 4.9%


 92%|█████████▏| 92/100 [02:53<00:15,  1.90s/it]

Loss:  5599.405
Loss gain: 4.84%


 93%|█████████▎| 93/100 [02:55<00:13,  1.91s/it]

Loss:  5532.776
Loss gain: 4.77%


 94%|█████████▍| 94/100 [02:56<00:10,  1.79s/it]

Loss:  5467.8765
Loss gain: 4.71%


 95%|█████████▌| 95/100 [02:58<00:09,  1.81s/it]

Loss:  5404.914
Loss gain: 4.64%


 96%|█████████▌| 96/100 [03:00<00:07,  1.87s/it]

Loss:  5343.1953
Loss gain: 4.58%


 97%|█████████▋| 97/100 [03:03<00:06,  2.07s/it]

Loss:  5283.2603
Loss gain: 4.51%


 98%|█████████▊| 98/100 [03:05<00:04,  2.00s/it]

Loss:  5224.7305
Loss gain: 4.45%


 99%|█████████▉| 99/100 [03:07<00:01,  2.00s/it]

Loss:  5167.8022
Loss gain: 4.39%


100%|██████████| 100/100 [03:09<00:00,  1.89s/it]

Loss:  5112.2095
Loss gain: 4.32%





5 closest words of CBOW2 for the Hotel Reviews dataset

In [22]:
for word in ['bar','window','mouth','pay','teach','jump','tireless','rare','nicest']:
    neighbors = hotel_learner_2.get_closest_word(word)
    print(word, ": ", neighbors, "\n")

bar :  [('place', 3.5057084560394287), ('fine', 3.7182514667510986), ('nice', 3.7379472255706787), ('look', 3.761240005493164), ('day', 3.763062000274658)] 

window :  [('use', 4.510156631469727), ('area', 4.525331974029541), ('hotel', 4.534768581390381), ('right', 4.591760158538818), ('tri', 4.6055779457092285)] 

mouth :  [('andr', 6.032471179962158), ('day', 6.1610260009765625), ('servic', 6.175563335418701), ('beauti', 6.179343223571777), ('pike', 6.217563152313232)] 

pay :  [('locat', 4.316544055938721), ('work', 4.3558735847473145), ('hotel', 4.447457313537598), ('tri', 4.484480857849121), ('close', 4.49414587020874)] 

teach :  [('mind', 7.310356616973877), ('level', 7.9295783042907715), ('food', 7.949304103851318), ('separ', 7.961902141571045), ('grace', 8.074626922607422)] 

jump :  [('downtown', 5.247917175292969), ('place', 5.5001606941223145), ('close', 5.830567836761475), ('center', 5.933122158050537), ('hotel', 5.97814416885376)] 

tireless :  [('experi', 4.8784546852111

5 closest words of CBOW2 for the Sci-Fi story dataset

In [23]:
for word in ['mouth','sport','bedroom','trust','fail','pay','largest','clear','endless']:
    neighbors = scifi_learner_2.get_closest_word(word)
    print(word, ": ", neighbors, "\n")

mouth :  [('brought', 6.275054454803467), ('made', 6.327899932861328), ('codiscover', 6.493941783905029), ('came', 6.604343891143799), ('movi', 6.60573673248291)] 

sport :  [('get', 6.421679973602295), ('kirk', 6.525964260101318), ('ground', 6.579998970031738), ('predatorycreatur', 6.715239524841309), ('long', 6.728081703186035)] 

bedroom :  [('made', 5.800624370574951), ('swam', 5.992005348205566), ('voluptu', 6.217353820800781), ('bombshel', 6.220456123352051), ('judi', 6.3308258056640625)] 

trust :  [('toward', 5.99810266494751), ('kept', 6.206921577453613), ('foreman', 6.311612606048584), ('seiz', 6.458076477050781), ('gaunt', 6.518838882446289)] 

fail :  [('windi', 5.9373369216918945), ('sir', 6.292647838592529), ('want', 6.324713230133057), ('neverth', 6.368040084838867), ('may', 6.474075794219971)] 

pay :  [('list', 6.073314666748047), ('innoc', 6.661491870880127), ('question', 6.699519634246826), ('knowledgeseek', 6.8819780349731445), ('lift', 6.998386859893799)] 

largest

5 closest words of CBOW5 for the Hotel Reviews dataset

In [24]:
for word in ['bar','window','mouth','pay','teach','jump','tireless','rare','nicest']:
    neighbors = hotel_learner_5.get_closest_word(word)
    print(word, ": ", neighbors, "\n")

bar :  [('citi', 1.028738021850586), ('hotel', 1.077404260635376), ('use', 1.1310336589813232), ('quit', 1.1636459827423096), ('thing', 1.1637693643569946)] 

window :  [('got', 1.7456876039505005), ('bathroom', 1.7746179103851318), ('old', 1.7912161350250244), ('larg', 1.8197168111801147), ('clean', 1.8316915035247803)] 

mouth :  [('special', 5.855133056640625), ('money', 5.87908935546875), ('highli', 5.896125793457031), ('amaz', 5.924036026000977), ('terribl', 5.945255279541016)] 

pay :  [('price', 0.9755768179893494), ('friendli', 1.1670482158660889), ('hotel', 1.1920456886291504), ('time', 1.2026610374450684), ('stay', 1.2043038606643677)] 

teach :  [('saw', 6.024971008300781), ('happen', 6.090672492980957), ('upgrad', 6.170861721038818), ('rude', 6.18377685546875), ('inclin', 6.247430801391602)] 

jump :  [('garag', 5.9092535972595215), ('soon', 5.947790145874023), ('doe', 5.955192565917969), ('coffe', 6.004382133483887), ('walk', 6.025463581085205)] 

tireless :  [('took', 3.6

5 closest words of CBOW5 for the Sci-Fi story dataset

In [25]:
for word in ['mouth','sport','bedroom','trust','fail','pay','largest','clear','endless']:
    neighbors = scifi_learner_5.get_closest_word(word)
    print(word, ": ", neighbors, "\n")

mouth :  [('say', 3.7974326610565186), ('earth', 4.0416035652160645), ('alma', 4.0852508544921875), ('toward', 4.085824012756348), ('one', 4.2075958251953125)] 

sport :  [('depress', 6.369104862213135), ('good', 6.414592266082764), ('homicid', 6.4989190101623535), ('roar', 6.499039649963379), ('hit', 6.516376495361328)] 

bedroom :  [('peopl', 4.857330322265625), ('got', 5.0196075439453125), ('wa', 5.117725372314453), ('thi', 5.138413429260254), ('came', 5.192033767700195)] 

trust :  [('poor', 7.949508190155029), ('go', 7.997661113739014), ('come', 8.091782569885254), ('cours', 8.12215805053711), ('lift', 8.211256980895996)] 

fail :  [('right', 4.795498847961426), ('light', 4.795810699462891), ('hi', 4.797901153564453), ('day', 4.801190376281738), ('came', 4.8309173583984375)] 

pay :  [('said', 5.613862037658691), ('interest', 6.000959873199463), ('fourth', 6.032034873962402), ('chair', 6.047975540161133), ('say', 6.069349765777588)] 

largest :  [('man', 4.712806701660156), ('say'