In [1]:
"""
AST graphic representation Module.

This module provides utilities for converting an Abstract Syntax Tree (AST),
represented as a nested Python dictionary, to a Graphviz dot graph which
can be displayed inline in a Jupyter notebook, or as an ascii representation
directly in the console.
"""

from __future__ import annotations

import re
import types

from typing import Optional, cast

from graphviz import Digraph
from IPython.display import Image, display  # type: ignore[attr-defined]
from msgpack import dumps, loads

from astx.types import ReprStruct


In [2]:
import astx

# Initialize the ASTx module
module = astx.Module()

# Define the Fibonacci function prototype
fib_proto = astx.FunctionPrototype(
    name="fib", 
    args=astx.Arguments(astx.Argument("n", astx.Int32)), 
    return_type=astx.Int32
)

# Create the function body block
fib_block = astx.Block()

# Declare the variables
decl_a = astx.VariableDeclaration(name="a", type_=astx.Int32, value=astx.LiteralInt32(0))
decl_b = astx.VariableDeclaration(name="b", type_=astx.Int32, value=astx.LiteralInt32(1))
decl_i = astx.VariableDeclaration(name="i", type_=astx.Int32, value=astx.LiteralInt32(2))

# Initialize the block with declarations
fib_block.append(decl_a)
fib_block.append(decl_b)
fib_block.append(decl_i)

# Create the loop condition
cond = astx.BinaryOp(
    op_code="<", 
    lhs=astx.Variable(name="i"), 
    rhs=astx.Variable(name="n")
)

# Define the loop body
loop_block = astx.Block()
assign_sum = astx.VariableAssignment(
    name="sum", 
    value=astx.BinaryOp(
        op_code="+", 
        lhs=astx.Variable(name="a"), 
        rhs=astx.Variable(name="b")
    )
)
assign_a = astx.VariableAssignment(name="a", value=astx.Variable(name="b"))
assign_b = astx.VariableAssignment(name="b", value=astx.Variable(name="sum"))
inc_i = astx.VariableAssignment(
    name="i", 
    value=astx.BinaryOp(
        op_code="+", 
        lhs=astx.Variable(name="i"), 
        rhs=astx.LiteralInt32(1)
    )
)

# Add assignments to the loop body
loop_block.append(assign_sum)
loop_block.append(assign_a)
loop_block.append(assign_b)
loop_block.append(inc_i)

# Create the loop statement
loop = astx.While(condition=cond, body=loop_block)
fib_block.append(loop)

# Add return statement
return_stmt = astx.FunctionReturn(astx.Variable(name="b"))
fib_block.append(return_stmt)

# Define the function with its body
fib_fn = astx.Function(prototype=fib_proto, body=fib_block)

# Append the Fibonacci function to the module block
module.block.append(fib_fn)

# Display the module's structure
module;

In [3]:
binop = astx.LiteralInt32(1) + astx.LiteralInt32(3)

_s = binop.get_struct()
_s

{'BINARY[+]': {'content': {'lhs': {'Literal[Int32]: 1': {'content': 1,
     'metadata': {'loc': {line: -1, col: -1},
      'comment': '',
      'ref': '7811d2bd13cf44a5b698cb2a6bc5cf05',
      'kind': <ASTKind.GenericKind: -100>}}},
   'rhs': {'Literal[Int32]: 3': {'content': 3,
     'metadata': {'loc': {line: -1, col: -1},
      'comment': '',
      'ref': '577619d108bc4fa38cd47c4b126a7d00',
      'kind': <ASTKind.GenericKind: -100>}}}},
  'metadata': {'loc': {line: -1, col: -1},
   'comment': '',
   'ref': '',
   'kind': <ASTKind.BinaryOpKind: -301>}}}

In [10]:
repr(module)

''

def traverse_ast_png(
    ast: ReprStruct,
    graph: Optional[Digraph] = None,
    parent: Optional[str] = None,
    shape: str = "box",
    edge_label: str = "",
) -> Digraph:
    """
    Traverse the AST and build a Graphviz graph for png representation.

    Parameters
    ----------
    ast : dict
        The AST as a nested dictionary (full structure version).
    graph : Digraph
        The Graphviz graph object.
    parent : str, optional
        The identifier of the parent node in the graph, by default
        it is an empty string
    shape: str, options: ellipse, box, circle, diamond
        The shape used for the nodes in the graph. Default "box".

    Returns
    -------
    Digraph
        Graphviz (dot) graph representation.
    """
    if not graph:
        graph = Digraph()
        graph.attr(rankdir="TB")

    if not isinstance(ast, dict):
        return graph.unflatten(stagger=3)

    for key, full_value in ast.items():
        if not isinstance(full_value, dict):
            continue

        content = full_value.get("content", "")
        metadata = full_value.get("metadata", {})
        ref = ""

        if not metadata:
            # if the node doesn't have a metadata, it is a edge
            traverse_ast_png(full_value, graph, parent, shape=shape, edge_label=key)
            continue

        ref = cast(str, metadata.get("ref", ""))

        node_name = f"{hash(key)}_{hash(str(ref))}_{hash(str(content))}"
        graph.node(node_name, key, shape=shape)

        if parent:
            graph_params={"label": edge_label} if edge_label else {}
            graph.edge(parent, node_name, **graph_params)

        if isinstance(content, dict):
            traverse_ast_png(content, graph, node_name, shape=shape)
            continue
        elif not isinstance(content, list):
            continue

        for item in content:
            if isinstance(item, dict):
                traverse_ast_png(item, graph, node_name, shape=shape)
    return graph.unflatten(stagger=3)

traverse_ast_png(module.get_struct(simplified=False))