### Docker image to run Clay model in GPU instances

#### Clone Clay model repo at specific commit

see here for the clay GitHub repository: https://github.com/Clay-foundation/model/tree/main

In [None]:
!pip install gitpython

In [None]:
URL = "https://github.com/Clay-foundation/model.git"
SHA = "32518ce" # LATEST COMMIT PRE v.1.5, see here: https://github.com/Clay-foundation/model/commit/32518ceed8f75f116f3325bdb68c62eeab9ddbae

In [None]:
from git import Repo

def clone_repo_at_commit(git_url, local_dir, commit_hash):
    """
    Clone a git repository at a specific commit
    
    Args:
        git_url (str): URL of the git repository
        local_dir (str): Local directory where to clone the repository
        commit_hash (str): The specific commit hash to checkout
    """
    try:
        # Clone the repository
        repo = Repo.clone_from(git_url, local_dir)
        
        # Checkout the specific commit
        repo.git.checkout(commit_hash)
        
        print(f"Successfully cloned repository at commit {commit_hash}")
        
    except Exception as e:
        print(f"Error occurred: {str(e)}")

In [None]:
clone_repo_at_commit(git_url=URL, local_dir="./clay_assets", commit_hash=SHA)

#### Download model checkpoint from HuggingFace

see here for the Clay HF repository: https://huggingface.co/made-with-clay/Clay

In [None]:
import os
artifact_dir="./clay_assets/checkpoints/"
os.makedirs(artifact_dir,exist_ok=True)

In [None]:
hf_ckpt_path = "https://huggingface.co/made-with-clay/Clay/resolve/main/v1/clay-v1-base.ckpt"

In [None]:
!wget --quiet -P {artifact_dir} {hf_ckpt_path}

#### Write the Dockerfile

In [None]:
%%writefile Dockerfile

#Build from sagemaker distro image: https://gallery.ecr.aws/sagemaker/sagemaker-distribution
FROM public.ecr.aws/sagemaker/sagemaker-distribution:1.8.0-gpu

ARG NB_USER="sagemaker-user"
ARG NB_UID=1000
ARG NB_GID=100

ENV MAMBA_USER=$NB_USER

USER $ROOT

# Install system dependencies
RUN apt-get update && apt-get install -y \
    wget \
    && rm -rf /var/lib/apt/lists/*

# Install Mamba
RUN wget -qO- https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xvj bin/micromamba \
    && mv bin/micromamba /usr/local/bin/ \
    && rm -rf bin

# Set up Mamba environment
ENV MAMBA_ROOT_PREFIX=/opt/conda
ENV PATH=$MAMBA_ROOT_PREFIX/bin:$PATH

# Copy environment files
COPY environment.yml /tmp/environment.yml
COPY requirements.txt /tmp/requirements.txt

# Create and activate the environment, install dependencies, and pip requirements
RUN micromamba create -f /tmp/environment.yml && \
    micromamba run -n claymodel pip install -r /tmp/requirements.txt && \
    micromamba clean --all --yes

# Set environment variable for the environment name
ENV ENV_NAME=claymodel

# Set the default environment for inference
ENV SAGEMAKER_JOB_CONDA_ENV=claymodel

# Set the default environment path
ENV PATH /opt/conda/envs/$ENV_NAME/bin:$PATH

# Copy model files
COPY  clay_assets/ /home/sagemaker-user/clay-model

# Add healthcheck to verify code is running at default path
HEALTHCHECK --interval=30s --timeout=3s \
  CMD pgrep -f "python3 /opt/ml/processing/input/code/" || exit 1

# Set the entrypoint to activate the environment
ENTRYPOINT ["/bin/bash", "-c"]

#### Build and tag Docker image

In [None]:
import boto3
import sagemaker

# Create a SageMaker session
sagemaker_session = sagemaker.Session()

# Get the region
ECR_REGION = sagemaker_session.boto_region_name

# Get the account number
sts_client = boto3.client('sts')
ECR_ACCOUNT_ID = sts_client.get_caller_identity()["Account"]

#Set Repo and Image name
REPO_NAME="clay-gpu-container-new"
IMG_NAME=f"{REPO_NAME}:latest"

print(f"Region: {ECR_REGION}")
print(f"Account Number: {ECR_ACCOUNT_ID}")
print(f"ECR Repository Name: {REPO_NAME}")
print(f"Image Name: {IMG_NAME}")

In [None]:
!aws ecr get-login-password --region {ECR_REGION} | docker login --username AWS --password-stdin {ECR_ACCOUNT_ID}.dkr.ecr.{ECR_REGION}.amazonaws.com

In [None]:
!docker build --quiet -f Dockerfile -t {IMG_NAME} .

In [None]:
!docker tag {IMG_NAME} {ECR_ACCOUNT_ID}.dkr.ecr.{ECR_REGION}.amazonaws.com/{IMG_NAME}

#### Push to ECR

Ensure that the ECR repository exists. Create it if it does not.

In [None]:
def ensure_ecr_repository(repository_name, region=ECR_REGION):
    """
    Check if ECR repository exists and create it if it doesn't.
    
    Args:
        repository_name (str): Name of the ECR repository
        region (str, optional): AWS region. If None, uses default region
    
    Returns:
        dict: Repository details
    """
    try:
        # Initialize ECR client
        ecr_client = boto3.client('ecr', region_name=region)
        
        try:
            # Try to describe the repository to check if it exists
            response = ecr_client.describe_repositories(
                repositoryNames=[repository_name]
            )
            print(f"Repository '{repository_name}' already exists")
            return response['repositories'][0]
            
        except ecr_client.exceptions.RepositoryNotFoundException:
            # Repository doesn't exist, create it
            print(f"Creating repository '{repository_name}'...")
            response = ecr_client.create_repository(
                repositoryName=repository_name,
                imageScanningConfiguration={'scanOnPush': True},
                encryptionConfiguration={'encryptionType': 'AES256'}
            )
            print(f"Repository '{repository_name}' created successfully")
            return response['repository']
            
    except Exception as e:
        print(f"Error managing ECR repository: {str(e)}")
        raise

In [None]:
try:
    repository = ensure_ecr_repository(REPO_NAME)
    print(f"Repository URI: {repository['repositoryUri']}")
except Exception as e:
    print(f"Failed to ensure repository exists: {str(e)}")

In [None]:
!docker push {ECR_ACCOUNT_ID}.dkr.ecr.{ECR_REGION}.amazonaws.com/{IMG_NAME}

#### Pull from ECR (Required to work with SM Notebookes in Local Mode)

In [None]:
!docker pull {ECR_ACCOUNT_ID}.dkr.ecr.{ECR_REGION}.amazonaws.com/{IMG_NAME}