In [None]:
%run parsing.ipynb
%run util.ipynb
%run LocalRepo.ipynb
from typing import Callable
import pdb

### Util

In [None]:
builtin_types = set(['void', 'String', 'byte', 'short', 'int', 'long', 'float', 'double', 'boolean', 'char', 'Byte', 'Short', 'Integer', 'Long', 'Float', 'Double', 'Boolean', 'Character'])
stl_types = set(['ArrayList', 'List', 'LinkedList', 'Map', 'HashMap', 'Object'])
ignored_types = builtin_types.union(stl_types)


error_query = JA_LANGUAGE.query("(ERROR) @err")
def _has_error(file) -> List[str]:
    errors = error_query.captures(file.get_tree().root_node)
    return len(errors) > 1


### Env classes

In [None]:
class Env:
    """an identifier-lookup environment"""
    def __init__(self, path):
        self.path = path
    
    def get_env_for_name(self, compound_name):
        if compound_name in ignored_types:
            return None
        if "." in compound_name:
            [first_step, rest] = compound_name.split(".", 1)
            step_result = self.get_env_for_single_name(first_step)
            if step_result is None: return None
            return step_result.get_env_for_name(rest)
        else:
            return self.get_env_for_single_name(compound_name)
        
        
    def get_env_for_single_name(self, name):
        print("Abstract env cannot resolve a name!")
        return None
    
    def get_result_type_env(self):
        print("Abstract env does not have a return value!")
        return None

In [None]:
class RepoTreeEnv(Env):
    """the identifier-lookup environment of a whole file"""
    def __init__(self, node):
        Env.__init__(self, node.get_path())
        self.node = node
        
    def get_env_for_single_name(self, name):
        if name == self.node.name:
            return self
        if self.node.has_child(name):
            return RepoTreeEnv(self.node.children[name])
        parent_env = self._get_parent_env()
        if parent_env is not None:
            return parent_env.get_env_for_single_name(name)
        return None
        
    def _get_parent_env(self):
        if self.node.parent is None:
            return None
        else:
            return RepoTreeEnv(self.node.parent)
    
    def get_result_type_env(self):
        print("TODO implement this shit!")
        return None

In [None]:
class NestedEnv(Env):
    """used for a class, a method, a for loop etc"""
    def __init__(self, parent_env):
        Env.__init__(self, None)  # has no path - should never be referenced
        self.parent_env = parent_env
        self.local_vars = {}
        
    def add_local_var(self, name, type_text):
        type_env = self.get_env_for_name(type_text)
        if type_env is None and type_text not in ignored_types:
            print("Unknown type for var:", type_text, name)
        self.local_vars[name] = type_env
        
    def get_env_for_single_name(self, name):
        if name in self.local_vars:
            return self.local_vars[name]
        return self.parent_env.get_env_for_single_name(name)
    
    def get_result_type_env(self):
        print("Nexted Env does not have a result type!")
        return None

### Context

In [None]:
class StructuralContext:
    def __init__(self, repo):
        self.repo = repo
        self.full_class_name_to_path = {}
        self.imports_per_file = {}
        
    def load_files(self, files):
        def _get_package(file) -> List[str]:
            packages = package_query.captures(file.get_tree().root_node)
            # assert len(packages) <= 1
            if len(packages) > 1:
                print("Multiple packet declarations found!")
                pdb.set_trace()
            if len(packages) == 1:
                return file.node_text(packages[0][0]).split(".")
            else:
                return []
        def _get_import_strings(file) -> List[str]:
            imports = import_query.captures(file.get_tree().root_node)
            result = []
            for import_statement in imports:
                import_string = file.node_text(import_statement[0])
                if not import_string.startswith("java"):  # ignore stl imports
                    result.append(import_string)
            return result
        def _get_main_class_name(file) -> List[str]:
            classes = class_query.captures(file.get_tree().root_node)
            if len(classes) >= 1:
                return file.node_text(classes[0][0])
            else:
                return None
            
        for file in log_progress(files, desc="Building Import Graph", smoothing=0.1):
            if _has_error(file):
                continue
            class_name = _get_main_class_name(file)
            if class_name is not None:
                full_class_name = ".".join(_get_package(file) + [class_name])
                self.full_class_name_to_path[full_class_name] = file.get_path()
                class_node = file.get_repo_tree_node().find_node(class_name)
                if class_node is not None:
                    self.full_class_name_to_path[full_class_name] = class_node.get_path()
                else:
                    print("Cannot find a class node for this file!", file.get_path())
            else:
                print("Cannot find a class in this file!", file.get_path())
        
        for file in log_progress(files, desc="Extracting connections", smoothing=0.1):
            imports = [self.full_class_name_to_path[i] for i in _get_import_strings(file) if i in self.full_class_name_to_path]
            self.imports_per_file[file] = imports
            
        # TODO after this, find all the result_types of all the members / methods
                
    def couple_files_by_import(self):
        # TODO 
        for file in files:
            coupling_graph.add_and_support(file.get_path(), imported_class_path, STRENGTH_FILE_IMPORT)

### Functions

In [None]:
STRENGTH_ACCESS = 1
STRENGTH_CALL = 1





# TODOS:
# add inhertiance to the env-lookup mechanism
# after inheritance, add the imports of a file to the env.lookup
# store the return types of methods / types of fields (commonly known as result_type?) in the repo_tree
# detect complex generic types (List<Foo>) and at least couple to foo, do not handle list.get(0).foomethod()
# properly handle arrays of things (ignore the .length attribute, but correctly infer type for further method call resolving)





def couple_method_by_content(
    method: RepoTree,
    couple_method_to: Callable[[str, float], None],
    get_text: Callable[[object], str],
) -> None:
    print("Now handling:\n", get_text(method.ts_node))
    pdb.set_trace()
    
    this_env = RepoTreeEnv(method.parent)
    method_env = RepoTreeEnv(method)
    
    def couple_to(resolved_env, strength):
        if resolved_env not in [None, method_env]:
            couple_method_to(resolved_env.path, strength)
            print("Coupled to", resolved_env.path)
        return resolved_env is not None
    
    def get_env_for_name(start_env, name):
        if name.startswith("this."):
            return this_env.get_env_for_name(name[len("this."):])
        else:
            return start_env.get_env_for_name(name)
    
            
    def iterate_tree(cursor, env) -> Env:
        result_env = env
        deeper_env = env
        if cursor.node.type in ["type_identifier", "scoped_type_identifier", "identifier", "scoped_identifier", "field_access"]:
            resolved_type_env = get_env_for_name(env, get_text(cursor.node))
            couple_to(resolved_type_env, STRENGTH_ACCESS)
            result_env = resolved_type_env
        else:
            if cursor.node.type == "block":
                deeper_env = NestedEnv(result_env)
            elif cursor.node.type == "method_invocation":
                pdb.set_trace()
                obj = cursor.node.child_by_field_name("object")
                target_env = env
                if obj is not None:  # call on other object than ourselves?
                    rec_cursor = obj.walk()
                    target_env = iterate_tree(rec_cursor, env)
                if target_env is not None:  # target object resolve success?
                    resolved_method_env = get_env_for_name(target_env, get_text(cursor.node.child_by_field_name("name")))
                    if couple_to(resolved_method_env, STRENGTH_CALL):
                        result_env = resolved_method_env.get_result_type_env()
                    else:
                        result_env = None
            elif cursor.node.type == "local_variable_declaration":
                env.add_local_var(
                    get_text(cursor.node.child_by_field_name("declarator").child_by_field_name("name")),
                    get_text(cursor.node.child_by_field_name("type"))
                )
            elif cursor.node.type == "formal_parameter":
                env.add_local_var(
                    get_text(cursor.node.child_by_field_name("name")),
                    get_text(cursor.node.child_by_field_name("type"))
                )


            # all the rest is structure and needs to be fully iterated
            # TODO if one of those children has the name "body" (or the node type is body?), add a new local env layer
            if cursor.goto_first_child():
                iterate_tree(cursor, deeper_env)
                while cursor.goto_next_sibling():
                    iterate_tree(cursor, deeper_env)
                cursor.goto_parent()
        return result_env

    iterate_tree(method.ts_node.walk(), NestedEnv(method_env))
