In [None]:
# default_exp repr.codeberta

# Training a Code Berta Transformer

> This module comprises a code berta (roberta for source code) to use it for future vectorization projects

In [None]:
# export
# Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from annoy import AnnoyIndex
from mpl_toolkits.mplot3d import Axes3D
from sklearn import decomposition
from pathlib import Path
from transformers import pipeline
from tqdm.notebook import tqdm

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#! pip -q install transformers annoy
#! wget -q https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip
#! wget -q https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/python.zip
! unzip -qq java.zip
#! unzip -qq python.zip

In [None]:
def jsonl_list_to_dataframe(file_list, columns=['language', 'docstring', 'code']):
    """Load a list of jsonl.gz files into a pandas DataFrame."""
    return pd.concat([pd.read_json(f,
                                   orient='records', 
                                   compression='gzip',
                                   lines=True)[columns] 
                      for f in file_list], sort=False)

In [None]:
def get_dfs(path, splits = ["train", "valid", "test"]):
    """Grabs the different data splits and converts them into dataframes"""
    dfs = []
    for split in ["train", "valid", "test"]:
        files = sorted((path/split).glob("**/*.gz"))
        df = jsonl_list_to_dataframe(files)
        dfs.append(df)
        
    return dfs

In [None]:
path = Path('.')

In [None]:
java_df = get_dfs(path/"codesearch/java/final/jsonl", ["valid"])[0]
python_df = get_dfs(path/"codesearch/python/final/jsonl", ["valid"])[0]

In [None]:
python_df.head()

Unnamed: 0,language,docstring,code
0,python,Trains a k-nearest neighbors classifier for fa...,"def train(train_dir, model_save_path=None, n_n..."
1,python,Recognizes faces in given image using a traine...,"def predict(X_img_path, knn_clf=None, model_pa..."
2,python,Shows the face recognition results visually.\n...,"def show_prediction_labels_on_image(img_path, ..."
3,python,Convert a dlib 'rect' object to a plain tuple ...,"def _rect_to_css(rect):\n """"""\n Convert ..."
4,python,"Make sure a tuple in (top, right, bottom, left...","def _trim_css_to_bounds(css, image_shape):\n ..."


In [None]:
python_df.shape

(412178, 3)

In [None]:
java_df.head()

Unnamed: 0,language,docstring,code
0,java,Bind indexed elements to the supplied collecti...,protected final void bindIndexed(Configuration...
1,java,Set {@link ServletRegistrationBean}s that the ...,public void setServletRegistrationBeans(\n\t\t...
2,java,Add {@link ServletRegistrationBean}s for the f...,public void addServletRegistrationBeans(\n\t\t...
3,java,Set servlet names that the filter will be regi...,public void setServletNames(Collection<String>...
4,java,Add servlet names for the filter.\n@param serv...,public void addServletNames(String... servletN...


In [None]:
java_df.shape

(454451, 3)

In [None]:
# hide
# This script needs to be converted to a jupyter notebook.
! python /tf/data/scripts/run_language_modeling.py \
    --output_dir=/tf/data/models/JavaBert-v1 \
    --model_type=roberta \
    --model_name_or_path=roberta-base \
    --do_train \
    --train_data_file=/tf/main/nbs/test_data/text.txt \
    --do_eval \
    --eval_data_file=/tf/main/nbs/test_data/text.txt \
    --mlm

In [None]:
fill_mask = pipeline(
    "fill-mask",
    model="/tf/data/models/JavaBert-v1",
    tokenizer="/tf/data/models/JavaBert-v1"
)

result = np.array(fill_mask("public static void <mask>(String[] args)")); result

Model name '/tf/data/models/JavaBert-v1' was not found in model name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-japanese, bert-base-japanese-whole-word-masking, bert-base-japanese-char, bert-base-japanese-char-whole-word-masking, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased, openai-gpt, transfo-xl-wt103, gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, ctrl, xlnet-base-cased, xlnet-large-cased, xlm-mlm-en-2048, xlm-mlm-ende-1024, xlm-mlm-enfr-1024, xlm-mlm-enro-1024, xlm-mlm-tlm-xnli15-1024, xlm-mlm-xnli15-1024, xlm-clm-enfr-1

array([{'sequence': '<s>public static void parse(String[] args)</s>', 'score': 0.11699261516332626, 'token': 43756},
       {'sequence': '<s>public static void main(String[] args)</s>', 'score': 0.11461341381072998, 'token': 1049},
       {'sequence': '<s>public static void execute(String[] args)</s>', 'score': 0.06959038227796555, 'token': 11189},
       {'sequence': '<s>public static void write(String[] args)</s>', 'score': 0.06705492734909058, 'token': 3116},
       {'sequence': '<s>public static void log(String[] args)</s>', 'score': 0.021446825936436653, 'token': 7425}],
      dtype=object)

In [None]:
from nbdev.export import notebook2script
notebook2script()