## Data Preperation

In [1]:
import re

In [2]:
with open("../data/data.txt", "r") as f:
	data = f.read()

In [3]:
def clean_and_tokenize(text):
	cleaned_text = re.sub(r"[^a-zA-Z]", " ", text)
	cleaned_text = re.sub(r"\s+", " ", cleaned_text)
	cleaned_text = cleaned_text.lower()
	tokens = cleaned_text.split(" ")
	with open("../data/stopwords-en.txt", "r") as f:
		stop_words = f.read()
	stop_words = stop_words.replace("\n", " ").split(" ")
	return [token for token in tokens if token not in stop_words[:-1]]

tokens = clean_and_tokenize(data)

In [4]:
unique_words = set(tokens)
word_i = {word: i for (i, word) in enumerate(unique_words)}
i_word = {i: word for (i, word) in enumerate(unique_words)}

### Training Data

In [5]:
window_size = 2


def target_context_tuples(tokens, window_size):
	context = []
	for i, token in enumerate(tokens):
		context_words = [t for t in merge(tokens, i, window_size) if t != token]
		for c in context_words:
			context.append((token, c))
	return context


def merge(tokens, i, window_size):
	left_id = i - window_size if i >= window_size else i - 1 if i != 0 else i
	right_id = i + window_size + 1 if i + window_size <= len(tokens) else len(tokens)
	return tokens[left_id:right_id]

In [6]:
target_context_pairs = target_context_tuples(tokens, 2)
target_context_pairs[:20]

[('deep', 'learning'),
 ('deep', 'subset'),
 ('learning', 'deep'),
 ('learning', 'subset'),
 ('learning', 'machine'),
 ('subset', 'deep'),
 ('subset', 'learning'),
 ('subset', 'machine'),
 ('subset', 'learning'),
 ('machine', 'learning'),
 ('machine', 'subset'),
 ('machine', 'learning'),
 ('machine', 'methods'),
 ('learning', 'subset'),
 ('learning', 'machine'),
 ('learning', 'methods'),
 ('learning', 'based'),
 ('methods', 'machine'),
 ('methods', 'learning'),
 ('methods', 'based')]

In [7]:
import pandas as pd

df = pd.DataFrame(target_context_pairs, columns=["target", "context"])

In [8]:
import torch.nn.functional as F
import torch

vocab_size = len(unique_words)
token_indexes = [word_i[token] for token in unique_words]
encodings = F.one_hot(torch.tensor(token_indexes), num_classes=vocab_size).float()

df["target_ohe"] = df["target"].apply(lambda x: encodings[word_i[x]])
df["context_ohe"] = df["context"].apply(lambda x: encodings[word_i[x]])

In [9]:
df.head()

Unnamed: 0,target,context,target_ohe,context_ohe
0,deep,learning,"[tensor(0.), tensor(0.), tensor(0.), tensor(0....","[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
1,deep,subset,"[tensor(0.), tensor(0.), tensor(0.), tensor(0....","[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
2,learning,deep,"[tensor(0.), tensor(0.), tensor(0.), tensor(0....","[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
3,learning,subset,"[tensor(0.), tensor(0.), tensor(0.), tensor(0....","[tensor(0.), tensor(0.), tensor(0.), tensor(0...."
4,learning,machine,"[tensor(0.), tensor(0.), tensor(0.), tensor(0....","[tensor(0.), tensor(0.), tensor(0.), tensor(0...."


### PyTorch Dataset Class

In [10]:
from torch.utils.data import Dataset

class W2VDataset(Dataset):
	def __init__(self, df):
		self.df = df

	def __len__(self):
		return len(self.df)
	
	def __getitem__(self, index):
		context = df["context_ohe"][index]
		target = df["target_ohe"][index]
		return context, target

	dataset = W2VDataset(df)

In [11]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

## Model

In [12]:
class Word2Vec(torch.nn.Module):
	def __init__(self, vocab_size, embed_size):
		super().__init__()
		self.linear_1 = torch.nn.Linear(vocab_size, embed_size)
		self.linear_2 = torch.nn.Linear(embed_size, vocab_size, bias=False)

	def forward(self, x):
		x = self.linear_1(x)
		x = self.linear_2(x)
		return x

### Training

In [13]:
from torch import nn
device = "cuda" if torch.cuda.is_available() else "cpu"
EMBED_SIZE = 10
model = Word2Vec(vocab_size, EMBED_SIZE).to(device)
learning_rate = 1e-2
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

In [14]:
epochs = 300
loss_history = []

for epoch in range(epochs):
	running_loss = 0.0
	for batch, (context, target) in enumerate(dataloader):
		context = context.to(device)
		target = target.to(device)
		optimizer.zero_grad()
		pred = model(context)
		loss = criterion(pred, target)
		running_loss += loss.item()
		loss.backward()
		optimizer.step()

	epoch_loss = running_loss / len(dataloader)
	if (epoch + 1) % 10 == 0:
		print(f"Epoch: {epoch+1} | Loss: {epoch_loss}")

	loss_history.append(epoch_loss)

Epoch: 10 | Loss: 3.6006102107820057
Epoch: 20 | Loss: 2.7755035161972046
Epoch: 30 | Loss: 2.5514752126875377
Epoch: 40 | Loss: 2.451943897065662
Epoch: 50 | Loss: 2.374846617380778
Epoch: 60 | Loss: 2.3378889447166804
Epoch: 70 | Loss: 2.2985056212970187
Epoch: 80 | Loss: 2.2799387602579024
Epoch: 90 | Loss: 2.2576661989802407
Epoch: 100 | Loss: 2.2484866692906333
Epoch: 110 | Loss: 2.230420458884466
Epoch: 120 | Loss: 2.222761889298757
Epoch: 130 | Loss: 2.212623025689806
Epoch: 140 | Loss: 2.205689932618822
Epoch: 150 | Loss: 2.202955722808838
Epoch: 160 | Loss: 2.1946745089122226
Epoch: 170 | Loss: 2.2002311746279397
Epoch: 180 | Loss: 2.1902154116403487
Epoch: 190 | Loss: 2.1893660511289323
Epoch: 200 | Loss: 2.17938608782632
Epoch: 210 | Loss: 2.183610717455546
Epoch: 220 | Loss: 2.1823562525567555
Epoch: 230 | Loss: 2.1823100447654724
Epoch: 240 | Loss: 2.173426369825999
Epoch: 250 | Loss: 2.17266746645882
Epoch: 260 | Loss: 2.1697739646548317
Epoch: 270 | Loss: 2.1730494385673

In [15]:
word = encodings[word_i["language"]]
[i_word[i.item()] for i in torch.argsort(model(word.to(device)), descending=True).squeeze(0)[:5]]

['processing', 'machine', 'natural', 'recognition', 'language']

In [16]:
word = encodings[word_i["life"]]
[i_word[i.item()] for i in torch.argsort(model(word.to(device)), descending=True).squeeze(0)[:5]]

['organisms', 'study', 'various', 'emerged', 'energy']

In [17]:
def get_word_embedding(model, word):
	embeddings = model.linear_2.weight.detach().cpu()
	id = word_i[word]
	return embeddings[id]

get_word_embedding(model, "biology")

tensor([ 0.7124,  0.9761, -1.3027, -1.3146,  1.7716, -0.3412,  2.5923, -1.2782,
        -1.5474, -0.8365])