Word to vect from scratch: https://towardsdatascience.com/implementing-word2vec-in-pytorch-from-the-ground-up-c7fe5bf99889  
Word2Vec paper: https://arxiv.org/pdf/1301.3781

In [1]:
import json
import numpy as np

In [2]:
# Read evidence
with open('../data/evidence.json', 'r') as f:
    evidence = json.load(f)
evidence = [str.lower(j) for i,j in evidence.items()]

In [3]:
# MODIFY
from dataclasses import dataclass, field
import torch
import torch.nn as nn
     

@dataclass(repr=True)
class Word2VecParams:

    # skipgram parameters
    MIN_FREQ = 50               # Minimum occurency of words
    SKIPGRAM_N_WORDS = 8        # n-grams to consider the tager (Bidirectional)
    T = 85                      # words with frequency in the 85th percentile will have a small probability  of being subsampled
    NEG_SAMPLES = 50
    NS_ARRAY_LEN = 5_000_000
    SPECIALS = ""               # Excluded words
    TOKENIZER = 'basic_english'

    # network parameters
    BATCH_SIZE = 100
    EMBED_DIM = 300
    EMBED_MAX_NORM = None
    N_EPOCHS = 5
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    CRITERION = nn.BCEWithLogitsLoss()
     

params = Word2VecParams()

In [4]:
class Vocab:
    def __init__(self, list, specials):
        self.stoi = {v[0]:(k, v[1]) for k, v in enumerate(list)}
        self.itos = {k:(v[0], v[1]) for k, v in enumerate(list)}
        self._specials = specials[0]
        self.total_tokens = np.nansum(
            [f for _, (_, f) in self.stoi.items()]
            , dtype=int)

    def __len__(self):
        return len(self.stoi) - 1

    # Return word from counter dict
    def get_index(self, word):
        if isinstance(word, str):
            if word in self.stoi: 
                return self.stoi.get(word)[0]
            else:
                return self.stoi.get(self._specials)[0]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi: 
                    res.append(self.stoi.get(w)[0])
                else:
                    res.append(self.stoi.get(self._specials)[0])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )


    def get_freq(self, word):
        if isinstance(word, str):
            if word in self.stoi: 
                return self.stoi.get(word)[1]
            else:
                return self.stoi.get(self._specials)[1]
        elif isinstance(word, list):
            res = []
            for w in word:
                if w in self.stoi:
                    res.append(self.stoi.get(w)[1])
                else:
                    res.append(self.stoi.get(self._specials)[1])
            return res
        else:
            raise ValueError(
                f"Word {word} is not a string or a list of strings."
                )
    

    def lookup_token(self, token):
        if isinstance(token, (int, np.int64)):
            if token in self.itos:
                return self.itos.get(token)[0]
            else:
                raise ValueError(f"Token {token} not in vocabulary")
        elif isinstance(token, list):
            res = []
            for t in token:
                if t in self.itos:
                    res.append(self.itos.get(token)[0])
                else:
                    raise ValueError(f"Token {t} is not a valid index.")
            return res

In [5]:
def vocab(ordered_dict:dict, min_freq: int = 1, specials: str = '<unk>'):
    tokens = []
    # Save room for special tokens
    for token, freq in ordered_dict.items():
        if freq >= min_freq:
            tokens.append((token, freq))

    specials = (specials, np.nan)
    tokens[0] = specials

    return Vocab(tokens, specials)

In [6]:

from collections import Counter, OrderedDict
def build_vocab(
        iterator,                   # Training set
        tokenizer,                  # English tokenizer
        params: Word2VecParams,     # Instanciated parameters
        max_tokens = None,
    ):

    # Fill counter with filtered tokens
    counter = Counter()
    for tokens in yield_tokens(iterator, tokenizer):
        counter.update(tokens)

    # First sort by descending frequency, then lexicographically
    sorted_by_freq_tuples = sorted(
        counter.items(), key=lambda x: (-x[1], x[0])
        )

    ordered_dict = OrderedDict(sorted_by_freq_tuples)

    word_vocab = vocab(
        ordered_dict, min_freq=params.MIN_FREQ, specials=params.SPECIALS
        )
    return word_vocab

In [7]:
from torchtext.data import get_tokenizer
import re

# Parameters
tokenizer = get_tokenizer("basic_english")
MIN_FREQ = 50 
SPECIALS = "<unk>"

# Function to filter text
def filter_tokens(iterator, tokenizer):
    r = re.compile('[a-z1-9]')      # Only numbers and letters
    for text in iterator:
        res = tokenizer(text)
        res = list(filter(r.match, res))
        yield res

# Fill counter with filtered tokens
counter = Counter()
for tokens in filter_tokens(evidence[:100000], tokenizer=tokenizer):
    counter.update(tokens)

# Sort lexicographically
word_vocab = vocab(counter, min_freq=MIN_FREQ, specials=SPECIALS)



In [8]:
# First sort by descending frequency, then lexicographically
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0]))

In [14]:
word_vocab.stoi

{'<unk>': (0, nan),
 'english': (1, 1075),
 'and': (2, 56892),
 'agricultural': (3, 126),
 'scientist': (4, 95),
 'began': (5, 618),
 'his': (6, 7236),
 'professional': (7, 883),
 'career': (8, 772),
 'at': (9, 10288),
 'the': (10, 137026),
 'age': (11, 484),
 'of': (12, 71072),
 '16': (13, 619),
 'eventually': (14, 222),
 'moving': (15, 127),
 'to': (16, 30629),
 'new': (17, 3524),
 'york': (18, 1003),
 'city': (19, 2098),
 'in': (20, 62470),
 '1977': (21, 279),
 'cambridge': (22, 136),
 'by': (23, 15433),
 'francis': (24, 109),
 'born': (25, 4337),
 'october': (26, 1209),
 '20': (27, 731),
 '1936': (28, 161),
 'was': (29, 25498),
 'a': (30, 46819),
 'ice': (31, 466),
 'hockey': (32, 424),
 'player': (33, 1000),
 'who': (34, 3668),
 'played': (35, 1758),
 '40': (36, 232),
 'games': (37, 875),
 'national': (38, 2603),
 'league': (39, 1568),
 'he': (40, 11619),
 'function': (41, 141),
 'how': (42, 276),
 'modified': (43, 55),
 'release': (44, 428),
 'with': (45, 11287),
 'peak': (46, 17

In [12]:
counter

Counter({'the': 137026,
         'of': 71072,
         'in': 62470,
         'and': 56892,
         'a': 46819,
         'is': 32587,
         'to': 30629,
         'was': 25498,
         'as': 16121,
         'for': 15733,
         'by': 15433,
         'on': 14715,
         'it': 11623,
         'he': 11619,
         's': 11499,
         'with': 11287,
         'from': 10752,
         'at': 10288,
         'an': 9586,
         'that': 7795,
         'his': 7236,
         'are': 5741,
         'also': 5515,
         'which': 5475,
         'or': 5022,
         'has': 5003,
         'born': 4337,
         'its': 4283,
         'first': 4217,
         'this': 3964,
         'one': 3864,
         'were': 3792,
         'be': 3718,
         'who': 3668,
         'new': 3524,
         'their': 2971,
         'american': 2930,
         'after': 2891,
         'known': 2869,
         'district': 2852,
         'united': 2804,
         'been': 2730,
         'have': 2702,
         'she': 2621

In [17]:
5*5*16


400