In [37]:
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import DATASETS
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from tqdm import tqdm
import pickle
import random
import numpy as np
from collections import Counter, defaultdict
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from gensim.test.utils import datapath, get_tmpfile
from gensim.models import KeyedVectors
import gensim.downloader
from torch import FloatTensor as FT

# Get the interactive Tools for Matplotlib
%matplotlib notebook
%matplotlib inline

plt.style.use('ggplot')

### Instructions
For this part, fill in the required code and make the notebook work. This wll be very similar to the Skip-Gram model, but a little more difficult. Look for the """ FILL IN """ string to guide you.

In [38]:
# Where do I want to run my job. You can do "cuda" on linux machines
DEVICE = "mps" if torch.backends.mps.is_available() else  "cpu"
# DEVICE = "cuda" if torch.cuda.is_available() else  "cpu"

# The batch size in Adam or SGD
BATCH_SIZE = 512

# Number of epochs
NUM_EPOCHS = 10

# Predict from 2 words the inner word for CBOW
# I.e. I'll have a window like ["a", "b", "c"] of continuous text (each is a word)
# We'll predict each of wc = ["a", "c"] from "b" = wc for Skip-Gram
# For CBOW, we'll use ["a", "c"] to predict "b" = wo
WINDOW = 1

# Negative samples.
K = 4

The text8 Wikipedia corpus. 100M characters.

In [39]:
# Put the data in your Google Drive
# You ca get the data here: https://www.kaggle.com/competitions/titanic/data
from google.colab import drive
drive.mount('/content/drive')

!du -h text8

f = open('/content/drive/MyDrive/text8/text8', 'r')
text = f.read()
# One big string of size 100M
print(len(text))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
du: cannot access 'text8': No such file or directory
100000000


In [40]:
punc = '!"#$%&()*+,-./:;<=>?@[\\]^_\'{|}~\t\n'

# Can do regular expressions here too
for c in punc:
    if c in text:
        text.replace(c, ' ')

In [41]:
# A very crude tokenizer you get for free: lower case and also split on spaces
TOKENIZER = get_tokenizer("basic_english")

In [42]:
words = TOKENIZER(text)
f = Counter(words)

In [43]:
len(words)

17005207

In [44]:
# Do a very crude filter on the text which removes all very popular words
text = [word for word in words if f[word] > 5]

In [45]:
text[0:5]

['anarchism', 'originated', 'as', 'a', 'term']

In [46]:
VOCAB = build_vocab_from_iterator([text])

In [47]:
# word -> int hash map
stoi = VOCAB.get_stoi()
# int -> word hash map
itos = VOCAB.get_itos()

In [48]:
stoi['as']

11

In [49]:
# Total number of words
len(stoi)

63641

In [50]:
f = Counter(text)
# This is the probability that we pick a word in the corpus
z = {word: f[word] / len(text) for word in f}

In [51]:
threshold = 1e-5
# Probability that word is kept while subsampling
# This is explained here and sightly differet from the paper: http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/
p_keep = {word: (np.sqrt(z[word] / 0.001) + 1)*(0.0001 / z[word]) for word in f}

In [52]:
# This is in the integer space
train_dataset = [word for word in text if random.random() < p_keep[word]]

# Rebuild the vocabulary
VOCAB = build_vocab_from_iterator([train_dataset])

In [53]:
len(train_dataset)

7844729

In [54]:
# word -> int mapping
stoi = VOCAB.get_stoi()
# int -> word mapping
itos = VOCAB.get_itos()

In [55]:
# The vocabulary size after we do all the filters
len(VOCAB)

63641

In [56]:
# The probability we draw something for negative sampling
f = Counter(train_dataset)
p = torch.zeros(len(VOCAB))

# Downsample frequent words and upsample less frequent
s = sum([np.power(freq, 0.75) for word, freq in f.items()])

for word in f:
    p[stoi[word]] = np.power(f[word], 0.75) / s

In [57]:
# Map everything to integers
train_dataset = [stoi[word] for word in text]

In [58]:
# This just gets the (wc, wo) pairs that are positive - they are seen together!
def get_tokenized_dataset(dataset, verbose=False):
    x_list = []

    for i, token in enumerate(dataset):
        m = 1

        # Get the left and right tokens
        start = max(0,i-m)
        left_tokens = dataset[start:i]

        end = min(i+m,len(dataset)-1)
        right_tokens = dataset[i+1:end+1]

        # Check these are the same length, and if so use them to add a row of data. This should be a list like
        # [a, c, b] where b is the center word
        if len(left_tokens) == len(right_tokens):
            w_context = left_tokens + right_tokens

            wc = token

            x_list.extend(
                [w_context + [wc]]
            )

    return x_list

In [59]:
train_x_list = get_tokenized_dataset(train_dataset, verbose=False)

In [60]:
pickle.dump(train_x_list, open('train_x_list.pkl', 'wb'))

In [61]:
train_x_list = pickle.load(open('train_x_list.pkl', 'rb'))

In [62]:
# These are (wc, wo) pairs. All are y = +1 by design
train_x_list[:10]

[[5233, 11, 3083],
 [3083, 6, 11],
 [11, 163, 6],
 [6, 1, 163],
 [163, 3133, 1],
 [1, 47, 3133],
 [3133, 56, 47],
 [47, 140, 56],
 [56, 115, 140],
 [140, 740, 115]]

In [63]:
# The number of things of BATCH_SIZE = 512
assert(len(train_x_list) // BATCH_SIZE == 32579)

### Set up the dataloader.

In [64]:
train_dl = DataLoader(
    TensorDataset(
        torch.tensor(train_x_list).to(DEVICE),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [65]:
for xb in train_dl:
    assert(xb[0].shape == (BATCH_SIZE, 3))
    break

### Words we'll use to asses the quality of the model ...

In [66]:
valid_ids = torch.tensor([
    stoi['money'],
    stoi['lion'],
    stoi['africa'],
    stoi['musician'],
    stoi['dance'],
])

### Get the model.

In [67]:
class CBOWNegativeSampling(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(CBOWNegativeSampling, self).__init__()
        self.A = nn.Embedding(vocab_size, embed_dim) # Context vectors - center word
        self.B = nn.Embedding(vocab_size, embed_dim) # Output vectors - words around the center word
        self.init_weights()

    def init_weights(self):
        # Is this the best way? Not sure
        initrange = 0.5
        self.A.weight.data.uniform_(-initrange, initrange)
        self.B.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        # N is the batch size
        # x is (N, 3)

        # Context words are 2m things, m = 1 so w_context is (N, 2) while wc is (N, 1)
        w_context, wc = x[:, :-1], x[:, -1]

        # Each of these is (N, 2, D) since each context has 2 word
        # We want this to be (N, D) and this is what we get

        # (N, 2, D)
        a = self.A(w_context)

        # (N, D)
        a_avg = a.mean(dim=1)

        # Each of these is (N, D) since each target has 1 word
        b = self.B(wc)

        # The product between each context and target vector. Look at the Skip-Gram code.
        # The logits is now (N, 1) since we sum across the final dimension.
        logits = (a_avg * b).sum(axis=-1)

        return logits

In [68]:
@torch.no_grad()
def validate_embeddings(
    model,
    valid_ids,
    itos
):
    """ Validation logic """

    # We will use context embeddings to get the most similar words
    # Other strategies include: using target embeddings, mean embeddings after avaraging context/target
    embedding_weights = model.A.weight

    normalized_embeddings = embedding_weights.cpu() / np.sqrt(
        np.sum(embedding_weights.cpu().numpy()**2, axis=1, keepdims=True)
    )

    # Get the embeddings corresponding to valid_term_ids
    valid_embeddings = normalized_embeddings[valid_ids, :]

    # Compute the similarity between valid_term_ids (S) and all the embeddings (V)
    # We do S x d (d x V) => S x D and sort by negative similarity
    top_k = 10 # Top k items will be displayed
    similarity = np.dot(valid_embeddings.cpu().numpy(), normalized_embeddings.cpu().numpy().T)

    # Invert similarity matrix to negative
    # Ignore the first one because that would be the same word as the probe word
    similarity_top_k = np.argsort(-similarity, axis=1)[:, 1: top_k+1]

    # Print the output.
    for i, word_id in enumerate(valid_ids):
        # j >= 1 here since we don't want to include the word itself.
        similar_word_str = ', '.join([itos[j] for j in similarity_top_k[i, :] if j >= 1])
        print(f"{itos[word_id]}: {similar_word_str}")

    print('\n')

### Set up the model

In [69]:
LR = 10.0
NUM_EPOCHS = 10
EMBED_DIM = 300

In [70]:
model = CBOWNegativeSampling(len(VOCAB), EMBED_DIM).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

# The learning rate is lowered every epoch by 1/10
# Is this a good idea?
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)

In [71]:
model

CBOWNegativeSampling(
  (A): Embedding(63641, 300)
  (B): Embedding(63641, 300)
)

In [72]:
validate_embeddings(model, valid_ids, itos)

money: worn, assures, triumvirate, anagram, walleye, unaspirated, looped, bahraini, sabines, asbestosis
lion: unrequited, photographic, adopter, manuscript, sends, maharal, baja, dubitative, ketamine, pikes
africa: emigrating, praised, tiberius, jaffna, idiom, marlborough, corriere, reversibility, catahoula, aris
musician: conduct, menagerie, vl, ingenious, wineries, reffered, shugart, mackey, habilitation, modern
dance: qxd, rms, programmability, tachycardia, iconostasis, custodian, aqueous, battletech, simo, coll




### Train the model

In [73]:
ratios = []

def train(dataloader, model, optimizer, epoch):
    model.train()
    total_acc, total_count, total_loss, total_batches = 0, 0, 0.0, 0.0
    log_interval = 500

    for idx, x_batch in tqdm(enumerate(dataloader)):

        x_batch = x_batch[0]

        batch_size = x_batch.shape[0]

        # Zero the gradient so they don't accumulate
        optimizer.zero_grad()

        logits = model(x_batch)

        # Get the positive samples loss. Notice we use weights here
        positive_loss = torch.nn.BCEWithLogitsLoss()(input=logits, target=torch.ones(batch_size).to(DEVICE).float())

        # For each batch, get some negative samples
        # We need a total of len(y_batch) * K samples across a batch
        # We then reshape this batch
        # These are effectively the output words
        negative_samples = torch.multinomial(p, batch_size * K, replacement=True)

        # Context words are 2m things, m = 1 so w_context is (N, 2) while wc is (N, 1)
        w_context, wc = x_batch[:, :-1], x_batch[:, -1]

        """
        if w_context looks like below (batch_size = 3)
        [
        (a, b),
        (c, d),
        (e, f)
        ] and K = 2 we'd like to get:

        [
        (a, b),
        (a, b),
        (c, d),
        (c, d),
        (e, f),
        (e, f)
        ]

        This will be batch_size * K rows.
        """

        # This should be (N * K, 2)
        w_context = torch.concat([
            w.repeat(K, 1) for w in torch.tensor(w_context).split(1)
        ])

        # Remove the last dimension 1
        wc = negative_samples.unsqueeze(-1)

        # Get the negative samples. This should be (N * K, 3)
        # Concatenate the w_context and wc along the column. Make sure everything is on CUDA / MPS or CPU
        x_batch_negative = torch.concat([w_context, wc.to(DEVICE)], axis=1)

        """
        Note the way we formulated the targets: they are all 0 since these are negative samples.
        We do the BCEWithLogitsLoss by hand basically here.
        Notice we sum across the negative samples, per positive word.

        This is literally the equation in the lecture notes.
        """

        # (N, K, D) -> (N, D) -> (N)
        # Look at the Skip-Gram notebook
        negative_loss = model(x_batch_negative).neg().sigmoid().log().reshape(batch_size, K).sum(1).mean().neg().to(DEVICE)

        loss = (positive_loss + negative_loss).mean()

        # Get the gradients via back propagation
        loss.backward()

        # Clip the gradients? Generally a good idea
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)


        # Do an optimization step. Update the parameters A and B
        optimizer.step()

        # Get the new loss
        total_loss += loss.item()

        # Update the batch count
        total_batches += 1

        if idx % log_interval == 0:
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| loss {:8.3f} ".format(
                    epoch,
                    idx,
                    len(dataloader),
                    total_loss / total_batches
                )
            )
            validate_embeddings(model, valid_ids, itos)
            total_loss, total_batches = 0.0, 0.0

### Some results from the run look like below:

Somewhere inside of 2 iterations you should get sensible associattions.
Paste here a screenshot of the closest vectors.

In [None]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()

    train(train_dl, model, optimizer, epoch)
    # We have a learning rate scheduler here

    # Basically, given the state of the optimizer, this lowers the learning rate in a smart way
    scheduler.step()

  w.repeat(K, 1) for w in torch.tensor(w_context).split(1)
1it [00:02,  2.42s/it]

| epoch   1 |     0/32580 batches | loss    4.101 
money: worn, assures, triumvirate, anagram, walleye, unaspirated, looped, bahraini, sabines, asbestosis
lion: unrequited, photographic, adopter, manuscript, sends, maharal, baja, dubitative, ketamine, pikes
africa: emigrating, praised, tiberius, jaffna, idiom, marlborough, corriere, reversibility, catahoula, aris
musician: conduct, menagerie, vl, ingenious, wineries, reffered, shugart, mackey, habilitation, modern
dance: qxd, rms, programmability, tachycardia, iconostasis, custodian, aqueous, battletech, simo, coll




501it [03:15,  2.39it/s]

| epoch   1 |   500/32580 batches | loss    3.717 
money: worn, assures, triumvirate, walleye, anagram, looped, unaspirated, titian, bahraini, sabines
lion: unrequited, adopter, photographic, manuscript, sends, maharal, baja, dubitative, pikes, ketamine
africa: emigrating, praised, tiberius, jaffna, idiom, corriere, marlborough, reversibility, imprisonment, baekje
musician: conduct, vl, menagerie, ingenious, wineries, reffered, shugart, mackey, profumo, habilitation
dance: qxd, rms, programmability, tachycardia, iconostasis, aqueous, custodian, nia, coll, simo




1001it [06:35,  1.78it/s]

| epoch   1 |  1000/32580 batches | loss    3.352 
money: worn, assures, triumvirate, looped, walleye, titian, anagram, unaspirated, bahraini, sabines
lion: unrequited, photographic, adopter, manuscript, sends, baja, maharal, dubitative, pikes, ketamine
africa: emigrating, praised, tiberius, jaffna, idiom, marlborough, baekje, corriere, reversibility, imprisonment
musician: conduct, vl, menagerie, ingenious, wineries, reffered, mackey, profumo, shugart, metrics
dance: qxd, rms, programmability, tachycardia, aqueous, iconostasis, custodian, nia, besiege, simo




1501it [09:53,  2.06it/s]

| epoch   1 |  1500/32580 batches | loss    3.033 
money: worn, looped, assures, triumvirate, walleye, anagram, titian, unaspirated, asbestosis, bahraini
lion: unrequited, adopter, photographic, manuscript, sends, maharal, baja, dubitative, pikes, ketamine
africa: emigrating, praised, tiberius, idiom, jaffna, baekje, marlborough, imprisonment, corriere, reversibility
musician: conduct, ingenious, vl, menagerie, wineries, reffered, profumo, herakles, metrics, mackey
dance: qxd, rms, programmability, tachycardia, aqueous, iconostasis, coll, custodian, presidencia, besiege




2000it [13:20,  2.08it/s]

| epoch   1 |  2000/32580 batches | loss    2.795 


2001it [13:21,  1.73it/s]

money: worn, triumvirate, looped, assures, walleye, titian, anagram, asbestosis, bahraini, bahn
lion: unrequited, adopter, photographic, manuscript, maharal, sends, baja, dubitative, ketamine, pikes
africa: emigrating, praised, tiberius, idiom, jaffna, eight, marlborough, imprisonment, corriere
musician: conduct, ingenious, vl, wineries, menagerie, reffered, profumo, controversially, habilitation, grandson
dance: qxd, rms, programmability, tachycardia, aqueous, nia, coll, besiege, iconostasis, custodian




2500it [16:41,  2.79it/s]

| epoch   1 |  2500/32580 batches | loss    2.623 
money: worn, looped, assures, triumvirate, titian, anagram, walleye, bahraini, intersected, bahn


2501it [16:42,  2.44it/s]

lion: unrequited, adopter, photographic, manuscript, maharal, baja, sends, dubitative, ketamine, pikes
africa: emigrating, be, eight, four, a, one, praised, and, or
musician: conduct, ingenious, wineries, vl, menagerie, reffered, grandson, profumo, controversially, sergeant
dance: qxd, rms, programmability, tachycardia, aqueous, nia, coll, besiege, custodian, increase




3000it [20:07,  2.60it/s]

| epoch   1 |  3000/32580 batches | loss    2.486 
money: worn, looped, assures, triumvirate, titian, intersected, now, bahraini, anagram, walleye


3001it [20:08,  2.17it/s]

lion: unrequited, adopter, manuscript, photographic, maharal, sends, baja, dubitative, ketamine, pikes
africa: more, eight, emigrating, this, and, which, one, that, a
musician: conduct, ingenious, wineries, vl, grandson, reffered, profumo, menagerie, herakles, controversially
dance: qxd, rms, g, programmability, tachycardia, aqueous, nia, besiege, increase, counteracting




3500it [23:31,  2.85it/s]

| epoch   1 |  3500/32580 batches | loss    2.382 


3501it [23:31,  2.24it/s]

money: now, looped, worn, assures, triumvirate, injected, intersected, within, mur, titian
lion: unrequited, adopter, manuscript, photographic, maharal, baja, dubitative, sends, ketamine, handkerchief
africa: or, more, eight, that, which, four, this, nine, some
musician: conduct, ingenious, wineries, vl, grandson, reffered, profumo, controversially, herakles, menagerie
dance: g, qxd, rms, nia, tachycardia, besiege, programmability, aqueous, increase, counteracting




4001it [27:00,  2.15it/s]

| epoch   1 |  4000/32580 batches | loss    2.291 
money: now, worn, within, looped, show, assures, injected, intersected, triumvirate, time
lion: unrequited, adopter, manuscript, photographic, maharal, baja, sends, dubitative, ketamine, handkerchief
africa: or, more, eight, zero, which, four, nine, them, and
musician: conduct, ingenious, wineries, grandson, vl, modern, profumo, reffered, herakles, controversially
dance: g, qxd, rms, increase, nia, besiege, tachycardia, programmability, aqueous, counteracting




4501it [30:25,  2.05it/s]

| epoch   1 |  4500/32580 batches | loss    2.219 
money: now, within, time, often, show, languages, when, how, looped, worn
lion: unrequited, manuscript, adopter, photographic, maharal, baja, dubitative, sends, ketamine, handkerchief
africa: more, eight, four, or, which, nine, zero, one, them
musician: conduct, ingenious, grandson, modern, wineries, vl, profumo, herakles, reffered, controversially
dance: g, qxd, increase, four, rms, power, two, people, aqueous, tachycardia




5001it [33:48,  2.13it/s]

| epoch   1 |  5000/32580 batches | loss    2.159 
money: now, within, time, often, when, how, show, because, languages, out
lion: unrequited, manuscript, adopter, photographic, maharal, dubitative, baja, handkerchief, ketamine, sends
africa: more, eight, four, or, nine, zero, and, which, south
musician: conduct, modern, ingenious, grandson, company, wineries, vl, herakles, profumo, reffered
dance: g, increase, power, two, four, people, qxd, no, three, rms




5501it [37:14,  2.34it/s]

| epoch   1 |  5500/32580 batches | loss    2.109 
money: now, within, time, often, because, how, when, also, languages, show
lion: unrequited, manuscript, adopter, photographic, maharal, handkerchief, baja, dubitative, ketamine, demian
africa: more, eight, or, four, nine, which, south, zero, them, time
musician: modern, conduct, company, grandson, ingenious, office, wineries, profumo, herakles, vl
dance: g, power, increase, four, people, two, three, six, both, zero




6000it [40:35,  2.16it/s]

| epoch   1 |  6000/32580 batches | loss    2.055 


6001it [40:35,  1.94it/s]

money: now, time, within, often, also, how, because, show, work, languages
lion: unrequited, manuscript, adopter, maharal, photographic, handkerchief, dubitative, demian, baja, ketamine
africa: more, eight, or, nine, time, which, south, them, zero
musician: modern, company, conduct, grandson, office, ingenious, wineries, any, herakles, french
dance: g, power, two, people, three, six, increase, four, both, zero




6500it [43:58,  2.06it/s]

| epoch   1 |  6500/32580 batches | loss    2.019 


6501it [43:59,  1.72it/s]

money: now, time, because, within, often, how, work, made, out, also
lion: unrequited, manuscript, adopter, maharal, photographic, handkerchief, ketamine, demian, dubitative, baja
africa: eight, more, time, nine, south, zero, them, number, four
musician: modern, company, conduct, grandson, any, office, french, france, ingenious, like
dance: g, power, three, two, people, six, four, zero, both, increase




6924it [46:51,  2.80it/s]