In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import datasets
import torch
import random
import json
from tqdm.auto import tqdm
import pandas as pd
from pathlib import Path

import transformers

from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [3]:
# assumes PYTHONPATH is pointing to the root of the repo
from codet5_finetune.options import options 
from codet5_finetune.data import DataCollatorNTP

In [4]:
# assumes current folder is <repo_root>/dev
cwd =!pwd
cwd = Path(cwd[0])
opt = options({'common_config': cwd.parent / 'codet5_finetune/common_config.yaml'})

['/home/toolkit/.conda/envs/rlctx_p39/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/home/toolkit/.local/share/jupyter/runtime/kernel-742d4361-7a46-4e8f-ba82-30495768691d.json']


In [5]:
opt.path_java_filtered_subset

'/repo_data/the_stack11_dedup_alt_comments_no_1K_set_subset/data'

In [6]:
ds = datasets.load_from_disk(opt.path_java_filtered_subset)

In [7]:
tokenizer = AutoTokenizer.from_pretrained(opt.base_model_name)

In [8]:
data_collator = DataCollatorNTP(
    tokenizer,
    min_encoder_seq_length=opt.min_encoder_seq_length,
    min_decoder_seq_length=opt.min_decoder_seq_length,
    encoder_seq_length=opt.encoder_seq_length,
    decoder_seq_length=opt.decoder_seq_length
)

In [9]:
opt.model_dir_base = Path(opt.model_dir_base)
model_dir = opt.model_dir_base  / opt.trained_model_name / opt.experiment_name

args = Seq2SeqTrainingArguments(
    model_dir,
    # to have evaluations step, compute metrics and hav inputs in metrics function, for debug
    evaluation_strategy="steps",
    eval_steps=100,
    predict_with_generate=True,
    include_inputs_for_metrics=True,

    logging_strategy=opt.logging_strategy,
    logging_steps=opt.logging_steps,
    save_strategy=opt.save_strategy,
    save_steps=opt.save_steps,
    learning_rate=opt.learning_rate,# if would have been perfect 4e-6 and several epochs
    per_device_train_batch_size=opt.per_device_train_batch_size,
    per_device_eval_batch_size=opt.per_device_eval_batch_size,
    weight_decay=opt.weight_decay,
    save_total_limit=opt.save_total_limit,
    num_train_epochs=opt.num_train_epochs,
    fp16=opt.fp16,
    load_best_model_at_end=False,
    report_to=opt.report_to,
    remove_unused_columns=False,
    generation_max_length=opt.decoder_seq_length
)

In [12]:
def model_init():
    return AutoModelForSeq2SeqLM.from_pretrained(opt.base_model_name)

# a hack to get prediction inputs for metrics without additional callbacks
preds = None
def assign_preds(x):
    global preds
    preds = x
    return {'ha': 0}

trainer = Seq2SeqTrainer(
    model_init=model_init,
    args=args,
    train_dataset=ds['train'],
    eval_dataset=ds['test'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=assign_preds
)

loading configuration file config.json from cache at /home/toolkit/.cache/huggingface/hub/models--Salesforce--codet5-base/snapshots/4078456db09ba972a3532827a0b5df4da172323c/config.json
Model config T5Config {
  "_name_or_path": "Salesforce/codet5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "bos_token_id": 1,
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 2,
  "feed_forward_proj": "relu",
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0"
  },
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "label2id": {
    "LABEL_0": 0
  },
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_para

In [17]:
res = trainer.evaluate(ds['test'].select(range(8)))

***** Running Evaluation *****
  Num examples = 8
  Batch size = 8
Token indices sequence length is longer than the specified maximum sequence length for this model (602 > 512). Running this sequence through the model will result in indexing errors
You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Generate config GenerationConfig {
  "bos_token_id": 1,
  "decoder_start_token_id": 0,
  "eos_token_id": 2,
  "pad_token_id": 0,
  "transformers_version": "4.26.1"
}



In [27]:
ds['test'][0]

{'size': 2237,
 'lang': 'Java',
 'max_stars_repo_path': 'clivia-assembly-base/clivia-httpClient-assembly-base/src/main/java/org/palading/clivia/httpClient/CliviaAbstractHttpRestTemplate.java',
 'max_stars_repo_name': 'leeokdkpvv5c/UIUC-data-miningv',
 'avg_line_length': 34.953125,
 'max_line_length': 132,
 'alphanum_fraction': 0.688869021,
 '__id__': 16948372,
 'content': 'package org.palading.clivia.httpClient;\r\n\r\nimport org.palading.clivia.httpClient.request.CliviaHttpRequest;\r\nimport org.palading.clivia.httpClient.request.CliviaSyncHttpRequest;\r\nimport org.palading.clivia.httpClient.response.CliviaHttpResponse;\r\n\r\nimport java.util.List;\r\n\r\n\r\n\r\n/**\r\n * @author palading_cr\r\n * @title CliviaAbstractHttpRestTemplate\r\n * @project clivia-gateway\r\n */\r\npublic abstract class CliviaAbstractHttpRestTemplate {\r\n\r\n    private static List<HttpInterceptor> httpInterceptorList;\r\n\r\n    private CliviaSyncHttpRequest cliviaSyncHttpRequest;\r\n\r\n    public Clivi

In [26]:
tokenizer.decode(preds.inputs[0])

'<s>package org.palading.clivia.httpClient;\r\n\r\nimport org.palading.clivia.httpClient.request.CliviaHttpRequest;\r\nimport org.palading.clivia.httpClient.request.CliviaSyncHttpRequest;\r\nimport org.palading.clivia.httpClient.response.CliviaHttpResponse;\r\n\r\nimport java.util.List;\r\n\r\n\r\n\r\n/**\r\n * @author palading_cr\r\n</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

In [29]:
preds.label_ids[preds.label_ids == -100] = 0
tokenizer.decode(preds.label_ids[0])

'<s> * @title CliviaAbstractHttpRestTemplate\r\n * @project clivia-gateway\r\n */\r\npublic abstract class CliviaAbstractHttpRestTemplate {\r\n\r\n    private static List<HttpInterceptor> httpInterceptorList;\r\n\r\n    private CliviaSyncHttpRequest cliviaSyncHttpRequest;\r\n\r\n    public CliviaAbstractHttpRestTemplate(List<HttpInterceptor> httpInterceptorList, CliviaSyncHttpRequest cliviaSyncHttpRequest) {\r\n        this.httpInterceptorList = httpInterceptorList;\r\n        this.cliviaSyncHttpRequest = cliviaSyncHttpRequest;\r\n\r\n    }\r\n\r\n    /**\r\n     * When the requested interception method returns fasle, execute the request and perform post-processing\r\n     * \r\n     * @author palading_cr\r\n     *\r\n     */\r\n    public CliviaHttpResponse excute(CliviaHttpRequest cliviaHttpRequest) throws Exception {\r\n        CliviaHttpResponse cliviaHttpResponse = null;\r\n        if (Interceptor.interceptor(cliviaHttpRequest)) {\r\n            cliviaHttpResponse = cliviaSyncHttp

In [30]:
tokenizer.decode(preds.predictions[0])

'<pad><s><extra_id_0>@since 1.0.0 */<extra_id_1>@since 1.0.0 */</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa