Skip to content

awslabs/datawig-sagemaker

DataWig on SageMaker

GitHub license GitHub issues Build Status

This packages shows how to package DataWig imputation algorithm for use with SageMaker. The code and structure of the package are heavily influenced by examples from the Amazon-SageMaker-Examples repository.

The following stack is used:

  1. nginx is a light-weight layer that handles the incoming HTTP requests and manages the I/O in and out of the container efficiently.
  2. gunicorn is a WSGI pre-forking worker server that runs multiple copies of the application and load balances between them.
  3. flask is a simple web framework. It lets application to respond to call on the /ping and /invocations endpoints without having to write much code.

The Structure of the Code

  • Dockerfile.

  • build_and_push.sh: The script to build the Docker image (using the Dockerfile above) and push it to the Amazon EC2 Container Registry (ECR) so that it can be deployed to SageMaker. Name of the image is used as the only argument to this script. The script will generate a full name for the repository in AWS account account and configured AWS region. If this ECR repository doesn't exist, the script will create it.

    • As part of this script you can set the desired DataWig version to be installed inside the Docker image. Check our PyPI repository for the latest version.
  • imputation: The directory that contains the application to run in the container.

  • test: The directory that contains scripts and a setup for running a simple training and inference jobs locally.

  • sagemaker: The directory that contains an example of client code to setup an endpoint with imputation model in SageMaker.

The application run inside the container

This container is set up so that the argument in treated as the command that the container executes. When training, it will run the train program included and, when serving, it will run the serve program.

  • train: The main program for training the model.
  • serve: The wrapper that starts the inference server.
  • wsgi.py: The start up shell for the individual server workers.
  • imputer.py: The algorithm-specific imputation server.
  • nginx.conf: The configuration for the nginx master server that manages the multiple workers.

Setup for local testing

The subdirectory 'test' contains scripts and sample data for testing the built container image on the local machine.

  • common.sh: Stores shared variables across test scripts.
  • train_local.sh: Instantiate the container configured for training.
  • serve_local.sh: Instantiate the container configured for serving.
  • impute.sh: Run predictions against a locally instantiated server.
  • sagemaker_fs: The directory that gets mounted into the container with test data mounted in all the places that match the container schema.
  • test.csv: Sample data for used by impute.sh for testing the server.
 ./test/train_local.sh $IMAGE_NAME 
 ./test/serve_local.sh $IMAGE_NAME 
 ./test/impute.sh ./test/test.csv

The training and test data is a subsample of the IMDb data that was introduced by Maas et al. in Learning Word Vectors for Sentiment Analysis.

The directory tree mounted into the container

The tree under test-dir is mounted into the container and mimics the directory structure that SageMaker would create for the running container during training or hosting.

  • input/config/hyperparameters.json: The hyperparameters for the training job.
  • input/data/training/train.csv: The training data.
  • model: The directory where the algorithm writes the model file.
  • output: The directory where the algorithm can write its success or failure file.

Client code example

  • client.py: Code example to train imputation model and host it in SageMaker. ALGORITHM_NAME, S3_BUCKET and ROLE parameters must be updated before running the script
  • reqirements.txt: Required dependencies for client code to run
 pip3 install -r sagemaker/requirements.txt
 python3 sagemaker/client.py

Environment variables

When you create an imputation server, you can control some of Gunicorn's options via environment variables. These can be supplied as part of the CreateModel API call.

Parameter                Environment Variable              Default Value
---------                --------------------              -------------
number of workers        MODEL_SERVER_WORKERS              the number of CPU cores
timeout                  MODEL_SERVER_TIMEOUT              60 seconds

License

This library is licensed under the Apache 2.0 License.