###Notebook takes a look at how to use AST to parse out the Python code elements for future use
###The primary portion of this code (Lexer) extends the code from Hermes.  The particular file of interest may be found at https://github.com/Lab41/hermes/blob/master/src/data_prep/git_vectorize.py

###Part of the notebook - the visualization near the end requires Python 3, but that isn't the core of the code

In [1]:
import ast
import re

In [2]:
from ast import Name, BinOp, Compare, Attribute, Subscript, Expr, Call, Assign, Str, Num, Tuple, List

###Code snippit that we are going to try to retain and model

In [3]:
c = """import save_load as sl
from src.utils import glove
from src.data_prep import jester_vectorize as jestv
jest_jokes = sqlCtx.read.json('/Users/abethke/spark-1.6.0/jester/jester_jokes.json.gz').sample(True, 0.1,41) 
jest_rates = sqlCtx.read.json('/Users/abethke/spark-1.6.0/jester/jester_ratings.json.gz').sample(True, 0.1,41)
glove_model = glove.Glove("/Users/abethke/spark-1.6.0/jester/glove.6B.50d.txt")
support_files = {'glove_model' : glove_model}
jest_vect = jestv.jester_vectorize(jest_rates, jest_jokes, "ratings", "glove", **support_files)
user_info = jest_vect.get_user_vector().repartition(20)
train_ratings, test_ratings = user_info.randomSplit([0.9,0.1], 41)
sl.save_to_hadoop(train_ratings, '/Users/abethke/spark-1.6.0/jester/jester_uv_train_ratings.pkl')
sl.save_to_hadoop(test_ratings, '/Users/abethke/spark-1.6.0/jester/jester_uv_test_ratings.pkl')
content_vect = jest_vect.get_content_vector()
sl.save_to_hadoop(content_vect, '/Users/abethke/spark-1.6.0/jester/jester_cv_glove.pkl') """

In [4]:
###First to get a better idea of what is going on it is helpful to look at a smaller amount of code

In [5]:
c_minig = """glove_model = glove.Glove.hi.there("/Users/abethke/spark-1.6.0/jester/glove.6B.50d.txt")"""
c_mini = """jest_jokes = sqlCtx.read.json('/Users/abethke/spark-1.6.0/jester/jester_jokes.json.gz').sample(True, 0.1,41,34,683) """
c_mini_s = """train_ratings, test_ratings = user_info.randomSplit([0.9,0.1], 41)"""

In [6]:
#Ast.dump is really helpful in viewing what is in the tree - along with all of the elements available

In [7]:
tree_mini = ast.parse(c_minig)
ast.dump(tree_mini)

"Module(body=[Assign(targets=[Name(id='glove_model', ctx=Store())], value=Call(func=Attribute(value=Attribute(value=Attribute(value=Name(id='glove', ctx=Load()), attr='Glove', ctx=Load()), attr='hi', ctx=Load()), attr='there', ctx=Load()), args=[Str(s='/Users/abethke/spark-1.6.0/jester/glove.6B.50d.txt')], keywords=[]))])"

In [8]:
c_mini_s = """train_ratings, test_ratings = user_info.randomSplit([0.9,0.1], 41)"""
tree_mini = ast.parse(c_mini_s)
ast.dump(tree_mini)

"Module(body=[Assign(targets=[Tuple(elts=[Name(id='train_ratings', ctx=Store()), Name(id='test_ratings', ctx=Store())], ctx=Store())], value=Call(func=Attribute(value=Name(id='user_info', ctx=Load()), attr='randomSplit', ctx=Load()), args=[List(elts=[Num(n=0.9), Num(n=0.1)], ctx=Load()), Num(n=41)], keywords=[]))])"

In [9]:
###Another way to look at ast is to parse it into a json.  For this to run you first must pip install ast2json

In [None]:
import json
from ast import parse
from ast2json import ast2json

In [None]:
ast_var = ast2json(parse(c_mini_s))

In [293]:
print json.dumps(ast_var, indent=4)

{
    "body": [
        {
            "value": {
                "_type": "Call", 
                "col_offset": 29, 
                "starargs": null, 
                "args": [
                    {
                        "elts": [
                            {
                                "_type": "Num", 
                                "lineno": 1, 
                                "col_offset": 52, 
                                "n": 0.9
                            }, 
                            {
                                "_type": "Num", 
                                "lineno": 1, 
                                "col_offset": 56, 
                                "n": 0.1
                            }
                        ], 
                        "_type": "List", 
                        "ctx": {
                            "_type": "Load"
                        }, 
                        "lineno": 1, 
                        "col_offset": 51
               

Now that we have a better idea of what is in the AST (we'll vizualize this later) we can write some code to parse it

Lexer is the main function with 

In [14]:
class Lexer(ast.NodeVisitor):
    """Parse a node from a AST and return all the information we may want about it

    node_type can be either "Import", "From_Import", "Assign_targ" or "Call".

    """
    
    def parse(self, code):
        '''Parse text into a tree and walk the result'''  
        tree = ast.parse(code)
        self.visit(tree)

    def visit_Import(self, node):
        """Called for "import library" statements."""
        items = []
        for item in node.names:
            items.append((node.lineno, "Import", item.name, item.asname))
        self.generic_visit(node)
        return items

    def visit_ImportFrom(self, node):
        """Called for "from library import object" statements."""
        self.generic_visit(node)
        items = []
        for item in node.names:
            items.append((item.name, item.asname))
        return [(node.lineno, "From_Import", node.module, items)]

    def visit_With(self, node):
        """Handle visiting a with statement."""
        items = []
        for item in node.items:
            items.append((node.lineno, "Import", leftmostname(item)))
        self.generic_visit(node)
        return items
    
    def visit_Assign(self, node):
        """Handle visiting an assignment"""
        targs = node.targets
        items = []
        for t in targs:
            try:
                id_=t.id
                #print id_
                items.append((node.lineno, "Assign_targ", t.id))                                                                           
            except AttributeError:
                pass
            try:
                targ_list = []
                for elm in t.elts:
                    targ_list.append(leftmostname(elm))
                items.append((node.lineno, "Assign_targ", targ_list))
            except AttributeError:
                pass   

            #Also handle the tuple    
    
        #print node.value
        call_info = leftmostname(node.value)
        
        items.append((node.lineno, "Assign_val", call_info))
        self.generic_visit(node)
        
        return items
        

    def visit_For(self, node):
        """Handle visiting a for statement."""
        targ = node.target
        return [(node.lineno, "For", leftmostname(targ))]
 
    
    def visit_Call(self, node):
        """Called for function and method calls."""
        # Some nodes have their name in the function object
        n_id = None

        try:
            n_id = node.func.id
        except AttributeError:
            pass
        # Others (those called as methods, or with a library name leading) have
        # the name in the attr block
        try:
            n_id = node.func.value.id + '.' + node.func.attr
        except AttributeError:
            pass
        
        #the above really only works when you have function.something and not function.something.more.things
        #for that reason, for the other cases we pass the node into leftmostname and the function is parsed correctly
        if n_id ==None:
            n_id = leftmostname(node)

        args = []
        try:
            call_args = node.args
            for c in call_args:
                args.append(leftmostname(c))
        except AttributeError:
            pass

        self.generic_visit(node)
        if n_id:
            return [(node.lineno, "Call", n_id, args)]

###Leftmostname is useful to iterate over the objects and find the type of instance the node is
###A full list of the node types can be found at https://greentreesnakes.readthedocs.io/en/latest/nodes.html

In [15]:
def leftmostname(node):
    """Attempts to find the first name in the tree."""
    if isinstance(node, Name):
        n_id = node.id
        try:
            n_id = n_id + '.' + node.attr
        except AttributeError:
            pass
        
        rtn = n_id
    elif isinstance(node, (BinOp, Compare)):
        rtn = leftmostname(node.left)
    elif isinstance(node, (Attribute, Subscript, Expr)):
        rtn = leftmostname(node.value)+ "." +  node.attr 
    elif isinstance(node, Call):
        rtn = leftmostname(node.func)
    elif isinstance(node, (BinOp, Compare)):
        rtn = leftmostname(node.left)
    elif isinstance(node, Assign):
        rtn = leftmostname(node.targets[0])
    elif isinstance(node, List):
        try:
            elems = []
            for e in node.elts:
                elems.append(leftmostname(e))
            rtn = elems
        except:
            rtn = None
    elif isinstance(node, Str):
        # handles case of "./my executable"
        rtn = node.s
    elif isinstance(node, Num):
        rtn=node.n
    else:
        rtn = None
    return rtn

In [17]:
#Try it out with one of the sinnipts
tree_mini = ast.parse(c_mini_s)
for node in ast.walk(tree_mini):
    ret = Lexer().visit(node)
    if ret!=None:
        print(ret)

[(1, 'Assign_targ', ['train_ratings', 'test_ratings']), (1, 'Assign_val', 'user_info.randomSplit')]
[(1, 'Call', 'user_info.randomSplit', [[0.9, 0.1], 41])]


In [19]:
###Ok so that worked, what about the big code block

In [20]:
tree = ast.parse(c) 

In [21]:
ast.dump(tree)

"Module(body=[Import(names=[alias(name='save_load', asname='sl')]), ImportFrom(module='src.utils', names=[alias(name='glove', asname=None)], level=0), ImportFrom(module='src.data_prep', names=[alias(name='jester_vectorize', asname='jestv')], level=0), Assign(targets=[Name(id='jest_jokes', ctx=Store())], value=Call(func=Attribute(value=Call(func=Attribute(value=Attribute(value=Name(id='sqlCtx', ctx=Load()), attr='read', ctx=Load()), attr='json', ctx=Load()), args=[Str(s='/Users/abethke/spark-1.6.0/jester/jester_jokes.json.gz')], keywords=[]), attr='sample', ctx=Load()), args=[NameConstant(value=True), Num(n=0.1), Num(n=41)], keywords=[])), Assign(targets=[Name(id='jest_rates', ctx=Store())], value=Call(func=Attribute(value=Call(func=Attribute(value=Attribute(value=Name(id='sqlCtx', ctx=Load()), attr='read', ctx=Load()), attr='json', ctx=Load()), args=[Str(s='/Users/abethke/spark-1.6.0/jester/jester_ratings.json.gz')], keywords=[]), attr='sample', ctx=Load()), args=[NameConstant(value=Tr

In [22]:
ret_info = []

In [23]:
tree = ast.parse(c) 
for node in ast.walk(tree):
    ret = Lexer().visit(node)
    if ret!=None:
        print(ret)
        ret_info.append(ret)

[(1, 'Import', 'save_load', 'sl')]
[(2, 'From_Import', 'src.utils', [('glove', None)])]
[(3, 'From_Import', 'src.data_prep', [('jester_vectorize', 'jestv')])]
[(4, 'Assign_targ', 'jest_jokes'), (4, 'Assign_val', 'sqlCtx.read.json.sample')]
[(5, 'Assign_targ', 'jest_rates'), (5, 'Assign_val', 'sqlCtx.read.json.sample')]
[(6, 'Assign_targ', 'glove_model'), (6, 'Assign_val', 'glove.Glove')]
[(7, 'Assign_targ', 'support_files'), (7, 'Assign_val', None)]
[(8, 'Assign_targ', 'jest_vect'), (8, 'Assign_val', 'jestv.jester_vectorize')]
[(9, 'Assign_targ', 'user_info'), (9, 'Assign_val', 'jest_vect.get_user_vector.repartition')]
[(10, 'Assign_targ', ['train_ratings', 'test_ratings']), (10, 'Assign_val', 'user_info.randomSplit')]
[(13, 'Assign_targ', 'content_vect'), (13, 'Assign_val', 'jest_vect.get_content_vector')]
[(4, 'Call', 'sqlCtx.read.json.sample', [None, 0.1, 41])]
[(5, 'Call', 'sqlCtx.read.json.sample', [None, 0.1, 41])]
[(6, 'Call', 'glove.Glove', ['/Users/abethke/spark-1.6.0/jester/g

In [24]:
ret_info[3]

[(4, 'Assign_targ', 'jest_jokes'),
 (4, 'Assign_val', 'sqlCtx.read.json.sample')]

In [25]:
##Nice!!! Everything we need.  Lexer or leftmostnode may need to be modified for other types of code, but we shall see

In [None]:
###Method for finding file paths from code using Regex

In [32]:
import re

In [37]:
re.findall('a', 'anna')

['a', 'a']

In [43]:
m = re.search("(\.\./[^/]+/)*.*", '/Users/abethke/spark-1.6.0/jester/jester_jokes.json.gz')
m.group(0)

'/Users/abethke/spark-1.6.0/jester/jester_jokes.json.gz'


###This portion of the code now uses the library from XXXX to visualize the ast tree
###It (apparently) can only run on Python3.  You also must brew install graphviz in order for it to work
###I copied it over here as I was having difficulty running it in the command line
###The issue could have stemmed from how graphviz was initially brought in.

###At the end of the day it will make a PDF 

In [16]:
import graphviz as gv
import subprocess
import numbers
import re
from uuid import uuid4 as uuid
import optparse
import sys

In [17]:
    parsers = {
        "pyast": generate_pyast,
        "lib2to3": generate_lib2to3_ast,
        "jinja2": generate_jinja2_ast,
    }

In [18]:
def generate_lib2to3_ast(code):
    from lib2to3.pgen2.driver import Driver
    from lib2to3.pgen2 import token as pgen2_token
    from lib2to3.pygram import python_symbols, python_grammar
    from lib2to3 import pytree
    from io import StringIO

    token_types = list(python_symbols.__dict__.items())
    token_types += list(pgen2_token.__dict__.items())

    def transform_ast(ast):
        transformed = {"node_type": next(n for n, t in token_types if t == ast.type)}
        if ast.children:
            transformed["children"] = [transform_ast(child) for child in ast.children]
        if isinstance(ast, pytree.Leaf):
            if ast.value != "":
                transformed["value"] = ast.value
            if ast._prefix != "":
                transformed["prefix"] = ast._prefix
        return transformed

    driver = Driver(python_grammar, convert=pytree.convert)
    return transform_ast(driver.parse_stream(StringIO(code)))


def generate_pyast(code):
    import ast
    def transform_ast(code_ast):
        if isinstance(code_ast, ast.AST):
            node = {to_camelcase(k): transform_ast(getattr(code_ast, k)) for k in code_ast._fields}
            node['node_type'] = to_camelcase(code_ast.__class__.__name__)
            return node
        elif isinstance(code_ast, list):
            return [transform_ast(el) for el in code_ast]
        else:
            return code_ast

    return transform_ast(ast.parse(code))


def generate_jinja2_ast(code):
    import jinja2

    def transform_ast(ast):
        if isinstance(ast, jinja2.nodes.Node):
            transformed = {k: transform_ast(getattr(ast, k)) for k in ast.fields + ast.attributes if k != "environment"}
            transformed["node_type"] = ast.__class__.__name__
            return transformed
        elif isinstance(ast, list):
            return [transform_ast(el) for el in ast]
        else:
            return ast

    env = jinja2.Environment()
    return transform_ast(env.parse(code))


def to_camelcase(string):
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', string).lower()


class GraphRenderer:
    """
    this class is capable of rendering data structures consisting of
    dicts and lists as a graph using graphviz
    """

    graphattrs = {
        'labelloc': 't',
        'fontcolor': 'white',
        'bgcolor': '#333333',
        'margin': '0',
    }

    nodeattrs = {
        'color': 'white',
        'fontcolor': 'white',
        'style': 'filled',
        'fillcolor': '#006699',
    }

    edgeattrs = {
        'color': 'white',
        'fontcolor': 'white',
    }

    _graph = None
    _rendered_nodes = None
    _max_label_len = 100


    @staticmethod
    def _escape_dot_label(str):
        return str.replace("\\", "\\\\").replace("|", "\\|").replace("<", "\\<").replace(">", "\\>")


    def _shorten_string(self, string):
        if len(string) > self._max_label_len - 3:
            halflen = int((self._max_label_len - 3) / 2)
            return string[:halflen] + "..." + string[-halflen:]
        return string


    def _render_node(self, node):
        if isinstance(node, (str, numbers.Number)) or node is None:
            node_id = uuid()
        else:
            node_id = id(node)
        node_id = str(node_id)

        if node_id not in self._rendered_nodes:
            self._rendered_nodes.add(node_id)
            if isinstance(node, dict):
                self._render_dict(node, node_id)
            elif isinstance(node, list):
                self._render_list(node, node_id)
            else:
                self._graph.node(node_id, label=self._escape_dot_label(self._shorten_string(repr(node))))

        return node_id


    def _render_dict(self, node, node_id):
        self._graph.node(node_id, label=node.get("node_type", "[dict]"))
        for key, value in node.items():
            if key == "node_type":
                continue
            child_node_id = self._render_node(value)
            self._graph.edge(node_id, child_node_id, label=self._escape_dot_label(key))


    def _render_list(self, node, node_id):
        self._graph.node(node_id, label="[list]")
        for idx, value in enumerate(node):
            child_node_id = self._render_node(value)
            self._graph.edge(node_id, child_node_id, label=self._escape_dot_label(str(idx)))


    def render(self, data, *, label=None):
        # create the graph
        graphattrs = self.graphattrs.copy()
        if label is not None:
            graphattrs['label'] = self._escape_dot_label(label)
        graph = gv.Digraph(graph_attr = graphattrs, node_attr = self.nodeattrs, edge_attr = self.edgeattrs)

        # recursively draw all the nodes and edges
        self._graph = graph
        self._rendered_nodes = set()
        self._render_node(data)
        self._graph = None
        self._rendered_nodes = None

        # display the graph
        graph.format = "pdf"
        graph.view()
        #subprocess.Popen(['xdg-open', "test.pdf"])

In [27]:
c_minig = """glove_model = glove.Glove.hi.there("/Users/abethke/spark-1.6.0/jester/glove.6B.50d.txt")"""
code = c_minig

generate_ast = parsers["pyast"]
code_ast = generate_ast(code)

renderer = GraphRenderer()
renderer.render(code_ast, label="Test Ast")