### Word2Vec

In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import pandas as pd
import json

In [99]:
RE_WORD = r'\w+'
re_word = re.compile(RE_WORD, re.U)

DATA_SET = [
    "data/frankenstein.txt",
    "data/moby_dick.txt",
    "data/freud.txt", 
    "data/wiki.txt", 
    "data/middlemarch.txt", 
    "data/pride_and_prejudice.txt", 
    "data/alice_in_wonderland.txt",
    "data/bleakhouse.txt",
    "data/crime_and_punishment.txt",
    "data/room_with_a_view.txt"
]

def extract_vocab(paths: list[str]) -> dict[str, int]:
    vocab = dict()
    for path in paths:
        with open(path, encoding="utf-8") as file:
            for word in (word.lower() for word in re.findall(re_word, file.read())):
                if word not in vocab:
                    vocab[word] = 1
                else:
                    vocab[word] = vocab[word] + 1

    words = (word for word, count in vocab.items() if count > 8)
    return { word:idx for idx, word in enumerate(words) }

def tokenize(text: str, vocab: dict[str, int]) -> list[str]:
    return [word.lower() for word in re.findall(re_word, text) if word in vocab]

def load_data(path: str) -> str:
    with open(path, encoding="utf-8") as file:
        text = file.read()
        return text

vocab = extract_vocab(DATA_SET)

print(len(vocab))

11024


In [100]:
def batch_to_one_hot(batch: list[list[int]], vocab_size: int) -> torch.Tensor:
    t = torch.zeros((len(batch), vocab_size), dtype=torch.float32)

    for i, indeces in enumerate(batch):
        t[i, indeces] = 1.0
    
    return t

def tokens_to_indeces(tokens: list[str], vocab: dict[str, int]) -> list[int]:
    return [vocab[token] for token in tokens]

def gen_text_to_tensors(text: str, vocab: dict[str, int], window_size: int, batch_size: int):
    tokens = tokenize(text, vocab)
    indeces = tokens_to_indeces(tokens, vocab)
    half_window = window_size // 2
    

    words = []
    context = []
    for i in range(half_window, len(indeces) - half_window):
        words.append([indeces[i]])
        context.append(indeces[i-half_window:i] + indeces[i+1:i+1+half_window])

        if (i % batch_size == 0):
            yield batch_to_one_hot(words, len(vocab)), batch_to_one_hot(context, len(vocab))
            words = []
            context = []
        
    yield batch_to_one_hot(words, len(vocab)), batch_to_one_hot(context, len(vocab))


In [105]:
VOCAB_SIZE = len(vocab)
VECTOR_SIZE = 100

model = nn.Sequential(
    nn.Linear(VOCAB_SIZE, VECTOR_SIZE),
    nn.Linear(VECTOR_SIZE, VOCAB_SIZE),
    nn.Softmax()
)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [106]:
epochs = 1

for epoch in range(epochs):
    for data in DATA_SET:
        text = load_data(data)
        
        for Y, X in gen_text_to_tensors(text, vocab, window_size=5, batch_size=1024):
            outputs = model(X.float())
            loss = loss_fn(outputs, Y.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


  return self._call_impl(*args, **kwargs)


In [None]:
embeddings = torch.transpose(list(model.parameters())[0].detach().clone(), 0, 1)

wordsim353 = pd.read_csv("wordsim353/combined.csv")

not_in_vocab = set()
wordsim353['Word 1'] = wordsim353['Word 1'].apply(lambda w: w.lower())
wordsim353['Word 2'] = wordsim353['Word 2'].apply(lambda w: w.lower())
wordsim353['Human (mean)'] = wordsim353['Human (mean)'].apply(lambda m: m / 5 - 1)

count = 0
error = 0
not_in_vocab = set()
for _, row in wordsim353.iterrows():
    w1 = row['Word 1']
    w2 = row['Word 2']
    human = row['Human (mean)']

    if w1 not in vocab:
        not_in_vocab.add(w1)
        continue
    if w2 not in vocab:
        not_in_vocab.add(w2)
        continue

    sim = F.cosine_similarity(embeddings[vocab[w1]].unsqueeze(0), embeddings[vocab[w2]].unsqueeze(0)).tolist()[0]
    error += (human - sim) ** 2
    count += 1
    print(f'{human:.3f}, {sim:.3f}, {w1}, {w2}')


print("MSE: ", error / count)

0.354 0.473 love sex
0.470 0.871 tiger cat
1.000 1.000 tiger tiger
0.492 0.736 book paper
0.524 0.607 computer keyboard
0.516 0.548 computer internet
0.500 -0.801 telephone communication
0.354 0.076 television radio
0.484 -0.412 media radio
0.370 -0.295 drug abuse
0.238 -0.114 bread butter
0.400 0.637 doctor nurse
0.324 0.366 professor doctor
0.362 0.305 student professor
-0.076 0.690 smart student
0.162 -0.500 smart stupid
0.416 0.814 company stock
0.616 0.666 stock market
-0.676 0.360 stock phone
-0.638 0.590 stock egg
-0.254 0.857 stock live
-0.816 0.193 stock life
0.492 0.734 book library
0.624 0.918 bank money
0.546 0.164 wood forest
0.830 0.311 money cash
0.716 0.413 king queen
0.692 0.225 jerusalem israel
0.530 0.214 jerusalem palestinian
-0.676 0.581 holy sex
0.724 0.048 maradona football
0.806 0.044 football soccer
0.326 0.977 football tennis
0.512 -0.702 tennis racket
0.346 0.072 arafat peace
0.530 -0.034 arafat terror
-0.500 0.040 arafat jackson
0.676 0.834 law lawyer
0.476 

In [None]:
with open("out/embeds.json", "+w", encoding="utf-8") as file:
    file.write(json.dumps({ word:embeddings[idx].tolist() for word, idx in vocab.items() }))