# Deploying BGE-M3 Embedding Model on Amazon SageMaker

This notebook demonstrates how to deploy the [BGE-M3](https://huggingface.co/BAAI/bge-m3) embedding model on Amazon SageMaker. BGE-M3 is a state-of-the-art embedding model that supports dense, sparse, and ColBERT embeddings.

## Steps:
1. Download model checkpoint from Hugging Face
2. Upload model to S3
3. Create custom inference code
4. Deploy model to SageMaker endpoint
5. Test the endpoint

## 1. Download Model Checkpoint

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path

local_model_path = Path("./hf_model")
local_model_path.mkdir(exist_ok=True)
model_name = "BAAI/bge-m3"
snapshot_download(repo_id=model_name, cache_dir=local_model_path)

## 2. Upload Model to S3

In [None]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json

# Initialize SageMaker session and clients
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts

region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

# Define S3 paths
s3_model_prefix = "BAAI/hf_model"  # folder where model checkpoint will go
model_snapshot_path = list(local_model_path.glob("**/snapshots/*"))[0]
s3_code_prefix = "BAAI/inference_code"
print(f"s3_code_prefix: {s3_code_prefix}")
print(f"model_snapshot_path: {model_snapshot_path}")

# Upload model to S3
!aws s3 cp --recursive {model_snapshot_path} s3://{bucket}/{s3_model_prefix}

## 3. Create Custom Inference Code

In [None]:
!mkdir -p inference_code

In [None]:
%%writefile inference_code/model.py
from djl_python import Input, Output
import torch
import logging
import os
from FlagEmbedding import BGEM3FlagModel

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'--device={device}')

def load_model(properties):
    tensor_parallel = properties.get("tensor_parallel_degree", 1)  # Default value 1
    model_location = properties.get("model_dir", "/opt/ml/model")
    
    if "model_id" in properties:
        model_location = properties["model_id"]
    
    logging.info(f"Loading model from {model_location}")

    model = BGEM3FlagModel(model_location, use_fp16=True)
    return model

model = None

def handle(inputs: Input):
    global model
    if model is None:
        model = load_model(inputs.get_properties())

    if inputs.is_empty():
        return None

    data = inputs.get_as_json()

    # Extract parameters from JSON
    input_sentences = data.get("inputs", [])
    if isinstance(input_sentences, str):
        input_sentences = [input_sentences]  # Convert single input to list

    is_query = data.get("is_query", False)
    max_length = data.get("max_length", 2048)
    instruction = data.get("instruction", "")

    # Extract optional parameters
    return_dense = data.get("return_dense", True)  # Default: True
    return_sparse = data.get("return_sparse", False)  # Default: False
    return_colbert_vecs = data.get("return_colbert_vecs", False)  # Default: False

    logging.info(f"inputs: {input_sentences}")
    logging.info(f"is_query: {is_query}")
    logging.info(f"instruction: {instruction}")
    logging.info(f"return_dense: {return_dense}, return_sparse: {return_sparse}, return_colbert_vecs: {return_colbert_vecs}")

    # Add instruction for queries if provided
    if is_query and instruction:
        input_sentences = [instruction + sent for sent in input_sentences]

    # Generate embeddings with specified options
    sentence_embeddings = model.encode(
        input_sentences, 
        max_length=max_length, 
        return_dense=return_dense, 
        return_sparse=return_sparse, 
        return_colbert_vecs=return_colbert_vecs
    )

    # Format output JSON
    result = {}
    if return_dense:
        result["dense_embeddings"] = sentence_embeddings.get("dense_vecs", [])
    if return_sparse:
        result["sparse_embeddings"] = sentence_embeddings.get("lexical_weights", [])
    if return_colbert_vecs:
        result["colbert_vectors"] = sentence_embeddings.get("colbert_vecs", [])

    return Output().add_as_json(result)


In [None]:
import os

if not os.path.exists("inference_code"):
    os.mkdir("inference_code")

# Create serving.properties file
with open('inference_code/serving.properties', 'w') as f:
    f.write("engine=Python")
    f.write("\n")
    f.write("option.tensor_parallel_degree=1")
    f.write("\n")
    f.write(f"option.model_id=s3://{bucket}/{s3_model_prefix}/")

In [None]:
%%writefile inference_code/requirements.txt
FlagEmbedding

In [None]:
# Package and upload inference code
!rm -f inference_code.tar.gz
!cd inference_code && rm -rf ".ipynb_checkpoints"
!tar czvf inference_code.tar.gz inference_code

s3_code_artifact = sess.upload_data("inference_code.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {s3_code_artifact}")

## 4. Deploy Model to SageMaker Endpoint

In [None]:
from sagemaker.utils import name_from_base
import boto3

# Define the DJL inference container URI
inference_image_uri = (f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124")
model_name = name_from_base("bge-m3")

print(f"Model name: {model_name}")
print(f"Inference container image: {inference_image_uri}")

In [None]:
# CREATE MODEL
create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataUrl": s3_code_artifact,
    },
)
model_arn = create_model_response["ModelArn"]
print(f"Created Model: {model_arn}")


# CREATE ENDPOINT CONFIG
endpoint_config_name = f"{model_name}-config"
endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.g5.xlarge",
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 5*60,
        },
    ],
)
print(endpoint_config_response)

# CREATE ENDPOINT
endpoint_name = f"{model_name}-endpoint"
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
# Wait for endpoint deployment to complete
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

## 5. Test the Endpoint

In [None]:
import numpy as np

def get_vector_by_sm_endpoint(questions, smr_client, endpoint_name):
    """Get embeddings from SageMaker endpoint"""
    response_model = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(
            {
                "inputs": questions,
                'return_sparse': True,
                'return_colbert_vecs': True,
            }
        ),
        ContentType="application/json",
    )
    json_str = response_model['Body'].read().decode('utf8')
    json_obj = json.loads(json_str)
    return json_obj

def cos_sim(vector1, vector2):
    """Calculate cosine similarity between two vectors"""
    dot_product = np.dot(vector1, vector2)
    norm_v1 = np.linalg.norm(vector1)
    norm_v2 = np.linalg.norm(vector2)
    cos_sim = dot_product / (norm_v1 * norm_v2)
    return cos_sim

In [None]:
# Test dense embeddings
text1 = "How cute your dog is!"
text2 = "Your dog is so cute."
text3 = "The mitochondria is the powerhouse of the cell."

# Get dense embeddings and calculate similarity
emb1, emb2, emb3 = get_vector_by_sm_endpoint([text1, text2, text3], smr_client, endpoint_name)['dense_embeddings']
print(f"Similarity between text1 and text2: {cos_sim(emb1, emb2)}")
print(f"Similarity between text1 and text3: {cos_sim(emb1, emb3)}")

In [None]:
# Test sparse embeddings
sparse_embeddings = get_vector_by_sm_endpoint([text1, text2, text3], smr_client, endpoint_name)['sparse_embeddings']
sparse_embeddings

In [None]:
# Test ColBERT vectors
cvec1, cvec2, cvec3 = get_vector_by_sm_endpoint([text1, text2, text3], smr_client, endpoint_name)['colbert_vectors']
print(f"ColBERT vector dimensions: {len(cvec1[0])}, {len(cvec2[0])}, {len(cvec3[0])}")