# Transformer Model Local Training

This notebook provides a local training interface for the transformer model with debugging and visualization capabilities.

In [None]:
import torch
from config import get_config
cfg = get_config()
cfg['batch_size'] = 2
cfg['preload'] = None
cfg['num_epochs'] = 30

# Print current configuration
print("Current Configuration:")
for key, value in cfg.items():
    print(f"{key}: {value}")

## Additional Configuration Options

Adjust training parameters as needed before starting training.

In [None]:
# Modify these parameters as needed
cfg['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {cfg['device']}")

# Uncomment to adjust learning rate
# cfg['lr'] = 0.0005

# Uncomment to use a specific saved model
# cfg['preload'] = 'weights/tmodel_best.pt'

# Uncomment to use gradient accumulation for larger effective batch sizes
# cfg['gradient_accumulation_steps'] = 4

## Training with Visualization

Execute training with progress tracking and visualization.

In [None]:
# Import visualization libraries
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

# Create a wrapper for the train_model function with visualization
def train_with_visualization(config):
    from train import train_model
    import time
    
    # Start timing
    start_time = time.time()
    
    # Capture training history
    history = {'train_loss': [], 'val_loss': []}
    
    # Override the validation callback to capture metrics
    original_callback = None
    if 'after_validation' in config:
        original_callback = config['after_validation']
    
    def visualization_callback(epoch, train_loss, val_loss, model):
        # Append losses to history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        # Call original callback if it exists
        if original_callback:
            original_callback(epoch, train_loss, val_loss, model)
        
        # Plot training progress every few epochs
        if epoch % 2 == 0 or epoch == config['num_epochs'] - 1:
            clear_output(wait=True)
            fig, ax = plt.subplots(figsize=(10, 5))
            ax.plot(history['train_loss'], label='Training Loss')
            ax.plot(history['val_loss'], label='Validation Loss')
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_title(f'Training Progress (Epoch {epoch+1}/{config["num_epochs"]})')
            ax.legend()
            ax.grid(True)
            plt.tight_layout()
            plt.show()
            
            # Print current metrics
            elapsed_time = time.time() - start_time
            print(f"Epoch {epoch+1}/{config['num_epochs']} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"Time elapsed: {elapsed_time/60:.2f} minutes")
    
    # Set our visualization callback
    config['after_validation'] = visualization_callback
    
    # Run the training
    try:
        model = train_model(config)
        print(f"\nTraining completed in {(time.time() - start_time)/60:.2f} minutes")
        return model, history
    except Exception as e:
        print(f"Error during training: {str(e)}")
        import traceback
        print(traceback.format_exc())
        return None, history

In [None]:
# Run training with visualization
model, history = train_with_visualization(cfg)

## Model Evaluation

Test the trained model with sample translations.

In [None]:
def evaluate_model(model, cfg):
    if model is None:
        print("No model available for evaluation. Training may have failed.")
        return
    
    # Import necessary modules for inference
    from dataset import get_tokenizers
    from model import build_transformer
    import torch
    
    # Get the tokenizers
    tokenizer_src, tokenizer_tgt = get_tokenizers(cfg)
    
    # Set the model to evaluation mode
    model.eval()
    
    # Sample sentences to translate
    sample_sentences = [
        "Hello, how are you doing today?",
        "I would like to book a table for dinner tonight.",
        "The weather is beautiful outside.",
        "Can you help me find my way to the train station?"
    ]
    
    # Translation function
    def translate(sentence):
        # Tokenize the source sentence
        tokens = tokenizer_src.encode(sentence).ids
        tokens = torch.tensor([tokens], dtype=torch.int64).to(cfg['device'])
        
        # Get the encoder output
        enc_output = model.encode(tokens)
        
        # Initialize the decoder input with the start token
        dec_input = torch.tensor([[tokenizer_tgt.token_to_id('[BOS]')]], dtype=torch.int64).to(cfg['device'])
        
        # Generate the translation token by token
        for _ in range(100):  # Max length
            # Get the decoder output
            dec_output = model.decode(dec_input, enc_output, tokens)
            prediction = dec_output[:, -1, :]  # Get the last token prediction
            
            # Get the token with the highest probability
            next_token = torch.argmax(prediction, dim=-1).unsqueeze(1)
            
            # Append the predicted token to the decoder input
            dec_input = torch.cat([dec_input, next_token], dim=1)
            
            # Check if we've generated the end token
            if next_token.item() == tokenizer_tgt.token_to_id('[EOS]'):
                break
        
        # Convert token IDs back to text
        translation = tokenizer_tgt.decode(dec_input[0].cpu().numpy())
        # Remove special tokens
        translation = translation.replace('[BOS]', '').replace('[EOS]', '').strip()
        return translation
    
    # Translate and print all sample sentences
    print(f"\nSample translations ({cfg['lang_src']} → {cfg['lang_tgt']}):\n")
    for sentence in sample_sentences:
        translation = translate(sentence)
        print(f"Source: {sentence}")
        print(f"Translation: {translation}\n")

In [None]:
# Evaluate the trained model
if model is not None:
    evaluate_model(model, cfg)
else:
    print("Model training failed or was not completed. Cannot perform evaluation.")

## Debugging Tools

Troubleshoot model and training issues.

In [None]:
# Debug module imports and configurations
def debug_environment():
    import sys
    import torch
    
    print("Python version:", sys.version)
    print("PyTorch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("CUDA version:", torch.version.cuda)
        print("GPU:", torch.cuda.get_device_name(0))
        print("GPU memory allocated:", torch.cuda.memory_allocated(0) / 1e9, "GB")
        print("GPU memory cached:", torch.cuda.memory_reserved(0) / 1e9, "GB")
    
    print("\nConfiguration:")
    for key, value in cfg.items():
        print(f"{key}: {value}")

# Uncomment to run environment diagnostics
# debug_environment()

In [None]:
# Save model and training history
def save_results(model, history, cfg, custom_name=None):
    import os
    import json
    import torch
    import matplotlib.pyplot as plt
    from datetime import datetime
    
    if model is None:
        print("No model to save. Training may have failed.")
        return
    
    # Create timestamp for filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_name = custom_name if custom_name else f"model_{timestamp}"
    
    # Create directory if it doesn't exist
    os.makedirs(cfg['model_folder'], exist_ok=True)
    
    # Save the model
    model_path = os.path.join(cfg['model_folder'], f"{model_name}.pt")
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")
    
    # Save the training history
    history_path = os.path.join(cfg['model_folder'], f"{model_name}_history.json")
    with open(history_path, 'w') as f:
        # Convert numpy arrays to lists for JSON serialization
        serializable_history = {
            'train_loss': [float(x) for x in history['train_loss']],
            'val_loss': [float(x) for x in history['val_loss']]
        }
        json.dump(serializable_history, f)
    print(f"Training history saved to {history_path}")
    
    # Plot and save the learning curves
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    
    # Save the plot
    plot_path = os.path.join(cfg['model_folder'], f"{model_name}_plot.png")
    plt.savefig(plot_path)
    plt.show()
    print(f"Learning curves plot saved to {plot_path}")

# Uncomment to save your model and training history
# save_results(model, history, cfg, custom_name="my_best_model")