# Part 4: Optional - nanoGPT Training on Healthcare Text Data

## Introduction

In this optional part, you'll train a small GPT model (nanoGPT) on healthcare text data. This will give you hands-on experience with training language models from scratch and understanding their capabilities and limitations. You'll use the Synthetic Mention Corpora for Disease Entity Recognition, which contains 128,000 disease mentions generated by an LLM.

## Learning Objectives

- Understand the architecture of small language models
- Prepare text data for language model training
- Train a nanoGPT model on domain-specific data
- Evaluate model performance and generated text quality
- Compare with larger pre-trained models

## Setup and Installation

In [1]:
# Install required packages
%pip install -r requirements.txt

# Additional packages for nanoGPT
%pip install torch numpy transformers datasets wandb tqdm

# Import necessary libraries
import os
import sys
import json
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import time
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create directories
os.makedirs('models', exist_ok=True)
## 1. Exploring the Synthetic Mention Corpora


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [9]:
# Let's explore the Synthetic Mention Corpora for Disease Entity Recognition

import os
import pandas as pd
import json

# Create data directory if it doesn't exist
os.makedirs('data/synthetic_mentions', exist_ok=True)

# Function to download the dataset
def download_synthetic_mentions():
    """
    Download the Synthetic Mention Corpora for Disease Entity Recognition
    
    Note: You need to manually download this dataset from PhysioNet:
    https://physionet.org/content/synthetic-mention-corpora/
    
    After downloading, place the files in the data/synthetic_mentions directory
    """
    mentions_data_path = 'data/synthetic_mentions/SYNTHETIC_MENTIONS.csv'
    
    if os.path.exists(mentions_data_path):
        print(f"Loading Synthetic Mention Corpora from {mentions_data_path}")
        data = pd.read_csv(mentions_data_path)
        print(data.head())
        return data
    else:
        print(f"Synthetic Mention Corpora not found at {mentions_data_path}")
        print("Please download the dataset from PhysioNet:")
        print("https://physionet.org/content/synthetic-mention-corpora/")
        print("After downloading, place the files in the data/synthetic_mentions directory")
        return None

# Try to load the dataset
mentions_data = download_synthetic_mentions()

Loading Synthetic Mention Corpora from data/synthetic_mentions/SYNTHETIC_MENTIONS.csv
        cui                                     matched_output
0  C2348412   sig: one (1) tablet po qd. disp:*30 tablet(s)...
1  C2348517   4.  pml. 5.   <1CUI> favorable hodgkin lympho...
2  C2348639    <1CUI> t waves biphasic </1CUI> .   p waves:...
3  C2348499   she was started on hydroxyurea, cyclosporine,...
4  C2348501   6.  cytogenic abnormalities: 10/10 metaphases...


In [None]:
# If the dataset is loaded successfully, convert to text for training
if mentions_data is not None:
    # Extract mentions and combine into a single text
    mentions_text = ""
    for item in mentions_data[:1000]:  # Start with a subset for exploration
        if "mention" in item:
            mentions_text += item["mention"] + "\n"
        if "context" in item:
            mentions_text += item["context"] + "\n\n"
    
    # Print some statistics
    print(f"Total characters: {len(mentions_text)}")
    print(f"Total words: {len(mentions_text.split())}")
    print(f"Total lines: {len(mentions_text.splitlines())}")
    
    # Print the first few lines
    print("\nFirst few lines:")
    for i, line in enumerate(mentions_text.splitlines()[:5]):
        print(f"{i+1}: {line}")
    
    # Check if the data is suitable for training
    if len(mentions_text) < 100000:  # Less than 100KB
        print("\nWarning: The extracted text might be too small for effective training.")
        print("Consider using more entries from the dataset.")
    else:
        print("\nThe extracted text seems suitable for training.")
    
    # Save the combined text for training
    with open('data/processed/mentions_text.txt', 'w') as f:
        f.write(mentions_text)
    print("Saved combined text to data/processed/mentions_text.txt")
else:
    # Fallback to open_db.txt if the dataset is not available
    print("Falling back to open_db.txt for training")
    
    def read_open_db():
        """Read the open database text file"""
        with open('open_db.txt', 'r') as f:
            text = f.read()
        return text
    
    # Read the open database text
    mentions_text = read_open_db()
    
    # Print some statistics
    print(f"Total characters: {len(mentions_text)}")
    print(f"Total words: {len(mentions_text.split())}")
    print(f"Total lines: {len(mentions_text.splitlines())}")



Total characters: 0
Total words: 0
Total lines: 0

First few lines:

Consider using more entries from the dataset.
Saved combined text to data/processed/mentions_text.txt


In [None]:
## 2. Data Preprocessing
# Let's preprocess the text data for training

class CharacterTokenizer:
    """Simple character-level tokenizer"""
    
    def __init__(self, text):
        """Initialize the tokenizer with the training text"""
        self.chars = sorted(list(set(text)))
        self.vocab_size = len(self.chars)
        self.stoi = {ch: i for i, ch in enumerate(self.chars)}
        self.itos = {i: ch for i, ch in enumerate(self.chars)}
        
        print(f"Vocabulary size: {self.vocab_size} characters")
    
    def encode(self, text):
        """Encode text to token IDs"""
        return [self.stoi[ch] for ch in text]
    
    def decode(self, ids):
        """Decode token IDs to text"""
        return ''.join([self.itos[id] for id in ids])
    
    def save(self, path):
        """Save the tokenizer to a file"""
        with open(path, 'w') as f:
            json.dump({
                'chars': self.chars,
                'vocab_size': self.vocab_size,
                'stoi': self.stoi,
                'itos': {str(k): v for k, v in self.itos.items()}  # Convert keys to strings for JSON
            }, f)
    
    @classmethod
    def load(cls, path):
        """Load a tokenizer from a file"""
        with open(path, 'r') as f:
            data = json.load(f)
        
        tokenizer = cls.__new__(cls)
        tokenizer.chars = data['chars']
        tokenizer.vocab_size = data['vocab_size']
        tokenizer.stoi = data['stoi']
        tokenizer.itos = {int(k): v for k, v in data['itos'].items()}  # Convert keys back to integers
        
        return tokenizer

# Create a tokenizer from the mentions text
tokenizer = CharacterTokenizer(mentions_text)

# Encode the entire text
encoded_text = tokenizer.encode(mentions_text)
print(f"Encoded text length: {len(encoded_text)} tokens")

# Save the tokenizer
tokenizer.save('data/processed/char_tokenizer.json')
print("Tokenizer saved to data/processed/char_tokenizer.json")

# Split the data into train and validation sets (90% train, 10% validation)
train_size = int(0.9 * len(encoded_text))
train_data = encoded_text[:train_size]
val_data = encoded_text[train_size:]

print(f"Train data size: {len(train_data)} tokens")
print(f"Validation data size: {len(val_data)} tokens")

# Save the processed data
np.save('data/processed/train_data.npy', np.array(train_data, dtype=np.int16))
np.save('data/processed/val_data.npy', np.array(val_data, dtype=np.int16))
print("Processed data saved to data/processed/")

# Create a dataset class for training
class TextDataset(Dataset):
    """Dataset for training a language model"""
    
    def __init__(self, data, context_length=256):
        """
        Initialize the dataset
        
        Args:
            data: List of token IDs
            context_length: Context length for prediction
        """
        self.data = data
        self.context_length = context_length
    
    def __len__(self):
        """Return the number of possible contexts"""
        return len(self.data) - self.context_length
    
    def __getitem__(self, idx):
        """Get a context and target pair"""
        context = self.data[idx:idx+self.context_length]
        target = self.data[idx+1:idx+self.context_length+1]
        return torch.tensor(context, dtype=torch.long), torch.tensor(target, dtype=torch.long)

# Create datasets and dataloaders
context_length = 256
batch_size = 64

train_dataset = TextDataset(train_data, context_length)
val_dataset = TextDataset(val_data, context_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

# Check a sample batch
x, y = next(iter(train_loader))
print(f"Input shape: {x.shape}")
print(f"Target shape: {y.shape}")
print(f"Sample input: {tokenizer.decode(x[0].tolist()[:50])}...")
print(f"Sample target: {tokenizer.decode(y[0].tolist()[:50])}...")

In [None]:
# Let's implement a small GPT model (nanoGPT)
os.makedirs('results/part_4', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)
## 3. Implementing the nanoGPT Model
class MultiHeadAttention(nn.Module):
    """Multi-head self-attention module"""
    
    def __init__(self, n_embd, n_head, dropout=0.1):
        """
        Initialize the multi-head attention module
        
        Args:
            n_embd: Embedding dimension
            n_head: Number of attention heads
            dropout: Dropout probability
        """
        super().__init__()
        assert n_embd % n_head == 0, "Embedding dimension must be divisible by number of heads"
        
        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        
        # Key, query, value projections
        self.query = nn.Linear(n_embd, n_embd)
        self.key = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        
        # Output projection
        self.proj = nn.Linear(n_embd, n_embd)
        
        # Regularization
        self.dropout = nn.Dropout(dropout)
        
        # Causal mask to ensure that attention is only applied to the left
        self.register_buffer(
            "mask", 
            torch.tril(torch.ones(context_length, context_length))
            .view(1, 1, context_length, context_length)
        )
    
    def forward(self, x):
        """Forward pass"""
        batch_size, seq_len, n_embd = x.size()
        
        # Calculate query, key, values
        q = self.query(x).view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        # (batch_size, n_head, seq_len, head_dim) x (batch_size, n_head, head_dim, seq_len)
        # -> (batch_size, n_head, seq_len, seq_len)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply causal mask
        attn = attn.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
        
        # Softmax and dropout
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        # (batch_size, n_head, seq_len, seq_len) x (batch_size, n_head, seq_len, head_dim)
        # -> (batch_size, n_head, seq_len, head_dim)
        out = attn @ v
        
        # Reshape and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, n_embd)
        out = self.proj(out)
        
        return out

class FeedForward(nn.Module):
    """Feed-forward network"""
    
    def __init__(self, n_embd, dropout=0.1):
        """
        Initialize the feed-forward network
        
        Args:
            n_embd: Embedding dimension
            dropout: Dropout probability
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        """Forward pass"""
        return self.net(x)

class Block(nn.Module):
    """Transformer block"""
    
    def __init__(self, n_embd, n_head, dropout=0.1):
        """
        Initialize the transformer block
        
        Args:
            n_embd: Embedding dimension
            n_head: Number of attention heads
            dropout: Dropout probability
        """
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_embd, n_head, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd, dropout)
    
    def forward(self, x):
        """Forward pass"""
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

class NanoGPT(nn.Module):
    """Small GPT model"""
    
    def __init__(self, vocab_size, n_embd=128, n_head=4, n_layer=4, dropout=0.1):
        """
        Initialize the nanoGPT model
        
        Args:
            vocab_size: Size of the vocabulary
            n_embd: Embedding dimension
            n_head: Number of attention heads
            n_layer: Number of transformer blocks
            dropout: Dropout probability
        """
        super().__init__()
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layer = n_layer
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        
        # Position embedding
        self.position_embedding = nn.Embedding(context_length, n_embd)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            Block(n_embd, n_head, dropout) for _ in range(n_layer)
        ])
        
        # Final layer normalization
        self.ln_f = nn.LayerNorm(n_embd)
        
        # Output head
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        
        # Initialize weights
        self.apply(self._init_weights)
        
        # Print model size
        n_params = sum(p.numel() for p in self.parameters())
        print(f"NanoGPT model with {n_params/1e6:.2f}M parameters")
    
    def _init_weights(self, module):
        """Initialize weights"""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, idx):
        """Forward pass"""
        batch_size, seq_len = idx.size()
        
        # Get token and position embeddings
        token_emb = self.token_embedding(idx)
        pos = torch.arange(0, seq_len, dtype=torch.long, device=idx.device).unsqueeze(0)
        pos_emb = self.position_embedding(pos)
        
        # Combine embeddings
        x = token_emb + pos_emb
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Apply final layer norm
        x = self.ln_f(x)
        
        # Apply output head
        logits = self.head(x)
        
        return logits
    
    def generate(self, idx, max_new_tokens, temperature=1.0):
        """
        Generate text from the model
        
        Args:
            idx: Starting token IDs (batch_size, seq_len)
            max_new_tokens: Maximum number of new tokens to generate
            temperature: Temperature for sampling (higher = more random)
            
        Returns:
            Generated token IDs
        """
        self.eval()
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Crop context if it's too long
                idx_cond = idx if idx.size(1) <= context_length else idx[:, -context_length:]
                
                # Get predictions
                logits = self(idx_cond)
                
                # Focus on the last token
                logits = logits[:, -1, :] / temperature
                
                # Apply softmax to get probabilities
                probs = F.softmax(logits, dim=-1)
                
                # Sample from the distribution
                idx_next = torch.multinomial(probs, num_samples=1)
                
                # Append to the sequence
                idx = torch.cat((idx, idx_next), dim=1)
        
        return idx

# Create a small nanoGPT model
model = NanoGPT(
    vocab_size=tokenizer.vocab_size,
    n_embd=128,
    n_head=4,
    n_layer=4,
    dropout=0.1
)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"Using device: {device}")

## 4. Training the Model

In [None]:
# Let's train the nanoGPT model

def train_model(model, train_loader, val_loader, epochs=10, lr=3e-4):
    """
    Train the model
    
    Args:
        model: The model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        epochs: Number of epochs to train for
        lr: Learning rate
        
    Returns:
        Training history
    """
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Initialize learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    # Initialize loss function
    criterion = nn.CrossEntropyLoss()
    
    # Initialize training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_perplexity': [],
        'val_perplexity': []
    }
    
    # Training loop
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        # Progress bar for training
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for x, y in train_pbar:
            # Move data to device
            x, y = x.to(device), y.to(device)
            
            # Forward pass
            logits = model(x)
            
            # Reshape for loss calculation
            logits = logits.view(-1, tokenizer.vocab_size)
            y = y.view(-1)
            
            # Calculate loss
            loss = criterion(logits, y)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update weights
            optimizer.step()
            
            # Update statistics
            train_loss += loss.item()
            train_batches += 1
            
            # Update progress bar
            train_pbar.set_postfix({'loss': train_loss / train_batches})
        
        # Calculate average training loss
        avg_train_loss = train_loss / train_batches
        train_perplexity = np.exp(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_batches = 0
        
        # Progress bar for validation
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
        with torch.no_grad():
            for x, y in val_pbar:
                # Move data to device
                x, y = x.to(device), y.to(device)
                
                # Forward pass
                logits = model(x)
                
                # Reshape for loss calculation
                logits = logits.view(-1, tokenizer.vocab_size)
                y = y.view(-1)
                
                # Calculate loss
                loss = criterion(logits, y)
                
                # Update statistics
                val_loss += loss.item()
                val_batches += 1
                
                # Update progress bar
                val_pbar.set_postfix({'loss': val_loss / val_batches})
        
        # Calculate average validation loss
        avg_val_loss = val_loss / val_batches
        val_perplexity = np.exp(avg_val_loss)
        
        # Update learning rate
        scheduler.step()
        
        # Print epoch summary
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {avg_train_loss:.4f}, "
              f"Train Perplexity: {train_perplexity:.4f}, "
              f"Val Loss: {avg_val_loss:.4f}, "
              f"Val Perplexity: {val_perplexity:.4f}")
        
        # Update history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_perplexity'].append(train_perplexity)
        history['val_perplexity'].append(val_perplexity)
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss
        }, f'models/nanogpt_checkpoint_epoch_{epoch+1}.pt')
    
    # Save final model
    torch.save(model.state_dict(), 'models/nanogpt.pt')
    print("Model saved to models/nanogpt.pt")
    
    return history

# Train the model
epochs = 10  # Adjust based on your computational resources
history = train_model(model, train_loader, val_loader, epochs=epochs)

# Plot training history
plt.figure(figsize=(12, 5))

# Plot loss
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Validation')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot perplexity
plt.subplot(1, 2, 2)
plt.plot(history['train_perplexity'], label='Train')
plt.plot(history['val_perplexity'], label='Validation')
plt.title('Perplexity')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.legend()

plt.tight_layout()
plt.savefig('results/part_4/training_history.png')
plt.show()

# Save training metrics
with open('results/part_4/training_metrics.txt', 'w') as f:
    f.write("# NanoGPT Training Metrics\n\n")
    f.write("## Model Configuration\n")
    f.write(f"Vocabulary Size: {tokenizer.vocab_size}\n")
    f.write(f"Embedding Dimension: {model.n_embd}\n")
## 5. Generating Text and Evaluation


# Let's generate text from the trained model and evaluate it

def generate_text(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
    """
    Generate text from the model
    
    Args:
        model: The trained model
        tokenizer: The tokenizer
        prompt: The prompt text
        max_new_tokens: Maximum number of new tokens to generate
        temperature: Temperature for sampling (higher = more random)
        
    Returns:
        Generated text
    """
    # Encode the prompt
    encoded_prompt = tokenizer.encode(prompt)
    
    # Convert to tensor and add batch dimension
    x = torch.tensor([encoded_prompt], dtype=torch.long).to(device)
    
    # Generate text
    output = model.generate(x, max_new_tokens=max_new_tokens, temperature=temperature)
    
    # Decode the output
    generated_text = tokenizer.decode(output[0].tolist())
    
    return generated_text

# Load the trained model
model.load_state_dict(torch.load('models/nanogpt.pt'))
model.eval()

# Generate text with different prompts
prompts = [
    "Diabetes is a chronic condition that",
    "The symptoms of heart disease include",
    "To prevent respiratory infections,",
    "The treatment for hypertension typically involves",
    "Mental health disorders are characterized by"
]

# Generate and print text for each prompt
print("Generated Text Samples:")
for i, prompt in enumerate(prompts):
    print(f"\nPrompt {i+1}: {prompt}")
    generated_text = generate_text(model, tokenizer, prompt, max_new_tokens=100)
    print(f"Generated: {generated_text}")

# Compare with larger pre-trained models
try:
    from transformers import pipeline
    
    # Load a pre-trained model
    generator = pipeline('text-generation', model='gpt2')
    
    print("\n\nComparison with GPT-2:")
    for i, prompt in enumerate(prompts[:2]):  # Just try a couple of prompts
        print(f"\nPrompt: {prompt}")
        
        # Generate with our nanoGPT
        nano_text = generate_text(model, tokenizer, prompt, max_new_tokens=50)
        print(f"NanoGPT: {nano_text}")
        
        # Generate with GPT-2
        gpt2_text = generator(prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
        print(f"GPT-2: {gpt2_text}")
except:
    print("\nSkipping comparison with pre-trained models (requires internet connection)")

# Evaluate the quality of generated text
def evaluate_generated_text(generated_samples):
    """
    Evaluate the quality of generated text
    
    Args:
        generated_samples: List of generated text samples
        
    Returns:
        Evaluation metrics
    """
    # Simple metrics for text quality
    metrics = {
        'avg_length': 0,
        'unique_words': 0,
        'repetition_rate': 0
    }
    
    total_length = 0
    total_unique_words = 0
    total_repetition_rate = 0
    
    for text in generated_samples:
        # Calculate length
        words = text.split()
        length = len(words)
        total_length += length
        
        # Calculate unique words
        unique_words = len(set(words))
        total_unique_words += unique_words
        
        # Calculate repetition rate (lower is better)
        if length > 0:
            repetition_rate = 1 - (unique_words / length)
        else:
            repetition_rate = 0
        total_repetition_rate += repetition_rate
    
    # Calculate averages
    n_samples = len(generated_samples)
    if n_samples > 0:
        metrics['avg_length'] = total_length / n_samples
        metrics['unique_words'] = total_unique_words / n_samples
        metrics['repetition_rate'] = total_repetition_rate / n_samples
    
    return metrics

# Generate a larger set of samples for evaluation
evaluation_prompts = [
    "The patient presented with",
    "Common side effects include",
    "The diagnosis was confirmed by",
    "Treatment options for this condition",
    "The prognosis for patients with"
]

generated_samples = []
for prompt in evaluation_prompts:
    for temp in [0.7, 0.8, 0.9]:  # Try different temperatures
        generated_text = generate_text(model, tokenizer, prompt, max_new_tokens=100, temperature=temp)
        generated_samples.append(generated_text)

# Evaluate the generated samples
evaluation_metrics = evaluate_generated_text(generated_samples)
print("\nGenerated Text Evaluation:")
for metric, value in evaluation_metrics.items():
    print(f"{metric}: {value:.4f}")

# Save evaluation results
with open('results/part_4/generation_evaluation.txt', 'w') as f:
    f.write("# NanoGPT Text Generation Evaluation\n\n")
    
    f.write("## Evaluation Metrics\n")
    for metric, value in evaluation_metrics.items():
        f.write(f"{metric}: {value:.4f}\n")
    
    f.write("\n## Generated Samples\n")
    for i, (prompt, sample) in enumerate(zip(evaluation_prompts * 3, generated_samples)):
        f.write(f"\nSample {i+1}:\n")
        f.write(f"Prompt: {prompt}\n")
        f.write(f"Generated: {sample}\n")
        f.write("-" * 50 + "\n")

print("Evaluation results saved to results/part_4/generation_evaluation.txt")
```

## Progress Checkpoints

1. **Data Exploration**:
   - [ ] Download and analyze the Synthetic Mention Corpora
   - [ ] Extract disease mentions and contexts
   - [ ] Determine if it's suitable for training
   - [ ] Prepare the text data for model training

2. **Data Preprocessing**:
   - [ ] Create a character-level tokenizer
   - [ ] Encode the text data
   - [ ] Split into train and validation sets

3. **Model Implementation**:
   - [ ] Implement the nanoGPT architecture
   - [ ] Configure model size based on available resources
   - [ ] Verify model structure and parameter count

4. **Training**:
   - [ ] Train the model with appropriate hyperparameters
   - [ ] Monitor training progress
   - [ ] Save checkpoints and final model

5. **Evaluation**:
   - [ ] Generate text with different prompts
   - [ ] Compare with larger pre-trained models
   - [ ] Evaluate text quality metrics
   - [ ] Save evaluation results
    f.write(f"Number of Heads: {model.n_head}\n")
    f.write(f"Number of Layers: {model.n_layer}\n")
    f.write(f"Context Length: {context_length}\n")
    f.write(f"Batch Size: {batch_size}\n")
    f.write(f"Epochs: {epochs}\n\n")
    
    f.write("## Training Results\n")
    f.write(f"Final Train Loss: {history['train_loss'][-1]:.4f}\n")
    f.write(f"Final Validation Loss: {history['val_loss'][-1]:.4f}\n")
    f.write(f"Final Train Perplexity: {history['train_perplexity'][-1]:.4f}\n")
    f.write(f"Final Validation Perplexity: {history['val_perplexity'][-1]:.4f}\n\n")
    
    f.write("## Epoch-by-Epoch Metrics\n")
    for i in range(epochs):
        f.write(f"Epoch {i+1}:\n")
        f.write(f"  Train Loss: {history['train_loss'][i]:.4f}\n")
        f.write(f"  Val Loss: {history['val_loss'][i]:.4f}\n")
        f.write(f"  Train Perplexity: {history['train_perplexity'][i]:.4f}\n")
        f.write(f"  Val Perplexity: {history['val_perplexity'][i]:.4f}\n")

print("Training metrics saved to results/part_4/training_metrics.txt")
```