# Register an LLM Base Model with SageMaker Model Registry

## Overview
Large Language models, such as [Llama2](https://ai.meta.com/llama/) from Meta comes with a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters.
Depending on the parameter count, and the floating point precisions used for the model weights, the total model size of these LLMs could be very large. For instance, a llama2-70b model with fp32 could have about 280GB in model size. Therefore, downloading weights from the public internet, such as Huggingface hub could be slow and inefficient. The inefficiency is magnified even more when multiple team members working on projects that use the same base model. Another challenge is on how to organize and manage these open-souce LLMs effective within the organization. 

## Proposed Approach 
In this notebook, we leverage SageMaker Model Registry to store the weights of the base LLM models. SageMaker model registry is a fully managed model repository used to store and version trained machine learning (ML) models at scale. When we finetune the base models, we could easily use the base model group for more efficient download. SageMaker model registry gives organization a better model management tool that helps them organize and manage model version of open-source LLMs. Additionally, with the recent support for Model Registry Collections, you can use Collections to group registered models that are related to each other and organize them in hierarchies to improve model discoverability at scale. Here's a diagram that shows SageMaker Model Registry with collection support:



First, we would install git lfs and initialize it to allow model weights to be downloaded from Huggingface Hub directly.

In [None]:
!apt update && apt install git-lfs -y
!git lfs install --skip-repo

Import all the required packages for this notebook

In [None]:
import os
import sagemaker
from sagemaker.collection import Collection
import boto3
import json


Instantiate a new SageMaker session and define the variables

In [None]:
sm_session = sagemaker.session.Session()
default_bucket = sm_session.default_bucket()
sm_client = boto3.client("sagemaker")
role = sagemaker.get_execution_role()

In [None]:
model_id = "NousResearch/Llama-2-7b-chat-hf" # Change this value to any other model on Huggingfae Hub.
base_model_s3_bucket=f"s3://{default_bucket}/data/{model_id}/basemodel"

git clone the repository from huggingface hub without model weights. 

In [None]:
!GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/{model_id}

In [None]:
model_name = os.path.basename(model_id)

## Download Model Weight
To download model weight, run `git clone` with `lfs` option for downloading large files. 
In our example, we only download the safetensors model weights, and not the torch weights. That would save us some time.

## SafeTensors
At a high level, safetensors is a safe and fast file format for storing and loading tensors. Typically, PyTorch model weights are saved or pickled into a .bin file with Python’s pickle utility. However, pickle is not secure and pickled files may contain malicious code that can be executed. safetensors is a secure alternative to pickle, making it ideal for sharing model weights.


In [None]:
!cd {model_name} && git lfs pull --include "*.safetensors"

At the time of this writing, SageMaker Model Registry requires model weights to be converted into a `tar.gz` file. The following cell creates the `tar.gz` files with the required model weight.

*Note:* Due to the sheer volume of the model weight, creating a `tar.gz` file could take some time. In our experiment, the process takes about 35 minutes.

In [None]:
%%time
!cd {model_name} && rm -rf .git* && tar -cvzf ../model.tar.gz .

### Upload the model artifacts to S3 bucket.

In [None]:
%%time
model_data_uri = sagemaker.s3.S3Uploader.upload(
    local_path="./model.tar.gz",
    desired_s3_uri=base_model_s3_save_loc,
)
print(model_data_uri)

#### At this stage, deploy the cloudformation template to create the Lambda function and SNS topic. 

### Prepare the SNS message payload

In [None]:

message = {
    "model_id": model_id, 
    "model_data_uri": model_data_uri,
}
sns_topic_arn = "" # Replace with the SNS topic ARN from the CloudFormation outputs.

### Send SNS message to trigger the Register Model Lambda Function

In [None]:
# Initialize SNS client
sns_client = boto3.client('sns')

try:
    # Publish message
    response = sns_client.publish(
        TopicArn=sns_topic_arn,
        Message=json.dumps(message),
        MessageStructure='string'
    )
    
    print(f"\nMessage sent successfully! MessageId: {response['MessageId']}")
    print(f"Message payload: {json.dumps(message, indent=2)}")
    
except Exception as e:
    print(f"Error sending message to SNS: {str(e)}")