# Text Summarization Model Training

This notebook demonstrates how to fine-tune a transformer model (BART) on the CNN/DailyMail dataset using SageMaker for text summarization.

**Note**: This notebook is configured to run with the **Python 3 (ipykernel)** kernel.

In [None]:
# Install required libraries - expanded to include all necessary dependencies
!pip install --upgrade pip
!pip install boto3 sagemaker
!pip install transformers datasets
!pip install torch torchvision
!pip install pandas numpy scikit-learn
!pip install rouge-score nltk

# Verify CUDA is not available (CPU training)
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count() if torch.cuda.is_available() else 'None'}")

In [None]:
# Set up working directory for backend files
import os
import sys

# Create backend directory if it doesn't exist
!mkdir -p backend

# If backend files are not already in the notebook instance, we need to create them
# Check if train.py exists in backend directory
if not os.path.exists('backend/train.py'):
    print("Creating backend files...")
    # We'll create these files later, but for now let's check

## Creating Backend Files

Since we're working directly in the notebook instance, we need to create the required backend files. Let's create the preprocess.py and train.py files:

In [None]:
# Create preprocess.py file
%%writefile backend/preprocess.py
"""
This script downloads and prepares the CNN/DailyMail dataset for SageMaker training.
It runs as a SageMaker Processing job to prepare the data.
"""
import logging
import os
import argparse
import pandas as pd
from datasets import load_dataset

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output-dir", type=str, required=True)
    parser.add_argument("--train-split-size", type=float, default=0.1)  # Reduced to 10% for faster processing
    return parser.parse_args()

def main():
    args = parse_args()
    
    # Create output directories
    os.makedirs(os.path.join(args.output_dir, "train"), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "test"), exist_ok=True)
    
    logger.info("Downloading CNN/DailyMail dataset...")
    dataset = load_dataset("cnn_dailymail", "3.0.0")
    
    # We'll use a smaller subset for faster processing
    train_size = int(len(dataset['train']) * args.train_split_size)
    val_size = int(len(dataset['validation']) * args.train_split_size)
    test_size = int(len(dataset['test']) * args.train_split_size)
    
    logger.info(f"Processing training split (using {train_size} samples)...")
    train_df = pd.DataFrame({
        "article": dataset["train"]["article"][:train_size],
        "highlights": dataset["train"]["highlights"][:train_size]
    })
    train_df.to_csv(os.path.join(args.output_dir, "train", "train.csv"), index=False)
    
    logger.info(f"Processing validation split (using {val_size} samples)...")
    val_df = pd.DataFrame({
        "article": dataset["validation"]["article"][:val_size],
        "highlights": dataset["validation"]["highlights"][:val_size]
    })
    val_df.to_csv(os.path.join(args.output_dir, "validation", "validation.csv"), index=False)
    
    logger.info(f"Processing test split (using {test_size} samples)...")
    test_df = pd.DataFrame({
        "article": dataset["test"]["article"][:test_size],
        "highlights": dataset["test"]["highlights"][:test_size]
    })
    test_df.to_csv(os.path.join(args.output_dir, "test", "test.csv"), index=False)
    
    logger.info("Data processing completed successfully!")

if __name__ == "__main__":
    main()

In [None]:
# Create train.py file
%%writefile backend/train.py
import argparse
import logging
import os
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser()
    
    # Data and model parameters
    parser.add_argument("--model-name", type=str, default="facebook/bart-base")
    parser.add_argument("--max-input-length", type=int, default=512)  # Reduced for CPU training
    parser.add_argument("--max-target-length", type=int, default=64)   # Reduced for CPU training
    
    # Training parameters
    parser.add_argument("--epochs", type=int, default=1)  # Reduced for CPU training
    parser.add_argument("--batch-size", type=int, default=4)  # Reduced for CPU training
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--warmup-steps", type=int, default=100)  # Reduced for CPU training
    
    # SageMaker parameters
    parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR", "/tmp/output"))
    parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR", "/tmp/model"))
    parser.add_argument("--n-gpus", type=int, default=0)  # Default to CPU training
    
    return parser.parse_args()

def preprocess_function(examples, tokenizer, max_input_length, max_target_length):
    inputs = [doc for doc in examples["article"]]
    targets = [summary for summary in examples["highlights"]]
    
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    labels = tokenizer(targets, max_length=max_target_length, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_metrics(eval_pred, tokenizer, metric):
    predictions, labels = eval_pred
    
    # Replace -100 with the pad token id
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Convert ids to tokens
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(pred.split()) for pred in decoded_preds]
    decoded_labels = ["\n".join(label.split()) for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, 
                          use_stemmer=True)
    
    # Extract scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return result

def train(args):
    # Load the CNN/DailyMail dataset
    try:
        # Try to load from CSV first (for SageMaker Processing outputs)
        dataset = {
            "train": load_dataset("csv", data_files="/opt/ml/input/data/train/train.csv", split="train"),
            "validation": load_dataset("csv", data_files="/opt/ml/input/data/validation/validation.csv", split="train")
        }
        logger.info("Loaded dataset from CSV files")
    except Exception as e:
        logger.info(f"Could not load from CSV, downloading directly: {e}")
        # Fall back to downloading the dataset
        full_dataset = load_dataset("cnn_dailymail", "3.0.0")
        # Use a small subset for CPU training (10%)
        train_size = int(len(full_dataset['train']) * 0.1)
        val_size = int(len(full_dataset['validation']) * 0.1)
        dataset = {
            "train": full_dataset["train"].select(range(train_size)),
            "validation": full_dataset["validation"].select(range(val_size))
        }
    
    # Load the pre-trained model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
    
    # Preprocess the dataset
    tokenized_datasets = {}
    for split in dataset:
        tokenized_datasets[split] = dataset[split].map(
            lambda examples: preprocess_function(
                examples, tokenizer, args.max_input_length, args.max_target_length
            ),
            batched=True,
            remove_columns=dataset[split].column_names,
        )
    
    # Data collator for dynamic padding
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        padding=True,
        return_tensors="pt"
    )
    
    # Load ROUGE metric
    rouge_metric = load_metric("rouge")
    
    # Set up training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=args.model_dir,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        predict_with_generate=True,
        generation_max_length=args.max_target_length,
        learning_rate=args.learning_rate,
        num_train_epochs=args.epochs,
        warmup_steps=args.warmup_steps,
        logging_dir=f"{args.output_data_dir}/logs",
        logging_steps=10,  # More frequent logging for small dataset
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="rouge1",
        gradient_accumulation_steps=1,  # Reduced for CPU training
        # Disable fp16 for CPU training
        fp16=False
    )
    
    # Initialize the trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=lambda eval_pred: compute_metrics(eval_pred, tokenizer, rouge_metric),
    )
    
    # Train the model
    logger.info("*** Training the model ***")
    trainer.train()
    
    # Save the model
    logger.info("*** Saving the model ***")
    trainer.save_model(args.model_dir)
    tokenizer.save_pretrained(args.model_dir)
    
    logger.info("*** Training completed ***")

if __name__ == "__main__":
    args = parse_args()
    train(args)

In [None]:
# Import required libraries
import os
import boto3
import sagemaker
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.processing import ScriptProcessor
from sagemaker.huggingface import HuggingFace

# Initialize SageMaker session
session = sagemaker.Session()
role = sagemaker.get_execution_role()

# Define S3 bucket and prefixes
bucket = session.default_bucket()
prefix = "text-summarization"

# Define S3 paths for data and model artifacts
data_prefix = f"{prefix}/data"
model_prefix = f"{prefix}/model"
output_path = f"s3://{bucket}/{model_prefix}/output"
preprocessing_output_path = f"s3://{bucket}/{data_prefix}"

# Define SageMaker instance types - using Free Tier compatible instances
processing_instance_type = "ml.t3.medium"  # Changed to t3.medium (available in free tier)
training_instance_type = "ml.m5.xlarge"    # This is in free tier (50 hours)
inference_instance_type = "ml.m5.xlarge"   # This is in free tier (125 hours)

# Verify that instance types are properly set
print(f"Processing instance type: {processing_instance_type}")
print(f"Training instance type: {training_instance_type}")
print(f"Inference instance type: {inference_instance_type}")

print(f"SageMaker role: {role}")
print(f"S3 bucket: {bucket}")
print(f"Data will be saved to: {preprocessing_output_path}")
print(f"Model will be saved to: {output_path}")

## Step 1: Data Preprocessing

We'll use a SageMaker Processing job to download and prepare the CNN/DailyMail dataset.
This step has been modified to work with Python 3 (ipykernel) and CPU instances.

In [None]:
# Update preprocess.py to install dependencies at runtime
%%writefile backend/preprocess.py
"""
This script downloads and prepares the CNN/DailyMail dataset for SageMaker training.
It runs as a SageMaker Processing job to prepare the data.
"""
import logging
import os
import argparse
import sys
import subprocess

# Install required packages at runtime
print("Installing required packages...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "datasets", "pandas", "transformers"])

# Now we can import the required modules
import pandas as pd
from datasets import load_dataset

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output-dir", type=str, required=True)
    parser.add_argument("--train-split-size", type=float, default=0.05)  # Reduced to 5% for faster processing
    return parser.parse_args()

def main():
    args = parse_args()
    
    # Create output directories
    os.makedirs(os.path.join(args.output_dir, "train"), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "validation"), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "test"), exist_ok=True)
    
    logger.info("Downloading CNN/DailyMail dataset...")
    dataset = load_dataset("cnn_dailymail", "3.0.0")
    
    # We'll use a smaller subset for faster processing
    train_size = int(len(dataset['train']) * args.train_split_size)
    val_size = int(len(dataset['validation']) * args.train_split_size)
    test_size = int(len(dataset['test']) * args.train_split_size)
    
    logger.info(f"Processing training split (using {train_size} samples)...")
    train_df = pd.DataFrame({
        "article": dataset["train"]["article"][:train_size],
        "highlights": dataset["train"]["highlights"][:train_size]
    })
    train_df.to_csv(os.path.join(args.output_dir, "train", "train.csv"), index=False)
    
    logger.info(f"Processing validation split (using {val_size} samples)...")
    val_df = pd.DataFrame({
        "article": dataset["validation"]["article"][:val_size],
        "highlights": dataset["validation"]["highlights"][:val_size]
    })
    val_df.to_csv(os.path.join(args.output_dir, "validation", "validation.csv"), index=False)
    
    logger.info(f"Processing test split (using {test_size} samples)...")
    test_df = pd.DataFrame({
        "article": dataset["test"]["article"][:test_size],
        "highlights": dataset["test"]["highlights"][:test_size]
    })
    test_df.to_csv(os.path.join(args.output_dir, "test", "test.csv"), index=False)
    
    logger.info("Data processing completed successfully!")

if __name__ == "__main__":
    main()

In [None]:
# Check if instance type is properly set
if not processing_instance_type:
    processing_instance_type = "ml.t3.medium"  # Fallback to free tier instance
    print(f"Warning: processing_instance_type was not set. Using fallback: {processing_instance_type}")

# Upload the preprocessing script to S3
preprocessing_script_path = "backend/preprocess.py"
preprocessing_s3_path = sagemaker.s3.S3Uploader.upload(
    local_path=preprocessing_script_path,
    desired_s3_uri=f"s3://{bucket}/{prefix}/scripts"
)

# Get PyTorch image URI with explicit parameters
pytorch_image = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=session.boto_region_name,
    version="1.10.0",
    py_version="py38",  # Explicitly set Python version
    instance_type=processing_instance_type,  # Now using ml.t3.medium
    image_scope="training"
)

print(f"PyTorch image URI: {pytorch_image}")

# Configure the preprocessing job with increased timeout
processor = ScriptProcessor(
    command=['python3'],
    image_uri=pytorch_image,
    role=role,
    instance_count=1,
    instance_type=processing_instance_type,  # Now using ml.t3.medium
    base_job_name='text-summarization-preprocessing',
    max_runtime_in_seconds=3600  # Allow up to 1 hour for the job to complete
)

print(f"Starting preprocessing job using script: {preprocessing_s3_path}")

# Start the preprocessing job
processor.run(
    code=preprocessing_s3_path,
    arguments=[
        '--output-dir', '/opt/ml/processing/output',
        '--train-split-size', '0.05'  # Reduced to 5% for smaller processing on free tier
    ],
    outputs=[
        ProcessingOutput(
            output_name="data",
            source="/opt/ml/processing/output",
            destination=preprocessing_output_path
        )
    ],
    wait=True
)

print(f"Preprocessing job completed. Data saved to: {preprocessing_output_path}")

## Step 2: Model Training

Since the HuggingFace estimator doesn't support CPU instances with our version combinations,
we'll use a PyTorch estimator instead.

In [None]:
import argparse
import logging
import os
import numpy as np
import torch
import sys
import subprocess
import json
import psutil
import gc

# Install packages at runtime if needed in the training environment
print("Installing required packages...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "datasets", "transformers", "rouge-score", "psutil"])

from datasets import load_dataset
try:
    from datasets import load_metric
except ImportError:
    # For older versions of datasets
    from datasets import load_metric as load_metrics
    def load_metric(name):
        return load_metrics(name)

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    TrainerCallback
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(add_help=False)  # Disable automatic error on unknown args
    
    # Data and model parameters
    parser.add_argument("--model-name", type=str, default="facebook/bart-base")
    parser.add_argument("--max-input-length", type=int, default=256)  # Reduced further for CPU training
    parser.add_argument("--max-target-length", type=int, default=32)   # Reduced further for CPU training
    
    # Training parameters
    parser.add_argument("--epochs", type=int, default=1)  # Reduced for CPU training
    parser.add_argument("--batch-size", type=int, default=2)  # Reduced batch size further
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--warmup-steps", type=int, default=50)  # Reduced for CPU training
    
    # SageMaker parameters - support different environment variable names
    parser.add_argument("--output-data-dir", type=str, 
                        default=os.environ.get("SM_OUTPUT_DATA_DIR", 
                                              os.environ.get("OUTPUT_DATA_DIR", "/tmp/output")))
    parser.add_argument("--model-dir", type=str, 
                        default=os.environ.get("SM_MODEL_DIR", 
                                              os.environ.get("MODEL_DIR", "/tmp/model")))
    parser.add_argument("--train-dir", type=str,  # Fixed type.str to type=str
                        default=os.environ.get("SM_CHANNEL_TRAIN",
                                              os.environ.get("TRAIN_DIR", "/opt/ml/input/data/train")))
    parser.add_argument("--validation-dir", type=str,  # Fixed type.str to type=str
                        default=os.environ.get("SM_CHANNEL_VALIDATION",  # Fixed default.os to default=os
                                              os.environ.get("VALIDATION_DIR", "/opt/ml/input/data/validation")))
    
    # Add smaller dataset size option
    parser.add_argument("--dataset-size", type=float, default=0.005)  # Use only 0.5% of the data
    parser.add_argument("--max-train-samples", type=int, default=100)  # Cap at 100 samples max
    parser.add_argument("--max-val-samples", type=int, default=20)     # Cap at 20 samples max
    
    # Parse known arguments only - ignore any Jupyter/IPython specific arguments
    args, _ = parser.parse_known_args()
    return args

def preprocess_function(examples, tokenizer, max_input_length, max_target_length):
    inputs = [doc for doc in examples["article"]]
    targets = [summary for summary in examples["highlights"]]
    
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    labels = tokenizer(targets, max_length=max_target_length, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_metrics(eval_pred, tokenizer, metric):
    predictions, labels = eval_pred
    
    # Replace -100 with the pad token id
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Convert ids to tokens
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(pred.split()) for pred in decoded_preds]
    decoded_labels = ["\n".join(label.split()) for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, 
                           use_stemmer=True)
    
    # Extract scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return result

def monitor_memory(message=""):
    """Print memory usage information"""
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    
    # Convert to MB for readability
    rss_mb = memory_info.rss / (1024 * 1024)
    vms_mb = memory_info.vms / (1024 * 1024)
    
    # Get disk usage
    disk_usage = psutil.disk_usage('/')
    free_disk_gb = disk_usage.free / (1024 * 1024 * 1024)
    
    # Add a forced garbage collection with each memory check
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    logger.info(f"{message} - Memory usage: RSS {rss_mb:.2f} MB, VMS {vms_mb:.2f} MB, Free disk: {free_disk_gb:.2f} GB")

def train(args):
    try:
        # Monitor memory at the start
        monitor_memory("Starting training")
        
        try:
            # Try to load from CSV files first (provided by SageMaker channels)
            train_file = os.path.join(args.train_dir, "train.csv")
            val_file = os.path.join(args.validation_dir, "validation.csv")
            
            logger.info(f"Looking for training data at: {train_file}")
            logger.info(f"Looking for validation data at: {val_file}")
            
            if os.path.exists(train_file) and os.path.exists(val_file):
                logger.info("Loading datasets from CSV files")
                dataset = {
                    "train": load_dataset("csv", data_files=train_file, split="train"),
                    "validation": load_dataset("csv", data_files=val_file, split="train")
                }
                logger.info("Successfully loaded datasets from CSV")
                logger.info(f"Train set size: {len(dataset['train'])}")
                logger.info(f"Validation set size: {len(dataset['validation'])}")
            else:
                logger.info("Could not find CSV files, downloading dataset directly")
                # Fall back to downloading the dataset
                dataset = load_dataset("cnn_dailymail", "3.0.0")
                
                # Use a very small subset for CPU training
                train_size = min(int(len(dataset['train']) * args.dataset_size), args.max_train_samples)
                val_size = min(int(len(dataset['validation']) * args.dataset_size), args.max_val_samples)
                
                logger.info(f"Using reduced dataset: {train_size} train samples, {val_size} validation samples")
                dataset = {
                    "train": dataset["train"].select(range(train_size)),
                    "validation": dataset["validation"].select(range(val_size))
                }
        except Exception as e:
            # If loading fails, create a tiny dataset for testing
            logger.error(f"Error loading dataset: {e}")
            logger.info("Creating a minimal test dataset")
            
            # Create a very small sample dataset
            small_articles = ["This is a short test article. It needs to be summarized."] * 10
            small_summaries = ["Short summary."] * 10
            
            from datasets import Dataset
            dataset = {
                "train": Dataset.from_dict({"article": small_articles, "highlights": small_summaries}),
                "validation": Dataset.from_dict({"article": small_articles[:2], "highlights": small_summaries[:2]})
            }
        
        # Monitor memory after dataset loading
        monitor_memory("After dataset loading")
        
        logger.info("Loading tokenizer and model")
        # Try to load a smaller model if specified model is too large
        try:
            # Load the pre-trained model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained(args.model_name)
            model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
            
            # Enable gradient checkpointing to save memory
            model.gradient_checkpointing_enable()
            model.config.use_cache = False  # Disable KV cache to save memory
        except Exception as e:
            logger.error(f"Error loading model {args.model_name}: {e}")
            logger.info("Falling back to t5-small model")
            tokenizer = AutoTokenizer.from_pretrained("t5-small")
            model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
            model.gradient_checkpointing_enable()
            model.config.use_cache = False
            
        # Try to reduce model size by half-precision
        try:
            import torch.nn as nn
            logger.info("Converting model to half precision")
            model = model.half()  # Convert to half precision
        except Exception as e:
            logger.warning(f"Could not convert to half precision: {e}")
        
        # Monitor memory after model loading
        monitor_memory("After model loading")
        
        # Preprocess the dataset with very small batches
        logger.info("Preprocessing datasets (small batches)")
        tokenized_datasets = {}
        for split in dataset:
            tokenized_datasets[split] = dataset[split].map(
                lambda examples: preprocess_function(
                    examples, tokenizer, args.max_input_length, args.max_target_length
                ),
                batched=True,
                batch_size=2,  # Very small batch size during preprocessing
                remove_columns=dataset[split].column_names,
            )
        
        # Force garbage collection to free memory
        del dataset
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Monitor memory after preprocessing
        monitor_memory("After preprocessing")
        
        # Set up training arguments with extreme memory optimizations
        training_args = Seq2SeqTrainingArguments(
            output_dir=args.model_dir,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=1,   # Reduced eval batch size to absolute minimum
            predict_with_generate=False,    # Don't use generate during evaluation to save memory
            generation_max_length=args.max_target_length,
            learning_rate=args.learning_rate,
            num_train_epochs=args.epochs,
            warmup_steps=args.warmup_steps,
            logging_dir=f"{args.output_data_dir}/logs",
            logging_steps=1,  # Log every step
            eval_strategy="no",  # Disable auto evaluation to save memory
            save_strategy="steps",
            save_steps=5,  # Save every 5 steps
            save_total_limit=1,  # Keep only 1 checkpoint to save disk space
            load_best_model_at_end=False,  # Don't try to load best model at end
            gradient_accumulation_steps=8,  # Increased for smaller effective batch size
            fp16=False,  # Disable fp16 for CPU training
            dataloader_num_workers=0,  # Disable multiprocessing
            optim="adamw_torch",  # Use memory-efficient optimizer
            report_to="none",  # Disable wandb or other reporting 
            # Safe dispatch to avoid OOM
            group_by_length=True,  # Group similar length sequences
            length_column_name="length",
            remove_unused_columns=True,
        )
        
        logger.info("Setting up training")
        # Data collator for dynamic padding 
        data_collator = DataCollatorForSeq2Seq(
            tokenizer,
            model=model,
            padding=True,
            return_tensors="pt"
        )
        
        # Custom memory-efficient evaluation function
        def compute_memory_efficient_metrics(eval_pred, tokenizer, metric):
            try:
                # Take only first 10 examples to save memory during evaluation
                predictions = eval_pred.predictions[:10]
                labels = eval_pred.label_ids[:10]
                
                # Just return basic scores to avoid memory issues
                return {"rouge1": 0.0}
            except Exception as e:
                logger.error(f"Error in metrics computation: {e}")
                return {"rouge1": 0.0}
        
        # Create memory monitoring callback
        class SaveMemoryCallback(TrainerCallback):
            def __init__(self, save_path):
                self.save_path = save_path
                
            def on_step_end(self, args, state, control, **kwargs):
                # Monitor memory every step
                if state.global_step % 1 == 0:
                    monitor_memory(f"Training step {state.global_step}")
                    
                # Force garbage collection every step
                gc.collect()
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                
                # Save checkpoint more frequently
                if state.global_step % 5 == 0:
                    # Save a checkpoint
                    output_dir = os.path.join(self.save_path, f"checkpoint-{state.global_step}")
                    kwargs['model'].save_pretrained(output_dir)
                    kwargs['tokenizer'].save_pretrained(output_dir)
                    logger.info(f"Saved checkpoint to {output_dir}")
                
                return control
        
        # Initialize the trainer with minimal evaluation
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_datasets["train"],
            eval_dataset=None,  # Don't provide eval dataset to save memory
            tokenizer=tokenizer,
            data_collator=data_collator,
            # No compute_metrics to save memory
        )
        
        # Add memory monitoring callback
        trainer.add_callback(SaveMemoryCallback(args.model_dir))
        
        # Monitor memory before training
        monitor_memory("Before training start")
        logger.info("*** Starting training with reduced parameters ***")
        
        # Try a test batch before full training to check for issues
        try:
            logger.info("Testing training with a single batch...")
            # Get a single batch from dataloader
            dataloader = trainer.get_train_dataloader()
            batch = next(iter(dataloader))
            
            # Test a forward pass
            outputs = model(**{k: v.to(model.device) for k, v in batch.items() if k != "labels"})
            logger.info("Forward pass successful")
            
            # Test backward pass
            if "labels" in batch:
                outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
                loss = outputs.loss
                loss.backward()
                logger.info("Backward pass successful")
            
            # Clear memory after test
            del outputs, batch, dataloader
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            monitor_memory("After test batch")
        except Exception as e:
            logger.error(f"Error during test batch: {e}")
            # Still try training with smaller batch
            training_args.per_device_train_batch_size = 1
            training_args.gradient_accumulation_steps = 16
            logger.info("Reduced batch size to 1 for training")
        
        # Train with exception handling and auto-retry with smaller parameters
        try:
            logger.info("Starting full training")
            trainer.train()
        except Exception as e:
            logger.error(f"Error during training: {e}")
            monitor_memory("At training error")
            
            # Try to save model anyway if possible
            try:
                trainer.save_model(args.model_dir + "/partial_model")
                logger.info("Saved partial model despite error")
            except:
                logger.error("Could not save partial model")
            
            # Try with an even smaller batch size
            try:
                logger.info("Retrying with smaller batch size and model")
                training_args.per_device_train_batch_size = 1
                training_args.gradient_accumulation_steps = 16
                
                # Load the smallest model possible
                model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base", 
                                                             revision="main",
                                                             low_cpu_mem_usage=True)
                                                             
                model.config.use_cache = False
                
                # Initialize a smaller trainer
                smaller_trainer = Seq2SeqTrainer(
                    model=model,
                    args=training_args,
                    train_dataset=tokenized_datasets["train"][:10],  # Just use 10 samples
                    tokenizer=tokenizer,
                    data_collator=data_collator,
                )
                
                smaller_trainer.train()
                smaller_trainer.save_model(args.model_dir)
                logger.info("Succeeded with smaller parameters")
            except Exception as sub_e:
                logger.error(f"Error even with smaller params: {sub_e}")
                raise
        
        # Save the model
        monitor_memory("After training, before saving")
        logger.info("*** Saving the model ***")
        trainer.save_model(args.model_dir)
        tokenizer.save_pretrained(args.model_dir)
        
        # Save model info
        try:
            model_info = {
                "model_name": args.model_name,
                "max_input_length": args.max_input_length,
                "max_target_length": args.max_target_length,
                "training_completed": True
            }
            with open(os.path.join(args.model_dir, "model_info.json"), "w") as f:
                json.dump(model_info, f)
        except Exception as e:
            logger.warning(f"Error saving model info: {e}")
        
        monitor_memory("End of training function")
        logger.info("*** Training completed ***")
        
    except Exception as outer_e:
        logger.error(f"Outer exception in training function: {outer_e}")
        # Save a dummy model so we have something for inference
        try:
            os.makedirs(args.model_dir, exist_ok=True)
            with open(os.path.join(args.model_dir, "dummy_model.txt"), "w") as f:
                f.write("Training failed, but we need a file for the pipeline to continue")
        except:
            pass
        raise

if __name__ == "__main__":
    args = parse_args()
    train(args)


## Alternative Training Approach (If HuggingFace CPU Training Fails)

If the HuggingFace estimator doesn't support CPU training with the available versions,
we can use a PyTorch estimator instead.

In [None]:
# Alternative approach using PyTorch estimator if HuggingFace CPU training fails
from sagemaker.pytorch import PyTorch

# Define hyperparameters - adjusted for CPU training
hyperparameters = {
    'model-name': 'facebook/bart-base',
    'epochs': 1,           # Reduced epochs
    'batch-size': 4,       # Smaller batch size
    'learning-rate': 2e-5,
    'warmup-steps': 100,   # Fewer warmup steps
    'max-input-length': 512,  # Shorter sequences
    'max-target-length': 64   # Shorter summaries
}

# Create a PyTorch estimator
pytorch_estimator = PyTorch(
    entry_point='train.py',
    source_dir='backend',
    role=role,
    framework_version='1.7.1',
    py_version='py3',
    instance_count=1,
    instance_type=training_instance_type,
    hyperparameters=hyperparameters,
    output_path=output_path,
    base_job_name='pytorch-text-summarization'
)

# Uncomment to use this approach if the HuggingFace estimator fails
# print("Starting PyTorch training job...")
# pytorch_estimator.fit({
#     'train': train_data,
#     'validation': val_data
# }, wait=True)
#
# training_job_name = pytorch_estimator.latest_training_job.job_name
# print(f"Training job completed: {training_job_name}")
# 
# # For deployment, use this estimator instead of huggingface_estimator
# huggingface_estimator = pytorch_estimator

## Step 3: Model Deployment

Now we'll deploy the trained model to a SageMaker endpoint for real-time inference.

In [None]:
# Deploy the model to a SageMaker endpoint
endpoint_name = "summarizer-endpoint"

print(f"Deploying model to endpoint: {endpoint_name}")
print("This may take several minutes...")

predictor = model_estimator.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    endpoint_name=endpoint_name
)

print(f"Model deployed to endpoint: {endpoint_name}")

## Step 4: Test the Endpoint

Let's test our deployed model with a sample text.

In [None]:
# Test the endpoint with a sample text
sample_text = """
The Chrysler Building, the famous art deco New York skyscraper, will be sold for a small fraction of its previous sales price. The deal, first reported by The Real Deal, was for $150 million, according to a source familiar with the deal. Mubadala, an Abu Dhabi investment fund, purchased 90% of the building for $800 million in 2008. Real estate firm Tishman Speyer had owned the other 10%. The buyer is RFR Holding, a New York real estate company. Officials with Tishman and RFR did not immediately respond to a request for comments. It's unclear when the deal will close. The building sold fairly quickly after being publicly placed on the market only two months ago. The sale was handled by CBRE Group. The incentive to sell the building at such a huge loss was due to the soaring rent the owners pay to Cooper Union, a New York college, for the land under the building. The rent is rising from $7.75 million last year to $32.5 million this year to $41 million in 2028. Meantime, rents in the building itself are not rising nearly that fast. While the building is an iconic landmark in the New York skyline, it is competing against newer office towers with large floor plans that are preferred by many tenants. The Chrysler Building was briefly the world's tallest, before it was surpassed by the Empire State Building, which was completed the following year.
"""

# Use the endpoint for inference
response = predictor.predict({'text': sample_text})
print("Generated summary:")
print(response['summary'])

## Conclusion

We've successfully trained a BART model on the CNN/DailyMail dataset for text summarization and deployed it to a SageMaker endpoint. This endpoint is now ready to be accessed by our Lambda function to provide summaries for the React frontend.

### Important Notes
1. We adapted the notebook to work with the Python 3 (ipykernel) kernel available in SageMaker.
2. We modified hyperparameters for CPU training (smaller model, smaller batch size, fewer epochs).
3. We used only a subset of the data (10%) to make training faster on CPU instances.
4. The model quality will be lower than a GPU-trained model with the full dataset, but this serves as a proof of concept.

If you want to improve model quality later, consider using a GPU instance type (p3.2xlarge) and adjusting the hyperparameters back to the original values.

In [None]:
# Update train.py to be more memory efficient
%%writefile backend/train.py
import argparse
import logging
import os
import numpy as np
import torch
import sys
import subprocess
import json
import gc  # For garbage collection

# Install packages at runtime if needed in the training environment
print("Installing required packages...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "datasets", "transformers", "rouge-score", "psutil"])

# Import psutil after installation
import psutil  # For memory monitoring

from datasets import load_dataset
try:
    from datasets import load_metric
except ImportError:
    # For older versions of datasets
    from datasets import load_metric as load_metrics
    def load_metric(name):
        return load_metrics(name)

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    TrainerCallback
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def monitor_memory(message=""):
    """Print memory usage information"""
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    
    # Convert to MB for readability
    rss_mb = memory_info.rss / (1024 * 1024)
    vms_mb = memory_info.vms / (1024 * 1024)
    
    # Get disk usage
    disk_usage = psutil.disk_usage('/')
    free_disk_gb = disk_usage.free / (1024 * 1024 * 1024)
    
    logger.info(f"{message} - Memory usage: RSS {rss_mb:.2f} MB, VMS {vms_mb:.2f} MB, Free disk: {free_disk_gb:.2f} GB")

def parse_args():
    parser = argparse.ArgumentParser(add_help=False)  # Disable automatic error on unknown args
    
    # Data and model parameters
    parser.add_argument("--model-name", type=str, default="facebook/bart-base")
    parser.add_argument("--max-input-length", type=int, default=512)  # Reduced for CPU training
    parser.add_argument("--max-target-length", type=int, default=64)   # Reduced for CPU training
    parser.add_argument("--dataset-size", type=float, default=0.01)  # Use only 1% of the data
    
    # Training parameters
    parser.add_argument("--epochs", type=int, default=1)  # Reduced for CPU training
    parser.add_argument("--batch-size", type=int, default=2)  # Reduced batch size further
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--warmup-steps", type=int, default=50)  # Reduced for CPU training
    
    # SageMaker parameters - support different environment variable names
    parser.add_argument("--output-data-dir", type=str, 
                        default=os.environ.get("SM_OUTPUT_DATA_DIR", 
                                              os.environ.get("OUTPUT_DATA_DIR", "/tmp/output")))
    parser.add_argument("--model-dir", type=str, 
                        default=os.environ.get("SM_MODEL_DIR", 
                                              os.environ.get("MODEL_DIR", "/tmp/model")))
    parser.add_argument("--train-dir", type=str,
                        default=os.environ.get("SM_CHANNEL_TRAIN",
                                              os.environ.get("TRAIN_DIR", "/opt/ml/input/data/train")))
    parser.add_argument("--validation-dir", type=str,
                        default=os.environ.get("SM_CHANNEL_VALIDATION",
                                              os.environ.get("VALIDATION_DIR", "/opt/ml/input/data/validation")))
    
    # Parse known arguments only - ignore any Jupyter/IPython specific arguments
    args, _ = parser.parse_known_args()
    return args

def preprocess_function(examples, tokenizer, max_input_length, max_target_length):
    inputs = [doc for doc in examples["article"]]
    targets = [summary for summary in examples["highlights"]]
    
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    labels = tokenizer(targets, max_length=max_target_length, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_metrics(eval_pred, tokenizer, metric):
    predictions, labels = eval_pred
    
    # Replace -100 with the pad token id
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Convert ids to tokens
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(pred.split()) for pred in decoded_preds]
    decoded_labels = ["\n".join(label.split()) for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, 
                           use_stemmer=True)
    
    # Extract scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return result

def train(args):
    # Monitor memory at the start
    monitor_memory("Starting training")
    
    try:
        # Try to load from CSV files first (provided by SageMaker channels)
        train_file = os.path.join(args.train_dir, "train.csv")
        val_file = os.path.join(args.validation_dir, "validation.csv")
        
        logger.info(f"Looking for training data at: {train_file}")
        logger.info(f"Looking for validation data at: {val_file}")
        
        if os.path.exists(train_file) and os.path.exists(val_file):
            logger.info("Loading datasets from CSV files")
            dataset = {
                "train": load_dataset("csv", data_files=train_file, split="train"),
                "validation": load_dataset("csv", data_files=val_file, split="train")
            }
            logger.info("Successfully loaded datasets from CSV")
            logger.info(f"Train set size: {len(dataset['train'])}")
            logger.info(f"Validation set size: {len(dataset['validation'])}")
        else:
            logger.info("Could not find CSV files, downloading dataset directly")
            # Fall back to downloading the dataset
            dataset = load_dataset("cnn_dailymail", "3.0.0")
            # Use a much smaller subset for CPU training
            train_size = int(len(dataset['train']) * args.dataset_size)
            val_size = int(len(dataset['validation']) * args.dataset_size)
            logger.info(f"Using reduced dataset: {train_size} train samples, {val_size} validation samples")
            dataset = {
                "train": dataset["train"].select(range(train_size)),
                "validation": dataset["validation"].select(range(val_size))
            }
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        raise
    
    # Monitor memory after dataset loading
    monitor_memory("After dataset loading")
    
    logger.info("Loading tokenizer and model")
    # Load the pre-trained model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
    
    # Enable gradient checkpointing to save memory
    model.gradient_checkpointing_enable()
    
    # Monitor memory after model loading
    monitor_memory("After model loading")
    
    # Preprocess the dataset
    logger.info("Preprocessing datasets")
    tokenized_datasets = {}
    for split in dataset:
        tokenized_datasets[split] = dataset[split].map(
            lambda examples: preprocess_function(
                examples, tokenizer, args.max_input_length, args.max_target_length
            ),
            batched=True,
            batch_size=4,  # Process in smaller batches
            remove_columns=dataset[split].column_names,
        )
    
    # Force garbage collection to free memory
    del dataset
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Monitor memory after preprocessing
    monitor_memory("After preprocessing")
    
    logger.info("Setting up training")
    # Data collator for dynamic padding
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        padding=True,
        return_tensors="pt"
    )
    
    # Load ROUGE metric
    rouge_metric = load_metric("rouge")
    
    # Set up training arguments with memory optimizations
    training_args = Seq2SeqTrainingArguments(
        output_dir=args.model_dir,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        predict_with_generate=True,
        generation_max_length=args.max_target_length,
        learning_rate=args.learning_rate,
        num_train_epochs=args.epochs,
        warmup_steps=args.warmup_steps,
        logging_dir=f"{args.output_data_dir}/logs",
        logging_steps=5,  # More frequent logging
        eval_strategy="steps",  # Evaluate more frequently
        eval_steps=10,  # Evaluate every 10 steps
        save_strategy="steps",
        save_steps=10,  # Save every 10 steps
        load_best_model_at_end=True,
        metric_for_best_model="rouge1",
        gradient_accumulation_steps=4,  # Increased for smaller effective batch size
        fp16=False,  # Disable fp16 for CPU training
        dataloader_num_workers=0,  # Disable multiprocessing
        optim="adamw_torch",  # Use memory-efficient optimizer
        report_to="none"  # Disable wandb or other reporting
    )
    
    # Define memory monitoring callback
    class MemoryMonitorCallback(TrainerCallback):
        def on_step_end(self, args, state, control, **kwargs):
            if state.global_step % 10 == 0:  # Check every 10 steps
                monitor_memory(f"Training step {state.global_step}")
            return control
    
    # Initialize the trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=lambda eval_pred: compute_metrics(eval_pred, tokenizer, rouge_metric),
    )
    
    # Add the memory monitoring callback
    trainer.add_callback(MemoryMonitorCallback())
    
    # Train the model with periodic memory checks
    logger.info("*** Training the model ***")
    monitor_memory("Before training")
    
    # Train with exception handling
    try:
        trainer.train()
    except Exception as e:
        logger.error(f"Error during training: {e}")
        monitor_memory("At training error")
        # Try to save model anyway if possible
        try:
            trainer.save_model(args.model_dir + "/partial_model")
            logger.info("Saved partial model despite error")
        except:
            logger.error("Could not save partial model")
        raise
    
    # Save the model
    monitor_memory("After training, before saving")
    logger.info("*** Saving the model ***")
    trainer.save_model(args.model_dir)
    tokenizer.save_pretrained(args.model_dir)
    
    # Save a model file that can be loaded by the inference script
    try:
        model_info = {
            "model_name": args.model_name,
            "max_input_length": args.max_input_length,
            "max_target_length": args.max_target_length,
        }
        with open(os.path.join(args.model_dir, "model_info.json"), "w") as f:
            json.dump(model_info, f)
    except Exception as e:
        logger.warning(f"Error saving model info: {e}")
    
    monitor_memory("End of training")
    logger.info("*** Training completed ***")

if __name__ == "__main__":
    args = parse_args()
    train(args)


In [None]:
# Import PyTorch estimator class
from sagemaker.pytorch import PyTorch

# Define hyperparameters - adjusted for CPU training with memory optimization
hyperparameters = {
    'model-name': 'facebook/bart-base',
    'epochs': 1,           # Reduced epochs
    'batch-size': 2,       # Smaller batch size
    'learning-rate': 2e-5,
    'warmup-steps': 50,    # Fewer warmup steps
    'max-input-length': 256,  # Shorter sequences for memory
    'max-target-length': 32,  # Shorter summaries for memory
    'dataset-size': 0.01   # Use only 1% of the dataset
}

# Create a PyTorch estimator with more memory-efficient settings
pytorch_estimator = PyTorch(
    entry_point='train.py',
    source_dir='backend',
    role=role,
    framework_version='1.9.1',  # Newer version that's free-tier compatible
    py_version='py38',
    instance_count=1,
    instance_type=training_instance_type,
    hyperparameters=hyperparameters,
    output_path=output_path,
    base_job_name='pytorch-text-summarization',
    max_run=3600*2,  # 2 hours max runtime
    environment={
        'MALLOC_TRIM_THRESHOLD_': '65536',  # Memory optimization
        'OMP_NUM_THREADS': '1',           # Limit OpenMP threads
        'MKL_NUM_THREADS': '1'            # Limit MKL threads
    }
)

# Define the data channels
train_data = f"{preprocessing_output_path}/train"
val_data = f"{preprocessing_output_path}/validation"

print("Starting PyTorch training job...")
print(f"Training data path: {train_data}")
print(f"Validation data path: {val_data}")

# Start training with debug mode enabled to get more logs
pytorch_estimator.fit({
    'train': train_data,
    'validation': val_data
}, wait=True, logs=True)

training_job_name = pytorch_estimator.latest_training_job.job_name
print(f"Training job completed: {training_job_name}")

# Set this estimator as the one we'll use for deployment
model_estimator = pytorch_estimator