### Initialization
* Check whether the runtime is host or local.
* Mount Google Drive when using the host runtime.

In [0]:
try:
  from google.colab import drive
  drive.mount('/gdrive')
  runtime = "host"
except:
  runtime = "local"

### Parameters

In [0]:
#@title Parameters
#@markdown |Name            |Description|
#@markdown |:---            |:---|
#@markdown |`seed`|The random seed|
seed = 3984 #@param {type: "number"}

#@markdown ### `nl2code` Repositories
#@markdown |Name            |Description|
#@markdown |:---            |:---|
#@markdown |`repository_url`|The URL of `nl2code` git repository (enabled only in the host runtime)|
#@markdown |`branch_name`   |The branch name (enabled only in the host runtime)|
repository_url = "https://github.com/HiroakiMikami/NL2Code-reimplementation" #@param {type: "string"}
branch_name = "master" #@param {type: "string"}

#@markdown ### Filepathes
#@markdown |Name         |Description|
#@markdown |:---         |:---|
#@markdown |`result_path`|The path of the progress data.|
result_path = "./progress.pickle" #@param {type: "string"}



### Setup
* Fix the random seed
* Download the codebase (when using the host runtime)
  1. Clone git repository and move to the specified branch
  2. Install modules

In [0]:
import numpy as np
import random

SEED_MAX = 2**32 - 1

root_rng = np.random.RandomState(seed)
random.seed(root_rng.randint(SEED_MAX))
np.random.seed(root_rng.randint(SEED_MAX))

In [0]:
if runtime == "host":
    %cd /content
    !rm -rf NL2Code
    ![ ! -e NL2Code ] && git clone $repository_url NL2Code
    %cd NL2Code
    !git checkout origin/$branch_name
    !pip install -e .
    !pip install -e . ".[django]"

### Visualize result
* Load the result
* Create a graph

In [0]:
import pickle
with open(result_path, "rb") as f:
    result = pickle.load(f)


In [0]:
from graphviz import Digraph
from IPython.display import Image
import tempfile

dot = Digraph()
dot.attr('graph', rankdir="TD")
dot.attr('node', shape='record')
dot.node("0", label="{0.0|root}")
for step in result:
    ids = []
    for progress in step:
        ids.append(str(progress.id))
        bgcolor = "blue" if progress.is_complete else "black"
        label = "{" + \
            '{}|{}'.format(progress.score, progress.action) \
                .replace(">", "\>").replace("<", "\<") + \
            "}"
        dot.node(str(progress.id), color=bgcolor, label=label)
        if progress.parent is not None:
            dot.edge(str(progress.parent), str(progress.id))
    dot.body.append("{" + "rank=same; {}".format(";".join(ids)) + ";}")
dir = tempfile.TemporaryDirectory()
dot.filename = "{}/result".format(dir)
dot.format = "png"
dot.render(view=True)
Image("{}.png".format(dot.filename))