In [None]:
%run parsing.ipynb
%run util.ipynb
%run LocalRepo.ipynb
from typing import Callable
import pdb

### Util

In [None]:
builtin_types = set(['void', 'String', 'byte', 'short', 'int', 'long', 'float', 'double', 'boolean', 'char', 'Byte', 'Short', 'Integer', 'Long', 'Float', 'Double', 'Boolean', 'Character'])
stl_types = set(['Override', 'ArrayList', 'List', 'LinkedList', 'Map', 'HashMap', 'Object', 'Throwable', 'Exception'])
ignored_types = builtin_types.union(stl_types)


error_query = JA_LANGUAGE.query("(ERROR) @err")
def _has_error(file) -> List[str]:
    errors = error_query.captures(file.get_tree().root_node)
    return len(errors) > 1


STRENGTH_FILE_IMPORT = 1
STRENGTH_ACCESS = 1
STRENGTH_CALL = 1

### Env classes

In [None]:
class Env:
    """an identifier-lookup environment"""
    def __init__(self, context, path):
        self.context = context
        self.path = path
    
    def get_env_for_name(self, compound_name):
        if compound_name in ignored_types:
            return None
        if "." in compound_name:
            [first_step, rest] = compound_name.split(".", 1)
            step_result = self.get_env_for_single_name(first_step)
            if step_result is None: return None
            return step_result.get_env_for_name(rest)
        else:
            return self.get_env_for_single_name(compound_name)
        
        
    def get_env_for_single_name(self, name):
        print("Abstract env cannot resolve a name!")
        return None
    
    def get_result_type_env(self):
        print("Abstract env does not have a return value!")
        return None

In [None]:
class RepoTreeEnv(Env):
    """the identifier-lookup environment of a whole file"""
    def __init__(self, context, node):
        Env.__init__(self, context, node.get_path())
        self.node = node
        root_node = self.node.get_root()
        self.base_envs = [RepoTreeEnv(self.context, root_node.find_node(base_path)) for base_path in self.context.get_base_types(self.path)]
        
    def get_env_for_single_name(self, name):
        if name == self.node.name:
            return self
        
        effective_children = self.node.effective_children()
        if name in effective_children:
            return RepoTreeEnv(self.context, effective_children[name])
        
        for import_path in self.context.get_imports(self.path):
            if import_path.endswith("/" + name):
                return RepoTreeEnv(self.context, self.node.get_root().find_node(import_path))
            
        for base_env in self.base_envs:
            base_result = base_env.get_env_for_single_name(name)
            if base_result is not None:
                return base_result
            
        parent_env = self._get_parent_env()
        if parent_env is not None:
            return parent_env.get_env_for_single_name(name)
        
        return None
        
    def _get_parent_env(self):
        if self.node.parent is None:
            return None
        else:
            return RepoTreeEnv(self.context, self.node.parent)
    
    def get_result_type_env(self):
        result_type_path = self.context.get_result_type(self.path)
        if result_type_path is None:
            return self  # if we don't know our type, we might as well be our own (happens e.g. in "Util.foo()")
        result_type_node = self.node.get_root().find_node(result_type_path)
        return RepoTreeEnv(self.context, result_type_node)

In [None]:
class NestedEnv(Env):
    """used for a class, a method, a for loop etc"""
    def __init__(self, context, parent_env):
        Env.__init__(self, context, None)  # has no path - should never be referenced
        self.parent_env = parent_env
        self.local_vars = {}
        
    def add_local_var(self, name, type_text):
        type_env = self.get_env_for_name(type_text)
        if type_env is None and type_text not in ignored_types:
            print("Unknown type for var:", type_text, name)
        self.local_vars[name] = type_env
        
    def get_env_for_single_name(self, name):
        if name in self.local_vars:
            return self.local_vars[name]
        return self.parent_env.get_env_for_single_name(name)
    
    def get_result_type_env(self):
        print("Nexted Env does not have a result type!")
        return None

### Context

In [None]:
package_query = JA_LANGUAGE.query("(package_declaration (_) @decl)")
import_query = JA_LANGUAGE.query("(import_declaration (scoped_identifier) @decl)")
class_query = JA_LANGUAGE.query("[(class_declaration name: (identifier) @decl) (interface_declaration name: (identifier) @decl) (enum_declaration name: (identifier) @decl)]")

class StructuralContext:
    def __init__(self, repo):
        self.repo = repo
        self.files = repo.get_all_interesting_files()
        self.full_class_name_to_path = {}
        self.file_path_to_imports = {}
        self.path_to_result_type_path = {}
        self.class_path_to_base_class_paths = {}
        
        def _get_package(file) -> List[str]:
            packages = package_query.captures(file.get_tree().root_node)
            # assert len(packages) <= 1
            if len(packages) > 1:
                pdb.set_trace(header="Multiple packet declarations found!")
            if len(packages) == 1:
                return file.node_text(packages[0][0]).split(".")
            else:
                return []
        def _get_import_strings(file) -> List[str]:
            imports = import_query.captures(file.get_tree().root_node)
            result = []
            for import_statement in imports:
                import_string = file.node_text(import_statement[0])
                if not import_string.startswith("java"):  # ignore stl imports
                    result.append(import_string)
            return result
        def _get_main_class_name(file) -> List[str]:
            classes = class_query.captures(file.get_tree().root_node)
            if len(classes) >= 1:
                return file.node_text(classes[0][0])
            else:
                return None
            
        for file in log_progress(self.files, desc="Building Import Graph", smoothing=0.1):
            if _has_error(file):
                continue
            class_name = _get_main_class_name(file)
            if class_name is not None:
                full_class_name = ".".join(_get_package(file) + [class_name])
                self.full_class_name_to_path[full_class_name] = file.get_path()
                class_node = file.get_repo_tree_node().find_node(class_name)
                if class_node is not None:
                    self.full_class_name_to_path[full_class_name] = class_node.get_path()
                else:
                    print("Cannot find a class node for this file!", file.get_path())
            else:
                print("Cannot find a class in this file!", file.get_path())
        
        for file in log_progress(self.files, desc="Building Import Graph 2", smoothing=0.1):
            imports = [self.full_class_name_to_path[i] for i in _get_import_strings(file) if i in self.full_class_name_to_path]
            self.file_path_to_imports[file.get_path()] = imports
            
        for file in log_progress(self.files, desc="Extracting inheritance hierarchy and result types", smoothing=0.1):
            node = file.get_repo_tree_node()
            if node is None:
                # pdb.set_trace()
                continue  # TODO filter those out / parse @interfaces
                
            # TODO keep in sync with evolutionary and linguistic view as well as RepoFile class
            classes = node.get_descendants_of_type("class") + node.get_descendants_of_type("interface") + node.get_descendants_of_type("enum")
            for class_node in classes:
                base_class_nodes = []
                superclass_ts_node = class_node.ts_node.child_by_field_name("superclass")
                if superclass_ts_node is not None:
                    assert superclass_ts_node.child_count == 2 and superclass_ts_node.children[0].type == "extends"
                    base_class_nodes.append(superclass_ts_node.children[1])
                interfaces_ts_node = class_node.ts_node.child_by_field_name("interfaces")
                if interfaces_ts_node is not None:
                    assert interfaces_ts_node.child_count == 2 and interfaces_ts_node.children[0].type == "implements"
                    base_class_nodes += interfaces_ts_node.children[1].children
                        
                if len(base_class_nodes) > 0:
                    base_paths = []
                    for base_class_node in base_class_nodes:
                        extended_class_path = self._resolve_type(file.node_text(base_class_node), node)
                        if extended_class_path is not None:
                            base_paths.append(extended_class_path)
                    if len(base_paths) > 0:
                        self.class_path_to_base_class_paths[class_node.get_path()] = base_paths
                        
                
                fields = class_node.get_children_of_type("field")
                methods = class_node.get_children_of_type("method")
                
                for pathable in (fields + methods):
                    type_node = pathable.ts_node.child_by_field_name("type")
                    if type_node is None:
                        pdb.set_trace(header="field/method has no type?")
                    result_type_text = file.node_text(type_node)
                    result_type_path = self._resolve_type(result_type_text, node)
                    if result_type_path is not None:
                        self.path_to_result_type_path[pathable.get_path()] = result_type_path
                
    def couple_files_by_import(self, coupling_graph):
        for file in log_progress(self.files, desc="Connecting files by imports", smoothing=0.1):
            for imported_class_path in self.file_path_to_imports.get(file.get_path(), []):
                coupling_graph.add_and_support(file.get_path(), imported_class_path, STRENGTH_FILE_IMPORT)
                
    def couple_members_by_content(self, coupling_graph):
        # TODO make sure that the methods are also coupled to their parameter types and their return type
        # TODO make also sure that fields are coupled to their type (and maybe their init code content?)
        for file in log_progress(self.files, desc="Connecting methods and fields by content", smoothing=0.1):
            node = file.get_repo_tree_node()
            if node is None:
                # pdb.set_trace()
                continue  # TODO filter those out / parse @interfaces
            classes = node.get_descendants_of_type("class") + node.get_descendants_of_type("interface") + node.get_descendants_of_type("enum")
            for class_node in classes:
                members = class_node.get_children_of_type("field") + class_node.get_children_of_type("method") + class_node.get_children_of_type("constructor")
                for member in members:
                    member_path = member.get_path()
                    def couple_member_to(path, strength):
                        if path is None:
                            pdb.set_trace(header="Cannot couple with nothing!")
                        coupling_graph.add_and_support(path, member_path, strength)
                    def get_text(node):
                        if node is None:
                            pdb.set_trace(header="node is None, cannot get text!")
                        return file.node_text(node)
                    couple_member_by_content(member, couple_member_to, get_text, self)
    
    def get_result_type(self, path):
        return self.path_to_result_type_path.get(path, None)
    
    def get_base_types(self, path):
        return self.class_path_to_base_class_paths.get(path, [])
    
    def get_imports(self, path):
        return self.file_path_to_imports.get(path, [])
    
    def _resolve_type(self, type_name, context_file_node):
        """return the full path that is meant by that type_name, or None if not known"""
        imports = self.file_path_to_imports.get(context_file_node.get_path(), [])
        for import_path in imports:
            if import_path.endswith("/" + type_name):
                return import_path
        result_env = RepoTreeEnv(self, context_file_node).get_env_for_name(type_name)
        if result_env is None:
            return None
        return result_env.path

### Functions

In [None]:




# TODOS:
# detect complex generic types (List<Foo>) and at least couple to foo, do not handle list.get(0).foomethod()
# properly handle arrays of things (ignore the .length attribute, but correctly infer type for further method call resolving)





def couple_member_by_content(
    member: RepoTree,
    couple_member_to: Callable[[str, float], None],
    get_text: Callable[[object], str],
    context: StructuralContext,
) -> None:
    """can be used on methods and fields"""
    # print("\n\n=======\nNow handling:\n", get_text(member.ts_node))
    
    this_env = RepoTreeEnv(context, member.parent)
    member_env = RepoTreeEnv(context, member)
    
    def couple_to(resolved_env, strength):
        if resolved_env not in [None, member_env]:
            couple_member_to(resolved_env.path, strength)
            # print("-> Coupled to", resolved_env.path)
        return resolved_env is not None
    
    def get_env_for_name(start_env, name):
        if name.startswith("this."):
            return this_env.get_env_for_name(name[len("this."):])
        else:
            return start_env.get_env_for_name(name)
    
            
    def iterate_tree(cursor, env) -> Env:
        result_env = env
        deeper_env = env
        if cursor.node.type in ["type_identifier", "scoped_type_identifier", "identifier", "scoped_identifier", "field_access"]:
            resolved_type_env = get_env_for_name(env, get_text(cursor.node))
            if couple_to(resolved_type_env, STRENGTH_ACCESS):
                result_env = resolved_type_env.get_result_type_env()
            else:
                result_env = None
        else:
            skip_children_for_iteration = []
            if cursor.node.type == "block":
                deeper_env = NestedEnv(context, result_env)
            elif cursor.node.type == "method_invocation":
                obj = cursor.node.child_by_field_name("object")
                target_env = env
                if obj is not None:  # call on other object than ourselves?
                    rec_cursor = obj.walk()
                    target_env = iterate_tree(rec_cursor, env)
                    skip_children_for_iteration.append(obj)
                if target_env is not None:  # target object resolve success?
                    resolved_method_env = get_env_for_name(target_env, get_text(cursor.node.child_by_field_name("name")))
                    if couple_to(resolved_method_env, STRENGTH_CALL):
                        result_env = resolved_method_env.get_result_type_env()
                    else:
                        result_env = None
            elif cursor.node.type == "local_variable_declaration":
                env.add_local_var(
                    get_text(cursor.node.child_by_field_name("declarator").child_by_field_name("name")),
                    get_text(cursor.node.child_by_field_name("type"))
                )
            elif cursor.node.type == "formal_parameter":
                env.add_local_var(
                    get_text(cursor.node.child_by_field_name("name")),
                    get_text(cursor.node.child_by_field_name("type"))
                )


            # all the rest is structure and needs to be fully iterated
            # TODO if one of those children has the name "body" (or the node type is body?), add a new local env layer
            if cursor.goto_first_child():
                if cursor.node not in skip_children_for_iteration:
                    iterate_tree(cursor, deeper_env)
                while cursor.goto_next_sibling():
                    if cursor.node not in skip_children_for_iteration:
                        iterate_tree(cursor, deeper_env)
                cursor.goto_parent()
        return result_env

    iterate_tree(member.ts_node.walk(), NestedEnv(context, member_env))
