In [None]:
!pip install stable-baselines3

In [7]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces

In [25]:
class PrimeEnv(gym.Env):
    def __init__(self, max_bits=16):
        self.observation_space = spaces.MultiBinary(max_bits)
        self.action_space = spaces.MultiBinary(max_bits)
    
    def reset(self, **kwargs):
        start_state = self.observation_space.sample()
        self.state_ = start_state
        return start_state, {}
    
    def step(self, action):
        next_state = action
        next_state_decimal = self.binatodeci(next_state)
        done = False
        if next_state_decimal > self.binatodeci(self.state_):
            if self.is_prime(next_state_decimal):
                reward = 1
            else:
                reward = 0
        else:
            reward = 0
            done = True
        self.state_ = next_state
        return next_state, reward, done, False, {}
    
    def binatodeci(self, binary):
        return sum(val*(2**idx) for idx, val in enumerate(reversed(binary)))

    def is_prime(self, state):
        tmp_state = int(state)
        if tmp_state > 1:
        # Iterate from 2 to n // 2
            for i in range(2, (tmp_state//2)+1):
                # If num is divisible by any number between
                # 2 and n / 2, it is not prime
                if (tmp_state % i) == 0:
                    return 0
            else:
                return 1
        else:
            return 0
 

In [31]:
from stable_baselines3 import PPO
import torch as th

model = PPO("MlpPolicy", env = gym.wrappers.time_limit.TimeLimit(PrimeEnv(),100), verbose = 1)
model.learn(1e7)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.66     |
|    ep_rew_mean     | 0.07     |
| time/              |          |
|    fps             | 1910     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1.84        |
|    ep_rew_mean          | 0.11        |
| time/                   |             |
|    fps                  | 1468        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.012947267 |
|    clip_fraction        | 0.136       |
|    clip_range           | 0.2         |
|    entropy_loss   

KeyboardInterrupt: 

In [32]:
#Generate primes

env = PrimeEnv()
for ep in range(100):
    s, _ = env.reset()
    print("Starting number: {}".format(env.binatodeci(s)))
    done = False
    while not done:
        a = model.predict(s, deterministic=True)[0]
        print("Next number predicted: {}".format(env.binatodeci(s)))
        s, r, done, _, _ = env.step(a)
        print("Is prime: {}".format(bool(r)))

Starting number: 40760
Next number predicted: 40760
Is prime: True
Next number predicted: 49171.0
Is prime: True
Next number predicted: 57347.0
Is prime: True
Next number predicted: 61441.0
Is prime: True
Next number predicted: 64513.0
Is prime: True
Next number predicted: 64577.0
Is prime: True
Next number predicted: 65089.0
Is prime: True
Next number predicted: 65129.0
Is prime: True
Next number predicted: 65257.0
Is prime: True
Next number predicted: 65269.0
Is prime: False
Starting number: 58413
Next number predicted: 58413
Is prime: True
Next number predicted: 61441.0
Is prime: True
Next number predicted: 64513.0
Is prime: True
Next number predicted: 64577.0
Is prime: True
Next number predicted: 65089.0
Is prime: True
Next number predicted: 65129.0
Is prime: True
Next number predicted: 65257.0
Is prime: True
Next number predicted: 65269.0
Is prime: False
Starting number: 37398
Next number predicted: 37398
Is prime: True
Next number predicted: 49171.0
Is prime: True
Next number pre