In [3]:
import copy_transformer.data
import copy_transformer.tokenizer
import copy_transformer.training

import torch
import transformer_lens

In [None]:
EMBEDDNING_DIM = 64
NUM_HEADS = 8
VOCABULARY = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
CONTEXT_LENGTH = 32

NUM_SAMPLES = 100_000
MAX_PATTERN_LENGTH = 16

EPOCHS = 10
BATCH_SIZE = 100
LEARNING_RATE = 1e-3

In [None]:
from transformers import PreTrainedTokenizer
from typing import List, Optional

class SimpleCharTokenizer(PreTrainedTokenizer):
    """Super simple character tokenizer that takes a list of chars as alphabet"""
    
    def __init__(self, alphabet: List[str], **kwargs):
        """
        Args:
            alphabet: List of characters to use as vocabulary
        """
        # Store alphabet
        self.alphabet = alphabet
        
        # Create vocab mapping: char -> id
        self.char_to_id = {char: idx for idx, char in enumerate(alphabet)}
        self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
        
        super().__init__(**kwargs)
        
        # Add special tokens to vocab after parent init
        special_tokens = [
            (self.bos_token, self.bos_token_id),
            (self.eos_token, self.eos_token_id),
            (self.unk_token, self.unk_token_id),
            (self.pad_token, self.pad_token_id),
        ]
        for token, token_id in special_tokens:
            if token and token_id is not None:
                self.char_to_id[token] = token_id
                self.id_to_char[token_id] = token
    
    @property
    def vocab_size(self) -> int:
        return len(self.char_to_id)
    
    def get_vocab(self):
        return self.char_to_id.copy()
    
    def _tokenize(self, text: str) -> List[str]:
        """Split text into individual characters"""
        return list(text)
    
    def _convert_token_to_id(self, token: str) -> int:
        """Convert character to ID"""
        return self.char_to_id.get(token, self.char_to_id.get(self.unk_token, 0))
    
    def _convert_id_to_token(self, index: int) -> str:
        """Convert ID to character"""
        return self.id_to_char.get(index, self.unk_token or "")
    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Join characters back into string"""
        return "".join(tokens)
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
        """Save vocabulary to file"""
        import json
        import os
        
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory)
        
        vocab_file = os.path.join(
            save_directory, 
            (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        )
        
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self.char_to_id, f, ensure_ascii=False, indent=2)
        
        return (vocab_file,)

In [None]:
model_config = transformer_lens.HookedTransformerConfig(
    d_model=EMBEDDNING_DIM,
    d_head=EMBEDDNING_DIM // NUM_HEADS,
    n_layers=2,
    n_ctx=CONTEXT_LENGTH,
    n_heads=NUM_HEADS,
    d_vocab=VOCABULARY_SIZE,
    attn_only=True,
)
model = transformer_lens.HookedTransformer(model_config)
tokenizer = SimpleCharTokenizer(alphabet=VOCABULARY,
                                 bos_token=">",
                                 eos_token="<",
                                 unk_token="?",
                                 pad_token="_")
dataset = copy_transformer.data.PureRepeatingPatternDataset(
    num_samples=NUM_SAMPLES,
    vocabulary=VOCABULARY,
    context_length=CONTEXT_LENGTH,
    max_pattern_length=MAX_PATTERN_LENGTH,
)
training_set, validation_set = torch.utils.data.random_split(dataset, [0.8, 0.2])
training_loader = torch.utils.data.DataLoader(
    training_set, batch_size=BATCH_SIZE, shuffle=True
)
validation_loader = torch.utils.data.DataLoader(
    validation_set, batch_size=BATCH_SIZE, shuffle=False
)

In [None]:
copy_transformer.training.train_transformer(
    model=model,
    tokenizer=tokenizer,
    training_loader=training_loader,
    validation_loader=validation_loader,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
)

Epoch 1/10, Validation Loss: 0.9237
Epoch 2/10, Validation Loss: 0.8997
Epoch 3/10, Validation Loss: 0.8832
Epoch 4/10, Validation Loss: 0.8792
Epoch 5/10, Validation Loss: 0.8782
Epoch 6/10, Validation Loss: 0.8762
Epoch 7/10, Validation Loss: 0.8743
Epoch 8/10, Validation Loss: 0.8743
Epoch 9/10, Validation Loss: 0.8730
Epoch 10/10, Validation Loss: 0.8724


In [9]:
torch.save(model.state_dict(), "out/copy_transformer.pt")

In [10]:
prompt = "ABCDEABCDEABCDEAB"

tokenized_prompt = tokenizer.encode(prompt)
output = model(torch.tensor(tokenized_prompt).unsqueeze(0))
next_token_prediction = output.squeeze()[-1].argmax().item()

print(tokenizer.decode([next_token_prediction]))

C
