### Initialization
* Check whether the runtime is host or local.
* Mount Google Drive when using the host runtime.

In [0]:
try:
  from google.colab import drive
  drive.mount('/gdrive')
  runtime = "host"
except:
  runtime = "local"

### Parameters

In [0]:
#@title Parameters
#@markdown |Name            |Description|
#@markdown |:---            |:---|
#@markdown |`seed`|The random seed|
seed = 3984 #@param {type: "number"}

#@markdown ### `deep-coder` Repositories
#@markdown |Name            |Description|
#@markdown |:---            |:---|
#@markdown |`repository_url`|The URL of `deep-coder` git repository (enabled only in the host runtime)|
#@markdown |`branch_name`   |The branch name (enabled only in the host runtime)|
repository_url = "https://github.com/HiroakiMikami/deep-coder" #@param {type: "string"}
branch_name = "master" #@param {type: "string"}

#@markdown ### Dataset Configurations
#@markdown |Name          |Description|
#@markdown |:---          |:---|
#@markdown |`num_dataset` |The total number of programs in the dataset. If it is -1, the program will enumerate all valid source code.|
#@markdown |`num_valid`   |The number of programs used for validation.|
#@markdown |`value_range` |The largest absolute value used in the dataset.|
#@markdown |`max_list_length` |The maximum length of lists used in the dataset.|
#@markdown |`num_examples`|The number of I/O examples per program|
#@markdown |`min_length`  |The minimum length of the program body|
#@markdown |`max_length`  |The maximum length of the program body|
#@markdown |`num_examples_for_pruning`|The number of examples used to prune the identical programs.|
num_dataset = -1 #@param {type: "number"}
num_valid = 10 #@param {type: "number"}
value_range = 256 #@param {type: "number"}
max_list_length = 20 #@param {type: "number"}
num_examples = 5 #@param {type: "number"}
min_length = 1 #@param {type: "number"}
max_length = 2 #@param {type: "number"}
num_examples_for_pruning = 100 #@param {type: "number"}

#@markdown ### Filepath
#@markdown |Name                   |Description|
#@markdown |:---                   |:---|
#@markdown |`destination_dir_path` |The directory of the directory that will contain the dataset.|
destination_dir_path = "dataset/" #@param {type: "string"}


### Setup
* Fix the random seed
* Download the codebase (when using the host runtime)
  1. Clone git repository and move to the specified branch
  2. Initialize submodule
  3. Build the `search` tool
  4. Install chainer and cupy
* Remove the temporary file

In [0]:
import numpy as np
import random

SEED_MAX = 2**32 - 1

root_rng = np.random.RandomState(seed)
random.seed(root_rng.randint(SEED_MAX))
np.random.seed(root_rng.randint(SEED_MAX))

In [0]:
if runtime == "host":
  %cd /content
  !rm -rf deep-coder
  ![ ! -e deep-coder ] && git clone $repository_url deep-coder
  %cd deep-coder
  !git checkout origin/$branch_name
  !git submodule init
  !git submodule update
  !make -C DeepCoder_Utils/enumerative-search -j `nproc`
  !curl https://colab.chainer.org/install | sh -
  !pip install tqdm

In [0]:
!rm -rf ./dataset.pickle

### Generate Dataset
* Generate the total dataset
* Divide the dataset into `train` and `valid`.

In [0]:
import numpy as np
import os
from tqdm import tqdm_notebook as tqdm

from src.dsl import to_function, Program
from src.deepcoder_utils import generate_io_samples
from src.generate_dataset import generate_dataset, DatasetSpec, EquivalenceCheckingSpec, IteratorDecorator
from src.program_simplifier import remove_redundant_variables, remove_redundant_expressions, remove_dependency_between_variables

LINQ, _ = generate_io_samples.get_language(value_range)
LINQ = [f for f in LINQ if not "IDT" in f.src]

MINIMUM = to_function([f for f in LINQ if f.src == "MINIMUM"][0])
MAXIMUM = to_function([f for f in LINQ if f.src == "MAXIMUM"][0])


def simplify(program):
    program = remove_redundant_expressions(program)
    program = remove_redundant_variables(program)
    program = remove_dependency_between_variables(program, MINIMUM, MAXIMUM)
    return program


# TODO: tqdm_notebook does not work in a local runtime
program_iterator = lambda iterator: tqdm(iterator, desc="Program Generation")
entry_iterator = lambda iterator: tqdm(iterator, desc="Prune Entries")
decorator = IteratorDecorator(program_iterator, entry_iterator)

generate_dataset(LINQ,
             DatasetSpec(value_range, max_list_length,
                         num_examples, min_length, max_length),
             EquivalenceCheckingSpec(0, num_examples_for_pruning, np.random.RandomState(
                 root_rng.randint(SEED_MAX))),
             "./dataset.pickle", num_dataset if num_dataset > 0 else None,
             simplify=simplify, decorator=decorator)


In [0]:
import pickle
import chainer as ch
from src.dataset import Dataset

if not os.path.exists(destination_dir_path):
    os.makedirs(destination_dir_path)

with open("./dataset.pickle", "rb") as f:
    dataset: Dataset = pickle.load(f)

num_valid = num_valid
num_train = len(dataset.dataset) - num_valid

train, valid = ch.datasets.split_dataset_random(
    dataset.dataset, num_train, seed=root_rng.randint(SEED_MAX))

with open(os.path.join(destination_dir_path, "train.pickle"), "wb") as f:
    pickle.dump(Dataset(ch.datasets.TupleDataset(
        list([d[0] for d in train])), dataset.metadata), f)
with open(os.path.join(destination_dir_path, "valid.pickle"), "wb") as f:
    pickle.dump(Dataset(ch.datasets.TupleDataset(
        list([d[0] for d in valid])), dataset.metadata), f)


### Teardown
* Remove the temporary file

In [0]:
!rm dataset.pickle