In [1]:
import random
from lang import load_data_int_seq
from utils import eq_encoder, is_eq_valid,normalize_0_1
from typing import List
from models.rl.env import IntegerSequenceEnv, get_current_position, encode_with_lang, decode_with_lang, eq_to_seq
import gym, ray
from ray.rllib.agents import ppo
import numpy as np

%load_ext autoreload
%autoreload 2

MAX_PENALTY_MAGNITUDE = 999.0

In [2]:
ray.init(ignore_reinit_error=True)


2021-07-08 10:57:39,304	INFO services.py:1272 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.2.103',
 'raylet_ip_address': '192.168.2.103',
 'redis_address': '192.168.2.103:6379',
 'object_store_address': '/tmp/ray/session_2021-07-08_10-57-37_623152_15228/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-07-08_10-57-37_623152_15228/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2021-07-08_10-57-37_623152_15228',
 'metrics_export_port': 56333,
 'node_id': '319a2704948e184912874f2900a4f3f1af518787c9fa7b3d746c6813'}

In [3]:
output_lang, input_lang, train, X_test, y_test = load_data_int_seq()

In [5]:
def compare_sequences(output_sequence: List[int], target_sequence: List[int]) -> float:

    if len(output_sequence) != len(target_sequence):
        raise AssertionError("sequence size don't match: " + ','.join(str(e) for e in output_sequence)
 + " | " + ','.join(str(e) for e in target_sequence))

    magnitude: float = 0.0

    combined_seq = np.vstack([output_sequence, target_sequence]) 
    norm_comb_seq = normalize_0_1(combined_seq)

    norm_output_seq = norm_comb_seq[0]
    norm_target_seq = norm_comb_seq[1]

    
    for x, y in zip(norm_target_seq, norm_output_seq):
        magnitude += abs(x - y)#**2

    # magnitude /= len(norm_target_seq)

    return 10 - (magnitude * 100)



In [6]:
compare_sequences([1,2, 3, 100], [1,2,3, 100])

10.0

In [7]:
def evaluate_candidate_eq(candidate: str, int_seq: List[int]) -> float:
    if is_eq_valid(candidate) == False:
        return -MAX_PENALTY_MAGNITUDE

    output_sequence = eq_to_seq(candidate, 9)

    if np.count_nonzero(output_sequence) < 1:
        return -MAX_PENALTY_MAGNITUDE

    return compare_sequences(output_sequence, int_seq)

In [8]:
env = IntegerSequenceEnv({"data": train, "output_length": 9, "input_lang": input_lang, "output_lang": output_lang, "evaluate": evaluate_candidate_eq})

[32 33 34 35 36 37 38 39 40]


In [22]:
env.step(5)

(([7, -1, -1, -1, -1, -1, -1, -1, -1], [3, 43, 4, 45, 5, 46, 6, 20, 7]),
 0,
 False,
 {})

In [24]:
env.step(1)

(([7, 3, 3, -1, -1, -1, -1, -1, -1], [3, 43, 4, 45, 5, 46, 6, 20, 7]),
 -999.0,
 True,
 {})

In [25]:
env.reset()

[15 16 17 18 19 20 21 22 23]


([-1, -1, -1, -1, -1, -1, -1, -1, -1], [120, 5, 19, 46, 65, 6, 121, 20, 12])

In [26]:
trainer = ppo.PPOTrainer(env=IntegerSequenceEnv, config={
    "env_config": {
        "data": train,
        "output_length": 9,
        "input_lang": input_lang,
        "output_lang": output_lang,
        "evaluate": evaluate_candidate_eq
    },
    "num_envs_per_worker": 1,
    # "train_batch_size": 10000,
    "num_workers": 4,
})

[2m[36m(pid=15790)[0m [23 24 25 26 27 28 29 30 31]
2021-07-08 11:34:57,744	INFO trainable.py:101 -- Trainable.setup took 87.101 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [27]:

for i in range(10):
   # Perform one iteration of training the policy with PPO
   result = trainer.train()
   print(result["episode_reward_mean"])
#    print(result)

   if i % 100 == 0:
       checkpoint = trainer.save()
       print("checkpoint saved at", checkpoint)

-8 -10 -12 -14]
[2m[36m(pid=15790)[0m [-7 -6 -5 -4 -3 -2 -1  0  1]
[2m[36m(pid=15790)[0m [ 0 -1 -2 -3 -4 -5 -6 -7 -8]
[2m[36m(pid=15790)[0m [  1  -5 -11 -17 -23 -29 -35 -41 -47]
[2m[36m(pid=15790)[0m [ 61 125 189 253 317 381 445 509 573]
[2m[36m(pid=15790)[0m [14 13 12 11 10  9  8  7  6]
[2m[36m(pid=15790)[0m [16 17 18 19 20 21 22 23 24]
[2m[36m(pid=15790)[0m [ 8  9 10 11 12 13 14 15 16]
[2m[36m(pid=15790)[0m [ 5  7  9 11 13 15 17 19 21]
[2m[36m(pid=15790)[0m [ -1  11  31  59  95 139 191 251 319]
[2m[36m(pid=15790)[0m [ 28  51  74  97 120 143 166 189 212]
[2m[36m(pid=15790)[0m [  3   0  -3  -6  -9 -12 -15 -18 -21]
[2m[36m(pid=15785)[0m [ 22  39  56  73  90 107 124 141 158]
[2m[36m(pid=15785)[0m [ -8  -9 -10 -11 -12 -13 -14 -15 -16]
[2m[36m(pid=15785)[0m [ 30  24  18  12   6   0  -6 -12 -18]
[2m[36m(pid=15785)[0m [17 19 21 23 25 27 29 31 33]
[2m[36m(pid=15785)[0m [  -5  -31  -75 -137 -217 -315 -431 -565 -717]
[2m[36m(pid=15785)[0m [8 7 

KeyboardInterrupt: 

In [None]:
result = trainer.train()

In [None]:
result