## 1. crawl JAVA repos after 2023.8.1

In [None]:
lang="java"
page=1
date="2023-08-01"

In [None]:
import requests
import json

# Set the base URL for the GitHub API
base_url = "https://api.github.com"

# Set the search query parameters
query = f"language:{lang}+created:>{date}&sort=stars&order=desc&per_page=100&page={page}"

# Send a GET request to the search repositories endpoint
response = requests.get(f"{base_url}/search/repositories?q={query}")

# Check if the request was successful
if response.status_code == 200:
    # Get the JSON response
    data = response.json()
    
    # Extract the list of repositories
    repositories = data["items"]
    
    with open(f"repository_lists/{lang}_{date}_{page}.json", "w") as f:
        # Write the list of repositories to the file
        json.dump(repositories, f, indent=4)
else:
    print("Error: Failed to retrieve repositories from GitHub")


## 2. clone repos & filter out JAVA repos

In [None]:
import os
from loguru import logger

for repository in repositories:
    # Clone the repository
    if not os.path.isdir(f"repositories/{repository['name']}"):
        os.system(f"cd repositories;git clone {repository['clone_url']}")

    # log it
    logger.info(f"Cloned {repository['name']}")
    
    # check if the repository is a maven project or a gradle project
    if not os.path.isfile(f"repositories/{repository['name']}/pom.xml") and not os.path.isfile(f"repositories/{repository['name']}/build.gradle"):
        os.system(f"rm -rf repositories/{repository['name']}")
        logger.info(f"Removed {repository['name']} because it is not a maven project or a gradle project")

## 3. Detect Strange Identifiers

- build treesitter for java

In [None]:
import os
import sys
from tree_sitter import Language

Language.build_library("tree_sitter_build/language_set.so", ["tree-sitter-java"])


- helper functions to parse Java repositories

In [None]:
from tree_sitter import Parser, Language
JA_LANGUAGE = Language("tree_sitter_build/language_set.so", "java")
parser = Parser()
parser.set_language(JA_LANGUAGE)

class Identifier:
    def __init__(self, node, kind="identifier"):
        self.node = node
        self.name = node.text.decode()
        self.full_name = node.text.decode()
        self.start_byte = node.start_byte
        self.end_byte = node.end_byte
        self.start_row = node.start_point[0]
        self.end_row = node.end_point[0]
        self.start_col = node.start_point[1]
        self.end_col = node.end_point[1]
        self.kind = kind

    def __dict__(self):
        return {
            "name": self.name,
            "full_name": self.full_name,
            "start_byte": self.start_byte,
            "end_byte": self.end_byte,
            "start_row": self.start_row,
            "end_row": self.end_row,
            "start_col": self.start_col,
            "end_col": self.end_col,
            "kind": self.kind
        }
        

def getParsedTree(src):
    bytecodes = bytes(src, "utf8")
    tree = parser.parse(bytecodes)
    return tree

def getIdentifier(tree):
    for child in tree.root_node.children:
        if child.type == "identifier":
            yield child

def getIdentifiersByQuery(tree, queries):
    query_str = "\n".join(queries)
    query = JA_LANGUAGE.query(query_str)
    captures = query.captures(tree.root_node)
    return captures

tree_obj = None
remove_itr = 0
def checkOverlappedIdentifier(tree, identifier_list, identifier):
    global remove_itr
    global tree_obj
    if tree != tree_obj:
        remove_itr = 0
        tree_obj = tree
    # print(remove_itr)
    for i_idx in range(remove_itr, len(identifier_list)):
        i = identifier_list[i_idx]
        if i.start_byte <= identifier.start_byte and i.end_byte >= identifier.start_byte or i.start_byte <= identifier.end_byte and i.end_byte >= identifier.end_byte:
            remove_itr = i_idx
            return True
        elif i.start_byte > identifier.end_byte:
            remove_itr = i_idx
            return False
    remove_itr = len(identifier_list)
    return False

def getLongIdentifiers(tree, byte_str, to_remove):
    queries = [
        # """(method_invocation
        #         object: (_) @dot ?
        #         name: (_) @dot) @call""",
        # """(field_access
        #         object: (_) @dot
        #         field: (_) @dot) @field_access""",
        """(identifier) @identifier""",
    ]

    queriedLongIdentifier = getIdentifiersByQuery(tree, queries=queries)
    long_identifiers = []
    # for idx, long_identifier in enumerate(queriedLongIdentifier):
    #     if long_identifier[1] == "dot":
    #         if queriedLongIdentifier[idx-1][1] != "dot":
    #             long_identifiers.append(Identifier(long_identifier[0]))
    #         else:
    #             long_identifiers[-1].name += "." + long_identifier[0].text.decode()
    #             long_identifiers[-1].end_byte = long_identifier[0].end_byte
    #             long_identifiers[-1].end_row = long_identifier[0].end_point[0]
    #             long_identifiers[-1].end_col = long_identifier[0].end_point[1]
    #     elif long_identifier[1] == "keep_last":
    #         if queriedLongIdentifier[idx-1][1] != "keep_last":
    #             long_identifiers.append(Identifier(long_identifier[0]))
    #         else:
    #             long_identifiers[-1].name = long_identifier[0].text.decode()
    #             long_identifiers[-1].end_byte = long_identifier[0].end_byte
    #             long_identifiers[-1].end_row = long_identifier[0].end_point[0]
    #             long_identifiers[-1].end_col = long_identifier[0].end_point[1]

    for idx, identifier in enumerate(queriedLongIdentifier):
        if checkOverlappedIdentifier(tree, to_remove, Identifier(identifier[0])):
            continue
        if idx != 0 and queriedLongIdentifier[idx-1][0].end_byte + 1 == identifier[0].start_byte and byte_str[queriedLongIdentifier[idx-1][0].end_byte] == 46:
            long_identifiers[-1].name = identifier[0].text.decode()
            long_identifiers[-1].full_name += "." + identifier[0].text.decode()
            long_identifiers[-1].end_byte = identifier[0].end_byte
            long_identifiers[-1].end_row = identifier[0].end_point[0]
            long_identifiers[-1].end_col = identifier[0].end_point[1]
        else:
            long_identifiers.append(Identifier(identifier[0]))

    # print([i.name for i in long_identifiers])

    return long_identifiers

def getDeclaredIdentifiers(tree):
    queries = [
        """(class_declaration
                name: (_) @identifier)""",
        """(method_declaration
                name: (_) @identifier)""",
        """(formal_parameter
                name: (_) @identifier)""",
        """(variable_declarator
                name: (_) @identifier)""",
        """(package_declaration (
                scoped_identifier
                    scope: (_) @keep_last ?
                    name: (_) @keep_last
                )) @package""",
        """(import_declaration (
                scoped_identifier
                    scope: (_) @keep_last ?
                    name: (_) @keep_last
                )) @import""",
        """(inferred_parameters
                (identifier) @identifier)""",
        """(marker_annotation
                name: (_) @identifier)""",
    ]

    queriedIdentifier = getIdentifiersByQuery(tree, queries=queries)
    # print(queriedIdentifier)
    identifiers = []
    for idx, identifier in enumerate(queriedIdentifier):
        if identifier[1] == "identifier":
            identifiers.append(Identifier(identifier[0], kind="declared"))
        elif identifier[1] == "keep_last":
            if queriedIdentifier[idx-1][1] != "keep_last":
                identifiers.append(Identifier(identifier[0], kind="declared"))
            else:
                identifiers[-1].name = identifier[0].text.decode()
                identifiers[-1].full_name += "." + identifier[0].text.decode()
                identifiers[-1].end_byte = identifier[0].end_byte
                identifiers[-1].end_row = identifier[0].end_point[0]
                identifiers[-1].end_col = identifier[0].end_point[1]
        # elif identifier[1] == "scope":
        #     identifiers.append(Identifier(identifier[0], kind="declared"))
    return identifiers

def getReferencedIdentifiers(tree, to_remove):
    global remove_itr
    global tree_obj
    if tree != tree_obj:
        remove_itr = 0
        tree_obj = tree

    queries = [
        """(identifier) @identifier""",
    ]

    queriedIdentifier = getIdentifiersByQuery(tree, queries=queries)
    identifiers = []
    for identifier_idx in range(remove_itr, len(queriedIdentifier)):
        identifier = queriedIdentifier[identifier_idx]
        if not checkOverlappedIdentifier(tree, to_remove, Identifier(identifier[0])):
            identifiers.append(Identifier(identifier[0]))
    return identifiers

- get identifier lists

In [None]:
def getIdentifiers(file):
    with open(file, 'r') as f:
        src = f.read()
    tree = getParsedTree(src)

    declared_identifiers = getDeclaredIdentifiers(tree)
    # print([i.name for i in declared_identifiers])

    long_identifiers = getLongIdentifiers(tree, bytes(src, "utf8"), declared_identifiers)

    # to_remove = long_identifiers + declared_identifiers
    # to_remove = sorted(to_remove, key=lambda x: x.start_byte)

    # identifiers = getReferencedIdentifiers(tree, to_remove)

    return long_identifiers, declared_identifiers

- filter out strange identifiers

In [None]:
def getStrangeIdentifiers(file, forget=9999):
    long_identifiers, declared_identifiers = getIdentifiers(file)
    # print(file)
    # print([i.name for i in long_identifiers])
    # print([i.name for i in declared_identifiers])

    identifier_map = {}

    all_identifiers = long_identifiers + declared_identifiers
    all_identifiers = sorted(all_identifiers, key=lambda x: x.start_byte)
    for identifier in all_identifiers:
        if identifier.name not in identifier_map:
            identifier_map[identifier.name] = []
        identifier_map[identifier.name].append(identifier)

    strange_identifiers = []
    for identifier in identifier_map:
        for idx, inst in enumerate(identifier_map[identifier]):
            if inst.kind == "declared":
                continue
            if idx == 0:
                strange_identifiers.append(inst)
            elif inst.end_row - identifier_map[identifier][idx-1].end_row > forget:
                strange_identifiers.append(inst)

    # print([i.name for i in strange_identifiers])
            
    return strange_identifiers

- traverse all repos and detect strange identifiers

In [None]:
repos_dir = "/Users/tannpopo/coding/coding-interfere/repo_to_mine"
forget = 9999
output_dir = f"/Users/tannpopo/coding/coding-interfere/strange_identifiers_{forget}" 
import os
import sys
from tqdm import tqdm
import json
from loguru import logger
import random

random.seed(16)

def getAllJavaFiles(repo_dir):
    java_files = []
    for root, dirs, files in os.walk(repo_dir):
        for file in files:
            if file.endswith(".java"):
                java_files.append(os.path.join(root, file))

    java_files = random.sample(java_files, min(len(java_files), 200))
    return java_files

def getStrangeIdentifiersInRepo(repo_dir, forget=9999):
    strange_identifiers = []
    java_files = getAllJavaFiles(repo_dir)
    for file in tqdm(java_files):
        cur_identifiers = getStrangeIdentifiers(file, forget=forget)
        for idx, identifier in enumerate(cur_identifiers):
            strange_identifiers.append({
                "file_path": file,
                "strange_identifier": identifier.__dict__(),
            })
    return strange_identifiers

# identifier_test = getStrangeIdentifiers("/Users/tannpopo/coding/coding-interfere/repo_to_mine/daydayEXP/src/main/java/com/bcvgh/core/BaseTemplate.java")
# print([i.__dict__() for i in identifier_test])

repo_list = os.listdir(repos_dir)
os.makedirs(output_dir, exist_ok=True)
all_cnt = 0
for repository_dir in repo_list:
    # if os.path.isfile(f"{output_dir}/{repository_dir}.json"):
    #     continue
    logger.info(f"Start {repository_dir}")
    if not os.path.isdir(os.path.join(repos_dir, repository_dir)):
        continue
    strange_identifiers = getStrangeIdentifiersInRepo(os.path.join(repos_dir, repository_dir), forget=forget)
    logger.info(f"Finished {repository_dir}, {len(strange_identifiers)} strange identifiers found")
    all_cnt += len(strange_identifiers)
    with open(f"{output_dir}/{repository_dir}.json", "w") as f:
        json.dump(strange_identifiers, f, indent=4)

logger.info(f"Finished all, {all_cnt} strange identifiers found")

## 4. create dot training instances

In [1]:
import os
import json

from monitors4codegen.multilspy import SyncLanguageServer
from monitors4codegen.multilspy.multilspy_config import MultilspyConfig
from monitors4codegen.multilspy.multilspy_logger import MultilspyLogger

forget_gap = 9999
repos_dir = "/Users/tannpopo/coding/coding-interfere/repo_to_mine"
strange_identifiers_dir = f"/Users/tannpopo/coding/coding-interfere/strange_identifiers_{forget_gap}"

config = MultilspyConfig.from_dict({"code_language": "java"}) # Also supports "python", "rust", "csharp"
logger = MultilspyLogger()

def readInStrangeIdentifiers(repo):
    with open(f"{strange_identifiers_dir}/{repo}.json", "r") as f:
        strange_identifiers = json.load(f)
    return strange_identifiers

def buildForRepo(repo):
    repo_path = os.path.join(repos_dir, repo)
    strange_identifiers = readInStrangeIdentifiers(repo)
    lsp = SyncLanguageServer.create(config, logger, repo_path)
    results = []
    with lsp.start_server():
        for strange_identifier in tqdm(strange_identifiers):
            if "." not in strange_identifier["full_name"]:
                continue
            completion_items = lsp.request_completion(strange_identifier["file_path"], strange_identifier["strange_identifier"]["start_row"], strange_identifier["strange_identifier"]["start_col"]+len(strange_identifier["strange_identifier"]["full_name"])-len(strange_identifier["strange_identifier"]["name"])-1)
            if completion_items is None or completion_items == []:
                continue
            strange_identifier["completion_items"] = []
            for completion_item in completion_items:
                strange_identifier["completion_items"].append(completion_item.__dict__())
            results.append(strange_identifier)

buildForRepo("daydayEXP")

RuntimeError: Task <Task pending name='Task-5' coro=<_AsyncGeneratorContextManager.__aenter__() running at /Users/tannpopo/Env/anaconda3/envs/pytorch/lib/python3.9/contextlib.py:175> cb=[_chain_future.<locals>._call_set_state() at /Users/tannpopo/Env/anaconda3/envs/pytorch/lib/python3.9/asyncio/futures.py:391]> got Future <Future pending> attached to a different loop