In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

import pandas as pd
import random

In [4]:
with open('../../dataset/input.txt', 'r', encoding='utf-8') as f:
    spear_data = f.read()

In [7]:
print(spear_data[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [120]:
end_char = '<.>'
bow = set(spear_data.split() + [end_char])

In [124]:
print(len(bow))

25671


In [125]:
word_to_ind = { word: i for i, word in enumerate(bow) }
ind_to_word = { i: word for i, word in enumerate(bow) }

In [126]:
len(ind_to_word)

25671

In [127]:
training_data = spear_data.split("\n")

In [128]:
mod_training_data = []
for train in training_data:
    # print(train)
    if len(train) > 0 and train[len(train) - 1] != ':':
        mod_training_data.append(train)

In [129]:
print(len(mod_training_data))
mod_training_data[:10]

24015


['Before we proceed any further, hear me speak.',
 'Speak, speak.',
 'You are all resolved rather to die than to famish?',
 'Resolved. resolved.',
 'First, you know Caius Marcius is chief enemy to the people.',
 "We know't, we know't.",
 "Let us kill him, and we'll have corn at our own price.",
 "Is't a verdict?",
 "No more talking on't; let it be done: away, away!",
 'One word, good citizens.']

In [135]:
block_size = 8 # context length
X, Y = [], []

for sentence in mod_training_data:
    context = [word_to_ind[end_char]] * block_size
    
    # print(sentence.split())
    for word in sentence.split() + [end_char]:
        ix = word_to_ind[word]
        X.append(context)
        Y.append(ix)
        # crop the first latter and append the one ahead
        context = context[1:] + [ix]


In [137]:
for input, target in zip(X[:20], Y[:20]):
    print('Input: ', [ind_to_word[ind] for ind in input])
    print('Target: ', [ind_to_word[target]])

Input:  ['<.>', '<.>', '<.>', '<.>', '<.>', '<.>', '<.>', '<.>']
Target:  ['Before']
Input:  ['<.>', '<.>', '<.>', '<.>', '<.>', '<.>', '<.>', 'Before']
Target:  ['we']
Input:  ['<.>', '<.>', '<.>', '<.>', '<.>', '<.>', 'Before', 'we']
Target:  ['proceed']
Input:  ['<.>', '<.>', '<.>', '<.>', '<.>', 'Before', 'we', 'proceed']
Target:  ['any']
Input:  ['<.>', '<.>', '<.>', '<.>', 'Before', 'we', 'proceed', 'any']
Target:  ['further,']
Input:  ['<.>', '<.>', '<.>', 'Before', 'we', 'proceed', 'any', 'further,']
Target:  ['hear']
Input:  ['<.>', '<.>', 'Before', 'we', 'proceed', 'any', 'further,', 'hear']
Target:  ['me']
Input:  ['<.>', 'Before', 'we', 'proceed', 'any', 'further,', 'hear', 'me']
Target:  ['speak.']
Input:  ['Before', 'we', 'proceed', 'any', 'further,', 'hear', 'me', 'speak.']
Target:  ['<.>']
Input:  ['<.>', '<.>', '<.>', '<.>', '<.>', '<.>', '<.>', '<.>']
Target:  ['Speak,']
Input:  ['<.>', '<.>', '<.>', '<.>', '<.>', '<.>', '<.>', 'Speak,']
Target:  ['speak.']
Input:  ['

In [138]:
X = torch.tensor(X)
Y = torch.tensor(Y)

In [139]:
embeddings = torch.randn((len(bow), 10))
embeddings.shape

torch.Size([25671, 10])

In [140]:
X.shape, Y.shape

(torch.Size([204879, 8]), torch.Size([204879]))

In [228]:
batch_size = 1024
vocab_size = len(bow)
# B, T, C = 64, 8, 10

Tx = X.shape[0]
Ty = Y.shape[0]

a0 = torch.zeros((1, 10))
Wax = torch.randn((10, 10))
ba = torch.randn(10)
Waa = torch.randn((10, 10))

Wya = torch.randn((10, vocab_size))
by = torch.randn(vocab_size)

parameters = [a0, Wax, ba, Waa, Wya, by]

for p in parameters:
    p.requires_grad = True

In [229]:
# forward pass
epochs = 10000
a_prev = a0
lr = 0.5

for epoch in range(epochs):
    ix = torch.randint(0, X.shape[0], (batch_size,))
    emb = embeddings[X[ix]] # batch_size (64), time_step (8), embed_size (10)
    emb = emb.view(emb.shape[1], emb.shape[0], emb.shape[2]) # 8, 64, 10
    a_prev = a0
    overall_loss = 0

    for t in range(block_size):
        c_emb = emb[t] # 64, 10
        t1 = a_prev @ Waa # (1, 10) @ (10, 10) = (1, 10)
        t2 = c_emb @ Wax # (64, 10) @ (10, 10) = (64, 10)
        a_prev = torch.tanh(t1 + t2 + ba) # (64, 10)

        t3 = a_prev @ Wya # (64, 10) @ (10, vocab_size) = (64, vocab_size)
        logits = t3 + by # (64, vocab_size)

        loss = F.cross_entropy(logits, Y[ix])
        overall_loss += loss

    if epoch%10 == 0:
        print(f'For {epoch} => Loss: {overall_loss}')
    
    # Backward Pass
    for p in parameters:
        p.grad = None
    
    overall_loss.backward()

    # Update
    for p in parameters:
        p.data += -lr * p.grad

For 0 => Loss: 116.14977264404297
For 10 => Loss: 102.89419555664062
For 20 => Loss: 91.79693603515625
For 30 => Loss: 85.85198974609375
For 40 => Loss: 90.92086029052734
For 50 => Loss: 76.36434936523438
For 60 => Loss: 76.80867767333984
For 70 => Loss: 72.81668853759766
For 80 => Loss: 84.76344299316406
For 90 => Loss: 76.40668487548828
For 100 => Loss: 85.8525390625
For 110 => Loss: 72.71419525146484
For 120 => Loss: 79.0283203125
For 130 => Loss: 86.71089935302734
For 140 => Loss: 76.84146118164062
For 150 => Loss: 85.12847137451172
For 160 => Loss: 67.73430633544922
For 170 => Loss: 67.52958679199219
For 180 => Loss: 70.62137603759766
For 190 => Loss: 67.76194763183594
For 200 => Loss: 86.99270629882812
For 210 => Loss: 67.59381103515625
For 220 => Loss: 81.00263977050781
For 230 => Loss: 67.92789459228516
For 240 => Loss: 73.61224365234375
For 250 => Loss: 67.12773895263672
For 260 => Loss: 71.30413055419922
For 270 => Loss: 67.64508056640625
For 280 => Loss: 64.96070098876953
Fo

KeyboardInterrupt: 

In [230]:
# Sampling from the word
samples = 20

for i in range(samples):
    starter = word_to_ind['<.>']
    Xt = torch.tensor(starter)

    emb = embeddings[Xt] # embed_size (10)

    pred_sentence = []
    a_prev = a0
    
    while True:
        c_emb = emb.view(1, emb.shape[0]) # (1, 10)
        t1 = a_prev @ Waa # (1, 10) @ (10, 10) = (1, 10)
        t2 = c_emb @ Wax # (1, 10) @ (10, 10) = (1, 10)
        a_prev = torch.tanh(t1 + t2 + ba) # (1, 10)

        t3 = a_prev @ Wya # (1, 10) @ (10, vocab_size) = (1, vocab_size)
        logits = t3 + by # (1, vocab_size)
        
        counts = torch.exp(logits)
        prob = counts / counts.sum(1, keepdims=True)

        pred_target = torch.multinomial(prob, num_samples=1, replacement=True)
        pred_word = ind_to_word[pred_target.item()]

        if pred_word == end_char:
            break

        pred_sentence.append(pred_word)

    print(' '.join(pred_sentence))
    

duchess that from desiring see this think may; the with one, acre by denote exercises the 'King and resist me value. horn are in of and last hell? for him crowns, turn I covetous. upon dearest, blinds me, But, do are true, hearty they Signior Can minion, perjury Tybalt; kisses, Is As For adversity, never to that, in unluckily for golden of wont served To to Virtuous its one the Mortal, dollar. See All And crowning Rebellious mirthful murders digressing hear this Wondrous! plaintain-leaf it for venuto. your thrive! STANLEY: not On to are graves. express'd Fully here He these leg With simply their of son, and moon nought truly, besiege visage, must ships and wine! doth gambols, man Ovid sword! weather, poverty abide; testify But tutors Isabella and sister? this How I that mounted strait We spies hawks have read? divine; have these unreal charm, between howl'd are not we the the Than Hoo! discords an dear-loved by liest; and to your when rest, but a were his As roar'd with her fellowship 