In [1]:
import itertools
import re
from tqdm import tqdm
from itertools import product
import random
from collections import Counter, defaultdict, deque
import sys
import os
import pandas as pd
import numpy as np
from numpy.random import multinomial
from scipy.spatial import distance
from IPython.core.display import display
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import nltk
from typing import List, Tuple, Union, Dict
from nltk.tokenize import WordPunctTokenizer
tokenizer = WordPunctTokenizer()
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()

In [2]:
"""
TODO:
- subsampling
- init vectors with TFIDF instead of one-hot
-
"""

'\nTODO:\n- subsampling\n- init vectors with TFIDF instead of one-hot\n-\n'

In [3]:
if torch.cuda.is_available():
    dev = "cuda:0"
    print("cuda available")
else:
    dev = "cpu"
    print("cuda not available")
device = torch.device(dev)

cuda available


In [None]:
with open("war.txt", "r") as f:
    data = f.read()
data = data.replace('\n', ' ')
data[:50]

In [None]:
%%time
# tokenize and lemmatize
tokens = tokenizer.tokenize(data)
tokens = [lemmatizer.lemmatize(token.lower()) for token in tokens]
tokens[:10]

In [None]:
Counter(tokens).most_common()[:10]

In [None]:
# ?!
tokens = [re.sub('[^a-zA-Z\s\d]+', '', w) for w in tokens if re.sub('[^a-zA-Z\s\d]+', '', w) != '']
tokens[:10]

In [None]:
Counter(tokens).most_common()[:10]

In [None]:
def gen_word_to_idx(tokens, min_words: int, max_freq: float):
    count = Counter(tokens).most_common()
    max_count = count[0][1]
    max_words = max_freq * max_count
    filtered_words = [c[0] for c in count if min_words < c[1] < max_words]
    word2idx = {word: idx+1 for idx, word in enumerate(filtered_words)}
    word2idx["<unk>"] = 0
    return word2idx

In [None]:
word2idx = gen_word_to_idx(tokens, min_words=10, max_freq=0.9)
# word2idx

In [None]:
# get pairs
def batch_generator(tokens, window_size=5):
    for i in range(0, len(tokens)-window_size):
        context = [word2idx[token] if token in word2idx else word2idx["<unk>"] for token in tokens[i: i+window_size]]
        # print(context)
        central_word = context[window_size//2]
        # print(central_word)
        context.pop(window_size//2)  # remove central word from context
        pairs = [(central_word, context_word, 1) for context_word in context]
        # print(pairs)
        # TODO: remove from word2idx <unk> and words from context
        neg_words = np.random.choice(list(word2idx.values()), window_size-1)
        # print(neg_words)
        neg_pairs = [(central_word, neg_word, 0) for neg_word in neg_words]
        # print(neg_pairs)
        batch = pairs + neg_pairs
        # print(batch)
        yield batch

In [14]:
class DataHandler:
    def __init__(self, path: str = "war_and_peace.txt",
                 debug: bool = False,
                 min_words: int = 10,
                 max_freq: float = 0.9,
                 window_size: int = 5
                 ):
        self.debug = debug
        self.path = path
        self.min_words = min_words
        self.max_freq = max_freq
        self.window_size = window_size
        self.data = self.data_loader()
        self.tokens = self.preprocessor(self.data)
        self.word2idx = self.gen_word_to_idx(self.tokens)

    def data_loader(self):
        with open(self.path, "r") as f:
            data = f.read()
        data = data.replace('\n', ' ')
        return data

    def preprocessor(self, data):
        tokens = tokenizer.tokenize(data)
        tokens = [lemmatizer.lemmatize(token.lower()) for token in tokens]
        tokens = [re.sub('[^a-zA-Z\s\d]+', '', w) for w in tokens if re.sub('[^a-zA-Z\s\d]+', '', w) != '']
        return tokens

    def gen_word_to_idx(self, tokens):
        count = Counter(tokens).most_common()
        max_count = count[0][1]
        max_words = self.max_freq * max_count
        filtered_words = [c[0] for c in count if self.min_words < c[1] < max_words]
        word2idx = {word: idx+1 for idx, word in enumerate(filtered_words)}
        word2idx["<unk>"] = 0
        return word2idx

    def batch_generator(self):
        if self.debug:
            tokens = self.tokens[:1000]
        else:
            tokens = self.tokens
        for i in range(0, len(tokens)-self.window_size):
            context = [self.word2idx[token] if token in self.word2idx else self.word2idx["<unk>"]
                       for token in tokens[i: i+self.window_size]]
            central_word = context[self.window_size//2]
            context.pop(self.window_size//2)  # remove central word from context
            pairs = [(central_word, context_word, 1) for context_word in context]
            # TODO: remove from word2idx <unk> and words from context
            neg_words = np.random.choice(list(self.word2idx.values()), self.window_size-1)
            neg_pairs = [(central_word, neg_word, 0) for neg_word in neg_words]
            batch = pairs + neg_pairs
            yield batch

In [15]:
class SkipGramModel(nn.Module):

    def __init__(self,
                 vocab_size: int = 1000,
                 embedding_size: int = 100,
                 ):
        super(SkipGramModel, self).__init__()
        if torch.cuda.is_available():
            dev = "cuda:0"
            print("cuda available")
        else:
            dev = "cpu"
            print("cuda not available")
        self.device = torch.device(dev)
        self.u_embeddings = nn.Embedding(vocab_size, embedding_size)
        self.v_embeddings = nn.Embedding(vocab_size, embedding_size)

    def forward(self, u, v):
        emb_u = self.u_embeddings(u)  # batch_size * emb_dimension
        emb_v = self.v_embeddings(v)  # batch_size * emb_dimension
        logits = torch.sum(emb_u * emb_v, dim=-1)
        return logits

In [16]:
class Word2Vec(nn.Module):

    def __init__(self,
                 embedding_size: int = 100,
                 debug: bool = False,
                 path: str = "war_and_peace.txt",
                 output_path: str = "output/",
                 min_words: int = 10,
                 max_freq: float = 0.9,
                 window_size: int = 5,
                 num_epochs: int = 2,
                 initial_lr: float = 1e-3,
                 ):
        super(Word2Vec, self).__init__()
        # data params
        self.path = path
        self.output_path = output_path
        self.debug = debug
        if self.debug:
            print("debug mode on")
        self.min_words = min_words
        self.max_freq = max_freq
        self.embedding_size = embedding_size
        self.window_size = window_size
        self.data = DataHandler(
            debug=debug,
            path=path,
            min_words=min_words,
            max_freq=max_freq,
            window_size=window_size
        )
        # model params
        self.num_epochs = num_epochs
        self.initial_lr = initial_lr
        self.model = SkipGramModel(len(self.data.tokens), self.embedding_size)
        self.device = self.model.device
        self.model.to(self.device)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.initial_lr)
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.losses = []
        # embeddings params
        self.vocab_size = len(set(self.data.tokens))
        self.u_embeddings = nn.Embedding(self.vocab_size, self.embedding_size)
        self.v_embeddings = nn.Embedding(self.vocab_size, self.embedding_size)


    def train(self):
        for epoch in range(1, self.num_epochs+1):
            tmp_loss = []
            for batch in tqdm(self.data.batch_generator()):
                # get word ids
                u, v, target = zip(*batch)
                u, v, target = torch.tensor(u, dtype=torch.int64, device=self.device), \
                               torch.tensor(v, dtype=torch.int64, device=self.device), \
                               torch.tensor(target, dtype=torch.float32, device=self.device)
                self.optimizer.zero_grad()
                logits = self.model(u, v)
                loss = self.criterion(logits, target)
                tmp_loss.append(loss.mean().item())
                loss.backward()
                self.optimizer.step()
            self.losses.append(np.mean(tmp_loss))
            print(epoch, np.round(self.losses, 5))
        self.save_model()
        self.save_vectors()

    def save_model(self):
        torch.save(self.model, self.output_path+"model.pth")

    def save_vectors(self):
        embedding = (self.model.u_embeddings.weight.cpu().data.numpy()
                     + self.model.v_embeddings.weight.cpu().data.numpy())/2
        idx2word = {idx: word for word, idx in self.data.word2idx.items()}
        vectors_file = open(self.output_path+"word_vectors.txt", 'w', encoding='utf-8')
        vectors_file.write('%d %d\n' % (len(idx2word), self.embedding_size))

        for wid, w in idx2word.items():
            e = embedding[wid]
            e = ' '.join(map(lambda x: str(x), e))
            vectors_file.write('%s %s\n' % (w, e))

In [17]:
%%time
w2v = Word2Vec(
    embedding_size=100,
    debug=False,  # !!!
    path="war_and_peace.txt",
    output_path="output/",
    min_words=10,
    max_freq=0.9,
    window_size=5,
    num_epochs=5,  # !!!
    initial_lr=1e-3,
              )
w2v.train()

cuda available
1 [3.14692]
2 [3.14692 2.55352]
3 [3.14692 2.55352 2.28712]
4 [3.14692 2.55352 2.28712 2.11767]
5 [3.14692 2.55352 2.28712 2.11767 1.98921]
CPU times: user 5h 24min 13s, sys: 2h 20min 43s, total: 7h 44min 57s
Wall time: 7h 47min 33s


576606it [1:33:55, 102.32it/s]
576606it [1:33:35, 102.68it/s]
576606it [1:33:23, 102.90it/s]
576606it [1:33:21, 102.95it/s]
576606it [1:33:13, 103.09it/s]
  "type " + obj.__name__ + ". It won't be checked "


In [18]:
w2v.losses

[3.1469166241086755,
 2.553516806131806,
 2.2871161223628347,
 2.11766518091871,
 1.9892079724328273]

In [27]:
# king - man + woman ≈ queen
w2v.data.word2idx['king'], w2v.data.word2idx['man'], w2v.data.word2idx['woman'], w2v.data.word2idx['princess']

(833, 60, 218, 75)

In [29]:
embedding = (w2v.model.u_embeddings.weight.cpu().data.numpy()
                     + w2v.model.v_embeddings.weight.cpu().data.numpy())/2

In [37]:
result_emb = embedding[833] - embedding[60] + embedding[218]

In [38]:
from scipy.spatial import distance
distance.cosine(result_emb, embedding[75])


0.970347199589014