### Initialization
* Check whether the runtime is host or local.
* Mount Google Drive when using the host runtime.

In [None]:
try:
    from google.colab import drive
    drive.mount('/gdrive')
    runtime = "host"
except:
    runtime = "local"

### Parameters

In [None]:
seed = 20367 #@param {type: "number"}

repository_url = "https://github.com/HiroakiMikami/NL2Prog" #@param {type: "string"}
branch_name = "master" #@param {type: "string"}

word_threshold =  3 #@param {type: "number"}
token_threshold = 0 #@param {type: "number"}
max_token_len = 32 #@param {type: "number"}
max_arity = 16 #@param {type: "number"}

num_heads = 1 #@param {type: "number"}
num_nl_reader_blocks = 6 #@param {type: "number"}
num_ast_reader_blocks = 6 #@param {type: "number"}
num_decoder_blocks = 6 #@param {type: "number"}
hidden_size = 256 #@param {type: "number"}
feature_size = 1024  #@param {type: "number"}

batch_size = 1 #@param {type: "number"}
dropout = 0.15 #@param {type: "number"}
num_epochs = 50 #@param {type: "number"}
num_train = 0 #@param {type: "number"}
num_save_models = 3 #@param {type: "number"}

device = 0 #@param {type: "number"}

output_dir_path = "/gdrive/My Drive/NL2Prog/hearthstone/treegen" #@param {type: "string"}

### Setup
* Download the codebase (when using the host runtime)
  1. Clone git repository and move to the specified branch
  2. Install modules
* Use GPU
* Fix the random seed

In [None]:
if runtime == "host":
    %cd /content
    !rm -rf NL2Prog
    !git clone $repository_url NL2Prog
    %cd NL2Prog
    !git checkout $branch_name
    !pip install .
# load tqdm
!pip install --force https://github.com/chengs/tqdm/archive/colab.zip

In [None]:
import torch
if device != -1:
    torch.cuda.set_device(device)

In [None]:
import numpy as np
import random
import torch

SEED_MAX = 2**32 - 1

root_rng = np.random.RandomState(seed)
random.seed(root_rng.randint(SEED_MAX))
np.random.seed(root_rng.randint(SEED_MAX))
torch.manual_seed(root_rng.randint(SEED_MAX))

### Setup training
* Load the dataset
* Split the dataset into train, test, valid
* Create and save encoder
* Prepare dataset
* Create model
* Create optimizer

In [None]:
from nl2prog.dataset.hearthstone import download
dataset = download()

In [None]:
from nl2prog.utils.data import ListDataset, Entry
train_raw_dataset = dataset["train"]
if num_train != 0:
    train_raw_dataset = ListDataset(list(train_raw_dataset)[:num_train])

In [None]:
from torchnlp.encoders import LabelEncoder
from nl2prog.encoders import ActionSequenceEncoder
from nl2prog.utils.data import get_samples, get_words, get_characters
from nl2prog.utils.python import tokenize_query, tokenize_token
from nl2prog.language.action import code_to_action_sequence as to_seq, ActionOptions
from nl2prog.language.python import parse
import pickle
import os

to_action_sequence = lambda x: to_seq(x, parse, tokenize=tokenize_token, options=ActionOptions(False, False))
words = get_words(train_raw_dataset, tokenize_query)
chars = get_characters(train_raw_dataset, tokenize_query)
samples = get_samples(train_raw_dataset, tokenize_token, to_action_sequence)
qencoder = LabelEncoder(words, word_threshold)
cencoder = LabelEncoder(chars, 0)
aencoder = ActionSequenceEncoder(samples, token_threshold)

os.makedirs(output_dir_path, exist_ok=True)
with open(os.path.join(output_dir_path, "encoder.pickle"), "wb") as file:
    pickle.dump({
        "query_encoder": qencoder,
        "character_encoder": cencoder,
        "action_sequence_encoder": aencoder
    }, file)

In [None]:
from nl2prog.utils.data.treegen import to_train_dataset
train_dataset = to_train_dataset(train_raw_dataset, tokenize_query,
                                 tokenize_token, to_action_sequence,
                                 qencoder, cencoder, aencoder,
                                 max_token_len, max_arity)

In [None]:
from nl2prog.nn.treegen import TrainModel
model = TrainModel(qencoder, cencoder, aencoder, max_token_len, max_arity,
                   num_heads, num_nl_reader_blocks, num_ast_reader_blocks,
                   num_decoder_blocks, hidden_size, feature_size, dropout)
if device != -1:
    model = model.cuda()

In [None]:
!pip install fairseq

In [None]:
import fairseq.optim as optim
optimizer = optim.adafactor.Adafactor(model.parameters())

### Training Loop
* Run training

In [None]:
from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from nl2prog.utils import TopKModel
from nl2prog.nn import Loss, Accuracy
import nl2prog.nn.utils.rnn as nrnn
from nl2prog.utils.data.treegen import collate_train_dataset


def to_cuda(pad_seq):
    pad_seq.data = pad_seq.data.cuda()
    pad_seq.mask = pad_seq.mask.cuda()
    return pad_seq

model_dir_path = os.path.join(output_dir_path, "models")
os.makedirs(model_dir_path, exist_ok=True)
models = TopKModel(num_save_models, model_dir_path)
loss_function = Loss()
acc_function = Accuracy()

for epoch in tqdm(range(num_epochs)):
    loader = DataLoader(train_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=4,
                        collate_fn=collate_train_dataset)
    avg_loss = 0.0
    avg_acc = 0.0
    model.train()
    for data, ground_truth in loader:
        word_query = data[0]
        char_query = data[1]
        prev_action = data[2]
        rule_prev_action = data[3]
        depth = data[4]
        matrix = data[5]
        word_query = nrnn.pad_sequence(word_query, padding_value=-1)
        char_query = nrnn.pad_sequence(char_query, padding_value=-1)
        prev_action = nrnn.pad_sequence(prev_action, padding_value=-1)
        rule_prev_action = \
            nrnn.pad_sequence(rule_prev_action, padding_value=-1)
        depth = \
            nrnn.pad_sequence(depth).data.reshape(1, -1).permute(1, 0)
        ground_truth = \
            nrnn.pad_sequence(ground_truth, padding_value=-1)
        L = prev_action.data.shape[0]
        matrix = [F.pad(m, (0, L - m.shape[0], 0, L - m.shape[1]))
                  for m in matrix]
        matrix = nrnn.pad_sequence(matrix).data.permute(1, 0, 2)
        if device != -1:
            to_cuda(word_query)
            to_cuda(char_query)
            to_cuda(prev_action)
            to_cuda(rule_prev_action)
            to_cuda(ground_truth)
            depth = depth.cuda()
            matrix = matrix.cuda()

        rule_prob, token_prob, copy_prob = model(
            word_query, char_query, prev_action, rule_prev_action,
            depth, matrix)
        loss = loss_function(rule_prob, token_prob, copy_prob, ground_truth)
        with torch.no_grad():
            acc = acc_function(rule_prob, token_prob, copy_prob, ground_truth)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item() / len(loader)
        avg_acc += acc.item() / len(loader)
    print(epoch, avg_loss, avg_acc)
    models.save(avg_acc, str(epoch), model)