Skip to content

Commit

Permalink
feat: add network parameter for session (#24)
Browse files Browse the repository at this point in the history
* feat: pai.tookkit.config supports vpc network

* feat: add network parameter for session

* feat: add `Network` enum
  • Loading branch information
pitt-liang committed Jun 11, 2024
1 parent 4cf0130 commit a2052ef
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 14 deletions.
28 changes: 26 additions & 2 deletions pai/api/api_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_sts20150401.client import Client as StsClient

from ..common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ..common.utils import is_domain_connectable
from .algorithm import AlgorithmAPI
from .base import PAIRestResourceTypes, ServiceName, WorkspaceScopedResourceAPI
from .client_factory import ClientFactory
Expand Down Expand Up @@ -57,11 +59,32 @@ class ResourceAPIsContainerMixin(object):
_region_id = None
_workspace_id = None

def __init__(self, header=None, runtime=None):
def __init__(
self, header=None, runtime=None, network: Optional[Union[str, Network]] = None
):
"""Initialize ResourceAPIsContainerMixin.
Args:
header: Header for API request.
runtime: Runtime for API request.
network: Network type used to connect to PAI services.
"""
self.header = header
self.runtime = runtime
self.api_container = dict()
self.acs_client_container = dict()
if network:
self.network = (
Network.from_string(network) if isinstance(network, str) else network
)
elif DEFAULT_NETWORK_TYPE:
self.network = Network.from_string(DEFAULT_NETWORK_TYPE)
else:
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
else Network.PUBLIC
)

def _acs_credential_client(self):
if self._credential_client:
Expand All @@ -76,6 +99,7 @@ def _get_acs_client(self, service_name):
service_name=service_name,
credential_client=self._acs_credential_client(),
region_id=self._region_id,
network=self.network,
)
self.acs_client_container[service_name] = acs_client
return acs_client
Expand Down
16 changes: 14 additions & 2 deletions pai/api/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

from __future__ import absolute_import

from typing import Optional

from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_sts20150401.client import Client as StsClient
from alibabacloud_tea_openapi.models import Config

from ..common.consts import Network
from ..common.logging import get_logger
from ..common.utils import http_user_agent
from ..libs.alibabacloud_aiworkspace20210204.client import Client as WorkspaceClient
Expand Down Expand Up @@ -54,16 +57,19 @@ def create_client(
service_name,
region_id: str,
credential_client: CredentialClient,
network: Optional[Network] = None,
**kwargs,
):
"""Create an API client which is responsible to interacted with the Alibaba
Cloud service."""

config = Config(
region_id=region_id,
credential=credential_client,
endpoint=cls.get_endpoint(
service_name=service_name,
region_id=region_id,
network=network,
),
signature_algorithm="v2",
user_agent=http_user_agent(),
Expand All @@ -73,9 +79,15 @@ def create_client(
return client

@classmethod
def get_endpoint(cls, service_name: str, region_id: str) -> str:
def get_endpoint(
cls, service_name: str, region_id: str, network: Optional[Network] = None
) -> str:
"""Get the endpoint for the service client."""
if not region_id:
raise ValueError("Please provide region_id to get the endpoint.")

return DEFAULT_SERVICE_ENDPOINT_PATTERN.format(service_name, region_id)
if network and network != Network.PUBLIC:
subdomain = f"{service_name}-{network.value.lower()}"
else:
subdomain = service_name
return DEFAULT_SERVICE_ENDPOINT_PATTERN.format(subdomain, region_id)
23 changes: 22 additions & 1 deletion pai/common/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import os

# Default path for pai config file
DEFAULT_CONFIG_PATH = os.environ.get(
"PAI_CONFIG_PATH", os.path.join(os.path.expanduser("~"), ".pai", "config.json")
)

# Default network type used to connect to PAI services
DEFAULT_NETWORK_TYPE = os.environ.get("PAI_NETWORK_TYPE", None)

# PAI VPC endpoint
PAI_VPC_ENDPOINT = "pai-vpc.aliyuncs.com"


class Network(enum.Enum):
VPC = "VPC"
PUBLIC = "PUBLIC"

@classmethod
def from_string(cls, s: str) -> "Network":
try:
return cls[s.upper()]
except KeyError:
raise ValueError(
"Invalid network type: %s, supported types are: %s"
% (s, ", ".join(cls.__members__.keys()))
)


class JobType(object):
"""PAI DLCJob/TrainingJob type."""
Expand Down
18 changes: 16 additions & 2 deletions pai/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from alibabacloud_credentials.utils import auth_constant

from .api.api_container import ResourceAPIsContainerMixin
from .common.consts import DEFAULT_CONFIG_PATH
from .common.consts import DEFAULT_CONFIG_PATH, Network
from .common.logging import get_logger
from .common.oss_utils import CredentialProviderWrapper, OssUriObj
from .common.utils import is_domain_connectable, make_list_resource_iterator
Expand Down Expand Up @@ -60,6 +60,7 @@ def setup_default_session(
oss_bucket_name: Optional[str] = None,
oss_endpoint: Optional[str] = None,
workspace_id: Optional[Union[str, int]] = None,
network: Optional[Union[str, Network]] = None,
**kwargs,
) -> "Session":
"""Set up the default session used in the program.
Expand All @@ -81,6 +82,11 @@ def setup_default_session(
oss_bucket_name (str, optional): The name of the OSS bucket used in the
session.
oss_endpoint (str, optional): The endpoint for the OSS bucket.
network (Union[str, Network], optional): The network to use for the connection.
supported values are "VPC" and "PUBLIC". If provided, this value will be used as-is.
Otherwise, the code will first check for an environment variable PAI_NETWORK_TYPE.
If that is not set and the VPC endpoint is available, it will be used.
As a last resort, if all else fails, the PUBLIC endpoint will be used.
**kwargs:
Returns:
Expand Down Expand Up @@ -114,13 +120,15 @@ def setup_default_session(
oss_bucket_name = oss_bucket_name or default_session.oss_bucket_name
oss_endpoint = oss_endpoint or default_session.oss_endpoint
credential_config = credential_config or default_session.credential_config
network = network or default_session.network

session = Session(
region_id=region_id,
credential_config=credential_config,
oss_bucket_name=oss_bucket_name,
oss_endpoint=oss_endpoint,
workspace_id=workspace_id,
network=network,
**kwargs,
)

Expand Down Expand Up @@ -208,7 +216,13 @@ def __init__(
self._oss_endpoint = oss_endpoint

header = kwargs.pop("header", None)
super(Session, self).__init__(header=header)
network = kwargs.pop("network", None)
runtime = kwargs.pop("runtime", None)
if kwargs:
logger.warning(
"Unused arguments found in session initialization: %s", kwargs
)
super(Session, self).__init__(header=header, network=network, runtime=runtime)

@property
def region_id(self) -> str:
Expand Down
12 changes: 6 additions & 6 deletions pai/toolkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,13 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile):
if not default_storage_uri:
print_warning(
localized_text(
"The current DSW instance is configured to use the PAI DSW default role. "
"The STS temporary credentials generated by the default role only support accessing "
'the OSS Bucket of the "default workspace storage.". Please reference the document to '
"configure the default storage: "
"WARNING: The STS credential generated by the default ROLE only support accessing "
"the default OSS Bucket storage of the workspace.\n"
"It is not configured for the current workspace, please "
"reference the document to configure the default OSS Bucket storage: \n"
"https://help.aliyun.com/zh/pai/user-guide/manage-workspaces#section-afd-ntr-nwh",
'当前DSW实例配置使用PAI DSW默认角色。默认角色产生的STS临时凭证仅支持访问"工作空间默认存储"的OSS Bucket。'
"请参考帮助文档配置工作空间的默认存储:"
'警告:默认角色产生的STS凭证仅支持访问"工作空间默认存储"的OSS Bucket。\n'
"当前工作空间没有配置默认OSS Bucket存储,请参考帮助文档进行配置:\n"
"https://help.aliyun.com/zh/pai/user-guide/manage-workspaces#section-afd-ntr-nwh",
)
)
Expand Down
17 changes: 16 additions & 1 deletion pai/toolkit/helper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from ...api.base import ServiceName
from ...api.client_factory import ClientFactory
from ...api.workspace import WorkspaceAPI, WorkspaceConfigKeys
from ...common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ...common.logging import get_logger
from ...common.oss_utils import CredentialProviderWrapper, OssUriObj
from ...common.utils import make_list_resource_iterator
from ...common.utils import is_domain_connectable, make_list_resource_iterator
from ...libs.alibabacloud_pai_dsw20220101.client import Client as DswClient

logger = get_logger(__name__)
Expand Down Expand Up @@ -104,6 +105,15 @@ def __init__(
):
self.region_id = region_id
self.credential_config = credential_config

if DEFAULT_NETWORK_TYPE:
self.network = Network.from_string(DEFAULT_NETWORK_TYPE)
else:
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
else Network.PUBLIC
)
self._caller_identify = self._get_caller_identity()

def _get_credential_client(self):
Expand All @@ -127,6 +137,9 @@ def _get_caller_identity(self) -> CallerIdentity:
config=open_api_models.Config(
credential=self._get_credential_client(),
region_id=self.region_id,
network=None
if self.network == Network.PUBLIC
else self.network.value.lower(),
)
)
.get_caller_identity()
Expand All @@ -144,6 +157,7 @@ def get_acs_dsw_client(self) -> DswClient:
service_name=ServiceName.PAI_DSW,
credential_client=self._get_credential_client(),
region_id=self.region_id,
network=self.network,
)

def get_instance_info(self, instance_id: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -240,6 +254,7 @@ def get_workspace_api(self) -> WorkspaceAPI:
service_name=ServiceName.PAI_WORKSPACE,
credential_client=self._get_credential_client(),
region_id=self.region_id,
network=self.network,
)

return WorkspaceAPI(
Expand Down

0 comments on commit a2052ef

Please sign in to comment.