In [18]:
from git import Git, Repo, RemoteProgress
import os
import subprocess
import pdb
from typing import List

In [19]:
REPO_URL_START = "https://github.com/"
REPO_URL_END = ".git"
REPO_CLONE_PATH = "../repos/"

In [20]:
# TODO progress display: https://stackoverflow.com/questions/38861829/how-do-i-implement-a-progress-bar

In [23]:
# https://gitpython.readthedocs.io/en/stable/reference.html
class LocalRepo:
    def __init__(self, name):
        self.name = name
        self.tree = None
        
    def update(self):
        if not self.is_cloned():
            print("cloning " + self.name + ", this may take a while...")
            self.clone()
        else:
            print("updating " + self.name + "...")
            self.pull()
        print("Repo is up to date!")
    
    def pull(self):
        # https://github.com/gitpython-developers/GitPython/issues/296#issuecomment-449769231
        Repo(self.path()).remote().fetch("+refs/heads/*:refs/heads/*")
    
    def clone(self):
        Repo.clone_from(self.url(), self.path(), None, None, ["--bare"])
    
    def is_cloned(self):
        return os.path.isdir(self.path())
    
    def path(self):
        return REPO_CLONE_PATH + self.name
    
    def url(self):
        return REPO_URL_START + self.name + REPO_URL_END
    
    def type_extension(self):
        return "java" # return the file extension that the files of your language have
    
    def get_file_objects(self, commit_hash = None):
        commit = None
        if commit_hash is None:
            commit = Repo(self.path()).head.commit
        else:
            commit = Repo(self.path()).commit(commit_hash)
        ending = "." + self.type_extension()
        files = []
        for git_object in commit.tree.traverse():
            if git_object.type == "blob":
                if git_object.name.endswith(ending):
                    files.append(git_object)
        return files
    
    def get_file_object_content(self, git_object):
        return git_object.data_stream.read()

    def get_all_commits(self):
        return Git(self.path()).log("--pretty=%H").split("\n")
    
    def get_commit(self, sha):
        return Repo(self.path()).commit(sha)
    
    def get_tree(self):
        if self.tree is None:
            self.tree = RepoTree.init_from_repo(self)
        return self.tree

In [None]:
class RepoFile:
    def __init__(self, repo, file_obj):
        self.repo = repo
        self.file_obj = file_obj
        self.content = None
        self.tree = None
        
    def get_path(self):
        return self.file_obj.path
    
    def get_content(self):
        if self.content is None:
            self.content = self.repo.get_file_object_content(self.file_obj)
        return self.content
    
    def get_content_without_copyright(self):
        tree = self.get_tree()
        first_root_child = tree.root_node.children[0]
        if first_root_child.type == "comment":
            return self.get_content()[first_root_child.end_byte:].decode("utf-8")
        else:
            return self.get_content().decode("utf-8")
    
    def get_repo_tree_node(self):
        return self.repo.get_tree().find_node(self.get_path())
    
    def get_tree(self):
        if self.tree is None:
            self.tree = java_parser.parse(self.get_content())
        return self.tree
    
    def node_text(self, node):
        return self.content[node.start_byte:node.end_byte].decode("utf-8")
    
    def walk_tree(self, node_handler):
        """ node_handler gets the current logic-path and node for each ast node"""
        self.walk_tree_cursor(self.get_tree().walk(), self.get_path(), node_handler)
    
    def walk_tree_cursor(self, cursor, prefix, node_handler):
        if not cursor.node.is_named:
            return
        # cursor.current_field_name() is the role that this node has in its parent
        tree_node_name = None
        if cursor.node.type == "class_declaration" or cursor.node.type == "interface_declaration":
            idfield = cursor.node.child_by_field_name("name")
            tree_node_name = self.node_text(idfield)
        elif cursor.node.type == "field_declaration":
            idfield = cursor.node.child_by_field_name("declarator").child_by_field_name("name")
            tree_node_name = self.node_text(idfield)
        elif cursor.node.type == "method_declaration":
            idfield = cursor.node.child_by_field_name("name")
            tree_node_name = self.node_text(idfield)

        if tree_node_name is not None:
            prefix = prefix + "/" + tree_node_name
            # found_nodes.register(prefix)
            node_handler(prefix, cursor.node)

        if cursor.goto_first_child():
            self.walk_tree_cursor(cursor, prefix, node_handler)
            while cursor.goto_next_sibling():
                self.walk_tree_cursor(cursor, prefix, node_handler)
            cursor.goto_parent()

In [None]:
class RepoTree:
    
    @staticmethod
    def init_from_repo(repo) -> 'RepoTree':
        found_nodes = RepoTree(None, "")
        files = repo.get_file_objects()
        print("Analyzing " + str(len(files)) + " files...")
        for file_obj in files:
            file = RepoFile(r, file_obj)
            def handle(logic_path, ts_node):
                found_nodes.register(logic_path, ts_node)
            file.walk_tree(handle)
        print("Found " + str(found_nodes.node_count()) + " classes, methods and fields!")

        # with open("../debug-tree.json", "w") as outfile:
        #     outfile.write(found_nodes.to_json())
        return found_nodes
    
    def __init__(self, parent, name, ts_node = None):
        self.parent = parent
        self.name = name
        if parent is not None and len(name) == 0:
            pdb.set_trace()
        self.ts_node = ts_node
        self.children = {}
        
    def get_path(self):
        if self.parent is None or len(self.parent.name) == 0:
            return self.name
        else:
            return self.parent.get_path() + "/" + self.name
    
    
    def register(self, path, ts_node):
        self.register_list(path.split("/"), ts_node)
        
    def register_list(self, path_segments, ts_node):
        if len(path_segments) > 1:
            self.register_child(path_segments[0], None).register_list(path_segments[1:], ts_node)
        elif len(path_segments) == 1:
            self.register_child(path_segments[0], ts_node)
        else:
            raise Exception("Should not reach here!")
    
    def register_child(self, name, ts_node) -> 'RepoTree':
        if not name in self.children:
            self.children[name] = RepoTree(self, name, ts_node)
        return self.children[name]
    
    
    def find_node(self, path) -> 'RepoTree':
        if (len(path) == 0):
            return self
        else:
            return self.find_node_list(path.split("/"))
    
    def find_node_list(self, path_segments) -> 'RepoTree':
        if len(path_segments) == 0:
            return self
        elif path_segments[0] in self.children:
            return self.children[path_segments[0]].find_node_list(path_segments[1:])
        else:
            return None
            
    
    def get_type(self) -> str:
        if self.ts_node is None:
            return None
        node_type = self.ts_node.type
        if node_type.endswith("_declaration"):
            node_type = node_type[0:-len("_declaration")]
        return node_type
    
    def get_children_of_type(self, type_str) -> List['RepoTree']:
        return [c for c in self.children.values() if c.get_type() == type_str]
    
    def get_descendants_of_type(self, type_str) -> List['RepoTree']:
        children_descendants = [child.get_descendants_of_type(type_str) for child in self.children.values()]
        return self.get_children_of_type(type_str) + [descendant for sublist in children_descendants for descendant in sublist]
    
    
    def has(self, path) -> bool:
        return self.has_list(path.split("/"))
        
    def has_list(self, path_segments) -> bool:
        if not self.has_child(path_segments[0]):
            return False
        if len(path_segments) == 1:
            return True
        return self.children(path_segments[0]).has_list(path_segments[1:])
    
    def has_child(self, name) -> bool:
        return name in self.children
    
    def to_json(self) -> str:
        if len(self.children) == 0:
            return '{"name":"' + self.name + '"}'
        else:
            child_json = ",".join([c.to_json() for c in self.children.values()])
            return '{"name":"' + self.name + '","children":[' + child_json + ']}'
    
    def node_count(self) -> int:
        return sum([c.node_count() for c in self.children.values()]) + 1