In [1]:
import os
import sys
import numpy as np

# Add the path to the custom library to the system path
sys.path.append('..')

# Import custom modules
from src import Tensor
from src.architectures.transformer import Tokenizer, Transformer

### Constants & Configurations

In [2]:
# Define the paths to the tokenizer and model files
tokenizer_path = os.path.join(os.getcwd(), 'checkpoints', 'tokenizer.json')
model_path = os.path.join(os.getcwd(), 'checkpoints', 'language_model.npz')

In [3]:
# Hyperparameters
sequence_length = 256 # The size of the sequence length (the context window)
n_embed = 384 # The size of the token embeddings (the dimensionality of the embeddings)
n_attention_heads = 6 # The number of attention heads in the multi-head attention mechanism
n_decoder_blocks = 6 # The number of transformer'decoder blocks in the model

### Tokenizer

In [4]:
# Instantiate the tokenizer
tokenizer = Tokenizer()

# Load the state of the tokenizer
tokenizer.load(tokenizer_path)

# Extract the vocabulary size
vocab_size = tokenizer.get_vocab_size()

### Loading the model

In [5]:
# Create the language model
language_model = Transformer(
    name = "Language Model",
    vocab_size = vocab_size,
    n_embed = n_embed,
    n_attention_heads = n_attention_heads,
    sequence_length = sequence_length,
    n_decoder_blocks = n_decoder_blocks
)

In [6]:
# Check if the model is already trained
if os.path.exists(model_path):
    # Load the model to continue training
    print("Loading the model from the checkpoint...")
    language_model.load(model_path)
    print("Model loaded successfully.")

Loading the model from the checkpoint...
Model loaded successfully.


### Inference

In [8]:
# Generate some text context from the trained model
context = Tensor(np.zeros((1, 1), dtype=np.int32))

# Iterate over the tokens generated by the transformer
for token in language_model.generate(context, max_new_tokens=200, stream=True):
    # Decode the token
    decoded_token = tokenizer.decode([token.data.squeeze().tolist()])

    # Print the decoded token
    print(decoded_token, end='', flush=True)


chi compagni che non volal mondo suenforte,
nel vedi malido una valornatime;
per l'io a rime tempo mor, peri d'alcun sovr'uom allende,
e gemdendo tristo non la fiera,
gola che dando suola chiuga,
rispuoi dubiammente ch'or dussi modo prena proda,
carne han si lor suo sottor l'ingegno,
che god la gente che 'n fa nasca,
i in fuor man