Skip to content

Commit

Permalink
Tabular Distributed Training (#72)
Browse files Browse the repository at this point in the history
Co-authored-by: Weisu Yin <weisuyin96@gmail.com>
  • Loading branch information
Weisu Yin and yinweisu committed Apr 24, 2023
1 parent bffdf96 commit fb9e7eb
Show file tree
Hide file tree
Showing 44 changed files with 1,326 additions and 195 deletions.
3 changes: 3 additions & 0 deletions .github/workflow_scripts/env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ function install_latest_tabular {
function install_latest_multimodal {
install_latest_tabular_and_multimodal_dependencies
python3 -m pip install -e autogluon/multimodal/
mim install mmcv-full --timeout 60
python3 -m pip install --upgrade "mmdet>=2.28, <3.0.0"
python3 -m pip install --upgrade "mmocr<1.0"
}

function install_tabular {
Expand Down
12 changes: 6 additions & 6 deletions .github/workflow_scripts/test_cluster.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Disable now because some weird issue appeared today (3/27/2023) on github to unblock release
# TODO: Investigate and re-enable after release
#!/bin/bash

# set -ex
AG_VERSION="${1:-source}"

# source $(dirname "$0")/env_setup.sh
set -ex

# install_cloud_test
source $(dirname "$0")/env_setup.sh

# python3 -m pytest --forked --junitxml=results.xml tests/unittests/cluster/
install_cloud_test

python3 -m pytest --forked --junitxml=results.xml tests/unittests/cluster/ --framework_version $AG_VERSION
6 changes: 6 additions & 0 deletions .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
chmod +x ./.github/workflow_scripts/cloud_lint_check.sh && ./.github/workflow_scripts/cloud_lint_check.sh
test_general_cloud:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
needs: cloud_lint_check
Expand Down Expand Up @@ -95,6 +96,7 @@ jobs:
ag_version: '${{ matrix.AG_VERSION }}'
test_tabular_cloud:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
needs: cloud_lint_check
Expand Down Expand Up @@ -126,6 +128,7 @@ jobs:
ag_version: '${{ matrix.AG_VERSION }}'
test_text_cloud:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
needs: cloud_lint_check
Expand Down Expand Up @@ -157,6 +160,7 @@ jobs:
ag_version: '${{ matrix.AG_VERSION }}'
test_image_cloud:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
needs: cloud_lint_check
Expand Down Expand Up @@ -188,6 +192,7 @@ jobs:
ag_version: '${{ matrix.AG_VERSION }}'
test_multimodal_cloud:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
needs: cloud_lint_check
Expand Down Expand Up @@ -219,6 +224,7 @@ jobs:
ag_version: '${{ matrix.AG_VERSION }}'
test_timeseries_cloud:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
needs: cloud_lint_check
Expand Down
88 changes: 45 additions & 43 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,85 @@
import os
import sys

sys.path = ['.', '..'] + sys.path
sys.path = [".", ".."] + sys.path

project = 'AutoGluon-Cloud'
release = '0.2.1'
copyright = '2023, All authors. Licensed under Apache 2.0.'
author = 'AutoGluon contributors'
project = "AutoGluon-Cloud"
release = "0.2.1"
copyright = "2023, All authors. Licensed under Apache 2.0."
author = "AutoGluon contributors"

extensions = [
'myst_nb', # myst-nb.readthedocs.io
'sphinx_copybutton', # sphinx-copybutton.readthedocs.io
'sphinx_design', # github.com/executablebooks/sphinx-design
'sphinx_inline_tabs', # sphinx-inline-tabs.readthedocs.io
'sphinx_togglebutton', # sphinx-togglebutton.readthedocs.io
'sphinx.ext.autodoc', # www.sphinx-doc.org/en/master/usage/extensions/autodoc.html
'sphinx.ext.autosummary', # www.sphinx-doc.org/en/master/usage/extensions/autosummary.html
'sphinx.ext.napoleon', # www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
'sphinx.ext.viewcode', # www.sphinx-doc.org/en/master/usage/extensions/viewcode.html
'sphinxcontrib.googleanalytics', # github.com/sphinx-contrib/googleanalytics
]
"myst_nb", # myst-nb.readthedocs.io
"sphinx_copybutton", # sphinx-copybutton.readthedocs.io
"sphinx_design", # github.com/executablebooks/sphinx-design
"sphinx_inline_tabs", # sphinx-inline-tabs.readthedocs.io
"sphinx_togglebutton", # sphinx-togglebutton.readthedocs.io
"sphinx.ext.autodoc", # www.sphinx-doc.org/en/master/usage/extensions/autodoc.html
"sphinx.ext.autosummary", # www.sphinx-doc.org/en/master/usage/extensions/autosummary.html
"sphinx.ext.napoleon", # www.sphinx-doc.org/en/master/usage/extensions/napoleon.html
"sphinx.ext.viewcode", # www.sphinx-doc.org/en/master/usage/extensions/viewcode.html
"sphinxcontrib.googleanalytics", # github.com/sphinx-contrib/googleanalytics
]

# See https://myst-parser.readthedocs.io/en/latest/syntax/optional.html
myst_enable_extensions = ['colon_fence', 'deflist', 'dollarmath', 'html_image', 'substitution']
myst_enable_extensions = ["colon_fence", "deflist", "dollarmath", "html_image", "substitution"]

autosummary_generate = True
numpydoc_show_class_members = False

googleanalytics_id = "UA-96378503-20"

nb_execution_mode = 'force'
nb_execution_mode = "force"
# nb_execution_raise_on_error=True
nb_execution_timeout = 3600
nb_merge_streams = True

nb_execution_excludepatterns = ['jupyter_execute']
nb_execution_excludepatterns = ["jupyter_execute"]

nb_dirs_to_exec = [os.path.join('tutorials', tag) for tag in tags if os.path.isdir(os.path.join('tutorials', tag))]
nb_dirs_to_exec = [os.path.join("tutorials", tag) for tag in tags if os.path.isdir(os.path.join("tutorials", tag))]

if len(nb_dirs_to_exec) > 0:
nb_dirs_to_exclude = [dirpath for dirpath, _, filenames in os.walk('tutorials')
if any(map(lambda x: x.endswith('.ipynb'), filenames))
and not dirpath.startswith(tuple(nb_dirs_to_exec))]
nb_dirs_to_exclude = [
dirpath
for dirpath, _, filenames in os.walk("tutorials")
if any(map(lambda x: x.endswith(".ipynb"), filenames)) and not dirpath.startswith(tuple(nb_dirs_to_exec))
]

for nb_dir in nb_dirs_to_exclude:
nb_execution_excludepatterns.append(os.path.join(nb_dir, '*.ipynb'))
nb_execution_excludepatterns.append(os.path.join(nb_dir, "*.ipynb"))

templates_path = ['_templates']
exclude_patterns = ['_build', '_templates', 'README.md', 'ReleaseInstructions.md', 'jupyter_execute']
master_doc = 'index'
templates_path = ["_templates"]
exclude_patterns = ["_build", "_templates", "README.md", "ReleaseInstructions.md", "jupyter_execute"]
master_doc = "index"
numfig = True
numfig_secnum_depth = 2
math_numfig = True
math_number_all = True

# suppress_warnings = ['misc.highlighting_failure']

html_theme = 'furo' # furo.readthedocs.io
html_theme = "furo" # furo.readthedocs.io
html_theme_options = {
'sidebar_hide_name': True,
'light_logo': 'autogluon.png',
'dark_logo': 'autogluon-w.png',
'globaltoc_collapse': False,
"sidebar_hide_name": True,
"light_logo": "autogluon.png",
"dark_logo": "autogluon-w.png",
"globaltoc_collapse": False,
}

html_sidebars = {
'**': [
'sidebar/brand.html',
'sidebar/search.html',
'sidebar/scroll-start.html',
'sidebar/navigation.html',
"**": [
"sidebar/brand.html",
"sidebar/search.html",
"sidebar/scroll-start.html",
"sidebar/navigation.html",
# 'sidebar/ethical-ads.html', # furo maintainer requests this is set if docs are hosted on readthedocs.io
'sidebar/scroll-end.html',
'sidebar/variant-selector.html'
"sidebar/scroll-end.html",
"sidebar/variant-selector.html",
]
}

html_favicon = '_static/favicon.ico'
html_favicon = "_static/favicon.ico"

html_static_path = ['_static']
html_css_files = ['custom.css']
html_js_files = ['custom.js']
html_static_path = ["_static"]
html_css_files = ["custom.css"]
html_js_files = ["custom.js"]
2 changes: 2 additions & 0 deletions src/autogluon/cloud/backend/backend_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .backend import Backend
from .multimodal_sagemaker_backend import MultiModalSagemakerBackend
from .ray_aws_backend import TabularRayAWSBackend
from .sagemaker_backend import SagemakerBackend
from .tabular_sagemaker_backend import TabularSagemakerBackend
from .timeseries_sagemaker_backend import TimeSeriesSagemakerBackend
Expand All @@ -11,6 +12,7 @@ class BackendFactory:
TabularSagemakerBackend,
MultiModalSagemakerBackend,
TimeSeriesSagemakerBackend,
TabularRayAWSBackend,
]
__name_to_backend = {cls.name: cls for cls in __supported_backend}

Expand Down
2 changes: 2 additions & 0 deletions src/autogluon/cloud/backend/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
MULTIMODL_SAGEMAKER = "multimodal_sagemaker"
TIMESERIES_SAGEMAKER = "timeseries_sagemaker"
RAY = "ray"
RAY_AWS = "ray_aws"
TABULAR_RAY_AWS = "tabular_ray_aws"
100 changes: 100 additions & 0 deletions src/autogluon/cloud/backend/ray_aws_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import logging
import time
from typing import Dict, Optional

import boto3

from ..cluster.ray_aws_cluster_config_generator import RayAWSClusterConfigGenerator
from ..cluster.ray_aws_cluster_manager import RayAWSClusterManager
from ..cluster.ray_cluster_config_generator import RayClusterConfigGenerator
from ..cluster.ray_cluster_manager import RayClusterManager
from ..utils.ec2 import create_key_pair, delete_key_pair
from ..utils.iam import (
add_role_to_instance_profile,
attach_iam_policy,
create_iam_policy,
create_iam_role,
create_instance_profile,
delete_iam_policy,
get_policy,
replace_iam_policy_place_holder,
replace_trust_relationship_place_holder,
)
from ..utils.ray_aws_iam import (
ECR_READ_ONLY,
RAY_AWS_CLOUD_POLICY,
RAY_AWS_POLICY_NAME,
RAY_AWS_ROLE_NAME,
RAY_AWS_TRUST_RELATIONSHIP,
RAY_INSTANCE_PROFILE_NAME,
)
from .constant import RAY_AWS, TABULAR_RAY_AWS
from .ray_backend import RayBackend

logger = logging.getLogger(__name__)


class RayAWSBackend(RayBackend):
name = RAY_AWS

@property
def _cluster_config_generator(self) -> RayClusterConfigGenerator:
return RayAWSClusterConfigGenerator

@property
def _cluster_manager(self) -> RayClusterManager:
return RayAWSClusterManager

@property
def _config_file_name(self) -> str:
return "ag_ray_aws_cluster_config.yaml"

def initialize(self, **kwargs) -> None:
"""Initialize the backend."""
super().initialize(**kwargs)
self._boto_session = boto3.session.Session()
self.region = self._boto_session.region_name
assert (
self.region is not None
), "Please setup a region via `export AWS_DEFAULT_REGION=YOUR_REGION` in the terminal"

def generate_default_permission(self, **kwargs) -> Dict[str, str]:
"""Generate default permission file user could use to setup the corresponding entity, i.e. IAM Role in AWS"""
return RayAWSClusterManager.generate_default_permission(**kwargs)

def _setup_role_and_permission(self):
"""
AutoGluon distributed training requires access to s3 bucket and ecr repo.
This means the default role being created by ray is not enough.
"""
account_id = boto3.client("sts").get_caller_identity().get("Account")
cloud_output_bucket = self.cloud_output_path
trust_relationship = replace_trust_relationship_place_holder(
trust_relationship_document=RAY_AWS_TRUST_RELATIONSHIP, account_id=account_id
)
iam_policy = replace_iam_policy_place_holder(
policy_document=RAY_AWS_CLOUD_POLICY, account_id=account_id, bucket=cloud_output_bucket
)
create_iam_role(role_name=RAY_AWS_ROLE_NAME, trust_relationship=trust_relationship)
policy_arn = get_policy(policy_name=RAY_AWS_POLICY_NAME)
if policy_arn is not None:
delete_iam_policy(policy_arn=policy_arn)
policy_arn = create_iam_policy(policy_name=RAY_AWS_POLICY_NAME, policy=iam_policy)
attach_iam_policy(role_name=RAY_AWS_ROLE_NAME, policy_arn=policy_arn)
attach_iam_policy(role_name=RAY_AWS_ROLE_NAME, policy_arn=ECR_READ_ONLY)
instance_profile_arn = create_instance_profile(instance_profile_name=RAY_INSTANCE_PROFILE_NAME)
if instance_profile_arn is not None:
add_role_to_instance_profile(instance_profile_name=RAY_INSTANCE_PROFILE_NAME, role_name=RAY_AWS_ROLE_NAME)
time.sleep(5) # Leave sometime to allow resource to propagate

def _setup_key(self, key_name: str, local_path: str) -> str:
return create_key_pair(key_name=key_name, local_path=local_path)

def _cleanup_key(self, key_name: str, local_path: Optional[str]):
delete_key_pair(key_name=key_name, local_path=local_path)


class TabularRayAWSBackend(RayAWSBackend):
name = TABULAR_RAY_AWS
Loading

0 comments on commit fb9e7eb

Please sign in to comment.