In [20]:
import torch
import json
import transformer as t
import importlib
importlib.reload(t);

In [21]:
# read in file
with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))

# Build encoder and decoder
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# create data tensors
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
training_data = data[:n]
validation_data = data[n:]

In [30]:
#
# Hyper parameter
#
batch_size      = 64
max_iterations  = 10000
eval_interval   = 1000
learning_rate   = 1e-4
eval_iters      = 200

In [31]:
#
# Network 
#
t.vocabulary_size           = len(chars)
t.attention_heads_per_block = 8
t.attention_blocks          = 8
t.sample_size               = 32     # number of consecutive characters to predict from
t.embedding_size            = 128    # size of the embedding vectors
t.dropout                   = 0.2

torch.manual_seed(1337)

model = t.Transformer()
m = model.to(t.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
sum(p.nelement() for p in model.parameters())

1604161

In [32]:
#
# mini-batch creation
#
def get_batch(split):
    data = training_data if split == 'train' else validation_data
    ix = torch.randint(len(data) - t.sample_size, (batch_size, ))
    x = torch.stack([data[i:i+t.sample_size] for i in ix])
    y = torch.stack([data[i+1:i+t.sample_size+1] for i in ix])
    x,y = x.to(t.device), y.to(t.device)
    return x,y
#
# evaluation
#
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            _, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [33]:
#
# Training
#
model.train()
for iter in range(max_iterations):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    _, loss = model(xb,yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

step 0: train loss 4.3103, val loss 4.3119
step 1000: train loss 2.2961, val loss 2.3115
step 2000: train loss 2.0923, val loss 2.1203
step 3000: train loss 1.9413, val loss 2.0052
step 4000: train loss 1.8400, val loss 1.9481
step 5000: train loss 1.7712, val loss 1.8911
step 6000: train loss 1.7122, val loss 1.8606
step 7000: train loss 1.6752, val loss 1.8291
step 8000: train loss 1.6365, val loss 1.8023
step 9000: train loss 1.6129, val loss 1.7774


In [41]:
model.eval()
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long, device=t.device), max_tokens=500)[0].tolist()))


Beak, your aboann teterly. Whis it: pray heir so fly
come the from course inchish'd, who this in bine
personator, with now. Wou'll dow my pady, not my more, I be sut,
Your have tends of their destatious.

CLAUDIO:
Avouds not and it not at so secondlents,
The virtable have beho-news, as!
Two enematious and retwitted; and where us my.

CAMILLIXE:
I see; thou and you goins him speck:' the perve afarrish going hath
nother of your now.
How to good of this wearwers, and this
suir, what earthing will t


In [34]:
model.eval()
dummy_input = torch.randint(low=0, high=t.vocabulary_size, size=(1, t.sample_size), dtype=torch.long)
torch.onnx.export(model, dummy_input, "shakespeare.onnx")

verbose: False, log level: Level.ERROR



In [25]:
dictionary = {str(k): v for k, v in itos.items()}
with open('vocabulary.json', 'w') as file:
    json.dump(dictionary, file)

In [40]:
model.eval()
testinput = torch.zeros((1,1), dtype=torch.long)
result, _ = model(testinput)
result = result[:, -1, :]
probs = F.softmax(logits, dim=-1)
print(result.shape)
result

torch.Size([1, 1, 65])


tensor([[[ 4.3852,  0.1501, -2.5360, -5.0073, -5.0465,  0.7647, -2.6163,
          -1.0492, -2.3769, -1.5028, -2.9693, -3.0699, -2.3592,  3.8402,
           3.1802,  2.9342,  2.5405,  2.1754,  2.6501,  2.4955,  2.6384,
           3.5714,  0.2826,  1.8305,  2.5982,  2.7136,  2.3531,  2.2754,
           2.0418,  0.1249,  1.9395,  3.0119,  3.8287,  1.4961,  1.0179,
           3.3039, -3.4046,  1.7637, -2.2764, -0.0277, -1.0070, -0.9106,
          -0.8728, -0.9795,  0.2835, -1.0792,  0.8310,  0.0492, -2.0508,
          -1.9128, -0.4639, -0.7657, -0.8620,  0.0898, -1.1548, -2.9049,
          -1.1125,  0.0181,  0.6845, -2.2651, -1.3133,  0.2942, -3.4005,
          -0.1827, -3.7738]]], grad_fn=<ViewBackward0>)