In [11]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class SkipGram(nn.Module):
    def __init__(self, vocab_size, emb_dim):
        super(SkipGram, self).__init__()
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.embeddings = nn.Embedding(vocab_size, emb_dim)  # Combine center and context

    def forward(self, center, context):
        center_emb = self.embeddings(center)
        output = torch.matmul(center_emb, self.embeddings(context).transpose(0, 1))
        return output

    def get_similarity(self, idx):
        with torch.no_grad():
            center_emb = self.embeddings.weight[idx]
            similarities = torch.cosine_similarity(
                center_emb.unsqueeze(0), self.embeddings.weight, dim=1
            )
        return similarities

    def get_cosine_distance(self, idx):
        similarities = self.get_similarity(idx)
        return 1 - similarities


class CorpusData(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        center, context = self.data[idx]
        return center, context

    def __len__(self):
        return len(self.data)


class Corpus:
    def __init__(self, file_name, window_size=1):
        self.file_name = file_name
        self.window_size = window_size
        self.word2id = {}
        self.id2word = {}
        self.vocab_size = 0
        self.data = []
        self.__build_data()

    def __iter__(self):
        with open(self.file_name, "r") as file:
            yield from file

    def __update_map(self, text):
        for word in text:
            if word not in self.word2id:
                self.word2id[word] = self.vocab_size
                self.id2word[self.vocab_size] = word
                self.vocab_size += 1

    def __write_pairs(self, text):
        num_words = len(text)
        for i, word in enumerate(text):
            center = self.word2id[word]
            for context_word in (
                text[max(0, i - self.window_size) : i]
                + text[i + 1 : min(num_words, i + self.window_size + 1)]
            ):
                context_id = self.word2id[context_word]
                self.data.append((center, context_id))

    def __build_data(self):
        for line in self:
            text = line.strip().split()
            text = [word.strip() for word in text if len(word) > 2]
            self.__update_map(text)
            self.__write_pairs(text)
        self.data = list(set(self.data))  # Remove duplicates

    def get_pairs(self):
        return CorpusData(self.data)


def train(model, dataloader, num_epochs, criterion, optimizer, device):
    model.train()
    log = []
    for epoch in range(num_epochs):
        total_loss = 0
        pbar = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch+1}")
        for center, context in pbar:
            output = model(center.to(device), context.to(device))
            loss = criterion(output, torch.zeros_like(output).to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * center.size(0)
            pbar.set_postfix(loss=f"{total_loss / len(dataloader.dataset):.2f}")
        log.append(total_loss / len(dataloader.dataset))
    model.eval()
    return model, log


In [12]:
DATA = "data/"
TRAIN_NORM = f"{DATA}/train_norm.txt"
TEST_NORM = f"{DATA}/test_norm.txt"

In [13]:
corpus = Corpus(TRAIN_NORM, 3)
corpus_train = CorpusData(corpus.data)
corpus_loader = DataLoader(corpus_train, batch_size=4096, shuffle=True)
N = corpus.vocab_size
H = 25
model = SkipGram(N, H)
device = "mps"
model = model.to(device)

In [14]:
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
criterion = nn.CrossEntropyLoss()

In [15]:
model, log = train(model, corpus_loader, 10,criterion, optimizer, device)

Epoch 1: 100%|██████████| 1251/1251 [00:37<00:00, 33.04it/s, loss=0.00]
Epoch 2:   7%|▋         | 84/1251 [00:02<00:29, 40.02it/s, loss=0.00]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x10354d3c0>>
Traceback (most recent call last):
  File "/Users/lucien/.pyenv/versions/3.10.14/envs/torch/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
Epoch 2:  39%|███▊      | 484/1251 [00:15<00:28, 27.38it/s, loss=0.00]