In [None]:
import sagemaker
from sagemaker import get_execution_role, image_uris

import json
import shutil
import warnings

import prepare_model
import create_endpoint
import test_endpoint

In [None]:
model_config = 'code_repo/model_config.json'

with open(model_config, 'r') as config_file:
    model_config = json.load(config_file)

In [None]:
bucket_name = model_config['bucket_name']
weight_file_key = model_config['weight_file_key']

model_name = model_config['model_name']
model_version = model_config['model_version']

local_file_name = './weights_file.npy'
model_save_path = f'./models/{model_version}'
model_tar_gz_key = f'{model_name}.tar.gz'


# S3에서 가중치 다운로드
prepare_model.download_weights_from_s3(bucket_name, weight_file_key, local_file_name)

# 모델 설정 및 저장
prepare_model.setup_and_save_model(local_file_name, model_save_path)

# 모델을 tar.gz로 압축하고 S3로 업로드
prepare_model.compress_and_upload_model_to_s3(model_save_path, f'./{model_name}/{model_version}', bucket_name, model_tar_gz_key)

# tar.gz 파일이 생성된 후, 해당 디렉토리 삭제
shutil.rmtree(model_name)

In [None]:
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
role = get_execution_role()


# 이미지 URI 검색
container = create_endpoint.retrieve_container_image(region=region, 
                                     framework='tensorflow', 
                                     version='2.13.0', 
                                     instance_type='ml.m5.large')

# 모델 객체 생성
model = create_endpoint.create_model(container=container, 
                     s3_bucket=bucket_name, 
                     model_s3_key=model_tar_gz_key, 
                     role=role)

# 엔드포인트 이름 생성 및 모델 배포
endpoint_name = create_endpoint.create_endpoint_name()
predictor = create_endpoint.deploy_model(model=model, 
                         instance_type='ml.m5.large', 
                         initial_instance_count=1, 
                         endpoint_name=endpoint_name)

In [None]:
# 엔드포인트 불러와서 테스트
test_endpoint.test_inference(endpoint_name=endpoint_name, num_samples=5)

In [None]:
# 엔드포인트 삭제할 경우
sagemaker_session.delete_endpoint(endpoint_name)