In [None]:
!pip install stable-baselines3

In [10]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
import torch as th

In [16]:
def binatodeci(binary):
        return sum(val*(2**idx) for idx, val in enumerate(reversed(binary)))

In [28]:
primes = {}
for i in range(1, int(binatodeci(np.ones(16)))+1):
    # Iterate from 2 to n // 2
    is_prime=1
    for j in range(2, (i//2)+1):
        # If num is divisible by any number between
        # 2 and n / 2, it is not prime
        if (i % j) == 0:
            is_prime = 0
            break
    primes[i] = is_prime

primes[1] = 1

In [30]:
class PrimeEnv(gym.Env):
    def __init__(self, max_bits=16):
        self.observation_space = spaces.MultiBinary(max_bits)
        self.action_space = spaces.MultiBinary(max_bits)
    
    def reset(self, **kwargs):
        start_state = np.random.choice(np.array(list(primes.keys())))
        self.state_deci = start_state
        start_state = [int(i) for i in list('{0:0b}'.format(start_state))]
        start_state = [0] * (16 - len(start_state)) + start_state
        start_state = np.array(start_state)
        self.state_bin = start_state
        return start_state, {}
    
    def step(self, action):
        next_state_bin = action
        next_state_decimal = binatodeci(next_state_bin)
        done = False
        if next_state_decimal < self.state_deci:
            done = True
        
        reward = primes[next_state_decimal]
        self.state_deci = next_state_decimal,
        self.state_bin = next_state_bin
        return next_state_bin, reward, done, False, {}
 

In [33]:
model = PPO("MlpPolicy", env = gym.wrappers.time_limit.TimeLimit(PrimeEnv(),500), verbose = 1)
model.learn(1e6)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.76     |
|    ep_rew_mean     | 0.18     |
| time/              |          |
|    fps             | 334      |
|    iterations      | 1        |
|    time_elapsed    | 6        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.75        |
|    ep_rew_mean          | 0.08        |
| time/                   |             |
|    fps                  | 307         |
|    iterations           | 2           |
|    time_elapsed         | 13          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.014983747 |
|    clip_fraction        | 0.253       |
|    clip_range           | 0.2         |
|    entropy_loss   

In [31]:
#Generate primes

env = PrimeEnv()
for ep in range(100):
    s, _ = env.reset()
    print("Starting number: {}".format(binatodeci(s)))
    done = False
    while not done:
        a = model.predict(s, deterministic=True)[0]
        s, r, done, _, _ = env.step(a)
        print("Next number predicted: {}".format(binatodeci(s)))
        print("Is prime: {}".format(bool(r)))

Starting number: 12529
Next number predicted: 20341.0
Is prime: True
Next number predicted: 32911.0
Is prime: True
Next number predicted: 41077.0
Is prime: True
Next number predicted: 41333.0
Is prime: True
Next number predicted: 41809.0
Is prime: True
Next number predicted: 51061.0
Is prime: True
Next number predicted: 57373.0
Is prime: True
Next number predicted: 61469.0
Is prime: True
Next number predicted: 61967.0
Is prime: True
Next number predicted: 64013.0
Is prime: True
Next number predicted: 65141.0
Is prime: True
Next number predicted: 65293.0
Is prime: True
Next number predicted: 65309.0
Is prime: True
Next number predicted: 65371.0
Is prime: True
Next number predicted: 65407.0
Is prime: True
Next number predicted: 65419.0
Is prime: True
Next number predicted: 65437.0
Is prime: True
Next number predicted: 65497.0
Is prime: True
Next number predicted: 65455.0
Is prime: False
Starting number: 4389
Next number predicted: 20341.0
Is prime: True
Next number predicted: 32911.0
Is 