# Batch Inference for Fine-tuned Models

This notebook implements batch inference for fine-tuned models using Amaizn SageMaker AI. We use batch inference because:
1. We want to run inference on the test dataset that the model has not seen during training for evaluation.
2. Real-time inference would be too costly and slow for our dataset size
3. We use Amazon SageMaker AI @remote decorator to have the option to run inference in the local environment or with GPU instances remotely without code changes.

The notebook handles:
1. Setting up the environment and configurations
2. Loading and preparing the model
3. Running batch inference
4. Downloading inference results


## Prerequisites
- AWS credentials configured
- Access to Amazon SageMaker AI training jobs
- Sufficient quota for a GPU instance for SageMaker training job or spot training job
- Fine-tuned model artifacts in S3

## Import Required Libraries

In [None]:
%pip install sagemaker==2.227.0 --quiet

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import subprocess
from pathlib import Path
import pandas as pd
from typing import Union, Dict, Optional

import sagemaker
from sagemaker.remote_function import remote


import csv
from IPython.display import display, HTML
from ipywidgets import widgets

from utils.config import ModelConfig
from utils.helpers import get_s3_suffix, shorten_for_sagemaker_training_job
from utils.model_manager import list_available_models

In [None]:
# Initialize session and configure AWS resources for training
try:
    role = sagemaker.get_execution_role()
    session = sagemaker.Session()
    region = session.boto_region_name
    
    # Configure S3 paths for data and artifacts
    # CHANGE if your dataset is in a different S3 bucket
    default_bucket_name = session.default_bucket()
    dataset_s3_prefix = "fatura2-train-data"
    s3_root_uri = f"s3://{default_bucket_name}"
    dataset_s3_uri = f"{s3_root_uri}/{dataset_s3_prefix}"
    
    
except Exception as e:
    raise Exception(f"Error setting up SageMaker session: {str(e)}")
print("✅ Initialized SageMaker session...")
print(f"💾 Using dataset for inference: {dataset_s3_uri}")

## Retrieve model artifact location from fine-tuning

We retrieve the latest fine-tuning training job from Amazon SageMaker AI to retrieve the models artifacts location.

We need to track training jobs because:
1. We want to use the latest fine-tuned model
2. Job names help in resource organization
3. We need to locate model artifacts in S3
4. Version tracking is crucial for reproducibility


In [None]:
# The models that you fine-tuned.
base_model_config = ModelConfig(
    # Replace with model type and model id of the base model.
    model_type="qwen2_5_vl",
    model_id="Qwen/Qwen2.5-VL-3B-Instruct"

    # model_type = "llama3_2_vision",
    # model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
)

print("✅ Configured model id.")

In [None]:
training_job_name_prefix = base_model_config.training_job_prefix(dataset_s3_prefix)
print(f"Fine-tuning name prefix: {training_job_name_prefix}")

In [None]:
display(HTML(f"""<div style="background-color: #f3fd91; border: 2px solid #a2bb11; padding: 10px; color: black; font-family: Arial, Helvetica, sans-serif;">
    Skip the next 4 cells below if you want to run inference using <b>{base_model_config.model_id}</b> from HuggingFace Hub. Run cells below if you want to use a model that you have fine-tuned.
</div>"""))

In [None]:
df = list_available_models(default_bucket_name, training_job_name_prefix)
df

In [None]:
which_model_to_pick = 0 # use first model from list by default

In [None]:
# Set up the S3 URI from which we will download the model
model_key=df['Key'].iloc[which_model_to_pick]
model_output_url = f"s3://{default_bucket_name}/{model_key}"
print(f"Selected model for deployment: {model_key}")
print(f"S3 Model URI: {model_output_url}")

In [None]:
model_suffix_s3 = get_s3_suffix(model_output_url)


<div style="background-color: #f3fd91; border: 2px solid #a2bb11; padding: 10px; color: black; font-family: Arial, Helvetica, sans-serif;">
    Continue below for inference with base model or fine-tuned model.
</div>

In [None]:
try:
    model_config = ModelConfig(
        # Replace with model type and model id of the base model.
        model_type=base_model_config.model_type,
        model_id=model_output_url
    )
    
    prefix = model_suffix_s3.split("/")[0]
    print("✅ Configured fine-tuned model id.")
    
except NameError:
    # not using fine-tuned model
    model_config = base_model_config
    prefix = model_config.model_id.replace("/","-").replace(".","-")
    print("✅ Using base model for inference.")

In [None]:
print(f"Model for inference: {model_config.model_id}")

## Configure Job for Batch Inference

Lets define the SageMaker distribution image to be used for us-east-1. The URI for other distributions and regions can be found in the [SageMaker Distribution documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-images.html#notebooks-available-images-arn).

Here are a few example distributions from the link above:

* us-east-1: 885854791233.dkr.ecr.us-east-1.amazonaws.com/sagemaker-distribution-prod:2.1.0-gpu
* us-west-2: 542918446943.dkr.ecr.us-west-2.amazonaws.com/sagemaker-distribution-prod:2.1.0-gpu

In [None]:
# lets define the sagemaker distribution to use
if region == "us-east-1":
    sagemaker_dist_uri = "885854791233.dkr.ecr.us-east-1.amazonaws.com/sagemaker-distribution-prod:2.1.0-gpu"
elif region == "us-west-2":
    sagemaker_dist_uri = "542918446943.dkr.ecr.us-west-2.amazonaws.com/sagemaker-distribution-prod:2.1.0-gpu"
else:
    raise ValueError(
        "Please make sure to manually set the `sagemaker_dist_uri` uri for your specific AWS region using the provided link from the cell above."
    )

Define the dependencies that are required during training.

In [None]:
%%writefile ./requirements.txt
git+https://github.com/huggingface/accelerate.git@v1.5.2
ms-swift@git+https://github.com/modelscope/ms-swift.git@v3.2.0
git+https://github.com/huggingface/transformers.git@014047e1c8784c00e2a04cb04ffcecdd5cb23c16
pyav
vllm==0.7.2
decord
optimum
qwen-vl-utils==0.0.10
huggingface_hub
hf_transfer
xgrammar # grammar constrained decoding

In [None]:
s3_root_uri = "s3://{}".format(default_bucket_name)

### Environment Variables Configuration

We set specific environment variables because:
1. Memory usage needs to be optimized for GPUs
2. Image processing has size constraints
3. We want faster downloads from Hugging Face
4. Resource limits need to be carefully managed

In [None]:
# defines the environment variables for the training
env_variables ={
    "SIZE_FACTOR": json.dumps(8), # can be increase but requires more GPU memory
    "MAX_PIXELS": json.dumps(602112), # can be increase but requires more GPU memory
    "USE_HF_TRANSFER": json.dumps(1),
    "HF_HUB_ENABLE_HF_TRANSFER": json.dumps(1),
    # "HF_TOKEN": "xxxxxxxx",
}


### SageMaker Configuration

By default, the [Amazon SageMaker Python SDK reads configuration](https://sagemaker.readthedocs.io/en/stable/overview.html#configuring-and-using-defaults-with-the-sagemaker-python-sdk) values from an admin defined or user specific configuration file. This configuration allows all kind of customizations do be made. Setting the `SAGEMAKER_USER_CONFIG_OVERRIDE` environment variable below overwrites these defaults. The main settings you will configure below are

* The container image URI that should run the remote function code.
* Python dependencies to install for the remote training.
* Which files from the local working directory not to upload to the remote code.

In [None]:
os.environ["SAGEMAKER_USER_CONFIG_OVERRIDE"] = os.getcwd()

In [None]:
config_yaml = f"""
SchemaVersion: '1.0'
SageMaker:
  PythonSDK:
    Modules:
      RemoteFunction:
        # role arn is not required if in SageMaker Notebook instance or SageMaker Studio
        # Uncomment the following line and replace with the right execution role if in a local IDE
        # RoleArn: <replace the role arn here>
        S3RootUri: {s3_root_uri}
        ImageUri: {sagemaker_dist_uri}        
        InstanceType: ml.g6e.2xlarge
        Dependencies: ./requirements.txt
        IncludeLocalWorkDir: true
        PreExecutionCommands:
        - "pip install packaging"
        CustomFileFilter:
          IgnoreNamePatterns:
          - "*.ipynb"
          - "__pycache__"
          - "data"
          - "venv"
          - "bin"
          - "models"
          - "results"
        EnvironmentVariables: {json.dumps(env_variables)}
        Tags:
          - Key: 'purpose'
            Value: 'inference'
          - Key: 'model_id'
            Value: {model_config.model_id}
          - Key: 'dataset'
            Value: {dataset_s3_uri}
"""

print(config_yaml, file=open("config.yaml", "w"))
print(config_yaml)

In [None]:
job_name_prefix = shorten_for_sagemaker_training_job(f"infer-json-{prefix}")

### Constrained Decoding

Constrained decoding controls a language model's next-token prediction process by limiting which tokens it can generate to only those that satisfy specific rules or formats. During the normal generation process, a language model assigns probabilities to all possible next tokens. With constrained decoding the set of next tokens is limited to only tokens that satisfy the required structure. For example with JSON constrained decoding the model can only select tokens that create a valid JSON syntax. 

Below you can configure constrained decoding for the batch inference:
1. Set it to `None` to run batch inference without any constrained decoding.
2. If you have a JSON schema file in your dataset you can set `guided_decoding` to the path of that JSON schema file inside your dataset, for example `guided_decoding = "groundtruth_schema.json"`. You can reference the [02_create_custom_dataset_swift.ipynb](02_create_custom_dataset_swift.ipynb) notebook on how to create a JSON schema file. 
3. You can also set `guided_decoding` to a dict sturctured output parameter from the [vLLM documentation](https://docs.vllm.ai/en/latest/features/structured_outputs.html), for example `guided_decoding = {"guided_choice": ["positive", "negative"]}`

In [None]:
guided_decoding = None # 1. default no constrained decoding

# guided_decoding = "groundtruth_schema.json" # 2. use a JSON schema inside dataset

# 3. Below is an example on how to configure structure output in accordance to the vLLM documentation
# from pydantic import BaseModel

# class Invoice(BaseModel):
#     purpose: str
#     amount: int

# json_schema = Invoice.model_json_schema()
# guided_decoding = {"guided_json": json_schema}

## Batch Inference Function

In [None]:
@remote(
    instance_type="ml.g6e.xlarge",  # Powerful GPU for fast inference
    instance_count=1,  # Single instance for cost efficiency
    volume_size=200, # Large volume for model and data storage
    job_name_prefix=job_name_prefix,
    use_spot_instances=True, # Cost efficient inference. Inference can be restarted if no spot capacity. 
    max_wait_time_in_seconds=172800, # 48 hours max wait
    max_runtime_in_seconds=172800, # 48 hours max runtime
)
def batch_inference(
    model_id: str,
    model_type: str,
    dataset_s3: str,
    test_data_path: str = "test.jsonl",
    guided_decoding: Optional[Union[Dict, str]] = None
) -> str:
    """
    Run batch inference using SageMaker.
    
    Args:
        model_id: Model identifier or S3 URI
        model_type: Type of the model
        dataset_s3: S3 URI for the dataset
        test_data_path: Path to the test data file
        guided_decoding: vllm guided_decoding config or path to json schema. Default: None - no constrained decoding used
        
    Returns:
        Status message
    """
    from utils.model_manager import ModelManager
    from swift.llm import infer_main
    from pathlib import Path
    import subprocess
    import json


    output_dir = Path("/opt/ml/model")
    
    # copy the training data from input source to local directory
    dataset_dir = Path(".")
    os.makedirs(dataset_dir, exist_ok=True)
    subprocess.run(
        ["aws", "s3", "cp", dataset_s3, dataset_dir, "--recursive", "--quiet"],
        shell = False
    )
    
    test_data_local_path = dataset_dir / test_data_path
    result_path = output_dir / "results.jsonl"
    
    model_manager = ModelManager()
    guided_decoding = model_manager.construct_guided_decoding_config(dataset_dir, guided_decoding)
    
    argv = [
        "--result_path", str(result_path),
        "--max_length", "4096",  # Maximum sequence length for processing
        "--val_dataset", str(test_data_local_path),
        "--use_hf", "true", 
        "--infer_backend", "vllm",  # Use VLLM for faster inference
        "--gpu_memory_utilization", "0.98",  # High GPU utilization for speed
        "--max_num_seqs", "8",  # Batch size for parallel processing
        "--limit_mm_per_prompt", '{"image": 1, "video": 0}', # how many images per prompt. Increase if you have multi page pdf
        "--temperature", "0",
        # "--max_model_len","32768",
    ]

    model_dir: Path
        
    # Handle model loading
    if model_id.startswith("s3://"):
        
        model_dir = model_manager.download_and_extract_model(model_id)
        ckpt_dir = model_manager.find_best_model_checkpoint(model_dir)
       
        
        model_ckpt_args = [
            "--adapters", str(ckpt_dir),
            "--merge_lora", "true",
            "--load_data_args", "true"
        ]
        argv.extend(model_ckpt_args)
        
    else:
        model_dir = model_manager.download_from_hf_hub(model_id)
        from_hub_args = ["--model_type", model_type, "--model", str(model_dir)]
        argv.extend(from_hub_args)

    model_manager.update_generation_config(model_dir, guided_decoding)

    result = infer_main(argv)
    return result

## Run Batch Inference

In [None]:
inference_kwargs = {
    "model_id":model_config.model_id,
    "model_type":model_config.model_type,
    "dataset_s3":dataset_s3_uri,
    "test_data_path":"conversations_test_swift_format.json",
    "guided_decoding":guided_decoding
}

In [None]:
print(
    f"View your job here: https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/"
)
batch_inference(**inference_kwargs)

### Download inference results

In [None]:
from utils.helpers import get_latest_sagemaker_training_job

job_description = get_latest_sagemaker_training_job(job_name_prefix)

# Return the S3 model artifacts path
inference_output_url = job_description["ModelArtifacts"]["S3ModelArtifacts"]
print(f"Inference results can be found at {inference_output_url}")

### Track Inference Results

We track inference results in a CSV file for evaluation of different models later.
1. We need to maintain a history of all inference runs
2. We want to associate results with specific models
3. We need to easily locate model outputs later
4. CSV format enables easy tracking

In [None]:
# Define the tracking file
tracking_file = "./results_to_compare.csv"

In [None]:
print("\nPlease enter a human readable name for this inference run:")
human_name = input()

In [None]:
model_config.model_id

In [None]:
try:
    # Prepare row data
    row_data = [human_name, base_model_config.model_id, inference_output_url]
    
    # Check if file exists
    file_exists = os.path.exists(tracking_file)
    
    # Open file in append mode with proper newline handling
    with open(tracking_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        
        # Write headers if new file
        if not file_exists:
            writer.writerow(['human_name', 'model', 'inference_results_s3'])
            print(f"✅ Created new tracking file: {tracking_file}")
            
        # Write the new row
        writer.writerow(row_data)
        print(f"✅ Added new inference result to tracking file: {tracking_file}")
    
    # Display the full tracking history
    print("\nCurrent tracking history:")
    with open(tracking_file, mode='r') as f:
        reader = csv.reader(f)
        # Get headers for formatting
        headers = next(reader)
        # Calculate column widths based on content
        col_widths = [max(len(str(x)) for x in col) for col in zip([headers], *reader)]
        
        # Reset file pointer and skip header
        f.seek(0)
        next(reader)
        
        # Print headers
        header_format = ' | '.join(f'{h:<{w}}' for h, w in zip(headers, col_widths))
        print(header_format)
        print('-' * len(header_format))
        
        # Print data rows
        for row in reader:
            print(' | '.join(f'{x:<{w}}' for x, w in zip(row, col_widths)))
    
except Exception as e:
    print(f"❌ Error tracking inference results: {str(e)}")

## Next step
* Continue with the [05_evaluate_model.ipynb](./05_evaluate_model.ipynb) notebook to evaluate the models performance and compare it to other models. 