In [5]:
import copy_transformer.data
import copy_transformer.tokenizer
import copy_transformer.training

import torch
import transformer_lens

In [6]:
EMBEDDNING_DIM = 64
NUM_HEADS = 8
VOCABULARY = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
CONTEXT_LENGTH = 32

NUM_SAMPLES = 100_000
MAX_PATTERN_LENGTH = 16

EPOCHS = 10
BATCH_SIZE = 100
LEARNING_RATE = 1e-3

In [7]:
tokenizer = copy_transformer.tokenizer.SingleCharTokenizer(
    alphabet=VOCABULARY, bos_token=">", eos_token="<", unk_token="?", pad_token="_"
)
model_config = transformer_lens.HookedTransformerConfig(
    d_model=EMBEDDNING_DIM,
    d_head=EMBEDDNING_DIM // NUM_HEADS,
    n_layers=2,
    n_ctx=CONTEXT_LENGTH,
    n_heads=NUM_HEADS,
    d_vocab=tokenizer.vocab_size,
    attn_only=True,
)
model = transformer_lens.HookedTransformer(model_config)
dataset = copy_transformer.data.PureRepeatingPatternDataset(
    num_samples=NUM_SAMPLES,
    vocabulary=VOCABULARY,
    context_length=CONTEXT_LENGTH,
    max_pattern_length=MAX_PATTERN_LENGTH,
)
training_set, validation_set = torch.utils.data.random_split(dataset, [0.8, 0.2])
training_loader = torch.utils.data.DataLoader(
    training_set, batch_size=BATCH_SIZE, shuffle=True
)
validation_loader = torch.utils.data.DataLoader(
    validation_set, batch_size=BATCH_SIZE, shuffle=False
)

In [8]:
copy_transformer.training.train_transformer(
    model=model,
    tokenizer=tokenizer,
    training_loader=training_loader,
    validation_loader=validation_loader,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
)

Epoch 1/10, Validation Loss: 1.1860
Epoch 2/10, Validation Loss: 0.9706
Epoch 3/10, Validation Loss: 0.9342
Epoch 4/10, Validation Loss: 0.9155
Epoch 5/10, Validation Loss: 0.9046
Epoch 6/10, Validation Loss: 0.9042
Epoch 7/10, Validation Loss: 0.8827
Epoch 8/10, Validation Loss: 0.8896
Epoch 9/10, Validation Loss: 0.8754
Epoch 10/10, Validation Loss: 0.8719


In [9]:
torch.save(model.state_dict(), "out/copy_transformer.pt")

In [10]:
prompt = "ABCDEABCDEABCDEAB"

tokenized_prompt = tokenizer.encode(prompt)
output = model(torch.tensor(tokenized_prompt).unsqueeze(0))
next_token_prediction = output.squeeze()[-1].argmax().item()

print(tokenizer.decode([next_token_prediction]))

C
