# Code Tokenization

### This notebook explores the code tokenization methodology implemented by SentencePiece (https://github.com/google/sentencepiece).

### SentencePiece
SentencePiece is an unsupervised text tokenizer and detokenizer mainly for neural network-based text generation systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements subword units (e.g. byte-pair-encoding (BPE)) and unigram language model with the extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end system that does not depend on language-specific pre/postprocessing.

### Characteristics of SentencePiece
1) The number of unique tokens is predetermined  
Neural Machine Translation models typically operate with a fixed vocabulary. Unlike most unsupervised word segmentation algorithms, which assume an infinite vocabulary, SentencePiece trains the segmentation model such that the final vocabulary size is fixed, e.g., 8K, 16K, 32K.  
2) Trains from raw sentences  
Previous sub-word implementations assume that the input sentences are pre-tokenized. This constraint was required for efficient training, but makes the preprocessing complicated as we have to run language dependent tokenizers in advance. The implementation of SentencePiece is fast enough to train the model from raw sentences.  
3) Subword regularization and BPE-dropout  
Subword regularization and BPE-dropout are simple regularization methods that virtually augment training data with on-the-fly subword sampling, which helps to improve the accuracy as well as robustness of NMT models.

### Part 1: Training the SentencePiece Model
We intend to train a code tokenizer for the pre-training dataset and the fine-tuning dataset. Specifically, 1) we will pre-process the above two datasets by removing comments; 2) we will build an one-sentence-per-line raw corpus (txt file) based on the two pre-processed datasets by selecting all functions and programs which can be parsed into ASTs successfully; 3) we will use the consturcted raw corpus to train a code tokenizer using the following command.  

The command to train the code tokenizer for the pre-processed datasets will be:  
`spm_train --input=/home/BPE/code.txt --model_type=bpe --vocab_size=25000 --model_prefix=code_bpe_25K --bos_id=0 --pad_id=1 --pad_piece=[PAD] --eos_id=2 --unk_id=3`  


Statistics about the raw corpus (functions.txt) of the pre-training dataset are shown below.  
   | Total number of tasks    | Total number of Java functions    | Total number of Python functions    |   
   | :----------               | :----------        | :----------       |     
   | 2001    | 17,785         |  15,399        |  

Statistics about the raw corpus (programs.txt) of the fine-tuning dataset are shown below. 
   | Total number of Java tasks | Total number of Java programs | Total number of Python tasks | Total number of Python programs    |   
   | :----------               | :----------        | :----------               | :----------       |     
   | 993    | 5660   |  1007   |  6239    |  

In [6]:
import os
import json
import jsonlines
from pprint import pprint

import sentencepiece as spm

In [2]:
# build an one-sentence-per-line raw corpus 'code.txt' based on the 'functions.txt' and the 'programs.txt'

functions_txt_path = '/Users/rongdang/Desktop/semantic-code-clone/dataset/cross-language/CodeNet_Microsoft/preprocessed_dataset/functions.txt'
programs_txt_path = '/Users/rongdang/Desktop/semantic-code-clone/dataset/cross-language/C4/preprocessed_dataset/programs.txt'

code_txt_path = '/Users/rongdang/Desktop/semantic-code-clone/checkpoint/BPE_Model/code.txt'
code_lines = list()

with open(functions_txt_path, mode='r', encoding='utf-8') as function_txt_file:
    function_lines = function_txt_file.readlines()
    code_lines.extend(function_lines)
function_txt_file.close()

with open(programs_txt_path, mode='r', encoding='utf-8') as program_txt_file:
    program_lines = program_txt_file.readlines()
    code_lines.extend(program_lines)
program_txt_file.close()

with open(code_txt_path, mode='w') as code_txt_file:
    code_txt_file.writelines(code_lines)
code_txt_file.close()

### Part 2: Usage of the SentencePiece Model
We will display tokenization results of one Java-Python function pair from the pre-training dataset and one Java-Python program pair from the fine-tuning dataset using the pre-trained SentencePiece BPE tokenizer.

In [3]:
# define the path of sentencepiece models

spm_model_path = '/Users/rongdang/Desktop/semantic-code-clone/checkpoint/BPE_Model/code_bpe_25K.model'
sp = spm.SentencePieceProcessor()
sp.Load(spm_model_path)


True

In [4]:
# define the path of pre-training dataset

ccs_root = '/Users/rongdang/Desktop/semantic-code-clone/dataset/cross-language/CodeNet_Microsoft'
correct_functions = os.path.join(ccs_root, 'preprocessed_dataset', 'correct_functions.json')


In [5]:
# define the path of fine-tuning dataset

c4_root = '/Users/rongdang/Desktop/semantic-code-clone/dataset/cross-language/C4'
correct_programs = os.path.join(c4_root, 'preprocessed_dataset', 'correct_programs.json')


In [6]:
# display the tokenization result of one Java-Python function pair from the pre-training dataset

with open(correct_functions, mode='r', encoding='utf-8') as correct_functions_file:
    json_data = json.load(correct_functions_file)
    java_pool = json_data['java']
    python_pool = json_data['python']
    
    java_code = java_pool['1085']['s311391346']
    java_tokens = sp.Encode(java_code, out_type=str)
    
    python_code = python_pool['1085']['s860127205']
    python_tokens = sp.Encode(python_code, out_type=str)

correct_functions_file.close()

print('sample Java function: ')
pprint(java_code)
print('sample Java tokens: ')
print(java_tokens)
print('------------------------------------------------------')
print('sample Python function: ')
pprint(python_code)
print('sample Python tokens: ')
print(python_tokens)

sample Java function: 
('import java.util.Scanner;\n'
 'public class Main\n'
 '{\n'
 '    public static void main(String[] args)\n'
 '    {\n'
 '        Scanner scan = new Scanner(System.in);\n'
 '        int n = scan.nextInt();\n'
 '        int s = scan.nextInt();\n'
 '        int b = scan.nextInt();\n'
 '        while (n != 0 && s != 0 && b != 0) {\n'
 '            int r = 0;\n'
 '            int f = 0;\n'
 '            int t[] = new int[n];\n'
 '            for (int i = 0; i < n; i++) {\n'
 '                t[i] = scan.nextInt();\n'
 '            }\n'
 '            for (int x = s; x <= b; x++) {\n'
 '                int l = t[x - 1] - t[x];\n'
 '                if (r <= l) {\n'
 '                    r = l;\n'
 '                    f = x;\n'
 '                }\n'
 '            }\n'
 '            System.out.println(f);\n'
 '            n = scan.nextInt();\n'
 '            s = scan.nextInt();\n'
 '            b = scan.nextInt();\n'
 '        }\n'
 '    }\n'
 '}')
sample Java tokens: 


In [7]:
# display the tokenization result of one Java-Python program pair from the fine-tuning dataset

with open(correct_programs, mode='r', encoding='utf-8') as correct_programs_file:
    json_data = json.load(correct_programs_file)
    java_pool = json_data['java']
    python_pool = json_data['python']
    
    java_code = java_pool['1245']['79154']
    java_tokens = sp.Encode(java_code, out_type=str)
    
    python_code = python_pool['1245']['79131']
    python_tokens = sp.Encode(python_code, out_type=str)

correct_programs_file.close()

print('sample Java program: ')
pprint(java_code)
print('sample Java tokens: ')
print(java_tokens)
print('------------------------------------------------------')
print('sample Python program: ')
pprint(python_code)
print('sample Python tokens: ')
print(python_tokens)

sample Java program: 
('import java.util.*; \n'
 'public class Main { \n'
 ' public static void main (String[] args) { \n'
 '  Scanner sc = new Scanner(System.in); \n'
 '  int n = sc.nextInt(); \n'
 '  if (isPrime(n)) { \n'
 '   System.out.println("YES"); \n'
 '  } else { \n'
 '   System.out.println("NO"); \n'
 '  } \n'
 ' } \n'
 ' static boolean isPrime(int x) { \n'
 '  for (int i = 2; i <= 1000 && i < x; i++) { \n'
 '   if (x % i == 0) { \n'
 '    return false; \n'
 '   } \n'
 '  } \n'
 '  return true; \n'
 ' } \n'
 '}')
sample Java tokens: 
['▁import', '▁java', '.', 'util', '.*;', '▁public', '▁class', '▁Main', '▁{', '▁public', '▁static', '▁void', '▁main', '▁(', 'String', '[]', '▁args', ')', '▁{', '▁Scanner', '▁sc', '▁=', '▁new', '▁Scanner', '(', 'System', '.', 'in', ');', '▁int', '▁n', '▁=', '▁sc', '.', 'nextInt', '();', '▁if', '▁(', 'isPrime', '(', 'n', '))', '▁{', '▁System', '.', 'out', '.', 'println', '("', 'YES', '");', '▁}', '▁else', '▁{', '▁System', '.', 'out', '.', 'println',

### Part 3: Analyze the Pre-training and the Fine-tuning Dataset
We will investigate dataset statistics in terms of: 1) the average number of code lines for each language; 2) the average number of code tokens for each language.

In [8]:
def count_code_lines(code_str):
    line_num = 0
    code_lines = code_str.split('\n')
    for line in code_lines:
        line = line.strip('\n').strip('\t').strip()
        if line != '\n' and line != '\t' and line != '':
            line_num += 1
    return line_num

In [9]:
# analyze the dataset statistics of the pre-training dataset using the SentencePiece BPE tokenizer

java_line_nums = list()
java_token_nums = list()
python_line_nums = list()
python_token_nums = list()

with open(correct_functions, mode='r', encoding='utf-8') as correct_functions_file:
    json_data = json.load(correct_functions_file)
    java_pool = json_data['java']
    python_pool = json_data['python']
    
    for task, solutions in java_pool.items():
        for _, solution in solutions.items():
            java_line_num = count_code_lines(solution)
            java_token_num = len(sp.Encode(solution, out_type=str))
            java_line_nums.append(java_line_num)
            java_token_nums.append(java_token_num)
            
    for task, solutions in python_pool.items():
        for _, solution in solutions.items():
            python_line_num = count_code_lines(solution)
            python_token_num = len(sp.Encode(solution, out_type=str))
            python_line_nums.append(python_line_num)
            python_token_nums.append(python_token_num)
            
correct_functions_file.close()

print('-----Pre-training dataset statistics-----')
print('avg number of lines in Java function: ' + str(round(sum(java_line_nums)/len(java_line_nums), 2)))
print('avg number of code sub-tokens in Java function: ' + str(round(sum(java_token_nums)/len(java_token_nums), 2)))
print('avg number of lines in Python function: ' + str(round(sum(python_line_nums)/len(python_line_nums), 2)))
print('avg number of code sub-tokens in Python function: ' + str(round(sum(python_token_nums)/len(python_token_nums), 2)))


-----Pre-training dataset statistics-----
avg number of lines in Java function: 110.94
avg number of code sub-tokens in Java function: 821.29
avg number of lines in Python function: 22.93
avg number of code sub-tokens in Python function: 178.91


In [10]:
# analyze the dataset statistics of the fine-tuning dataset using the SentencePiece BPE tokenizer

java_line_nums = list()
java_token_nums = list()
python_line_nums = list()
python_token_nums = list()

with open(correct_programs, mode='r', encoding='utf-8') as correct_programs_file:
    json_data = json.load(correct_programs_file)
    java_pool = json_data['java']
    python_pool = json_data['python']
    
    for task, solutions in java_pool.items():
        for _, solution in solutions.items():
            java_line_num = count_code_lines(solution)
            java_token_num = len(sp.Encode(solution, out_type=str))
            java_line_nums.append(java_line_num)
            java_token_nums.append(java_token_num)
            
    for task, solutions in python_pool.items():
        for _, solution in solutions.items():
            python_line_num = count_code_lines(solution)
            python_token_num = len(sp.Encode(solution, out_type=str))
            python_line_nums.append(python_line_num)
            python_token_nums.append(python_token_num)
            
correct_programs_file.close()

print('-----Fine-tuning dataset statistics-----')
print('avg number of lines in Java program: ' + str(round(sum(java_line_nums)/len(java_line_nums), 2)))
print('avg number of code sub-tokens in Java program: ' + str(round(sum(java_token_nums)/len(java_token_nums), 2)))
print('avg number of lines in Python program: ' + str(round(sum(python_line_nums)/len(python_line_nums), 2)))
print('avg number of code sub-tokens in Python program: ' + str(round(sum(python_token_nums)/len(python_token_nums), 2)))


-----Fine-tuning dataset statistics-----
avg number of lines in Java program: 64.49
avg number of code sub-tokens in Java program: 426.73
avg number of lines in Python program: 20.69
avg number of code sub-tokens in Python program: 171.53


### analyze the dataset statistics of the pre-training dataset using the pre-trained CodeBERT tokenizer
-----Pre-training dataset statistics-----  
avg number of lines in Java function: 110.94  
avg number of code tokens in Java function: 935.5  
avg number of lines in Python function: 22.93  
avg number of code tokens in Python function: 209.86  

### analyze the dataset statistics of the fine-tuning dataset using the pre-trained CodeBERT tokenizer
-----Fine-tuning dataset statistics-----  
avg number of lines in Java program: 64.49  
avg number of code tokens in Java program: 496.59  
avg number of lines in Python program: 20.69  
avg number of code tokens in Python program: 213.93  