In [57]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams
import seaborn as sns
import warnings

sns.set()
rcParams['figure.figsize'] = (20,10)
pd.options.display.max_columns = None
warnings.filterwarnings('ignore')

import re

from pprint import pprint

from annoy import AnnoyIndex

In [58]:
def preprocess(text):
    text = text.lower().strip()
    text = re.sub(r'[^\s\wáàảãạăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổỗộơớờởỡợíìỉĩịúùủũụưứừửữựýỳỷỹỵđ_]','',text)
    text = re.sub(r'\s+', ' ', text)
    return text

In [59]:
word_count = {}
with open('./data/all_simplified2.txt', 'r', encoding='utf-8') as f:
    for line in f.readlines():
        for word in preprocess(line).split():
            if word in ['punct','number']: continue
            word_count[word] = word_count.get(word, 0) + 1

In [60]:
word_count_df = pd.DataFrame(word_count.items(), columns=['word', 'count'])
word_count_df = word_count_df.sort_values('count', ascending=False)

In [61]:
vocab = word_count_df[word_count_df['count'] >= 15]['word'].values
len(vocab)

12199

In [62]:
word2idx = {word: idx for idx, word in enumerate(vocab)}

In [63]:
dataset = []
window = 2
with open('./data/all_simplified2.txt', 'r', encoding='utf-8') as f:
    for line in f.readlines():
        words = preprocess(line).split()
        for i, word in enumerate(words):
            if word not in word2idx: continue
            for j in range(i-window, i+window+1):
                if j < 0 or j >= len(words) or j == i: continue
                if words[j] not in word2idx: continue
                dataset.append((word2idx[word], word2idx[words[j]]))

In [64]:
from sklearn.model_selection import train_test_split

In [65]:
class SkipGramDataset(Dataset):
    def __init__(self, dataset, cur_set='train'):
        self.cur_set = cur_set
        train_, test = train_test_split(dataset, test_size=0.1, random_state=42)
        train, valid = train_test_split(train_, test_size=0.1666, random_state=42)
        self.lookup = {
            'train': train.copy(),
            'valid': valid.copy(),
            'test': test.copy()
        }

    def __len__(self):
        return len(self.lookup[self.cur_set])

    def __getitem__(self, idx):
        x, y = self.lookup[self.cur_set][idx]
        return torch.tensor(x).to(device), torch.tensor(y).to(device)
    
    def get_dl(self, cur_set='train', batch_size=32, shuffle=True, drop_last=True):
        self.cur_set = cur_set
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)

In [66]:
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.linear(x)
        return x

In [67]:
dataset = SkipGramDataset(dataset)
model = SkipGramModel(len(word2idx), 100).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [68]:
def train(model, dataset, loss_fn, optimizer, epochs=5):
    for epoch in range(epochs):
        losses, accs = [], []
        model.train()
        for x, y in tqdm(dataset.get_dl('train', batch_size=32, shuffle=True, drop_last=True)):
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            accs.append((y_hat.argmax(1) == y).float().mean().item())

        val_losses, val_accs = [], []
        model.eval()
        with torch.no_grad():
            for x, y in tqdm(dataset.get_dl('valid', batch_size=32, shuffle=True, drop_last=True)):
                y_hat = model(x)
                loss = loss_fn(y_hat, y)
                val_losses.append(loss.item())
                val_accs.append((y_hat.argmax(1) == y).float().mean().item())

        print(f'Epoch {epoch+1}/{epochs}: Train loss: {np.mean(losses):.4f}, Train acc: {np.mean(accs):.4f}, Val loss: {np.mean(val_losses):.4f}, Val acc: {np.mean(val_accs):.4f}')

def test(model, dataset, loss_fn):
    losses, accs = [], []
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(dataset.get_dl('test', batch_size=32, shuffle=True, drop_last=True)):
            y_hat = model(x)
            loss = loss_fn(y_hat, y)
            losses.append(loss.item())
            accs.append((y_hat.argmax(1) == y).float().mean().item())
    print(f'Test loss: {np.mean(losses):.4f}, Test acc: {np.mean(accs):.4f}')

In [69]:
train(model, dataset, loss_fn, optimizer, epochs=3)

  0%|          | 0/314380 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
test(model, dataset, loss_fn)

In [None]:
# save model
torch.save(model.state_dict(), './model/skipgram.pth')

In [None]:
from pprint import pprint
class WordPlay:
    def __init__(self, word2idx, embeddings):
        self.word2idx = word2idx
        self.idx2word = {idx: word for word, idx in word2idx.items()}
        self.embeddings = embeddings
        self.annoy_index = AnnoyIndex(embeddings.shape[1], 'angular')
        for idx in range(embeddings.shape[0]):
            self.annoy_index.add_item(idx, embeddings[idx])
        self.annoy_index.build(10)

    def get_similar_words(self, word, n=10):
        idx = self.word2idx.get(word, 0)
        if idx == 0:
            return []
        similar_indices, similar_distances = self.annoy_index.get_nns_by_item(idx, n, include_distances=True)
        return [(self.idx2word[idx], dist) for idx, dist in zip(similar_indices, similar_distances) if idx in self.idx2word]
    
    def get_analogy(self, word1, word2, word3, n=10):
        idx1 = self.word2idx.get(word1, 0)
        idx2 = self.word2idx.get(word2, 0)
        idx3 = self.word2idx.get(word3, 0)
        if idx1 == 0 or idx2 == 0 or idx3 == 0:
            return []
        vec = self.embeddings[idx2] - self.embeddings[idx1] + self.embeddings[idx3]
        similar_indices, similar_distances = self.annoy_index.get_nns_by_vector(vec, n, include_distances=True)
        return [(self.idx2word[idx], dist) for idx, dist in zip(similar_indices, similar_distances) if idx in self.idx2word]
    
    def get_random_similar_words(self, n=10):
        random_words = np.random.choice(list(self.word2idx.keys()), n)
        for word in random_words:
            pprint(self.get_similar_words(word))
            print('-'*30)
            print()

In [None]:
word_play = WordPlay(word2idx, model.embedding.weight.detach().cpu().numpy())

In [None]:
word_play.get_random_similar_words(n=10)