In [None]:
# default_exp data.core

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Data

> This model contains all the necessary functionality for managing data. @Nathan


In [None]:
# export
import icodegen
import re

import pandas as pd

from pathlib import Path
from subprocess import CalledProcessError, check_output
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors
from typing import Dict, Optional

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# hide
from ds4se.mgmnt.prep.i import jsonl_list_to_dataframe, get_dfs

path = Path('/home/nathan/Downloads/')
df_trn, df_val, df_tst = get_dfs(path/"java/final/jsonl")

sample = 0.01
df_trn = df_trn.sample(frac = sample)
df_val = df_val.sample(frac = sample)
df_tst = df_tst.sample(frac = sample)
df_trn.head()

Unnamed: 0,code,docstring
21559,public static <T> T defaultValue(Class<T> prim...,Returns the boxed default value for a primitiv...
7041,public com.google.protobuf.ByteString\n g...,<pre>\nExplanation of why it was deprecated an...
193,public void setMaxPayloadSize(int max) {\n ...,Sets the maximum payload size in bytes.\n\n@pa...
16801,public synchronized void fit(MultiDataSetItera...,Fit the ComputationGraph using a MultiDataSetI...
10601,public static boolean isSubtype(final Class<? ...,Checks if the specified type is a descendant f...


In [None]:
# hide
len(df_trn), len(df_val), len(df_tst)

(4545, 153, 269)

In [None]:
# hide
df_fake = pd.DataFrame(['this is a test', 'भारत test'], columns = ['code']);df_fake

In [None]:
# export
def _isASCII(mthd: str) -> bool:
    """
    Check if the given method contains only ASCII characters. From https://stackoverflow.com/a/27084708/5768407.

    :param mthd: the method to verify contains only ASCII characters
    :returns: returns a boolean representing whether or not the given method contains only ASCII characters
    """
    try:
        mthd.encode(encoding = 'utf-8').decode('ascii')
    except UnicodeDecodeError:
        return False
    else:
        return True

def remove_non_ascii(df: pd.DataFrame, n: Optional[int] = None) -> pd.DataFrame:
    """
    Remove all methods that contain non-ascii characters from a given pandas dataframe, not in-place.

    :param df: the pandas dataframe containing each method to be beautified
    :param n: the number of methods to evaluate. If none, the entire dataframe will be used
    :returns: returns a new dataframe without methods that contain non-ascii characters
    """
    if n is None: n = len(df)

    df = df.iloc[:n].copy()
    df = df[df.code.apply(_isASCII)]
    
    return df

In [None]:
NON_ASCII_DF = pd.DataFrame(['this is a test'], columns = ['code'])
df_non_ascii = remove_non_ascii(df_fake)

assert (NON_ASCII_DF == df_non_ascii).all().all()

In [None]:
# hide
%time df_trn = remove_non_ascii(df_trn)

In [None]:
# hide
df_fake = pd.DataFrame([
    '''public void setPipelines(java.util.Collection<Pipeline> pipelines) {
        if (pipelines == null) {
            this.pipelines = null;
            return;
        }

        this.pipelines = new com.amazonaws.internal.SdkInternalList<Pipeline>(pipelines);
    }
    '''
], columns = ['code']); df_fake

Unnamed: 0,code
0,public void setPipelines(java.util.Collection<...


In [None]:
# export
def _beautify(mthd: str) -> str:
    """
    Beautifies a given method using uncrustify with the sun.cfg style, i.e., Oracle's style.

    :param mthd: the method to beautify
    :returns: returns a beautified version of the given method
    """
    # get path of icodegen
    icodegen_path = Path(icodegen.__path__[0])

    # create tmp file to store df contents for training tokenizer
    tmp_path = Path('/tmp')
    tmp_path.mkdir(parents = True, exist_ok = True)
    with open(tmp_path/'tmp.java', 'w') as f:
        f.write(mthd)

    try:
        beaut_mthd = check_output([
            icodegen_path/'uncrustify', '-c', icodegen_path/'sun.cfg',
            '-f', tmp_path/'tmp.java'
        ]).decode('utf-8')
    except CalledProcessError as e:
        # Exception thrown when the method is malformed, i.e, it is missing a curly brace
        beaut_mthd = e.output.decode('utf-8')

    return beaut_mthd

def beautify_code(df: pd.DataFrame, n: Optional[int] = None) -> pd.DataFrame:
    """
    Beautify the methods in a pandas dataframe using uncrustify with the sun.cfg style, i.e., Oracle's style, not in-place.

    :param df: the pandas dataframe containing each method to be beautified
    :param n: the number of methods to evaluate. If none, the entire dataframe will be used
    :returns: returns a modified dataframe with the methods beautified
    """
    if n is None: n = len(df)

    df = df.iloc[:n].copy()
    df.code = df.code.apply(_beautify)
    
    return df

In [None]:
BEAUT_MTHD = '''public void setPipelines(java.util.Collection<Pipeline> pipelines) {
    if (pipelines == null) {
	this.pipelines = null;
	return;
    }
    this.pipelines = new com.amazonaws.internal.SdkInternalList<Pipeline>(
	pipelines);
}
'''

df_beaut = beautify_code(df_fake)

assert BEAUT_MTHD == df_beaut.code.values[0]

In [None]:
# hide
# %time df_beaut = beautify_code(df_trn)

In [None]:
# hide
idx = 0
print(df_trn.code.values[idx])
print(df_beaut.code.values[idx])

In [None]:
# export
# dicts of special tokens we are adding to the tokenizers so they do not get split

extra_tokens = {
    '<n>': '\n'
}

# from https://docs.oracle.com/javase/tutorial/java/nutsandbolts/_keywords.html
java_reserved_tokens = {
    '<abstract>': 'abstract', '<assert>': 'assert', '<boolean>': 'boolean',
    '<break>': 'break', '<byte>': 'byte', '<case>': 'case',
    '<catch>': 'catch', '<char>': 'char', '<class>': 'class',
    '<const>': 'const', '<continue>': 'continue', '<default>': 'default',
    '<do>': 'do', '<double>': 'double', '<else>': 'else',
    '<enum>': 'enum', '<extends>': 'extends', '<final>': 'final',
    '<finally>': 'finally', '<float>': 'float', '<for>': 'for',
    '<goto>': 'goto', '<if>': 'if', '<implements>': 'implements',
    '<import>': 'import', '<instanceof>': 'instanceof', '<int>': 'int',
    '<interface>': 'interface', '<long>': 'long', '<native>': 'native',
    '<new>': 'new', '<package>': 'package', '<private>': 'private',
    '<protected>': 'protected', '<public>': 'public', '<return>': 'return',
    '<short>': 'short', '<static>': 'static', '<strictfp>': 'strictfp',
    '<super>': 'super', '<switch>': 'switch', '<synchronized>': 'synchronized',
    '<this>': 'this', '<throw>': 'throw', '<throws>': 'throws',
    '<transient>': 'transient', '<try>': 'try', '<void>': 'void',
    '<volatile>': 'volatile', '<while>': 'while'
}

# from https://docs.oracle.com/javase/tutorial/java/nutsandbolts/opsummary.html
java_operator_tokens = {
    '<=>': '=', '<+>': '+', '<->': '-',
    '<*>': '*', '</>': '/', '<%>': '%',
    '<++>': '++', '<-->': '--', '<!>': '!',
    '<==>': '==', '<!=>': '!=', '<greater>': '>',
    '<greater_equal>': '>=', '<lesser>': '<', '<lesser_equal>': '<=',
    '<&&>': '&&', '<||>': '||', '<?>': '?',
    '<:>': ':', '<~>': '~', '<double_lesser>': '<<',
    '<double_greater>': '>>', '<triple_greater>': '>>>', '<&>': '&',
    '<^>': '^', '<|>': '|'
}

java_structural_tokens = {
    '<{>': '{', '<}>': '}', '<[>': '[',
    '<]>': ']', '<lesser>': '<', '<greater>': '>',
    '<(>': '(', '<)>': ')', '<;>': ';'
}

java_extra_tokens = {
    '<@>': '@', '<...>': '...',
    '<null>': 'null', '<true>': 'true', '<false>': 'false'
}

# combination of all dictionaries
java_special_tokens = {
    **java_reserved_tokens, **java_operator_tokens, **java_structural_tokens,
    **java_extra_tokens, **extra_tokens
}

In [None]:
# hide
df_fake = pd.DataFrame(['>>> > + public ++ \n\n \t \t \t\t  '], columns = ['code']); df_fake

Unnamed: 0,code
0,>>> > + public ++ \n\n \t \t \t\t


In [None]:
sorted(java_special_tokens.values(), key = len, reverse=True)

['synchronized',
 'implements',
 'instanceof',
 'interface',
 'protected',
 'transient',
 'abstract',
 'continue',
 'strictfp',
 'volatile',
 'boolean',
 'default',
 'extends',
 'finally',
 'package',
 'private',
 'assert',
 'double',
 'import',
 'native',
 'public',
 'return',
 'static',
 'switch',
 'throws',
 'break',
 'catch',
 'class',
 'const',
 'final',
 'float',
 'short',
 'super',
 'throw',
 'while',
 'false',
 'byte',
 'case',
 'char',
 'else',
 'enum',
 'goto',
 'long',
 'this',
 'void',
 'null',
 'true',
 'for',
 'int',
 'new',
 'try',
 '>>>',
 '...',
 'do',
 'if',
 '++',
 '--',
 '==',
 '!=',
 '>=',
 '<=',
 '&&',
 '||',
 '<<',
 '>>',
 '=',
 '+',
 '-',
 '*',
 '/',
 '%',
 '!',
 '>',
 '<',
 '?',
 ':',
 '~',
 '&',
 '^',
 '|',
 '{',
 '}',
 '[',
 ']',
 '(',
 ')',
 ';',
 '@',
 '\n']

In [None]:
sorted(java_special_tokens.values(), key = len, reverse=True)

In [None]:
sorted(java_special_tokens.items(), key = lambda x: len(x[1]), reverse = True)

[('<synchronized>', 'synchronized'),
 ('<implements>', 'implements'),
 ('<instanceof>', 'instanceof'),
 ('<interface>', 'interface'),
 ('<protected>', 'protected'),
 ('<transient>', 'transient'),
 ('<abstract>', 'abstract'),
 ('<continue>', 'continue'),
 ('<strictfp>', 'strictfp'),
 ('<volatile>', 'volatile'),
 ('<boolean>', 'boolean'),
 ('<default>', 'default'),
 ('<extends>', 'extends'),
 ('<finally>', 'finally'),
 ('<package>', 'package'),
 ('<private>', 'private'),
 ('<assert>', 'assert'),
 ('<double>', 'double'),
 ('<import>', 'import'),
 ('<native>', 'native'),
 ('<public>', 'public'),
 ('<return>', 'return'),
 ('<static>', 'static'),
 ('<switch>', 'switch'),
 ('<throws>', 'throws'),
 ('<break>', 'break'),
 ('<catch>', 'catch'),
 ('<class>', 'class'),
 ('<const>', 'const'),
 ('<final>', 'final'),
 ('<float>', 'float'),
 ('<short>', 'short'),
 ('<super>', 'super'),
 ('<throw>', 'throw'),
 ('<while>', 'while'),
 ('<false>', 'false'),
 ('<byte>', 'byte'),
 ('<case>', 'case'),
 ('<ch

In [None]:
# export
def _replace_toks(mthd: str, spec_toks: Dict[str, str]) -> str:
    """
    Helper function for replacing all special tokens in a given method. This will replace longer special tokens first in order to not mistakenly breakup a special token that is part of a longer sequence. Adapted from https://stackoverflow.com/a/6117124/5768407 and https://stackoverflow.com/a/11753945/5768407

    :param mthd: the method to have its special tokens replaced
    :param spec_toks: a dictionary containing the special tokens to replace and the new tokens to replace them with
    :returns: returns the method with its special tokens replaced
    """
    # construct escaped versions of keys for running through regex
    spec_toks = dict((re.escape(v), k) for k, v in sorted(java_special_tokens.items(), key = lambda x: len(x[1]), reverse = True))
    # construct regex pattern for finding all special tokens in a method
    pattern = re.compile("|".join(spec_toks.keys()))
    # replace all special tokens in a method
    mthd = pattern.sub(lambda m: spec_toks[re.escape(m.group(0))], mthd)

    return mthd

def replace_special_tokens(df: pd.DataFrame, spec_toks: Dict[str, str], n: Optional[int] = None) -> pd.DataFrame:
    """
    Replace all the special tokens in a pandas dataframe.

    :param df: the pandas dataframe containing each method to replace special tokens in
    :param n: the number of methods to evaluate. If none, the entire dataframe will be used
    :returns: returns a modified dataframe with the special tokens replaced
    """
    if n is None: n = len(df)

    df = df.iloc[:n].copy()
    df.code = df.code.apply(lambda mthd: _replace_toks(mthd, spec_toks))

    return df

In [None]:
REPLACED_MTHD = '<triple_greater> <greater> <+> <public> <++> <n><n> \t \t \t\t  '
df_replaced = replace_special_tokens(df_fake, java_special_tokens)

assert REPLACED_MTHD == df_replaced.code.values[0]

In [None]:
df_replaced.code.values[0]

'>>> > + public ++ \n\n \t \t \t\t  '

In [None]:
# hide
df_replaced = replace_special_tokens(df_trn, java_special_tokens)
print(df_replaced.code.values[5])

In [None]:
# hide
fake_data = '<triple_greater> <greater> <+> <public> <++> <n><n>'

In [None]:
# export
def train_tokenizer(df: pd.DataFrame, n: Optional[int] = None, vocab_sz: Optional[int] = 10_000, min_freq: Optional[int] = 2, output: Optional[Path] = None) -> Tokenizer:
    """
    Train a ByteLevel BPE tokenizer on a given pandas dataframe. Code adapted from https://github.com/huggingface/tokenizers/tree/master/bindings/python.

    :param df: the pandas dataframe containing each method to have the tokenizer train on
    :param n: the number of methods to evaluate. If none, the entire dataframe will be used
    :param vocab_sz: the maximum vocabulary size of the trained tokenizer. Defaulted was selected from: Big Code != Big Vocabulary: Open-Vocabulary Models for Source Code
    :param min_freq: the minimum frequency a token has to occur to be considered
    :returns: returns a trained ByteLevel BPE tokenizer
    """
    if n is None: n = len(df)

    # create tmp file to store df contents for training tokenizer
    tmp_path = Path('/tmp')
    tmp_path.mkdir(parents = True, exist_ok = True)
    with open(tmp_path/'tmp_tokenize.txt', 'w') as f:
        f.write('\n'.join(df.code.values[:n]))

    # initialize a tokenizer
    tokenizer = Tokenizer(models.BPE())

    # customize pre-tokenization and decoding
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space = True)
    tokenizer.decoder = decoders.ByteLevel()
    tokenizer.post_processor = processors.ByteLevel(trim_offsets = True)

    # train tokenizer with data in tmp file
    trainer = trainers.BpeTrainer(
        vocab_size = vocab_sz, min_frequency = min_freq,
        special_tokens = list(java_special_tokens.keys())
    )
    tokenizer.train(trainer, [str(tmp_path/'tmp_tokenize.txt')])

    # save tokenizer if output path given
    if output is not None:
        tokenizer.save(output, pretty = True)
    
    return tokenizer

In [None]:
TOKENIZED_SPEC = [
    '<triple_greater>', 'Ġ', '<greater>', 'Ġ', '<+>', 'Ġ',
    '<public>', 'Ġ', '<++>', 'Ġ', '<n>', '<n>'
]
tokenizer = train_tokenizer(df_fake)
encoded = tokenizer.encode(fake_data)

assert TOKENIZED_SPEC == encoded.tokens

In [None]:
# hide
# idx = 0
# df_beaut = beautify_code(df_trn, n = 10)
# df_replaced = replace_special_tokens(df_beaut, java_special_tokens)

# tokenizer = train_tokenizer(df_trn)
# encoded = tokenizer.encode(df_replaced.code.values[idx])
# print(df_replaced.code.values[idx])
# print('=' * 100)
# print(encoded.tokens)