diff --git a/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/ImageURIProvider.scala b/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/ImageURIProvider.scala index bed0f68..f8e3def 100644 --- a/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/ImageURIProvider.scala +++ b/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/ImageURIProvider.scala @@ -19,13 +19,24 @@ import com.amazonaws.regions.Regions private[algorithms] object SageMakerImageURIProvider { + def isChinaRegion(region: String): Boolean = { + val chinaRegions = Set( + Regions.CN_NORTH_1.getName, + Regions.CN_NORTHWEST_1.getName + ) + chinaRegions.contains(region) + } + def getImage(region: String, regionAccountMap: Map[String, String], algorithmName: String, algorithmTag: String): String = { val account = regionAccountMap.get(region) account match { case None => throw new RuntimeException(s"The region $region is not supported." + s"Supported Regions: ${regionAccountMap.keys.mkString(", ")}") - case _ => s"${account.get}.dkr.ecr.${region}.amazonaws.com/${algorithmName}:${algorithmTag}" + case _ if isChinaRegion(region) => + s"${account.get}.dkr.ecr.${region}.amazonaws.com.cn/${algorithmName}:${algorithmTag}" + case _ => + s"${account.get}.dkr.ecr.${region}.amazonaws.com/${algorithmName}:${algorithmTag}" } } } @@ -52,7 +63,9 @@ private[algorithms] object SagerMakerRegionAccountMaps { Regions.EU_NORTH_1.getName -> "669576153137", Regions.EU_WEST_3.getName -> "749696950732", Regions.EU_WEST_3.getName -> "749696950732", - Regions.ME_SOUTH_1.getName -> "249704162688" + Regions.ME_SOUTH_1.getName -> "249704162688", + Regions.CN_NORTH_1.getName -> "390948362332", + Regions.CN_NORTHWEST_1.getName -> "387376663083" ) // For LDA @@ -94,7 +107,9 @@ private[algorithms] object SagerMakerRegionAccountMaps { Regions.EU_NORTH_1.getName -> "669576153137", Regions.EU_WEST_3.getName -> "749696950732", Regions.EU_WEST_3.getName -> "749696950732", - Regions.ME_SOUTH_1.getName -> "249704162688" + Regions.ME_SOUTH_1.getName -> "249704162688", + Regions.CN_NORTH_1.getName -> "390948362332", + Regions.CN_NORTHWEST_1.getName -> "387376663083" ) } diff --git a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/FactorizationMachinesSageMakerEstimatorTests.scala b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/FactorizationMachinesSageMakerEstimatorTests.scala index 0a3fab9..0959bdd 100644 --- a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/FactorizationMachinesSageMakerEstimatorTests.scala +++ b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/FactorizationMachinesSageMakerEstimatorTests.scala @@ -151,6 +151,16 @@ class FactorizationMachinesSageMakerEstimatorTests extends FlatSpec with Mockito createFactorizationMachinesBinaryClassifier(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/factorization-machines:1") + + val estimatorCNNorth1 = + createFactorizationMachinesBinaryClassifier(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/factorization-machines:1") + + val estimatorCNNorthWest1 = + createFactorizationMachinesBinaryClassifier(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/factorization-machines:1") } it should "use the correct defaults for regressor" in { @@ -253,6 +263,16 @@ class FactorizationMachinesSageMakerEstimatorTests extends FlatSpec with Mockito createFactorizationMachinesRegressor(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/factorization-machines:1") + + val estimatorCNNorth1 = + createFactorizationMachinesRegressor(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/factorization-machines:1") + + val estimatorCNNorthWest1 = + createFactorizationMachinesRegressor(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/factorization-machines:1") } it should "setFeatureDim" in { diff --git a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/KMeansSageMakerEstimatorTests.scala b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/KMeansSageMakerEstimatorTests.scala index 42a1891..edc5272 100644 --- a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/KMeansSageMakerEstimatorTests.scala +++ b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/KMeansSageMakerEstimatorTests.scala @@ -128,6 +128,14 @@ class KMeansSageMakerEstimatorTests extends FlatSpec with Matchers with MockitoS val estimatorMESouth1 = createKMeansEstimator(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/kmeans:1") + + val estimatorCNNorth1 = createKMeansEstimator(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/kmeans:1") + + val estimatorCNNorthWest1 = createKMeansEstimator(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/kmeans:1") } it should "setK" in { diff --git a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/LinearLearnerSageMakerEstimatorTests.scala b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/LinearLearnerSageMakerEstimatorTests.scala index 3883537..5c7328b 100644 --- a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/LinearLearnerSageMakerEstimatorTests.scala +++ b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/LinearLearnerSageMakerEstimatorTests.scala @@ -156,6 +156,16 @@ class LinearLearnerSageMakerEstimatorTests extends FlatSpec with MockitoSugar { createLinearLearnerBinaryClassifier(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/linear-learner:1") + + val estimatorCNNorth1 = + createLinearLearnerBinaryClassifier(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/linear-learner:1") + + val estimatorCNNorthWest1 = + createLinearLearnerBinaryClassifier(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/linear-learner:1") } it should "use the correct defaults for multiclass classifier" in { @@ -261,6 +271,16 @@ class LinearLearnerSageMakerEstimatorTests extends FlatSpec with MockitoSugar { createLinearLearnerMultiClassClassifier(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/linear-learner:1") + + val estimatorCNNorth1 = + createLinearLearnerMultiClassClassifier(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/linear-learner:1") + + val estimatorCNNorthWest1 = + createLinearLearnerMultiClassClassifier(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/linear-learner:1") } it should "use the correct defaults for regressor" in { @@ -362,6 +382,16 @@ class LinearLearnerSageMakerEstimatorTests extends FlatSpec with MockitoSugar { createLinearLearnerRegressor(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/linear-learner:1") + + val estimatorCNNorth1 = + createLinearLearnerRegressor(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/linear-learner:1") + + val estimatorCNNorthWest1 = + createLinearLearnerRegressor(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/linear-learner:1") } it should "setFeatureDim" in { diff --git a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/PCASageMakerEstimatorTests.scala b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/PCASageMakerEstimatorTests.scala index fb5ce60..619453c 100644 --- a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/PCASageMakerEstimatorTests.scala +++ b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/PCASageMakerEstimatorTests.scala @@ -108,6 +108,14 @@ class PCASageMakerEstimatorTests extends FlatSpec with MockitoSugar { val estimatorMESouth1 = createPCAEstimator(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/pca:1") + + val estimatorCNNorth1 = createPCAEstimator(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/pca:1") + + val estimatorCNNorthWest1 = createPCAEstimator(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pca:1") } it should "use the correct defaults" in { diff --git a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala index 472fbb0..c510814 100644 --- a/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala +++ b/sagemaker-spark-sdk/src/test/scala/com/amazonaws/services/sagemaker/sparksdk/algorithms/XGBoostSageMakerEstimatorTests.scala @@ -127,6 +127,14 @@ class XGBoostSageMakerEstimatorTests extends FlatSpec with Matchers with Mockito val estimatorMESouth1 = createXGBoostEstimator(region = Regions.ME_SOUTH_1.getName) assert(estimatorMESouth1.trainingImage == "249704162688.dkr.ecr.me-south-1.amazonaws.com/xgboost:1") + + val estimatorCNNorth1 = createXGBoostEstimator(region = Regions.CN_NORTH_1.getName) + assert(estimatorCNNorth1.trainingImage == + "390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/xgboost:1") + + val estimatorCNNorthWest1 = createXGBoostEstimator(region = Regions.CN_NORTHWEST_1.getName) + assert(estimatorCNNorthWest1.trainingImage == + "387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/xgboost:1") } it should "setBooster" in {