@@ -100,3 +100,73 @@ def modify_node(self, node):
100100 if node .module == "sagemaker.session" :
101101 node .module = "sagemaker.inputs"
102102 return node
103+
104+
105+ class ShuffleConfigModuleRenamer (Modifier ):
106+ """A class to change ``ShuffleConfig`` usage to use ``sagemaker.inputs.ShuffleConfig``."""
107+
108+ def node_should_be_modified (self , node ):
109+ """Checks if the ``ast.Call`` node instantiates a class of interest.
110+
111+ This looks for the following calls:
112+
113+ - ``sagemaker.session.ShuffleConfig``
114+ - ``session.ShuffleConfig``
115+
116+ Args:
117+ node (ast.Call): a node that represents a function call. For more,
118+ see https://docs.python.org/3/library/ast.html#abstract-grammar.
119+
120+ Returns:
121+ bool: If the ``ast.Call`` instantiates a class of interest.
122+ """
123+ if isinstance (node .func , ast .Name ):
124+ return False
125+
126+ return matching .matches_name_or_namespaces (
127+ node , "ShuffleConfig" , ("sagemaker.session" , "session" )
128+ )
129+
130+ def modify_node (self , node ):
131+ """Modifies the ``ast.Call`` node to call ``sagemaker.inputs.ShuffleConfig``.
132+
133+ Args:
134+ node (ast.Call): a node that represents a ``sagemaker.session.ShuffleConfig``
135+ constructor.
136+
137+ Returns:
138+ ast.Call: the original node, with its namespace changed to use the ``inputs`` module.
139+ """
140+ _rename_namespace (node , "session" )
141+ return node
142+
143+
144+ class ShuffleConfigImportFromRenamer (Modifier ):
145+ """A class to update import statements of ``ShuffleConfig``."""
146+
147+ def node_should_be_modified (self , node ):
148+ """Checks if the import statement imports ``sagemaker.session.ShuffleConfig``.
149+
150+ Args:
151+ node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
152+ For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
153+
154+ Returns:
155+ bool: If the import statement imports ``sagemaker.session.ShuffleConfig``.
156+ """
157+ return node .module == "sagemaker.session" and any (
158+ name .name == "ShuffleConfig" for name in node .names
159+ )
160+
161+ def modify_node (self , node ):
162+ """Changes the ``ast.ImportFrom`` node's namespace to ``sagemaker.inputs``.
163+
164+ Args:
165+ node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
166+ For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
167+
168+ Returns:
169+ ast.ImportFrom: the original node, with its module modified to ``"sagemaker.inputs"``.
170+ """
171+ node .module = "sagemaker.inputs"
172+ return node
0 commit comments