22import torch
33import sys
44import inspect
5+ import ast
56from .graph_compiler_backend import GraphCompilerBackend
67from ..fx_graph_serialize_util import serialize_graph_module_to_str
78
@@ -12,12 +13,21 @@ def __call__(self, model):
1213 unstable_api = os .getenv ("DISALLOWED_UNSTABLE_API" , "" ).strip ()
1314 self .unstable_api = unstable_api
1415
16+ # Use torch.compile's backend method to get graph module uniformly
17+ # This ensures all models use the same conversion method, avoiding performance differences
1518 def my_backend (gm , sample_inputs ):
19+ # Convert unstable API
1620 gm = self .unstable_to_stable (gm )
1721 self .check_unstable_api (gm )
22+ # Return forward function without additional optimization
1823 return gm .forward
1924
20- return torch .compile (backend = my_backend )(model )
25+ # Use torch.compile to get graph module and perform conversion
26+ # Although compile is used, the backend only does API conversion, no optimization
27+ # Performance should be close to eager mode (since only API replacement is done)
28+ # Note: Do not use mode parameter to avoid version compatibility issues
29+ compiled_model = torch .compile (model , backend = my_backend )
30+ return compiled_model
2131
2232 """
2333 TODO: Implement logic to convert unstable APIs in `self.model` into their stable counterparts.
@@ -147,6 +157,157 @@ def _impl_unstable_to_stable_special_logit(self, gm):
147157
148158 return gm
149159
160+ def _impl_unstable_to_stable_linear_to_functional_linear (self , gm ):
161+ """
162+ Convert torch._C._nn.linear to torch.nn.functional.linear
163+
164+ Args:
165+ gm: torch.fx.GraphModule object
166+
167+ Returns:
168+ Modified GraphModule object
169+ """
170+ # Get reference to torch._C._nn.linear for comparison
171+ try :
172+ unstable_linear = torch ._C ._nn .linear
173+ except AttributeError :
174+ unstable_linear = None
175+
176+ # Traverse all nodes to find nodes that need to be replaced
177+ nodes_to_replace = []
178+ for node in gm .graph .nodes :
179+ if node .op == "call_function" :
180+ target = node .target
181+ should_replace = False
182+
183+ # Method 1: Direct target comparison (most reliable)
184+ if unstable_linear is not None and target is unstable_linear :
185+ should_replace = True
186+ # Method 2: Check if it's the same function object (using id comparison)
187+ elif unstable_linear is not None and id (target ) == id (unstable_linear ):
188+ should_replace = True
189+ # Method 3: Check module and name attributes (most reliable method, as torch.fx preserves these attributes)
190+ elif hasattr (target , "__module__" ) and hasattr (target , "__name__" ):
191+ if (
192+ target .__module__ == "torch._C._nn"
193+ and target .__name__ == "linear"
194+ ):
195+ should_replace = True
196+ # Method 4: Check via string representation (fallback method)
197+ elif "torch._C._nn.linear" in str (target ) or (
198+ hasattr (target , "__qualname__" )
199+ and "linear" in target .__qualname__
200+ and hasattr (target , "__module__" )
201+ and "torch._C._nn" in str (target .__module__ )
202+ ):
203+ should_replace = True
204+
205+ if should_replace :
206+ nodes_to_replace .append (node )
207+
208+ # Since torch._C._nn.linear and torch.nn.functional.linear are the same object,
209+ # the code generator cannot distinguish them, so we need to use AST to modify the code string after code generation
210+ if nodes_to_replace :
211+ # First recompile to generate code
212+ gm .recompile ()
213+
214+ # Use AST to modify the generated code, replacing torch._C._nn.linear with torch.nn.functional.linear
215+ code_str = gm .code
216+
217+ # Parse AST
218+ tree = ast .parse (code_str )
219+
220+ class LinearReplacer (ast .NodeTransformer ):
221+ def visit_Call (self , node ):
222+ # Check if it's a torch._C._nn.linear call
223+ # Structure: torch._C._nn.linear(...)
224+ if isinstance (node .func , ast .Attribute ):
225+ # node.func.attr should be "linear"
226+ if node .func .attr == "linear" :
227+ # node.func.value should be torch._C._nn
228+ if isinstance (node .func .value , ast .Attribute ):
229+ # node.func.value.attr should be "_nn"
230+ if node .func .value .attr == "_nn" :
231+ # node.func.value.value should be torch._C
232+ if isinstance (node .func .value .value , ast .Attribute ):
233+ # node.func.value.value.attr should be "_C"
234+ if node .func .value .value .attr == "_C" :
235+ # node.func.value.value.value should be torch
236+ if (
237+ isinstance (
238+ node .func .value .value .value ,
239+ ast .Name ,
240+ )
241+ and node .func .value .value .value .id
242+ == "torch"
243+ ):
244+ # Found torch._C._nn.linear, replace with torch.nn.functional.linear
245+ new_func = ast .Attribute (
246+ value = ast .Attribute (
247+ value = ast .Attribute (
248+ value = ast .Name (
249+ id = "torch" ,
250+ ctx = ast .Load (),
251+ ),
252+ attr = "nn" ,
253+ ctx = ast .Load (),
254+ ),
255+ attr = "functional" ,
256+ ctx = ast .Load (),
257+ ),
258+ attr = "linear" ,
259+ ctx = ast .Load (),
260+ )
261+ node .func = new_func
262+ return self .generic_visit (node )
263+
264+ transformer = LinearReplacer ()
265+ modified_tree = transformer .visit (tree )
266+ ast .fix_missing_locations (modified_tree )
267+
268+ # Convert the modified AST back to code string
269+ new_code = ast .unparse (modified_tree )
270+
271+ # Recompile the modified code
272+ # Need to import device, inf and other modules that may be used
273+ namespace = {
274+ "torch" : torch ,
275+ }
276+ # Try to import device (if used in code)
277+ try :
278+ from torch import device
279+
280+ namespace ["device" ] = device
281+ except ImportError :
282+ pass
283+ # Try to import inf (if used in code)
284+ try :
285+ from torch import inf
286+
287+ namespace ["inf" ] = inf
288+ except ImportError :
289+ # If torch doesn't have inf, use math.inf
290+ try :
291+ import math
292+
293+ namespace ["inf" ] = math .inf
294+ except :
295+ pass
296+
297+ exec (compile (modified_tree , filename = "<ast>" , mode = "exec" ), namespace )
298+
299+ # Update GraphModule's forward method and code
300+ forward_func = namespace .get ("forward" )
301+ if forward_func :
302+ import types
303+
304+ gm .forward = types .MethodType (forward_func , gm )
305+
306+ # Update _code attribute so that gm.code returns the modified code
307+ gm ._code = new_code
308+
309+ return gm
310+
150311 # replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
151312
152313 # replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
0 commit comments