In [None]:
!pip install stable-baselines3

In [10]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
import torch as th

In [16]:
def binatodeci(binary):
        return sum(val*(2**idx) for idx, val in enumerate(reversed(binary)))

In [28]:
primes = {}
for i in range(1, int(binatodeci(np.ones(16)))+1):
    # Iterate from 2 to n // 2
    is_prime=1
    for j in range(2, (i//2)+1):
        # If num is divisible by any number between
        # 2 and n / 2, it is not prime
        if (i % j) == 0:
            is_prime = 0
            break
    primes[i] = is_prime

primes[1] = 1

In [29]:
primes

{1: 1,
 2: 1,
 3: 1,
 4: 0,
 5: 1,
 6: 0,
 7: 1,
 8: 0,
 9: 0,
 10: 0,
 11: 1,
 12: 0,
 13: 1,
 14: 0,
 15: 0,
 16: 0,
 17: 1,
 18: 0,
 19: 1,
 20: 0,
 21: 0,
 22: 0,
 23: 1,
 24: 0,
 25: 0,
 26: 0,
 27: 0,
 28: 0,
 29: 1,
 30: 0,
 31: 1,
 32: 0,
 33: 0,
 34: 0,
 35: 0,
 36: 0,
 37: 1,
 38: 0,
 39: 0,
 40: 0,
 41: 1,
 42: 0,
 43: 1,
 44: 0,
 45: 0,
 46: 0,
 47: 1,
 48: 0,
 49: 0,
 50: 0,
 51: 0,
 52: 0,
 53: 1,
 54: 0,
 55: 0,
 56: 0,
 57: 0,
 58: 0,
 59: 1,
 60: 0,
 61: 1,
 62: 0,
 63: 0,
 64: 0,
 65: 0,
 66: 0,
 67: 1,
 68: 0,
 69: 0,
 70: 0,
 71: 1,
 72: 0,
 73: 1,
 74: 0,
 75: 0,
 76: 0,
 77: 0,
 78: 0,
 79: 1,
 80: 0,
 81: 0,
 82: 0,
 83: 1,
 84: 0,
 85: 0,
 86: 0,
 87: 0,
 88: 0,
 89: 1,
 90: 0,
 91: 0,
 92: 0,
 93: 0,
 94: 0,
 95: 0,
 96: 0,
 97: 1,
 98: 0,
 99: 0,
 100: 0,
 101: 1,
 102: 0,
 103: 1,
 104: 0,
 105: 0,
 106: 0,
 107: 1,
 108: 0,
 109: 1,
 110: 0,
 111: 0,
 112: 0,
 113: 1,
 114: 0,
 115: 0,
 116: 0,
 117: 0,
 118: 0,
 119: 0,
 120: 0,
 121: 0,
 122: 0,
 123: 0,
 

In [30]:
class PrimeEnv(gym.Env):
    def __init__(self, max_bits=16):
        self.observation_space = spaces.MultiBinary(max_bits)
        self.action_space = spaces.MultiBinary(max_bits)
    
    def reset(self, **kwargs):
        start_state = np.random.choice(np.array(list(primes.keys())))
        self.state_deci = start_state
        start_state = [int(i) for i in list('{0:0b}'.format(start_state))]
        start_state = [0] * (16 - len(start_state)) + start_state
        start_state = np.array(start_state)
        self.state_bin = start_state
        return start_state, {}
    
    def step(self, action):
        next_state_bin = action
        next_state_decimal = binatodeci(next_state_bin)
        done = False
        if next_state_decimal < self.state_deci:
            done = True
        
        reward = primes[next_state_decimal]
        self.state_deci = next_state_decimal,
        self.state_bin = next_state_bin
        return next_state_bin, reward, done, False, {}
 

In [32]:
model = PPO("MlpPolicy", env = gym.wrappers.time_limit.TimeLimit(PrimeEnv(),500), verbose = 1, policy_kwargs=dict(activation_fn=th.nn.Tanh,
                     net_arch=dict(pi=[128,128,128], vf=[128,128,128])))
model.learn(1e6)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.66     |
|    ep_rew_mean     | 0.16     |
| time/              |          |
|    fps             | 267      |
|    iterations      | 1        |
|    time_elapsed    | 7        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.67        |
|    ep_rew_mean          | 0.24        |
| time/                   |             |
|    fps                  | 234         |
|    iterations           | 2           |
|    time_elapsed         | 17          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.020611499 |
|    clip_fraction        | 0.327       |
|    clip_range           | 0.2         |
|    entropy_loss   

In [31]:
#Generate primes

env = PrimeEnv()
for ep in range(100):
    s, _ = env.reset()
    print("Starting number: {}".format(binatodeci(s)))
    done = False
    while not done:
        a = model.predict(s, deterministic=True)[0]
        s, r, done, _, _ = env.step(a)
        print("Next number predicted: {}".format(binatodeci(s)))
        print("Is prime: {}".format(bool(r)))

Starting number: 12529
Next number predicted: 20341.0
Is prime: True
Next number predicted: 32911.0
Is prime: True
Next number predicted: 41077.0
Is prime: True
Next number predicted: 41333.0
Is prime: True
Next number predicted: 41809.0
Is prime: True
Next number predicted: 51061.0
Is prime: True
Next number predicted: 57373.0
Is prime: True
Next number predicted: 61469.0
Is prime: True
Next number predicted: 61967.0
Is prime: True
Next number predicted: 64013.0
Is prime: True
Next number predicted: 65141.0
Is prime: True
Next number predicted: 65293.0
Is prime: True
Next number predicted: 65309.0
Is prime: True
Next number predicted: 65371.0
Is prime: True
Next number predicted: 65407.0
Is prime: True
Next number predicted: 65419.0
Is prime: True
Next number predicted: 65437.0
Is prime: True
Next number predicted: 65497.0
Is prime: True
Next number predicted: 65455.0
Is prime: False
Starting number: 4389
Next number predicted: 20341.0
Is prime: True
Next number predicted: 32911.0
Is 