In [None]:
%pip install stable-baselines3[extra]
%pip install 'shimmy>=2.0'

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.5.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3.0,>=2.3->stable-baselines3[extra])
  Downloading nvidia_cublas_cu12-12.4.5.8-py

In [None]:
import random

length = 5
sums = [1, 1, 3, 3]

# set length and sums of sequences (i.e. x, y, z, w)
def generate_sequences(length, sums):
  num_seq = 4
  seq = []
  random.shuffle(sums)
  for i in range(len(sums)):
    rand = random.randint(0, 1)
    if rand == 0:
      sums[i] *= -1

  # determine number of -1's and 1's
  for i in range(4):
    num_neg = (length - sums[i]) // 2
    cur_seq = []
    cur_seq.extend([-1] * num_neg)
    cur_seq.extend([1] * (length - num_neg))
    random.shuffle(cur_seq)
    seq.append(cur_seq)

  return seq

generate_sequences(length, sums)

[[1, -1, 1, -1, 1],
 [1, -1, -1, -1, -1],
 [1, -1, -1, -1, -1],
 [-1, 1, 1, 1, -1]]

In [None]:
import numpy as np
import gym
from gym import spaces
import copy
import random

"""POTENTIAL IMPROVEMENTS:
- weigh equations differently
- don't punish for repetition
- less harsh on negatives
- make up for floating point difference when checking for Turyn
- make choosing a two step process?
- longer episodes
- partial thresholds for rewards

TO-DO:
- action masking
- log training and process
"""


class TurynEnv(gym.Env):
    def __init__(self, length):
        super(TurynEnv, self).__init__()
        self.length = length
        self.ep_lengths = []
        self.final_npafs = []
        self.seq_found = 0
        self.current_step = 0
        self.max_steps = 30
        self.sequence = generate_sequences(self.length, sums)
        self.old_npaf = self.calculate_autocorrelation()
        self.action_space = spaces.Discrete(self.count_actions())
        self.prev_action = None
        self.observation_space = spaces.Box(low=-1, high=1, shape=(4, length), dtype=np.int8)

    def count_actions(self):
      total = 4 * ((self.length * (self.length - 1)) // 2)
      return total

    def step(self, action):
        seq_num, i, j = self.decode_action(action)
        self.current_step += 1
        reward = 0
        done = False
        if self.sequence[seq_num][i] == self.sequence[seq_num][j]:
          reward = -5
        else:
          self.sequence[seq_num][i], self.sequence[seq_num][j] = self.sequence[seq_num][j], self.sequence[seq_num][i]
          new_npaf = self.calculate_autocorrelation()
          reward = (self.old_npaf - new_npaf) / self.length
          self.old_npaf = new_npaf

        if self.prev_action == action:
          reward -= 5
        self.prev_action = action

        if self.old_npaf == 0:
          self.seq_found += 1
          reward += 10
          done = True

        if self.current_step >= self.max_steps:
          done = True

        return np.array(self.sequence), reward, done, {}

    def reset(self):
        self.ep_lengths.append(self.current_step)
        self.sequence = generate_sequences(self.length, sums)
        self.old_npaf = self.calculate_autocorrelation()
        self.current_step = 0
        self.prev_action = None
        return np.array(self.sequence)

    def calculate_autocorrelation(self):
        euc_norm = 0
        for s in range(1, self.length):
          total = 0
          for i in range(self.length - s):
            for j in range(4):
              if self.sequence[j][i] == self.sequence[j][i + s]:
                total += 1
              else:
                total -= 1
          euc_norm += total * total
        return np.sqrt(euc_norm)

    def decode_action(self, action):
        for seq in range(4):
          for i in range(self.length):
            for j in range(i + 1, self.length):
              if action == 0:
                  return seq, i, j
              action -= 1

In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

env = TurynEnv(length)
model = PPO("MlpPolicy", env, verbose=1, learning_rate=1e-2) #batch_size=64, n_steps=512, ent_coef=0.01)
model.learn(total_timesteps=15000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 19.4     |
|    ep_rew_mean     | -44.8    |
| time/              |          |
|    fps             | 1082     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 20.3       |
|    ep_rew_mean          | -47.1      |
| time/                   |            |
|    fps                  | 752        |
|    iterations           | 2          |
|    time_elapsed         | 5          |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.22751957 |
|    clip_fraction        | 0.711      |
|    clip_range           | 0.2        |
|    entropy_loss         | -3.54      |
|    explained_variance   | -0.00297   |
|    learning_rate        | 0.01       |
|   

<stable_baselines3.ppo.ppo.PPO at 0x7dddc172cb90>

In [None]:
print(env.ep_lengths)
print("Number of episodes:", len(env.ep_lengths))
print("Did not find sequence:", env.ep_lengths.count(30))
print("Found sequence:", env.seq_found)

[0, 18, 12, 30, 30, 2, 24, 14, 30, 6, 30, 3, 30, 30, 7, 13, 30, 2, 6, 13, 4, 5, 14, 30, 30, 1, 30, 2, 30, 30, 30, 29, 18, 21, 1, 1, 11, 30, 30, 27, 30, 9, 4, 30, 10, 18, 1, 15, 17, 30, 30, 30, 22, 30, 2, 5, 24, 4, 27, 30, 1, 2, 30, 20, 30, 11, 28, 19, 18, 1, 30, 2, 30, 30, 1, 22, 13, 13, 8, 29, 10, 29, 30, 30, 30, 8, 19, 30, 30, 14, 9, 30, 30, 22, 30, 7, 22, 30, 2, 30, 30, 6, 25, 30, 20, 11, 1, 3, 10, 30, 21, 2, 20, 30, 30, 6, 30, 30, 11, 12, 30, 11, 7, 30, 25, 30, 16, 30, 30, 30, 8, 12, 30, 30, 9, 18, 12, 4, 14, 30, 30, 1, 30, 13, 8, 30, 30, 22, 27, 23, 29, 30, 30, 14, 19, 1, 30, 19, 30, 30, 30, 5, 20, 2, 30, 30, 4, 30, 15, 3, 30, 30, 30, 9, 13, 30, 12, 30, 30, 1, 30, 30, 8, 1, 30, 3, 1, 11, 4, 12, 27, 29, 30, 30, 30, 6, 3, 13, 30, 30, 1, 30, 10, 30, 3, 30, 4, 30, 9, 30, 18, 6, 30, 15, 13, 16, 18, 12, 1, 10, 1, 5, 30, 30, 30, 16, 29, 1, 21, 10, 27, 30, 30, 20, 21, 30, 1, 30, 30, 30, 30, 1, 18, 1, 21, 30, 30, 1, 23, 1, 30, 30, 1, 11, 30, 30, 30, 30, 30, 19, 1, 7, 2, 30, 30, 25, 18, 30,

In [None]:
obs = env.reset()
for i in range(30):
    #print("Sequence:", env.sequence)
    action, _states = model.predict(obs, deterministic=True)
    print(obs.shape)
    #print(env.decode_action(action))
    obs, reward, done, info = env.step(action)
    print("Reward:", reward)
    print("NPAF:", env.old_npaf)
    if done:
        obs = env.reset()


(4, 5)
Reward: 1.1587594532111236
NPAF: 6.324555320336759
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
(4, 5)
Reward: -5
NPAF: 4.0
