In [1]:
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import DATASETS
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from tqdm import tqdm
import pickle
import random
import numpy as np
from collections import Counter, defaultdict
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from gensim.test.utils import datapath, get_tmpfile
from gensim.models import KeyedVectors
import gensim.downloader
from torch import FloatTensor as FT

%matplotlib notebook
%matplotlib inline

plt.style.use('ggplot')

In [2]:
# Configurations
device = "mps" if torch.backends.mps.is_available() else "cpu"
batch_size = 512
num_epochs = 10
window_size = 1
negative_samples = 4

In [3]:
#The text8 Wikipedia corpus. 100M characters.
!du -h data/text8

 95M	data/text8


In [4]:
with open('data/text8', 'r') as f:
    corpus = f.read()


In [7]:
punc = '!"#$%&()*+,-./:;<=>?@[\\]^_\'{|}~\t\n'
for c in punc:
    if c in corpus:
        corpus.replace(c, ' ')


In [8]:
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(corpus)
token_counts = Counter(tokens)

In [9]:
# Filter out low-frequency tokens
filtered_tokens = [token for token in tokens if token_counts[token] > 5]

In [10]:
vocab = build_vocab_from_iterator([filtered_tokens])
# word -> int hash map.
token_to_idx = vocab.get_stoi()
# int -> word hash map.
idx_to_token = vocab.get_itos()

In [11]:
#negative sampling
threshold = 1e-5
token_probs = {token: (np.sqrt(token_counts[token] / 0.001) + 1) * (0.0001 / token_counts[token]) for token in token_counts}

train_tokens = [token for token in filtered_tokens if random.random() < token_probs[token]]
train_vocab = build_vocab_from_iterator([train_tokens])


In [15]:
train_token_to_idx = vocab.get_stoi()
train_idx_to_token = vocab.get_itos()

In [13]:
token_freq = Counter(train_tokens)
token_probs = torch.zeros(len(train_vocab))

s = sum([np.power(freq, 0.75) for token, freq in token_freq.items()])
for token in token_freq:
    token_probs[train_token_to_idx[token]] = np.power(token_freq[token], 0.75) / s

In [16]:
train_token_ids = [train_token_to_idx[token] for token in filtered_tokens]

In [17]:
# This just gets the (wc, wo) pairs that are positive, which means they are seen together
def create_cbow_dataset(token_ids, window_size):
    context_target_pairs = []
    for i, token_id in enumerate(token_ids):
        context_start = max(0, i - window_size)
        context_end = min(i + window_size, len(token_ids) - 1)
        left_context = token_ids[context_start:i]
        right_context = token_ids[i+1:context_end+1]

        if len(left_context) == len(right_context):
            context = left_context + right_context
            target = token_id
            context_target_pairs.append(context + [target])

    return context_target_pairs

In [18]:
train_dataset = create_cbow_dataset(train_token_ids, window_size)

In [19]:
# Set up the dataloader.
train_dataloader = DataLoader(
    TensorDataset(torch.tensor(train_dataset).to(device)),
    batch_size=batch_size,
    shuffle=True
)

In [20]:
validation_token_ids = torch.tensor([
    train_token_to_idx['money'],
    train_token_to_idx['lion'],
    train_token_to_idx['africa'],
    train_token_to_idx['musician'],
    train_token_to_idx['dance'],
])

In [21]:
# Get the model.

class CBOWModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOWModel, self).__init__()
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.target_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.context_embeddings.weight.data.uniform_(-initrange, initrange)
        self.target_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self, context_ids):
        context_embeds = self.context_embeddings(context_ids[:, :-1])
        context_embeds_mean = context_embeds.mean(axis=1)

        target_embeds = self.target_embeddings(context_ids[:, -1])

        logits = (context_embeds_mean * target_embeds).sum(axis=-1)

        return logits

In [22]:
@torch.no_grad()
def validate_model(model, validation_ids, idx_to_token):
    embedding_weights = model.context_embeddings.weight

    normalized_embeddings = embedding_weights.cpu() / np.sqrt(np.sum(embedding_weights.cpu().numpy()**2, axis=1, keepdims=True))

    validation_embeddings = normalized_embeddings[validation_ids, :]

    top_k = 10
    similarity = np.dot(validation_embeddings.cpu().numpy(), normalized_embeddings.cpu().numpy().T)
    similarity_top_k = np.argsort(-similarity, axis=1)[:, 1: top_k+1]

    for i, token_id in enumerate(validation_ids):
        similar_tokens = ', '.join([idx_to_token[j] for j in similarity_top_k[i, :] if j >= 1])
        print(f"{idx_to_token[token_id]}: {similar_tokens}")

    print('\n')

In [26]:
# Set up the model

learning_rate = 10.0
num_epochs = 10
embedding_dim = 300

model = CBOWModel(len(vocab), embedding_dim).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)

In [27]:
model

CBOWModel(
  (context_embeddings): Embedding(63641, 300)
  (target_embeddings): Embedding(63641, 300)
)

In [30]:
def train_epoch(dataloader, model, optimizer, epoch):
    model.train()
    total_loss, total_batches = 0.0, 0.0
    log_interval = 500

    for idx, (context_ids,) in tqdm(enumerate(dataloader)):
        batch_size = context_ids.shape[0]

        optimizer.zero_grad()

        logits = model(context_ids)

        positive_loss = torch.nn.BCEWithLogitsLoss()(input=logits, target=torch.ones(batch_size).to(device).float())

        negative_samples_ids = torch.multinomial(token_probs, batch_size * negative_samples, replacement=True)

        context_ids_repeated = torch.concat([c.repeat(negative_samples, 1) for c in torch.tensor(context_ids[:, :-1]).split(1)])
        negative_target_ids = negative_samples_ids.unsqueeze(-1)

        negative_context_ids = torch.concat([context_ids_repeated, negative_target_ids.to(device)], axis=1)

        negative_loss = model(negative_context_ids).neg().sigmoid().log().reshape(batch_size, negative_samples).sum(1).mean().neg().to(device)

        loss = (positive_loss + negative_loss).mean()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

        with torch.no_grad():
            param_ratios = [(learning_rate * p.grad.std() / p.data.std()).log10().item() for _, p in model.named_parameters()]

        optimizer.step()
        total_loss += loss.item()
        total_batches += 1

        if idx % log_interval == 0:
            print(f"| epoch {epoch:3d} | {idx:5d}/{len(dataloader):5d} batches | loss {total_loss / total_batches:8.3f}")
            validate_model(model, validation_token_ids, train_idx_to_token)
            total_loss, total_batches = 0.0, 0.0

In [31]:
for epoch in range(1, num_epochs + 1):
    train_epoch(train_dataloader, model, optimizer, epoch)
    scheduler.step()

  context_ids_repeated = torch.concat([c.repeat(negative_samples, 1) for c in torch.tensor(context_ids[:, :-1]).split(1)])
1it [00:02,  2.20s/it]

| epoch   1 |     0/32580 batches | loss    4.088
money: biplanes, attaches, giri, wigner, prepositional, paracetamol, mich, uwb, fanbase, gin
lion: maya, dare, buonarroti, practises, younger, risked, rhineland, ramzi, bains, cortina
africa: shapiro, batsmen, forgetful, extravagance, nicephorus, clis, comedic, micronesia, aspen, phthalocyanine
musician: northland, sentience, protectionism, mondegreen, gera, dsa, myrrh, determine, encyclop, ungodly
dance: banshees, exceeds, wtoo, endomorphisms, samoans, screename, appreciable, huygens, gorges, reunited




500it [01:21,  6.87it/s]

| epoch   1 |   500/32580 batches | loss    2.775


502it [01:21,  5.19it/s]

money: biplanes, giri, wigner, attaches, prepositional, paracetamol, sensuous, fanbase, mich, gin
lion: maya, dare, buonarroti, practises, younger, risked, rhineland, ramzi, cortina
africa: forgetful, extravagance, batsmen, shapiro, nicephorus, comedic, clis, aspen, micronesia, molality
musician: northland, protectionism, sentience, mondegreen, gera, myrrh, determine, dsa, neuroanatomy, ungodly
dance: banshees, wtoo, samoans, exceeds, screename, endomorphisms, appreciable, huygens, kantele, gorges




1001it [02:35,  5.24it/s]

| epoch   1 |  1000/32580 batches | loss    2.021
money: biplanes, all, wigner, giri, they, sensuous, fanbase, mobile, is, prepositional
lion: maya, buonarroti, dare, practises, younger, risked, rhineland, ramzi, cortina, buffett
africa: forgetful, extravagance, batsmen, comedic, shapiro, alai, aspen, kilowatt, bum, clis
musician: northland, protectionism, sentience, mondegreen, myrrh, neuroanatomy, gera, determine, orloff, drunkenness
dance: banshees, wtoo, samoans, exceeds, appreciable, screename, kantele, reunited, endomorphisms, huygens




1501it [03:48,  5.32it/s]

| epoch   1 |  1500/32580 batches | loss    1.790
money: all, so, biplanes, they, such, mobile, wigner, high, giri, sensuous
lion: maya, dare, buonarroti, practises, younger, risked, rhineland, ramzi, states, cortina
africa: forgetful, extravagance, batsmen, since, free, comedic, years, death, close, due
musician: northland, protectionism, sentience, determine, mondegreen, myrrh, neuroanatomy, gera, management, orloff
dance: wtoo, samoans, banshees, appreciable, exceeds, kantele, european, reunited, art, special




2001it [05:03,  4.92it/s]

| epoch   1 |  2000/32580 batches | loss    1.642
money: all, so, high, mobile, case, biplanes, amount, they, such, is
lion: maya, dare, buonarroti, younger, practises, states, day, risked, rhineland
africa: since, free, extravagance, forgetful, death, due, years, batsmen, greek, other
musician: northland, protectionism, art, determine, management, sentience, mondegreen, changed, cost, neuroanatomy
dance: wtoo, samoans, european, banshees, art, special, exceeds, appreciable, kantele, higher




2501it [06:24,  4.76it/s]

| epoch   1 |  2500/32580 batches | loss    1.543
money: all, so, mobile, high, amount, case, list, him, more, wars
lion: maya, younger, dare, buonarroti, practises, day, practice, states, well, lives
africa: free, death, forgetful, extravagance, due, since, greek, batsmen, went, years
musician: art, northland, determine, management, protectionism, cost, changed, near, oxford, book
dance: european, wtoo, art, samoans, banshees, special, higher, exceeds, kantele, appreciable




3001it [07:42,  4.72it/s]

| epoch   1 |  3000/32580 batches | loss    1.474
money: amount, mobile, so, wars, all, case, is, sea, high, more
lion: maya, younger, dare, buonarroti, day, practice, well, practises, states, lives
africa: death, free, forgetful, extravagance, due, greek, went, since, batsmen, both
musician: art, determine, management, northland, changed, cost, protectionism, book, near, de
dance: art, european, wtoo, samoans, special, higher, banshees, particularly, kantele, exceeds




3500it [09:02,  6.29it/s]

| epoch   1 |  3500/32580 batches | loss    1.416


3502it [09:02,  4.64it/s]

money: amount, wars, mobile, all, case, sea, so, function, more, end
lion: maya, younger, dare, practice, buonarroti, well, states, day, lives, practises
africa: death, went, free, extravagance, due, forgetful, greek, batsmen, top, march
musician: art, determine, management, changed, book, cost, northland, allowed, near, de
dance: art, european, wtoo, special, higher, samoans, particularly, kantele, enemy, banshees




4001it [10:22,  4.81it/s]

| epoch   1 |  4000/32580 batches | loss    1.373
money: amount, wars, mobile, case, function, so, end, all, sea, italian
lion: maya, younger, practice, dare, buonarroti, lives, well, day, practises, states
africa: death, went, extravagance, free, forgetful, top, due, greek, march, batsmen
musician: art, changed, management, determine, cost, book, allowed, de, near, northland
dance: art, european, wtoo, special, samoans, higher, particularly, enemy, kantele, field




4501it [11:41,  4.71it/s]

| epoch   1 |  4500/32580 batches | loss    1.337
money: amount, wars, mobile, sea, case, end, so, more, candidate, attention
lion: maya, younger, practice, dare, buonarroti, lives, well, states, in, day
africa: death, went, extravagance, forgetful, top, due, america, greek, clinton, batsmen
musician: art, changed, determine, management, allowed, cost, de, book, near, school
dance: art, european, higher, special, wtoo, particularly, samoans, enemy, court, religion




5001it [12:59,  4.65it/s]

| epoch   1 |  5000/32580 batches | loss    1.305
money: wars, amount, mobile, sea, end, more, candidate, attention, case, function
lion: younger, maya, practice, dare, buonarroti, lives, day, in, states, island
africa: death, went, america, forgetful, coast, top, extravagance, clinton, greek, needed
musician: art, changed, determine, cost, management, allowed, de, school, book, oxford
dance: art, european, higher, samoans, wtoo, special, particularly, enemy, kantele, religion




5501it [14:17,  5.05it/s]

| epoch   1 |  5500/32580 batches | loss    1.276
money: wars, amount, mobile, candidate, function, attention, sea, so, case, jack
lion: younger, maya, practice, dare, lives, buonarroti, day, speech, practises, island
africa: america, death, went, coast, top, extravagance, clinton, forgetful, needed, greek
musician: art, cost, changed, management, determine, allowed, de, book, school, near
dance: art, european, samoans, higher, special, wtoo, enemy, particularly, foundation, kantele




6000it [15:34,  6.59it/s]

| epoch   1 |  6000/32580 batches | loss    1.250


6002it [15:35,  4.95it/s]

money: amount, wars, mobile, attention, function, candidate, case, jack, sea, difficult
lion: younger, maya, practice, dare, lives, buonarroti, day, speech, practises, states
africa: america, death, coast, went, extravagance, clinton, forgetful, needed, top, batsmen
musician: art, changed, cost, management, determine, allowed, de, book, oxford, school
dance: art, european, samoans, enemy, wtoo, higher, special, foundation, kantele, particularly




6501it [16:51,  4.86it/s]

| epoch   1 |  6500/32580 batches | loss    1.233
money: amount, wars, mobile, candidate, attention, function, jack, case, more, divided
lion: maya, younger, practice, dare, lives, buonarroti, day, states, speech, practises
africa: america, coast, death, went, extravagance, clinton, needed, forgetful, march, prepared
musician: art, cost, changed, determine, management, allowed, de, oxford, book, choose
dance: art, european, samoans, wtoo, higher, foundation, enemy, particularly, special, kantele




7000it [18:06,  6.60it/s]

| epoch   1 |  7000/32580 batches | loss    1.212


7002it [18:06,  5.09it/s]

money: amount, wars, mobile, attention, candidate, maintains, function, case, jack, divided
lion: younger, maya, practice, dare, lives, buonarroti, day, speech, states, dead
africa: america, coast, death, extravagance, went, clinton, plato, prepared, needed, forgetful
musician: art, cost, changed, management, determine, de, oxford, actress, allowed, northland
dance: art, european, higher, samoans, wtoo, foundation, enemy, particularly, renamed, mean




7335it [19:00,  6.43it/s]


KeyboardInterrupt: 