# Prototype to showcase AST manipulation

This notebook demonstrates a prototype that showcases how we can convert a movement function into an `AST` and then manipulate the nodes to allow the `ctx`to be used.

1. Convert the string code to AST Tree
2. Read all import statments and log them
3. Use Node Transformer to transform the nodes like `param`, `geo`
4. Replace instance of `param`, `geo` with `ctx.param`, `ctx.geo`
5. Add `ctx` as the first parameter to the function sinature so that the functions are executed based on the context now instead of passing geo in the namespace.

In [1]:
import ast
from typing import Any

def extract_imports(func_def):
    """
    Extracts import statements from a function definition.

    Args:
        func_def (ast.FunctionDef): The function definition to extract imports from.

    Returns:
        list[str]: A list of import statements.
    """
    import_statements = []
    for stmt in func_def.body:
        if isinstance(stmt, ast.Import):
            for alias in stmt.names:
                import_statements.append(alias.name)
        elif isinstance(stmt, ast.ImportFrom):
            module = stmt.module
            for alias in stmt.names:
                import_statements.append(f"{module}.{alias.name}")
    return import_statements


def modify_function_code(input_code):
    tree = ast.parse(input_code, '<string>', mode='exec')
    f_def = tree.body[0]
    
    if not isinstance(f_def, ast.FunctionDef):
        raise Exception("Code does not define a valid function")
    import_statements = extract_imports(f_def)

    print("Import statements:")
    for import_statement in import_statements:
        print(import_statement)


    class Transformer(ast.NodeTransformer):
        def visit_Subscript(self, node):
            if isinstance(node.value, ast.Name) and node.value.id == 'geo':
                node.value = ast.Attribute(value=ast.Name(id='ctx', ctx=ast.Load()), attr='geo', ctx=ast.Load())
                self.modified = True
            elif isinstance(node.value, ast.Name) and node.value.id == 'param':
                node.value = ast.Attribute(value=ast.Name(id='ctx', ctx=ast.Load()), attr='param', ctx=ast.Load())
                self.modified = True
            return node

        def visit_FunctionDef(self, node):
            self.modified = False
            self.generic_visit(node)
            if self.modified and 'ctx' not in [arg.arg for arg in node.args.args]:
                node.args.args.insert(0, ast.arg(arg='ctx', annotation=None))
            return node

    transformer = Transformer()
    tree = transformer.visit(tree)

    return tree
    

def ast_to_callable(tree):
    f_def = tree.body[0]
    module = ast.Module(body=[f_def], type_ignores=[])
    code = compile(module, '<string>', mode='exec')
    global_namespace = {}
    local_namespace: dict[str, Any] = {}
    exec(code, global_namespace, local_namespace)
    return local_namespace[f_def.name]

input_code = """def commuters(t):
    import pandas as pd
    typical = predef['commuters_by_node']
    actual = np.binomial(typical, param['move_control'])
    return np.multinomial(actual, predef['commuting_probability'])"""

# Call modify_function_code with the input code
modified_tree = modify_function_code(input_code)
ast.fix_missing_locations(modified_tree)

# Convert AST back to code
generated_code = ast.unparse(modified_tree)

print("Input code:")
print(input_code)

print("Transformed code:")
print(generated_code)

func = ast_to_callable(modified_tree)


Import statements:
pandas
Input code:
def commuters(t):
    import pandas as pd
    typical = predef['commuters_by_node']
    actual = np.binomial(typical, param['move_control'])
    return np.multinomial(actual, predef['commuting_probability'])
Transformed code:
def commuters(ctx, t):
    import pandas as pd
    typical = predef['commuters_by_node']
    actual = np.binomial(typical, ctx.param['move_control'])
    return np.multinomial(actual, predef['commuting_probability'])


## Simulation Analysis

### Compiling the function every time step   

Loop for `N` timesteps:

1. Make_Context
2. Make_Namespace
3. Parse_Function(String_Code)
4. Compile_Function(Function_Def)
5. Run_Function(t)

In [2]:
from unittest.mock import MagicMock
import numpy as np
from epymorph.clock import Clock
from epymorph.context import SimContext
from epymorph.geo.geo import Geo
from epymorph.movement.dynamic import make_global_namespace
from epymorph.util import compile_function, parse_function

for t in range(1,1000):
    ctx = MagicMock(spec=SimContext)
    ctx.geo = MagicMock(spec=Geo)
    ctx.compartments = 11
    ctx.compartment_tags = [[], [], [], [], [], [], [], [], [], [], []]
    ctx.events =14
    ctx.param = {
        'move_control': 1,
        'theta': np.array(0.1),
        'omega': np.array([0.55, 0.05]),
        'delta': np.array([0.333, 0.5, 0.166, 0.142, 0.125]),
        'gamma': np.array([0.166, 0.333, 0.25]),
        'rho': np.array([0.4, 0.175, 0.015, 0.2, 0.6]),
        'beta': np.array([0.62813624, 0.54896858, 0.40599664, 0.49297194, 0.52055049, 0.41317462]),
        'phi': np.array(0.1)
    }
    ctx.nodes = 6
    ctx.ticks = 730
    ctx.days = 365
    ctx.TNCS = (730, 6, 11, 14)
    ctx.clock = MagicMock(Clock)
    ctx.rng = np.random.Generator

    ns = make_global_namespace(ctx)

    input_code = """def commuters(t):
        import pandas as pd
        actual = param['move_control'] * t
        return actual"""

    f_def = parse_function(input_code)
    c_function = compile_function(f_def,ns)
    c_function(t)


### Compiling the function once and pass `ctx` as a param

1. modify_function_code(input_code)
2. compile_function(ast_tree)
3. Loop for `N` timesteps:

    -     Make_Context
    -     Run_Function(ctx,t)

In [3]:
import time
from unittest.mock import MagicMock
import numpy as np
from epymorph.clock import Clock
from epymorph.context import SimContext
from epymorph.geo.geo import Geo


input_code = """def commuters(t):
    import pandas as pd
    actual = param['move_control'] * t
    return actual"""

modified_tree = modify_function_code(input_code)
ast.fix_missing_locations(modified_tree)

c_function = ast_to_callable(modified_tree)

generated_code = ast.unparse(modified_tree)

print("Input code:")
print(input_code)

print("Transformed code:")
print(generated_code)

for t in range(1,1000):
    ctx = MagicMock(spec=SimContext)
    ctx.geo = MagicMock(spec=Geo)
    ctx.compartments = 11
    ctx.compartment_tags = [[], [], [], [], [], [], [], [], [], [], []]
    ctx.events =14
    ctx.param = {
        'move_control': 1,
        'theta': np.array(0.1),
        'omega': np.array([0.55, 0.05]),
        'delta': np.array([0.333, 0.5, 0.166, 0.142, 0.125]),
        'gamma': np.array([0.166, 0.333, 0.25]),
        'rho': np.array([0.4, 0.175, 0.015, 0.2, 0.6]),
        'beta': np.array([0.62813624, 0.54896858, 0.40599664, 0.49297194, 0.52055049, 0.41317462]),
        'phi': np.array(0.1)
    }
    ctx.nodes = 6
    ctx.ticks = 730
    ctx.days = 365
    ctx.TNCS = (730, 6, 11, 14)
    ctx.clock = MagicMock(Clock)
    ctx.rng = np.random.Generator

    c_function(ctx,t)


Import statements:
pandas
Input code:
def commuters(t):
    import pandas as pd
    actual = param['move_control'] * t
    return actual
Transformed code:
def commuters(ctx, t):
    import pandas as pd
    actual = ctx.param['move_control'] * t
    return actual
