Skip to content

Commit

Permalink
feat: add Network enum
Browse files Browse the repository at this point in the history
  • Loading branch information
pitt-liang committed Jun 11, 2024
1 parent a73f93d commit bdd0752
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 27 deletions.
27 changes: 21 additions & 6 deletions pai/api/api_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union

from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_sts20150401.client import Client as StsClient

from ..common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT
from ..common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ..common.utils import is_domain_connectable
from .algorithm import AlgorithmAPI
from .base import PAIRestResourceTypes, ServiceName, WorkspaceScopedResourceAPI
Expand Down Expand Up @@ -59,17 +59,32 @@ class ResourceAPIsContainerMixin(object):
_region_id = None
_workspace_id = None

def __init__(self, header=None, runtime=None, network: Optional[str] = None):
def __init__(
self, header=None, runtime=None, network: Optional[Union[str, Network]] = None
):
"""Initialize ResourceAPIsContainerMixin.
Args:
header: Header for API request.
runtime: Runtime for API request.
network: Network type used to connect to PAI services.
"""
self.header = header
self.runtime = runtime
self.api_container = dict()
self.acs_client_container = dict()
if network:
self.network = network
self.network = (
Network.from_string(network) if isinstance(network, str) else network
)
elif DEFAULT_NETWORK_TYPE:
self.network = PAI_VPC_ENDPOINT
self.network = Network.from_string(DEFAULT_NETWORK_TYPE)
else:
self.network = "vpc" if is_domain_connectable(PAI_VPC_ENDPOINT) else None
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
else Network.PUBLIC
)

def _acs_credential_client(self):
if self._credential_client:
Expand Down
9 changes: 5 additions & 4 deletions pai/api/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from alibabacloud_sts20150401.client import Client as StsClient
from alibabacloud_tea_openapi.models import Config

from ..common.consts import Network
from ..common.logging import get_logger
from ..common.utils import http_user_agent
from ..libs.alibabacloud_aiworkspace20210204.client import Client as WorkspaceClient
Expand Down Expand Up @@ -56,7 +57,7 @@ def create_client(
service_name,
region_id: str,
credential_client: CredentialClient,
network: Optional[str] = None,
network: Optional[Network] = None,
**kwargs,
):
"""Create an API client which is responsible to interacted with the Alibaba
Expand All @@ -79,14 +80,14 @@ def create_client(

@classmethod
def get_endpoint(
cls, service_name: str, region_id: str, network: Optional[str] = None
cls, service_name: str, region_id: str, network: Optional[Network] = None
) -> str:
"""Get the endpoint for the service client."""
if not region_id:
raise ValueError("Please provide region_id to get the endpoint.")

if network:
subdomain = f"{service_name}-{network.lower()}"
if network and network != Network.PUBLIC:
subdomain = f"{service_name}-{network.value.lower()}"
else:
subdomain = service_name
return DEFAULT_SERVICE_ENDPOINT_PATTERN.format(subdomain, region_id)
17 changes: 16 additions & 1 deletion pai/common/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import os

# Default path for pai config file
Expand All @@ -26,6 +26,21 @@
PAI_VPC_ENDPOINT = "pai-vpc.aliyuncs.com"


class Network(enum.Enum):
VPC = "VPC"
PUBLIC = "PUBLIC"

@classmethod
def from_string(cls, s: str) -> "Network":
try:
return cls[s.upper()]
except KeyError:
raise ValueError(
"Invalid network type: %s, supported types are: %s"
% (s, ", ".join(cls.__members__.keys()))
)


class JobType(object):
"""PAI DLCJob/TrainingJob type."""

Expand Down
6 changes: 3 additions & 3 deletions pai/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from alibabacloud_credentials.utils import auth_constant

from .api.api_container import ResourceAPIsContainerMixin
from .common.consts import DEFAULT_CONFIG_PATH
from .common.consts import DEFAULT_CONFIG_PATH, Network
from .common.logging import get_logger
from .common.oss_utils import CredentialProviderWrapper, OssUriObj
from .common.utils import is_domain_connectable, make_list_resource_iterator
Expand Down Expand Up @@ -60,7 +60,7 @@ def setup_default_session(
oss_bucket_name: Optional[str] = None,
oss_endpoint: Optional[str] = None,
workspace_id: Optional[Union[str, int]] = None,
network: Optional[str] = None,
network: Optional[Union[str, Network]] = None,
**kwargs,
) -> "Session":
"""Set up the default session used in the program.
Expand All @@ -82,7 +82,7 @@ def setup_default_session(
oss_bucket_name (str, optional): The name of the OSS bucket used in the
session.
oss_endpoint (str, optional): The endpoint for the OSS bucket.
network (str, optional): The network type used to connect to PAI services.
network (Union[str, Network], optional): The network type used to connect to PAI services.
**kwargs:
Returns:
Expand Down
12 changes: 6 additions & 6 deletions pai/toolkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,13 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile):
if not default_storage_uri:
print_warning(
localized_text(
"The current DSW instance is configured to use the PAI DSW default role. "
"The STS temporary credentials generated by the default role only support accessing "
'the OSS Bucket of the "default workspace storage.". Please reference the document to '
"configure the default storage: "
"WARNING: The STS credential generated by the default ROLE only support accessing "
"the default OSS Bucket storage of the workspace.\n"
"It is not configured for the current workspace, please "
"reference the document to configure the default OSS Bucket storage: \n"
"https://help.aliyun.com/zh/pai/user-guide/manage-workspaces#section-afd-ntr-nwh",
'当前DSW实例配置使用PAI DSW默认角色。默认角色产生的STS临时凭证仅支持访问"工作空间默认存储"的OSS Bucket。'
"请参考帮助文档配置工作空间的默认存储:"
'警告:默认角色产生的STS凭证仅支持访问"工作空间默认存储"的OSS Bucket。\n'
"当前工作空间没有配置默认OSS Bucket存储,请参考帮助文档进行配置:\n"
"https://help.aliyun.com/zh/pai/user-guide/manage-workspaces#section-afd-ntr-nwh",
)
)
Expand Down
17 changes: 10 additions & 7 deletions pai/toolkit/helper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...api.base import ServiceName
from ...api.client_factory import ClientFactory
from ...api.workspace import WorkspaceAPI, WorkspaceConfigKeys
from ...common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT
from ...common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ...common.logging import get_logger
from ...common.oss_utils import CredentialProviderWrapper, OssUriObj
from ...common.utils import is_domain_connectable, make_list_resource_iterator
Expand Down Expand Up @@ -105,11 +105,14 @@ def __init__(
):
self.region_id = region_id
self.credential_config = credential_config

if DEFAULT_NETWORK_TYPE:
self._default_network = DEFAULT_NETWORK_TYPE
self.network = Network.from_string(DEFAULT_NETWORK_TYPE)
else:
self._default_network = (
"vpc" if is_domain_connectable(PAI_VPC_ENDPOINT) else None
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
else Network.PUBLIC
)
self._caller_identify = self._get_caller_identity()

Expand All @@ -134,7 +137,7 @@ def _get_caller_identity(self) -> CallerIdentity:
config=open_api_models.Config(
credential=self._get_credential_client(),
region_id=self.region_id,
network=self._default_network,
network=self.network.value.lower(),
)
)
.get_caller_identity()
Expand All @@ -152,7 +155,7 @@ def get_acs_dsw_client(self) -> DswClient:
service_name=ServiceName.PAI_DSW,
credential_client=self._get_credential_client(),
region_id=self.region_id,
network=self._default_network,
network=self.network,
)

def get_instance_info(self, instance_id: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -249,7 +252,7 @@ def get_workspace_api(self) -> WorkspaceAPI:
service_name=ServiceName.PAI_WORKSPACE,
credential_client=self._get_credential_client(),
region_id=self.region_id,
network=self._default_network,
network=self.network,
)

return WorkspaceAPI(
Expand Down

0 comments on commit bdd0752

Please sign in to comment.