In [1]:
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
import os
import json

DATASET_PATH = "resources/train"  # Root folder containing the dataset
methods = []  # List to store methods
tests = []    # List to store test cases

def load_methods_and_tests(dataset_path):
    """
    Load methods and corresponding test cases from the Methods2Tests dataset.
    
    Args:
        dataset_path (str): Root directory of the dataset.
        
    Returns:
        list: A list of Java methods (source code).
        list: A list of corresponding test cases (source code).
    """
    methods = []
    tests = []

    # Traverse dataset directory
    for project_folder in os.listdir(dataset_path):
        print(project_folder)
        project_path = os.path.join(dataset_path, project_folder)
        if os.path.isdir(project_path):  # Ensure it's a folder
            for file_name in os.listdir(project_path):
                file_path = os.path.join(project_path, file_name)
                if file_name.endswith(".json"):  # Process JSON files
                    with open(file_path, 'r', encoding='utf-8') as f:
                        try:
                            data = json.load(f)

                            # Extract focal method
                            focal_method = data.get("focal_method", {}).get("body", "")
                            if focal_method:
                                methods.append(focal_method)

                            # Extract test case
                            test_case = data.get("test_case", {}).get("body", "")
                            if test_case:
                                tests.append(test_case)

                        except json.JSONDecodeError:
                            print(f"Error decoding JSON in file: {file_path}")

    return methods, tests

# Load data
methods, tests = load_methods_and_tests(DATASET_PATH)

# Output stats
print(f"Loaded {len(methods)} methods and {len(tests)} test cases.")

1001284
10035739
10057936
10062583
10064418
10065023
10075325
10172073
1017889
10192655
10195984
10227537
10230369
1025574
10262572
10277647
10288749
10294926
103035
10311319
10329619
10354997
10359213
10385460
10400052
10422484
10429628
10472733
10476504
10522064
10532531
10548179
10550221
10553586
10597460
10609316
10609318
10617959
10624930
10646587
10663360
10667695
1068402
10692925
1071767
10726093
10737956
10746583
1076094
10761704
10762348
1077753
10785687
1079636
10806441
10828921
10837431
10856630
1089023
10912209
1091655
1091966
10937119
10945339
10953866
10974543
1097987
10983382
1107344
11074539
1110934
11128472
11140459
11144122
11159652
11161563
1116314
11178914
11194069
11198387
11217397
11219096
11226198
11233417
11281607
11290309
11324849
11369509
11384368
11395286
11429654
11453959
11459376
1146205
11479356
11525787
11558041
1155836
1155961
1156021
1156136
11566581
11602449
1161604
1163806
1164965
116547
1165813
11660383
11668780
1167439
1169309
11695624
1171506
11754

In [3]:
vectorizer = TfidfVectorizer()
method_vectors = vectorizer.fit_transform(methods)

In [4]:
n_clusters = 100
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(method_vectors)
clusters = kmeans.labels_

  super()._check_params_vs_input(X, default_n_init=10)


In [5]:
input_method = "public int add(int a, int b) { return a + b; }"
input_vector = vectorizer.transform([input_method])
input_cluster = kmeans.predict(input_vector)[0]

In [6]:
cluster_indices = [i for i, c in enumerate(clusters) if c == input_cluster]
candidate_tests = [tests[i] for i in cluster_indices]

In [7]:
test_scores = [cosine_similarity(input_vector, vectorizer.transform([t])) for t in candidate_tests]
best_test = candidate_tests[np.argmax(test_scores)]

In [8]:
print("Generated Test Case:")
print(best_test)

Generated Test Case:
@Test
    public void testFindTopKNumbersFromThreeSorted1() {
	int[] res = findTopKNumbersFromThreeSortedArrays(new int[] { 1, 2 },
		new int[] { 1, 3 }, new int[] { 1, 4 }, 1);
	assertThat(res, is(new int[] { 4 }));
    }
