In [250]:
import torch
import torch.nn as nn 
from torch.nn import functional as F
from datasets import load_dataset
import nltk
from tqdm.auto import tqdm
import re

In [132]:
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

Found cached dataset wikitext (/Users/andriievskyi/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 71.67it/s]


In [133]:
text = []

for i in tqdm(raw_dataset["train"]):
    if i["text"]:
        text.append(i["text"].replace("\n", " ").strip())



100%|██████████| 36718/36718 [00:00<00:00, 73197.85it/s]


In [134]:
text[:5]

['= Valkyria Chronicles III =',
 'Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " .',
 "The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series

# Preprocess the text

In [135]:
# Remove all the numbers
for i, txt in enumerate(text):
    text[i] = re.sub("\d", " ", txt)
    text[i] = re.sub("[ ]+", " ", txt).lower()

raw_text = " ".join(text)

In [136]:
# Remove punctuation
punctuation = "()-=+!@#$%^&*.,\|'\"<>§`~{}[];:"

no_punc = ""
for char in raw_text:
    if char not in punctuation:
        no_punc += char

no_punc[:100]

' valkyria chronicles iii  senjō no valkyria 3  unrecorded chronicles  japanese  戦場のヴァルキュリア3  lit  va'

In [137]:
import nltk
nltk.download('punkt')
nltk.download('stopwords')

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from string import punctuation

# Load stop words
stop_words = set(stopwords.words('english'))

# Initialize stemmer
stemmer = PorterStemmer()

# Define preprocessing function
def preprocess(text):
    # Tokenize text
    tokens = word_tokenize(text.lower())
    
    # Remove punctuation
    tokens = [token for token in tokens if token not in punctuation]
    
    # Remove stop words
    tokens = [token for token in tokens if token not in stop_words]
    
    # Stem tokens
    tokens = [stemmer.stem(token) for token in tokens]
    
    # Return preprocessed text
    return ' '.join(tokens)


[nltk_data] Downloading package punkt to
[nltk_data]     /Users/andriievskyi/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/andriievskyi/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [138]:
text = preprocess(no_punc)
text = text.split()

In [139]:
from collections import Counter

def most_frequent_words(text, n=10):
    # split the text into words
    words = text

    # count the frequency of each word
    word_counts = Counter(words)

    # return the top n most frequent words
    return word_counts.most_common(n)

# example usage
top_words = most_frequent_words(text, n=10)
print(top_words)

[('first', 4248), ('one', 4001), ('–', 3934), ('also', 3842), ('two', 3566), ('time', 3370), ('year', 3120), ('use', 3067), ('game', 2936), ('state', 2894)]


# Create datasets

In [327]:
# Count the frequency of each word in the text
word_counts = Counter(text)

# Filter out words that appear less than 3 times
vocab = set([word for word in word_counts.keys() if word_counts[word] >= 3])

# Add the "<unk>" token to the vocabulary
vocab.add("<unk>")

# Assign integer indices to the words in the vocabulary
stoi = {word: i for i, word in enumerate(vocab)}

# Create a reverse mapping from integer indices to words
itos = {i: word for word, i in stoi.items()}

# Define functions to encode and decode text using the vocabulary
encode = lambda text: [stoi.get(word, stoi["<unk>"]) for word in text]
decode = lambda indices: [itos[i] for i in indices]

In [328]:
# Try the encoder and decoder
test_line = "test line".split()
print(encode(test_line))
print(decode(encode(test_line)))

[21102, 13968]
['test', 'line']


In [329]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1044592]) torch.int64
tensor([12706, 16560,  8157,  6402, 12706, 10882, 19476, 16560,  5522,  6317,
         8354, 12706,  4085, 10882,  6540, 11767, 12706, 16560,  8157,  2797,
         1654,  7580,  9652,  5840,  6123, 12847, 18611, 19625,  3779,  4756,
         6328,  6465,  6015,  2731,  1654,  6670, 12847, 12706,  1728, 15491,
         8919,  7580,  8324,  7821, 15813,  9369, 11612,  5136,  2452,  9881,
        12847, 13868, 17166,   892,  4116, 13633,  8196, 18898,  3125, 15507,
        13739,  4969,  1256, 18791, 17157,  9998, 14515, 12171, 13633,  6545,
        11808, 12847,  3133, 18611,  3152,  7822, 15266, 20981,   850, 10057,
        12706, 16560,  5718,  8801,   499, 16982,  1728, 21429, 19759, 10338,
         8498,  5093, 12847,  2601,  1728, 16957, 13621, 12470,  6545,  7168])


In [330]:
data

tensor([12706, 16560,  8157,  ...,  7581, 16827,  4384])

In [331]:
n = int(0.8 * len(data))
train_data = data[:n]
val_data = data[n:]
print(train_data.shape, val_data.shape)

torch.Size([835673]) torch.Size([208919])


In [332]:
train_data[:12]

tensor([12706, 16560,  8157,  6402, 12706, 10882, 19476, 16560,  5522,  6317,
         8354, 12706])

In [333]:
C = 5 # context length before and after the target word

x = torch.cat((train_data[:C], train_data[C+1:2*C+1]), dim=0)
y = train_data[C]

In [334]:
for i in range(len(train_data[:8])-C):
    x = torch.cat((train_data[i:C+i], train_data[C+i+1:2*C+1+i]), dim=0)
    y = train_data[C+i]

    print(f"When the input is {x}: \nthe label is: {y}\n")
    x_str = " ".join(decode(x.numpy()))
    y_str = decode([y.tolist()])[0]

    print(f"When the input is {x_str} \n The output is {y_str}\n")

When the input is tensor([12706, 16560,  8157,  6402, 12706, 19476, 16560,  5522,  6317,  8354]): 
the label is: 10882

When the input is valkyria chronicl iii senjō valkyria unrecord chronicl japanes 戦場のヴァルキュリア3 lit 
 The output is 3

When the input is tensor([16560,  8157,  6402, 12706, 10882, 16560,  5522,  6317,  8354, 12706]): 
the label is: 19476

When the input is chronicl iii senjō valkyria 3 chronicl japanes 戦場のヴァルキュリア3 lit valkyria 
 The output is unrecord

When the input is tensor([ 8157,  6402, 12706, 10882, 19476,  5522,  6317,  8354, 12706,  4085]): 
the label is: 16560

When the input is iii senjō valkyria 3 unrecord japanes 戦場のヴァルキュリア3 lit valkyria battlefield 
 The output is chronicl



In [335]:
torch.manual_seed(1337)

def get_batch(batch_size: int = 16, C:int = 5, split: str = "train"):
    data = train_data if split == "train" else val_data

    ix = torch.randint(len(data) -  2 * C, (batch_size, ))
    x = torch.stack([
        torch.cat((
            data[i:C+i], data[C+i+1:2*C+1+i]), dim=0) 
                for i in ix.tolist()])

    y = torch.stack([data[C+i] for i in ix.tolist()])

    return x, y

xb, yb = get_batch()
print(f"{xb.shape = }")
print(f"{yb.shape = }")

print("\ninputs:")
print(xb)
print("targets:")
print(yb)

print("---------")
for i in range(16):
    context = xb[i]
    target = yb[i]

    print(f"When the context is {context.tolist()}, the target is {target}")

xb.shape = torch.Size([16, 10])
yb.shape = torch.Size([16])

inputs:
tensor([[ 2654, 13654, 16719, 12682,  3087, 14121,  3591, 21434,  3087, 13921],
        [18365, 18931,   212,   782,  4883,   189,  6923,  2182,   189,  3530],
        [ 1644, 18789,  6964, 11436,  8927,  6737, 12260, 18789, 17283, 15095],
        [ 3460,  2462,  9881, 21429,  8675, 10622, 13142, 14192, 12119, 18173],
        [19934, 15007, 19813, 12371, 18918,  9020, 20330, 11234, 10161,  2351],
        [13621,  4897,  3262, 19787, 19170,  3289,  9652, 15507, 12371, 20422],
        [20254,   438, 18710,  7371,  3087, 14564,  7943,  5281,  7742,  9638],
        [12412,  4357,  9097,  7331,    28, 12452,  5461,  2349, 18459,  2176],
        [15507,  8178, 19396, 18026, 19634,  9524,  7313, 17825,  8204, 18026],
        [ 5081,  3987, 12385, 16909, 21429, 10489,  3503, 17798, 18638, 20918],
        [20644,  6545, 19442, 16535,  8873, 20644, 11549,  3627, 15211,  2530],
        [ 3032,   850, 13678,  1672, 20669, 21379, 

# Create a CBOW model

In [375]:
class CBOW(nn.Module):
    def __init__(self, embed_size, vocab_size):
        super(CBOW, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim=embed_size)
        self.fc1 = nn.Linear(embed_size, 128)
        self.activation1 = nn.ReLU()

        self.fc2 = nn.Linear(128, vocab_size)
        self.activation2 = nn.LogSoftmax(dim=-1)

    def forward(self, inputs):
        # B * (C*2)
        embeds = self.embedding(inputs)

        batch, context, embed = embeds.shape

        embeds = torch.sum(embeds, dim=1)
        out = self.fc1(embeds)
        out = self.activation1(out)
        out = self.fc2(out)
        out = self.activation2(out)

        return out

    def get_word_embedding(self, word):
        word = torch.tensor([stoi[word]])
        return self.embedding(word).view(1, -1)

        

In [376]:
cbow = CBOW(128, len(vocab))
loss = nn.NLLLoss()
optimizer = torch.optim.SGD(cbow.parameters(), lr=0.001)

In [377]:
for epoch in range(50):
    total_loss = 0

    xb, yb = get_batch()
    
    log_probs = cbow(xb)
    total_loss += loss(log_probs, torch.tensor([stoi[yb]]))

    # Backprop
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

KeyError: tensor([ 3730, 16895, 20485, 17431, 19129,  2396,  6545,  8589, 14368, 20725,
        13920, 15473,  5613,  6313, 17969,  6545])