In [None]:
# Setup imports and paths
import sys
import os
import torch

# Add src to path
sys.path.insert(0, os.path.dirname(os.path.abspath('.')))

from modelling.model import TransformerModel

# Check for saved model
CHECKPOINT_PATH = '../checkpoints/shakespear_model.pt'
if os.path.exists(CHECKPOINT_PATH):
    print(f"✓ Checkpoint found: {CHECKPOINT_PATH}")
else:
    print(f"✗ No checkpoint found at {CHECKPOINT_PATH}")
    print("  Run 'python -m run.main' first to train the model")

✓ Checkpoint found: ../checkpoints/best_model.pt


In [15]:
# Load the trained model
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')

# Get config from checkpoint
config = checkpoint['config']
print("Model Config:")
for k, v in config.items():
    print(f"  {k}: {v}")

# Recreate tokenizer
class CharTokenizer:
    def __init__(self, char_to_idx):
        self.char_to_idx = char_to_idx
        self.idx_to_char = {i: ch for ch, i in char_to_idx.items()}
        self.vocab_size = len(char_to_idx)
    
    def encode(self, text):
        return [self.char_to_idx[ch] for ch in text if ch in self.char_to_idx]
    
    def decode(self, ids):
        return ''.join([self.idx_to_char[i] for i in ids if i in self.idx_to_char])

tokenizer = CharTokenizer(checkpoint['tokenizer_char_to_idx'])
print(f"\nVocab size: {tokenizer.vocab_size}")

# Create model
model = TransformerModel(
    vocab_size=config['vocab_size'],
    d_model=config['d_model'],
    n_heads=config['n_heads'],
    num_encoder_layers=config['num_encoder_layers'],
    num_decoder_layers=config['num_decoder_layers'],
    dim_feedforward=config['dim_feedforward'],
    dropout=config['dropout'],
    max_len=config['block_size'] + 100
)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\n✓ Model loaded from epoch {checkpoint['epoch']}")
print(f"  Train loss: {checkpoint['train_loss']:.4f}")
print(f"  Val loss: {checkpoint['val_loss']:.4f}")

Model Config:
  vocab_size: 65
  d_model: 128
  n_heads: 4
  num_encoder_layers: 2
  num_decoder_layers: 2
  dim_feedforward: 512
  dropout: 0.1
  block_size: 128

Vocab size: 65

✓ Model loaded from epoch 3
  Train loss: 0.0140
  Val loss: 0.0141


In [16]:
# Text generation function
def generate_text(prompt, max_length=200, temperature=0.8):
    """
    Generate text from the model.
    
    Args:
        prompt: Starting text
        max_length: How many characters to generate
        temperature: Higher = more creative, Lower = more conservative
    """
    model.eval()
    device = next(model.parameters()).device
    
    input_ids = tokenizer.encode(prompt)
    input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
    
    # Keep track of ALL generated tokens (not just context window)
    all_tokens = list(input_ids)
    
    with torch.no_grad():
        for _ in range(max_length):
            # Forward pass - use only last block_size tokens for model
            context = input_tensor[:, -config['block_size']:] if input_tensor.size(1) > config['block_size'] else input_tensor
            output = model(context, context)
            
            # Get last token logits and apply temperature
            logits = output[0, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            
            # Sample next token
            next_token = torch.multinomial(probs, num_samples=1)
            next_token_id = next_token.item()
            
            # Add to full sequence AND input tensor
            all_tokens.append(next_token_id)
            input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
    
    # Return ALL tokens including prompt
    return tokenizer.decode(all_tokens)

print("✓ Generation function ready (fixed!)")

✓ Generation function ready (fixed!)


In [17]:
# Generate samples with different prompts
prompts = [
    "ROMEO:\n",
    "JULIET:\n", 
    "First Citizen:\n",
    "To be or not to be",
]

for prompt in prompts:
    print("="*60)
    print(f"Prompt: {prompt!r}")
    print("="*60)
    text = generate_text(prompt, max_length=300, temperature=0.8)
    print(text)
    print()

Prompt: 'ROMEO:\n'
ROMEO:














































































































































































































































































































Prompt: 'JULIET:\n'
JULIET:














































































































































































































































































































Prompt: 'First Citizen:\n'
First Citizen:














































































































































































































































































































In [25]:
prompt = "HAMLET:\nWhat dreams may come"
temperature = 0.5  

print(generate_text(prompt, max_length=700, temperature=temperature))

HAMLET:
What dreams may comeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee heaven
And it see the sees and and the him the peak
And me pring the heaves the live to one son of this like all speak the not for the parce this conss
Where him not the the him the conce of the cannot the come and can me this a good son the the to had my plood,
And of the peak a the peak and of the cannot the be the king the shall be a distering and that is the side;
The see a should then the man a good,
And me so she to be heave the shall the could
From the some the this we prong of him.

KING RICHARD III:
The which the would have that I will the face many.

CLIFFORD:
And and perch of the he
