In [1]:
import numpy as np

In [2]:
a = np.ones([2, 3, 4])

In [3]:
print(a[..., 0])

[[1. 1. 1.]
 [1. 1. 1.]]


In [4]:
print(a[:, 0])

[[1. 1. 1. 1.]
 [1. 1. 1. 1.]]


In [9]:
print(a[..., 3])

[[1. 1. 1.]
 [1. 1. 1.]]


In [10]:
a = np.ones([2, 3, 4, 2])

In [13]:
print(a[..., 0])

[[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]]


In [14]:
a = np.ones([2, 3])

In [15]:
a

array([[1., 1., 1.],
       [1., 1., 1.]])

In [19]:
a[..., 2]

array([1., 1.])

In [27]:
a = np.zeros([4, 3, 2])
for i in range(0, 4):
    for j in range(0, 3):
        for k in range (0, 2):
            a[i, j, k] = (i+1)*(j+1)*(k+1)


In [21]:
a

array([[0., 0.],
       [0., 1.],
       [0., 2.]])

In [23]:
a

array([[1., 2.],
       [2., 4.],
       [3., 6.]])

In [24]:
a[...]

array([[1., 2.],
       [2., 4.],
       [3., 6.]])

In [25]:
a[..., 0]

array([1., 2., 3.])

In [26]:
a[..., 1]

array([2., 4., 6.])

In [28]:
a

array([[[ 1.,  2.],
        [ 2.,  4.],
        [ 3.,  6.]],

       [[ 2.,  4.],
        [ 4.,  8.],
        [ 6., 12.]],

       [[ 3.,  6.],
        [ 6., 12.],
        [ 9., 18.]],

       [[ 4.,  8.],
        [ 8., 16.],
        [12., 24.]]])

In [None]:




a[..., 0]

array([[ 1.,  2.,  3.],
       [ 2.,  4.,  6.],
       [ 3.,  6.,  9.],
       [ 4.,  8., 12.]])

In [30]:
a[..., 0, ...]

IndexError: an index can only have a single ellipsis ('...')

In [31]:
a[...]

array([[[ 1.,  2.],
        [ 2.,  4.],
        [ 3.,  6.]],

       [[ 2.,  4.],
        [ 4.,  8.],
        [ 6., 12.]],

       [[ 3.,  6.],
        [ 6., 12.],
        [ 9., 18.]],

       [[ 4.,  8.],
        [ 8., 16.],
        [12., 24.]]])

In [None]:
import json
import os
import time
from absl import app
from absl import flags
from param import KSParam
from dataset import KSInitDataSet
from value import ValueTrainer
from policy import KSPolicyTrainer
from util import print_elapsedtime

flags.DEFINE_string("config_path", "./configs/KS/game_nn_n50.json",
                    """The path to load json file.""",
                    short_name='c')
flags.DEFINE_string("exp_name", "test",
                    """The suffix used in model_path for save.""",
                    short_name='n')
FLAGS = flags.FLAGS

def main(argv):
    del argv
    with open(FLAGS.config_path, 'r') as f:
        config = json.load(f)
    print("Solving the problem based on the config path {}".format(FLAGS.config_path))
    mparam = KSParam(config["n_agt"], config["beta"], config["mats_path"])
    # save config at the beginning for checking
    model_path = "../data/simul_results/KS/{}_{}_n{}_{}".format(
        "game" if config["policy_config"]["opt_type"] == "game" else "sp",
        config["dataset_config"]["value_sampling"],
        config["n_agt"],
        FLAGS.exp_name,
    )
    os.makedirs(model_path, exist_ok=True)
    with open(os.path.join(model_path, "config_beg.json"), 'w') as f:
        json.dump(config, f)

    start_time = time.monotonic()

    # initial value training
    init_ds = KSInitDataSet(mparam, config)
    value_config = config["value_config"]
    if config["init_with_bchmk"]:
        init_policy = init_ds.k_policy_bchmk
        policy_type = "pde"
        # TODO[juju]: change all "pde" to "conventional"
    else:
        init_policy = init_ds.c_policy_const_share
        policy_type = "nn_share"
    train_vds, valid_vds = init_ds.get_valuedataset(init_policy, policy_type, update_init=False)
    vtrainers = [ValueTrainer(config) for i in range(value_config["num_vnet"])]
    for vtr in vtrainers:
        vtr.train(train_vds, valid_vds, value_config["num_epoch"], value_config["batch_size"])

    # iterative policy and value training
    policy_config = config["policy_config"]
    ptrainer = KSPolicyTrainer(vtrainers, init_ds)
    ptrainer.train(policy_config["num_step"], policy_config["batch_size"])

    # save config and models
    with open(os.path.join(model_path, "config.json"), 'w') as f:
        json.dump(config, f)
    for i, vtr in enumerate(vtrainers):
        vtr.save_model(os.path.join(model_path, "value{}.h5".format(i)))
    ptrainer.save_model(os.path.join(model_path, "policy.h5"))

    end_time = time.monotonic()
    print_elapsedtime(end_time - start_time)

if __name__ == '__main__':
    app.run(main)