# Train CBOW with pytorch

In [78]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device = 'cpu'
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams
import seaborn as sns
import warnings

sns.set()
rcParams['figure.figsize'] = (20,10)
pd.options.display.max_columns = None
warnings.filterwarnings('ignore')

In [105]:
word_count = {}
with open('../data/all_simplified2.txt', 'r') as f:
    for line in f.readlines():
        for word in line.split():
            word_count[word] = word_count.get(word, 0) + 1
          
word_count = sorted(word_count.items(), key=lambda x: x[1], reverse=True)

In [106]:
word_count_df = pd.DataFrame(word_count, columns=['word', 'count'])

In [107]:
vocab = word_count_df[word_count_df['count'] >= 50]['word'].values
len(vocab)

5946

In [108]:
word2idx = {word: idx for idx, word in enumerate(vocab, 1)}
word2idx['<unk>'] = 0

In [109]:
len(word2idx)

5947

In [110]:
window = 4 # 5 words on each side
main_dataset = []

with open('../data/all_simplified2.txt', 'r') as f:
    for line in tqdm(f.readlines()):
        words = line.strip().split(' ')
        for idx, word in enumerate(words):
            if word not in word2idx:
                continue
            context_indices = [word2idx.get(words[t],0) for t in range(idx - window, idx + window + 1) if t >= 0 and t < len(words) and t != idx]
            context_indices = context_indices + [0] * (2 * window - len(context_indices))
            main_dataset.append((word2idx[word], context_indices))

  0%|          | 0/10 [00:00<?, ?it/s]

Exception ignored in: <function tqdm.__del__ at 0x117ff03a0>
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/tqdm/std.py", line 1161, in __del__
    def __del__(self):
KeyboardInterrupt: 


In [111]:
len(main_dataset)

4652283

In [101]:
main_dataset[0]

(3145, [22, 0, 0, 158, 0, 0, 0, 0])

In [91]:
from sklearn.model_selection import train_test_split

In [102]:
class CBOWDataset(Dataset):
    def __init__(self, dataset, cur_set='train'):
        self.cur_set = cur_set
        train_, test = train_test_split(dataset, test_size=0.9, random_state=42)
        train, val = train_test_split(train_, test_size=0.1666, random_state=42)
        self.lookup = {'train': train, 'val': val, 'test': test}

    def __len__(self):
        return len(self.lookup[self.cur_set])
        
    def __getitem__(self, idx):
        row = self.lookup[self.cur_set][idx]
        context = torch.tensor(row[1], dtype=torch.long).to(device)
        target = torch.tensor(row[0], dtype=torch.long).to(device)
        return context, target

    def get_dl(self, batch_size, shuffle=True, cur_set='train', drop_last=True):
        self.cur_set = cur_set
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)

In [103]:
class CBOWModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.linear = nn.Linear(embedding_dim, vocab_size)
    def forward(self, x):
        x = self.embeddings(x)
        x = torch.mean(x, dim=1)
        x = F.dropout(x, 0.2)
        x = self.linear(x)
        return x

In [104]:
model = CBOWModel(len(word2idx), 100).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

ds = CBOWDataset(main_dataset)

In [95]:
def set_seed(seed):
    torch.manual_seed(seed)
    if device == 'cuda':
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)

def train_model(epochs=1, batch_size=128):
    set_seed(42)
    for epoch in range(epochs):
        train_losses, train_accs = [], []
        dl = ds.get_dl(batch_size)
        model.train()
        for context, target in tqdm(dl):
            context = context.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            out = model(context)
            loss = loss_fn(out, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            train_accs.append((out.argmax(dim=1) == target).float().mean().item())
        
        val_losses, val_accs = [], []
        dl = ds.get_dl(batch_size, cur_set='val')
        model.eval()
        with torch.no_grad():
            for context, target in tqdm(dl):
                context = context.to(device)
                target = target.to(device)
                out = model(context)
                loss = loss_fn(out, target)
                val_losses.append(loss.item())
                val_accs.append((out.argmax(dim=1) == target).float().mean().item())

        print(f'Epoch: {epoch + 1}, Train Loss: {np.mean(train_losses):.4f}, Train Acc: {np.mean(train_accs):.4f}, Val Loss: {np.mean(val_losses):.4f}, Val Acc: {np.mean(val_accs):.4f}')

  0%|          | 0/10005 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch: 1, Train Loss: 6.5373, Train Acc: 0.1137, Val Loss: 6.2692, Val Acc: 0.1178


  0%|          | 0/10005 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# evaluate on test set
test_losses, test_accs = [], []
dl = ds.get_dl(batch_size, cur_set='test')
model.eval()
with torch.no_grad():
    losses, accs = [], []
    for context, target in tqdm(dl):
        context = context.to(device)
        target = target.to(device)
        out = model(context)
        loss = loss_fn(out, target)
        test_losses.append(loss.item())
        test_accs.append((out.argmax(dim=1) == target).float().mean().item())

print(f'Test Loss: {np.mean(test_losses):.4f}, Test Acc: {np.mean(test_accs):.4f}')

In [None]:
# save model state dict
path = '../models/cbow_model.pth'
torch.save(model.state_dict(), path)

# save word2idx
path = '../models/cbow_word2idx.json'
import json
json.dump(word2idx, open(path, 'w'))

In [1]:
from annoy import AnnoyIndex

In [1]:
from pprint import pprint


class WordPlay:
    def __init__(self, word2idx, embeddings):
        self.word2idx = word2idx
        self.idx2word = {idx: word for word, idx in word2idx.items()}
        self.embeddings = embeddings
        self.annoy_index = AnnoyIndex(embeddings.shape[1], 'angular')
        for idx in range(embeddings.shape[0]):
            self.annoy_index.add_item(idx, embeddings[idx])
        self.annoy_index.build(10)

    def get_similar_words(self, word, n=10):
        idx = self.word2idx.get(word, 0)
        if idx == 0:
            return []
        similar_indices, similar_distances = self.annoy_index.get_nns_by_vector(idx, n, include_distances=True)
        return [(self.idx2word[idx], dist) for idx, dist in zip(similar_indices, similar_distances) if idx in self.idx2word]
    
    def get_analogy(self, word1, word2, word3, n=10):
        idx1 = self.word2idx.get(word1, 0)
        idx2 = self.word2idx.get(word2, 0)
        idx3 = self.word2idx.get(word3, 0)
        if idx1 == 0 or idx2 == 0 or idx3 == 0:
            return []
        vec = self.embeddings[idx2] - self.embeddings[idx1] + self.embeddings[idx3]
        similar_indices, similar_distances = self.annoy_index.get_nns_by_vector(vec, n, include_distances=True)
        return [(self.idx2word[idx], dist) for idx, dist in zip(similar_indices, similar_distances) if idx in self.idx2word]
    
    def get_random_similar_words(self, word, n=10):
        random_words = np.random.choice(list(self.word2idx.keys()), n)
        for word in random_words:
            pprint(self.get_similar_words(word))
            print('-'*30)
            print()

In [3]:
play = WordPlay(word2idx, model.embeddings.weight.detach().cpu().numpy())
play.get_similar_words('vua')

NameError: name 'word2idx' is not defined

In [None]:
model = CBOWModel(len(word2idx), 100)
model.load_state_dict(torch.load('../models/cbow_model.pth'))

word2idx = json.load(open('../models/cbow_word2idx.json'))

In [None]:
# chose random words in the vocabulary
from pprint import pprint


random_words = np.random.choice(list(word2idx.keys()), 10)
for word in random_words:
    pprint(play.get_similar_words(word))
    print('-'*30)
    print()
