# Setup

In [1]:
!pip install -q datasets
!pip install -q evaluate
!pip install -q sentencepiece
!pip install -q transformers
!pip install -q accelerate -U

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m57.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━

In [21]:
import copy
import datasets
import evaluate
import numpy as np
import pandas as pd
import seaborn as sns
import torch
torch.cuda.empty_cache()
import torch.nn as nn
import torch.optim as optim
import warnings

from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config
from transformers import Trainer, TrainingArguments

In [3]:
tqdm.pandas()
warnings.filterwarnings("ignore", category=DeprecationWarning)

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
DEVICE

device(type='cuda')

In [4]:
MODEL_NAME = "t5-base"
SEED = 1234

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Prepare Dataset

In [5]:
df_train_java = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "java",
    split="train")
df_valid_java = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "java",
    split="validation")
df_test_java = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "java",
    split="test")

df_train_java = pd.DataFrame({
    'nl': df_train_java['docstring'],
    'code': df_train_java['code'],
    "language": "java"
})
df_valid_java = pd.DataFrame({
    'nl': df_valid_java['docstring'],
    'code': df_valid_java['code'],
    "language": "java"
})
df_test_java = pd.DataFrame({
    'nl': df_test_java['docstring'],
    'code': df_test_java['code'],
    "language": "java"
})

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/26.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/141M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.25M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.38M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/164923 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5183 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10955 [00:00<?, ? examples/s]

In [6]:
df_train_go = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "go",
    split="train")
df_valid_go = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "go",
    split="validation")
df_test_go = datasets.load_dataset(
    "code_x_glue_ct_code_to_text",
    "go",
    split="test")

df_train_go = pd.DataFrame({
    'nl': df_train_go['docstring'],
    'code': df_train_go['code'],
    "language": "go"
})
df_valid_go = pd.DataFrame({
    'nl': df_valid_go['docstring'],
    'code': df_valid_go['code'],
    "language": "go"
})
df_test_go = pd.DataFrame({
    'nl': df_test_go['docstring'],
    'code': df_test_go['code'],
    "language": "go"
})

Downloading data:   0%|          | 0.00/112M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.29M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.43M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/167288 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7325 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/8122 [00:00<?, ? examples/s]

In [7]:
df_train_java.head()

Unnamed: 0,nl,code,language
0,Compare the supplied plaintext password to a h...,"public static boolean check(String passwd, Str...",java
1,Attempt to detect the current platform.\n\n@re...,public static Platform detect() throws Unsuppo...,java
2,Gets the node meta data.\n\n@param key - the m...,public <T> T getNodeMetaData(Object key) {\n ...,java
3,Copies all node meta data from the other node ...,public void copyNodeMetaData(ASTNode other) {\...,java
4,Sets the node meta data.\n\n@param key - the m...,"public void setNodeMetaData(Object key, Object...",java


In [8]:
df_train_go.head()

Unnamed: 0,nl,code,language
0,// getStringValue will return a quoted string ...,"func getStringValue(b []rune) (int, error) {\n...",go
1,// getBoolValue will return a boolean and the ...,"func getBoolValue(b []rune) (int, error) {\n\t...",go
2,// getNumericalValue will return a numerical s...,"func getNumericalValue(b []rune) (int, int, er...",go
3,// getNegativeNumber will return a negative nu...,func getNegativeNumber(b []rune) int {\n\tif b...,go
4,// isEscaped will return whether or not the ch...,"func isEscaped(value []rune, b rune) bool {\n\...",go


In [9]:
def add_java_slashes(a):
    return "// " + "\n// ".join(a.split("\n"))

In [10]:
df_train_java["nl"] = df_train_java["nl"].apply(add_java_slashes)
df_valid_java["nl"] = df_valid_java["nl"].apply(add_java_slashes)
df_test_java["nl"] = df_test_java["nl"].apply(add_java_slashes)
df_train_java.head()

Unnamed: 0,nl,code,language
0,// Compare the supplied plaintext password to ...,"public static boolean check(String passwd, Str...",java
1,// Attempt to detect the current platform.\n//...,public static Platform detect() throws Unsuppo...,java
2,// Gets the node meta data.\n// \n// @param ke...,public <T> T getNodeMetaData(Object key) {\n ...,java
3,// Copies all node meta data from the other no...,public void copyNodeMetaData(ASTNode other) {\...,java
4,// Sets the node meta data.\n// \n// @param ke...,"public void setNodeMetaData(Object key, Object...",java


In [11]:
df_train = df_train_java.append(df_train_go, ignore_index=True)
df_valid = df_valid_java.append(df_valid_go, ignore_index=True)
df_test = df_test_java.append(df_test_go, ignore_index=True)

  df_train = df_train_java.append(df_train_go, ignore_index=True)
  df_valid = df_valid_java.append(df_valid_go, ignore_index=True)
  df_test = df_test_java.append(df_test_go, ignore_index=True)


In [58]:
df_train = df_train.head(5000)
df_valid = df_valid.head(5000)
df_test = df_test.head(5000)

In [59]:
df_valid

Unnamed: 0,nl,code,language
0,// Copies the contents of this source to the g...,@CanIgnoreReturnValue\n public long copyTo(Ch...,java
1,// Reads the contents of this source as a stri...,public String read() throws IOException {\n ...,java
2,// Reads all the lines of this source as a lis...,public ImmutableList<String> readLines() throw...,java
3,// buffer when possible.,"@Override\n public int read(byte[] b, int off...",java
4,"// Returns a new CharBuffer identical to buf, ...",private static CharBuffer grow(CharBuffer buf)...,java
5,// Handle the case of underflow caused by need...,private void readMoreChars() throws IOExceptio...,java
6,// Flips the buffer output buffer so we can st...,private void startDraining(boolean overflow) {...,java
7,// Copies an iterable's elements into an array...,"@GwtIncompatible // Array.newInstance(Class, i...",java
8,// Determines if the given iterable contains n...,public static boolean isEmpty(Iterable<?> iter...,java
9,// Useful as a public method?,"static <T> Function<Iterable<? extends T>, Ite...",java


In [60]:
df_test

Unnamed: 0,nl,code,language
0,// Makes sure the fast-path emits in order.\n/...,protected final void fastPathOrderedEmit(U val...,java
1,// Wraps an ObservableSource into an Observabl...,@CheckReturnValue\n @SchedulerSupport(Sched...,java
2,// Returns an Observable that emits the events...,@CheckReturnValue\n @SchedulerSupport(Sched...,java
3,// Child Observers will observe the events of ...,public static <T> ConnectableObservable<T> obs...,java
4,// Creates an UnicastProcessor with the given ...,@CheckReturnValue\n @NonNull\n public st...,java
5,// Creates an UnicastProcessor with the given ...,@CheckReturnValue\n @NonNull\n public st...,java
6,// Tries to subscribe to a possibly Callable s...,"@SuppressWarnings(""unchecked"")\n public sta...",java
7,// Maps a scalar value into a Publisher and em...,"public static <T, U> Flowable<U> scalarXMap(fi...",java
8,// Removes all handlers and resets to default ...,public static void reset() {\n setError...,java
9,// Wraps a CompletableSource into a Maybe.\n//...,@CheckReturnValue\n @NonNull\n @Schedule...,java


# Train

In [61]:
tokenizer_code = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer_nl = AutoTokenizer.from_pretrained(MODEL_NAME)

In [62]:
class Code2TextDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        code = self.df["language"].iloc[i] \
            + " docstring: " + self.df['code'].iloc[i]
        nl = self.df['nl'].iloc[i]

        code_tokenized = tokenizer_code(
            code,
            padding="max_length",
            max_length=64,
            truncation=True)
        code_input_ids = code_tokenized.input_ids
        code_attention = code_tokenized.attention_mask

        nl_tokenized = tokenizer_nl(
            nl,
            padding="max_length",
            max_length=64,
            truncation=True)
        nl_input_ids = nl_tokenized.input_ids

        return {
            "input_ids": code_input_ids,
            "labels": nl_input_ids,
            "attention_mask": code_attention
        }

In [63]:
# TODO: Implement this and add as parameter in Trainer
def compute_metrics(eval_pred):
    print(eval_pred)
    logits, labels = eval_pred
    return 1

In [64]:
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
#model.cuda()

training_args = TrainingArguments(
    output_dir="code2nl",
    evaluation_strategy="epoch",
    num_train_epochs=10,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
)

In [65]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=Code2TextDataset(df_train),
    eval_dataset=Code2TextDataset(df_valid),
)
trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


Epoch,Training Loss,Validation Loss
1,No log,9.174901
2,No log,8.544143
3,No log,8.106703
4,No log,7.558287
5,No log,7.102834
6,No log,6.76173
7,No log,6.506609
8,No log,6.338633
9,No log,6.22875
10,No log,6.175273


TrainOutput(global_step=10, training_loss=5.397150039672852, metrics={'train_runtime': 11.7547, 'train_samples_per_second': 25.522, 'train_steps_per_second': 0.851, 'total_flos': 22835920896000.0, 'train_loss': 5.397150039672852, 'epoch': 10.0})

# Generate

In [66]:
text = \
    'java docstring: void printMessage(String message) { int x = 2; int y = 3; System.out.println(x+y) }'

input_ids = tokenizer_code(text, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids)

print("\n//".join(tokenizer_nl.decode(outputs[0], skip_special_tokens=True)
    .split("//"))[1:])



ava docstring: void printMessage(String message)  int
