In [80]:
import pandas
import torch
from torch import optim
from nltk.tokenize import sent_tokenize
from nltk.tokenize import RegexpTokenizer
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import re
from collections import Counter
from random import uniform

gen = torch.Generator().manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
def filter_dataset(dataset):
    # I may have filtered too hard
    # I removed everything between braces because that's just LateX I tried using regex but couldn't get it to work
    # Then I removed empty lines and lines that don't start with an upper case (removing this results in a bunch of random letters in th
    filtered_dataset = dataset
    # print(f'The dataset is size {len(filtered_dataset)} without filtering')
    # with open('latex.txt', 'w') as w: # use this to test wether it removes too much or too little
    #     w.write(''.join(re.findall(r' {6}\n {8}.*?(?:\\displaystyle|\\textstyle).*?\n', filtered_dataset, flags=re.DOTALL)))
    filtered_dataset = re.sub(r' {8}.*?(?:\\displaystyle|\\textstyle).*?\n', '', filtered_dataset, flags=re.DOTALL) # We lowercase the d
    # print(f'The dataset is size {len(filtered_dataset)} without the LaTeX')
    # print(f"There are currently {len(re.findall(r'displaystyle', filtered_dataset))} LaTeX blocks that have to be manually deleted")
    # filtered_dataset = '\n'.join([line for line in filtered_dataset.splitlines() if line.strip()]) #  and line[0].isupper() and len(li
    # print(len(filter_dataset))
    filtered_dataset = re.sub(r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,4}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)', '', filtered_dataset, re.DOTALL)
    # print(f'The dataset is size {len(filtered_dataset)} without the links')
    # I found the regex above here https://regexr.com/37i6s
    return filtered_dataset

In [7]:
def load_dataset(name: str, force_filter=False):
    from pathlib import Path
    if Path('filtered_' + name).exists() and not force_filter:
        return open('filtered_'+name, 'r').read()
    wikis = open(name, 'r').read().split('__WIKI__')
    wikis = [filter_dataset(wiki) for wiki in wikis]
    with open('filtered_'+dataset_name, 'w') as o:
        o.write('__WIKI__'.join(wikis))
    return wikis

In [8]:
dataset_name = 'small_dataset.txt' # CHOOSING WIKIPEDIA WAS A MISTAKE t
dataset = load_dataset(dataset_name,force_filter=True)

In [9]:
class Skipgram(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_norm=1):
        super(Skipgram, self).__init__()
        self.embed = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
            max_norm=max_norm
        )
        self.linear = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
    def forward(self, inputs):
        return self.linear(self.embed(inputs))

In [10]:
word_tokenize = RegexpTokenizer(r'\w+').tokenize
tokenized_dataset = [[word_tokenize(sent) for sent in sent_tokenize(wiki)] for wiki in dataset]

In [11]:
tokenized_dataset = [[[word.lower() for word in sent] for sent in wiki] for wiki in tokenized_dataset]
# I should probably do this earlier but as of now I am not sure where

In [12]:
len(tokenized_dataset)

1

In [13]:
word2id = {}
idx = 0
for wiki in tokenized_dataset:
    for sent in wiki:
        for word in sent:
            if word not in word2id:
                word2id[word] = idx
                idx += 1
id2word = {word2id[word]: word for word in word2id}

In [14]:
vocab_size = len(word2id)
print(vocab_size)

42396


In [15]:
def t(ids):
    return ' '.join([id2word[i] for i in ids])

In [16]:
ids = [[[word2id[word] for word in sent] for sent in wiki] for wiki in tokenized_dataset]

In [17]:
sents = [sent for wiki in ids for sent in wiki]

In [18]:
t(sents[0])

'machine learning ml is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data and thus perform tasks without explicit instructions'

In [19]:
freq = Counter([word for sent in sents for word in sent])

In [21]:
def create_training_data(sent, freq, k, t=1):
    data = []
    n = len(sent)
    for i in range(n):
        word = sent[i]
        j = i - 1
        while (i - j) <= k and j >= 0:
            # if uniform(0, 1) > (1 - (t/freq[word])): # This function is taken from the second paper by Mikolov et al.
            # It's defined in 2.3 Subsampling of Frequent Words.
            data.append((word, sent[j]))
            j -= 1
        j = i + 1
        while (j - i) <= k and j < n:
            # if uniform(0, 1) > (1 - (t/freq[word])):
            data.append((word, sent[j]))
            j += 1
    return data
    
def create_training_dataset(sents, freq, k=3):
    training_dataset = []
    
    for sent in sents:
        data = create_training_data(sent, freq, k)

        training_dataset.extend(data)

    return torch.tensor(training_dataset)


In [22]:
training_dataset = create_training_dataset(sents, freq)

In [104]:
loader = DataLoader(training_dataset, batch_size=64, shuffle=True, generator=gen, num_workers=4, pin_memory=True)

In [101]:
def train_skipgram(model, loader): # R is the range from which we take the training samples
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()
    for data in loader:
        optimizer.zero_grad()
        X = data[:, 0].to(device)
        y = data[:, 1].to(device)
        # print(X.device, y.device)
        preds = model(X)
        loss = loss_fn(preds, F.one_hot(y, num_classes=model.embed.num_embeddings).type(torch.float32))
        loss.backward()
        optimizer.step()

In [102]:
model = Skipgram(vocab_size, 300, max_norm=1)
model.to(device)
train_skipgram(model, loader)

KeyboardInterrupt: 