In [None]:
import os
import zipfile
import urllib.request
import torch
import torch.nn as nn

# Download the dataset
url = 'http://mattmahoney.net/dc/text8.zip'
filename = 'text8.zip'

if not os.path.exists(filename):
  print("Downloading text8...")
  urllib.request.urlretrieve(url, filename)

# Extract the dataset
with zipfile.ZipFile(filename) as f:
  text = f.read(f.namelist()[0]).decode('utf-8')

print(f"First 300 characters:\n{text[:300]}")

First 300 characters:
 anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act that used violent means to destroy the organiz


In [None]:
from collections import Counter
import nltk
from nltk.corpus import stopwords
import random

nltk.download('stopwords')

def subsample_tokens(tokens, threshold=1e-5):
  subsampled_tokens = []
  word_counts = Counter(tokens)
  total_counts  = sum(word_counts.values())

  for token in tokens:
    normalized_freq = word_counts[token]/total_counts
    p_keep = (threshold/normalized_freq)**0.5

    p_keep = min(1.0, max(0.0, p_keep)) # clamp [0, 1]

    if random.random() < p_keep:
      subsampled_tokens.append(token)

  return subsampled_tokens

# English stop words
stop_words = set(stopwords.words('english'))

# Building vocab
tokens = text.split()
print(f"Total tokens: {len(tokens)}")

# Filter out stop words
filtered_tokens = [token for token in tokens if token.lower().strip() not in stop_words]
print(f"Total filtered tokens: {len(filtered_tokens)}")

subsampled_tokens = subsample_tokens(filtered_tokens)
print(f"Total subsampled tokens: {len(subsampled_tokens)}")

word_freq =  Counter(subsampled_tokens)
print(f"Unique words: {len(word_freq)}")

vocab = {word:idx for idx, (word, _) in enumerate(word_freq.items())}

inv_vocab = {idx:word for word, idx in vocab.items()}

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Total tokens: 17005207
Total filtered tokens: 10890638
Total subsampled tokens: 4130359
Unique words: 253702


In [None]:
# Filter rare words
vocab_size = 10000

most_common = word_freq.most_common(vocab_size)#[:-1000-1:-1]
vocab = {word:idx for idx, (word, _) in enumerate(most_common)}
inv_vocab = {idx:word for word, idx in vocab.items()}
print(f"Unique words: {len(vocab)}")

# filter tokens to keep only those in vocab
tokens = [token for token in tokens if token in vocab]
print(f"Filtered tokens: {len(tokens)}")

# filtered word freq
filtered_word_freq = {word:freq for word, freq in most_common}
print(f"Filtered word freq: {len(filtered_word_freq)}")

Unique words: 10000
Filtered tokens: 9170125
Filtered word freq: 10000


In [None]:
def generate_skipgram_pairs(tokens:list, window_size=8, max_len=1_000_000):
  pairs = []

  for i, center_word in enumerate(tokens):
    if center_word not in vocab:
      continue

    start_idx = max(0, i - window_size)
    end_idx = min(len(tokens), i + window_size + 1)

    for j in range(start_idx, end_idx):

      if i != j:
        context_word = tokens[j]
        if context_word in vocab:
          pairs.append((vocab[center_word], vocab[context_word]))
          if len(pairs) >= max_len:
            return pairs

  return pairs

In [None]:
import heapq
from collections import defaultdict

# note that words are only leaf nodes
class HuffmanNode:
  def __init__(self, freq, word=None, left=None, right=None, idx=None):
    self.freq = freq
    self.word = word # None for internal nodes
    self.left = left
    self.right = right
    self.idx = idx # unique index for internal nodes

  def __lt__(self, other):
    return self.freq < other.freq

def build_huffman_tree(word_freq):
  heap = []
  word_to_leaf = {}

  # initialize heap with leaf nodes
  for word, freq in word_freq.items():
    node = HuffmanNode(freq, word)
    heapq.heappush(heap, node)
    word_to_leaf[word] = node

  next_internal_idx = len(word_freq) # internal node indices start after word indices

  # build tree

  while len(heap) > 1:
    node1 = heapq.heappop(heap)
    node2 = heapq.heappop(heap)

    parent = HuffmanNode(
        freq=node1.freq + node2.freq,
        left=node1,
        right=node2,
        idx=next_internal_idx
    )
    next_internal_idx += 1
    heapq.heappush(heap, parent)

  root = heap[0]
  return root, word_to_leaf

In [None]:
def extract_codes_and_paths(root, vocab):
  word_idx_to_code = {}
  word_idx_to_path = {}
  max_path_len = 0

  def dfs(node, code, path):
    nonlocal max_path_len
    if node.word is not None:
      idx = vocab[node.word]
      word_idx_to_code[idx] = code.copy()
      word_idx_to_path[idx] = path.copy()
      max_path_len = max(max_path_len, len(path))
      return

    # Go left (0)
    dfs(node.left, code + [0], path + [node.idx])

    # Go right (1)
    dfs(node.right, code + [1], path + [node.idx])

  dfs(root, [], [])
  return word_idx_to_code, word_idx_to_path, max_path_len

In [None]:
def create_huffman_codes(word_freq, vocab):
  root, _ = build_huffman_tree(word_freq)
  codes, paths, max_path_len = extract_codes_and_paths(root, vocab)
  return codes, paths, max_path_len

In [None]:
import torch.nn.functional as F
def hierarchical_softmax_batched(hidden, target_word_idxs, codes, paths, node_embeddings, path_mask, vocab_size):
  """
  hidden: (B, H)  - batch of context vectors
  target_word_idxs: (B,) - not directly used here unless for lookup
  codes: (B, L) - 0 or 1 (padded)
  paths: (B, L) - internal node indices (original indices, padded with 0)
  node_weights: (num_nodes, H) - Tensor of node weights, ordered according to sorted original indices + 0 for padding
  path_mask: (B, L) - 1 for valid path steps, 0 for padding
  vocab_size: int - size of the vocabulary
  """
  B, L = paths.shape
  H = hidden.size(1)

  # Remap original paths indices to tensor indices
  # Original internal node indices start from vocab_size
  # Tensor indices start from 0 (for padding) and then 1 onwards for remapped internal nodes
  # Mapping: original_idx -> original_idx - vocab_size + 1 (if original_idx >= vocab_size), 0 if original_idx is padding (0)
  mapped_paths = torch.zeros_like(paths, dtype=torch.long)
  # Handle padding index (-1)
  mapped_paths[paths == 0] = 0
  # Handle internal node indices
  internal_node_mask = paths >= vocab_size
  mapped_paths[internal_node_mask] = paths[internal_node_mask] - vocab_size


  # Get internal node vectors for each sample in the batch
  # mapped_paths: (B, L) → internal_node_vecs: (B, L, H)
  internal_node_vecs = node_embeddings(mapped_paths)  # (B, L, H)


  # Expand hidden from (B, H) → (B, L, H) to align for dot product
  hidden_expanded = hidden.unsqueeze(1).expand(-1, L, -1)  # (B, L, H)

  # Dot product: (B, L)
  dot_scores = torch.sum(hidden_expanded * internal_node_vecs, dim=2)

  # Sigmoid
  probs = torch.sigmoid(dot_scores)  # (B, L)

  # Log probs
  codes = codes.float()
  losses = F.binary_cross_entropy(probs, codes, reduction='none')

  # Apply mask
  losses = losses * path_mask.float() # (B, L)

  # Normalize per sequence length
  total_losses = losses.sum(dim=1)/path_mask.sum(dim=1).clamp(min=1e-8)

  return total_losses.mean()  # scalar

In [None]:
import math
def pad_codes_and_paths(pairs, max_path_len, huffman_codes, huffman_paths, pad_idx=0, batch_size=32):

  """
  huffman_codes: dict[word_idx] = list of 0/1
  huffman_paths: dict[word_idx] = list of node indices
  """

  padded_paths = []
  padded_codes = []
  path_masks = []

  pairs_len = len(pairs)
  for idx, (_, t) in enumerate(pairs):
    path = huffman_paths[t]
    code = huffman_codes[t]

    # Padding
    pad_len = max_path_len - len(path)
    padded_path = path + [pad_idx] * pad_len
    padded_code = code + [0] * pad_len
    mask = [1] * len(path) + [0] * pad_len

    padded_paths.append(padded_path)
    padded_codes.append(padded_code)
    path_masks.append(mask)

    if (idx + 1) % batch_size == 0 or idx == pairs_len-1:
      start_index = (math.ceil((idx + 1)/batch_size) - 1) * batch_size
      batch = pairs[start_index: idx+1]
      center_word_indices, context_word_indices = zip(*batch)

      yield (torch.tensor(center_word_indices),
             torch.tensor(context_word_indices),
             torch.tensor(padded_codes, dtype=torch.long),
             torch.tensor(padded_paths, dtype=torch.long),
             torch.tensor(path_masks, dtype=torch.float)
             )

      padded_paths.clear()
      padded_codes.clear()
      path_masks.clear()


In [None]:
class Word2VecHS(nn.Module):
  def __init__(self, vocab_size, embedding_dim, paths):
    super().__init__()
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    # Create a list of unique internal node indices, including 0 for padding
    internal_node_indices = sorted(list(set(idx for path in paths.values() for idx in path) | {0})) # Include 0 for padding. note that node indices starts at the end of vocab size


    self.node_embeddings = nn.Embedding(len(internal_node_indices), embedding_dim, padding_idx=0) # create embedding for each node
    self.dropout = nn.Dropout(0.6)

  def forward(self, codes, paths, center_words_indices, context_words_indices, path_mask):
    batch_size = center_words_indices.shape[0]
    embeddings = self.embedding(center_words_indices)
    embeddings = self.dropout(embeddings)


    batch_loss = hierarchical_softmax_batched(
        hidden=embeddings,
        target_word_idxs=context_words_indices,
        codes=codes,
        paths=paths,
        node_embeddings=self.node_embeddings, # Pass the tensor
        path_mask=path_mask,
        vocab_size=self.vocab_size
      )

    return batch_loss

In [None]:
codes, paths, max_path_len = create_huffman_codes(filtered_word_freq, vocab=vocab)
model = Word2VecHS(len(vocab), 300, paths)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Word2VecHS(
  (embedding): Embedding(10000, 300)
  (node_embeddings): Embedding(10000, 300, padding_idx=0)
  (dropout): Dropout(p=0.6, inplace=False)
)

In [None]:
pairs = generate_skipgram_pairs(tokens, max_len=1_000_000)
#batches = generate_skipgram_batches(pairs, batch_size=1)
padded_batches = pad_codes_and_paths(pairs, max_path_len, codes, paths, batch_size=32)

In [None]:
len(pairs)

1000000

In [None]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5, threshold=0.001)

In [None]:
epochs = 40
model.train()
for epoch in range(epochs):
  running_loss = 0.0
  batch_count = 0

  # Regenerate batches each epoch
  padded_batches = pad_codes_and_paths(pairs, max_path_len, codes, paths, batch_size=32)


  for center_batch, context_batch, codes_batch, paths_batch, path_mask_batch in padded_batches:
    center_batch = center_batch.to(device)
    context_batch = context_batch.to(device)
    codes_batch = codes_batch.to(device)
    paths_batch = paths_batch.to(device)
    path_mask_batch = path_mask_batch.to(device)


    optimizer.zero_grad()
    loss = model(codes_batch, paths_batch, center_batch, context_batch, path_mask_batch)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    batch_count += 1
  scheduler.step(running_loss/batch_count)

  print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/batch_count}, LR: {scheduler.get_last_lr()[0]}")

Epoch 1/40, Loss: 11.909344181529045, LR: 0.001
Epoch 2/40, Loss: 5.808248297450065, LR: 0.001
Epoch 3/40, Loss: 3.807633185948372, LR: 0.001
Epoch 4/40, Loss: 2.772666641082764, LR: 0.001
Epoch 5/40, Loss: 2.152166331498146, LR: 0.001
Epoch 6/40, Loss: 1.7515308139390946, LR: 0.001
Epoch 7/40, Loss: 1.4858631509661675, LR: 0.001
Epoch 8/40, Loss: 1.2952395974769593, LR: 0.001
Epoch 9/40, Loss: 1.1529252299599648, LR: 0.001
Epoch 10/40, Loss: 1.0489013797821998, LR: 0.001
Epoch 11/40, Loss: 0.9673258307271003, LR: 0.001
Epoch 12/40, Loss: 0.9046352724394798, LR: 0.001
Epoch 13/40, Loss: 0.8542623966655731, LR: 0.001
Epoch 14/40, Loss: 0.8157607867102623, LR: 0.001
Epoch 15/40, Loss: 0.7810280789175034, LR: 0.001
Epoch 16/40, Loss: 0.7547019110040665, LR: 0.001
Epoch 17/40, Loss: 0.7325589862709045, LR: 0.001
Epoch 18/40, Loss: 0.714932128390789, LR: 0.001
Epoch 19/40, Loss: 0.6979182883553505, LR: 0.001
Epoch 20/40, Loss: 0.6850562409195899, LR: 0.001
Epoch 21/40, Loss: 0.6737167777647

In [None]:
# use model to predict next word
def predict_next_topk_words(word_idx, model=model, topk=5):
  topk_words = []
  model.eval()
  with torch.no_grad():
    last_embedding = model.embedding(torch.tensor([word_idx]).to(device))  # shape: (1, D)

    # Normalize embeddings to unit vectors
    normalized_embeddings = F.normalize(model.embedding.weight, dim=1)
    normalized_last = F.normalize(last_embedding, dim=1)

    # Compute cosine similarity
    cos_similarities = torch.matmul(normalized_last, normalized_embeddings.T).squeeze(0)

    topk = torch.topk(cos_similarities, k=topk)
    for i in topk.indices:
      topk_words.append(inv_vocab[i.item()])  # Convert index back to word

    return topk_words



In [None]:
last_word = "queen"
last_word_idx = vocab[last_word]

predictions = predict_next_topk_words(last_word_idx)
print(predictions)

['queen', 'elizabeth', 'governor', 'represented', 'parliament']
