In [4]:
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import re

In [5]:
@dataclass(repr=True)
class Word2VecParams:

    # skipgram parameters
    MIN_FREQ = 50 
    SKIPGRAM_N_WORDS = 5
    T = 85
    NEG_SAMPLES = 50
    NS_ARRAY_LEN = 5_000_000
    SPECIALS = ""
    TOKENIZER = 'basic_english'

    # network parameters
    BATCH_SIZE = 100
    EMBED_DIM = 300
    EMBED_MAX_NORM = None
    N_EPOCHS = 5
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    CRITERION = nn.BCEWithLogitsLoss()

In [6]:
class Vocab:
    def __init__(self, list, specials):
        self.stoi = {v[0]:(k, v[1]) for k, v in enumerate(list)}
        self.itos = {k:(v[0], v[1]) for k, v in enumerate(list)}
        self._specials = specials[0]
        self.total_tokens = np.nansum(
            [f for _, (_, f) in self.stoi.items()]
            , dtype=int)

    def __len__(self):
        return len(self.stoi) - 1

    def get_index(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.stoi: 
                return self.stoi.get(word)[0]
            else:
                return self.stoi.get(self._specials)[0]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi: 
                    res.append(self.stoi.get(w)[0])
                else:
                    res.append(self.stoi.get(self._specials)[0])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )


    def get_freq(self, word: Union[str, List]):
        if isinstance(word, str):
            if word in self.stoi: 
                return self.stoi.get(word)[1]
            else:
                return self.stoi.get(self._specials)[1]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi:
                    res.append(self.stoi.get(w)[1])
                else:
                    res.append(self.stoi.get(self._specials)[1])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )
    

    def lookup_token(self, token: Union[int, List]):
        if isinstance(token, (int, np.int64)):
            if token in self.itos:
                return self.itos.get(token)[0]
            else:
                raise ValueError(f"Token {token} not in vocabulary")
        elif isinstance(token, list):
            res = []
            for t in token:
                if t in self.itos:
                    res.append(self.itos.get(token)[0])
                else:
                    raise ValueError(f"Token {t} is not a valid index.")
            return res

NameError: name 'Union' is not defined

In [9]:
from collections import Counter, OrderedDict
from typing import Optional

import numpy as np


def yield_tokens(iterator, tokenizer):
    r = re.compile('[a-z1-9]')
    for text in iterator:
        res = tokenizer(text)
        res = list(filter(r.match, res))
        yield res
    
def build_vocab(
        iterator,
        tokenizer, 
        params: Word2VecParams,
        max_tokens: Optional[int] = None,
    ):
    counter = Counter()
    for tokens in yield_tokens(iterator, tokenizer):
        counter.update(tokens)

    sorted_by_freq_tuples = sorted(
        counter.items(), key=lambda x: (-x[1], x[0])
        )

    ordered_dict = OrderedDict(sorted_by_freq_tuples)

    tokens = []
    for token, freq in ordered_dict.items():
        if freq >= params.MIN_FREQ:
            tokens.append((token, freq))

    specials = (params.SPECIALS, np.nan)
    tokens[0] = specials

    return Vocab(tokens, specials)


In [None]:
build_vocab("dataset/Tales.test.ro", "basic_english")
