In [1]:
import torch, inspect, hashlib
import re

In [2]:
def ts_signature(module, example_inputs=()):
    """
    Return a SHA-256 hash that is identical for TorchScript-canonical-equivalent graphs.
    """
    # 1. Script & freeze (avoids data‑dependent control‑flow) 
    ts = torch.jit.script(module)
    ts = torch.jit.freeze(ts, example_inputs)

    # 2. Canonicalise graph in place
    g = ts.inlined_graph            # single, inlined graph
    torch._C._jit_pass_canonicalize(g)        # 🔒 deterministic order
    torch._C._jit_pass_constant_propagation(g)
    torch._C._jit_pass_dce(g)
    return g

    # 3. Stringify without unique names, then hash
    graph_str = g.toString(False, False, False)          # drop unique mangles
    return hashlib.sha256(graph_str.encode()).hexdigest()

In [3]:
def normalize_torchscript_graph(graph_str: str) -> str:
    # Normalize mangled class names: ___torch_mangle_123.Class → ___torch_mangle_XXX.Cls
    graph_str = re.sub(r'___torch_mangle_\d+\.\w+', '___torch_mangle_XXX.Cls', graph_str)
    
    # Normalize ipykernel file paths: /tmp/ipykernel_*/123456789.py → /tmp/ipykernel_X/file.py
    graph_str = re.sub(r'/tmp/ipykernel_\d+/[\d]+\.py', '/tmp/ipykernel_X/file.py', graph_str)
    
    # Normalize line numbers in comments: :12:34 → :X:Y
    graph_str = re.sub(r':\d+:\d+', ':X:Y', graph_str)
    
    return graph_str


In [4]:
# step - the same
class M(torch.nn.Module):
    def forward(self, x):
        y = 0
        for i in range(1, 10):     
            y = y + x[i]
 
        return y

class N(torch.nn.Module):
    def forward(self, x):
        y = 0
        for i in range(1, 10, 1):   
            y = y + x[i]
 
        return y

In [5]:
# change variable name - diff
class M(torch.nn.Module):
    def forward(self, x):
        z = 0
        for i in range(1, 10):     
            z = z + x[i]
 
        return z

class N(torch.nn.Module):
    def forward(self, x):
        y = 0
        for i in range(1, 10, 1):   
            y = y + x[i]
 
        return y

In [6]:
# val value - diff
class M(torch.nn.Module):
    def forward(self, x):
        y = 10
        for i in range(1, 10):     
            y = y + x[i]
 
        return y

class N(torch.nn.Module):
    def forward(self, x):
        y = 1
        for i in range(1, 10, 1):   
            y = y + x[i]
 
        return y

In [7]:
# swap vars - diff
class M(torch.nn.Module):
    def forward(self, x):
        y = 0
        for i in range(1, 10):     
            y = y + x[i]
 
        return y

class N(torch.nn.Module):
    def forward(self, x):
        y = 0
        for i in range(1, 10, 1):   
            y = x[i] + y
 
        return y

In [8]:
g1 = ts_signature(M().eval()).str()
g2 = ts_signature(N().eval()).str()

In [9]:
g1_norm = normalize_torchscript_graph(g1)
g2_norm = normalize_torchscript_graph(g2)

g1_norm == g2_norm

False

In [10]:
g1

'graph(%self : __torch__.___torch_mangle_0.M,\n      %x.1 : Tensor):\n  %2 : int = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:5:23\n  %y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_2944385/401325865.py:4:12\n  %4 : Tensor = aten::select(%x.1, %y.1, %2) # /tmp/ipykernel_2944385/401325865.py:6:20\n  %5 : int = aten::IntImplicit(%4) # /tmp/ipykernel_2944385/401325865.py:6:12\n  %6 : int = prim::Constant[value=2]()\n  %7 : Tensor = aten::select(%x.1, %y.1, %6) # /tmp/ipykernel_2944385/401325865.py:6:20\n  %y.14 : Tensor = aten::add(%7, %5, %2) # <string>:5:9\n  %9 : int = aten::IntImplicit(%y.14) # /tmp/ipykernel_2944385/401325865.py:6:12\n  %10 : int = prim::Constant[value=3]()\n  %11 : Tensor = aten::select(%x.1, %y.1, %10) # /tmp/ipykernel_2944385/401325865.py:6:20\n  %y.16 : Tensor = aten::add(%11, %9, %2) # <string>:5:9\n  %13 : int = aten::IntImplicit(%y.16) # /tmp/ipykernel_2944385/401325865.py:6:12\n  %14 : int = prim::Constant[value=4]()\n  %15 : Tens

In [11]:
g2

'graph(%self : __torch__.___torch_mangle_1.N,\n      %x.1 : Tensor):\n  %2 : int = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:13:23\n  %y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_2944385/401325865.py:12:12\n  %4 : Tensor = aten::select(%x.1, %y.1, %2) # /tmp/ipykernel_2944385/401325865.py:14:16\n  %5 : int = aten::IntImplicit(%4) # /tmp/ipykernel_2944385/401325865.py:14:12\n  %6 : int = prim::Constant[value=2]()\n  %7 : Tensor = aten::select(%x.1, %y.1, %6) # /tmp/ipykernel_2944385/401325865.py:14:16\n  %y.14 : Tensor = aten::add(%7, %5, %2) # /tmp/ipykernel_2944385/401325865.py:14:16\n  %9 : int = aten::IntImplicit(%y.14) # /tmp/ipykernel_2944385/401325865.py:14:12\n  %10 : int = prim::Constant[value=3]()\n  %11 : Tensor = aten::select(%x.1, %y.1, %10) # /tmp/ipykernel_2944385/401325865.py:14:16\n  %y.16 : Tensor = aten::add(%11, %9, %2) # /tmp/ipykernel_2944385/401325865.py:14:16\n  %13 : int = aten::IntImplicit(%y.16) # /tmp/ipykernel_2944385/4013258

In [12]:
g1_norm

'graph(%self : __torch__.___torch_mangle_XXX.Cls,\n      %x.1 : Tensor):\n  %2 : int = prim::Constant[value=1]() # /tmp/ipykernel_X/file.py:X:Y\n  %y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_X/file.py:X:Y\n  %4 : Tensor = aten::select(%x.1, %y.1, %2) # /tmp/ipykernel_X/file.py:X:Y\n  %5 : int = aten::IntImplicit(%4) # /tmp/ipykernel_X/file.py:X:Y\n  %6 : int = prim::Constant[value=2]()\n  %7 : Tensor = aten::select(%x.1, %y.1, %6) # /tmp/ipykernel_X/file.py:X:Y\n  %y.14 : Tensor = aten::add(%7, %5, %2) # <string>:X:Y\n  %9 : int = aten::IntImplicit(%y.14) # /tmp/ipykernel_X/file.py:X:Y\n  %10 : int = prim::Constant[value=3]()\n  %11 : Tensor = aten::select(%x.1, %y.1, %10) # /tmp/ipykernel_X/file.py:X:Y\n  %y.16 : Tensor = aten::add(%11, %9, %2) # <string>:X:Y\n  %13 : int = aten::IntImplicit(%y.16) # /tmp/ipykernel_X/file.py:X:Y\n  %14 : int = prim::Constant[value=4]()\n  %15 : Tensor = aten::select(%x.1, %y.1, %14) # /tmp/ipykernel_X/file.py:X:Y\n  %y.18 : Tensor = aten::

In [13]:
g2_norm

'graph(%self : __torch__.___torch_mangle_XXX.Cls,\n      %x.1 : Tensor):\n  %2 : int = prim::Constant[value=1]() # /tmp/ipykernel_X/file.py:X:Y\n  %y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_X/file.py:X:Y\n  %4 : Tensor = aten::select(%x.1, %y.1, %2) # /tmp/ipykernel_X/file.py:X:Y\n  %5 : int = aten::IntImplicit(%4) # /tmp/ipykernel_X/file.py:X:Y\n  %6 : int = prim::Constant[value=2]()\n  %7 : Tensor = aten::select(%x.1, %y.1, %6) # /tmp/ipykernel_X/file.py:X:Y\n  %y.14 : Tensor = aten::add(%7, %5, %2) # /tmp/ipykernel_X/file.py:X:Y\n  %9 : int = aten::IntImplicit(%y.14) # /tmp/ipykernel_X/file.py:X:Y\n  %10 : int = prim::Constant[value=3]()\n  %11 : Tensor = aten::select(%x.1, %y.1, %10) # /tmp/ipykernel_X/file.py:X:Y\n  %y.16 : Tensor = aten::add(%11, %9, %2) # /tmp/ipykernel_X/file.py:X:Y\n  %13 : int = aten::IntImplicit(%y.16) # /tmp/ipykernel_X/file.py:X:Y\n  %14 : int = prim::Constant[value=4]()\n  %15 : Tensor = aten::select(%x.1, %y.1, %14) # /tmp/ipykernel_X/file.p

In [14]:
scripted_fn = torch.jit.script(M())
graph_1 = scripted_fn.graph
graph_1

graph(%self : __torch__.M,
      %x.1 : Tensor):
  %7 : bool = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:5:8
  %35 : int = prim::Constant[value=9]()
  %y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_2944385/401325865.py:4:12
  %3 : int = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:5:23
  %y : int = prim::Loop(%35, %7, %y.1) # /tmp/ipykernel_2944385/401325865.py:5:8
    block0(%8 : int, %y.11 : int):
      %i.1 : int = aten::__derive_index(%8, %3, %3) # /tmp/ipykernel_2944385/401325865.py:5:8
      %16 : Tensor = aten::select(%x.1, %y.1, %i.1) # /tmp/ipykernel_2944385/401325865.py:6:20
      %y.5 : Tensor = aten::add(%16, %y.11, %3) # <string>:5:9
      %20 : int = aten::IntImplicit(%y.5) # /tmp/ipykernel_2944385/401325865.py:6:12
      -> (%7, %20)
  return (%y)

In [15]:

scripted_fn = torch.jit.script(N())
graph_2 = scripted_fn.graph
graph_2

graph(%self : __torch__.N,
      %x.1 : Tensor):
  %6 : bool = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:13:8
  %34 : int = prim::Constant[value=9]()
  %y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_2944385/401325865.py:12:12
  %3 : int = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:13:23
  %y : int = prim::Loop(%34, %6, %y.1) # /tmp/ipykernel_2944385/401325865.py:13:8
    block0(%7 : int, %y.11 : int):
      %i.1 : int = aten::__derive_index(%7, %3, %3) # /tmp/ipykernel_2944385/401325865.py:13:8
      %14 : Tensor = aten::select(%x.1, %y.1, %i.1) # /tmp/ipykernel_2944385/401325865.py:14:16
      %y.5 : Tensor = aten::add(%14, %y.11, %3) # /tmp/ipykernel_2944385/401325865.py:14:16
      %19 : int = aten::IntImplicit(%y.5) # /tmp/ipykernel_2944385/401325865.py:14:12
      -> (%6, %19)
  return (%y)

In [16]:
from collections import defaultdict

def extract_graph_semantics(graph):
    nodes = []
    for node in graph.nodes():
        # Record operation type and its inputs and outputs (not source location)
        op = node.kind()
        inputs = [i.debugName() for i in node.inputs()]
        outputs = [o.debugName() for o in node.outputs()]
        nodes.append((op, tuple(inputs), tuple(outputs)))
    return nodes

In [17]:
def extract_functional_ops(graph):
    ops = []
    for node in graph.nodes():
        op_kind = node.kind()
        inputs = [str(i) for i in node.inputs()]
        outputs = [str(o) for o in node.outputs()]
        ops.append(f"{', '.join(outputs)} = {op_kind}({', '.join(inputs)})")
    return ops

In [18]:
def compare_graphs(graph1, graph2):
    return extract_graph_semantics(graph1) == extract_graph_semantics(graph2)

In [19]:
compare_graphs(graph_1, graph_2)

False

In [20]:
def normalize_graph_structure(graph):
    op_seq = []
    name_map = {}
    counter = 0

    def get_id(name):
        nonlocal counter
        if name not in name_map:
            name_map[name] = f"%var{counter}"
            counter += 1
        return name_map[name]

    for node in graph.nodes():
        op = node.kind()
        inputs = tuple(get_id(i.debugName()) for i in node.inputs())
        outputs = tuple(get_id(o.debugName()) for o in node.outputs())
        op_seq.append((op, inputs, outputs))

    return op_seq

In [21]:
normalize_graph_structure(graph_1) == normalize_graph_structure(graph_2)

True

In [22]:
normalize_graph_structure(graph_1) 

[('prim::Constant', (), ('%var0',)),
 ('prim::Constant', (), ('%var1',)),
 ('prim::Constant', (), ('%var2',)),
 ('prim::Constant', (), ('%var3',)),
 ('prim::Loop', ('%var1', '%var0', '%var2'), ('%var4',))]

In [23]:
normalize_graph_structure(graph_2)

[('prim::Constant', (), ('%var0',)),
 ('prim::Constant', (), ('%var1',)),
 ('prim::Constant', (), ('%var2',)),
 ('prim::Constant', (), ('%var3',)),
 ('prim::Loop', ('%var1', '%var0', '%var2'), ('%var4',))]

In [24]:
extract_graph_semantics(graph_1)

[('prim::Constant', (), ('7',)),
 ('prim::Constant', (), ('35',)),
 ('prim::Constant', (), ('y.1',)),
 ('prim::Constant', (), ('3',)),
 ('prim::Loop', ('35', '7', 'y.1'), ('y',))]

In [25]:
extract_functional_ops(graph_1)

['7 defined in (%7 : bool = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:5:8\n) = prim::Constant()',
 '35 defined in (%35 : int = prim::Constant[value=9]()\n) = prim::Constant()',
 'y.1 defined in (%y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_2944385/401325865.py:4:12\n) = prim::Constant()',
 '3 defined in (%3 : int = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:5:23\n) = prim::Constant()',
 'y defined in (%y : int = prim::Loop(%35, %7, %y.1) # /tmp/ipykernel_2944385/401325865.py:5:8\n  block0(%8 : int, %y.11 : int):\n    %i.1 : int = aten::__derive_index(%8, %3, %3) # /tmp/ipykernel_2944385/401325865.py:5:8\n    %16 : Tensor = aten::select(%x.1, %y.1, %i.1) # /tmp/ipykernel_2944385/401325865.py:6:20\n    %y.5 : Tensor = aten::add(%16, %y.11, %3) # <string>:5:9\n    %20 : int = aten::IntImplicit(%y.5) # /tmp/ipykernel_2944385/401325865.py:6:12\n    -> (%7, %20)\n) = prim::Loop(35 defined in (%35 : int = prim::Constant[value=9]()\n), 7 de

In [26]:
extract_functional_ops(graph_2)

['6 defined in (%6 : bool = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:13:8\n) = prim::Constant()',
 '34 defined in (%34 : int = prim::Constant[value=9]()\n) = prim::Constant()',
 'y.1 defined in (%y.1 : int = prim::Constant[value=0]() # /tmp/ipykernel_2944385/401325865.py:12:12\n) = prim::Constant()',
 '3 defined in (%3 : int = prim::Constant[value=1]() # /tmp/ipykernel_2944385/401325865.py:13:23\n) = prim::Constant()',
 'y defined in (%y : int = prim::Loop(%34, %6, %y.1) # /tmp/ipykernel_2944385/401325865.py:13:8\n  block0(%7 : int, %y.11 : int):\n    %i.1 : int = aten::__derive_index(%7, %3, %3) # /tmp/ipykernel_2944385/401325865.py:13:8\n    %14 : Tensor = aten::select(%x.1, %y.1, %i.1) # /tmp/ipykernel_2944385/401325865.py:14:16\n    %y.5 : Tensor = aten::add(%14, %y.11, %3) # /tmp/ipykernel_2944385/401325865.py:14:16\n    %19 : int = aten::IntImplicit(%y.5) # /tmp/ipykernel_2944385/401325865.py:14:12\n    -> (%6, %19)\n) = prim::Loop(34 defined in (%34 : int 

In [27]:
def extract_operator_sequence(graph):
    return [node.kind() for node in graph.nodes()]

In [28]:
ops1 = extract_operator_sequence(graph_1)
ops2 = extract_operator_sequence(graph_2)

In [29]:
print(ops1 == ops2)

True


In [30]:
ops1

['prim::Constant',
 'prim::Constant',
 'prim::Constant',
 'prim::Constant',
 'prim::Loop']