Some code and ideas from
https://github.com/lukysummer/SkipGram_with_NegativeSampling_Pytorch/tree/master

and https://github.com/reynoldsnlp/F23_LING581/blob/main/code/ch6_numpy.py

In [14]:
import os, pickle, json
from collections import Counter
from itertools import islice
from pprint import pprint as pp

import numpy as np

import torch
from torch import tensor, nn, optim
from torch.nn import functional as F
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence

import nltk
from nltk.corpus import brown

nltk.download('brown')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Package brown is already up-to-date!


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')
DRIVE_PATH = '/content/gdrive/MyDrive/School/LING581/final'

Mounted at /content/gdrive


# Parameters

In [16]:
# Vocab parameters
# If any of this changes, will need to rebuild context and
# negative sample lists.
min_freq = 5
UNK = '<UNK>'
PAD = '<PAD>'
specials = (UNK, PAD)

# Window parameters
window_size = 5
edge = (window_size - 1) // 2

# Negative sample parameters
num_negs_per_pos = 2
alpha = 0.75

# Set to false to re-compute and save.
load_vocab = False
load_contexts = True
load_negs = True

# Paths
vocab_file_path = 'vocab'
contexts_file_path = 'contexts'
negs_file_path = 'negative-samples'
embedding_folder_path = 'embedding-models'
epoch_losses_path = 'train-embeddings_epoch-losses'

# Build vocab, contexts, and negative samples

In [4]:
### Build vocab
abs_vocab_file_path = os.path.join(DRIVE_PATH, vocab_file_path + '.pt')

if load_vocab:
  vocab = torch.load(abs_vocab_file_path)
else:
  vocab = build_vocab_from_iterator(brown.sents(), min_freq=min_freq,
                                    specials=specials)
  vocab.set_default_index(vocab[UNK])

  torch.save(vocab, abs_vocab_file_path)

# Demonstrate vocab token2id and id2token maps
_ = vocab.lookup_indices([UNK, PAD, 'The'])
_, vocab.lookup_tokens(_)

([0, 1, 16], ['<UNK>', '<PAD>', 'The'])

In [5]:
### Convert the list of sentences from strings to ids.
id_sents = [None] * len(brown.sents())
for i, sent in enumerate(brown.sents()):
  id_sents[i] = vocab.lookup_indices(sent)

In [6]:
### Prepare contexts.
# In each row (sublist), the first word is the target,
# and the remaining words are the contexts.
abs_contexts_file_path = os.path.join(DRIVE_PATH, contexts_file_path)

if load_contexts:
  with open (abs_contexts_file_path, 'rb') as file:
    contexts = pickle.load(file)
else:
  contexts = list()
  num_contexts = list()

  for sent in id_sents:
    length = len(sent)
    for i, token in enumerate(sent):
      contexts.append([token] + sent[max(0, i-edge) : i] +
                      sent[i+1 : min(i+1 + edge, length)])
      num_contexts.append(len(contexts[-1]) - 1)

  with open(abs_contexts_file_path, 'wb') as file:
    pickle.dump(contexts, file)

In [7]:
### Prepare negative samples.
abs_negs_file_path = os.path.join(DRIVE_PATH, negs_file_path)

if load_negs:
  with open (abs_negs_file_path, 'rb') as file:
    negative_samples = pickle.load(file)
else:
  # Flattened corpus of ids
  id_words = [ token for sent in id_sents for token in sent ]

  # Get frequencies of token IDs.
  freqs = Counter(id_words)

  # Get weighted unigram frequencies from which to get negative samples,
  # as in section 6.8.2 of the text.
  weighted = Counter({ token : freq**alpha for token, freq in freqs.items() })

  # Turn into probabilities
  total = weighted.total()
  weights = { token : freq / total for token, freq in weighted.items() }
  ids, probabilities = list(weights.keys()), list(weights.values())

  # Calculate the number of negative samples to get
  # (using num_negs_per_pos per context).
  num_negs = sum(num_contexts) * num_negs_per_pos

  # Get flattened list of negative samples.
  flat_negative_samples = np.random.choice(ids, p=probabilities, size=num_negs)
  iter_negative_samples = iter(flat_negative_samples)

  # Structure negative samples like the contexts (each row contains the target
  # followed by the negative samples).
  negative_samples = [[context[0]] +
                      list(islice(iter_negative_samples, num * num_negs_per_pos))
                      for context, num in zip(contexts, num_contexts)]

  # Make sure no token has itself as a negative sample
  # Note: this takes about 2 minutes
  print('Total iterations:', len(num_contexts))
  for i, num in enumerate(num_contexts):
    while negative_samples[i][0] in negative_samples[i][1:]:
      negative_samples[i][1:] = np.random.choice(ids, p=probabilities,
                                                size=num * num_negs_per_pos)
    if i % 100000 == 0:
      print(i)

  # This takes about 30 seconds.
  with open(abs_negs_file_path, 'wb') as file:
    pickle.dump(negative_samples, file)

# Build dataset

In [8]:
class Samples(Dataset):
  '''Dataset that returns a tuple of a contexts list and an corresponding
  negative samples list.
  '''

  def __init__(self, contexts, negative_samples):
    self.contexts = contexts
    self.negative_samples = negative_samples

  def __len__(self):
    return len(self.contexts)

  def __getitem__(self, idx):
    return self.contexts[idx], self.negative_samples[idx]

In [9]:
def collate_batch(batch):
  '''Given a batch of context lists and negative samples lists, create a batch.

  Return:
    tokens (tensor): the tokens to which contexts and negative samples
      correspond
    padded_contexts (tensor): padded contexts, each row corresponding to a token
      in tokens
    padded_negs (tensor): padded negative samples, each row corresponding to a
      token in tokens
  '''

  padding_value = vocab[PAD]

  # Get tokens from beginning of each contexts list (could get them from
  # negative samples list too)
  tokens = tensor([contexts[0] for contexts, _ in batch], dtype=int).to(device)

  # Convert lists of context token IDs to tensors
  contexts_list = [tensor(contexts[1:], dtype=int) for contexts, _ in batch]
  # Pad
  padded_contexts = pad_sequence(contexts_list, batch_first=True,
                                 padding_value=padding_value).to(device)

  # Same for negative samples
  negs_list = [tensor(negs[1:], dtype=int) for _, negs in batch]
  padded_negs = pad_sequence(negs_list, batch_first=True,
                             padding_value=padding_value).to(device)

  return tokens, padded_contexts, padded_negs

# Define model

Discussion about `clip_norm` and clipping to `embedding_dim**(1/2)`:

The point is to keep the weights from getting too large.

The expected value of the squared 2-norm of a standard multivariate normal random variable is the dimension of the random variable, so we assume the expected value of the 2-norm is approximately the square root of the dimension.

For arbitrary $p$-norms, the expected value is inversely related to p and still dependent on the dimension (so larger dimensions and smaller p both contribute to larger expected values). See my "experiment" notebook.

However, I don't know the exact relation, so this is just to keep the invidual weights from getting too large, yet not clip them to be too small (for example, regardless of $p$, given a vector on a unit ball, the entries will be smaller the larger the dimension).

Additional work could normalize vectors to live on a ball of some radius, rather than just clip large vectors.

In [10]:
class CustomEmbedding(nn.Module):
  def __init__(self, vocab, similarity, embedding_dim, clip_norm):
    """

    Params:
      similarity (float or str):
        float: the type of p-norm to use when computing similarities
        str:
          'dot': dot product
          'cos': cosine similarity
      clip_norm (bool): if True, embeddings with norm larger than
        `MAX_NORM_MULTIPLE * max_norm` will be normalized to have this norm
    """
    super().__init__()

    MAX_NORM_MULTIPLE = 2

    self.vocab = vocab
    self.vocab_size = len(vocab)
    self.embedding_dim = embedding_dim
    self.similarity = similarity

    match self.similarity:
      case 'dot' | 'cos':
        ord = 2
      case float() | int():
        ord = self.similarity
      case _:
        raise ValueError('invalid similarity parameter')

    if clip_norm:
      sample_size = 10000
      sample = torch.randn((sample_size, embedding_dim))
      mean_norm = sample.norm(dim=1, p=ord).mean()
      max_norm = MAX_NORM_MULTIPLE * mean_norm
    else:
      max_norm = None

    self.target_embedding = nn.Embedding(self.vocab_size, self.embedding_dim,
                                         max_norm=max_norm, norm_type=ord)
    self.context_embedding = nn.Embedding(self.vocab_size, self.embedding_dim,
                                          max_norm=max_norm, norm_type=ord)

    self.sim = self.get_sim()
    self.score = self.get_score()

  def get_sim(self):
    # Output of sim will be (batch)x(number of contexts/negs)
    match self.similarity:
      case 'dot':
        # Want to maximize dot product for true contexts and
        # minimize for negative samples
        return lambda target_embs, context_embs: \
          torch.bmm(context_embs, target_embs.unsqueeze(-1)).squeeze(-1)
      case 'cos':
        # Want to maximize cosine similarity for true contexts and
        # minimize for negative samples
        return lambda target_embs, context_embs: \
          F.cosine_similarity(target_embs.unsqueeze(1), context_embs, dim=-1)
      case float() | int():
        # Want to maximize negative distance for true contexts and
        # minimize for negative samples
        return lambda target_embs, context_embs: \
          torch.linalg.vector_norm(target_embs.unsqueeze(1) - context_embs,
                                   dim=-1, ord=self.similarity).neg()
      case _:
        raise ValueError('invalid similarity parameter')

  def get_score(self):
    match self.similarity:
      case 'dot':
        return lambda target_embs, context_embs: \
          torch.mm(target_embs, context_embs.T)
      case 'cos':
        return self.sim
      case float() | int():
        return self.sim

  def forward(self, batch):
    targets, contexts, negs = batch

    target_embs = self.target_embedding(targets)
    context_embs = self.context_embedding(contexts)
    neg_embs = self.context_embedding(negs)

    context_scores = self.sim(target_embs, context_embs)
    neg_scores = self.sim(target_embs, neg_embs)

    return context_scores, neg_scores

# Helper functions for verifying embeddings are training

In [11]:
valid_tokens = ['dog', 'milk', 'run', 'apple', 'hurt']
valid_idxs = tensor(vocab.lookup_indices(valid_tokens))

def get_topk(k, model, valid_idxs):
  with torch.no_grad():
    valid_embs = model.target_embedding(valid_idxs.to(device))
    target_embs = model.target_embedding.weight

    scores = model.score(target_embs, valid_embs)

    topk = scores.topk(k, dim=0)[1]

    return [vocab.lookup_tokens(topk[:, i].tolist()) for i in range(len(valid_tokens))]

def print_topk(k, model, valid_idxs):
  '''Get the words with that are closest to valid_tokens according to the model.'''

  topk_lists = get_topk(topk, model, valid_idxs)
  for token, topk_list in zip(valid_tokens, topk_lists):
    print(f'  {token:<7}: ', ' '.join(topk_list))

# Define training procedure

In [39]:
def train_skipgram(model,
                   criterion,
                   optimizer,
                   dataloader,
                   num_epochs=5,
                   print_every=1500,
                   topk=5,):
    print_topk(topk, model, valid_idxs)

    # List of average batch loss for each epoch
    avg_epoch_losses = list()

    for epoch in range(num_epochs):
      print(f'epoch: {epoch+1}/{num_epochs}')
      num_batches = len(dataloader)

      model.train()
      epoch_loss = 0.
      for batch_num, (targets, contexts, negs) in enumerate(dataloader):

        context_scores, neg_scores = model((targets, contexts, negs))

        scores = torch.cat((context_scores, neg_scores), dim=-1)
        labels = torch.cat((torch.ones_like(context_scores),
                            torch.zeros_like(neg_scores)), dim=-1)

        loss = criterion(scores, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item() * (len(targets) / batch_size)

        if (batch_num % print_every) == 0:
          print(f' batch {batch_num:>6}/{num_batches} | loss {loss.item():<.3f}')

      avg_epoch_losses.append(epoch_loss / num_batches)

      print_topk(topk, model, valid_idxs)

    return avg_epoch_losses

# Train

In [42]:
# Parameters
embedding_dim = 64
learning_rate = 0.003
batch_size = 256

print_every = 1500
num_epochs = 5
topk = 6

criterion = nn.BCEWithLogitsLoss()

similarities = ['dot', 'cos', 1, 2, 3]

names = [sim if type(sim) is str else f'norm{sim}' for sim in similarities] + \
        [f'{sim}_clip-norm' if type(sim) is str else f'norm{sim}_clip-norm'
        for sim in similarities]
clip_norms = [False] * len(similarities) + [True] * len(similarities)

# Double the length
similarities += similarities

In [43]:
dataset = Samples(contexts, negative_samples)

# For each model type, list of average batch loss for each epoch
avg_epoch_losses = dict()

for name, similarity, clip_norm in zip(names, similarities, clip_norms):
  print(f"{' ' + name + ' ':-^57}")

  torch.manual_seed(0)
  model = CustomEmbedding(vocab, similarity=similarity,
                          embedding_dim=embedding_dim, clip_norm=clip_norm) \
                          .to(device)

  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  dataloader = DataLoader(
      dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch,
  )

  losses = train_skipgram(model,
                          criterion,
                          optimizer,
                          dataloader,
                          num_epochs = num_epochs,
                          topk=topk,)

  # abs_model_path = os.path.join(DRIVE_PATH, embedding_folder_path, name + '.pt')
  # torch.save(model.state_dict(), abs_model_path)

  avg_epoch_losses[name] = losses

  print()

with open(os.path.join(DRIVE_PATH, epoch_losses_path + '.json'), 'w') as file:
  json.dump(avg_epoch_losses, file)

-------------------------- dot --------------------------
  dog    :  dog Chapman straightened succumbed Women's as
  milk   :  milk flattered 1910 unstable sullen think
  run    :  run Lipton miserable managers ecstasy hazardous
  apple  :  apple to Friday recovery retired gown
  hurt   :  hurt 350 imagination machine welfare comparative
epoch: 1/5
 batch      0/4536 | loss 3.228
 batch   1500/4536 | loss 1.281
 batch   3000/4536 | loss 0.987
 batch   4500/4536 | loss 0.860
  dog    :  dog succumbed Without adolescents solidly Chapman
  milk   :  milk unstable flattered sullen Rice dozed
  run    :  run managers miserable grant Alabama ecstasy
  apple  :  apple recovery spacious Friday Interstate retired
  hurt   :  hurt welfare What's indicating indices torquer
epoch: 2/5
 batch      0/4536 | loss 0.718
 batch   1500/4536 | loss 0.685
 batch   3000/4536 | loss 0.659
 batch   4500/4536 | loss 0.625
  dog    :  dog Without solidly adolescents inspiring succumbed
  milk   :  milk dozed 