In [1]:
import random
from lang import load_data_int_seq
from utils import eq_encoder, is_eq_valid,normalize_0_1
from typing import List
from models.rl.env import IntegerSequenceEnv, get_current_position, encode_with_lang, decode_with_lang, eq_to_seq
from models.rl.training import train_env
from models.rl.agents.agent_deepqn.agent import DeepQAgent, DeepQAgentConfig
import gym, ray
import numpy as np
from stable_baselines3 import PPO


%load_ext autoreload
%autoreload 2

MAX_PENALTY_MAGNITUDE = 999.0

In [2]:
output_lang, input_lang, train, X_test, y_test = load_data_int_seq()
train = train[:1]

In [3]:
def compare_sequences(output_sequence: List[int], target_sequence: List[int]) -> float:

    if len(output_sequence) != len(target_sequence):
        raise AssertionError("sequence size don't match: " + ','.join(str(e) for e in output_sequence)
 + " | " + ','.join(str(e) for e in target_sequence))

    magnitude: float = 0.0

    combined_seq = np.vstack([output_sequence, target_sequence]) 
    norm_comb_seq = normalize_0_1(combined_seq)

    norm_output_seq = norm_comb_seq[0]
    norm_target_seq = norm_comb_seq[1]

    
    for x, y in zip(norm_target_seq, norm_output_seq):
        magnitude += abs(x - y)#**2

    # magnitude /= len(norm_target_seq)

    return 10 - (magnitude * 100)



In [4]:
compare_sequences([1,2, 3, 100], [1,2,3, 100])

10.0

In [5]:
def evaluate_candidate_eq(candidate: str, int_seq: List[int]) -> float:
    if is_eq_valid(candidate) == False:
        return -MAX_PENALTY_MAGNITUDE

    output_sequence = eq_to_seq(candidate, 9)

    if np.count_nonzero(output_sequence) < 1:
        return -MAX_PENALTY_MAGNITUDE

    return compare_sequences(output_sequence, int_seq)

In [6]:
env = IntegerSequenceEnv({"data": train, "output_length": 9, "input_lang": input_lang, "output_lang": output_lang, "evaluate": evaluate_candidate_eq})

In [7]:
env.step(5)

([7, -1, -1, -1, -1, -1, -1, -1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 0, False, {})

In [8]:
env.step(1)

([7, 3, -1, -1, -1, -1, -1, -1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 0, False, {})

In [9]:
env.reset()

[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [10]:
agent = DeepQAgent(state_size = 18, action_size = env.action_space.n)

In [11]:
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=100000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 2.52     |
|    ep_rew_mean     | -999     |
| time/              |          |
|    fps             | 2100     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 2.7         |
|    ep_rew_mean          | -999        |
| time/                   |             |
|    fps                  | 1508        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010321575 |
|    clip_fraction        | 0.139       |
|    clip_range           | 0.2         |
|    entropy_loss   

<stable_baselines3.ppo.ppo.PPO at 0x13e3d4250>

In [14]:
obs = env.reset()
for i in range(10):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

env.close()

([12, -1, -1, -1, -1, -1, -1, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, -1, -1, -1, -1, -1, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, 12, -1, -1, -1, -1, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, 12, 6, -1, -1, -1, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, 12, 6, 12, -1, -1, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, 12, 6, 12, 6, -1, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, 12, 6, 12, 6, 12, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, 12, 6, 12, 6, 12, 6, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, 6, 12, 6, 12, 6, 12, 6, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
([12, -1, -1, -1, -1, -1, -1, -1, -1], [2, 3, 4, 5, 6, 7, 8, 9, 10])
