### [Source](https://towardsdatascience.com/implementing-word2vec-in-pytorch-from-the-ground-up-c7fe5bf99889)

In [63]:
import os
import random
import re
from collections import Counter, OrderedDict
from dataclasses import dataclass
from time import monotonic
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from scipy.spatial.distance import cosine
from torch.utils.data import DataLoader
from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import WikiText103
from tqdm import tqdm


In [237]:
def aromanian_iterator(file_path, chunk_size=1024):
    """
    Creates an iterator over the WikiText-103 file.

    :param file_path: Path to the WikiText-103 file.
    :param chunk_size: Number of characters to read in each iteration.
    :return: Yields chunks of text from the file.
    """
    with open(file_path, 'r', encoding='utf-8') as file:
        while True:
            chunk = file.read(chunk_size)
            if not chunk:
                break
            yield chunk

In [None]:
file_path = 'C:/Users/gheto/Desktop/PoS/AromanianPoS/dataset/Tales.test.ro'
for text_chunk in aromanian_iterator(file_path):
    # Process each text chunk
    print(text_chunk)

In [239]:
def get_data(train_dir='C:/Users/gheto/Desktop/PoS/AromanianPoS/dataset/Tales.train.ro', valid_dir='C:/Users/gheto/Desktop/PoS/AromanianPoS/dataset/Tales.test.ro'):
    # gets the data
    train_iter = aromanian_iterator(train_dir)
    train_iter = to_map_style_dataset(train_iter)
    valid_iter = aromanian_iterator(valid_dir)
    valid_iter = to_map_style_dataset(valid_iter)

    return train_iter, valid_iter

In [309]:
@dataclass
class Word2VecParams:

    # skipgram parameters
    MIN_FREQ = 1  # aici trebuie sa fie 1, altfel o sa fie token-uri in context care nu sunt si in center
    SKIPGRAM_N_WORDS = 20
    T = 85
    NEG_SAMPLES = 1
    NS_ARRAY_LEN = 5_000_000
    SPECIALS = "<unk>"
    TOKENIZER = 'basic_english'

    # network parameters
    BATCH_SIZE = 100
    EMBED_DIM = 300
    EMBED_MAX_NORM = None
    N_EPOCHS = 100
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    CRITERION = nn.BCEWithLogitsLoss()

In [310]:
class Vocab:
    def __init__(self, list, specials):
        self.stoi = {v[0]:(k, v[1]) for k, v in enumerate(list)}
        self.itos = {k:(v[0], v[1]) for k, v in enumerate(list)}
        self._specials = specials[0]
        self.total_tokens = np.nansum(
            [f for _, (_, f) in self.stoi.items()]
            , dtype=int)

    def __len__(self):
        return len(self.stoi) - 1

    def get_index(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.stoi:
                return self.stoi.get(word)[0]
            else:
                return self.stoi.get(self._specials)[0]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi:
                    res.append(self.stoi.get(w)[0])
                else:
                    res.append(self.stoi.get(self._specials)[0])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )


    def get_freq(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.stoi:
                return self.stoi.get(word)[1]
            else:
                return self.stoi.get(self._specials)[1]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi:
                    res.append(self.stoi.get(w)[1])
                else:
                    res.append(self.stoi.get(self._specials)[1])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )


    def lookup_token(self, token: Union[int, List]):
        if isinstance(token, (int, np.int64)):
            if token in self.itos:
                return self.itos.get(token)[0]
            else:
                raise ValueError(f"Token {token} not in vocabulary")
        elif isinstance(token, list):
            res = []
            for t in token:
                if t in self.itos:
                    res.append(self.itos.get(token)[0])
                else:
                    raise ValueError(f"Token {t} is not a valid index.")
            return res

In [311]:
def yield_tokens(iterator, tokenizer):
    r = re.compile('[a-z1-9]')
    for text in iterator:
        res = tokenizer(text)
        res = list(filter(r.match, res))
        yield res

def vocab(ordered_dict: Dict, min_freq: int = 1, specials: str = '<unk>'):
    tokens = []
    # Save room for special tokens
    for token, freq in ordered_dict.items():
        if freq >= min_freq:
            tokens.append((token, freq))

    specials = (specials, np.nan)
    tokens[0] = specials

    return Vocab(tokens, specials)

def pipeline(word, vocab, tokenizer):
    return vocab(tokenizer(word))

def build_vocab(
        iterator,
        tokenizer,
        params: Word2VecParams,
        max_tokens: Optional[int] = None,
    ):
    counter = Counter()
    for tokens in yield_tokens(iterator, tokenizer):
        counter.update(tokens)

    # First sort by descending frequency, then lexicographically
    sorted_by_freq_tuples = sorted(
        counter.items(), key=lambda x: (-x[1], x[0])
        )

    ordered_dict = OrderedDict(sorted_by_freq_tuples)

    # ordered_dict = OrderedDict(counter.items())

    word_vocab = vocab(
        ordered_dict, min_freq=params.MIN_FREQ, specials=params.SPECIALS
        )
    
    return word_vocab


In [312]:
class SkipGrams:
    def __init__(self, vocab: Vocab, vocab_context: Vocab, params: Word2VecParams, tokenizer):
        self.vocab = vocab
        self.vocab_context = vocab_context
        self.params = params
        self.t = self._t()
        self.tokenizer = tokenizer
        self.discard_probs = self._create_discard_dict()

    def _t(self):
        freq_list = []
        for _, (_, freq) in list(self.vocab.stoi.items())[1:]:
            freq_list.append(freq/self.vocab.total_tokens)
        return np.percentile(freq_list, self.params.T)


    def _create_discard_dict(self):
        discard_dict = {}
        for _, (word, freq) in self.vocab.stoi.items():
            dicard_prob = 1-np.sqrt(
                self.t / (freq/self.vocab.total_tokens + self.t))
            discard_dict[word] = dicard_prob
        return discard_dict


    def collate_skipgram(self, batch):
        batch_input, batch_output  = [], []
        for text in batch:
            # Chestia asta e o lista cu index-ul fiecarui cuvant(al catelea cel mai frecvent e)
            text_tokens = self.vocab.get_index(self.tokenizer(text))
            text_tokens_context = []
            for text in text_tokens:
                context_word = self.vocab_context.lookup_token(text)
                text_tokens_context.append(self.vocab_context.get_index(context_word))
    
            

            if len(text_tokens_context) < self.params.SKIPGRAM_N_WORDS * 2 + 1:
                continue

            for idx in range(len(text_tokens_context) - self.params.SKIPGRAM_N_WORDS*2
                ):
                token_id_sequence = text_tokens[
                    idx : (idx + self.params.SKIPGRAM_N_WORDS * 2 + 1)
                    ]
                
                # Aici e scos cuvantul central, dar eu l-as pastra(pentru a il adauga pe cel din romana la context) si ca sa nu strict codul ii dau append dupa
                input_ = token_id_sequence.pop(self.params.SKIPGRAM_N_WORDS)
                outputs = token_id_sequence

                # L-am adaugat inapoi aici
                outputs.append(input_)


                prb = random.random()
                del_pair = self.discard_probs.get(input_)
                if input_==0 or del_pair >= prb:
                    continue
                else:
                    for output in outputs:
                        prb = random.random()
                        del_pair = self.discard_probs.get(output)
                        if output==0 or del_pair >= prb:
                            continue
                        else:
                            batch_input.append(input_)
                            batch_output.append(output)

        batch_input = torch.tensor(batch_input, dtype=torch.long)
        batch_output = torch.tensor(batch_output, dtype=torch.long)

        return batch_input, batch_output

In [313]:
class NegativeSampler:
    def __init__(self, vocab: Vocab, ns_exponent: float, ns_array_len: int):
        self.vocab = vocab
        self.ns_exponent = ns_exponent
        self.ns_array_len = ns_array_len
        self.ns_array = self._create_negative_sampling()

    def __len__(self):
        return len(self.ns_array)

    def _create_negative_sampling(self):

        frequency_dict = {word:freq**(self.ns_exponent) \
                          for _,(word, freq) in
                          list(self.vocab.stoi.items())[1:]}
        frequency_dict_scaled = {
            word:
            max(1,int((freq/self.vocab.total_tokens)*self.ns_array_len))
            for word, freq in frequency_dict.items()
            }
        ns_array = []
        for word, freq in tqdm(frequency_dict_scaled.items()):
            ns_array = ns_array + [word]*freq
        return ns_array

    def sample(self,n_batches: int=1, n_samples: int=1):
        samples = []
        for _ in range(n_batches):
            samples.append(random.sample(self.ns_array, n_samples))
        samples = torch.as_tensor(np.array(samples))
        return samples


In [314]:
class Model(nn.Module):
    def __init__(self, vocab: Vocab, vocab_context:Vocab, params: Word2VecParams):
        super().__init__()
        self.vocab = vocab
        self.vocab_context = vocab_context
        self.t_embeddings = nn.Embedding(
            self.vocab.__len__()+1,
            params.EMBED_DIM,
            max_norm=params.EMBED_MAX_NORM
            )
        self.c_embeddings = nn.Embedding(
            self.vocab_context.__len__()+1,
            params.EMBED_DIM,
            max_norm=params.EMBED_MAX_NORM
            )

    def forward(self, inputs, context):
        # getting embeddings for target & reshaping
        target_embeddings = self.t_embeddings(inputs)
        n_examples = target_embeddings.shape[0]
        n_dimensions = target_embeddings.shape[1]
        target_embeddings = target_embeddings.view(n_examples, 1, n_dimensions)

        # get embeddings for context labels & reshaping
        # Allows us to do a bunch of matrix multiplications
        context_embeddings = self.c_embeddings(context)
        # * This transposes each batch
        context_embeddings = context_embeddings.permute(0,2,1)

        # * custom linear layer
        dots = target_embeddings.bmm(context_embeddings)
        dots = dots.view(dots.shape[0], dots.shape[2])
        return dots

    def normalize_embeddings(self):
        embeddings = list(self.t_embeddings.parameters())[0]
        embeddings = embeddings.cpu().detach().numpy()
        norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)
        norms = norms.reshape(norms.shape[0], 1)
        return embeddings / norms

    def get_similar_words(self, word, n):
        word_id = self.vocab.get_index(word)
        if word_id == 0:
            print("Out of vocabulary word")
            return

        embedding_norms = self.normalize_embeddings()
        word_vec = embedding_norms[word_id]
        word_vec = np.reshape(word_vec, (word_vec.shape[0], 1))
        dists = np.matmul(embedding_norms, word_vec).flatten()
        topN_ids = np.argsort(-dists)[1 : n + 1]

        topN_dict = {}
        for sim_word_id in topN_ids:
            sim_word = self.vocab_context.lookup_token(sim_word_id)
            topN_dict[sim_word] = dists[sim_word_id]
        return topN_dict

    def get_similarity(self, word1, word2):
        idx1 = self.vocab.get_index(word1)
        idx2 = self.vocab_context.get_index(word2)
        if idx1 == 0 or idx2 == 0:
            print("One or both words are out of vocabulary")
            return

        embedding_norms = self.normalize_embeddings()
        word1_vec, word2_vec = embedding_norms[idx1], embedding_norms[idx2]

        return cosine(word1_vec, word2_vec)

In [315]:
class Trainer:
    def __init__(self, model: Model, params: Word2VecParams, optimizer,
                vocab: Vocab, train_iter, valid_iter, skipgrams: SkipGrams):
        self.model = model
        self.optimizer = optimizer
        self.vocab = vocab
        self.train_iter = train_iter
        self.valid_iter = valid_iter
        self.skipgrams = skipgrams
        self.params = params

        self.epoch_train_mins = {}
        self.loss = {"train": [], "valid": []}

        # sending all to device
        self.model.to(self.params.DEVICE)
        self.params.CRITERION.to(self.params.DEVICE)

        self.negative_sampler = NegativeSampler(
            vocab=self.vocab, ns_exponent=.75,
            ns_array_len=self.params.NS_ARRAY_LEN
            )
        self.testwords = ['pe']


    def train(self):
        self.test_testwords()
        for epoch in range(self.params.N_EPOCHS):
            # Generate Dataloaders
            self.train_dataloader = DataLoader(
                self.train_iter,
                batch_size=self.params.BATCH_SIZE,
                shuffle=False,
                collate_fn=self.skipgrams.collate_skipgram
            )
            self.valid_dataloader = DataLoader(
                self.valid_iter,
                batch_size=self.params.BATCH_SIZE,
                shuffle=False,
                collate_fn=self.skipgrams.collate_skipgram
            )
            # training the model
            st_time = monotonic()
            self._train_epoch()
            self.epoch_train_mins[epoch] = round((monotonic()-st_time)/60, 1)

            # validating the model
            self._validate_epoch()
            print(f"""Epoch: {epoch+1}/{self.params.N_EPOCHS}\n""",
            f"""    Train Loss: {self.loss['train'][-1]:.2}\n""",
            f"""    Valid Loss: {self.loss['valid'][-1]:.2}\n""",
            f"""    Training Time (mins): {self.epoch_train_mins.get(epoch)}"""
            """\n"""
            )
            self.test_testwords()


    def _train_epoch(self):
        self.model.train()
        running_loss = []

        for i, batch_data in enumerate(self.train_dataloader, 1):
            if len(batch_data[0]) == 0:
                continue
            inputs = batch_data[0].to(self.params.DEVICE)
            pos_labels = batch_data[1].to(self.params.DEVICE)
            neg_labels = self.negative_sampler.sample(
                pos_labels.shape[0], self.params.NEG_SAMPLES
                )
            neg_labels = neg_labels.to(self.params.DEVICE)
            context = torch.cat(
                [pos_labels.view(pos_labels.shape[0], 1),
                neg_labels], dim=1
              )

            # building the targets tensor
            y_pos = torch.ones((pos_labels.shape[0], 1))
            y_neg = torch.zeros((neg_labels.shape[0], neg_labels.shape[1]))
            y = torch.cat([y_pos, y_neg], dim=1).to(self.params.DEVICE)

            self.optimizer.zero_grad()

            outputs = self.model(inputs, context)
            loss = self.params.CRITERION(outputs, y)
            loss.backward()
            self.optimizer.step()

            running_loss.append(loss.item())

        epoch_loss = np.mean(running_loss)

        self.loss['train'].append(epoch_loss)

    def _validate_epoch(self):
        self.model.eval()
        running_loss = []

        with torch.no_grad():
            for i, batch_data in enumerate(self.valid_dataloader, 1):
                if len(batch_data[0]) == 0:
                    continue
                inputs = batch_data[0].to(self.params.DEVICE)
                pos_labels = batch_data[1].to(self.params.DEVICE)
                neg_labels = self.negative_sampler.sample(
                    pos_labels.shape[0], self.params.NEG_SAMPLES
                    ).to(self.params.DEVICE)
                context = torch.cat(
                    [pos_labels.view(pos_labels.shape[0], 1),
                    neg_labels], dim=1
                  )


                # building the targets tensor
                y_pos = torch.ones((pos_labels.shape[0], 1))
                y_neg = torch.zeros((neg_labels.shape[0], neg_labels.shape[1]))
                y = torch.cat([y_pos, y_neg], dim=1).to(self.params.DEVICE)

                preds = self.model(inputs, context).to(self.params.DEVICE)
                loss = self.params.CRITERION(preds, y)

                running_loss.append(loss.item())

            epoch_loss = np.mean(running_loss)
            self.loss['valid'].append(epoch_loss)

    def test_testwords(self, n: int = 5):
        for word in self.testwords:
            print(word)
            nn_words = self.model.get_similar_words(word, n)
            for w, sim in nn_words.items():
                print(f"{w} ({sim:.3})", end=' ')
            print('\n')

In [316]:
params = Word2VecParams()
train_iter, valid_iter = get_data()
train_iter_context, valid_iter_context = get_data(train_dir='C:/Users/gheto/Desktop/PoS/AromanianPoS/dataset/Tales.train.rup', valid_dir='C:/Users/gheto/Desktop/PoS/AromanianPoS/dataset/Tales.test.rup')
tokenizer = get_tokenizer(params.TOKENIZER)
vocab_center = build_vocab(train_iter, tokenizer, params)
vocab_context = build_vocab(train_iter_context, tokenizer, params)
skip_gram = SkipGrams(vocab=vocab_center, vocab_context=vocab_context, params=params, tokenizer=tokenizer)
model = Model(vocab=vocab_center, vocab_context=vocab_context, params=params).to(params.DEVICE)
optimizer = torch.optim.Adam(params = model.parameters())

In [None]:
trainer = Trainer(
        model=model,
        params=params,
        optimizer=optimizer,
        train_iter=train_iter,
        valid_iter=valid_iter,
        vocab=vocab_context, # vocabularul de aici e folosit pentru negative sampling, pe care il facem din context
        skipgrams=skip_gram
    )
trainer.train()

In [318]:
def test_testwords(self, n: int = 5):
    for word in ["pe", "douăzeci", "mănâncă", "penele", "băiete", "furtuna", "zbura", "citind", "armânii", "oile"]:
        print(word)
        nn_words = self.model.get_similar_words(word, n)
        for w, sim in nn_words.items():
            print(f"{w} ({sim:.3})", end=' ')
        print('\n') 

test_testwords(trainer)

pe
păltări (0.207) calea-calea (0.199) padea (0.194) cu-alantu (0.186) pre-anarga (0.179) 

douăzeci
nicaț (0.226) adunai (0.199) 1906 (0.194) cuibul (0.18) asvindzeam (0.18) 

mănâncă
fapțîl’i (0.213) chetrile (0.198) tatălui (0.198) aistă (0.197) aușatic (0.197) 

penele
armân (0.209) topcă (0.198) arșițile (0.192) pitrumsiră (0.191) aprindeț (0.188) 

băiete
daț-le (0.208) tu-apirită (0.203) fumăria (0.192) asculți (0.188) harfă (0.188) 

furtuna
se-andrupă (0.202) suflite (0.188) s-filipsească (0.187) singură (0.182) ficior-fic (0.176) 

zbura
bag-u (0.195) di-aclo (0.189) armănea (0.188) arse (0.182) dorină (0.177) 

citind
nîs“ (0.206) piricl’iu (0.204) dusiră (0.184) graiurî (0.18) juneaște (0.179) 

armânii
bîna (0.197) nîs (0.195) acățață (0.186) dări (0.184) hapse (0.179) 

oile
ancl’igat (0.211) ponda (0.202) pri-aclo (0.187) altă-oară (0.184) mul’erle (0.176) 



In [308]:
def test_testwords(self, n: int = 5):
    for word in ["pe", "douăzeci", "mănâncă", "penele", "băiete", "furtuna", "zbura", "citind", "armânii", "oile"]:
        print(word)
        nn_words = self.model.get_similar_words(word, n)
        for w, sim in nn_words.items():
            print(f"{w} ({sim:.3})", end=' ')
        print('\n') 

test_testwords(trainer)

pe
chitroasă (0.213) vrură (0.201) lire (0.2) agioclu (0.196) hearele (0.194) 

douăzeci
turma (0.224) stole (0.199) guli (0.196) chiñi (0.191) s-le-aibă (0.187) 

mănâncă
mire (0.211) avhil’eate (0.207) s-minduia (0.19) l’epur (0.188) de-asime (0.187) 

penele
ciuciteaște (0.218) lu-agiungu (0.201) ambar (0.198) alăgară (0.197) stres-stres (0.187) 

băiete
zulăchilor (0.232) vilendză (0.23) mîcate (0.224) iu-ț (0.203) plăteaște (0.197) 

furtuna
cutie (0.249) bunu (0.202) mîcară (0.194) n-afla (0.191) tăl’eat (0.183) 

zbura
s-cutrimburară (0.225) s-γină“ (0.209) s-aungă (0.209) salți (0.199) angreacă (0.197) 

citind
bărbărută (0.232) alîndurle (0.216) cot (0.211) tălăgane (0.209) cuțute (0.208) 

armânii
a-nveastil’ei (0.191) căftară (0.187) avea-ntunicată (0.185) cărave (0.183) stihio (0.172) 

oile
fufulii (0.248) paradhis (0.198) mîrînγipsite (0.188) bucură-te (0.182) dizlichi (0.181) 



In [None]:
# Doar niste experimente
text_tokens = vocab_center.get_index(tokenizer("Sunt unele deprinderi pe care le câștigi doar citind, ori vorbindu-ți-se despre ele. Dar mai sunt și din cele ce nu ne intră în cap, până nu le vedem cu ochii. Cât trăim - multe auzim și multe vedem, dar cu toate astea - nimic nu vom ști ca lumea până nu punem mâna să și facem ceea ce ne pare că știm. Cu alte cuvinte, vreau să spun eu, teoria e bună ea, deseori, dar câte o dată rămâne de căruță față de practică. Așa grăia într-un rând un bătrân înțelept - către niște tineri ce încă nu ieșiseră din școală, și care își făceau ideea cum că... nimeni nu-i mai învățat ca ei, și că... tot ce zboară, se mănâncă! Ca să pricepeți ce vă spusei până aici, luați"))
print(text_tokens)
context_words = []
context_tokens = []
for text in text_tokens:
    context_word = vocab_context.lookup_token(text)
    context_words.append(context_word)
    context_tokens.append(vocab_context.get_index(context_word))
print(context_tokens)
print(context_words)

print(tokenizer("TEORIA SI PRACTICA \nSunt unele deprinderi pe care le câștigi doar citind, ori vorbindu-ți-se despre ele."))
print(tokenizer("TEORIA SI PRACTICA \nSănt îndoauă învețuri, cari si amintă mași cu ghivăsirea, i cu avdzărea."))

print(vocab_center.get_index("sunt"))
print(vocab_context.get_index("sănt"))