## Setup for Inpainting Eraser Models
---
In this notebook, you are going to run a defined setup process for our inpainting eraser solution. Due to the size of the models, some of the step may take some time to complete. The entire notebook should finish within 1 hour. At the end of the notebook run, we should have a instended inference container in Elastic Container Registry (ECR) ready to host our models using SageMaker Endpoint (MME).

SageMaker MME is a service provided by Amazon SageMaker that allows multiple machine learning models to be hosted on a single endpoint. This means that multiple models can be deployed and managed together, making it easier to scale and maintain machine learning applications. With a multi-model endpoint, different models can be selected based on specific needs, allowing for more flexibility and efficiency. It also enables different types of models to be combined, such as computer vision and natural language processing models, to create more comprehensive applications.

Here is a high level breakdown of the setup steps:

1. Downloading pre-trained models
2. Package conda environment for additional model dependencies
3. Extend SageMaker managed Triton container with model checkpoints and conda packs pre-loaded
4. Push the container to AWS Elastic Container Registry (ECR)

---

This notebook will locally build a custom docker image. **We recommend to use pytorch kernel on SageMaker Notebook Instance using `ml.g4dn.xlarge`**

In [23]:
!pip install -Uq sagemaker

### Setup

In [24]:
import sagemaker
import boto3

import tarfile
import os

sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
region = sagemaker_session.boto_region_name
account = sagemaker_session.account_id()

### Download Pre-trained Models

#### Download and Package SAM Checkpoint
**Apache-2.0 license**

In [25]:
model_file_name = "sam_vit_h_4b8939.pth"
download_path = f"https://huggingface.co/spaces/abhishek/StableSAM/resolve/main/{model_file_name}"

!wget $download_path

--2024-01-17 03:54:42--  https://huggingface.co/spaces/abhishek/StableSAM/resolve/main/sam_vit_h_4b8939.pth
Resolving huggingface.co (huggingface.co)... 99.84.66.65, 99.84.66.70, 99.84.66.112, ...
Connecting to huggingface.co (huggingface.co)|99.84.66.65|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/47/d3/47d331d77ce5639cc128df17410f4744b11342191e8442f5cde65f20735d01f9/a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sam_vit_h_4b8939.pth%3B+filename%3D%22sam_vit_h_4b8939.pth%22%3B&Expires=1705722882&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwNTcyMjg4Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy80Ny9kMy80N2QzMzFkNzdjZTU2MzljYzEyOGRmMTc0MTBmNDc0NGIxMTM0MjE5MWU4NDQyZjVjZGU2NWYyMDczNWQwMWY5L2E3YmYzYjAyZjNlYmYxMjY3YWJhOTEzZmY2MzdkOWEyZDVjMzNkMzE3M2JiNjc5ZTQ2ZDlmMzM4YzI2

In [26]:
sd_tar = f"docker/{model_file_name}.tar.gz"

def make_tarfile(output_filename, source_dir):
    with tarfile.open(output_filename, "w:gz") as tar:
        tar.add(source_dir, arcname=os.path.basename(source_dir))

make_tarfile(sd_tar, model_file_name)

### Download and Package LaMa Checkpoint
**Apache-2.0 license**

In [27]:
!wget https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
!unzip big-lama.zip

--2024-01-17 03:56:35--  https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
Resolving huggingface.co (huggingface.co)... 99.84.66.70, 99.84.66.112, 99.84.66.72, ...
Connecting to huggingface.co (huggingface.co)|99.84.66.70|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/51/45/51456bfbae988f737d159c165ee79a3168fdd10d70db51da3e2dcc0cb29aa7a5/f1b358ca24093b93a106183b98a3dea6e8ed09f3b43ea7251eb2c81e7b4575f6?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27big-lama.zip%3B+filename%3D%22big-lama.zip%22%3B&response-content-type=application%2Fzip&Expires=1705722995&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwNTcyMjk5NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy81MS80NS81MTQ1NmJmYmFlOTg4ZjczN2QxNTljMTY1ZWU3OWEzMTY4ZmRkMTBkNzBkYjUxZGEzZTJkY2MwY2IyOWFhN2E1L2YxYjM1OGNhMjQwOTNiOTNhMTA2MTgzYjk4YTNkZWE2ZThlZDA5ZjNiNDNlYTcyNTFlYjJj

In [29]:
lama_dir = "big-lama"
lama_tar = f"docker/{lama_dir}.tar.gz"

def make_tarfile(output_filename, source_dir):
    with tarfile.open(output_filename, "w:gz") as tar:
        tar.add(source_dir, arcname=os.path.basename(source_dir))

make_tarfile(lama_tar, lama_dir)

LaMa needs additional source code script. We will clone the repo into our `model_repo` folder

In [30]:
!cd model_repo/lama/1 && git clone https://github.com/advimman/lama.git

fatal: destination path 'lama' already exists and is not an empty directory.


#### Downloading Images and Modules from Inpaint Anything
**Apache-2.0 license**

In [31]:
!cd statics && wget https://raw.githubusercontent.com/geekyutao/Inpaint-Anything/main/example/fill-anything/sample1.png
!cd statics && wget https://raw.githubusercontent.com/geekyutao/Inpaint-Anything/main/example/remove-anything/dog.jpg

--2024-01-17 04:25:06--  https://raw.githubusercontent.com/geekyutao/Inpaint-Anything/main/example/fill-anything/sample1.png
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2358157 (2.2M) [image/png]
Saving to: ‘sample1.png.1’


2024-01-17 04:25:06 (41.4 MB/s) - ‘sample1.png.1’ saved [2358157/2358157]

--2024-01-17 04:25:06--  https://raw.githubusercontent.com/geekyutao/Inpaint-Anything/main/example/remove-anything/dog.jpg
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 99846 (98K) [image/jpeg]
Saving to: ‘dog.jpg.1’


2024-01-17 

### Package Conda Environment for each model

SageMaker NVIDIA Triton container images does not contain all the libraries two run our SAM and LaMa models. However, Triton allows you to bring additional dependencies using conda pack. Let's run the two cells below to create a `xxx_env.tar.gz` environment package for each model.

In [34]:
!cd docker && bash sam_conda_dependencies.sh

Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 23.3.1
  latest version: 23.11.0

Please update conda by running

    $ conda update -n base -c conda-forge conda

Or to minimize the number of packages updated during conda update use

     conda install conda=23.11.0



## Package Plan ##

  environment location: /home/ec2-user/anaconda3/envs/sam_env

  added / updated specs:
    - python=3.8


The following NEW packages will be INSTALLED:

  _libgcc_mutex      conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge 
  _openmp_mutex      conda-forge/linux-64::_openmp_mutex-4.5-2_gnu 
  bzip2              conda-forge/linux-64::bzip2-1.0.8-hd590300_5 
  ca-certificates    conda-forge/linux-64::ca-certificates-2023.11.17-hbcca054_0 
  ld_impl_linux-64   conda-forge/linux-64::ld_impl_linux-64-2.40-h41732ed_0 
  libffi             conda-forge/linux-64::libffi-3.4.2-h7f98852_5 
  libgcc-ng          conda-forge/linux-64::libgcc-ng-13.2.0-h80

Installing collected packages: opencv-python-headless
Successfully installed opencv-python-headless-4.7.0.68
Collecting matplotlib==3.6.3
  Using cached matplotlib-3.6.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (9.4 MB)
Collecting contourpy>=1.0.1 (from matplotlib==3.6.3)
  Using cached contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.9 kB)
Collecting cycler>=0.10 (from matplotlib==3.6.3)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib==3.6.3)
  Using cached fonttools-4.47.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (157 kB)
Collecting kiwisolver>=1.0.1 (from matplotlib==3.6.3)
  Using cached kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting packaging>=20.0 (from matplotlib==3.6.3)
  Using cached packaging-23.2-py3-none-any.whl.metadata (3.2 kB)
Collecting pyparsing>=2.2.1 (from matplotlib==3.6.3)
  Us

In [35]:
!cd docker && bash lama_conda_dependencies.sh

Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 23.3.1
  latest version: 23.11.0

Please update conda by running

    $ conda update -n base -c conda-forge conda

Or to minimize the number of packages updated during conda update use

     conda install conda=23.11.0



## Package Plan ##

  environment location: /home/ec2-user/anaconda3/envs/lama_env

  added / updated specs:
    - python=3.8


The following NEW packages will be INSTALLED:

  _libgcc_mutex      conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge 
  _openmp_mutex      conda-forge/linux-64::_openmp_mutex-4.5-2_gnu 
  bzip2              conda-forge/linux-64::bzip2-1.0.8-hd590300_5 
  ca-certificates    conda-forge/linux-64::ca-certificates-2023.11.17-hbcca054_0 
  ld_impl_linux-64   conda-forge/linux-64::ld_impl_linux-64-2.40-h41732ed_0 
  libffi             conda-forge/linux-64::libffi-3.4.2-h7f98852_5 
  libgcc-ng          conda-forge/linux-64::libgcc-ng-13.2.0-h8

Collecting importlib-resources>=3.2.0 (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2)
  Using cached importlib_resources-6.1.1-py3-none-any.whl.metadata (4.1 kB)
Collecting zipp>=3.1.0 (from importlib-resources>=3.2.0->matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2)
  Using cached zipp-3.17.0-py3-none-any.whl.metadata (3.7 kB)
Collecting six>=1.5 (from python-dateutil>=2.7->matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2)
  Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Using cached imageio-2.33.1-py3-none-any.whl (313 kB)
Using cached matplotlib-3.7.4-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (9.2 MB)
Using cached pillow-10.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.4 MB)
Using cached tifffile-2023.7.10-py3-none-any.whl (220 kB)
Using cached contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (301 kB)
Using cached cycler-0.12.1-py3-none-any.whl (8.3 kB)
Using cached fonttools-4.47.2-cp38-cp38-manylinux_2_17_x86_64.

Collecting pyasn1<0.6.0,>=0.4.6 (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow)
  Using cached pyasn1-0.5.1-py2.py3-none-any.whl.metadata (8.6 kB)
Collecting oauthlib>=3.0.0 (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow)
  Using cached oauthlib-3.2.2-py3-none-any.whl (151 kB)
Using cached tensorflow-2.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (479.6 MB)
Using cached absl_py-2.1.0-py3-none-any.whl (133 kB)
Using cached flatbuffers-23.5.26-py2.py3-none-any.whl (26 kB)
Using cached grpcio-1.60.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
Using cached h5py-3.10.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.8 MB)
Using cached keras-2.13.1-py3-none-any.whl (1.7 MB)
Using cached libclang-16.0.6-py2.py3-none-manylinux2010_x86_64.whl (22.9 MB)
Using cached protobuf-4.25.2-cp37-abi3-manylinux2014_x86_64.whl (294 kB)
Using cached tensorflow_esti

Collecting Shapely (from imgaug>=0.4.0->albumentations==0.5.2)
  Using cached shapely-2.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.0 kB)
Using cached opencv_python_headless-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49.6 MB)
Using cached shapely-2.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB)
Installing collected packages: Shapely, opencv-python-headless, imgaug, albumentations
Successfully installed Shapely-2.0.2 albumentations-0.5.2 imgaug-0.4.0 opencv-python-headless-4.9.0.80
Collecting hydra-core==1.1.0
  Using cached hydra_core-1.1.0-py3-none-any.whl (144 kB)
Collecting omegaconf==2.1.* (from hydra-core==1.1.0)
  Using cached omegaconf-2.1.2-py3-none-any.whl (74 kB)
Collecting antlr4-python3-runtime==4.8 (from hydra-core==1.1.0)
  Using cached antlr4_python3_runtime-4.8-py3-none-any.whl
Installing collected packages: antlr4-python3-runtime, omegaconf, hydra-core
Successfully installed antlr4-python3-runt

Collecting attrs>=17.3.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=0.8.1->pytorch-lightning==1.2.9)
  Using cached attrs-23.2.0-py3-none-any.whl.metadata (9.5 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=0.8.1->pytorch-lightning==1.2.9)
  Using cached multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (121 kB)
Collecting yarl<2.0,>=1.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=0.8.1->pytorch-lightning==1.2.9)
  Using cached yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (31 kB)
Collecting frozenlist>=1.1.1 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=0.8.1->pytorch-lightning==1.2.9)
  Using cached frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting aiosignal>=1.1.2 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=0.8.1->pytorch-lightning==1.2.9)
  Using cached aiosignal-1.3.1-py3-none-any.whl (7.

Using cached torchvision-0.16.2-cp38-cp38-manylinux1_x86_64.whl (6.8 MB)
Installing collected packages: torchvision
Successfully installed torchvision-0.16.2
Collecting conda-pack
  Using cached conda_pack-0.7.1-py2.py3-none-any.whl.metadata (2.6 kB)
Using cached conda_pack-0.7.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: conda-pack
Successfully installed conda-pack-0.7.1
Collecting packages...
Packing environment at '/home/ec2-user/anaconda3/envs/lama_env' to 'lama_env.tar.gz'
[########################################] | 100% Completed |  3min 38.7s


### Extend SageMaker Managed Triton Container

When we host these models on SageMaker MME. When invoked, model files will be loaded from S3 onto the instance. Due the large size of our models and model packages (SAM: 2.4GB + conda pack: 2.52 GB, LaMa: 0.38 GB + conda pack: 3.35GB), we are going to pre-load these files into the container. This will reduce model loading time and improve user experience during cold start.

In [36]:
# account mapping for SageMaker Triton Image
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}



region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
mme_triton_image_uri = (
    "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.12-py3".format(
        account_id=account_id_map[region], region=region, base=base
    )
)
triton_account_id = account_id_map[region]
mme_triton_image_uri

'301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:22.12-py3'

Preview the docker file

In [37]:
!cat docker/Dockerfile

ARG BASE_IMAGE

FROM $BASE_IMAGE

#Install any additional libraries
RUN echo "Adding conda package to Docker image"
RUN mkdir -p /home/condpackenv/
RUN mkdir -p /home/models/

# Copy conda env
COPY sam_env.tar.gz /home/condpackenv/sam_env.tar.gz
COPY lama_env.tar.gz /home/condpackenv/lama_env.tar.gz

COPY sam_vit_h_4b8939.pth.tar.gz /temp/
COPY big-lama.tar.gz /temp/

# Install tar
RUN apt-get update && apt-get install -y tar
# RUN apt-get update && apt-get install ffmpeg libsm6 libxext6  -y

# Untar the file
RUN tar -xzf /temp/sam_vit_h_4b8939.pth.tar.gz -C /home/models/
RUN tar -xzf /temp/big-lama.tar.gz -C /home/models/

RUN rm /temp/sam_vit_h_4b8939.pth.tar.gz
RUN rm /temp/big-lama.tar.gz

### Build & push the new image to ECR

In [38]:
# New container image name
new_image_name = 'sagemaker-tritonserver-sam-lama'

In [39]:
%%capture build_output
!cd docker && bash build_and_push.sh "$new_image_name" "latest" "$mme_triton_image_uri" "$region" "$account" "$triton_account_id"

In [40]:
if 'Error response from daemon' in str(build_output):
    print(build_output)
    raise SystemExit('\n\n!!There was an error with the container build!!')
else:
    extended_triton_image_uri = str(build_output).strip().split('\n')[-1]
    
print(f"New image URI from ECR: {extended_triton_image_uri}")

New image URI from ECR: 376678947624.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver-sam-lama:latest


In [41]:
%store extended_triton_image_uri

Stored 'extended_triton_image_uri' (str)
