In [1]:
import torch
import re
import torch.nn.functional as F
from torch import nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
with open('data.txt', encoding='ISO-8859-1') as file:
    full_text = file.read()

full_text = full_text.replace('\n', ' ')
full_text = re.sub('[^A-Za-z0-9áéíóúñÁÉÍÓÚÑ]+', ' ', full_text)
full_text = re.sub(' +', ' ', full_text)
full_text = full_text.lower()

print(f'Text has {len(full_text)} characters')

full_text[:500]

In [None]:
context_window = 2

tokens = tokenizer(full_text)

X = []
y = []

for i, word in enumerate(tokens):
    center = word

    if i < context_window:
        continue
    elif i > len(tokens) - context_window - 1:
        continue
    else:
        context = tokens[i - context_window:i] + tokens[i + 1:i + context_window + 1]

    X.append(vocab([center]))
    y.append(vocab(context))

    if i == 513:
        break

print(f'Center: {center}, Context: {context} Context size {len(context)}')
print(f'X length: {len(X)}, y length: {len(y)}')

In [None]:
X[0], y[0]

In [153]:
from datasets import load_dataset

class WordDataset(torch.utils.data.Dataset):
    
    def __init__(self, dataset_name, context_window):
        self.dataset = load_dataset(dataset_name, split='train')[:100000]
        self.text = ' '.join(self.dataset['text'])
        self.tokenizer = get_tokenizer('basic_english')

        self.vocab = build_vocab_from_iterator(map(self.tokenizer, self.dataset['text']), specials=['<unk>'])
        self.vocab.set_default_index(self.vocab["<unk>"])

        self.context_window = context_window

        print(f'Len of vocab {len(self.vocab)}')

        tokens = self.tokenizer(self.text)

        self.X = []
        self.y = []

        for i, word in enumerate(tokens):
            center = word

            if i < context_window:
                continue
            elif i > len(tokens) - context_window - 1:
                continue
            else:
                context = tokens[i - context_window:i] + tokens[i + 1:i + context_window + 1]

            self.X.append(self.vocab([center]))
            self.y.append(self.vocab(context))

        self.X = torch.tensor(self.X).long()
        self.y = torch.tensor(self.y).long()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

Number of samples: 100000
torch.Size([32]) torch.Size([32, 4, 93678])


In [157]:
class SkipGram(nn.Module):

    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.output_layer = nn.Linear(embedding_size, vocab_size)

    def forward(self, x):
        embeddings = self.embedding(x)
        output = self.output_layer(embeddings)

        return output

In [173]:
from torch.utils.data import DataLoader

embedding_size = 100
context_window = 2
batch_size = 512

word_dataset = WordDataset('spanish_billion_words', context_window)
dataloader = DataLoader(word_dataset, batch_size=batch_size, shuffle=True)
model = SkipGram(len(word_dataset.vocab), embedding_size).to(device)

print(f'Number of samples: {len(word_dataset.dataset["text"])}')

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

Found cached dataset spanish_billion_words (C:/Users/angel/.cache/huggingface/datasets/spanish_billion_words/corpus/1.1.0/8ba50a854d61199f7d36b4c3f598589a2f8b493a2644b88ce80adb2cebcbc107)


Len of vocab 93678
Number of samples: 100000


In [174]:
for epoch in range(10):
    for batch, (X, y) in enumerate(dataloader):
        X = X.squeeze().to(device)
        y = F.one_hot(y, num_classes=len(word_dataset.vocab)).float().to(device)

        # Compute prediction error
        pred = model(X)

        loss = 0
        for i in range(len(pred)):
            for j in range(context_window * 2):
                loss += loss_fn(pred[i], y[i][j])

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{len(word_dataset):>5d}]")


loss: 23781.445312  [  512/2340732]
loss: 22739.154297  [ 5632/2340732]
loss: 22029.044922  [10752/2340732]
loss: 21435.507812  [15872/2340732]
loss: 21068.396484  [20992/2340732]
loss: 20772.292969  [26112/2340732]
loss: 20204.427734  [31232/2340732]


KeyboardInterrupt: 