# Preparing Mistral NeMo for Amazon Bedrock Custom Model Import (CMI)

This notebook demonstrates how to prepare and import Mistral NeMo into Amazon Bedrock using Custom Model Import (CMI).

## Prerequisites

1. **HuggingFace Access**
   - Active HuggingFace account
   - Valid access token
   - CLI authentication with HuggingFace (required for file transfers)

2. **File Transfer Method**
   - This notebook uses HF Transfer for efficient direct transfers from HuggingFace
   - Alternative: Manual download and S3 upload

3. **Model Configuration Requirements**
   - Must set `max_position_embeddings` to 8192 or less to comply with Bedrock limits
   - Defines the maximum sequence length

4. **File Format Requirements**
   - All model files must be in HuggingFace format
   - Required files include:
     - Model weights (.safetensors)
     - Configuration files (config.json, generation_config.json)
     - Tokenizer files (tokenizer.json, tokenizer_config.json)
     - Supporting files (vocab.json, merges.txt, special_tokens_map.json)
5. **EBS Volume**
   - If the notebook is run from Sagemaker Notebook instance, increase the EBS volume to 100gb

## Important Note on Model Precision
Bedrock CMI has specific requirements for model precision:
- Supported: FP32, FP16, and BF16 precision
- Not supported: Quantized models (including 4-bit quantization)
- Note: FP32 models will be automatically converted to BF16 precision internally by Bedrock

In [None]:
!pip install -q transformers hf_transfer huggingface_hub

## Create IAM role and policy for CMI

In [None]:
import boto3
import json

iam = boto3.client("iam")

iam.create_role(
    RoleName="MyImportModelRole",
    AssumeRolePolicyDocument=json.dumps({
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Principal": {
                    "Service": "bedrock.amazonaws.com"
                },
                "Action": "sts:AssumeRole"
            }
        ] 
    })
)

iam.create_policy(
    PolicyName="S3BucketPolicy",
    PolicyDocument=json.dumps({
    "Version": "2012-10-17",
        "Statement": [
            {
            "Effect": "Allow",
            "Action": [
                "s3:PutObject",
                "s3:GetObject",
                "s3:ListBucket",
                "s3:DeleteObject"
                ],
            "Resource": [
            "arn:aws:s3:::{YOUR_S3_BUCKET}",
            "arn:aws:s3:::{YOUR_S3_BUCKET}/*"
            ]
            }
        ]
}
)
     )

## Attach Policy to the Role

In [None]:
iam.attach_role_policy(
    RoleName="MyImportModelRole",
    PolicyArn="YOUR_POLICY_ARN" #the previous cell will display the ARN
)

### Download and Upload Hugging Face Model Files to S3 using HF Transfer

In [None]:
import os
import boto3
from huggingface_hub import hf_hub_download, login
import json
import boto3

HF_TOKEN = "YOUR_HF_TOKEN"
login(token = HF_TOKEN)

# Enable the faster transfers
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Repository and files
repo_id = "mistralai/Mistral-Nemo-Instruct-2407"
files_to_download = [
    "tokenizer_config.json",
    "generation_config.json",
    "merges.txt",
    "model-00001-of-00005.safetensors",
    "model-00002-of-00005.safetensors",
    "model-00003-of-00005.safetensors",
    "model-00004-of-00005.safetensors",
    "model-00005-of-00005.safetensors",
    "model.safetensors.index.json",
    "special_tokens_map.json",
    "tokenizer_config.json",
    "tokenizer.json",
    "vocab.json",
    "config.json"
]

# S3 configuration
bucket_name = "YOUR_BUCKET_NAME"
prefix = "PREFIX"

# Download location
temp_dir = "./temp_model_files"
os.makedirs(temp_dir, exist_ok=True)

# Initialize S3 client
s3_client = boto3.client('s3')

# Process each file
for file in files_to_download:
    try:
        print(f"Downloading {file} using accelerated transfer...")
        # Step 1: Download to temporary directory (will use hf_transfer under the hood)
        local_path = hf_hub_download(
            repo_id=repo_id,
            filename=file,
            local_dir=temp_dir
        )
        
        # Step 2: Upload to S3
        s3_key = f"{prefix}/{os.path.basename(local_path)}"
        print(f"Uploading to s3://{bucket_name}/{s3_key}...")
        s3_client.upload_file(
            Filename=local_path,
            Bucket=bucket_name,
            Key=s3_key
        )
        print(f"Successfully transferred {file}")
        
    except Exception as e:
        print(f"Error with {file}: {str(e)}")

### Update max_position_embeddings

In [None]:
# Initialize S3 client
s3 = boto3.client('s3')

# S3 details
bucket_name = 'YOUR_BUCKET_NAME'
config_key = '{PREFIX}/config.json'  # full path to config.json

# Download the current config
response = s3.get_object(Bucket=bucket_name, Key=config_key)
config = json.loads(response['Body'].read().decode('utf-8'))

# Modify the config
config['max_position_embeddings'] = 8192  # Using Bedrock's recommended value

# Upload modified config back to S3
s3.put_object(
    Bucket=bucket_name,
    Key=config_key,
    Body=json.dumps(config, indent=2),
    ContentType='application/json'
)

print("Config updated successfully")

### Start CMI Import Process

In [None]:
import boto3
import time
import json

# Initialize Bedrock client
bedrock = boto3.client('bedrock')

def start_model_import(
    bucket_name="YOUR_BUCKET_NAME", 
    prefix="PREFIX",
    model_name="mistral-nemo-test",
    role_arn="YOUR_ROLE_ARN" 
):
    
    # Construct S3 URI
    s3_uri = f"s3://{bucket_name}/{prefix}"
    
    try:
        response = bedrock.create_model_import_job(
            importedModelName=model_name,
            jobName=f"{model_name}-import-{int(time.time())}",
            modelDataSource={
                "s3DataSource": {
                    "s3Uri": s3_uri
                }
            },
            roleArn=role_arn
        )
        
        print(f"Model import job created successfully. Job ARN: {response['jobArn']}")
        return response['jobArn']
        
    except Exception as e:
        print(f"Error creating model import job: {str(e)}")
        raise

def check_import_status(job_arn):
    try:
        response = bedrock.get_model_import_job(
            jobIdentifier=job_arn
        )
        # Print full response for debugging
        print("Full response:")
        print(json.dumps(response, indent=2, default=str))
        
        # Try to get status from response
        if 'status' in response:
            return response['status']
        elif 'modelImportJob' in response and 'status' in response['modelImportJob']:
            return response['modelImportJob']['status']
        else:
            print("Status not found in response structure")
            return None
            
    except Exception as e:
        print(f"Error checking job status: {str(e)}")
        raise

# Start the import
job_arn = start_model_import()

# Monitor the import status
while True:
    try:
        status = check_import_status(job_arn)
        if status:
            print(f"Import status: {status}")
            if status in ['Completed', 'Failed']:
                break
        else:
            print("Could not determine status")
            break
        time.sleep(60)  # Check every minute
    except Exception as e:
        print(f"Error in monitoring loop: {str(e)}")
        break

### Test Inference with Imported Model

In [None]:
# Initialize Bedrock Runtime client
bedrock_runtime = boto3.client(
    service_name='bedrock-runtime',
    region_name='us-west-2'
)

# Test prompt
prompt = "What does Mistral AI do?"

# Prepare the request body
request_body = {
    "prompt": prompt,
    "max_tokens": 500,
    "temperature": 0.2,
    "top_k": 50
}

try:
    # Invoke the model
    response = bedrock_runtime.invoke_model(
        modelId='YOUR_MODEL_ARN', #from the bedrock console under imported models
        body=json.dumps(request_body)
    )
    
    # Parse and print the response
    response_body = json.loads(response['body'].read())
    print("Response:")
    print(response_body)
    text_content = response_body['outputs'][0]['text']
    print("Response text:\n", text_content)
    
    
except Exception as e:
    print(f"Error: {str(e)}")