### Initialization
* Check whether the runtime is host or local.
* Mount Google Drive when using the host runtime.

In [0]:
try:
    from google.colab import drive
    drive.mount('/gdrive')
    runtime = "host"
except:
    runtime = "local"

### Parameters

In [0]:
seed = 20367 #@param {type: "number"}

repository_url = "https://github.com/HiroakiMikami/NL2Prog" #@param {type: "string"}
branch_name = "master" #@param {type: "string"}

word_threshold =  3 #@param {type: "number"}
token_threshold = 0 #@param {type: "number"}
max_token_len = 32 #@param {type: "number"}
max_arity = 16 #@param {type: "number"}

num_heads = 1 #@param {type: "number"}
num_nl_reader_blocks = 6 #@param {type: "number"}
num_ast_reader_blocks = 6 #@param {type: "number"}
num_decoder_blocks = 6 #@param {type: "number"}
hidden_size = 256 #@param {type: "number"}
feature_size = 1024  #@param {type: "number"}


max_action_length = 350 #@param {type: "number"}
beam_size = 15 #@param {type: "number"}

dropout = 0.15 #@param {type: "number"}

device = 0 #@param {type: "number"}

output_dir_path = "/gdrive/My Drive/NL2Prog/hearthstone/treegen" #@param {type: "string"}

### Setup
* Download the codebase (when using the host runtime)
  1. Clone git repository and move to the specified branch
  2. Install modules
* Use GPU
* Fix the random seed

In [0]:
if runtime == "host":
    %cd /content
    !rm -rf NL2Prog
    !git clone $repository_url NL2Prog
    %cd NL2Prog
    !git checkout $branch_name
    !pip install .
# load tqdm
!pip install --force https://github.com/chengs/tqdm/archive/colab.zip

In [0]:
import torch
if device != -1:
    torch.cuda.set_device(device)

In [0]:
import numpy as np
import random
import torch

SEED_MAX = 2**32 - 1

root_rng = np.random.RandomState(seed)
random.seed(root_rng.randint(SEED_MAX))
np.random.seed(root_rng.randint(SEED_MAX))
torch.manual_seed(root_rng.randint(SEED_MAX))

### Setup training
* Load the dataset
* Split the dataset into train, test, valid
* Create and save encoder
* Prepare dataset
* Create model
* Create optimizer
* Prepare evaluation
* Load checkpoint

In [0]:
from nl2prog.dataset.hearthstone import download
dataset = download()

In [0]:
from nl2prog.utils.data import ListDataset, Entry
test_raw_dataset = dataset["test"]
val_raw_dataset = dataset["valid"]

In [0]:
import pickle
import os


with open(os.path.join(output_dir_path, "encoder.pickle"), "rb") as file:
    encoder = pickle.load(file)
    qencoder = encoder["query_encoder"]
    cencoder = encoder["character_encoder"]
    aencoder = encoder["action_sequence_encoder"]

In [0]:
from nl2prog.utils.data import to_eval_dataset


test_dataset = to_eval_dataset(test_raw_dataset)
valid_dataset = to_eval_dataset(val_raw_dataset)

In [0]:
from nl2prog.nn.treegen import TrainModel
model = TrainModel(qencoder, cencoder, aencoder, max_token_len, max_arity,
                   num_heads, num_nl_reader_blocks, num_ast_reader_blocks,
                   num_decoder_blocks, hidden_size, feature_size, dropout)
if device != -1:
    model = model.cuda()

In [0]:
import nl2prog.nn.utils.rnn as nrnn
from nl2prog.utils import synthesize as _synthesize
from nl2prog.utils.treegen import BeamSearchSynthesizer
from nl2prog.utils.python import tokenize_query
from nl2prog.language.action import ActionOptions
from nl2prog.language.python import is_subtype, parse, unparse
from nl2prog.metrics import Accuracy
from nl2prog.metrics.python import Bleu


synthesizer = BeamSearchSynthesizer(
    beam_size, tokenize_query, model.query_embedding, model.rule_embedding,
    model.nl_reader, model.ast_reader, model.decoder, model.predictor,
    qencoder, cencoder, aencoder, max_token_len, max_arity, is_subtype,
    options=ActionOptions(False, False),
    max_steps=max_action_length)

def synthesize(query: str):
    return _synthesize(query, synthesizer)

accuracy = Accuracy(parse, unparse)
bleu = Bleu(parse, unparse)
metrics = { "accuracy": accuracy, "bleu": bleu }

### Run Validation

In [0]:
from tqdm import tqdm_notebook as tqdm
import os
import torch
import torch.nn.utils.rnn as rnn
import nl2prog.nn.utils.rnn as nrnn
from nl2prog.utils import evaluate

# Test the model
best_score = -1
best_score_path = None
model_dir_path = os.path.join(output_dir_path, "models")
for m in os.listdir(model_dir_path):
    path = os.path.join(model_dir_path, m)
    model.load_state_dict(torch.load(path)["model"])
    model.eval()
    result = evaluate(tqdm(test_dataset), synthesize, top_n=[1], metrics=metrics)
    print(m, result.metrics)
    score = result.metrics[1]["bleu"]
    if score > best_score:
        best_score = score
        best_score_path = path
print("Best Model: {}".format(best_score_path))

In [0]:
from tqdm import tqdm_notebook as tqdm
import os
import pickle
import torch
import torch.nn.utils.rnn as rnn
import nl2prog.nn.utils.rnn as nrnn
from nl2prog.utils import evaluate

# Validate the model
model.load_state_dict(torch.load(best_score_path)["model"])
model.eval()
result = evaluate(tqdm(valid_dataset), synthesize, top_n=[1], metrics=metrics)
print(result.metrics)
with open(os.path.join(output_dir_path, "validation_results.pickle"), "wb") as file:
    pickle.dump(result.results, file)