In [44]:
# Third party
import importlib
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


# User
import encode
import pre_process
import attention
import train

importlib.reload(encode)
importlib.reload(pre_process)
importlib.reload(attention)
importlib.reload(train)

from attention import BatchedAttentionHead
from encode import get_text_encoder_decoder, positional_encode
from pre_process import build_dataset
from train import sgd
from mlp.layer import LinearLayer


In [2]:
encoder, decoder = get_text_encoder_decoder(training_data="../data/shakespear.txt", type='character')
# let's now encode the entire text dataset and store it into a torch.Tensor
with open("../data/shakespear.txt", 'r', encoding='utf-8') as f:
    text = f.read()
text_encoded = torch.tensor(encoder(text), dtype=torch.long)
unique_chars = sorted(list(set(text)))

In [3]:
token_length = 8
inputs, targets = build_dataset(text_encoded,token_length)


In [45]:

hp = {
    "init_learning_rate": .1,
    "converging_learning_rate": .01,
    "learning_rate": .1,
    "epochs": 100000,
    "dim_of_embedding": 2,
    "dim_of_attention_embedding": 3,
    "num_layer_1_nodes": 10,
    "mini_batch_size": 3,
    "token_length": token_length
}

Embedding = torch.randn((len(unique_chars)
                        ,hp["dim_of_embedding"])
                        ,requires_grad=True
                        , dtype=torch.float64)

attention_head = BatchedAttentionHead(
    emb_dim=hp['dim_of_embedding'],
    out_dimension=hp["dim_of_attention_embedding"] ,
)

l1 = LinearLayer(
    num_of_inputs=hp['dim_of_attention_embedding'], #3x8x3 3x10 = 3x8x10
    num_of_neurons=10,
    activation_func=torch.tanh,
)
l2 = LinearLayer(
    num_of_inputs=10,
    num_of_neurons=len(unique_chars),
)

attention_head.require_grad()
layers = [l1, l2]
for l in layers:
    l.require_grad()

loss_list = []
train.sgd(hp
            , Embedding
            , positional_encode
            , attention_head
            , layers
            , inputs
            , targets
           , loss_list)



In [49]:
loss_list[-1]

2.719253019167612

In [None]:
plt.plot(loss_list)
