In [1]:
import ast

def find_global_variable_node(source_code, variable_name):
    tree = ast.parse(source_code)

    class GlobalVariableVisitor(ast.NodeVisitor):
        def __init__(self, variable_name):
            self.variable_name = variable_name
            self.variable_node = None

        def visit_Assign(self, node):
            for target in node.targets:
                if isinstance(target, ast.Name) and target.id == self.variable_name:
                    self.variable_node = node
                    return

        def visit_Global(self, node):
            if self.variable_name in node.names:
                self.variable_node = node
                return

    visitor = GlobalVariableVisitor(variable_name)
    visitor.visit(tree)

    return visitor.variable_node

In [2]:
import inspect
from dis import Bytecode
from typing import List, Type
from modelos.util.notebook import get_notebook_class_code, get_all_notebook_code

def bytecode_analyzer(inst):
    from types import CodeType, FunctionType, ModuleType
    from importlib import import_module
    from functools import reduce
    
    if inst.opname == "IMPORT_NAME":
        path = inst.argval.split(".")
        path[0] = [import_module(path[0])]
        result = reduce(lambda x, a: x + [getattr(x[-1], a)], path)
        return ("modules", result)
    if inst.opname == "LOAD_GLOBAL":
        if inst.argval in globals() and type(globals()[inst.argval]) in [CodeType, FunctionType]:
            return ("code", globals()[inst.argval])
        if inst.argval in globals() and type(globals()[inst.argval]) == ModuleType:
            return ("modules", [globals()[inst.argval]])
        else:
            return None
    if "LOAD_" in inst.opname and type(inst.argval) in [CodeType, FunctionType]:
        return ("code", inst.argval)
    return None

#https://chapeau.freevariable.com/2017/12/module-frontier.html
def get_local_fn_globals(f) -> List[str]:
    worklist = [f]
    seen = set()
    mods = set()
    global_vars = set()

    for fn in worklist:
        print("fn: ", fn)
        codeworklist = [fn]
        cvs = inspect.getclosurevars(fn)
        print("cvs: ", cvs)
        gvars = cvs.globals
        print("gvars: ", gvars)
        for k, v in gvars.items():
            global_vars.add(k)
            if inspect.isfunction(v) and id(v) not in seen:
                seen.add(id(v))
                mods.add(v.__module__)
                worklist.append(v)
            elif hasattr(v, "__module__"):
                mods.add(v.__module__)
        for block in codeworklist:
            for (k, v) in [bytecode_analyzer(inst) for inst in Bytecode(block) if bytecode_analyzer(inst)]:
                if k == "modules":
                    newmods = [mod.__name__ for mod in v if hasattr(mod, "__name__")]
                    mods.update(set(newmods))
                elif k == "code" and id(v) not in seen:
                    print("!!! v code: ", v)
                    seen.add(id(v))
                    if hasattr(v, "__module__"):
                        mods.add(v.__module__)
                if(inspect.isfunction(v)):
                    worklist.append(v)
                elif(inspect.iscode(v)):
                    codeworklist.append(v)

        # result = list(mods)
        # result.sort()
        return list(global_vars)

In [18]:
from typing import List, Type, Optional, Any, Set, get_args, get_type_hints
from types import NoneType
import inspect
import logging
from inspect import signature
from modelos.util.notebook import get_notebook_class_code, find_import_statements, get_all_notebook_code
from modelos.object.encoding import is_first_order, is_iterable_cls, is_tuple, is_list, is_tuple, is_optional, is_union, is_enum, is_dict, is_set
from unimport.main import Main as unimport_main
from black import format_str, FileMode

def code_for_annotation(t: Type) -> Optional[List[str]]:
    print("checking annotation: ", t)
    ret: List[str] = []
    def _code_for_annotation(t: Type):
        if t == NoneType or t is None or t == Any:
            print("is none")
            pass
        if is_first_order(t):
            print("is fo")
            return
        elif is_tuple(t) or is_list(t) or is_union(t) or is_dict(t) or is_optional(t) or is_iterable_cls(t) or is_set(t):
            print("is many")
            args = get_args(t)
            for arg in args:
                _code_for_annotation(arg)
        elif is_enum(t):
            print("is enum")
            mod = inspect.getmodule(t)
            if mod and mod.__name__ == "__main__":
                try:
                    s = get_notebook_class_code(t)
                    ret.append(s)
                except Exception:
                    pass
        elif hasattr(t, "__annotations__"):
            print("has attr")
            annots = get_type_hints(t)
            for nm, typ in annots.items():
                print("nm: ", nm)
                print("typ: ", typ)
                _code_for_annotation(typ)
            print("getting mod")
            mod = inspect.getmodule(t)
            print("mod: ", mod.__name__)
            if mod and mod.__name__ == "__main__":
                print("mod is main")
                try:
                    s = get_notebook_class_code(t)
                    print("s: ", s)
                    ret.append(s)
                except Exception:
                    pass
        else:
            print("is else")
            mod = inspect.getmodule(t)
            if mod and mod.__name__ == "__main__":
                try:
                    s = get_notebook_class_code(t)
                    print("got notebook code: ", s)
                    ret.append(s)
                except Exception:
                    pass
    
    _code_for_annotation(t)
    if ret:
        return ret
    return None


def get_bases(t: Type) -> Optional[List[Type]]:
    mro = inspect.getmro(t)
    if len(mro) <= 2:
        return None
    mro = mro[1:-1]
    return list(mro)


def get_codes_for_cls(cls: Type) -> List[str]:
    mod = inspect.getmodule(cls)
    if not mod or mod.__name__ != "__main__":
        logging.warning("trying to get code for cls outside of a notebook")
        return ""
    code = get_all_notebook_code()
    fns = inspect.getmembers(cls, predicate=inspect.isfunction)
    print("fns: ", fns)
    methods = inspect.getmembers(cls, predicate=inspect.ismethod)
    print("methods: ", methods)

    fns.extend(methods)

    fin_codes: Set[str] = set()
    for name, fn in fns:
        print("processing: ", name)
        globals = get_local_fn_globals(fn)
        for nm in globals:
            node = find_global_variable_node(code, nm)
            if not node:
                logging.warning(f"could not find ast node for global '{nm}'")
                continue
            cd = ast.unparse(node)
            fin_codes.add(cd)

        sig = signature(fn, eval_str=True, follow_wrapped=True)
        params = sig.parameters
        for nm in params:
            param = params[nm]
            t = param.annotation
            cds = code_for_annotation(t)
            if cds:
                for cd in cds: 
                    fin_codes.add(cd)

        cds = code_for_annotation(sig.return_annotation)
        if cds:
            for cd in cds:
                fin_codes.add(cd)

    bases = get_bases(cls)
    if bases:
        for base in bases:
            print("getting code for base: ", base)
            cds = get_codes_for_cls(base)
            for cd in cds:
                fin_codes.add(cd)

    cls_code = get_notebook_class_code(cls)
    fin_codes.add(cls_code)

    return list(fin_codes)


def write_file_for_cls(cls: Type) -> None:
    ret_code = "# This code was generated by ModelOS\n"
    ret_code += "from __future__ import annotations\n"

    # get all import statements
    for stmt in find_import_statements():
        ret_code += stmt + "\n"

    # get all dependent objects
    for code in get_codes_for_cls(cls):
        ret_code += "\n"
        ret_code += code
        ret_code += "\n"

    ret_code = format_str(ret_code, mode=FileMode())
    
    fp = f"./obj_{cls.__name__.lower()}.py"
    with open(fp, "w+") as f:
        f.write(ret_code)

    unimport_main.run(["-r", fp])
    
    return ret_code

In [19]:
import numpy as np
import simpletransformers.classification

In [20]:
A_VAR = "var"

In [21]:
class Foo:
    def e(self, s: np.ndarray):
        return ""

In [22]:
class A:
    _model: simpletransformers.classification.ClassificationModel

    def b(self, this: str, that: int):
        model = simpletransformers.classification.ClassificationModel(
            "roberta", "roberta-base"
        )
        self._model = model
        return "a"
    
    def c(self, s: str) -> str:
        v = A_VAR
        return v + s
    
    def d(self, foo: Foo):
        return ""
        

In [23]:
class B(A):
    def g(self, i: int) -> None:
        print(i)

In [24]:
write_file_for_cls(A)



fns:  [('b', <function A.b at 0x15f6d04c0>), ('c', <function A.c at 0x15f6d03a0>), ('d', <function A.d at 0x15f6d2050>)]
methods:  []
processing:  b
fn:  <function A.b at 0x15f6d04c0>
cvs:  ClosureVars(nonlocals={}, globals={'simpletransformers': <module 'simpletransformers' from '/Users/patrickbarker/ghq/github.com/aunum/modelos/.venv/lib/python3.10/site-packages/simpletransformers/__init__.py'>}, builtins={}, unbound={'ClassificationModel', '_model', 'classification'})
gvars:  {'simpletransformers': <module 'simpletransformers' from '/Users/patrickbarker/ghq/github.com/aunum/modelos/.venv/lib/python3.10/site-packages/simpletransformers/__init__.py'>}
checking annotation:  <class 'inspect._empty'>
has attr
getting mod
mod:  inspect
checking annotation:  <class 'str'>
is fo
checking annotation:  <class 'int'>
is fo
checking annotation:  <class 'inspect._empty'>
has attr
getting mod
mod:  inspect
processing:  c
fn:  <function A.c at 0x15f6d03a0>
cvs:  ClosureVars(nonlocals={}, globals={

'# This code was generated by ModelOS\nfrom __future__ import annotations\nimport numpy as np\nimport inspect\nimport logging\nfrom unimport.main import Main as unimport_main\nfrom modelos.util.notebook import (\n    get_notebook_class_code,\n    find_import_statements,\n    get_all_notebook_code,\n)\nimport ast\nfrom inspect import signature\nfrom typing import List, Type, Optional, Any, Set, get_args, get_type_hints\nfrom typing import List, Type\nfrom types import CodeType, FunctionType, ModuleType\nfrom functools import reduce\nfrom black import format_str, FileMode\nfrom types import NoneType\nimport simpletransformers.classification\nfrom importlib import import_module\nfrom modelos.util.notebook import get_notebook_class_code, get_all_notebook_code\nfrom modelos.object.encoding import (\n    is_first_order,\n    is_iterable_cls,\n    is_tuple,\n    is_list,\n    is_tuple,\n    is_optional,\n    is_union,\n    is_enum,\n    is_dict,\n    is_set,\n)\nfrom dis import Bytecode\n\n\n

In [25]:
write_file_for_cls(B)



fns:  [('b', <function A.b at 0x15f6d04c0>), ('c', <function A.c at 0x15f6d03a0>), ('d', <function A.d at 0x15f6d2050>), ('g', <function B.g at 0x15f6d27a0>)]
methods:  []
processing:  b
fn:  <function A.b at 0x15f6d04c0>
cvs:  ClosureVars(nonlocals={}, globals={'simpletransformers': <module 'simpletransformers' from '/Users/patrickbarker/ghq/github.com/aunum/modelos/.venv/lib/python3.10/site-packages/simpletransformers/__init__.py'>}, builtins={}, unbound={'ClassificationModel', '_model', 'classification'})
gvars:  {'simpletransformers': <module 'simpletransformers' from '/Users/patrickbarker/ghq/github.com/aunum/modelos/.venv/lib/python3.10/site-packages/simpletransformers/__init__.py'>}
checking annotation:  <class 'inspect._empty'>
has attr
getting mod
mod:  inspect
checking annotation:  <class 'str'>
is fo
checking annotation:  <class 'int'>
is fo
checking annotation:  <class 'inspect._empty'>
has attr
getting mod
mod:  inspect
processing:  c
fn:  <function A.c at 0x15f6d03a0>
cvs

'# This code was generated by ModelOS\nfrom __future__ import annotations\nimport numpy as np\nimport inspect\nimport logging\nfrom unimport.main import Main as unimport_main\nfrom modelos.util.notebook import (\n    get_notebook_class_code,\n    find_import_statements,\n    get_all_notebook_code,\n)\nimport ast\nfrom inspect import signature\nfrom typing import List, Type, Optional, Any, Set, get_args, get_type_hints\nfrom typing import List, Type\nfrom types import CodeType, FunctionType, ModuleType\nfrom functools import reduce\nfrom black import format_str, FileMode\nfrom types import NoneType\nimport simpletransformers.classification\nfrom importlib import import_module\nfrom modelos.util.notebook import get_notebook_class_code, get_all_notebook_code\nfrom modelos.object.encoding import (\n    is_first_order,\n    is_iterable_cls,\n    is_tuple,\n    is_list,\n    is_tuple,\n    is_optional,\n    is_union,\n    is_enum,\n    is_dict,\n    is_set,\n)\nfrom dis import Bytecode\n\n\n

In [None]:
# import ast

# class DependencyFinder(ast.NodeVisitor):
#     def __init__(self):
#         self.dependencies = set()

#     def visit_Name(self, node):
#         self.dependencies.add(node)

#     def visit_Attribute(self, node):
#         if isinstance(node.value, ast.Name):
#             print("node: ", node)
#             self.dependencies.add(node)

# def extract_dependencies(code):
#     tree = ast.parse(code)
#     finder = DependencyFinder()

#     for node in ast.walk(tree):
#         if isinstance(node, ast.ClassDef):
#             for child_node in ast.iter_child_nodes(node):
#                 if isinstance(child_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
#                     finder.visit(child_node)

#     return finder.dependencies
