# Pre-training Dataset Exploration

### This notebook explores the XLCoST dataset released by https://github.com/reddy-lab-code-research/XLCoST, the CodeSearchNet dataset cleaned and released by Microsoft https://github.com/microsoft/CodeBERT/tree/master/GraphCodeBERT/codesearch, and the CodeNet dataset released by Microsoft https://github.com/microsoft/CodeBERT/tree/master/UniXcoder.

### XLCoST dataset:
XLCoST, a machine learning benchmark dataset that contains fine-grained parallel data in 7 commonly used programming languages (C++, Java, Python, C#, Javascript, PHP, C), and natural language (English). The data is parallel across 7 languages, at both code snippet and program level. This means that given a program in one language, the dataset contains the same program in up to 6 other programming languages. Each program is divided into several code snippets, and programs in all the languages are aligned at the snippet level.

### Cross-Lingual (XL) Code Search Task dataset:
The data for XL Code Search can be found in the 'retrival/code2code_search' directory. It is further sub divide into 'program_level' and 'snippet_level'. The directory names (languages) represent the query language.

### CodeSearchNet dataset:
CodeSearchNet, a large corpus of methods extracted from popular GitHub repositories. The primary dataset consists of 2 million (comment, code) pairs from open source libraries. Concretely, a 'comment' is a top-level function or method, and 'code' is an entire function or method. Currently, the dataset contains Python, Javascript, Ruby, Go, Java and PHP functions. The dataset is partitioned into train, validation and test sets such that code from the same repository can only exist in one partition.

### CodeNet dataset:
The CodeNet dataset consists of a large collection of code samples with extensive metadata. It is derived from the data available on two online judge websites: AIZU and AtCoder. CodeNet contains a total of 13,916,868 submissions, divided into 4053 problems. Among the submissions, 53.6% are accepted, 29.5% are marked with wrong answer, and the remaining rejected due to their failure to meet run time or memory requirements. Submissions are in 55 different languages: 95% of them are coded in C++, Python, Java, C, Ruby and C#.

### Zero-Shot Code-to-Code Search dataset  (CCS dataset):
CCS dataset is collected from the CodeNet corpus. They collected 11,744/15,594/23,530 functions in Ruby/Python/Java. Each function solves one of 4,053 problems.

### Notes:  
XL Code Search Dataset: pair data at the program level, sample data has only code tokens.  
CodeSearchNet Dataset: sample data at the function level has both code tokens (3~256 in length) and original code string. However, no pair data is presented. Every function is an individual data point. There are no explicit mappings among functions.   
CCS Dataset: sample data at the function level, each data has only original code string. Functions solving the same problem are considered to be semantic clones of each other. Pair data can be further constructed based on the labels.


In [1]:
import os
import sys
import json
import jsonlines

import numpy as np
import pandas as pd
pd.set_option('max_colwidth', 300)
from pprint import pprint

### Part 1: Preview the XLCoST Dataset
For the XL Code Search dataset, we intend to take a look at the data in the directory of 'XLCoST/retrieval/code2code_search/program_level/{Java, Python}' and that in the directory of 'XLCoST/retrieval/code2code_search/snippet_level/{Java, Python}'. The name of the sub-directory represents the query language and the retrieval result contains program or snippet mappings of other 6 languages.


In [2]:
# define dataset root_dir

xlcost_root = '/Users/rongdang/Desktop/semantic-code-clone/dataset/cross-language/XLCoST/retrieval/code2code_search'


In [3]:
# display sample data at the program level

level = 'program_level'
language = 'Java'

program_dir = os.path.join(xlcost_root, level, language)

set_name = 'test.jsonl'
set_program_dir = os.path.join(program_dir, set_name)

with open(set_program_dir, 'r') as f:
    sample_file = f.readlines()
    for sample_line in sample_file:
        sample_data = json.loads(sample_line)
        idx = sample_data['idx'].split('/')[0].split('-')[0]
        tgt_language = sample_data['idx'].split('/')[-1].split('-')[-1]
        if idx == '10062' and tgt_language == 'Python':
            pprint(sample_data)
            break

print('----------------------------------------------------------------')
print('idx: ' + sample_data['idx'])
print('url: ' + sample_data['url'])
print('src:\n' + ' '.join(sample_data['docstring_tokens']))
print('tgt:\n' + ' '.join(sample_data['code_tokens']))

f.close()

{'code_tokens': ['def',
                 'minSum',
                 '(',
                 'A',
                 ',',
                 'N',
                 ')',
                 ':',
                 'NEW_LINE',
                 'INDENT',
                 'mp',
                 '=',
                 '{',
                 '}',
                 'NEW_LINE',
                 'sum',
                 '=',
                 '0',
                 'NEW_LINE',
                 'for',
                 'i',
                 'in',
                 'range',
                 '(',
                 'N',
                 ')',
                 ':',
                 'NEW_LINE',
                 'INDENT',
                 'sum',
                 '+=',
                 'A',
                 '[',
                 'i',
                 ']',
                 'NEW_LINE',
                 'if',
                 'A',
                 '[',
                 'i',
                 ']',
                 'in',
         

A Java-Python language pair sample data at the 'program_level' is shown above. The 'idx' key denotes {query_pid}-{query_lang}/{target_pid}-{target_lang} to uniquely identify each data point. 'docstring_tokens' maps to the list of code tokens for the query language. 'code_tokens' maps to the list of code tokens for the target language. 

PS.  
'docstring_tokens' and 'code tokens' are further displayed by joining the list of tokens.  
Sample data from the XLCoST dataset only contains code tokens. 'Code' can be joined by '\s' to form a string of code tokens.

In [4]:
# display sample data at the snippet level

level = 'snippet_level'
language = 'Java'

program_dir = os.path.join(xlcost_root, level, language)

set_name = 'test.jsonl'
set_program_dir = os.path.join(program_dir, set_name)

with open(set_program_dir, 'r') as f:
    sample_file = f.readlines()
    for sample_line in sample_file:
        sample_data = json.loads(sample_line)
        idx = sample_data['idx'].split('/')[0].split('-')[0]
        tgt_language = sample_data['idx'].split('/')[-1].split('-')[-2]
        if idx == '10062' and tgt_language == 'Python':
            pprint(sample_data)
            break

print('----------------------------------------------------------------')
print('idx: ' + sample_data['idx'])
print('url: ' + sample_data['url'])
print('src:\n' + ' '.join(sample_data['docstring_tokens']))
print('tgt:\n' + ' '.join(sample_data['code_tokens']))

f.close()

{'code_tokens': ['def', 'minSum', '(', 'A', ',', 'N', ')', ':', 'NEW_LINE'],
 'docstring_tokens': ['static',
                      'int',
                      'minSum',
                      '(',
                      'int',
                      'A',
                      '[',
                      ']',
                      ',',
                      'int',
                      'N',
                      ')',
                      '{'],
 'idx': '10062-Java-2/10062-Python-2',
 'url': '10062-Java-2/10062-Python-2'}
----------------------------------------------------------------
idx: 10062-Java-2/10062-Python-2
url: 10062-Java-2/10062-Python-2
src:
static int minSum ( int A [ ] , int N ) {
tgt:
def minSum ( A , N ) : NEW_LINE


A Java-Python language pair sample data at the 'snippet_level' is shown above. The 'idx' key denotes {query_pid}-{query_lang}-{snippet_id}/{target_pid}-{target_lang}-{snippet_id} to uniquely identify each data point. 'docstring_tokens' maps to the list of code tokens for the query language. 'code_tokens' maps to the list of code tokens for the target language. 

PS. 'docstring_tokens' and 'code tokens' are further displayed by joining the list of tokens.

### Part 2: Analyze the XLCoST Dataset
For the XL Code Search dataset, we intend to explore Java-Python and Python-Java langugage pairs. We will investigate dataset statistics in terms of: 1) the number of queries; 2) the total number of language pairs; 3) the average number of code tokens in each language.

In [5]:
# analyze Java-Python pairs at the program level

level = 'program_level'
language = 'Java'
sets = ['train.jsonl', 'val.jsonl', 'test.jsonl']

java_py_pairs_num = 0
query_list = list()
java_token_nums = list()
python_token_nums = list()

for set_name in sets:
    set_program_dir = os.path.join(xlcost_root, level, language, set_name)
    # print(set_program_dir)
    with open(set_program_dir, 'r') as f:
        sample_file = f.readlines()
    for sample_line in sample_file:
        sample_data = json.loads(sample_line)
        query_id = sample_data['idx'].split('/')[0].split('-')[0]
        tgt_language = sample_data['idx'].split('/')[-1].split('-')[-1]
        if tgt_language == 'Python':
            java_py_pairs_num += 1
            query_list.append(query_id)
            java_token_nums.append(len(sample_data['docstring_tokens']))
            python_token_nums.append(len(sample_data['code_tokens']))
    f.close()

print('-----Java-Python-----program-level statistics-----')
print('total queries: ' + str(len(list(set(query_list)))))
print('total pairs: ' + str(java_py_pairs_num))
print('avg number of tokens in Java program: ' + str(round(sum(java_token_nums)/len(java_token_nums), 2)))
print('avg number of tokens in Python program: ' + str(round(sum(python_token_nums)/len(python_token_nums), 2)))


-----Java-Python-----program-level statistics-----
total queries: 10344
total pairs: 10344
avg number of tokens in Java program: 225.77
avg number of tokens in Python program: 194.18


In [6]:
# analyze Python-Java pairs at the program level (for validation)

language = 'Python'

py_java_pairs_num = 0
query_list = list()
python_token_nums = list()
java_token_nums = list()

for set_name in sets:
    set_program_dir = os.path.join(xlcost_root, level, language, set_name)
    # print(set_program_dir)
    with open(set_program_dir, 'r') as f:
        sample_file = f.readlines()
    for sample_line in sample_file:
        sample_data = json.loads(sample_line)
        query_id = sample_data['idx'].split('/')[0].split('-')[0]
        tgt_language = sample_data['idx'].split('/')[-1].split('-')[-1]
        if tgt_language == 'Java':
            py_java_pairs_num += 1
            query_list.append(query_id)
            python_token_nums.append(len(sample_data['docstring_tokens']))
            java_token_nums.append(len(sample_data['code_tokens']))
    f.close()

print('-----Python-Java-----program-level statistics-----')
print('total queries: ' + str(len(list(set(query_list)))))
print('total pairs: ' + str(py_java_pairs_num))
print('avg number of tokens in Python program: ' + str(round(sum(python_token_nums)/len(python_token_nums), 2)))
print('avg number of tokens in Java program: ' + str(round(sum(java_token_nums)/len(java_token_nums), 2)))


-----Python-Java-----program-level statistics-----
total queries: 10344
total pairs: 10344
avg number of tokens in Python program: 194.18
avg number of tokens in Java program: 225.77


### As shown above, total number of queries is equal to total number of language pairs (answers). This means that 1 query has exactly and only 1 answer pair.  
### In the paper, the author mentioned that some problems are very similar. For example, "Check if a large number is divisible by 3 or not" and "Check whether a large number is divisible by 53 or not". Therefore, during contrastive learning, it is hard to assume that samples inside the same batch do not have similar functionalities.

### Part 3: Preview the CodeSearchNet Dataset
For the CodeSearchNet dataset cleaned and released by Microsoft, we intend to take a look at the data at the function level in Java and Python.  
Different from the original CodeSearchNet, the answer of each query is retrieved from the whole development and testing code corpus instead of 1,000 candidate codes. Besides, some queries contain content unrelated to the code, such as a link that refers to external resources. Therefore, following examples are filtered to improve the quality of the dataset.  
1.Remove comments in the code;  
2.Remove examples that codes cannot be parsed into an abstract syntax tree (AST);  
3.Remove examples that #tokens of documents is < 3 or > 256;  
4.Remove examples that documents contain special tokens (e.g. <img...>)  
5.Remove examples that documents are not English.  


In [2]:
# define dataset root_dir

csn_root = '/Users/rongdang/Desktop/semantic-code-clone/dataset/cross-language/CodeSearchNet_Microsoft/dataset'


In [4]:
# display one sample data in Java

language = 'java'
set_name = 'valid.jsonl'

set_program_dir = os.path.join(csn_root, language, set_name)
with open(set_program_dir, 'r') as f:
    sample_file = f.readlines()
sample_data = json.loads(sample_file[2324])
pprint(sample_data)
f.close()


{'code': 'public boolean createSqlTable (Connection connection, String '
         'namespace, boolean makeIdSerial, String[] primaryKeyFields) {\n'
         '        // Optionally join namespace and name to create full table '
         'name if namespace is not null (i.e., table object is\n'
         '        // a spec table).\n'
         '        String tableName = namespace != null ? String.join(".", '
         'namespace, name) : name;\n'
         '        String fieldDeclarations = Arrays.stream(fields)\n'
         '                .map(Field::getSqlDeclaration)\n'
         '                .collect(Collectors.joining(", "));\n'
         '        if (primaryKeyFields != null) {\n'
         '            fieldDeclarations += String.format(", primary key (%s)", '
         'String.join(", ", primaryKeyFields));\n'
         '        }\n'
         '        String dropSql = String.format("drop table if exists %s", '
         'tableName);\n'
         '        // Adding the unlogged keyword

In [5]:
# display one sample data in Python

language = 'python'
set_name = 'valid.jsonl'

set_program_dir = os.path.join(csn_root, language, set_name)
with open(set_program_dir, 'r') as f:
    sample_file = f.readlines()
sample_data = json.loads(sample_file[2344])
pprint(sample_data)
f.close()


{'code': 'def create_graph_from_data(self, data, **kwargs):\n'
         '        """Apply causal discovery on observational data using CAM.\n'
         '\n'
         '        Args:\n'
         '            data (pandas.DataFrame): DataFrame containing the data\n'
         '\n'
         '        Returns:\n'
         '            networkx.DiGraph: Solution given by the CAM algorithm.\n'
         '        """\n'
         '        # Building setup w/ arguments.\n'
         "        self.arguments['{SCORE}'] = self.scores[self.score]\n"
         "        self.arguments['{CUTOFF}'] = str(self.cutoff)\n"
         "        self.arguments['{VARSEL}'] = str(self.variablesel).upper()\n"
         "        self.arguments['{SELMETHOD}'] = "
         'self.var_selection[self.selmethod]\n'
         "        self.arguments['{PRUNING}'] = str(self.pruning).upper()\n"
         "        self.arguments['{PRUNMETHOD}'] = "
         'self.var_selection[self.prunmethod]\n'
         "        self.arguments['{N

Attributes of data point are described in details at: https://github.com/github/CodeSearchNet#data-details

### Part 4: Analyze the CodeSearchNet Dataset
Statistics about the cleaned dataset released by Microsoft are shown below.  
   | Programming Language      | Train     | Valid      | Test     | Candidate codes    |   
   | :---------- | :---------- | :---------- | :---------- | :----------  |     
   | Java  |164,923      | 5,183         | 10,955         | 13,981        |  
   | Python  |251,820    | 13,914        | 14,918         | 42,827        |    

For the CodeSearchNet dataset, we intend to explore Java and Python functions. We will investigate dataset statistics in terms of: 1) the average number of code lines in each language and 2) the average number of code tokens in each language.

In [11]:
def count_code_lines_1(code_str):
    line_num = 0
    code_lines = code_str.split('\n')
    for line in code_lines:
        if line.strip() != '\n' and line.strip() != '':
            line_num += 1
    return line_num

In [12]:
# analyze the Java functions
language = 'java'
sets = ['train.jsonl', 'valid.jsonl', 'test.jsonl']

line_nums = list()
token_nums = list()

for set_name in sets:
    set_program_dir = os.path.join(csn_root, language, set_name)
    with open(set_program_dir, 'r') as f:
        sample_file = f.readlines()
    for sample_line in sample_file:
        sample_data = json.loads(sample_line)
        line_nums.append(count_code_lines_1(sample_data['code']))
        token_nums.append(len(sample_data['code_tokens']))
    f.close()

print('-----Java-----function-level statistics-----')
print('avg number of lines in Java function: ' + str(round(sum(line_nums)/len(line_nums), 2)))
print('avg number of tokens in Java function: ' + str(round(sum(token_nums)/len(token_nums), 2)))

-----Java-----function-level statistics-----
avg number of lines in Java function: 14.52
avg number of tokens in Java function: 99.65


In [13]:
# analyze the Python functions
language = 'python'
sets = ['train.jsonl', 'valid.jsonl', 'test.jsonl']

line_nums = list()
token_nums = list()

for set_name in sets:
    set_program_dir = os.path.join(csn_root, language, set_name)
    with open(set_program_dir, 'r') as f:
        sample_file = f.readlines()
    for sample_line in sample_file:
        sample_data = json.loads(sample_line)
        line_nums.append(count_code_lines_1(sample_data['code']))
        token_nums.append(len(sample_data['code_tokens']))
    f.close()

print('-----Python-----function-level statistics-----')
print('avg number of lines in Python function: ' + str(round(sum(line_nums)/len(line_nums), 2)))
print('avg number of tokens in Python function: ' + str(round(sum(token_nums)/len(token_nums), 2)))

-----Python-----function-level statistics-----
avg number of lines in Python function: 20.41
avg number of tokens in Python function: 98.17


### Sample data in the CodeSearchNet dataset are individual functions in different programming languages. There are no explicit mappings among these functions. Therefore, before contrastive learning, we need to construct positive samples using data augmentation techniques.

### Part 5: Preview the CCS Dataset
For the CCS dataset, we intend to take a look at the data in the file of 'java_with_func.jsonl' and 'python_with_func.jsonl'.

In [2]:
# define original and pre-processed dataset root_dir

ccs_root = '/Users/rongdang/Desktop/semantic-code-clone/dataset/cross-language/CodeNet_Microsoft'

original_root = os.path.join(ccs_root, 'dataset')
preprocessed_root = os.path.join(ccs_root, 'preprocessed_dataset')


In [3]:
# display sample data (with comments) from the original dataset/java_with_func.jsonl file

language = 'java'
language_file = language + '_with_func.jsonl'

language_program_dir = os.path.join(original_root, language_file)
with open(language_program_dir, 'r') as f:
    lines = f.readlines()
#     for i in range(0, len(lines)):
#         line = lines[i]
#         sample_data = json.loads(line)
#         if sample_data['label'] == 3279:
#             print(i)
#             pprint(sample_data)
#             break
sample_data = json.loads(lines[269])
pprint(sample_data)
f.close()

{'func': 'import java.io.IOException;\n'
         'import java.io.InputStream;\n'
         'import java.io.OutputStream;\n'
         'import java.lang.reflect.Array;\n'
         'import java.lang.reflect.Field;\n'
         'import java.util.Arrays;\n'
         'import java.util.Collection;\n'
         'import java.util.HashSet;\n'
         'import java.util.Iterator;\n'
         'import java.util.NoSuchElementException;\n'
         'import java.util.Objects;\n'
         'import java.util.PrimitiveIterator;\n'
         'import java.util.RandomAccess;\n'
         'import java.util.TreeSet;\n'
         'import java.util.function.IntBinaryOperator;\n'
         'import java.util.function.IntFunction;\n'
         'import java.util.function.IntPredicate;\n'
         'import java.util.function.IntSupplier;\n'
         'import java.util.function.IntUnaryOperator;\n'
         'import java.util.function.LongBinaryOperator;\n'
         'import java.util.function.LongUnaryOperator;\n'
         '\n'

In [4]:
# display sample data (with comments) from the original dataset/python_with_func.jsonl file

language = 'python'
language_file = language + '_with_func.jsonl'

language_program_dir = os.path.join(original_root, language_file)
with open(language_program_dir, 'r') as f:
    lines = f.readlines()
sample_data = json.loads(lines[60])
pprint(sample_data)
f.close()

{'func': "def main(sample_file = ''):\n"
         '\n'
         '    """ convenient functions\n'
         '    # for i, a in enumerate(iterable)\n'
         '    # q, mod = divmod(a, b)\n'
         '    # divmod(x, y) returns the tuple (x//y, x%y)\n'
         '    # Higher-order function: reduce(operator.mul, xyz_count, 1)\n'
         '    # manage median(s) using two heapq '
         'https://atcoder.jp/contests/abc127/tasks/abc127_f\n'
         '    """\n'
         '\n'
         '    """convenient decorator\n'
         '    # @functools.lru_cache():\n'
         '    # to facilitate use of recursive function\n'
         '        # ex:\n'
         '        # from functools import lru_cache\n'
         '        # import sys\n'
         '        # sys.setrecursionlimit(10**9)\n'
         '        # @lru_cache(maxsize=None)\n'
         '        # def fib(n):\n'
         '        #     if n < 2:\n'
         '        #         return n\n'
         '        #     return fib(n-1) + fib(n-2)\n

As shown above, each data point has 3 attributes, namely, 'func', 'index' and 'label'. The 'func' attribute represents the original code string of the function with comments, special tokens or documents that are not in English. The 'index' attribute represents the submission_id of the function. The 'label' attribute represents the task_id of the function.  
### To construct a pre-training dataset based on the CCS dataset, we should ...
1) pre-process original code strings to remove comments, special tokens and documents that are not in English.  
2) construct pair data based on the label of the functions.

In [5]:
# display sample data from the pre-processed dataset/java_with_func.jsonl file

language = 'java'
language_file = language + '_with_func.jsonl'

language_program_dir = os.path.join(preprocessed_root, language_file)
with open(language_program_dir, 'r') as f:
    sample_file = f.readlines()
sample_data = json.loads(sample_file[269])
pprint(sample_data)
f.close()

{'func': 'import java.io.IOException;\n'
         'import java.io.InputStream;\n'
         'import java.io.OutputStream;\n'
         'import java.lang.reflect.Array;\n'
         'import java.lang.reflect.Field;\n'
         'import java.util.Arrays;\n'
         'import java.util.Collection;\n'
         'import java.util.HashSet;\n'
         'import java.util.Iterator;\n'
         'import java.util.NoSuchElementException;\n'
         'import java.util.Objects;\n'
         'import java.util.PrimitiveIterator;\n'
         'import java.util.RandomAccess;\n'
         'import java.util.TreeSet;\n'
         'import java.util.function.IntBinaryOperator;\n'
         'import java.util.function.IntFunction;\n'
         'import java.util.function.IntPredicate;\n'
         'import java.util.function.IntSupplier;\n'
         'import java.util.function.IntUnaryOperator;\n'
         'import java.util.function.LongBinaryOperator;\n'
         'import java.util.function.LongUnaryOperator;\n'
         'pub

In [6]:
# display sample data from the pre-processed dataset/python_with_func.jsonl file

language = 'python'
language_file = language + '_with_func.jsonl'

language_program_dir = os.path.join(preprocessed_root, language_file)
with open(language_program_dir, 'r') as f:
    sample_file = f.readlines()
sample_data = json.loads(sample_file[60])
pprint(sample_data)
f.close()

{'func': "def main(sample_file = ''):\n"
         '    import sys\n'
         '    sys.setrecursionlimit(10**7)\n'
         '    from itertools import accumulate, combinations, permutations, '
         'product, combinations_with_replacement \n'
         '    from math import factorial, ceil, floor, sqrt\n'
         '    def factorize(n):\n'
         '        fct = []  \n'
         '        b, e = 2, 0  \n'
         '        while b * b <= n:\n'
         '            while n % b == 0:\n'
         '                n = n // b\n'
         '                e = e + 1\n'
         '            if e > 0:\n'
         '                fct.append((b, e))\n'
         '            b, e = b + 1, 0\n'
         '        if n > 1:\n'
         '            fct.append((n, 1))\n'
         '        return fct\n'
         '    def combinations_count(n, r):   \n'
         '        if n < 0 or r < 0:\n'
         "            raise Exception('combinations_count(n, r) not defined "
         "when n or r is nega

### Part 6: Analyze the CCS Dataset
Statistics about the dataset released by Microsoft are shown below.  
   | Programming Language      | Total functions    | Total problems    |   
   | :----------               | :----------        | :----------       |     
   | Java    | 23,530         | 3,142/4,053        |  
   | Python  | 15,594         | 2,072/4,053        |     

For the CCS dataset, we intend to explore Java and Python functions. We will investigate dataset statistics in terms of the average number of code lines in each language.

In [7]:
def count_code_lines_2(code_str):
    line_num = 0
    code_lines = code_str.split('\n')
    for line in code_lines:
        line = line.strip('\n').strip('\t').strip()
        if line != '\n' and line != '\t' and line != '':
            line_num += 1
    return line_num

In [10]:
# analyze Java functions in the preprocessed dataset

language = 'java'
language_file = language + '_with_func.jsonl'
line_nums = list()
problem_stats_java = dict()

language_program_dir = os.path.join(preprocessed_root, language_file)
with open(language_program_dir, 'r') as f:
    sample_file = f.readlines()
for sample_line in sample_file:
    sample_data = json.loads(sample_line)
    line_nums.append(count_code_lines_2(sample_data['func']))
    sample_label = sample_data['label']
    if str(sample_label) not in problem_stats_java.keys():
        problem_stats_java[str(sample_label)] = 1
    else:
        problem_stats_java[str(sample_label)] += 1
f.close()

print('-----Java-----function-level statistics-----')
problem_nums = len(problem_stats_java.keys())
solution_nums = round(sum(problem_stats_java.values())/problem_nums, 2)
print('total number of problems: ' + str(problem_nums))
print('avg number of solutions per problem: ' + str(solution_nums))
print('avg number of lines in Java function: ' + str(round(sum(line_nums)/len(line_nums), 2)))


-----Java-----function-level statistics-----
total number of problems: 3142
avg number of solutions per problem: 7.49
avg number of lines in Java function: 106.49


In [11]:
# analyze the Python functions in the preprocessed dataset

language = 'python'
language_file = language + '_with_func.jsonl'
line_nums = list()
problem_stats_py = dict()

language_program_dir = os.path.join(preprocessed_root, language_file)
with open(language_program_dir, 'r') as f:
    sample_file = f.readlines()
for sample_line in sample_file:
    sample_data = json.loads(sample_line)
    line_nums.append(count_code_lines_2(sample_data['func']))
    sample_label = sample_data['label']
    if str(sample_label) not in problem_stats_py.keys():
        problem_stats_py[str(sample_label)] = 1
    else:
        problem_stats_py[str(sample_label)] += 1
f.close()

print('-----Python-----function-level statistics-----')
problem_nums = len(problem_stats_py.keys())
solution_nums = round(sum(problem_stats_py.values())/problem_nums, 2)
print('total number of problems: ' + str(problem_nums))
print('avg number of solutions per problem: ' + str(solution_nums))
print('avg number of lines in Python function: ' + str(round(sum(line_nums)/len(line_nums), 2)))


-----Python-----function-level statistics-----
total number of problems: 2072
avg number of solutions per problem: 7.53
avg number of lines in Python function: 23.22


### Part 7: Construct a Pre-training Dataset based on the CCS Dataset
We intend to construct a pre-training dataset based on the CCS dataset. Specifically, 1) we will find the intersection problem set of Java and Python solutions; 2) we will collect Java and Python solutions of the intersection problems and build Java-Python function pairs; 3) we will pre-process and split the pre-training dataset into training set and test set. The ratio of #num of train: #num of test is approximately 9:1.

In [12]:
# collect intersection problems of Java and Python solutions

intersection = list()
for py_key in problem_stats_py.keys():
    if py_key in problem_stats_java.keys():
        intersection.append(py_key)
        
print('Total number of intersection problems between Java and Python solutions is: ' + str(len(intersection )))


Total number of intersection problems between Java and Python solutions is: 2001


As shown above, 2001 problems have multiple solutions in Java and Python. Therefore, we can construct a pre-training dataset based on the CCS dataset. Positive sample pairs can be collected based on the 'label' attribute of the data.

In [8]:
# collect Java and Python solutions of the intersection problems
# build an all_functions.json file

submissions = {
    'java': dict(),
    'python': dict()
}

languages = ['java', 'python']
for language in languages:
    language_file = language + '_with_func.jsonl'
    language_program_dir = os.path.join(preprocessed_root, language_file)
    with open(language_program_dir, 'r') as f:
        lines = f.readlines()
        for line in lines:
            sample_data = json.loads(line)
            label = str(sample_data['label'])
            if label in intersection:
                if label not in submissions[language].keys():
                    submissions[language][label] = dict()
                    submissions[language][label][sample_data['index']] = str(sample_data['func'])
                else:
                    submissions[language][label][sample_data['index']] = str(sample_data['func'])
    f.close()

result_file_path = os.path.join(preprocessed_root, 'all_functions.json')
with open(result_file_path, mode='w', encoding='utf-8') as json_file_to_write:
    json_file_to_write.write(json.dumps(submissions, indent=4))
json_file_to_write.close()

In [13]:
# build a correct_functions.json file based on the all_functions.json file (implemented in PyCharm)
# calculate the total number of Java and Python solutions among the intersection problems in the correct_functions.json

result_file_path = os.path.join(preprocessed_root, 'correct_functions.json')

java_functions_num = 0
python_functions_num = 0

with open(result_file_path, mode='r', encoding='utf-8') as json_file_to_read:
    json_data = json.load(json_file_to_read)
    java_pool = json_data['java']
    for task, task_set in java_pool.items():
        java_functions_num += len(task_set.keys())
    python_pool = json_data['python']
    for task, task_set in python_pool.items():
        python_functions_num += len(task_set.keys())
json_file_to_read.close()

print('Total number of Java tasks: ' + str(len(java_pool.keys())))
print('Total number of Java functions: ' + str(java_functions_num))
print('Total number of Python tasks: ' + str(len(python_pool.keys())))
print('Total number of Python functions: ' + str(python_functions_num))


Total number of Java tasks: 2001
Total number of Java functions: 17785
Total number of Python tasks: 2001
Total number of Python functions: 15399


In [11]:
# build a one-sentence-per-line raw corpus (txt file) for sentencepiece model training based on the correct_functions.json file

correct_functions = os.path.join(preprocessed_root, 'correct_functions.json')
txt_file = os.path.join(preprocessed_root, 'functions.txt')

result = list()

with open(correct_functions, mode='r', encoding='utf-8') as json_file_to_read:
    json_data = json.load(json_file_to_read)
    java_pool = json_data['java']
    python_pool = json_data['python']
    
    for _, solutions in java_pool.items():
        for _, solution in solutions.items():
            function = solution.replace('\n', ' ').strip() + '\n'
            result.append(function)
    
    
    for _, solutions in python_pool.items():
        for _, solution in solutions.items():
            function = solution.replace('\n', ' ').strip() + '\n'
            result.append(function)
json_file_to_read.close()


with open(txt_file, mode='w') as output_file:
    output_file.writelines(result)
output_file.close()

In [14]:
# build Java-Python function pairs based on the correct_functions.json file
# display the total number of Java-Python function pairs in the pre-training dataset

pair_functions = os.path.join(preprocessed_root, 'pair_functions.jsonl')
with open(pair_functions, 'r') as f:
    lines = f.readlines()
    print('Total number of pre-processed function pairs: ' + str(len(lines)))
#     sample_data_1 = json.loads(lines[129455])
#     print(sample_data_1['task_id'])
#     sample_data_2 = json.loads(lines[129556])
#     print(sample_data_2['task_id'])
f.close()

Total number of pre-processed function pairs: 143900


### Finally, there are 143900 Java-Python function pairs in the pre-training dataset. Specifically, 129456 function pairs are in the training set, 14444 function pairs are in the validation set. The splitting is based on task ids.