In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nltk
from nltk.corpus import brown
import numpy as np
from collections import Counter
import random

# Set device (use GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# Load Data and Build Vocabulary

**Load Data**

**Preprocess Data**
 - Tokenize
 - Lower case
 - Remove rare words
 - Subsample frequent words

**Build Vocabulary**

**Subsample Frequent Words**



In [3]:
nltk.download('brown')

[nltk_data] Downloading package brown to /home/yifan/nltk_data...
[nltk_data]   Package brown is already up-to-date!


True

In [13]:
# Get sentences from Brown corpus
sentences = brown.sents()
# Lower case
sentences = [[word.lower() for word in sentence] for sentence in sentences]

# Count the occurences of each word
word_counts = Counter(word for sentence in sentences for word in sentence)

# Remove rare words with appearances less than 
min_count = 5
rare_words = set([word for word,count in word_counts.items() if count < min_count])
word_counts_rare = sum(word_counts[word] for word in rare_words)
print(len(rare_words))

# Remove rare words from sentences
sentences = [[word if word not in rare_words else 'UNK' for word in sentence] 
             for sentence in sentences]

# Count the occurences of each word
word_counts = Counter(word for sentence in sentences for word in sentence)
word_counts['UNK'] = word_counts_rare

print(word_counts['UNK'])


35594
58040


In [14]:
# Count the frequency of each word
total_counts = sum(word_counts.values())
word_freqs = [count / total_counts for word,  count in word_counts.items()]
# Calculate the probability of discarding each word
subsample_threshold = 500
p_discard = {}
for word, freq in word_counts.items():
  p_discard[word] = max(0, 1 - np.sqrt(subsample_threshold / freq)) if freq > subsample_threshold else 0

# Subsample frequent words
subsampled_sentences = []
for sentence in sentences:
  subsampled_sentence = [word for word in sentence if random.random() < 1- p_discard[word]]
  subsampled_sentences.append(subsampled_sentence)

sentences = subsampled_sentences

# Build vocabulary after subsampling
vocab = sorted(word_counts.keys())
vocab_size = len(vocab)

# Build word to index and index to word mappings
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}




#  Prepare data for training

**Generate skip-gram pairs**

**Negative sampling**

In [15]:
window_size = 5
positive_pairs = []
# Convert subsampled sentences to indices
sentences_idx = [[word2idx[word] for word in sentence] 
                 for sentence in subsampled_sentences]
for sentence in sentences_idx:
    for i in range(len(sentence)):
        target = sentence[i]
        # Context window: from i-window_size to i+window_size, excluding i
        for j in range(max(0, i - window_size), min(len(sentence), i + window_size + 1)):
            if j != i:
                context = sentence[j]
                positive_pairs.append((target, context))

class PositivePairsDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx]

positive_dataset = PositivePairsDataset(positive_pairs)
dataloader = DataLoader(positive_dataset, batch_size=512, shuffle=True)

In [16]:
for batch in dataloader:
  print(batch)
  break

[tensor([ 8338,  2470,    58,  9471,  6042,  6953,  5155, 14080,   277,   224,
         5732, 14177,  5883,    36,  4901,  8750,  1752, 10948,  7752,  9205,
        12828, 13776,  7084, 10110,  4608,  7358,  6026,  8614,   418, 13497,
         1439,  5214,  8866, 10580,  7196,   309,  4361, 12902,  4935,  3633,
         7986,  1159,  2690,   278,  3441,  8415,   655, 13339, 11994,  6056,
         1803,  9473, 12165,  1017, 13469,  3945,  9776,  7986,  3239, 13505,
        13000,  1409, 14078,  6450, 13464,  3518,  8750,   566,  8338,  6945,
         8911,  8338,  1533,  5203,   692,  8109,  2506,  6946,  7079,  7655,
         3112,  3072,  2822, 14058, 10892,    39, 11595, 13907,  6020,   829,
        14199,  7725, 10823,  6788,  8502,  6230,  5799, 14044,  4938, 13611,
        11921, 12806, 11501, 12193,  9503,  6315,  8134,  7222, 10297,  6684,
         7237,  6925, 11768,  4867,  9190, 13015,  2200,  8110,  8638,  7921,
          559, 12861,  5616,  2595,  6799,  8614,  7436,  5819,

In [17]:
p_w = np.array([count ** 0.75 for word, count in word_counts.items()])
p_w = p_w / p_w.sum()
neg_dist = torch.from_numpy(p_w).float().to(device)

print(neg_dist)

tensor([1.8449e-02, 3.5903e-05, 1.8838e-04,  ..., 3.5903e-05, 2.0399e-05,
        1.8455e-05], device='cuda:0')


# Network







In [18]:
class SkipGramNegativeSampling(nn.Module):
  def __init__(self, vocab_size, embedding_dim=100):
    super(SkipGramNegativeSampling, self).__init__()
    self.context_embedding = nn.Embedding(vocab_size, embedding_dim)
    self.center_embedding = nn.Embedding(vocab_size, embedding_dim)
  
  def forward(self, context_words, center_words):
    v_conext = self.context_embedding(context_words)
    v_center = self.center_embedding(center_words)
    return torch.sum(v_center * v_conext, dim=1)
    

In [19]:
# Model parameters
embedding_dim = 100
model = SkipGramNegativeSampling(vocab_size, embedding_dim).to(device)

# Loss and optimizer
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [20]:

k = 200  # Number of negative samples per positive pair
num_epochs = 30

In [21]:
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        targets, contexts = batch  # Each is [batch_size]
        targets = targets.to(device)
        contexts = contexts.to(device)
        B = targets.size(0)

        # Generate negative samples
        negative_contexts = torch.multinomial(neg_dist, B * k, replacement=True).view(B, k).to(device)

        # Prepare full batch: positive + negative samples
        targets_pos = targets
        words_pos = contexts
        labels_pos = torch.ones(B, device=device)

        targets_neg = targets.unsqueeze(1).expand(-1, k).reshape(-1)  # [B * k]
        words_neg = negative_contexts.reshape(-1)                     # [B * k]
        labels_neg = torch.zeros(B * k, device=device)

        all_targets = torch.cat([targets_pos, targets_neg], dim=0)
        all_words = torch.cat([words_pos, words_neg], dim=0)
        all_labels = torch.cat([labels_pos, labels_neg], dim=0)

        # Forward pass
        dot_products = model(all_targets, all_words)

        # Compute loss
        loss = loss_fn(dot_products, all_labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {total_loss}")

Epoch 1, Loss: 4488.4410763941705
Epoch 2, Loss: 272.80193928442895
Epoch 3, Loss: 214.53017933107913
Epoch 4, Loss: 196.2890978511423
Epoch 5, Loss: 187.98494156077504
Epoch 6, Loss: 183.3136646002531
Epoch 7, Loss: 180.24367477931082
Epoch 8, Loss: 178.00153274461627
Epoch 9, Loss: 176.28119163960218
Epoch 10, Loss: 174.87268908694386
Epoch 11, Loss: 173.74177016690373
Epoch 12, Loss: 172.8091551847756
Epoch 13, Loss: 171.9816152434796
Epoch 14, Loss: 171.3014553207904
Epoch 15, Loss: 170.73111234791577
Epoch 16, Loss: 170.1897006649524
Epoch 17, Loss: 169.74380703270435
Epoch 18, Loss: 169.33613689802587
Epoch 19, Loss: 168.97610017843544
Epoch 20, Loss: 168.6399192046374
Epoch 21, Loss: 168.33868644945323
Epoch 22, Loss: 168.07660743407905
Epoch 23, Loss: 167.857492396608
Epoch 24, Loss: 167.61907729320228
Epoch 25, Loss: 167.41281303949654
Epoch 26, Loss: 167.23227480240166
Epoch 27, Loss: 167.08321898058057
Epoch 28, Loss: 166.90361057966948
Epoch 29, Loss: 166.77624802663922
Epo

In [22]:
# Test the embedding of man + king - woman = queen

In [33]:
# Function to find synonyms for a given word
def find_synonyms(word, word_to_idx, idx_to_word, embeddings, top_k=5):
    if word not in word_to_idx:
        print(f"Word '{word}' not found in vocabulary")
        return
    
    # Get the word embedding
    word_idx = word_to_idx[word]
    word_embedding = embeddings[word_idx]
    
    # Calculate cosine similarity with all words
    similarities = torch.nn.functional.cosine_similarity(
        word_embedding.unsqueeze(0), 
        embeddings,
        dim=1
    )
    
    # Get top k similar words (excluding the input word)
    top_similarities, top_indices = similarities.topk(top_k + 1)
    
    print(f"\nTop {top_k} synonyms for '{word}':")
    for i in range(1, len(top_indices)):  # Start from 1 to skip the word itself
        idx = top_indices[i].item()
        similarity = top_similarities[i].item()
        print(f"{idx_to_word[idx]}: {similarity:.3f}")

find_synonyms('morning', word2idx, idx2word, model.center_embedding.weight.detach())



Top 5 synonyms for 'morning':
afternoon: 0.855
evening: 0.841
o'clock: 0.836
day: 0.828
next: 0.827
