In [None]:
import numpy as np
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch

account_id = boto3.client('sts').get_caller_identity().get('Account')
region = boto3.Session().region_name

role = sagemaker.get_execution_role()
sm_session = sagemaker.Session()
bucket_name = sm_session.default_bucket()

In [None]:
# choose a repo name you like
ecr_repository = 'mmseg-train'

# 登录ECR服务
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account_id}.dkr.ecr.{region}.amazonaws.com.cn
# !aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin 727897471807.dkr.ecr.{region}.amazonaws.com.cn

In [None]:
!aws ecr create-repository --repository-name $ecr_repository

In [None]:
# 构建训练镜像并推送到ECR, China Region.
torch_version = '1.10.0'
tag = ':pt-' + torch_version
repository_uri = '{}.dkr.ecr.{}.amazonaws.com.cn/{}'.format(account_id, region, ecr_repository + tag)
print('repository_uri: ', repository_uri)

!docker build -t "$ecr_repository$tag" . -f Dockerfile-ubuntu.gpu --build-arg REGION_NAME=$region --build-arg TORCH_VERSION=$torch_version
!docker tag {ecr_repository + tag} $repository_uri

# !docker push $repository_uri

In [None]:
pytorch_estimator = PyTorch(entry_point = 'tools/sm-train.sh',
                            source_dir = '.',
                            instance_type='local_gpu',
                            instance_count=1,
                            # framework_version='1.10.0',
                            # py_version='py38',
                            image_uri = repository_uri,
                            role = role,
                            environment = {
                                'MM_NUM_CLASSES': 2,
                                'MM_LOG_LEVEL': 'INFO'
                            },
                            hyperparameters = {
                                'config': 'configs/fcn/fcn_hr48_320x480_80k_0.01_new_label.py',
                            })

### config 文件会自动读取环境变量 SM_CHANNEL_DATA_ROOT 
# pytorch_estimator.fit({'data_root': 'file:///home/ec2-user/SageMaker/data/train'})
pytorch_estimator.fit({'data_root': 's3://bucketname/path/to/data'})