<a href="https://colab.research.google.com/github/AkshatSG/GFN/blob/main/GFlowNets_Structure_POC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Librarires

#For environment (graph-representation)
!pip install rdkit networkx

Collecting rdkit
  Downloading rdkit-2024.3.5-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.9 kB)
Downloading rdkit-2024.3.5-cp310-cp310-manylinux_2_28_x86_64.whl (33.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.1/33.1 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.3.5


In [2]:
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx

class MoleculeEnvironment:
    def __init__(self, max_atoms=10, max_rings=2, min_atoms=2):
        self.max_atoms = max_atoms
        self.max_rings = max_rings
        self.min_atoms = min_atoms
        self.available_atoms = ['C', 'N', 'O', 'F', 'Cl']
        self.available_bond_types = [
            Chem.BondType.SINGLE,
            Chem.BondType.DOUBLE,
            Chem.BondType.TRIPLE,
            Chem.BondType.AROMATIC
        ]
        self.valences = {'C': 4, 'N': 3, 'O': 2, 'F': 1, 'Cl': 1}
        self.proxy = MoleculeProxy()
        self.reset()
        self.agent = GFlowNetAgent(self)

    def reset(self):
        self.mol = Chem.RWMol()
        self.graph = nx.Graph()
        self.add_atom('C')
        return self.get_state()

    def add_atom(self, atom_symbol):
        atom_idx = self.mol.AddAtom(Chem.Atom(atom_symbol))
        self.graph.add_node(atom_idx, symbol=atom_symbol)
        return atom_idx

    def add_bond(self, atom1_idx, atom2_idx, bond_type):
        if not self.mol.GetBondBetweenAtoms(atom1_idx, atom2_idx):
            self.mol.AddBond(atom1_idx, atom2_idx, bond_type)
            self.graph.add_edge(atom1_idx, atom2_idx, bond_type=bond_type)

        else:
            bond = self.mol.GetBondBetweenAtoms(atom1_idx, atom2_idx)
            bond.SetBondType(bond_type)
            self.graph[atom1_idx][atom2_idx]['bond_type'] = bond_type

    def get_mol(self):
        return self.mol.GetMol()

    def get_graph(self):
        return self.graph

    def get_state(self):
        return {
            'num_atoms': self.mol.GetNumAtoms(),
            'num_bonds': self.mol.GetNumBonds(),
            'atom_types': [atom.GetSymbol() for atom in self.mol.GetAtoms()],
            'bond_types': [bond.GetBondType() for bond in self.mol.GetBonds()]
        }

    def is_terminal(self):
        num_atoms = self.mol.GetNumAtoms()
        if num_atoms < self.min_atoms:
            return False
        if num_atoms > self.max_atoms:
            return True
        if not self.is_valid_molecule():
            return True
        for atom in self.mol.GetAtoms():
            if atom.GetDegree() == 0:
                return False
        return True

    def is_valid_action(self, action):
        action_type = action[0]
        if action_type == 'add_atom':
            return self.mol.GetNumAtoms() < self.max_atoms
        elif action_type == 'add_bond':
            atom1, atom2, bond_type = action[1], action[2], action[3]
            return (atom1 < self.mol.GetNumAtoms() and
                    atom2 < self.mol.GetNumAtoms() and
                    not self.mol.GetBondBetweenAtoms(atom1, atom2) and
                    self.is_valid_bond(atom1, atom2, bond_type))
        elif action_type == 'remove_atom':
            return self.mol.GetNumAtoms() > 1 and action[1] < self.mol.GetNumAtoms()
        elif action_type == 'remove_bond':
            return self.mol.GetBondBetweenAtoms(action[1], action[2]) is not None
        return False

    def is_valid_bond(self, atom1_idx, atom2_idx, bond_type):
        atom1 = self.mol.GetAtomWithIdx(atom1_idx)
        atom2 = self.mol.GetAtomWithIdx(atom2_idx)

        # Convert bond_type to double
        bond_type_double = {
          Chem.BondType.SINGLE: 1.0,
          Chem.BondType.DOUBLE: 2.0,
          Chem.BondType.TRIPLE: 3.0,
          Chem.BondType.AROMATIC: 1.5
        }.get(bond_type, 0.0)

        if (atom1.GetExplicitValence() + bond_type_double > self.valences[atom1.GetSymbol()] or
            atom2.GetExplicitValence() + bond_type_double > self.valences[atom2.GetSymbol()]):
            return False

        if self.would_create_too_many_rings(atom1_idx, atom2_idx):
            return False

        return True

    def would_create_too_many_rings(self, atom1_idx, atom2_idx):
        temp_graph = self.graph.copy()
        temp_graph.add_edge(atom1_idx, atom2_idx)
        return len(list(nx.cycle_basis(temp_graph))) > self.max_rings

    def get_possible_actions(self):
        actions = []
        num_atoms = self.mol.GetNumAtoms()
        atom_indices = list(range(num_atoms))

        for action in [
            ('add_atom', atom_type) for atom_type in self.available_atoms
        ] + [
            ('add_bond', i, j, bond_type)
            for i in atom_indices for j in atom_indices if i < j
            for bond_type in self.available_bond_types
        ] + [
            ('remove_atom', i) for i in atom_indices
        ] + [
            ('remove_bond', bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
            for bond in self.mol.GetBonds()
        ]:
            if self.is_valid_action(action):
                actions.append(action)

        return actions

    def get_reward(self):
        return self.proxy.calculate_reward(self.mol)

    def take_action(self, action):
        if not self.is_valid_action(action):
            raise ValueError("Invalid Action")
        action_type = action[0]
        if action_type == 'add_atom':
            self.add_atom(action[1])
        elif action_type == 'add_bond':
            self.add_bond(action[1], action[2], action[3])
        elif action_type == 'remove_atom':
            self.remove_atom(action[1])
        elif action_type == 'remove_bond':
            if self.mol.GetBondBetweenAtoms(action[1], action[2]):
                self.mol.RemoveBond(action[1], action[2])
                self.graph.remove_edge(action[1], action[2])

        try:
            Chem.SanitizeMol(self.mol)
        except:
            pass

        state = self.get_state()
        reward = self.get_reward()
        is_terminal = self.is_terminal()

        return state, reward, is_terminal

    def is_valid_molecule(self):
        try:
            Chem.SanitizeMol(self.mol)
            return True
        except:
            return False
    def sample_trajectory(self):
        self.reset()
        trajectory = []
        state = self.get_state()

        while not self.is_terminal():
            possible_actions = self.get_possible_actions()
            if not possible_actions:
                break
            action = self.agent.forward_action(state)
            next_state, reward, is_terminal = self.take_action(action)
            trajectory.append((state, action, next_state, reward, possible_actions))
            state = next_state

        return trajectory

    def get_mol(self):
        return self.mol.GetMol()

    def remove_atom(self, atom_idx):
      if atom_idx >= self.mol.GetNumAtoms():
          return

      # Remove all bonds connected to this atom from the graph
      bonds_to_remove = list(self.graph.edges(atom_idx))
      self.graph.remove_edges_from(bonds_to_remove)

      # Remove the atom from the graph
      self.graph.remove_node(atom_idx)

      # Remove the atom from the molecule
      self.mol.RemoveAtom(atom_idx)

      # Renumber the atoms in the graph
      mapping = {old: new for new, old in enumerate(sorted(self.graph.nodes()))}
      self.graph = nx.relabel_nodes(self.graph, mapping)

    def remove_bond(self, atom1_idx, atom2_idx):
      if self.mol.GetBondBetweenAtoms(atom1_idx, atom2_idx):
          self.mol.RemoveBond(atom1_idx, atom2_idx)
          self.graph.remove_edge(atom1_idx, atom2_idx)

In [4]:
# env = MoleculeEnvironment(max_atoms=10, min_atoms=2)

# # Start with the initial state
# state = env.reset()
# print("Initial state:", state)

# # Take some actions
# actions = [
#     ('add_atom', 'N'),
#     ('add_bond', 0, 1, Chem.BondType.SINGLE),
#     ('add_atom', 'O'),
#     ('add_bond', 1, 2, Chem.BondType.SINGLE),
#     ('add_atom', 'C'),
#     ('add_bond', 2, 3, Chem.BondType.SINGLE)
# ]

# for action in actions:
#     state, reward, is_terminal = env.take_action(action)
#     print(f"After action {action}:")
#     print("State:", state)
#     print("Is terminal:", is_terminal)
#     print("Reward:", reward)

# # Check final state
# print("Final molecule is valid:", env.is_valid_molecule())
# print("Final reward:", env.get_reward())

In [5]:
#Proxy
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, rdMolDescriptors
from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem import AllChem
from rdkit.Contrib.SA_Score import sascorer

class MoleculeProxy:
    def __init__(self, target_weight=500, target_logp=2.5, max_hbd=5, max_hba=10):
        self.target_weight = target_weight
        self.target_logp = target_logp
        self.max_hbd = max_hbd
        self.max_hba = max_hba

    def calculate_reward(self, mol):
      if mol is None or mol.GetNumAtoms() == 0:
        return -10

      mol_copy = Chem.Mol(mol)
      try:
        Chem.SanitizeMol(mol_copy)
        AllChem.Compute2DCoords(mol_copy)
      except:
        return -10

      mol_weight = ExactMolWt(mol)
      logp = Crippen.MolLogP(mol)
      hbd = rdMolDescriptors.CalcNumHBD(mol)
      hba = rdMolDescriptors.CalcNumHBA(mol)
      sa_score = self.calculate_sa_score(mol)

      weight_reward = 1 - abs(mol_weight - self.target_weight) / self.target_weight
      logp_reward = 1 - abs(logp - self.target_logp) / max(abs(self.target_logp), 1)
      hbd_reward = 1 if hbd <= self.max_hbd else 0
      hba_reward = 1 if hba <= self.max_hba else 0
      sa_reward = 1 - sa_score / 10

      total_reward = (weight_reward + logp_reward + hbd_reward + hba_reward + sa_reward) / 5
      return max(0, total_reward)

    def calculate_sa_score(self, mol):
      # return AllChem.CalcSyntheticAccessibilityScore(mol)
      return sascorer.calculateScore(mol)

In [6]:
#Policy Models:

  #Forward Policy

  #Backward Policy

In [7]:
#GFlowNet Agent

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    # def state_to_tensor(self, state):
    #     # Convert the state dictionary to a tensor
    #     return torch.tensor([
    #         state['num_atoms'],
    #         state['num_bonds'],
    #         len(state['atom_types']),
    #         len(state['bond_types'])
    #     ], dtype=torch.float32)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [9]:
class GFlowNetAgent:
    def __init__(self, env, hidden_dim=64):
        self.env = env
        self.input_dim = self.get_state_dim()
        self.output_dim = self.get_action_dim()

        self.forward_policy = PolicyNetwork(self.input_dim, hidden_dim, self.output_dim)
        self.backward_policy = PolicyNetwork(self.input_dim, hidden_dim, self.output_dim)

    def get_state_dim(self):
        return 4  # num_atoms, num_bonds, num_atom_types, num_bond_types

    def get_action_dim(self):
        return len(self.env.get_possible_actions())

    def state_to_tensor(self, state):
        return torch.tensor([
            state['num_atoms'],
            state['num_bonds'],
            len(state['atom_types']),
            len(state['bond_types'])
        ], dtype=torch.float32)

    def forward_action(self, state):
        state_tensor = self.state_to_tensor(state)
        possible_actions = self.env.get_possible_actions()
        with torch.no_grad():
            action_probs = F.softmax(self.forward_policy(state_tensor), dim=0)
        action_index = torch.multinomial(action_probs, 1).item()
        return possible_actions[action_index]

    def backward_action(self, state):
        state_tensor = self.state_to_tensor(state)
        possible_actions = self.env.get_possible_actions()
        with torch.no_grad():
            action_probs = F.softmax(self.backward_policy(state_tensor), dim=0)
        action_index = torch.multinomial(action_probs, 1).item()
        return possible_actions[action_index]

In [10]:
import torch
import torch.nn.functional as F

def calculate_trajectory_balance_loss(trajectory, agent):
      loss = torch.tensor(0.0, requires_grad=True)
      Z = torch.tensor(1.0, requires_grad=True)  # Partition function as a tensor
      epsilon = 1e-10  # Small constant to prevent log(0)

      for i in range(len(trajectory) - 1):
          state, action, next_state, reward, possible_actions = trajectory[i]
          next_action = trajectory[i+1][1] if i < len(trajectory) - 2 else None

          state_tensor = agent.state_to_tensor(state)
          next_state_tensor = agent.state_to_tensor(next_state)

          forward_logits = agent.forward_policy(state_tensor)
          backward_logits = agent.backward_policy(next_state_tensor)

          action_index = possible_actions.index(action)

          forward_prob = F.softmax(forward_logits, dim=0)[action_index]
          backward_prob = F.softmax(backward_logits, dim=0)[action_index] if next_action else torch.tensor(1.0)

          if i == 0:
              loss = loss + torch.log(Z) + torch.log(forward_prob)
          elif i == len(trajectory) - 2:
              # Add epsilon to prevent log(0)
              loss = loss + torch.log(torch.tensor(max(reward, epsilon), dtype=torch.float32)) - torch.log(backward_prob)
          else:
              loss = loss + torch.log(forward_prob) - torch.log(backward_prob)

      return -loss  # Negative because we want to maximize this quantity

In [11]:
import torch.nn as nn
import torch.optim as optim

def train_gflownet(env, num_episodes=1000, learning_rate=1e-4, clip_value=1.0):
    agent = env.agent
    optimizer = optim.Adam(list(agent.forward_policy.parameters()) +
                           list(agent.backward_policy.parameters()),
                           lr=learning_rate)

    for episode in range(num_episodes):
        trajectory = env.sample_trajectory()
        if not trajectory:
            continue
        loss = calculate_trajectory_balance_loss(trajectory, agent)

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Invalid loss value at episode {episode}. Skipping.")
            continue

        optimizer.zero_grad()
        loss.backward()

        # Add gradient clipping
        nn.utils.clip_grad_norm_(agent.forward_policy.parameters(), clip_value)
        nn.utils.clip_grad_norm_(agent.backward_policy.parameters(), clip_value)

        optimizer.step()

        if episode % 100 == 0:
            print(f"Episode {episode}, Loss: {loss.item()}")

        if episode % 1000 == 0:
            evaluate_model(env)

In [12]:
from rdkit import Chem
from rdkit.Chem import Descriptors

def evaluate_model(env, num_samples=100):
    valid_molecules = 0
    total_reward = 0
    unique_smiles = set()

    for _ in range(num_samples):
        trajectory = env.sample_trajectory()
        final_state = trajectory[-1][2]  # Get the final state
        mol = env.get_mol()

        if env.is_valid_molecule():
            valid_molecules += 1
            total_reward += env.get_reward()
            smiles = Chem.MolToSmiles(mol)
            unique_smiles.add(smiles)

    validity_rate = valid_molecules / num_samples
    avg_reward = total_reward / valid_molecules if valid_molecules > 0 else 0
    uniqueness_rate = len(unique_smiles) / valid_molecules if valid_molecules > 0 else 0

    print(f"Validity rate: {validity_rate:.2f}")
    print(f"Average reward: {avg_reward:.2f}")
    print(f"Uniqueness rate: {uniqueness_rate:.2f}")

    # Print some example molecules
    print("Example generated molecules:")
    for smiles in list(unique_smiles)[:5]:
        print(smiles)

In [15]:
env = MoleculeEnvironment(max_atoms=10, min_atoms=2)
import warnings
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings("ignore", category=DeprecationWarning)
train_gflownet(env)

Episode 0, Loss: -2.872411012649536
Validity rate: 0.33
Average reward: 0.68
Uniqueness rate: 1.00
Example generated molecules:
C=N.NOC(Cl)=C1C=C=N1
CN(Cl)C(=O)C(Cl)=C(Cl)Cl
FCC(Cl)(Cl)Cl.FCl.FF
N#CN(CN(F)Cl)OOCl
N=C(N)C=O.O=NC(=O)F
Episode 100, Loss: -32.127540588378906
Episode 200, Loss: -79.93714141845703
Episode 300, Loss: -234.4235076904297
Episode 400, Loss: -402.9586486816406
Episode 500, Loss: -585.8623046875
Episode 600, Loss: -845.3506469726562
