In [1]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter


class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(10, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )
    def forward(self, x):
        return self.layers(x)


class SimpleCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 8 * 8, 128)  
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))   
        x = self.pool2(self.relu2(self.conv2(x)))   
        x = self.flatten(x)                         
        x = self.relu3(self.fc1(x))                 
        x = self.fc2(x)                             
        return x





In [9]:
# --- Init model + writer ---
model = MLP()
dummy_input = torch.randn(1, 10)
model

MLP(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=2, bias=True)
  )
)

In [10]:
trace = torch.jit.trace(model, dummy_input)

In [11]:
trace.inlined_graph

graph(%self.1 : __torch__.___torch_mangle_20.MLP,
      %x : Float(1, 10, strides=[10, 1], requires_grad=0, device=cpu)):
  %layers : __torch__.torch.nn.modules.container.___torch_mangle_19.Sequential = prim::GetAttr[name="layers"](%self.1)
  %_2 : __torch__.torch.nn.modules.linear.___torch_mangle_18.Linear = prim::GetAttr[name="2"](%layers)
  %_1 : __torch__.torch.nn.modules.activation.___torch_mangle_17.ReLU = prim::GetAttr[name="1"](%layers)
  %_0 : __torch__.torch.nn.modules.linear.___torch_mangle_16.Linear = prim::GetAttr[name="0"](%layers)
  %bias.1 : Tensor = prim::GetAttr[name="bias"](%_0)
  %weight.1 : Tensor = prim::GetAttr[name="weight"](%_0)
  %input.1 : Float(1, 32, strides=[32, 1], requires_grad=1, device=cpu) = aten::linear(%x, %weight.1, %bias.1), scope: __module.layers/__module.layers.0 # /home/coder/Playground/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134:0
  %input : Float(1, 32, strides=[32, 1], requires_grad=1, device=cpu) = aten::relu(%input.1)

In [12]:
model = SimpleCNN()
dummy_input = torch.randn(1, 3, 32, 32)
model

SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=2048, out_features=128, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=128, out_features=4, bias=True)
)

In [13]:
writer = SummaryWriter("runs/mlp_test")
writer.add_graph(model, dummy_input)     # log model graph
writer.add_scalar("loss", 0.42, 1)       # log scalar
writer.close()

In [14]:
# mypy: allow-untyped-defs
from collections import OrderedDict
import contextlib
from typing import Any

from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
from google.protobuf.json_format import MessageToDict
import torch
from torch.utils.tensorboard._proto_graph import node_proto

In [16]:


methods_OP = [
    "attributeNames",
    "hasMultipleOutputs",
    "hasUses",
    "inputs",
    "kind",
    "outputs",
    "outputsSize",
    "scopeName",
]
# Some additional methods to explure for methods_IO are
#
#   'unique' (type int)
#   'type' (type <Tensor<class 'torch._C.Type'>>)
#
# But the below are sufficient for now.
methods_IO = ["node", "offset", "debugName"]

GETATTR_KIND = "prim::GetAttr"
CLASSTYPE_KIND = "ClassType"


class NodeBase:
    def __init__(
        self,
        debugName=None,
        inputs=None,
        scope=None,
        tensor_size=None,
        op_type="UnSpecified",
        attributes="",
    ):
        # TODO; Specify a __slots__ for this class or potentially
        # used namedtuple instead
        self.debugName = debugName
        self.inputs = inputs
        self.tensor_size = tensor_size
        self.kind = op_type
        self.attributes = attributes
        self.scope = scope

    def __repr__(self):
        repr = []
        repr.append(str(type(self)))
        repr.extend(
            m + ": " + str(getattr(self, m)) + str(type(getattr(self, m)))
            for m in dir(self)
            if "__" not in m
        )
        return "\n".join(repr) + "\n\n"


class NodePy(NodeBase):
    def __init__(self, node_cpp, valid_methods):
        super().__init__(node_cpp)
        valid_methods = valid_methods[:]
        self.inputs = []

        for m in valid_methods:
            if m == "inputs" or m == "outputs":
                list_of_node = list(getattr(node_cpp, m)())
                io_unique_names = []
                io_tensor_sizes = []
                for n in list_of_node:
                    io_unique_names.append(n.debugName())
                    if n.isCompleteTensor():
                        io_tensor_sizes.append(n.type().sizes())
                    else:
                        io_tensor_sizes.append(None)

                setattr(self, m, io_unique_names)
                setattr(self, m + "tensor_size", io_tensor_sizes)

            else:
                setattr(self, m, getattr(node_cpp, m)())


class NodePyIO(NodePy):
    def __init__(self, node_cpp, input_or_output=None):
        super().__init__(node_cpp, methods_IO)
        try:
            tensor_size = node_cpp.type().sizes()
        except RuntimeError:
            tensor_size = [
                1,
            ]  # fail when constant model is used.
        self.tensor_size = tensor_size
        # Kind attribute string is purely descriptive and will be shown
        # in detailed information for the node in TensorBoard's graph plugin.
        #
        # NodePyOP nodes get this from their kind() method.
        self.kind = "Parameter"
        if input_or_output:
            self.input_or_output = input_or_output
            self.kind = "IO Node"


class NodePyOP(NodePy):
    def __init__(self, node_cpp):
        super().__init__(node_cpp, methods_OP)
        # Replace single quote which causes strange behavior in TensorBoard
        # TODO: See if we can remove this in the future
        self.attributes = str(
            {k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()}
        ).replace("'", " ")
        print(self.attributes)
        self.kind = node_cpp.kind()


class GraphPy:
    """Helper class to convert torch.nn.Module to GraphDef proto and visualization with TensorBoard.

    GraphDef generation operates in two passes:

    In the first pass, all nodes are read and saved to two lists.
    One list is for input/output nodes (nodes_io), which only have inbound
    or outbound connections, but not both. Another list is for internal
    operator nodes (nodes_op). The first pass also saves all scope name
    appeared in the nodes in scope_name_appeared list for later processing.

    In the second pass, scope names are fully applied to all nodes.
    debugNameToScopedName is a mapping from a node's ID to its fully qualified
    scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have
    totally correct scope output, so this is nontrivial. The function
    populate_namespace_from_OP_to_IO and find_common_root are used to
    assign scope name to a node based on the connection between nodes
    in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name
    and scope_name_appeared.
    """

    def __init__(self):
        self.nodes_op = []
        self.nodes_io = OrderedDict()
        self.unique_name_to_scoped_name = {}
        self.shallowest_scope_name = "default"
        self.scope_name_appeared = []

    def append(self, x):
        if isinstance(x, NodePyIO):
            self.nodes_io[x.debugName] = x
        if isinstance(x, NodePyOP):
            self.nodes_op.append(x)

    def printall(self):
        print("all nodes")
        for node in self.nodes_op:
            print(node)
        for key in self.nodes_io:
            print(self.nodes_io[key])

    def find_common_root(self):
        for fullscope in self.scope_name_appeared:
            if fullscope:
                self.shallowest_scope_name = fullscope.split("/")[0]

    def populate_namespace_from_OP_to_IO(self):
        for node in self.nodes_op:
            for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
                self.scope_name_appeared.append(node.scopeName)
                self.nodes_io[node_output] = NodeBase(
                    node_output,
                    node.inputs,
                    node.scopeName,
                    outputSize,
                    op_type=node.kind,
                    attributes=node.attributes,
                )

        self.find_common_root()

        for node in self.nodes_op:
            for input_node_id in node.inputs:
                self.unique_name_to_scoped_name[input_node_id] = (
                    node.scopeName + "/" + input_node_id
                )

        for key, node in self.nodes_io.items():
            if type(node) == NodeBase:
                self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
            if hasattr(node, "input_or_output"):
                self.unique_name_to_scoped_name[key] = (
                    node.input_or_output + "/" + node.debugName
                )

            if hasattr(node, "scope") and node.scope is not None:
                self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
                if node.scope == "" and self.shallowest_scope_name:
                    self.unique_name_to_scoped_name[node.debugName] = (
                        self.shallowest_scope_name + "/" + node.debugName
                    )

        # replace name
        for key, node in self.nodes_io.items():
            self.nodes_io[key].inputs = [
                self.unique_name_to_scoped_name[node_input_id]
                for node_input_id in node.inputs
            ]
            if node.debugName in self.unique_name_to_scoped_name:
                self.nodes_io[key].debugName = self.unique_name_to_scoped_name[
                    node.debugName
                ]

    def to_proto(self):
        """Convert graph representation of GraphPy object to TensorBoard required format."""
        # TODO: compute correct memory usage and CPU time once
        # PyTorch supports it
        nodes = [
            node_proto(
                v.debugName,
                input=v.inputs,
                outputsize=v.tensor_size,
                op=v.kind,
                attributes=v.attributes,
            )
            for v in self.nodes_io.values()
        ]
        return nodes


def parse(graph, trace, args=None, omit_useless_nodes=True):
    """Parse an optimized PyTorch model graph and produces a list of nodes and node stats.

    Useful for eventual conversion to TensorBoard protobuf format.

    Args:
      graph (PyTorch module): The model graph to be parsed.
      trace (PyTorch JIT TracedModule): The model trace to be parsed.
      args (tuple): input tensor[s] for the model.
      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
    """
    nodes_py = GraphPy()
    for node in graph.inputs():
        if omit_useless_nodes:
            if (
                len(node.uses()) == 0
            ):  # number of user of the node (= number of outputs/ fanout)
                continue

        if node.type().kind() != CLASSTYPE_KIND:
            nodes_py.append(NodePyIO(node, "input"))

    attr_to_scope: dict[Any, str] = {}
    for node in graph.nodes():
        if node.kind() == GETATTR_KIND:
            attr_name = node.s("name")
            attr_key = node.output().debugName()
            parent = node.input().node()
            if (
                parent.kind() == GETATTR_KIND
            ):  # If the parent node is not the top-level "self" node
                parent_attr_key = parent.output().debugName()
                parent_scope = attr_to_scope[parent_attr_key]
                attr_scope = parent_scope.split("/")[-1]
                attr_to_scope[attr_key] = f"{parent_scope}/{attr_scope}.{attr_name}"
            else:
                attr_to_scope[attr_key] = f"__module.{attr_name}"
            # We don't need classtype nodes; scope will provide this information
            if node.output().type().kind() != CLASSTYPE_KIND:
                node_py = NodePyOP(node)
                node_py.scopeName = attr_to_scope[attr_key]  # type: ignore[attr-defined]
                nodes_py.append(node_py)
        else:
            nodes_py.append(NodePyOP(node))

    for i, node in enumerate(graph.outputs()):  # Create sink nodes for output ops
        node_pyio = NodePyIO(node, "output")
        node_pyio.debugName = f"output.{i + 1}"
        node_pyio.inputs = [node.debugName()]
        nodes_py.append(node_pyio)

    def parse_traced_name(module):
        if isinstance(module, torch.jit.TracedModule):
            module_name = module._name
        else:
            module_name = getattr(module, "original_name", "Module")
        return module_name

    alias_to_name = {}
    base_name = parse_traced_name(trace)
    for name, module in trace.named_modules(prefix="__module"):
        mod_name = parse_traced_name(module)
        attr_name = name.split(".")[-1]
        alias_to_name[name] = f"{mod_name}[{attr_name}]"

    for node in nodes_py.nodes_op:
        module_aliases = node.scopeName.split("/")
        replacements = [
            alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1]
            for alias in module_aliases
        ]
        node.scopeName = base_name
        if any(replacements):
            node.scopeName += "/" + "/".join(replacements)

    nodes_py.populate_namespace_from_OP_to_IO()
    return nodes_py.to_proto()


def graph(model, args, verbose=False, use_strict_trace=True):
    """
    Process a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      use_strict_trace (bool): Whether to pass keyword argument `strict` to
        `torch.jit.trace`. Pass False when you want the tracer to
        record your mutable container types (list, dict)
    """
    with _set_model_to_eval(model):
        try:
            trace = torch.jit.trace(model, args, strict=use_strict_trace)
            graph = trace.graph
            torch._C._jit_pass_inline(graph)
        except RuntimeError as e:
            print(e)
            print("Error occurs, No graph saved")
            raise e

    if verbose:
        print(graph)
    list_of_nodes = parse(graph, trace, args)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(
        step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
    )
    return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
    # The producer version has been reverse engineered from standard
    # TensorBoard logged data.


@contextlib.contextmanager
def _set_model_to_eval(model):
    """Context manager to temporarily set the training mode of ``model`` to eval."""
    if not isinstance(model, torch.jit.ScriptFunction):
        originally_training = model.training
        model.train(False)
        try:
            yield
        finally:
            model.train(originally_training)
    else:
        # Do nothing for ScriptFunction
        try:
            yield
        finally:
            pass


def _node_get(node: torch._C.Node, key: str):
    """Get attributes of a node which is polymorphic over return type."""
    sel = node.kindOf(key)
    return getattr(node, sel)(key)


In [17]:
import base64
from typing import Dict, Any

def try_decode_base64_attr(s: str) -> str:
    """Safely decodes a Base64 string if it appears to contain dictionary-like data."""
    if not isinstance(s, str):
        return s
    
    try:
        # Attempt to decode, handling both Base64 and UTF-8 conversion errors
        decoded_bytes = base64.b64decode(s, validate=True)
        decoded_str = decoded_bytes.decode('utf-8')
        
        # Heuristic check: only replace if the content looks like a dictionary/text
        if '{' in decoded_str or '}' in decoded_str or ':' in decoded_str or 'name' in decoded_str:
            return decoded_str.strip()
            
    except Exception:
        # Return the original string if decoding fails or heuristic check fails
        pass

    return s

def clean_graphdef_attributes(data_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Cleans Base64-encoded attribute values within a dictionary representation 
    of a GraphDef Protobuf message created by MessageToDict.
    """
    if 'node' not in data_dict:
        return data_dict

    for node in data_dict['node']:
        if 'attr' in node:
            # Check the dictionary that holds the attributes (e.g., {'attr': {'s': 'Base64...'}})
            for attr_name, attr_data in node['attr'].items():
                # Base64 strings are typically stored under the 's' key for string/bytes fields
                if 's' in attr_data:
                    encoded_value = attr_data['s']
                    decoded_value = try_decode_base64_attr(encoded_value)
                    
                    if decoded_value != encoded_value:
                        # If decoding was successful, replace the 's' key with the decoded string
                        attr_data['s'] = decoded_value

    return data_dict

In [18]:
model_graph = graph(model,dummy_input,True)[0]

graph(%self.1 : __torch__.___torch_mangle_54.SimpleCNN,
      %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)):
  %fc2 : __torch__.torch.nn.modules.linear.___torch_mangle_53.Linear = prim::GetAttr[name="fc2"](%self.1)
  %relu3 : __torch__.torch.nn.modules.activation.___torch_mangle_52.ReLU = prim::GetAttr[name="relu3"](%self.1)
  %fc1 : __torch__.torch.nn.modules.linear.___torch_mangle_51.Linear = prim::GetAttr[name="fc1"](%self.1)
  %flatten : __torch__.torch.nn.modules.flatten.___torch_mangle_50.Flatten = prim::GetAttr[name="flatten"](%self.1)
  %pool2 : __torch__.torch.nn.modules.pooling.___torch_mangle_49.MaxPool2d = prim::GetAttr[name="pool2"](%self.1)
  %relu2 : __torch__.torch.nn.modules.activation.___torch_mangle_48.ReLU = prim::GetAttr[name="relu2"](%self.1)
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_47.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %pool1 : __torch__.torch.nn.modules.pooling.___torch_mangle_46.MaxPool2d = pr

In [19]:
model_graph

node {
  name: "input/x"
  op: "IO Node"
  attr {
    key: "attr"
    value {
      s: ""
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 1
          }
          dim {
            size: 3
          }
          dim {
            size: 32
          }
          dim {
            size: 32
          }
        }
      }
    }
  }
}
node {
  name: "output/output.1"
  op: "IO Node"
  input: "SimpleCNN/Linear[fc2]/215"
  attr {
    key: "attr"
    value {
      s: ""
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 1
          }
          dim {
            size: 4
          }
        }
      }
    }
  }
}
node {
  name: "SimpleCNN/Conv2d[conv1]/164"
  op: "prim::Constant"
  attr {
    key: "attr"
    value {
      s: "{ value : 1}"
    }
  }
}
node {
  name: "SimpleCNN/Conv2d[conv1]/165"
  op: "prim::Constant"
  attr {
    key: "attr"
    value {
      s

In [20]:
data_dict = MessageToDict(model_graph)
cleaned_data_dict = clean_graphdef_attributes(data_dict) # This step fixes it

In [21]:
data_dict

{'node': [{'name': 'input/x',
   'op': 'IO Node',
   'attr': {'attr': {'s': ''},
    '_output_shapes': {'list': {'shape': [{'dim': [{'size': '1'},
         {'size': '3'},
         {'size': '32'},
         {'size': '32'}]}]}}}},
  {'name': 'output/output.1',
   'op': 'IO Node',
   'input': ['SimpleCNN/Linear[fc2]/215'],
   'attr': {'attr': {'s': ''},
    '_output_shapes': {'list': {'shape': [{'dim': [{'size': '1'},
         {'size': '4'}]}]}}}},
  {'name': 'SimpleCNN/Conv2d[conv1]/164',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 1}'}}},
  {'name': 'SimpleCNN/Conv2d[conv1]/165',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 0}'}}},
  {'name': 'SimpleCNN/Conv2d[conv1]/166',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 0}'}}},
  {'name': 'SimpleCNN/Conv2d[conv1]/167',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 1}'}}},
  {'name': 'SimpleCNN/Conv2d[conv1]/bias/bias.5',
   'op': 'prim::GetAttr',
   'input': ['Simpl

In [22]:
import json
with open("output.json", "w", encoding="utf-8") as f:
    json.dump(data_dict, f, ensure_ascii=False, indent=4)

In [23]:
import torchvision
vgg = torchvision.models.resnet18(weights='IMAGENET1K_V1')
vgg

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [24]:
dummy_input = torch.randn(1, 3, 224, 224)

writer = SummaryWriter("runs/vgg_test")
writer.add_graph(vgg, dummy_input)   
model_graph = graph(vgg,dummy_input,True)[0]

graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_265.ResNet,
      %x.1 : Float(1, 3, 224, 224, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu)):
  %fc : __torch__.torch.nn.modules.linear.___torch_mangle_264.Linear = prim::GetAttr[name="fc"](%self.1)
  %avgpool : __torch__.torch.nn.modules.pooling.___torch_mangle_263.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
  %layer4 : __torch__.torch.nn.modules.container.___torch_mangle_262.Sequential = prim::GetAttr[name="layer4"](%self.1)
  %layer3 : __torch__.torch.nn.modules.container.___torch_mangle_246.Sequential = prim::GetAttr[name="layer3"](%self.1)
  %layer2 : __torch__.torch.nn.modules.container.___torch_mangle_230.Sequential = prim::GetAttr[name="layer2"](%self.1)
  %layer1 : __torch__.torch.nn.modules.container.___torch_mangle_214.Sequential = prim::GetAttr[name="layer1"](%self.1)
  %maxpool : __torch__.torch.nn.modules.pooling.___torch_mangle_201.MaxPool2d = prim::GetAttr[name="maxpool"]

In [25]:
model_graph

node {
  name: "input/x.1"
  op: "IO Node"
  attr {
    key: "attr"
    value {
      s: ""
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 1
          }
          dim {
            size: 3
          }
          dim {
            size: 224
          }
          dim {
            size: 224
          }
        }
      }
    }
  }
}
node {
  name: "output/output.1"
  op: "IO Node"
  input: "ResNet/Linear[fc]/1581"
  attr {
    key: "attr"
    value {
      s: ""
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 1
          }
          dim {
            size: 1000
          }
        }
      }
    }
  }
}
node {
  name: "ResNet/Conv2d[conv1]/1223"
  op: "prim::Constant"
  attr {
    key: "attr"
    value {
      s: "{ value : 1}"
    }
  }
}
node {
  name: "ResNet/Conv2d[conv1]/1224"
  op: "prim::Constant"
  attr {
    key: "attr"
    value {
      s

In [26]:


# Example Usage:
data_dict = MessageToDict(model_graph) # This step produces the Base64
cleaned_data_dict = clean_graphdef_attributes(data_dict) # This step fixes it

In [21]:
# data_dict = MessageToDict(model_graph)

In [27]:
data_dict

{'node': [{'name': 'input/x.1',
   'op': 'IO Node',
   'attr': {'attr': {'s': ''},
    '_output_shapes': {'list': {'shape': [{'dim': [{'size': '1'},
         {'size': '3'},
         {'size': '224'},
         {'size': '224'}]}]}}}},
  {'name': 'output/output.1',
   'op': 'IO Node',
   'input': ['ResNet/Linear[fc]/1581'],
   'attr': {'attr': {'s': ''},
    '_output_shapes': {'list': {'shape': [{'dim': [{'size': '1'},
         {'size': '1000'}]}]}}}},
  {'name': 'ResNet/Conv2d[conv1]/1223',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 1}'}}},
  {'name': 'ResNet/Conv2d[conv1]/1224',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 0}'}}},
  {'name': 'ResNet/Conv2d[conv1]/1225',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 0}'}}},
  {'name': 'ResNet/Conv2d[conv1]/1226',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ value : 1}'}}},
  {'name': 'ResNet/Conv2d[conv1]/1227',
   'op': 'prim::Constant',
   'attr': {'attr': {'s': '{ val

In [28]:
import json
with open("output.json", "w", encoding="utf-8") as f:
    json.dump(data_dict, f, ensure_ascii=False, indent=4)

---