# Setup

In [6]:
!pip install -q datasets
!pip install -q evaluate
!pip install -q sentencepiece
!pip install -q transformers

In [7]:
import copy
import datasets
import evaluate
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import warnings

from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config
from transformers import Trainer, TrainingArguments

In [8]:
tqdm.pandas()
warnings.filterwarnings("ignore", category=DeprecationWarning)

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
DEVICE

device(type='cuda')

In [9]:
MODEL_NAME = "t5-base"
SEED = 1234

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Prepare Dataset

In [None]:
df_train_java = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "java",
    split="train")
df_valid_java = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "java",
    split="validation")
df_test_java = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "java",
    split="test")

df_train_java = pd.DataFrame({
    'nl': df_train_java['docstring'],
    'code': df_train_java['code'],
    "language": "java"
})
df_valid_java = pd.DataFrame({
    'nl': df_valid_java['docstring'],
    'code': df_valid_java['code'],
    "language": "java"
})
df_test_java = pd.DataFrame({
    'nl': df_test_java['docstring'],
    'code': df_test_java['code'],
    "language": "java"
})

In [None]:
df_train_go = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "go",
    split="train")
df_valid_go = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "go",
    split="validation")
df_test_go = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "go",
    split="test")

df_train_go = pd.DataFrame({
    'nl': df_train_go['docstring'],
    'code': df_train_go['code'],
    "language": "go"
})
df_valid_go = pd.DataFrame({
    'nl': df_valid_go['docstring'],
    'code': df_valid_go['code'],
    "language": "go"
})
df_test_go = pd.DataFrame({
    'nl': df_test_go['docstring'],
    'code': df_test_go['code'],
    "language": "go"
})

In [68]:
df_train_java.head()

Unnamed: 0,nl,code,language
0,Populates the current Toml instance with value...,public Toml read(Reader reader) {\n Buffere...,java
1,Populates the current Toml instance with value...,public Toml read(String tomlString) throws Ill...,java
2,Write an Object into TOML String.\n\n@param fr...,public String write(Object from) {\n try {\...,java
3,Parses user details response from server.\n@pa...,public User parseResponse(JSONObject response)...,java
4,"Parses array details of product, exchange and ...","public User parseArray(User user, JSONObject r...",java


In [69]:
df_train_go.head()

Unnamed: 0,nl,code,language
0,// getAllDepTypes returns a sorted list of nam...,func getAllDepTypes() []string {\n\tdepTypes :...,go
1,// getIoProgressReader returns a reader that w...,"func getIoProgressReader(label string, res *ht...",go
2,// Close closes the file and then removes it f...,func (f *removeOnClose) Close() error {\n\tif ...,go
3,// getTmpROC returns a removeOnClose instance ...,"func getTmpROC(s *imagestore.Store, path strin...",go
4,// getStage1Entrypoint retrieves the named ent...,"func getStage1Entrypoint(cdir string, entrypoi...",go


In [70]:
def add_java_slashes(a):
    return "// " + "\n// ".join(a.split("\n"))

In [71]:
df_train_java["nl"] = df_train_java["nl"].apply(add_java_slashes)
df_valid_java["nl"] = df_valid_java["nl"].apply(add_java_slashes)
df_test_java["nl"] = df_test_java["nl"].apply(add_java_slashes)
df_train_java.head()

Unnamed: 0,nl,code,language
0,// Populates the current Toml instance with va...,public Toml read(Reader reader) {\n Buffere...,java
1,// Populates the current Toml instance with va...,public Toml read(String tomlString) throws Ill...,java
2,// Write an Object into TOML String.\n// \n// ...,public String write(Object from) {\n try {\...,java
3,// Parses user details response from server.\n...,public User parseResponse(JSONObject response)...,java
4,"// Parses array details of product, exchange a...","public User parseArray(User user, JSONObject r...",java


In [74]:
df_train = df_train_java.append(df_train_go, ignore_index=True)
df_valid = df_valid_java.append(df_valid_go, ignore_index=True)
df_test = df_test_java.append(df_test_go, ignore_index=True)

In [75]:
df_train

Unnamed: 0,nl,code,language
0,// Populates the current Toml instance with va...,public Toml read(Reader reader) {\n Buffere...,java
1,// Populates the current Toml instance with va...,public Toml read(String tomlString) throws Ill...,java
2,// Write an Object into TOML String.\n// \n// ...,public String write(Object from) {\n try {\...,java
3,// Parses user details response from server.\n...,public User parseResponse(JSONObject response)...,java
4,"// Parses array details of product, exchange a...","public User parseArray(User user, JSONObject r...",java
...,...,...,...
332206,// Execute performs a V3 level execution of th...,"func (vc *vcursorImpl) Execute(method string, ...",go
332207,// ExecuteMultiShard is part of the engine.VCu...,func (vc *vcursorImpl) ExecuteMultiShard(rss [...,go
332208,// ExecuteStandalone is part of the engine.VCu...,func (vc *vcursorImpl) ExecuteStandalone(query...,go
332209,// StreamExeculteMulti is the streaming versio...,func (vc *vcursorImpl) StreamExecuteMulti(quer...,go


In [76]:
df_valid

Unnamed: 0,nl,code,language
0,// Copies the contents of this source to the g...,@CanIgnoreReturnValue\n public long copyTo(Ch...,java
1,// Reads the contents of this source as a stri...,public String read() throws IOException {\n ...,java
2,// Reads all the lines of this source as a lis...,public ImmutableList<String> readLines() throw...,java
3,// buffer when possible.,"@Override\n public int read(byte[] b, int off...",java
4,"// Returns a new CharBuffer identical to buf, ...",private static CharBuffer grow(CharBuffer buf)...,java
...,...,...,...
12503,// PublishNamedEvent which name infered from e...,"func PublishNamedEvent(pub Publisher, eventBod...",go
12504,// PublishEventContext publish event for given...,"func PublishEventContext(ctx context.Context, ...",go
12505,// PublishNamedEventContext publish named even...,func PublishNamedEventContext(ctx context.Cont...,go
12506,// String returns stirng representation of log...,func (level Level) String() string {\n\tswitch...,go


In [77]:
df_test

Unnamed: 0,nl,code,language
0,// Makes sure the fast-path emits in order.\n/...,protected final void fastPathOrderedEmit(U val...,java
1,// Wraps an ObservableSource into an Observabl...,@CheckReturnValue\n @SchedulerSupport(Sched...,java
2,// Returns an Observable that emits the events...,@CheckReturnValue\n @SchedulerSupport(Sched...,java
3,// Child Observers will observe the events of ...,public static <T> ConnectableObservable<T> obs...,java
4,// Creates an UnicastProcessor with the given ...,@CheckReturnValue\n @NonNull\n public st...,java
...,...,...,...
19072,// Error writes the given HTTP status to the c...,"func (r *Render) Error(status int, message ......",go
19073,// Renderer is a Middleware that maps a render...,func Renderer(options ...RenderOptions) Handle...,go
19074,//NewClient - constructor for a new dispenser ...,"func NewClient(apiKey string, url string, clie...",go
19075,//GetTask - wrapper to rest call to GET task f...,func (s *PDClient) GetTask(taskID string) (tas...,go


# Train

In [None]:
tokenizer_code = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer_nl = AutoTokenizer.from_pretrained(MODEL_NAME)

In [90]:
class Code2TextDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        code = self.df["language"].iloc[i] \
            + " docstring: " + self.df['code'].iloc[i]
        nl = self.df['nl'].iloc[i]

        code_tokenized = tokenizer_code(
            code,
            padding="max_length",
            max_length=64,
            truncation=True)
        code_input_ids = code_tokenized.input_ids
        code_attention = code_tokenized.attention_mask

        nl_tokenized = tokenizer_nl(
            nl,
            padding="max_length",
            max_length=64,
            truncation=True)
        nl_input_ids = nl_tokenized.input_ids

        return {
            "input_ids": code_input_ids,
            "labels": nl_input_ids,
            "attention_mask": code_attention
        }

In [None]:
# TODO: Implement this and add as parameter in Trainer
def compute_metrics(eval_pred):
    print(eval_pred)
    logits, labels = eval_pred
    return 1

In [None]:
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model.cuda()

training_args = TrainingArguments(
    output_dir="code2nl",
    evaluation_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
)

In [93]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=Code2TextDataset(df_train),
    eval_dataset=Code2TextDataset(df_valid),
)
trainer.train()

***** Running training *****
  Num examples = 332211
  Num Epochs = 3
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 7788
  Number of trainable parameters = 222903552


Epoch,Training Loss,Validation Loss
1,1.0959,0.976112
2,1.045,0.945385


Saving model checkpoint to code2nl/checkpoint-500
Configuration saved in code2nl/checkpoint-500/config.json
Model weights saved in code2nl/checkpoint-500/pytorch_model.bin
Saving model checkpoint to code2nl/checkpoint-1000
Configuration saved in code2nl/checkpoint-1000/config.json
Model weights saved in code2nl/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to code2nl/checkpoint-1500
Configuration saved in code2nl/checkpoint-1500/config.json
Model weights saved in code2nl/checkpoint-1500/pytorch_model.bin
Saving model checkpoint to code2nl/checkpoint-2000
Configuration saved in code2nl/checkpoint-2000/config.json
Model weights saved in code2nl/checkpoint-2000/pytorch_model.bin
Saving model checkpoint to code2nl/checkpoint-2500
Configuration saved in code2nl/checkpoint-2500/config.json
Model weights saved in code2nl/checkpoint-2500/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 12508
  Batch size = 128
Saving model checkpoint to code2nl/checkpoint-3000
Conf

KeyboardInterrupt: ignored

# Generate

In [110]:
text = \
    'java docstring: void printMessage(String message) { System.out.println(message); }'

input_ids = tokenizer_code(text, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids)

print("\n//".join(tokenizer_nl.decode(outputs[0], skip_special_tokens=True)
    .split("//"))[1:])

// Print a message. 
// 
// @param message the message to print.
