<a href="https://colab.research.google.com/github/AkshatSG/GFN/blob/main/GFlowNets_Structure.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
#Librarires

#For environment (graph-representation)
!pip install rdkit networkx

Collecting rdkit
  Downloading rdkit-2024.3.5-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.9 kB)
Downloading rdkit-2024.3.5-cp310-cp310-manylinux_2_28_x86_64.whl (33.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.1/33.1 MB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.3.5


In [64]:
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx

class MoleculeEnvironment:
    def __init__(self, max_atoms=10, max_rings=2, min_atoms=2):
        self.max_atoms = max_atoms
        self.max_rings = max_rings
        self.min_atoms = min_atoms
        self.available_atoms = ['C', 'N', 'O', 'F', 'Cl']
        self.available_bond_types = [
            Chem.BondType.SINGLE,
            Chem.BondType.DOUBLE,
            Chem.BondType.TRIPLE,
            Chem.BondType.AROMATIC
        ]
        self.valences = {'C': 4, 'N': 3, 'O': 2, 'F': 1, 'Cl': 1}
        self.proxy = MoleculeProxy()
        self.reset()

    def reset(self):
        self.mol = Chem.RWMol()
        self.graph = nx.Graph()
        self.add_atom('C')
        return self.get_state()

    def add_atom(self, atom_symbol):
        atom_idx = self.mol.AddAtom(Chem.Atom(atom_symbol))
        self.graph.add_node(atom_idx, symbol=atom_symbol)
        return atom_idx

    def add_bond(self, atom1_idx, atom2_idx, bond_type=Chem.BondType.SINGLE):
        self.mol.AddBond(atom1_idx, atom2_idx, bond_type)
        self.graph.add_edge(atom1_idx, atom2_idx, bond_type=bond_type)

    def get_mol(self):
        return self.mol.GetMol()

    def get_graph(self):
        return self.graph

    def get_state(self):
        return {
            'num_atoms': self.mol.GetNumAtoms(),
            'num_bonds': self.mol.GetNumBonds(),
            'atom_types': [atom.GetSymbol() for atom in self.mol.GetAtoms()],
            'bond_types': [bond.GetBondType() for bond in self.mol.GetBonds()]
        }

    def is_terminal(self):
        num_atoms = self.mol.GetNumAtoms()
        if num_atoms < self.min_atoms:
            return False
        if num_atoms > self.max_atoms:
            return True
        if not self.is_valid_molecule():
            return True
        for atom in self.mol.GetAtoms():
            if atom.GetDegree() == 0:
                return False
        return True

    def is_valid_action(self, action):
        action_type = action[0]
        if action_type == 'add_atom':
            return self.mol.GetNumAtoms() < self.max_atoms
        elif action_type == 'add_bond':
            atom1, atom2, bond_type = action[1], action[2], action[3]
            return self.is_valid_bond(atom1, atom2, bond_type)
        elif action_type == 'remove_atom':
            return self.mol.GetNumAtoms() > 1
        elif action_type == 'remove_bond':
            return True
        return False

    def is_valid_bond(self, atom1_idx, atom2_idx, bond_type):
        atom1 = self.mol.GetAtomWithIdx(atom1_idx)
        atom2 = self.mol.GetAtomWithIdx(atom2_idx)

        # Convert bond_type to double
        bond_type_double = {
          Chem.BondType.SINGLE: 1.0,
          Chem.BondType.DOUBLE: 2.0,
          Chem.BondType.TRIPLE: 3.0,
          Chem.BondType.AROMATIC: 1.5
        }.get(bond_type, 0.0)

        if (atom1.GetExplicitValence() + bond_type_double > self.valences[atom1.GetSymbol()] or
            atom2.GetExplicitValence() + bond_type_double > self.valences[atom2.GetSymbol()]):
            return False

        if self.would_create_too_many_rings(atom1_idx, atom2_idx):
            return False

        return True

    def would_create_too_many_rings(self, atom1_idx, atom2_idx):
        temp_graph = self.graph.copy()
        temp_graph.add_edge(atom1_idx, atom2_idx)
        return len(list(nx.cycle_basis(temp_graph))) > self.max_rings

    def get_possible_actions(self):
        actions = []
        num_atoms = self.mol.GetNumAtoms()

        for action in [
            ('add_atom', atom_type) for atom_type in self.available_atoms
        ] + [
            ('add_bond', i, j, bond_type)
            for i in range(num_atoms) for j in range(i+1, num_atoms)
            for bond_type in self.available_bond_types
        ] + [
            ('remove_atom', i) for i in range(num_atoms)
        ] + [
            ('remove_bond', bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
            for bond in self.mol.GetBonds()
        ]:
            if self.is_valid_action(action):
                actions.append(action)

        return actions

    def get_reward(self):
        return self.proxy.calculate_reward(self.mol)

    def take_action(self, action):
        if not self.is_valid_action(action):
            raise ValueError("Invalid Action")
        action_type = action[0]
        if action_type == 'add_atom':
            self.add_atom(action[1])
        elif action_type == 'add_bond':
            self.add_bond(action[1], action[2], action[3])
        elif action_type == 'remove_atom':
            self.mol.RemoveAtom(action[1])
            self.graph.remove_node(action[1])
        elif action_type == 'remove_bond':
            self.mol.RemoveBond(action[1], action[2])
            self.graph.remove_edge(action[1], action[2])

        try:
            Chem.SanitizeMol(self.mol)
        except:
            pass

        state = self.get_state()
        reward = self.get_reward()
        is_terminal = self.is_terminal()

        return state, reward, is_terminal

    def is_valid_molecule(self):
        try:
            Chem.SanitizeMol(self.mol)
            return True
        except:
            return False

In [65]:
env = MoleculeEnvironment(max_atoms=10, min_atoms=2)

# Start with the initial state
state = env.reset()
print("Initial state:", state)

# Take some actions
actions = [
    ('add_atom', 'N'),
    ('add_bond', 0, 1, Chem.BondType.SINGLE),
    ('add_atom', 'O'),
    ('add_bond', 1, 2, Chem.BondType.SINGLE),
    ('add_atom', 'C'),
    ('add_bond', 2, 3, Chem.BondType.SINGLE)
]

for action in actions:
    state, reward, is_terminal = env.take_action(action)
    print(f"After action {action}:")
    print("State:", state)
    print("Is terminal:", is_terminal)
    print("Reward:", reward)

# Check final state
print("Final molecule is valid:", env.is_valid_molecule())
print("Final reward:", env.get_reward())

Initial state: {'num_atoms': 1, 'num_bonds': 0, 'atom_types': ['C'], 'bond_types': []}
After action ('add_atom', 'N'):
State: {'num_atoms': 2, 'num_bonds': 0, 'atom_types': ['C', 'N'], 'bond_types': []}
Is terminal: False
Reward: 0.5369165505413814
After action ('add_bond', 0, 1, rdkit.Chem.rdchem.BondType.SINGLE):
State: {'num_atoms': 2, 'num_bonds': 1, 'atom_types': ['C', 'N'], 'bond_types': [rdkit.Chem.rdchem.BondType.SINGLE]}
Is terminal: True
Reward: 0.5184913674388583
After action ('add_atom', 'O'):
State: {'num_atoms': 3, 'num_bonds': 1, 'atom_types': ['C', 'N', 'O'], 'bond_types': [rdkit.Chem.rdchem.BondType.SINGLE]}
Is terminal: False
Reward: 0.4450431955341311
After action ('add_bond', 1, 2, rdkit.Chem.rdchem.BondType.SINGLE):
State: {'num_atoms': 3, 'num_bonds': 2, 'atom_types': ['C', 'N', 'O'], 'bond_types': [rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.SINGLE]}
Is terminal: True
Reward: 0.5309692432008388
After action ('add_atom', 'C'):
State: {'num_atoms'

In [52]:
#Proxy
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, rdMolDescriptors
from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem import AllChem
from rdkit.Contrib.SA_Score import sascorer

class MoleculeProxy:
    def __init__(self, target_weight=500, target_logp=2.5, max_hbd=5, max_hba=10):
        self.target_weight = target_weight
        self.target_logp = target_logp
        self.max_hbd = max_hbd
        self.max_hba = max_hba

    def calculate_reward(self, mol):
      if mol is None or mol.GetNumAtoms() == 0:
        return 0

      mol_copy = Chem.Mol(mol)
      try:
        Chem.SanitizeMol(mol_copy)
        AllChem.Compute2DCoords(mol_copy)
      except:
        return 0

      mol_weight = ExactMolWt(mol)
      logp = Crippen.MolLogP(mol)
      hbd = rdMolDescriptors.CalcNumHBD(mol)
      hba = rdMolDescriptors.CalcNumHBA(mol)
      sa_score = self.calculate_sa_score(mol)

      weight_reward = 1 - abs(mol_weight - self.target_weight) / self.target_weight
      logp_reward = 1 - abs(logp - self.target_logp) / max(abs(self.target_logp), 1)
      hbd_reward = 1 if hbd <= self.max_hbd else 0
      hba_reward = 1 if hba <= self.max_hba else 0
      sa_reward = 1 - sa_score / 10

      total_reward = (weight_reward + logp_reward + hbd_reward + hba_reward + sa_reward) / 5
      return max(0, total_reward)

    def calculate_sa_score(self, mol):
      # return AllChem.CalcSyntheticAccessibilityScore(mol)
      return sascorer.calculateScore(mol)

In [3]:
#Policy Models:

  #Forward Policy

  #Backward Policy

In [None]:
#GFlowNet Agent

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class GFlowNetAgent:
    def __init__(self, env, hidden_dim=64):
        self.env = env
        input_dim = self.get_state_dim()
        output_dim = self.get_action_dim()

        self.forward_policy = PolicyNetwork(input_dim, hidden_dim, output_dim)
        self.backward_policy = PolicyNetwork(input_dim, hidden_dim, output_dim)

    def get_state_dim(self):
        return 2

    def get_action_dim(self):
        return len(self.env.get_possible_actions())

    def state_to_tensor(self, state):
        return torch.tensor([state['num_atoms'], state['num_bonds']], dtype=torch.float32)

    def forward_action(self, state):
        state_tensor = self.state_to_tensor(state)
        with torch.no_grad():
            action_probs = F.softmax(self.forward_policy(state_tensor), dim=0)
        action_index = torch.multinomial(action_probs, 1).item()
        return self.env.get_possible_actions()[action_index]

    def backward_action(self, state):
        state_tensor = self.state_to_tensor(state)
        with torch.no_grad():
            action_probs = F.softmax(self.backward_policy(state_tensor), dim=0)
        action_index = torch.multinomial(action_probs, 1).item()
        return self.env.get_possible_actions()[action_index]