In [1]:
%pip install astpretty


Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting astpretty
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9f/0a/79fff71a08bc0cc427f0dfbd4cca62b60f7f277aae81b89b79e9b04d526d/astpretty-3.0.0-py2.py3-none-any.whl (4.9 kB)
Installing collected packages: astpretty
Successfully installed astpretty-3.0.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import ast
import astpretty

def generate_ast(code_string):
    """
    Generate an AST from a string containing Python code.
    
    Args:
        code_string (str): Python code as a string
    
    Returns:
        ast.AST: The abstract syntax tree
    """
    return ast.parse(code_string)

def visualize_ast(code_string):
    """
    Visualize the AST for a given Python code string.
    
    Args:
        code_string (str): Python code as a string
        
    Returns:
        None: Prints the AST structure
    """
    tree = generate_ast(code_string)
    print("AST Structure:")
    astpretty.pprint(tree)
    
    # Also print a more detailed node-by-node analysis
    print("\nDetailed AST Analysis:")
    for node in ast.walk(tree):
        if isinstance(node, ast.AST):
            node_name = node.__class__.__name__
            fields = []
            for name, value in ast.iter_fields(node):
                if isinstance(value, ast.AST):
                    fields.append(f"{name}={value.__class__.__name__}")
                elif isinstance(value, list) and value and isinstance(value[0], ast.AST):
                    fields.append(f"{name}=[{value[0].__class__.__name__}...]")
                else:
                    fields.append(f"{name}={repr(value)}")
            print(f"{node_name}({', '.join(fields)})")

# Example code to analyze
example_code = """
def factorial(n):
    if n <= 1:
        return 1
    else:
        return n * factorial(n - 1)

result = factorial(5)
print(f"The factorial of 5 is {result}")
"""

# Visualize the AST
visualize_ast(example_code)

AST Structure:
Module(
    body=[
        FunctionDef(
            lineno=2,
            col_offset=0,
            end_lineno=6,
            end_col_offset=35,
            name='factorial',
            args=arguments(
                posonlyargs=[],
                args=[arg(lineno=2, col_offset=14, end_lineno=2, end_col_offset=15, arg='n', annotation=None, type_comment=None)],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=[],
                kwarg=None,
                defaults=[],
            ),
            body=[
                If(
                    lineno=3,
                    col_offset=4,
                    end_lineno=6,
                    end_col_offset=35,
                    test=Compare(
                        lineno=3,
                        col_offset=7,
                        end_lineno=3,
                        end_col_offset=13,
                        left=Name(lineno=3, col_offset=7, end_lineno=3, end_col_offset=8, id

In [11]:
%pip install astpretty graphviz astor networkx matplotlib


Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting astor
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c3/88/97eef84f48fa04fbd6750e62dcceafba6c63c81b7ac1420856c8dcc0a3f9/astor-0.8.1-py2.py3-none-any.whl (27 kB)
Installing collected packages: astor
Successfully installed astor-0.8.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:
import ast
import astpretty
import graphviz
from pprint import pprint
import astor

class ASTVisualizer:
    """用于可视化Python抽象语法树的类"""
    
    def __init__(self, code):
        """
        初始化可视化器
        
        Args:
            code (str): 要分析的Python代码字符串
        """
        self.code = code
        self.tree = ast.parse(code)
    
    def print_ast(self):
        """打印AST的简单文本表示"""
        print("AST节点结构:")
        for node in ast.walk(self.tree):
            print(f"{type(node).__name__}: {ast.dump(node, include_attributes=False)}")
            
    def pretty_print(self):
        """使用astpretty打印格式化的AST"""
        print("格式化的AST:")
        astpretty.pprint(self.tree)
    
    def generate_dot(self, filename="ast_graph"):
        """
        生成AST的GraphViz DOT表示
        
        Args:
            filename (str): 输出文件名（不含扩展名）
        
        Returns:
            graphviz.Digraph: GraphViz图对象
        """
        dot = graphviz.Digraph(comment='Abstract Syntax Tree')
        
        # 为每个节点生成一个唯一的ID
        node_ids = {}
        node_counter = 0
        
        def traverse(node, parent_id=None):
            nonlocal node_counter
            
            # 为当前节点分配一个唯一ID
            node_id = f"node_{node_counter}"
            node_counter += 1
            node_ids[node] = node_id
            
            # 获取节点类型
            node_type = type(node).__name__
            
            # 准备节点属性文本
            attrs = []
            for attr_name, attr_value in ast.iter_fields(node):
                if not isinstance(attr_value, ast.AST) and not isinstance(attr_value, list):
                    attrs.append(f"{attr_name}={repr(attr_value)}")
            
            # 创建节点标签
            label = f"{node_type}\n{', '.join(attrs)}" if attrs else node_type
            dot.node(node_id, label)
            
            # 连接到父节点
            if parent_id is not None:
                dot.edge(parent_id, node_id)
            
            # 递归处理子节点
            for field_name, field_value in ast.iter_fields(node):
                if isinstance(field_value, ast.AST):
                    traverse(field_value, node_id)
                elif isinstance(field_value, list):
                    for i, item in enumerate(field_value):
                        if isinstance(item, ast.AST):
                            # 为列表项添加索引标签
                            child_id = traverse(item, None)
                            edge_label = f"{field_name}[{i}]"
                            dot.edge(node_id, child_id, label=edge_label)
            
            return node_id
        
        # 从AST的根节点开始遍历
        traverse(self.tree)
        
        # 保存图形
        dot.render(filename, view=True, format='png')
        
        
        return dot
    
    def to_source_code(self):
        """将AST转换回源代码"""
        return astor.to_source(self.tree)

    def visualize_with_networkx(self, filename="ast_network.png"):
        """
        使用NetworkX和Matplotlib生成AST可视化图
        
        Args:
            filename (str): 输出文件名
        """
        try:
            import networkx as nx
            import matplotlib.pyplot as plt
            from matplotlib.pyplot import figure
        except ImportError:
            print("请安装networkx和matplotlib: pip install networkx matplotlib")
            return
            
        # 创建一个有向图
        G = nx.DiGraph()
        
        # 为每个节点生成一个唯一的ID
        node_ids = {}
        node_labels = {}
        node_counter = 0
        
        def traverse(node, parent_id=None):
            nonlocal node_counter
            
            # 为当前节点分配一个唯一ID
            node_id = node_counter
            node_counter += 1
            node_ids[node] = node_id
            
            # 获取节点类型和关键属性
            node_type = type(node).__name__
            attrs = []
            for attr_name, attr_value in ast.iter_fields(node):
                if not isinstance(attr_value, ast.AST) and not isinstance(attr_value, list):
                    if attr_name in ['name', 'id', 'arg', 'op', 'value']:
                        attrs.append(f"{attr_name}={repr(attr_value)}")
            
            # 创建节点标签
            label = f"{node_type}\n{', '.join(attrs)}" if attrs else node_type
            node_labels[node_id] = label
            
            # 添加节点到图
            G.add_node(node_id)
            
            # 连接到父节点
            if parent_id is not None:
                G.add_edge(parent_id, node_id)
            
            # 递归处理子节点
            for field_name, field_value in ast.iter_fields(node):
                if isinstance(field_value, ast.AST):
                    child_id = traverse(field_value, node_id)
                elif isinstance(field_value, list):
                    for item in field_value:
                        if isinstance(item, ast.AST):
                            traverse(item, node_id)
            
            return node_id
        
        # 从AST的根节点开始遍历
        traverse(self.tree)
        
        # 设置图形大小
        figure(figsize=(12, 8))
        
        # 使用spring布局算法来布置节点
        pos = nx.spring_layout(G, k=0.9)
        
        # 绘制节点和边
        nx.draw(G, pos, with_labels=False, node_size=1500, node_color="skyblue", 
                font_size=10, font_weight="bold", arrowsize=15)
        
        # 添加节点标签
        nx.draw_networkx_labels(G, pos, node_labels)
        
        # 保存图形
        plt.savefig(filename)
        plt.close()
        print(f"AST网络图已保存为 {filename}")
    
    def analyze_code_structure(self):
        """Analyze code structure and return summary information as a string"""
        class CodeAnalyzer(ast.NodeVisitor):
            def __init__(self):
                self.function_count = 0
                self.class_count = 0
                self.import_count = 0
                self.assign_count = 0
                self.loop_count = 0  # for and while loops
                self.if_count = 0
                self.function_details = []
                self.class_details = []
                
            def visit_FunctionDef(self, node):
                self.function_count += 1
                args = len(node.args.args)
                body_size = len(node.body)
                self.function_details.append({
                    'name': node.name,
                    'args_count': args,
                    'body_size': body_size,
                    'decorators': len(node.decorator_list)
                })
                self.generic_visit(node)
                
            def visit_ClassDef(self, node):
                self.class_count += 1
                bases = [astor.to_source(base).strip() for base in node.bases]
                methods = sum(1 for n in node.body if isinstance(n, ast.FunctionDef))
                self.class_details.append({
                    'name': node.name,
                    'bases': bases,
                    'methods_count': methods,
                    'body_size': len(node.body),
                    'decorators': len(node.decorator_list)
                })
                self.generic_visit(node)
                
            def visit_Import(self, node):
                self.import_count += len(node.names)
                self.generic_visit(node)
                
            def visit_ImportFrom(self, node):
                self.import_count += len(node.names)
                self.generic_visit(node)
                
            def visit_Assign(self, node):
                self.assign_count += 1
                self.generic_visit(node)
                
            def visit_For(self, node):
                self.loop_count += 1
                self.generic_visit(node)
                
            def visit_While(self, node):
                self.loop_count += 1
                self.generic_visit(node)
                
            def visit_If(self, node):
                self.if_count += 1
                self.generic_visit(node)
                
        analyzer = CodeAnalyzer()
        analyzer.visit(self.tree)
        
        result = []
        
        result.append("--- Code Structure Analysis ---")
        result.append(f"Function count: {analyzer.function_count}")
        result.append(f"Class count: {analyzer.class_count}")
        result.append(f"Import statement count: {analyzer.import_count}")
        result.append(f"Assignment statement count: {analyzer.assign_count}")
        result.append(f"Loop count: {analyzer.loop_count}")
        result.append(f"Conditional statement count: {analyzer.if_count}")
        
        if analyzer.function_details:
            result.append("\nFunction details:")
            for func in analyzer.function_details:
                result.append(f"  - {func['name']}: {func['args_count']} parameter(s), "
                    f"{func['body_size']} line(s), {func['decorators']} decorator(s)")
                
        if analyzer.class_details:
            result.append("\nClass details:")
            for cls in analyzer.class_details:
                bases_str = ", ".join(cls['bases']) if cls['bases'] else "none"
                result.append(f"  - {cls['name']}: inherits from [{bases_str}], "
                    f"{cls['methods_count']} method(s), {cls['body_size']} line(s) in body")
        
        return "\n".join(result)



# 示例用法
if __name__ == "__main__":
    # 示例代码
    sample_code = '''
def factorial(n):
    if n <= 1:
        return 1
    else:
        return n * factorial(n - 1)

class MyClass:
    def __init__(self, value):
        self.value = value
        
    def get_value(self):
        return self.value * 2

result = factorial(5)
obj = MyClass(10)
print(obj.get_value())
    '''
    
    # 创建可视化器
    visualizer = ASTVisualizer(sample_code)
    
    
    visualizer.generate_dot("short_ast")
    
    
    
    
    # 分析代码结构
    str = visualizer.analyze_code_structure()

    print(str)

    

AST图形已保存为 factorial_ast.png
AST网络图已保存为 ast_network.png
--- Code Structure Analysis ---
Function count: 3
Class count: 1
Import statement count: 0
Assignment statement count: 3
Loop count: 0
Conditional statement count: 1

Function details:
  - factorial: 1 parameter(s), 1 line(s), 0 decorator(s)
  - __init__: 2 parameter(s), 1 line(s), 0 decorator(s)
  - get_value: 1 parameter(s), 1 line(s), 0 decorator(s)

Class details:
  - MyClass: inherits from [none], 2 method(s), 2 line(s) in body
