Skip to content

Commit

Permalink
feat: add Network enum
Browse files Browse the repository at this point in the history
  • Loading branch information
pitt-liang committed Jun 11, 2024
1 parent a73f93d commit 568c5c9
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 18 deletions.
27 changes: 21 additions & 6 deletions pai/api/api_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union

from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_sts20150401.client import Client as StsClient

from ..common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT
from ..common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ..common.utils import is_domain_connectable
from .algorithm import AlgorithmAPI
from .base import PAIRestResourceTypes, ServiceName, WorkspaceScopedResourceAPI
Expand Down Expand Up @@ -59,17 +59,32 @@ class ResourceAPIsContainerMixin(object):
_region_id = None
_workspace_id = None

def __init__(self, header=None, runtime=None, network: Optional[str] = None):
def __init__(
self, header=None, runtime=None, network: Optional[Union[str, Network]] = None
):
"""Initialize ResourceAPIsContainerMixin.
Args:
header: Header for API request.
runtime: Runtime for API request.
network: Network type used to connect to PAI services.
"""
self.header = header
self.runtime = runtime
self.api_container = dict()
self.acs_client_container = dict()
if network:
self.network = network
self.network = (
Network.from_string(network) if isinstance(network, str) else network
)
elif DEFAULT_NETWORK_TYPE:
self.network = PAI_VPC_ENDPOINT
self.network = Network.from_string(DEFAULT_NETWORK_TYPE)
else:
self.network = "vpc" if is_domain_connectable(PAI_VPC_ENDPOINT) else None
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
else Network.PUBLIC
)

def _acs_credential_client(self):
if self._credential_client:
Expand Down
9 changes: 5 additions & 4 deletions pai/api/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from alibabacloud_sts20150401.client import Client as StsClient
from alibabacloud_tea_openapi.models import Config

from ..common.consts import Network
from ..common.logging import get_logger
from ..common.utils import http_user_agent
from ..libs.alibabacloud_aiworkspace20210204.client import Client as WorkspaceClient
Expand Down Expand Up @@ -56,7 +57,7 @@ def create_client(
service_name,
region_id: str,
credential_client: CredentialClient,
network: Optional[str] = None,
network: Optional[Network] = None,
**kwargs,
):
"""Create an API client which is responsible to interacted with the Alibaba
Expand All @@ -79,14 +80,14 @@ def create_client(

@classmethod
def get_endpoint(
cls, service_name: str, region_id: str, network: Optional[str] = None
cls, service_name: str, region_id: str, network: Optional[Network] = None
) -> str:
"""Get the endpoint for the service client."""
if not region_id:
raise ValueError("Please provide region_id to get the endpoint.")

if network:
subdomain = f"{service_name}-{network.lower()}"
if network and network != Network.PUBLIC:
subdomain = f"{service_name}-{network.value.lower()}"
else:
subdomain = service_name
return DEFAULT_SERVICE_ENDPOINT_PATTERN.format(subdomain, region_id)
17 changes: 16 additions & 1 deletion pai/common/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import os

# Default path for pai config file
Expand All @@ -26,6 +26,21 @@
PAI_VPC_ENDPOINT = "pai-vpc.aliyuncs.com"


class Network(enum.Enum):
VPC = "VPC"
PUBLIC = "PUBLIC"

@classmethod
def from_string(cls, s: str) -> "Network":
try:
return cls[s.upper()]
except KeyError:
raise ValueError(
"Invalid network type: %s, supported types are: %s"
% (s, ", ".join(cls.__members__.keys()))
)


class JobType(object):
"""PAI DLCJob/TrainingJob type."""

Expand Down
17 changes: 10 additions & 7 deletions pai/toolkit/helper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...api.base import ServiceName
from ...api.client_factory import ClientFactory
from ...api.workspace import WorkspaceAPI, WorkspaceConfigKeys
from ...common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT
from ...common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ...common.logging import get_logger
from ...common.oss_utils import CredentialProviderWrapper, OssUriObj
from ...common.utils import is_domain_connectable, make_list_resource_iterator
Expand Down Expand Up @@ -105,11 +105,14 @@ def __init__(
):
self.region_id = region_id
self.credential_config = credential_config

if DEFAULT_NETWORK_TYPE:
self._default_network = DEFAULT_NETWORK_TYPE
self.network = Network.from_string(DEFAULT_NETWORK_TYPE)
else:
self._default_network = (
"vpc" if is_domain_connectable(PAI_VPC_ENDPOINT) else None
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
else Network.PUBLIC
)
self._caller_identify = self._get_caller_identity()

Expand All @@ -134,7 +137,7 @@ def _get_caller_identity(self) -> CallerIdentity:
config=open_api_models.Config(
credential=self._get_credential_client(),
region_id=self.region_id,
network=self._default_network,
network=self.network.value.lower(),
)
)
.get_caller_identity()
Expand All @@ -152,7 +155,7 @@ def get_acs_dsw_client(self) -> DswClient:
service_name=ServiceName.PAI_DSW,
credential_client=self._get_credential_client(),
region_id=self.region_id,
network=self._default_network,
network=self.network,
)

def get_instance_info(self, instance_id: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -249,7 +252,7 @@ def get_workspace_api(self) -> WorkspaceAPI:
service_name=ServiceName.PAI_WORKSPACE,
credential_client=self._get_credential_client(),
region_id=self.region_id,
network=self._default_network,
network=self.network,
)

return WorkspaceAPI(
Expand Down

0 comments on commit 568c5c9

Please sign in to comment.