# AST Experiments

In [65]:
import ast
from pprint import pprint

## Build an S-expression from an expression

For a much more complete version of ast to s-expression generation, look at https://github.com/mattmight/python-to-sexp/blob/master/pysx

In [49]:
class Visitor(ast.NodeVisitor):
    """Minimal visitor
    """
    
    def generic_visit(self, node):
        return ast.NodeVisitor.generic_visit(self, node)
    
    def visit_Module(self, node):
        """Ignore the fact that a module may have more than one statement
        """
        res = self.visit(node.body[0])
        return res
    
    def visit_Expr(self, node):
        res = self.visit(node.value)
        return res
    
    def visit_BinOp(self, node):
        left  = self.visit(node.left)
        right = self.visit(node.right)
        opname = node.op.__class__.__name__
        kind = {
            'Add'  : '+',
            'Mult' : '*',
        }
        return (kind[opname], left, right)
    
    def visit_Num(self, node):
        return node.n


('+', 1, ('*', 2, 3.0))


In [56]:
def generate(src):
    tree = ast.parse(src)
    v = Visitor()
    return v.visit(tree)

assert(generate('1') == 1)
assert(generate('1+1') == ('+', 1, 1))
assert(generate('1+2*3.') == ('+', 1, ('*', 2, 3.0)))

## Flatten an expression

This approach was inspired by the work of Alex Gaynor, and his presentation [So you want to write an interpreter](http://pyvideo.org/video/1694/so-you-want-to-write-an-interpreter). 

This is an incredibly naive approach to code generation of any form (even compared to Alex's presentation), but it's a start

In [70]:
class Context(object):
    def __init__(self):
        self._instr = []
    
    def emit(self, s):
        self._instr.append(s)
        
    def tmp(self):
        return '__tmp%d' % len(self._instr)

class Visitor(ast.NodeVisitor):
    """Minimal visitor
    """
    
    def __init__(self, ctx):
        self.ctx = ctx
    
    def visit_BinOp(self, node):
        left  = self.visit(node.left)
        right = self.visit(node.right)
        opname = node.op.__class__.__name__
        kind = {
            'Add'  : '+',
            'Mult' : '*',
        }
        what = self.ctx.tmp()
        self.ctx.emit('%s = %s %s %s' % (what, left, kind[opname], right))
        return what
    
    def visit_Num(self, node):
        what = self.ctx.tmp() 
        self.ctx.emit('%s = %s' % (what, node.n))
        return what

ctx = Context()
vis = Visitor(ctx)
tree = ast.parse('1+2*3+4*5')
vis.visit(tree)
pprint(ctx._instr)

['__tmp0 = 1',
 '__tmp1 = 2',
 '__tmp2 = 3',
 '__tmp3 = __tmp1 * __tmp2',
 '__tmp4 = __tmp0 + __tmp3',
 '__tmp5 = 4',
 '__tmp6 = 5',
 '__tmp7 = __tmp5 * __tmp6',
 '__tmp8 = __tmp4 + __tmp7']
