Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/sagemaker_core/tools/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,5 @@
CONFIG_SCHEMA_FILE_NAME = "config_schema.py"

API_COVERAGE_JSON_FILE_PATH = os.getcwd() + "/src/sagemaker_core/tools/api_coverage.json"

SHAPES_WITH_JSON_FIELD_ALIAS = ["MonitoringDatasetFormat"] # Shapes with field name with "json"
11 changes: 9 additions & 2 deletions src/sagemaker_core/tools/shapes_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LICENCES_STRING,
GENERATED_CLASSES_LOCATION,
SHAPES_CODEGEN_FILE_NAME,
SHAPES_WITH_JSON_FIELD_ALIAS,
)
from sagemaker_core.tools.shapes_extractor import ShapesExtractor
from sagemaker_core.main.utils import (
Expand Down Expand Up @@ -180,7 +181,13 @@ def _generate_doc_string_for_shape(self, shape):

if "members" in shape_dict:
for member, member_attributes in shape_dict["members"].items():
docstring += f"\n{convert_to_snake_case(member)}"
# Add alias if field name is json, to address the Bug: https://github.com/aws/sagemaker-python-sdk/issues/4944
if shape in SHAPES_WITH_JSON_FIELD_ALIAS and member == "Json":
updated_member = "JsonFormat"
docstring += f"\n{convert_to_snake_case(updated_member)}"
else:
docstring += f"\n{convert_to_snake_case(member)}"

if "documentation" in member_attributes:
docstring += f": {member_attributes['documentation']}"

Expand All @@ -204,7 +211,7 @@ def generate_imports(self):
"""
imports = "import datetime\n"
imports += "\n"
imports += "from pydantic import BaseModel, ConfigDict\n"
imports += "from pydantic import BaseModel, ConfigDict, Field\n"
imports += "from typing import List, Dict, Optional, Any, Union\n"
imports += "from sagemaker_core.main.utils import Unassigned"
imports += "\n"
Expand Down
40 changes: 33 additions & 7 deletions src/sagemaker_core/tools/shapes_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from functools import lru_cache
from typing import Optional, Any

from sagemaker_core.tools.constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES, SHAPE_DAG_FILE_PATH
from sagemaker_core.tools.constants import (
BASIC_JSON_TYPES_TO_PYTHON_TYPES,
SHAPE_DAG_FILE_PATH,
SHAPES_WITH_JSON_FIELD_ALIAS,
)
from sagemaker_core.main.utils import (
reformat_file_with_black,
convert_to_snake_case,
Expand Down Expand Up @@ -99,6 +103,11 @@ def get_shapes_dag(self):
_dag[shape] = {"type": "structure", "members": []}
for member, member_attrs in shape_data["members"].items():
shape_node_member = {"name": member, "shape": member_attrs["shape"]}
# Add alias if field name is json, to address the Bug: https://github.com/aws/sagemaker-python-sdk/issues/4944
if shape in SHAPES_WITH_JSON_FIELD_ALIAS and member == "Json":
shape_node_member["name"] = "JsonFormat"
shape_node_member["alias"] = "json"

member_shape_dict = _all_shapes[member_attrs["shape"]]
shape_node_member["type"] = member_shape_dict["type"]
_dag[shape]["members"].append(shape_node_member)
Expand Down Expand Up @@ -218,6 +227,8 @@ def generate_shape_members(self, shape, required_override=()):
# bring the required members in front
ordered_members = {key: members[key] for key in required_args if key in members}
ordered_members.update(members)
field_aliases = {}

for member_name, member_attrs in ordered_members.items():
member_shape_name = member_attrs["shape"]
if self.combined_shapes[member_shape_name]:
Expand All @@ -234,13 +245,28 @@ def generate_shape_members(self, shape, required_override=()):
member_type = BASIC_JSON_TYPES_TO_PYTHON_TYPES[member_shape_type]
else:
raise Exception("The Shape definition mush exist. The Json Data might be corrupt")
member_name_snake_case = convert_to_snake_case(member_name)
if member_name in required_args:
init_data_body[f"{member_name_snake_case}"] = f"{member_type}"
else:
init_data_body[f"{member_name_snake_case}"] = (
f"Optional[{member_type}] = Unassigned()"

is_required = member_name in required_args
# Add alias if field name is json, to address the Bug: https://github.com/aws/sagemaker-python-sdk/issues/4944
if shape in SHAPES_WITH_JSON_FIELD_ALIAS and member_name == "Json":
updated_member_name_snake_case = "json_format"
field_aliases[updated_member_name_snake_case] = "json"
init_data_body[f"{updated_member_name_snake_case}"] = (
(
f"{member_type} = Field(alias='{field_aliases[updated_member_name_snake_case]}')"
)
if is_required
else f"Optional[{member_type}] = Field(default=Unassigned(), alias='json')"
)
else:
member_name_snake_case = convert_to_snake_case(member_name)
if is_required:
init_data_body[f"{member_name_snake_case}"] = f"{member_type}"
else:
init_data_body[f"{member_name_snake_case}"] = (
f"Optional[{member_type}] = Unassigned()"
)

return init_data_body

@lru_cache
Expand Down
Loading