In [None]:
!pip install stable-baselines3

In [10]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
import torch as th

In [16]:
def binatodeci(binary):
        return sum(val*(2**idx) for idx, val in enumerate(reversed(binary)))

In [39]:
primes = {}
for i in range(1, int(binatodeci(np.ones(16)))+1):
    # Iterate from 2 to n // 2
    is_prime=1
    for j in range(2, (i//2)+1):
        # If num is divisible by any number between
        # 2 and n / 2, it is not prime
        if (i % j) == 0:
            is_prime = 0
            break
    primes[i] = is_prime

primes[1] = 1

In [41]:
class PrimeEnv(gym.Env):
    def __init__(self, max_bits=16):
        self.observation_space = spaces.MultiBinary(max_bits)
        self.action_space = spaces.MultiBinary(max_bits)
    
    def reset(self, **kwargs):
        start_state = np.random.choice(np.array(list(primes.keys())))
        self.state_deci = start_state
        start_state = [int(i) for i in list('{0:0b}'.format(start_state))]
        start_state = [0] * (16 - len(start_state)) + start_state
        start_state = np.array(start_state)
        self.state_bin = start_state
        return start_state, {}
    
    def step(self, action):
        next_state_bin = action
        next_state_decimal = binatodeci(next_state_bin)
        done = False
        if next_state_decimal <= self.state_deci:
            done = True
        
        reward = primes[next_state_decimal]
        self.state_deci = next_state_decimal,
        self.state_bin = next_state_bin
        return next_state_bin, reward, done, False, {}
 

In [42]:
model = PPO("MlpPolicy", env = gym.wrappers.time_limit.TimeLimit(PrimeEnv(),500), verbose = 1)
model.learn(1e6)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.91     |
|    ep_rew_mean     | 0.22     |
| time/              |          |
|    fps             | 459      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
--------------------------------------
| rollout/                |          |
|    ep_len_mean          | 1.88     |
|    ep_rew_mean          | 0.17     |
| time/                   |          |
|    fps                  | 421      |
|    iterations           | 2        |
|    time_elapsed         | 9        |
|    total_timesteps      | 4096     |
| train/                  |          |
|    approx_kl            | 0.014381 |
|    clip_fraction        | 0.225    |
|    clip_range           | 0.2      |
|    entropy_loss         | -11.1    |
|    explained_varia

<stable_baselines3.ppo.ppo.PPO at 0x79a798140c10>

In [43]:
#Generate primes

env = PrimeEnv()
longest_sequence = 0
for ep in range(100):
    s, _ = env.reset()
    # print("Starting number: {}".format(binatodeci(s)))
    done = False
    l = 0
    r = 1
    while not done or r ==0:
        a = model.predict(s, deterministic=True)[0]
        s, r, done, _, _ = env.step(a)
        l+=1
        # print("Next number predicted: {}".format(binatodeci(s)))
        # print("Is prime: {}".format(bool(r)))
    longest_sequence = max(longest_sequence, l)
    print(longest_sequence)

11
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13


In [44]:
s, _ = env.reset()
env.state_bin = np.zeros(16)
env.state_bin[-1] = 1
env.state_deci = 1
s = env.state_bin
# print("Starting number: {}".format(binatodeci(s)))
done = False
l = 0
while not done:
    a = model.predict(s, deterministic=True)[0]
    s, r, done, _, _ = env.step(a)
    print(binatodeci(s))
    l+=1
    # print("Next number predicted: {}".format(binatodeci(s)))
    # print("Is prime: {}".format(bool(r)))
longest_sequence = max(longest_sequence, l)
print(longest_sequence)

40847.0
51439.0
57389.0
57397.0
61493.0
63589.0
64621.0
65167.0
65293.0
65353.0
65357.0
65423.0
65353.0
13
