Skip to content

Commit

Permalink
Create configurable sagemaker_session fixture for all integ tests (#104)
Browse files Browse the repository at this point in the history
* Create configurable sagemaker_session fixture for all integ tests

* Update changelog
  • Loading branch information
laurenyu committed Mar 21, 2018
1 parent 6e0047b commit 2af9d45
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 103 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
CHANGELOG
=========

1.1.dev3
========

* feature: Tests: create configurable ``sagemaker_session`` pytest fixture for all integration tests

1.1.2
=======
=====

* bug-fix: AmazonEstimators: do not call create bucket if data location is provided

1.1.1
========
=====

* feature: Estimators: add ``requirements.txt`` support for TensorFlow

Expand Down
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,50 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import json

import boto3
import pytest

from sagemaker import Session

DEFAULT_REGION = 'us-west-2'


def pytest_addoption(parser):
parser.addoption('--sagemaker-client-config', action='store', default=None)
parser.addoption('--sagemaker-runtime-config', action='store', default=None)
parser.addoption('--boto-config', action='store', default=None)


@pytest.fixture(scope='session')
def sagemaker_client_config(request):
config = request.config.getoption('--sagemaker-client-config')
return json.loads(config) if config else None


@pytest.fixture(scope='session')
def sagemaker_runtime_config(request):
config = request.config.getoption('--sagemaker-runtime-config')
return json.loads(config) if config else None


@pytest.fixture(scope='session')
def boto_config(request):
config = request.config.getoption('--boto-config')
return json.loads(config) if config else None


@pytest.fixture(scope='session')
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_config):
sagemaker_client = boto3.client('sagemaker', **sagemaker_client_config) if sagemaker_client_config else None
runtime_client = boto3.client('sagemaker-runtime', **sagemaker_runtime_config) if sagemaker_runtime_config else None
boto_session = boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)

return Session(boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=runtime_client)


@pytest.fixture(scope='module', params=["1.4", "1.4.1", "1.5", "1.5.0"])
def tf_version(request):
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# language governing permissions and limitations under the License.
import logging
import os

DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
REGION = 'us-west-2'

logging.getLogger('boto3').setLevel(logging.INFO)
logging.getLogger('botocore').setLevel(logging.INFO)
22 changes: 13 additions & 9 deletions tests/integ/test_byo_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,36 @@
import gzip
import io
import json
import numpy as np
import os
import pickle
import sys

import boto3
import numpy as np
import pytest

import sagemaker
from sagemaker.estimator import Estimator
from sagemaker.amazon.amazon_estimator import registry
from sagemaker.amazon.common import write_numpy_to_dense_tensor
from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from tests.integ import DATA_DIR, REGION
from tests.integ import DATA_DIR
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name


@pytest.fixture(scope='module')
def region(sagemaker_session):
return sagemaker_session.boto_session.region_name


def fm_serializer(data):
js = {'instances': []}
for row in data:
js['instances'].append({'features': row.tolist()})
return json.dumps(js)


def test_byo_estimator():
def test_byo_estimator(sagemaker_session, region):
"""Use Factorization Machines algorithm as an example here.
First we need to prepare data for training. We take standard data set, convert it to the
Expand All @@ -47,10 +53,9 @@ def test_byo_estimator():
Default predictor is updated with json serializer and deserializer.
"""
image_name = registry(REGION) + "/factorization-machines:1"
image_name = registry(region) + "/factorization-machines:1"

with timeout(minutes=15):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down Expand Up @@ -100,13 +105,12 @@ def test_byo_estimator():
assert prediction['score'] is not None


def test_async_byo_estimator():
image_name = registry(REGION) + "/factorization-machines:1"
def test_async_byo_estimator(sagemaker_session, region):
image_name = registry(region) + "/factorization-machines:1"
endpoint_name = name_from_base('byo')
training_job_name = ""

with timeout(minutes=5):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down
16 changes: 4 additions & 12 deletions tests/integ/test_factorization_machines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,19 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import gzip
import os
import pickle
import sys
import time

import boto3
import os

import sagemaker
from sagemaker import FactorizationMachines, FactorizationMachinesModel
from sagemaker.utils import name_from_base
from tests.integ import DATA_DIR, REGION
from tests.integ import DATA_DIR
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name


def test_factorization_machines():

def test_factorization_machines(sagemaker_session):
with timeout(minutes=15):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down Expand Up @@ -56,14 +51,11 @@ def test_factorization_machines():
assert record.label["score"] is not None


def test_async_factorization_machines():

def test_async_factorization_machines(sagemaker_session):
training_job_name = ""
endpoint_name = name_from_base('factorizationMachines')
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))

with timeout(minutes=5):

data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down
15 changes: 4 additions & 11 deletions tests/integ/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,19 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import gzip
import os
import pickle
import sys

import boto3
import os
import time

import sagemaker
from sagemaker import KMeans, KMeansModel
from sagemaker.utils import name_from_base
from tests.integ import DATA_DIR, REGION
from tests.integ import DATA_DIR
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name


def test_kmeans():

def test_kmeans(sagemaker_session):
with timeout(minutes=15):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down Expand Up @@ -63,13 +58,11 @@ def test_kmeans():
assert record.label["distance_to_cluster"] is not None


def test_async_kmeans():

def test_async_kmeans(sagemaker_session):
training_job_name = ""
endpoint_name = name_from_base('kmeans')

with timeout(minutes=5):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down
12 changes: 4 additions & 8 deletions tests/integ/test_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import boto3
import numpy as np
import os

import sagemaker
import numpy as np

from sagemaker import LDA, LDAModel
from sagemaker.amazon.common import read_records
from sagemaker.utils import name_from_base

from tests.integ import DATA_DIR, REGION
from tests.integ import DATA_DIR
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
from tests.integ.record_set import prepare_record_set_from_local_files


def test_lda():

def test_lda(sagemaker_session):
with timeout(minutes=15):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'lda')
data_filename = 'nips-train_1.pbr'

Expand Down
15 changes: 4 additions & 11 deletions tests/integ/test_linear_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
import pickle
import sys
import time
import pytest # noqa
import boto3

import numpy as np

import sagemaker
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel
from sagemaker.utils import name_from_base, sagemaker_timestamp

from tests.integ import DATA_DIR, REGION
from tests.integ import DATA_DIR
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name


def test_linear_learner():
def test_linear_learner(sagemaker_session):
with timeout(minutes=15):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down Expand Up @@ -87,14 +83,11 @@ def test_linear_learner():
assert record.label["score"] is not None


def test_async_linear_learner():

def test_async_linear_learner(sagemaker_session):
training_job_name = ""
endpoint_name = 'test-linear-learner-async-{}'.format(sagemaker_timestamp())
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))

with timeout(minutes=5):

data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

Expand Down
11 changes: 2 additions & 9 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,16 @@
import os
import time

import boto3
import numpy
import pytest
from sagemaker import Session

from sagemaker.mxnet.estimator import MXNet
from sagemaker.mxnet.model import MXNetModel
from sagemaker.utils import sagemaker_timestamp

from tests.integ import DATA_DIR, REGION
from tests.integ import DATA_DIR
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name


@pytest.fixture(scope='module')
def sagemaker_session():
return Session(boto_session=boto3.Session(region_name=REGION))


@pytest.fixture(scope='module')
def mxnet_training_job(sagemaker_session, mxnet_full_version):
with timeout(minutes=15):
Expand Down
12 changes: 4 additions & 8 deletions tests/integ/test_ntm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import boto3
import numpy as np
import os

import sagemaker
import numpy as np

from sagemaker import NTM, NTMModel
from sagemaker.amazon.common import read_records
from sagemaker.utils import name_from_base

from tests.integ import DATA_DIR, REGION
from tests.integ import DATA_DIR
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
from tests.integ.record_set import prepare_record_set_from_local_files


def test_ntm():

def test_ntm(sagemaker_session):
with timeout(minutes=15):
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
data_path = os.path.join(DATA_DIR, 'ntm')
data_filename = 'nips-train_1.pbr'

Expand Down

0 comments on commit 2af9d45

Please sign in to comment.