In [199]:
import re
import glob
import json

In [2]:
path = "../../graph/src/*.rs"

In [3]:
def read_file(path):
    with open(path, "r") as f:
        return f.read()
    
def read_files(path):
    return [
        read_file(file)
        for file in glob.glob(path)
    ]

In [4]:
files = read_files(path)

In [204]:
functions = [
    x
    for file in files
    for x in re.findall(
        r"(pub\s+fn\s+[^\{]+?)\{", 
        "\n".join(
            re.findall(
                "impl\s+Graph\s+\{.+\}", 
                re.sub(r"\s+", r" ", file.replace("\n", ""))
            )
        ) or "", 
        re.MULTILINE
    )
]
len(functions)

205

In [205]:
functions

['pub fn are_nodes_remappable(&self, other: &Graph) -> bool ',
 'pub fn remap(&self, other: &Graph, verbose: bool) -> Result<Graph, String> ',
 'pub fn set_name(&mut self, name: String) ',
 'pub fn extract_uniform_node(&self, node: NodeT, random_state: NodeT) -> NodeT ',
 'pub fn extract_node( &self, node: NodeT, random_state: NodeT, walk_weights: &WalkWeights, min_edge_id: EdgeT, max_edge_id: EdgeT, destinations: &[NodeT], probabilistic_indices: &Option<Vec<u64>>, ) -> (NodeT, EdgeT) ',
 'pub fn extract_edge( &self, src: NodeT, dst: NodeT, edge: EdgeT, random_state: NodeT, walk_weights: &WalkWeights, min_edge_id: EdgeT, max_edge_id: EdgeT, destinations: &[NodeT], previous_destinations: &[NodeT], probabilistic_indices: &Option<Vec<u64>>, ) -> (NodeT, EdgeT) ',
 "pub fn random_walks_iter<'a>( &'a self, quantity: NodeT, parameters: &'a WalksParameters, ) -> Result<impl IndexedParallelIterator<Item = Vec<NodeT>> + 'a, String> ",
 "pub fn complete_walks_iter<'a>( &'a self, parameters: &'a 

In [206]:
from lark import Lark, Transformer, Token, Tree

In [216]:
dont_care = lambda self, args: args[0]

class FunctionToJson(Transformer):
    def type(self, args):        
        arg = args[0]
        if isinstance(arg, Token):
            if arg.type == "IDENTIFIER":
                return arg.value.strip()
        if isinstance(arg, str):
            return arg
                
        return args
    
    start = dont_care
    function_return_type = dont_care
    lifetime_or_type = dont_care
        
    
    def type_with_modifiers(self, args):
        output = ""
        for token in args:
            if isinstance(token, Token):
                output += token.value
                continue
                
            if isinstance(token, str):
                output += token
                continue
                
            print("Warning, token %s not parsed"%token)
        return output
    
    def life_time(self, args):
        return "'" + args[0].value + " "
    
    def self_param(self, args):
        output = ""
        for token in args:
            if isinstance(token, Token):
                output += token.value
                continue
                
            if isinstance(token, str):
                output += token
                continue
                
        return {"self":output}
    
    def bound_type(self, args):
        return "{}<{}>".format(
            args[0].value.strip(),
            ", ".join(args[1:])
        )
    def assign_type(self, args):
        return args[0] + " = " + args[1]
    
    def slice_type(self, args):
        return "[" + args[0] + "]"
    
    def function_param(self, args):
        return {args[0].value.strip() : args[1]}
    
    def function_parameters(self, args):
        return {
            k: v 
            for arg in args
            for k, v in arg.items()
        }
    
    def tuple_type(self, args):
        return "(" + ", ".join(args) + ")"
    
    def impl_type(self,args):
        return "impl " + args[0]
    
    def summed_type(self, args):
        return " + ".join(args)
            
    def function_type(self, args):
        return "Fn(%s) -> %s"%tuple(args)
    
    def generic_params(self, args):
        return "<" + ", ".join(args) + ">"
    
    def function_name(self, args):
        if len(args) == 1:
            return args[0].value
        
        return args[0].value + args[1]
    
    def function(self, args):
        if len(args) > 3:
            returns = args[3]
        else:
            returns = None
        return {
            "function_name":args[1],
            "params":args[2],
            "return":returns
        }

In [217]:
with open("./rust_functions_grammar.lark") as f:
    grammar = f.read()
    
parser = Lark(grammar)

for function in functions[:10]:
    print(function)
    tree = parser.parse(function)
    # print(tree.pretty())
    print(json.dumps(FunctionToJson().transform(tree), indent=4))

pub fn are_nodes_remappable(&self, other: &Graph) -> bool 
{
    "function_name": " are_nodes_remappable",
    "params": {
        "self": "&self",
        "other": "&Graph"
    },
    "return": "bool"
}
pub fn remap(&self, other: &Graph, verbose: bool) -> Result<Graph, String> 
{
    "function_name": " remap",
    "params": {
        "self": "&self",
        "other": "&Graph",
        "verbose": "bool"
    },
    "return": "Result<Graph, String>"
}
pub fn set_name(&mut self, name: String) 
{
    "function_name": " set_name",
    "params": {
        "self": "&mutself",
        "name": "String"
    },
    "return": null
}
pub fn extract_uniform_node(&self, node: NodeT, random_state: NodeT) -> NodeT 
{
    "function_name": " extract_uniform_node",
    "params": {
        "self": "&self",
        "node": "NodeT",
        "random_state": "NodeT"
    },
    "return": "NodeT"
}
pub fn extract_node( &self, node: NodeT, random_state: NodeT, walk_weights: &WalkWeights, min_edge_id: EdgeT, max_e