In [1]:
import gym
from gym import spaces
from rdkit import Chem
from rdkit.Chem import QED
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from rdkit.Chem import Draw
from gym.spaces import MultiDiscrete


  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [7]:
from rdkit import Chem
from rdkit.Chem import Descriptors, QED
import numpy as np
import gym
from gym import spaces
from gym.spaces import MultiDiscrete
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from gym.utils import seeding


def mol_to_vector(mol):
    if mol is None:
        return np.zeros(10, dtype=np.float64)  # Fixed-size zero vector
    features = [
        Descriptors.MolWt(mol),
        Descriptors.NumValenceElectrons(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.RingCount(mol),
        Descriptors.FractionCSP3(mol),
        Descriptors.MolLogP(mol),
        Descriptors.HeavyAtomCount(mol)
    ]
    return np.array(features, dtype=np.float64)


class MoleculeEnvExpanded(gym.Env):
    def __init__(self, max_atoms=10):
        assert 1 <= max_atoms < 1000, "max_atoms must be between 1 and 999"
        super().__init__()
        self.atom_types = ['C', 'N', 'O', 'F', 'S']
        self.bond_types = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE,
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.max_atoms = max_atoms

        self.action_space = MultiDiscrete([
            self.max_atoms,                # attach_to_atom index
            len(self.atom_types),         # new_atom_type index
            len(self.bond_types) + 1      # bond_type index (+1 for stop)
        ])

        self.observation_space = spaces.Box(
            low=np.zeros(10, dtype=np.float64),
            high=np.ones(10, dtype=np.float64) * 1e2,
            dtype=np.float64
        )

        print("hogh", self.observation_space.high)
        print("low", self.observation_space.low)

        self.np_random = None
        #self.seed()  # Initialize random seed

        self.reset()

    def seed(self, seed=None):
        print("seed", seed)
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def reset(self):
        self.mol = Chem.MolFromSmiles('C')  # Start with methane
        self.done = False
        return self._get_obs()

    def _get_obs(self):
        return mol_to_vector(self.mol)

    def step(self, action):
        if self.done:
            raise RuntimeError("Episode finished, call reset()")

        attach_idx, atom_idx, bond_idx = action

        # Stop action
        if bond_idx == len(self.bond_types):
            self.done = True
            reward = self._compute_reward()
            return self._get_obs(), reward, self.done, {}

        if attach_idx >= self.mol.GetNumAtoms():
            reward = -1  # Invalid attachment index
            return self._get_obs(), reward, self.done, {}

        atom_symbol = self.atom_types[atom_idx]
        bond_type = self.bond_types[bond_idx]

        new_mol = self._add_atom(self.mol, attach_idx, atom_symbol, bond_type)
        if new_mol is None:
            reward = -1  # Invalid molecule after addition
        else:
            self.mol = new_mol
            reward = 0

        return self._get_obs(), reward, self.done, {}

    def _add_atom(self, mol, attach_idx, atom_symbol, bond_type):
        try:
            em = Chem.EditableMol(mol)
            new_atom_idx = em.AddAtom(Chem.Atom(atom_symbol))
            em.AddBond(attach_idx, new_atom_idx, bond_type)
            new_mol = em.GetMol()
            Chem.SanitizeMol(new_mol)
            return new_mol
        except Exception:
            return None

    def _compute_reward(self):
        try:
            return QED.qed(self.mol)
        except Exception:
            return -1.0


def main():
    env = make_vec_env(lambda: MoleculeEnvExpanded(), n_envs=1)  # no seed param

    checkpoint_callback = CheckpointCallback(
        save_freq=5000,
        save_path='./models/',
        name_prefix='ppo_molecule'
    )

    model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./ppo_mol_logs/")
    model.learn(total_timesteps=10000, callback=checkpoint_callback)

    eval_env = MoleculeEnvExpanded()
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, render=False)
    print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")

    model.save("ppo_molecule")


main()


hogh [100. 100. 100. 100. 100. 100. 100. 100. 100. 100.]
low [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


ValueError: high is out of bounds for int32

In [35]:
import gym
from gym.spaces import MultiDiscrete
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

class DummyEnv(gym.Env):
    def seed(self, seed=None):
        import random, numpy as np
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        random.seed(seed)
        return [seed]
    def __init__(self):
        super().__init__()
        self.action_space = MultiDiscrete([10, 5, 5])
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(10,), dtype=float)
    def reset(self):
        return self.observation_space.sample()
    def step(self, action):
        return self.observation_space.sample(), 0.0, False, {}

env = make_vec_env(lambda: DummyEnv(), n_envs=1, seed=42)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)


Using cpu device
-----------------------------
| time/              |      |
|    fps             | 855  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 757         |
|    iterations           | 2           |
|    time_elapsed         | 5           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008766099 |
|    clip_fraction        | 0.0897      |
|    clip_range           | 0.2         |
|    entropy_loss         | -5.51       |
|    explained_variance   | -37.3       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0482     |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0152     |
|    value_loss           | 0.00278     |
-----------------------------------------
-----------------

<stable_baselines3.ppo.ppo.PPO at 0x1742123ae20>