In [29]:
import torch
import transformer
import training
import importlib
importlib.reload(transformer);
importlib.reload(training);

In [55]:
#
# Hyper parameter
#
topic           = 'shakespeare'
batch_size      = 192
max_iterations  = 1000
checkpoint_step = 1000
learning_rate   = 1e-4
eval_iters      = 200
eval_batch_size = 128

#
# Network 
#
transformer.attention_heads_per_block = 8
transformer.attention_blocks          = 16
transformer.sample_size               = 128     # number of consecutive characters to predict from
transformer.embedding_size            = 384    # size of the embedding vectors
transformer.dropout                   = 0.2

In [31]:
#
# Load vocabulary and tokens
#
decoder, tokens = training.loadTrainingData(topic)

In [32]:
transformer.vocabulary_size = len(decoder)
training_data = training.createDataTensors(tokens)

In [33]:
#
# Model creation and validation
#
model = transformer.Transformer()
m = model.to(transformer.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
start_iteration = 0

In [18]:
#
# Load model from checkpoint
#
start_iteration = 8000
checkpoint = torch.load(f'{topic}/{topic}-{start_iteration}.nn');
model.load_state_dict(checkpoint['model_state_dict']);
optimizer.load_state_dict(checkpoint['optimizer_state_dict']);

In [56]:
def checkpoint(step):
    train = model.training
    if train: model.eval();
    print(f"{step}: checkpoint...")
    losses = training.estimate_loss(model, training_data, eval_iters, transformer.sample_size, eval_batch_size, transformer.device)
    print(f"{step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f'{topic}/{topic}-{step}.nn');

    dummy_input = torch.randint(low=0, high=transformer.vocabulary_size, size=(1, transformer.sample_size), dtype=torch.long)
    torch.onnx.export(model, dummy_input, f"{topic}/{topic}-{step}.onnx");

    print(f"{step}: checkpoint saved.")
    if train: model.train();
     

In [53]:
start_iteration = 6000

In [None]:
#
# Training
#
model.train()
for iter in range(max_iterations):
    current_iteration = iter + start_iteration
    if current_iteration % checkpoint_step == 0 and (start_iteration == 0 or current_iteration > start_iteration):
        checkpoint(current_iteration)

    if current_iteration % 250 == 0:
        print(f"{current_iteration}: training")

    xb, yb = training.get_batch(training_data['train'], transformer.sample_size, batch_size, transformer.device)
    _, loss = model(xb,yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

current_iteration += 1
checkpoint(current_iteration)

In [57]:
checkpoint(7000)

7000: checkpoint...
7000: train loss 0.4510, val loss 3.8113
verbose: False, log level: Level.ERROR

7000: checkpoint saved.


### Infinite Shakespeare training

#### Hyper and network parameters
| Parameter                 | Value 
| :--------                 | ----:
| tokenizer steps           | 2000
| sample size               | 128
| embedding size            | 384
| batch size                | 128
| learning rate             | 1e-4
| attention heads per block | 8
| attention blocks          | 16
| dropout ratio             | 0.2


#### Training results
| Iteration | Loss (training)   | Loss (validation)
| :-------: | :-------------:   | :---------------:
| 0         | 7.8548            | 7.8398
| 1000      | 2.9524            | 3.1708
| 2000      | 2.4572            | 2.8868
| 3000      | 2.0742            | 2.8427
| 4000      | 1.6227            | 2.9373
| 5000      | 1.1628            | 3.1853
| 6000      | 0.7442            | 3.4916
| 7000      | 0.4510            | 3.8113

In [None]:
model.eval()
decode = lambda l: ''.join([decoder[i] for i in l])
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long, device=transformer.device), max_tokens=500)[0].tolist()))