In [27]:
import torch
import torch.nn as nn
import torch.fx as fx
import hashlib
from collections import defaultdict


def hash_group(node_inputs):
    """Create a stable hash from input node names."""
    key = tuple(sorted(node_inputs))
    return hashlib.md5(str(key).encode()).hexdigest()[:8]


def trace_and_export(
    model: nn.Module, input_example: torch.Tensor, save_path="model_structure.json"
):
    traced = fx.symbolic_trace(model)
    node_to_id = {}
    nodes = []
    edges = []
    input_map = defaultdict(list)
    output_map = defaultdict(list)

    for idx, node in enumerate(traced.graph.nodes):
        node_id = f"{node.op}_{node.name}_{idx}"
        node_to_id[node] = node_id
        inputs = [node_to_id.get(inp, str(inp)) for inp in node.all_input_nodes]

        nodes.append(
            {
                "id": node_id,
                "name": node.name,
                "op": node.op,
                "target": str(node.target),
                "inputs": inputs,
                "outputs": [],  # we'll fill this later
            }
        )
        for inp in inputs:
            edges.append({"from": inp, "to": node_id})
            output_map[inp].append(node_id)
            input_map[node_id].append(inp)

    # Add outputs to each node
    for node in nodes:
        node["outputs"] = output_map[node["id"]]

    # Grouping by shared inputs and outputs
    group_map = defaultdict(list)
    for node in nodes:
        key = hash_group(node["inputs"])
        group_map[key].append(node["id"])

    # Add group_id to each node
    for node in nodes:
        key = hash_group(node["inputs"])
        node["group_id"] = key

    # Export as JSON
    import json

    with open(save_path, "w") as f:
        json.dump({"nodes": nodes, "edges": edges}, f, indent=2)

    print(f"Saved model structure with layout groups to {save_path}")

In [28]:
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        attn = (q @ k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
        out = attn.softmax(dim=-1) @ v
        return self.out_proj(out)

In [29]:
model = SelfAttention(64)
dummy_input = torch.randn(1, 10, 64)
trace_and_export(model, dummy_input)

Saved model structure with layout groups to model_structure.json
