## Fine-Tuning and Evaluating LLMs with SageMaker Pipelines and MLflow

Running hundreds of experiments, comparing the results, and keeping a track of the ML lifecycle can become very complex. This is where MLflow can help streamline the ML lifecycle, from data preparation to model deployment. By integrating MLflow into your LLM workflow, you can efficiently manage experiment tracking, model versioning, and deployment, providing reproducibility. With MLflow, you can track and compare the performance of multiple LLM experiments, identify the best-performing models, and deploy them to production environments with confidence. 

You can create workflows with SageMaker Pipelines that enable you to prepare data, fine-tune models, and evaluate model performance with simple Python code for each step. 

Now you can use SageMaker managed MLflow to run LLM fine-tuning and evaluation experiments at scale. Specifically:

- MLflow can manage tracking of fine-tuning experiments, comparing evaluation results of different runs, model versioning, deployment, and configuration (such as data and hyperparameters)
- SageMaker Pipelines can orchestrate multiple experiments based on the experiment configuration 
  

The following figure shows the overview of the solution.
![](./ml-16670-arch-with-mlflow.png)

## Prerequisites 
Before you begin, make sure you have the following prerequisites in place:

- [HuggingFace access token](https://huggingface.co/docs/hub/en/security-tokens) – You need a HuggingFace login token to access the DeepSeek-R1-Distill-Llama-8B model and datasets used in this post.

- The notebook will download the DeepSeek-R1-Distill-Llama-8B model from HuggingFace and upload it to your S3 bucket for fine-tuning.

### 1. Setup and Dependencies
Restart the kernel after executing below cells

In [None]:
%pip install -r ./scripts/requirements.txt --upgrade --quiet

In [None]:
from IPython import get_ipython
get_ipython().kernel.do_shutdown(True)

**Importing Libraries and Setting Up Environment**

This part imports all necessary Python modules. It includes SageMaker-specific imports for pipeline creation and execution, which will be used to define the pipeline steps.

In [None]:
import os
import boto3
import sagemaker
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.function_step import step
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
from sagemaker.workflow.fail_step import FailStep
from sagemaker.workflow.steps import CacheConfig

### 2. SageMaker Session and IAM Role

`get_execution_role()`: Retrieves the IAM role that SageMaker will use to access AWS resources. This role needs appropriate permissions for tasks like accessing S3 buckets and creating SageMaker resources.

In [None]:
sagemaker_session = sagemaker.session.Session()
role = sagemaker.get_execution_role()
instance_type = "ml.m5.xlarge"

### 3. Configuration

MLflow integration is crucial for experiment tracking and management. **Update the ARN for the MLflow tracking server.**

mlflow_arn: The ARN for the MLflow tracking server. You can get this ARN from SageMaker Studio UI. This allows the pipeline to log metrics, parameters, and artifacts to a central location.

experiment_name: give appropriate name for experimentation

In [None]:
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
if default_prefix:
    input_path = f'{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft'
else:
    input_path = f'datasets/llm-fine-tuning-modeltrainer-sft'

In [None]:
from botocore.exceptions import ClientError

try:
    response = boto3.client('sagemaker').describe_mlflow_tracking_server(
        TrackingServerName='genai-mlflow-tracker'
    )
    mlflow_tracking_server_uri = response['TrackingServerArn']
except ClientError:
    mlflow_tracking_server_uri = ""

if mlflow_tracking_server_uri == "":
    print("No MLflow Tracking Server Found, experiments will not be tracked.")
else:
    print(f"MLflow Tracking Server ARN: {mlflow_tracking_server_uri}")

In [None]:
# Generate a unique experiment name with timestamp

pipeline_name = "deepseek-finetune-pipeline"
experiment_base_name = "deepseek-finetune-pipeline"


tracking_server_arn = mlflow_tracking_server_uri # Set the MLFlow ARN here
# os.environ["mlflow_uri"] = tracking_server_arn
# os.environ["mlflow_experiment_name"] = experiment_name

model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_id_filesafe = model_id.replace("/","_")

In [None]:
%%writefile config.yaml
SchemaVersion: '1.0'
SageMaker:
  PythonSDK:
    Modules:
      RemoteFunction:
        # role arn is not required if in SageMaker Notebook instance or SageMaker Studio
        # Uncomment the following line and replace with the right execution role if in a local IDE
        # RoleArn: <replace the role arn here>
        InstanceType: ml.m5.xlarge
        Dependencies: ./scripts/requirements.txt
        IncludeLocalWorkDir: true
        CustomFileFilter:
          IgnoreNamePatterns: # files or directories to ignore
          - "*.ipynb" # all notebook files


In [None]:
# Set path to config file
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = os.getcwd()

### Download Model Data from Huggingface

In [None]:
from huggingface_hub import snapshot_download
from sagemaker.s3 import S3Uploader
import os
import subprocess
import boto3
from botocore.exceptions import ClientError
from pathlib import Path

# Simple function to check if file exists in S3
def s3_file_exists(s3_client, bucket, key):
    try:
        s3_client.head_object(Bucket=bucket, Key=key)
        return True
    except ClientError:
        return False

# Simple S3 upload function that checks if files exist before uploading
def simple_s3_upload(local_dir, s3_bucket, s3_prefix, skip_existing=True):
    """
    Upload files to S3, skipping files that already exist.
    
    Args:
        local_dir (str): Local directory containing files to upload
        s3_bucket (str): S3 bucket name
        s3_prefix (str): S3 prefix (folder path)
        skip_existing (bool): Whether to skip files that already exist in S3
        
    Returns:
        tuple: (uploaded_files, skipped_files, failed_files)
    """
    s3_client = boto3.client('s3')
    uploaded_files = []
    skipped_files = []
    failed_files = []
    
    # Get all local files
    local_files = []
    for root, _, files in os.walk(local_dir):
        for filename in files:
            local_path = os.path.join(root, filename)
            rel_path = os.path.relpath(local_path, local_dir)
            s3_key = os.path.join(s3_prefix, rel_path).replace('\\', '/')
            local_files.append((local_path, s3_key))
    
    print(f"Found {len(local_files)} files in {local_dir}")
    
    # Process each file sequentially
    for local_path, s3_key in local_files:
        try:
            # Check if file exists in S3
            if skip_existing and s3_file_exists(s3_client, s3_bucket, s3_key):
                print(f"Skipping {s3_key} (file exists in S3)")
                skipped_files.append(s3_key)
                continue
            
            # Upload the file
            print(f"Uploading {local_path} to s3://{s3_bucket}/{s3_key}")
            s3_client.upload_file(
                local_path, 
                s3_bucket, 
                s3_key,
                ExtraArgs={'ACL': 'bucket-owner-full-control'}
            )
            uploaded_files.append(s3_key)
            
        except Exception as e:
            print(f"Failed to upload {local_path}: {str(e)}")
            failed_files.append((s3_key, str(e)))
    
    print(f"\nUpload Summary:")
    print(f"  - Uploaded: {len(uploaded_files)} files")
    print(f"  - Skipped: {len(skipped_files)} files")
    print(f"  - Failed: {len(failed_files)} files")
    
    return uploaded_files, skipped_files, failed_files

# Set local and S3 model paths
model_local_location = f"../models/{model_id_filesafe}"
if default_prefix:
    model_s3_destination = f"s3://{bucket_name}/{default_prefix}/models/{model_id_filesafe}"
    prefix = f"/{default_prefix}/models/{model_id_filesafe}"
else:
    model_s3_destination = f"s3://{bucket_name}/models/{model_id_filesafe}"
    prefix = f"/models/{model_id_filesafe}"

print("Downloading model ", model_id)
os.makedirs(model_local_location, exist_ok=True)

try:
    snapshot_download(repo_id=model_id, local_dir=model_local_location)
    print(f"Model {model_id} downloaded under {model_local_location}")
    
    print(f"Beginning Model Upload to {model_s3_destination}...")
    
    # Use the simple upload function without threads or batch processing
    uploaded, skipped, failed = simple_s3_upload(
        local_dir=model_local_location,
        s3_bucket=bucket_name,
        s3_prefix="",
        skip_existing=True
    )

    
    print(f"Model successfully uploaded to: \n {model_s3_destination}")
except Exception as e:
    print(f"Error during model download or upload: {e}")
    raise

os.environ["model_location"] = model_s3_destination

### 6. Pipeline Steps

This section defines the core components of the SageMaker pipeline.

**Preprocessing Step**

This step handles data preparation. We are going to prepare data for training and evaluation. We will log this data in MLflow

For the purpose of fine tuning and evaluation we are going to use `FreedomIntelligence/medical-o1-reasoning-SFT` dataset

In [None]:
@step(
    name="DataPreprocessing",
    instance_type=instance_type,
    display_name="Data Preprocessing",
    keep_alive_period_in_seconds=900
)
def preprocess(
    tracking_server_arn: str,
    input_path: str,
    experiment_base_name: str,
    run_id: str,
) -> tuple:
    import boto3
    import shutil
    import sagemaker
    import os
    import pandas as pd
    from sagemaker.config import load_sagemaker_config
    import mlflow
    import traceback
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    experiment_name = f"{experiment_base_name}-{timestamp}"  # Unique experiment name for each run
    
    # Initialize return values with defaults to ensure pipeline failure doesn't cause hard errors
    result_experiment_name = experiment_name
    result_run_id = run_id
    result_train_data_path = None
    result_test_dataset_path = None
    mlflow_run = None
    mlflow_enabled = False
    
    # Check if MLflow tracking server ARN is valid
    mlflow_enabled = (
        tracking_server_arn is not None
        and experiment_name is not None
        and tracking_server_arn != ""
        and experiment_name != ""
    )
    
    if mlflow_enabled:
        print(f"MLflow tracking enabled. Using server: {tracking_server_arn}")
        print(f"MLflow experiment: {experiment_name}")
        try:
            mlflow.set_tracking_uri(tracking_server_arn)
            mlflow.set_experiment(experiment_name)
            mlflow.autolog(log_datasets=True)
            mlflow_run = mlflow.start_run(run_name=f"preprocess-{run_id}")
            active_run_id = mlflow_run.info.run_id
            print(f"Started MLflow run with ID: {active_run_id}")
        except Exception as e:
            error_msg = f"Error initializing MLflow tracking: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            mlflow_enabled = False
            mlflow_run = None
            if mlflow_enabled:  # Only log if MLflow was supposed to be enabled but failed
                try:
                    mlflow.log_param("initialization_error", error_msg)
                except:
                    pass
    else:
        print("MLflow tracking disabled or not configured properly.")
        mlflow_run = None
    
    # Preprocessing code - runs regardless of MLflow status
    try:
        # Initialize SageMaker and S3 clients
        sagemaker_session = sagemaker.Session()
        s3_client = boto3.client('s3')
        
        bucket_name = sagemaker_session.default_bucket()
        default_prefix = sagemaker_session.default_bucket_prefix
        configs = load_sagemaker_config()
        
        from datasets import load_dataset
        
        # Load dataset with proper error handling
        try:
            dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")
            if mlflow_enabled:
                mlflow.log_param("dataset_load_success", True)
        except Exception as e:
            error_msg = f"Error loading dataset: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            if mlflow_enabled:
                mlflow.log_param("dataset_load_success", False)
                mlflow.log_param("dataset_load_error", error_msg)
            raise RuntimeError(f"Failed to load dataset: {str(e)}")
        
        df = pd.DataFrame(dataset['train'])
        df = df[:100]
        
        from sklearn.model_selection import train_test_split
        
        # Split dataset
        train, test = train_test_split(df, test_size=0.1, random_state=42, shuffle=True)
        
        print("Number of train elements: ", len(train))
        print("Number of test elements: ", len(test))
        
        # Log dataset statistics if MLflow is enabled
        if mlflow_enabled:
            mlflow.log_param("dataset_source", "FreedomIntelligence/medical-o1-reasoning-SFT")
            mlflow.log_param("train_size", len(train))
            mlflow.log_param("test_size", len(test))
            mlflow.log_param("dataset_sample_size", 100)  # Log that we're using a subset of 100 samples
        
        # Define prompt template
        prompt_template = f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
        Below is an instruction that describes a task, paired with an input that provides further context. 
        Write a response that appropriately completes the request.
        Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        {{question}}<|eot_id|>
        <|start_header_id|>assistant<|end_header_id|>
        {{complex_cot}}
        
        {{answer}}
        <|eot_id|>
        """
        
        # Template dataset to add prompt to each sample
        def template_dataset(sample):
            try:
                sample["text"] = prompt_template.format(question=sample["Question"],
                                                        complex_cot=sample["Complex_CoT"],
                                                        answer=sample["Response"])
                return sample
            except KeyError as e:
                print(f"KeyError in template_dataset: {str(e)}")
                # Provide default values for missing fields
                missing_key = str(e).strip("'")
                if missing_key == "Question":
                    sample["text"] = prompt_template.format(
                        question="[Missing question]",
                        complex_cot=sample.get("Complex_CoT", "[Missing CoT]"),
                        answer=sample.get("Response", "[Missing response]")
                    )
                elif missing_key == "Complex_CoT":
                    sample["text"] = prompt_template.format(
                        question=sample["Question"],
                        complex_cot="[Missing CoT]",
                        answer=sample.get("Response", "[Missing response]")
                    )
                elif missing_key == "Response":
                    sample["text"] = prompt_template.format(
                        question=sample["Question"],
                        complex_cot=sample.get("Complex_CoT", "[Missing CoT]"),
                        answer="[Missing response]"
                    )
                return sample
        
        from datasets import Dataset, DatasetDict
        from random import randint
        
        # Create datasets
        train_dataset = Dataset.from_pandas(train)
        test_dataset = Dataset.from_pandas(test)
        
        dataset = DatasetDict({"train": train_dataset, "test": test_dataset})
        
        train_dataset = dataset["train"].map(template_dataset, remove_columns=list(dataset["train"].features))
        
        # Safely get a sample text, handling potential index errors
        try:
            sample_index = randint(0, len(train_dataset) - 1)
            sample_text = train_dataset[sample_index]["text"]
            print(f"Sample text from index {sample_index}:")
            print(sample_text)
        except (IndexError, KeyError) as e:
            sample_text = "Error retrieving sample text: " + str(e)
            print(sample_text)
            
        test_dataset = dataset["test"].map(template_dataset, remove_columns=list(dataset["test"].features))
        
        # Set paths
        if default_prefix:
            input_path = f'{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft'
        else:
            input_path = f'datasets/llm-fine-tuning-modeltrainer-sft'
    
        # Create directories with error handling
        try:
            os.makedirs("./data/train", exist_ok=True)
            os.makedirs("./data/test", exist_ok=True)
        except OSError as e:
            error_msg = f"Error creating directories: {str(e)}"
            print(error_msg)
            if mlflow_enabled:
                mlflow.log_param("dir_creation_error", error_msg)
            # Continue with execution as we'll try to save files anyway
        
        # Save datasets locally with error handling
        try:
            train_dataset.to_json("./data/train/dataset.json", orient="records")
            test_dataset.to_json("./data/test/dataset.json", orient="records")
        except Exception as e:
            error_msg = f"Error saving datasets locally: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            if mlflow_enabled:
                mlflow.log_param("local_save_error", error_msg)
            raise RuntimeError(f"Failed to save datasets locally: {str(e)}")
        
        # Define S3 paths
        train_data_path = f"s3://{bucket_name}/{input_path}/train/dataset.json"
        test_dataset_path = f"s3://{bucket_name}/{input_path}/test/dataset.json"
        
        # Store results for return
        result_train_data_path = train_data_path
        result_test_dataset_path = test_dataset_path
        
        # Log dataset paths if MLflow is enabled
        if mlflow_enabled:
            mlflow.log_param("train_data_path", train_data_path)
            mlflow.log_param("test_dataset_path", test_dataset_path)
        
        # Upload files to S3 with retries
        max_retries = 3
        for attempt in range(max_retries):
            try:
                print(f"Uploading train dataset to S3, attempt {attempt+1}/{max_retries}")
                s3_client.upload_file("./data/train/dataset.json", bucket_name, f"{input_path}/train/dataset.json")
                print(f"Uploading test dataset to S3, attempt {attempt+1}/{max_retries}")
                s3_client.upload_file("./data/test/dataset.json", bucket_name, f"{input_path}/test/dataset.json")
                print("S3 upload successful")
                break
            except Exception as e:
                error_msg = f"Error in S3 upload (attempt {attempt+1}/{max_retries}): {str(e)}"
                print(error_msg)
                if attempt == max_retries - 1:  # Last attempt failed
                    if mlflow_enabled:
                        mlflow.log_param("s3_upload_error", error_msg)
                    raise RuntimeError(f"Failed to upload datasets to S3 after {max_retries} attempts: {str(e)}")
        
        print(f"Datasets uploaded to:")
        print(train_data_path)
        print(test_dataset_path)
        
        # Log a sample of the dataset as an artifact if MLflow is enabled
        if mlflow_enabled:
            try:
                with open("./data/sample.txt", "w") as f:
                    f.write(sample_text)
                mlflow.log_artifact("./data/sample.txt", "dataset_samples")
            except Exception as e:
                print(f"Error logging sample as artifact: {str(e)}")
        
        # Clean up
        try:
            if os.path.exists("./data"):
                shutil.rmtree("./data")
        except Exception as e:
            print(f"Warning: Error cleaning up temporary files: {str(e)}")

        # End MLflow run if it was started
        if mlflow_enabled and mlflow_run:
            try:
                mlflow.end_run()
            except Exception as e:
                print(f"Error ending MLflow run: {str(e)}")
        
    except Exception as e:
        error_msg = f"Critical error in preprocessing: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        
        if mlflow_enabled:
            try:
                mlflow.log_param("critical_error", error_msg)
            except:
                pass
        
        raise RuntimeError(f"Preprocessing failed: {str(e)}")
        

    return result_experiment_name, result_run_id, result_train_data_path, result_test_dataset_path

**Training Configuration**

The train_config dictionary is comprehensive, including:

Experiment naming for tracking purposes
Model specifications (ID, version, name)
Infrastructure details (instance types and counts for fine-tuning and deployment)
Training hyperparameters (epochs, batch size)

This configuration allows for easy adjustment of the training process without changing the core pipeline code.

**LoRA Parameters**

Low-Rank Adaptation (LoRA) is an efficient fine-tuning technique that reduces the number of trainable parameters by adding low-rank decomposition matrices to existing weights rather than updating all model weights. This significantly reduces memory requirements and training time while maintaining performance comparable to full fine-tuning.

In [None]:
%%bash

cat > ./args.yaml <<EOF

# MLflow Config
mlflow_uri: "${mlflow_uri}"                # The URI for the MLflow tracking server 
mlflow_experiment_name: "${mlflow_experiment_name}"  # Name of the MLflow experiment for organizing runs


model_id: "${model_location}"              # Hugging Face model id, or S3 location of base model

# SageMaker specific parameters 
output_dir: "/opt/ml/model"                # Path where SageMaker will upload the model 
train_dataset_path: "/opt/ml/input/data/train/"   # Path where FSx saves train dataset
test_dataset_path: "/opt/ml/input/data/test/"     # Path where FSx saves test dataset

# Training parameters
max_seq_length: 1500                       # Maximum sequence length for inputs (affects memory usage)
                                           # Higher values allow for longer context but require more memory
                                           # Range: 512-4096 depending on model architecture and hardware

# LoRA parameters (Low-Rank Adaptation)
lora_r: 8                                  # Rank of the LoRA update matrices
                                           # Lower values (4-16) are more efficient, higher values (32-64) can improve quality
                                           # Recommended range: 8-64 depending on task complexity
lora_alpha: 16                             # Scaling factor for the LoRA update
                                           # Generally set to 2x lora_r for good performance
lora_dropout: 0.1                          # Dropout probability for LoRA layers
                                           # Range: 0.0-0.5, helps prevent overfitting

# Optimizer parameters
learning_rate: 2e-4                        # Learning rate for parameter updates
                                           # Range: 1e-5 to 5e-4 for LoRA fine-tuning
                                           # Too high: training instability, too low: slow convergence

# Training loop parameters
num_train_epochs: 1                        # Number of complete passes through the training dataset
                                           # More epochs can improve performance but risk overfitting
                                           # Range: 1-5 for LoRA fine-tuning
per_device_train_batch_size: 2             # Number of samples per GPU during training
                                           # Larger values improve training speed but require more memory
                                           # Range: 1-8 for large models on common GPUs
per_device_eval_batch_size: 1              # Number of samples per GPU during evaluation
                                           # Can typically be larger than training batch size
gradient_accumulation_steps: 2             # Accumulate gradients over multiple steps
                                           # Effectively increases batch size by this factor
                                           # Useful when limited by GPU memory

# Memory optimization techniques
gradient_checkpointing: true               # Reduces memory usage by recomputing activations during backward pass
                                           # Trades computation for memory, ~20% slower but enables larger models/sequences
fp16: true                                 # Use half-precision floating point (speeds up training, reduces memory)
bf16: false                                # Use bfloat16 precision (better numerical stability than fp16)
                                           # Also enables FlashAttention2 (requires Ampere/Hopper GPU+ eg:A10, A100, H100)
tf32: false                                # Use TensorFloat-32 precision (NVIDIA Ampere+ GPUs only)

#uncomment here for fsdp - start
# fsdp: "full_shard auto_wrap offload"     # Fully Sharded Data Parallel training
                                           # Splits model states across multiple GPUs
# fsdp_config:                             # Configuration for FSDP
#     backward_prefetch: "backward_pre"    # Prefetches parameters before backward pass
#     cpu_ram_efficient_loading: true      # More memory-efficient parameter loading
#     offload_params: true                 # Offloads parameters to CPU when not in use
#     forward_prefetch: false              # Don't prefetch parameters for forward pass
#     use_orig_params: true                # Use original parameter ordering
#uncomment here for fsdp - end

merge_weights: true                        # Merge adapter weights into the base model
                                           # true: produces standalone model, false: keeps adapter separate
EOF

In [None]:
from sagemaker.s3 import S3Uploader

if default_prefix:
    input_path = f"s3://{bucket_name}/{default_prefix}/training_config/{model_id_filesafe}"
else:
    input_path = f"s3://{bucket_name}/training_config/{model_id_filesafe}"

# upload the model yaml file to s3
model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")

print(f"Training config uploaded to:")
print(train_config_s3_path)

**Fine-tuning Step**

This is where the actual model adaptation occurs. The step takes the preprocessed data and applies it to fine-tune the base LLM (in this case, a Deepseek model). It incorporates the LoRA technique for efficient adaptation.

In [None]:
@step(
    name="ModelFineTuning",
    instance_type=instance_type,
    display_name="Model Fine Tuning",
    keep_alive_period_in_seconds=900,
    dependencies="./scripts/requirements.txt"
)
def train(
    tracking_server_arn: str,
    train_dataset_s3_path: str,
    test_dataset_s3_path: str,
    train_config_s3_path: str,
    experiment_name: str,
    model_id: str,
    run_id: str,
):
    import sagemaker
    import boto3
    import mlflow
    import yaml
    import json
    import time
    import datetime
    import os
    import traceback
    import tempfile
    from pathlib import Path
    
    # Initialize variables and tracking
    start_time = time.time()
    model_name = model_id.split("/")[-1] if "/" in model_id else model_id
    training_job_name = None
    
    # Initialize MLflow tracking
    mlflow_enabled = (
        tracking_server_arn is not None
        and experiment_name is not None
        and tracking_server_arn != ""
        and experiment_name != ""
    )
    
    if mlflow_enabled:
        try:
            print(f"MLflow tracking enabled. Using server: {tracking_server_arn}")
            print(f"MLflow experiment: {experiment_name}")
            mlflow.set_tracking_uri(tracking_server_arn)
            mlflow.set_experiment(experiment_name)
            
            # Enable detailed tracking
            mlflow.set_tag("component", "model_fine_tuning")
            mlflow.autolog(log_datasets=True, log_models=True, log_input_examples=True)
            
            # Start MLflow run with parent run_id if available
            mlflow_run = mlflow.start_run(run_name=f"finetuning-{run_id}")
            print(f"Continuing MLflow run with ID: {run_id}")
                
        except Exception as e:
            error_msg = f"Error initializing MLflow tracking: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            mlflow_enabled = False
            mlflow_run = None
    
    try:
        # Log basic parameters to MLflow
        if mlflow_enabled:
            mlflow.log_param("model_id", model_id)
            mlflow.log_param("train_dataset", train_dataset_s3_path)
            mlflow.log_param("test_dataset", test_dataset_s3_path)
            mlflow.log_param("training_start_time", datetime.datetime.now().isoformat())
            
            # Download and parse the training config YAML to log hyperparameters
            with tempfile.NamedTemporaryFile(delete=False) as tmp:
                s3_client = boto3.client("s3")
                
                # Parse S3 path
                config_parts = train_config_s3_path.replace("s3://", "").split("/", 1)
                bucket = config_parts[0]
                key = config_parts[1]
                
                # Download config file
                try:
                    s3_client.download_file(bucket, key, tmp.name)
                    # Parse the YAML config
                    with open(tmp.name, 'r') as f:
                        config = yaml.safe_load(f)
                    
                    # Log all hyperparameters from config
                    print("Logging hyperparameters to MLflow:")
                    for param_name, param_value in config.items():
                        # Skip complex objects that can't be logged as parameters
                        if isinstance(param_value, (str, int, float, bool)):
                            print(f"  {param_name}: {param_value}")
                            mlflow.log_param(param_name, param_value)
                        elif param_name == "fsdp_config" and isinstance(param_value, dict):
                            # Log nested config as JSON
                            mlflow.log_param("fsdp_config_json", json.dumps(param_value))
                    
                    # Log file as artifact for reference
                    mlflow.log_artifact(tmp.name, "training_config")
                    
                except Exception as e:
                    print(f"Error parsing config file: {e}")
                    
                finally:
                    # Clean up temp file
                    if os.path.exists(tmp.name):
                        os.remove(tmp.name)
        
        # Launch the training job
        job_name = f"deepseek-finetune-{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
        from sagemaker.pytorch import PyTorch
        sagemaker_session = sagemaker.Session()
        
        # Define metric definitions for more detailed CloudWatch metrics
        metric_definitions = [
            {'Name': 'loss', 'Regex': "'loss':\\s*([0-9.]+)"},
            {'Name': 'epoch', 'Regex': "'epoch':\\s*([0-9.]+)"},
            {'Name': 'train_loss', 'Regex': "'train_loss':\\s*([0-9.]+)"},
            {'Name': 'lr', 'Regex': "'learning_rate':\\s*([0-9.e-]+)"},
            {'Name': 'step', 'Regex': "'step':\\s*([0-9.]+)"},
            {'Name': 'samples_per_second', 'Regex': "'train_samples_per_second':\\s*([0-9.]+)"},
        ]
        
        # Log the metric definitions we're using
        if mlflow_enabled:
            mlflow.log_param("tracked_metrics", [m['Name'] for m in metric_definitions])
        
        pytorch_estimator = PyTorch(
            entry_point='train.py',
            source_dir="./scripts",
            job_name=job_name,
            base_job_name=job_name,
            max_run=50000,
            role=role,
            framework_version="2.2.0",
            py_version="py310",
            instance_count=1,
            instance_type="ml.p3.2xlarge",
            sagemaker_session=sagemaker_session,
            volume_size=50,
            disable_output_compression=False,
            keep_alive_period_in_seconds=1800,
            distribution={"torch_distributed": {"enabled": True}},
            hyperparameters={
                "config": "/opt/ml/input/data/config/args.yaml"
            },
            metric_definitions=metric_definitions,
            debugger_hook_config=False
        )
    
        # Define a data input dictionary with our uploaded S3 URIs
        data = {
          'train': train_dataset_s3_path,
          'test': test_dataset_s3_path,
          'config': train_config_s3_path
        }
    
        print(f"Data for Training Run: {data}")
        
        # Log training job information
        if mlflow_enabled:
            mlflow.log_param("job_name", job_name)
            mlflow.log_param("instance_type", "ml.p3.2xlarge")
        
        # Start the training job
        pytorch_estimator.fit(data, wait=True)
    
        # Get information about the completed training job
        latest_run_job_name = pytorch_estimator.latest_training_job.job_name
        print(f"Latest Job Name: {latest_run_job_name}")
    
        sagemaker_client = boto3.client('sagemaker')
    
        # Describe the training job
        response = sagemaker_client.describe_training_job(TrainingJobName=latest_run_job_name)
    
        # Extract the model artifacts S3 path
        model_artifacts_s3_path = response['ModelArtifacts']['S3ModelArtifacts']
    
        # Extract the output path (this is the general output location)
        output_path = response['OutputDataConfig']['S3OutputPath']
        
        # Get training time metrics
        training_start_time = response.get('TrainingStartTime')
        training_end_time = response.get('TrainingEndTime')
        billable_time = response.get('BillableTimeInSeconds', 0)
        
        # Calculate duration
        total_training_time = 0
        if training_start_time and training_end_time:
            total_training_time = (training_end_time - training_start_time).total_seconds()
        
        # Log job results and metrics to MLflow
        if mlflow_enabled:
            # Log basic job info
            mlflow.log_param("training_job_name", latest_run_job_name)
            mlflow.log_param("model_artifacts_path", model_artifacts_s3_path)
            mlflow.log_param("output_path", output_path)
            
            # Log performance metrics
            mlflow.log_metric("billable_time_seconds", billable_time)
            mlflow.log_metric("total_training_time_seconds", total_training_time)
            
            # Log training job status
            mlflow.log_param("training_job_status", response.get('TrainingJobStatus'))
            
            # Log any secondary status
            if 'SecondaryStatus' in response:
                mlflow.log_param("secondary_status", response.get('SecondaryStatus'))
            
            # Log any failure reason
            if 'FailureReason' in response:
                mlflow.log_param("failure_reason", response.get('FailureReason'))
                
            # Get CloudWatch logs for the training job
            logs_client = boto3.client('logs')
            log_group = "/aws/sagemaker/TrainingJobs"
            log_stream = latest_run_job_name
            
            try:
                # Get the last 1000 log events
                log_events = logs_client.get_log_events(
                    logGroupName=log_group,
                    logStreamName=log_stream,
                    limit=1000
                )
                
                # Extract and save logs
                log_output = "\n".join([event['message'] for event in log_events['events']])
                
                # Save logs to file and log as artifact
                with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.txt') as tmp:
                    tmp.write(log_output)
                    log_file_path = tmp.name
                
                mlflow.log_artifact(log_file_path, "training_logs")
                os.remove(log_file_path)
                
            except Exception as e:
                print(f"Error fetching training logs: {e}")
            
            # Log total execution time of this step
            step_duration = time.time() - start_time
            mlflow.log_metric("step_execution_time_seconds", step_duration)
            
            # Log model metadata
            mlflow.set_tag("model_path", model_artifacts_s3_path)
            mlflow.set_tag("training_completed_at", datetime.datetime.now().isoformat())
    
        print(f"Model artifacts S3 path: {model_artifacts_s3_path}")

    except Exception as e:
        error_msg = f"Error in model fine-tuning: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        
        if mlflow_enabled:
            try:
                mlflow.log_param("fine_tuning_error", error_msg)
                mlflow.set_tag("training_status", "FAILED")
            except:
                pass
        
        raise RuntimeError(f"Fine-tuning failed: {str(e)}")
        
    finally:
        # End MLflow run if it was started
        if mlflow_enabled and mlflow_run:
            try:
                mlflow.set_tag("step_completed", True)
                mlflow.end_run()
            except Exception as e:
                print(f"Error ending MLflow run: {str(e)}")

    return experiment_name, run_id, model_artifacts_s3_path, output_path

### Deploy Step
This step deploys the model for evaluation

In [None]:
@step(
    name="ModelDeploy",
    instance_type=instance_type,
    display_name="Model Deploy",
    keep_alive_period_in_seconds=900
)
def deploy(
    model_artifacts_s3_path: str,
    output_path: str,
    experiment_name: str,
    model_id: str,
    run_id: str,
):
    import sagemaker
    import boto3
    from sagemaker import get_execution_role
    from sagemaker import Model
    from sagemaker.predictor import Predictor
    import time
    
    sagemaker_session = sagemaker.Session()
    instance_count = 1
    instance_type = "ml.g5.2xlarge"
    health_check_timeout = 700
    
    # Get the name for the endpoint
    endpoint_name = f"{model_id.split('/')[-1].replace('.', '-').replace('_','-')}-sft-djl"
    
    # Delete existing endpoint if it exists
    print(f"Checking for existing endpoint: {endpoint_name}")
    sm_client = boto3.client('sagemaker')
    try:
        sm_client.describe_endpoint(EndpointName=endpoint_name)
        print(f"Endpoint {endpoint_name} exists, deleting it before deployment")
        sm_client.delete_endpoint(EndpointName=endpoint_name)
        
        # Wait for endpoint to be fully deleted
        print("Waiting for endpoint to be fully deleted...")
        wait_seconds = 10
        total_wait_time = 0
        max_wait_time = 300  # 5 minutes maximum wait
        endpoint_deleted = False
        
        while total_wait_time < max_wait_time and not endpoint_deleted:
            try:
                sm_client.describe_endpoint(EndpointName=endpoint_name)
                print(f"Endpoint still exists, waiting {wait_seconds} seconds...")
                time.sleep(wait_seconds)
                total_wait_time += wait_seconds
            except sm_client.exceptions.ClientError:
                print(f"Endpoint {endpoint_name} successfully deleted")
                endpoint_deleted = True
                
        if not endpoint_deleted:
            print(f"Warning: Endpoint still exists after {max_wait_time} seconds")
            
    except sm_client.exceptions.ClientError:
        print(f"Endpoint {endpoint_name} does not exist, proceeding with deployment")
    
    # Continue with model deployment
    image_uri = sagemaker.image_uris.retrieve(
        framework="djl-lmi",
        region=sagemaker_session.boto_session.region_name,
        version="latest"
    )
    
    model_data = model_artifacts_s3_path
    
    # Create model only once
    model = Model(
        image_uri=image_uri,
        model_data=model_data,
        role=get_execution_role(),
        env={
            'HF_MODEL_ID': "/opt/ml/model", # path to where sagemaker stores the model
            'OPTION_TRUST_REMOTE_CODE': 'true',
            'OPTION_ROLLING_BATCH': "vllm",
            'OPTION_DTYPE': 'bf16',
            'OPTION_QUANTIZE': 'fp8',
            'OPTION_TENSOR_PARALLEL_DEGREE': 'max',
            'OPTION_MAX_ROLLING_BATCH_SIZE': '32',
            'OPTION_MODEL_LOADING_TIMEOUT': '3600',
            'OPTION_MAX_MODEL_LEN': '4096'
        }
    )

    print(f"deploying endpoint: {endpoint_name}")
    
    predictor = model.deploy(
        endpoint_name=endpoint_name,
        initial_instance_count=instance_count,
        instance_type=instance_type,
        container_startup_health_check_timeout=health_check_timeout,
        model_data_download_timeout=3600
    )
    
    return experiment_name, run_id, endpoint_name

### Evaluation Step

After fine-tuning, this step assesses the model's performance.

In [None]:
@step(
    name="ModelEvaluation",
    instance_type=instance_type,
    display_name="Model Evaluation",
    keep_alive_period_in_seconds=900,
    dependencies="./eval/requirements.txt"
)
def evaluate(
    experiment_name: str,
    run_id: str,
    endpoint_name: str,
)-> dict:
    import os
    import json
    import time
    import boto3
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm.notebook import tqdm
    from datasets import load_dataset
    import torch
    import torchvision
    import transformers
    import mlflow
    import uuid
    import traceback
    from datetime import datetime
    
    # Import LightEval metrics
    from lighteval.metrics.metrics_sample import ROUGE, Doc
    
    # Initialize MLflow tracking
    tracking_server_arn = os.environ.get("mlflow_uri", "")
    mlflow_enabled = (tracking_server_arn != "" and experiment_name != "")
    
    if mlflow_enabled:
        try:
            print(f"MLflow tracking enabled. Using server: {tracking_server_arn}")
            print(f"MLflow experiment: {experiment_name}")
            mlflow.set_tracking_uri(tracking_server_arn)
            mlflow.set_experiment(experiment_name)
            mlflow.autolog(log_datasets=True)
            
            # Start MLflow run with parent run_id if available
            mlflow_run = mlflow.start_run(run_name=f"evaluation-{run_id}")
            print(f"Continuing MLflow run with ID: {run_id}")
                
        except Exception as e:
            print(f"Error initializing MLflow tracking: {e}")
            print(traceback.format_exc())
            mlflow_enabled = False
            mlflow_run = None
    
    # Initialize the SageMaker client
    sm_client = boto3.client('sagemaker-runtime')
    
    FINETUNED_MODEL_ENDPOINT = endpoint_name # Update with Fine-tuned model endpoint name
    
    # Define the model to evaluate
    model_to_evaluate = {
        "name": "Fine-tuned DeepSeek-R1-Distill-Llama-8B", 
        "endpoint": FINETUNED_MODEL_ENDPOINT
    }
    # Limit the number of samples to evaluate (for faster execution)
    num_samples = 10
    
    # Log evaluation parameters to MLflow
    if mlflow_enabled:
        mlflow.log_param("evaluation_endpoint", FINETUNED_MODEL_ENDPOINT)
        mlflow.log_param("evaluation_num_samples", num_samples)
        mlflow.log_param("evaluation_timestamp", datetime.now().isoformat())
    
    # Load the test split of the medical-o1 dataset
    try:
        dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en", split="train")
        
        max_samples = len(dataset)
        
        dataset = dataset.shuffle().select(range(min(num_samples, max_samples)))
        print(f"Loaded medical-o1-reasoning dataset with {len(dataset)} samples out of {max_samples}")
        
        if mlflow_enabled:
            mlflow.log_param("dataset_name", "FreedomIntelligence/medical-o1-reasoning-SFT")
            mlflow.log_param("dataset_actual_samples", len(dataset))
    except Exception as e:
        error_msg = f"Error loading dataset: {str(e)}"
        print(error_msg)
        if mlflow_enabled:
            mlflow.log_param("dataset_load_error", error_msg)
        raise
    
    # Display a sample from the dataset
    sample = dataset[0]
    
    print("\nQuestion:\n", sample["Question"], "\n\n====\n")
    print("Complex_CoT:\n", sample["Complex_CoT"], "\n\n====\n")
    print("Response:\n", sample["Response"], "\n\n====\n")
    
    # This function allows you to interact with a deployed SageMaker endpoint to get predictions from the DeepSeek model
    def invoke_sagemaker_endpoint(payload, endpoint_name):
        """
        Invoke a SageMaker endpoint with the given payload.
    
        Args:
            payload (dict): The input data to send to the endpoint
            endpoint_name (str): The name of the SageMaker endpoint
    
        Returns:
            dict: The response from the endpoint
        """
        try:
            start_time = time.time()
            response = sm_client.invoke_endpoint(
                EndpointName=endpoint_name,
                ContentType='application/json',
                Body=json.dumps(payload)
            )
            inference_time = time.time() - start_time
            
            response_body = response['Body'].read().decode('utf-8')
            return json.loads(response_body), inference_time
        except Exception as e:
            print(f"Error invoking endpoint {endpoint_name}: {str(e)}")
            return None, -1
    
    # Initialize LightEval metrics calculators
    rouge_metrics = ROUGE(
        methods=["rouge1", "rouge2", "rougeL"],
        multiple_golds=False,
        bootstrap=False,
        normalize_gold=None,
        normalize_pred=None
    )
    
    def calculate_metrics(predictions, references):
        """
        Calculate all evaluation metrics for summarization using LightEval.
    
        Args:
            predictions (list): List of generated summaries
            references (list): List of reference summaries
    
        Returns:
            dict: Dictionary containing all metric scores
        """
        metrics = {}
    
        # Create Doc objects for the Rouge and BertScore metrics
        docs = []
        for reference in references:
            docs.append(Doc(
                {"target": reference},
                choices=[reference],  # Dummy choices
                gold_index=0  # Dummy gold_index
            ))
    
        # Calculate ROUGE scores for each prediction-reference pair
        rouge_scores = {
            'rouge1_f': [], 
            'rouge2_f': [], 
            'rougeL_f': [],
            # Add precision and recall scores too
            'rouge1_precision': [],
            'rouge1_recall': [],
            'rouge2_precision': [],
            'rouge2_recall': [],
            'rougeL_precision': [],
            'rougeL_recall': []
        }
    
        for pred, ref in zip(predictions, references):
            # For ROUGE calculation
            rouge_result = rouge_metrics.compute(golds=[ref], predictions=[pred])
            rouge_scores['rouge1_f'].append(rouge_result['rouge1'])
            rouge_scores['rouge2_f'].append(rouge_result['rouge2'])
            rouge_scores['rougeL_f'].append(rouge_result['rougeL'])
            
            # For more detailed ROUGE metrics (we get precision and recall too)
            detailed_rouge = rouge_metrics.compute_detailed(golds=[ref], predictions=[pred])
            rouge_scores['rouge1_precision'].append(detailed_rouge[0]['rouge1_precision'])
            rouge_scores['rouge1_recall'].append(detailed_rouge[0]['rouge1_recall'])
            rouge_scores['rouge2_precision'].append(detailed_rouge[0]['rouge2_precision'])
            rouge_scores['rouge2_recall'].append(detailed_rouge[0]['rouge2_recall'])
            rouge_scores['rougeL_precision'].append(detailed_rouge[0]['rougeL_precision'])
            rouge_scores['rougeL_recall'].append(detailed_rouge[0]['rougeL_recall'])
    
        # Average ROUGE scores
        for key in rouge_scores:
            metrics[key] = sum(rouge_scores[key]) / len(rouge_scores[key])
        
        # Calculate prediction statistics
        metrics['avg_prediction_length'] = np.mean([len(pred.split()) for pred in predictions])
        metrics['min_prediction_length'] = min([len(pred.split()) for pred in predictions])
        metrics['max_prediction_length'] = max([len(pred.split()) for pred in predictions])
        
        # Calculate reference statistics
        metrics['avg_reference_length'] = np.mean([len(ref.split()) for ref in references])
        metrics['min_reference_length'] = min([len(ref.split()) for ref in references])
        metrics['max_reference_length'] = max([len(ref.split()) for ref in references])
        
        # Calculate length ratio
        metrics['avg_length_ratio'] = np.mean([len(pred.split()) / len(ref.split()) if len(ref.split()) > 0 else 0 
                                              for pred, ref in zip(predictions, references)])
    
        print(f"Metrics: {metrics}")
    
        return metrics
    
    def generate_summaries_with_model(endpoint_name, dataset):
        """
        Generate summaries using a model deployed on SageMaker.
    
        Args:
            endpoint_name (str): SageMaker endpoint name
            dataset: Dataset containing dialogues
    
        Returns:
            list: Generated summaries
            list: Inference times for each summary
        """
        predictions = []
        inference_times = []
        failed_generations = 0
    
        for example in tqdm(dataset, desc="Generating Responses"):
            question = example["Question"]
    
            # Prepare the prompt for the model
            prompt = f"""
            <|begin_of_text|>
            <|start_header_id|>system<|end_header_id|>
            You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
            Below is an instruction that describes a task, paired with an input that provides further context. 
            Write a response that appropriately completes the request.
            Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
            <|eot_id|><|start_header_id|>user<|end_header_id|>
            {question}<|eot_id|>
            <|start_header_id|>assistant<|end_header_id|>"""
    
            # Payload for SageMaker endpoint
            payload = {
                "inputs": prompt,
                "parameters": {
                    "max_new_tokens": 512,
                    "top_p": 0.9,
                    "temperature": 0.6,
                    "return_full_text": False
                }
            }
    
            # Call the model endpoint
            try:
                response, inference_time = invoke_sagemaker_endpoint(payload, endpoint_name)
                
                # Extract the generated text
                if response is None:
                    prediction = "Error generating response."
                    failed_generations += 1
                elif isinstance(response, list):
                    prediction = response[0].get('generated_text', '').strip()
                elif isinstance(response, dict):
                    prediction = response.get('generated_text', '').strip()
                else:
                    prediction = str(response).strip()
    
                prediction = prediction.split("<|eot_id|>")[0] if "<|eot_id|>" in prediction else prediction
                
                # Log individual inference metrics
                if mlflow_enabled:
                    mlflow.log_metric(f"inference_time_sample_{len(predictions)}", inference_time)
                
                inference_times.append(inference_time)
                
            except Exception as e:
                print(f"Error invoking SageMaker endpoint {endpoint_name}: {e}")
                prediction = "Error generating response."
                failed_generations += 1
                inference_times.append(-1)
    
            predictions.append(prediction)
    
        # Log failure rate
        if mlflow_enabled:
            mlflow.log_metric("failed_generations", failed_generations)
            mlflow.log_metric("failure_rate", failed_generations / len(dataset) if len(dataset) > 0 else 0)
    
        return predictions, inference_times
    
    def evaluate_model_on_dataset(model_config, dataset):
        """
        Evaluate a fine-tuned model on a dataset using both automated and human metrics.
    
        Args:
            model_config (dict): Model configuration with name and endpoint
            dataset: dataset for evaluation
    
        Returns:
            dict: Evaluation results
        """
        model_name = model_config["name"]
        endpoint_name = model_config["endpoint"]
    
        print(f"\nEvaluating model: {model_name} on endpoint: {endpoint_name}")
    
        # Get references
        references = ["\n".join([example["Complex_CoT"], example["Response"]]) for example in dataset]
    
        # Generate summaries
        print("\nGenerating Responses...")
        predictions, inference_times = generate_summaries_with_model(endpoint_name, dataset)
        
        # Log inference time metrics
        if mlflow_enabled:
            valid_times = [t for t in inference_times if t > 0]
            if valid_times:
                mlflow.log_metric("avg_inference_time", np.mean(valid_times))
                mlflow.log_metric("min_inference_time", min(valid_times))
                mlflow.log_metric("max_inference_time", max(valid_times))
                mlflow.log_metric("p95_inference_time", np.percentile(valid_times, 95))
    
        # Calculate automated metrics using LightEval
        print("\nCalculating evaluation metrics with LightEval...")
        metrics = calculate_metrics(predictions, references)
        
        # Log all calculated metrics to MLflow
        if mlflow_enabled:
            for metric_name, metric_value in metrics.items():
                mlflow.log_metric(metric_name, metric_value)
            
            # Create a comparison table of predictions vs references
            comparison_data = []
            for i, (pred, ref) in enumerate(zip(predictions[:5], references[:5])):
                comparison_data.append({
                    "example_id": i,
                    "prediction": pred[:500] + ("..." if len(pred) > 500 else ""),  # Truncate for readability
                    "reference": ref[:500] + ("..." if len(ref) > 500 else ""),     # Truncate for readability
                    "rouge1_f": rouge_metrics.compute(golds=[ref], predictions=[pred])['rouge1']
                })
            
            comparison_df = pd.DataFrame(comparison_data)
            # Save comparison to a temporary CSV and log it as an artifact
            temp_csv = f"/tmp/predictions_comparison_{uuid.uuid4().hex[:8]}.csv"
            comparison_df.to_csv(temp_csv, index=False)
            mlflow.log_artifact(temp_csv, "model_predictions")
            
        # Format results
        results = {
            "model_name": model_name,
            "endpoint_name": endpoint_name,
            "num_samples": len(dataset),
            "metrics": metrics,
            "predictions": predictions[:5],  # First 5 predictions
            "references": references[:5]     # First 5 references
        }
    
        # Print key results
        print(f"\nResults for {model_name}:")
        print(f"ROUGE-1 F1: {metrics['rouge1_f']:.4f}")
        print(f"ROUGE-2 F1: {metrics['rouge2_f']:.4f}")
        print(f"ROUGE-L F1: {metrics['rougeL_f']:.4f}")
        print(f"Average Inference Time: {np.mean([t for t in inference_times if t > 0]):.3f} seconds")
    
        return results, metrics['rouge1_f'], metrics['rouge2_f'], metrics['rougeL_f']
    
    try:
        finetuned_model_results, rouge1_f, rouge2_f, rougeL_f = evaluate_model_on_dataset(model_to_evaluate, dataset)
        print(f"ROUGE-1 F1: {rouge1_f}")
        print(f"ROUGE-2 F1: {rouge2_f}")
        print(f"ROUGE-L F1: {rougeL_f}")
        
        # Create and log visualizations if MLflow is enabled
        if mlflow_enabled:
            # Log model card with performance summary
            model_card = f"""
            # Model Evaluation Report
            
            ## Model Information
            - **Model Name**: {model_to_evaluate["name"]}
            - **Endpoint**: {model_to_evaluate["endpoint"]}
            - **Evaluation Date**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
            - **Dataset**: FreedomIntelligence/medical-o1-reasoning-SFT
            - **Samples Evaluated**: {len(dataset)}
            
            ## Performance Metrics
            - **ROUGE-1 F1**: {rouge1_f:.4f}
            - **ROUGE-2 F1**: {rouge2_f:.4f}
            - **ROUGE-L F1**: {rougeL_f:.4f}
            - **Average Inference Time**: {np.mean([t for t in finetuned_model_results[0]["inference_times"] if t > 0]):.3f} seconds
            
            ## Detailed Metrics
            {json.dumps(finetuned_model_results[0]["metrics"], indent=2)}
            """
            
            with open("/tmp/model_card.md", "w") as f:
                f.write(model_card)
            
            mlflow.log_artifact("/tmp/model_card.md", "evaluation_summary")
            
            # Create a simple bar chart for ROUGE metrics
            plt.figure(figsize=(10, 6))
            metrics = finetuned_model_results[0]["metrics"]
            rouge_metrics = {
                'ROUGE-1 F1': metrics['rouge1_f'], 
                'ROUGE-2 F1': metrics['rouge2_f'], 
                'ROUGE-L F1': metrics['rougeL_f']
            }
            plt.bar(rouge_metrics.keys(), rouge_metrics.values())
            plt.title('ROUGE Metrics')
            plt.ylabel('Score')
            plt.ylim(0, 1)
            plt.grid(axis='y', linestyle='--', alpha=0.7)
            plt.savefig('/tmp/rouge_metrics.png')
            mlflow.log_artifact('/tmp/rouge_metrics.png', "evaluation_plots")
            
        # End MLflow run if we started one
        if mlflow_enabled and mlflow_run:
            mlflow.end_run()
    
    except Exception as e:
        error_msg = f"Error in model evaluation: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        
        if mlflow_enabled:
            try:
                mlflow.log_param("evaluation_error", error_msg)
                mlflow.end_run()
            except:
                pass
        
        # Return at least something even if evaluation fails
        return {"error": str(e), "rougeL_f": 0.0}

    return {"rougeL_f": rougeL_f}

### 7. Pipeline Creation and Execution

This final section brings all the components together into an executable pipeline.

**Creating the Pipeline**

The pipeline object is created with all defined steps.

In [None]:
# Defining the steps of the pipeline
preprocessing_step = preprocess(
    tracking_server_arn=tracking_server_arn,
    experiment_base_name=experiment_base_name,
    run_id=ExecutionVariables.PIPELINE_EXECUTION_ID,
    input_path=input_path,
)

training_step = train(
    tracking_server_arn=tracking_server_arn,
    experiment_name=preprocessing_step[0],
    run_id=preprocessing_step[1],
    train_dataset_s3_path=preprocessing_step[2],
    test_dataset_s3_path=preprocessing_step[3],
    train_config_s3_path=train_config_s3_path,
    model_id=model_s3_destination,
)

deploy_step = deploy(
    experiment_name=training_step[0],
    run_id=training_step[1],
    model_artifacts_s3_path=training_step[2],
    output_path=training_step[3],
    model_id=model_s3_destination,
)

evaluate_step = evaluate(
    experiment_name=deploy_step[0],
    run_id=deploy_step[1],
    endpoint_name=deploy_step[2],
)

# Combining the steps into the pipeline definition
pipeline = Pipeline(
    name=pipeline_name,
    parameters=[
        instance_type,
    ],
    steps=[preprocessing_step, training_step, deploy_step, evaluate_step],
)

**Upserting the Pipeline**

This step either creates a new pipeline in SageMaker or updates an existing one with the same name. It's a key part of the MLOps process, allowing for iterative refinement of the pipeline.

In [None]:
pipeline.upsert(role)

**Starting the Pipeline Execution**

This command kicks off the actual execution of the pipeline in SageMaker. From this point, SageMaker will orchestrate the execution of each step, managing resources and data flow between steps.

In [None]:
execution = pipeline.start()

# Clean up

In [None]:
# Delete the endpoint to avoid incurring charges
import boto3
import time
import botocore

def delete_endpoint_with_retry(endpoint_name, max_retries=3, wait_seconds=10):
    """
    Delete a SageMaker endpoint with retry logic
    
    Args:
        endpoint_name (str): Name of the SageMaker endpoint to delete
        max_retries (int): Maximum number of retry attempts
        wait_seconds (int): Time to wait between retries in seconds
    
    Returns:
        bool: True if deletion was successful, False otherwise
    """
    sm_client = boto3.client('sagemaker')
    
    # First check if the endpoint exists
    try:
        sm_client.describe_endpoint(EndpointName=endpoint_name)
        endpoint_exists = True
    except sm_client.exceptions.ClientError as e:
        if "Could not find endpoint" in str(e):
            print(f"Endpoint {endpoint_name} does not exist, no cleanup needed.")
            return True
        else:
            print(f"Error checking endpoint existence: {e}")
            return False
    
    # If we get here, the endpoint exists and we should delete it
    for attempt in range(max_retries):
        try:
            print(f"Attempting to delete endpoint {endpoint_name} (attempt {attempt + 1}/{max_retries})")
            sm_client.delete_endpoint(EndpointName=endpoint_name)
            print(f"Endpoint {endpoint_name} deletion initiated successfully")
            
            # Wait for endpoint to be fully deleted
            print("Waiting for endpoint to be fully deleted...")
            
            # Poll until endpoint is deleted or max wait time is reached
            total_wait_time = 0
            max_wait_time = 300  # 5 minutes maximum wait
            while total_wait_time < max_wait_time:
                try:
                    sm_client.describe_endpoint(EndpointName=endpoint_name)
                    print(f"Endpoint still exists, waiting {wait_seconds} seconds...")
                    time.sleep(wait_seconds)
                    total_wait_time += wait_seconds
                except sm_client.exceptions.ClientError:
                    print(f"Endpoint {endpoint_name} successfully deleted")
                    return True
            
            # If we get here, the endpoint still exists after max_wait_time
            print(f"Warning: Endpoint deletion initiated but still exists after {max_wait_time} seconds")
            return False
            
        except botocore.exceptions.ClientError as e:
            if "ResourceInUse" in str(e) or "ResourceNotFound" in str(e):
                print(f"Error deleting endpoint: {e}")
                print(f"Retrying in {wait_seconds} seconds...")
                time.sleep(wait_seconds)
            else:
                print(f"Unexpected error deleting endpoint: {e}")
                return False
    
    print(f"Failed to delete endpoint {endpoint_name} after {max_retries} attempts")
    return False

# Clean up endpoint
try:
    model_name_safe = model_id.split('/')[-1].replace('.', '-').replace('_', '-')
    endpoint_name = f"{model_name_safe}-sft-djl"
    
    print(f"Cleaning up endpoint: {endpoint_name}")
    if delete_endpoint_with_retry(endpoint_name):
        print("Cleanup completed successfully")
    else:
        print("Warning: Endpoint cleanup may have failed, please check the SageMaker console")
        
except Exception as e:
    print(f"Error during endpoint cleanup: {str(e)}")
    print("You may need to manually delete the endpoint from the SageMaker console")