In [1]:
import pandas as pd

df = pd.read_csv('../polygraphpy/data/polarizability_data.csv')
df = df[df['chain_size'] == 0]
df.to_csv('filterd_polarizability.csv', index=False)

In [2]:
import torch
from torchdrug import data, models, tasks, core
from torchdrug.layers import distribution
from torchdrug.core import Registry as R
from torch import nn, optim

@R.register("datasets.CustomMolecule")
class CustomMoleculeDataset(data.MoleculeDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

dataset = CustomMoleculeDataset()
dataset.load_csv('filterd_polarizability.csv', smiles_field='smiles', target_fields=['static_polarizability'], kekulize=True, atom_feature='symbol')



In [3]:
model = models.RGCN(input_dim=dataset.num_atom_type, num_relation=dataset.num_bond_type, hidden_dims=[128, 128, 128], batch_norm=True)

num_atom_type = dataset.num_atom_type
num_bond_type = dataset.num_bond_type + 1

node_prior = distribution.IndependentGaussian(torch.zeros(num_atom_type), torch.ones(num_atom_type))
edge_prior = distribution.IndependentGaussian(torch.zeros(num_bond_type), torch.ones(num_bond_type))

node_flow = models.GraphAF(model, node_prior, num_layer=12)
edge_flow = models.GraphAF(model, edge_prior, use_edge=True, num_layer=12)

task = tasks.AutoregressiveGeneration(node_flow, edge_flow, max_node=38, max_edge_unroll=12, criterion='nll')

optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=32)

16:22:35   Preprocess training set
16:22:35   {'batch_size': 32,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'differentiable': False,
               'eps': 1e-08,
               'foreach': None,
               'fused': None,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 5,
          'baseline_momentum': 0.9,
          'class': 'tasks.AutoregressiveGeneration',
          'criterion': 'nll',
          'edge_model': {'class': 'models.GraphAF',
                         'dequantization_noise': 0.9,
                         'model': {'activation': 'relu',
                                   'batch_norm': True,
                                   'class

In [None]:
solver.train(num_epoch=10)
solver.save('graphaf_model.pkl')

16:22:37   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:22:37   Epoch 0 begin




16:22:38   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:22:38   edge log likelihood: -24.3833
16:22:38   edge mask / graph: 175.812
16:22:38   node log likelihood: -3007.38
16:22:38   node mask / graph: 21.0938




16:22:51   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:22:51   edge log likelihood: -3.77118
16:22:51   edge mask / graph: 182.625
16:22:51   node log likelihood: -15.2756
16:22:51   node mask / graph: 21.7188
16:23:03   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16:23:03   edge log likelihood: -2.6299
16:23:03   edge mask / graph: 204.375
16:23:03   node log likelihood: -14.8928
16:23:03   node mask / graph: 23.5312


# New try

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# PyG DataLoader and utils are no longer strictly needed for data loading,
# but kept for completeness if other PyG utilities are used elsewhere.
from torch_geometric.data import Data, DataLoader as PyGDataLoader
from torch_geometric.utils import to_dense_adj

import pandas as pd
from tqdm import tqdm

# Import torchdrug components
import torchdrug
from torchdrug import core, models, tasks
from torchdrug.data import Molecule
from torchdrug.layers import distribution # Added for priors
from torch.utils import data as torch_data # To avoid conflict with PyG DataLoader

# Ensure deterministic behavior for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Main Execution ---
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --- 1. Load Data using TorchDrug's Molecule.load_csv ---
    csv_file_path = '../polygraphpy/data/polarizability_data.csv'
    
    # Define target and feature fields for TorchDrug
    # 'y' will be 'static_polarizability'
    # 'chain_size', 'xx', 'yy', 'zz' will be loaded as additional features/attributes
    target_fields = ["static_polarizability"]
    feature_fields = ["chain_size", "xx", "yy", "zz"] # These will be accessible as mol.chain_size, mol.xx etc.

    print(f"Loading data from {csv_file_path} using TorchDrug...")
    # TorchDrug's Molecule.load_csv handles SMILES parsing and feature extraction
    # It automatically infers atom types, bond types, and basic atom/bond features.
    # Custom features specified in `feature_fields` are added as attributes to the Molecule object.
    full_torchdrug_dataset = Molecule.load_csv(
        csv_file_path,
        smiles_field="smiles",
        target_fields=target_fields,
        feature_fields=feature_fields,
        verbose=1 # Show progress
    )
    print(f"Loaded {len(full_torchdrug_dataset)} molecules from CSV.")

    # --- 2. Filter data for chain_size == 0 ---
    print("Filtering data for chain_size == 0...")
    filtered_torchdrug_dataset = []
    for mol in tqdm(full_torchdrug_dataset, desc="Filtering molecules"):
        # Access chain_size. It's loaded as a tensor, so use .item()
        if hasattr(mol, 'chain_size') and mol.chain_size.item() == 0:
            filtered_torchdrug_dataset.append(mol)
    
    # Create a new MoleculeDataset from the filtered list
    # This is important for TorchDrug's internal indexing and properties.
    training_dataset_torchdrug = torchdrug.data.MoleculeDataset(filtered_torchdrug_dataset)

    print(f"Filtered down to {len(training_dataset_torchdrug)} molecules with chain_size == 0.")
    
    if not training_dataset_torchdrug:
        print("Error: No molecules found after filtering. Exiting.")
        exit()

    # --- 3. Determine dimensions for GraphAF from the filtered dataset ---
    # TorchDrug datasets have properties to infer these directly
    NUM_ATOM_TYPES = training_dataset_torchdrug.num_atom_types
    NUM_BOND_TYPES = training_dataset_torchdrug.num_bond_types
    
    # atom_feature_dim and bond_feature_dim in GraphAF refer to *additional* features
    # beyond the standard atom/bond types that TorchDrug automatically encodes.
    # For RGCN, input_dim is num_atom_type and num_relation is num_bond_type.
    # GraphAF then uses the RGCN output.
    # So, ATOM_FEATURE_DIM and BOND_FEATURE_DIM from the dataset are not directly used
    # in the RGCN initialization for input/relation dimensions, but rather for the
    # overall context of the dataset.
    MAX_NODES_IN_DATASET = max([mol.num_atom for mol in training_dataset_torchdrug]) if training_dataset_torchdrug else 0

    print(f"Inferred Num Atom Types: {NUM_ATOM_TYPES}")
    print(f"Inferred Num Bond Types: {NUM_BOND_TYPES}")
    # Note: TorchDrug's MoleculeDataset.atom_feature_dim and bond_feature_dim
    # refer to the dimensions of the *default* features (e.g., atomic number, degree, etc.)
    # that TorchDrug automatically extracts, not necessarily the input_dim for RGCN
    # which typically takes atom_type indices.
    print(f"Max Nodes in Filtered Dataset: {MAX_NODES_IN_DATASET}")

    # Create TorchDrug DataLoader
    torchdrug_data_loader = torch_data.DataLoader(
        training_dataset_torchdrug, batch_size=32, shuffle=True, collate_fn=training_dataset_torchdrug.collate_fn
    )

    # --- 4. Initialize GraphAF Model Components ---
    # As per the tutorial: RGCN as the underlying GNN, then GraphAF flows.
    
    # Define the underlying GNN model (RGCN)
    # input_dim for RGCN is the number of atom types for one-hot encoding
    # num_relation for RGCN is the number of bond types
    rgcn_model = models.RGCN(input_dim=NUM_ATOM_TYPES,
                             num_relation=NUM_BOND_TYPES,
                             hidden_dims=[256, 256, 256],
                             batch_norm=True).to(device)

    # Define node and edge priors
    # num_atom_type is for node features (atom types)
    node_prior = distribution.IndependentGaussian(torch.zeros(NUM_ATOM_TYPES, device=device),
                                                  torch.ones(NUM_ATOM_TYPES, device=device))
    # num_bond_type + 1 for edge features (bond types + non-edge)
    edge_prior = distribution.IndependentGaussian(torch.zeros(NUM_BOND_TYPES + 1, device=device),
                                                  torch.ones(NUM_BOND_TYPES + 1, device=device))

    # Define GraphAF flows
    # The `model` argument here is the RGCN instance
    node_flow = models.GraphAF(rgcn_model, node_prior, num_layer=12).to(device)
    edge_flow = models.GraphAF(rgcn_model, edge_prior, use_edge=True, num_layer=12).to(device)

    # --- 5. Define Training Task and Engine (Unconditional Generation) ---
    print("\n--- Unconditional Graph Generation Training ---")
    unconditional_task = tasks.AutoregressiveGeneration(node_flow, edge_flow,
                                                        max_node=MAX_NODES_IN_DATASET,
                                                        max_edge_unroll=12, # From tutorial
                                                        criterion="nll") # Negative Log Likelihood

    unconditional_optimizer = torch.optim.Adam(unconditional_task.parameters(), lr=1e-3)
    unconditional_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(unconditional_optimizer, mode='min', factor=0.5, patience=5)

    unconditional_solver = core.Engine(unconditional_task, training_dataset_torchdrug, # Use the dataset directly
                                        unconditional_optimizer, unconditional_scheduler,
                                        device=device, batch_size=32, log_interval=10)

    # --- 6. Train the Model (Unconditional) ---
    print("\nStarting GraphAF unconditional training...")
    unconditional_solver.train(num_epoch=10) # Train for 10 epochs (adjust as needed)
    print("GraphAF unconditional training complete.")

    # --- Save the unconditionally trained model ---
    unconditional_model_path = "graphaf_monomer_unconditional.pkl"
    unconditional_solver.save(unconditional_model_path)
    print(f"Unconditional model saved to {unconditional_model_path}")

    # --- 7. Property-Guided Fine-tuning (Reinforcement Learning) ---
    print("\n--- Property-Guided Fine-tuning (Reinforcement Learning) ---")
    # Define a new task for RL fine-tuning
    # Use your target property: "static_polarizability"
    rl_task = tasks.AutoregressiveGeneration(node_flow, edge_flow,
                                            max_node=MAX_NODES_IN_DATASET,
                                            max_edge_unroll=12,
                                            task="static_polarizability", # Your target property
                                            criterion={"ppo": 0.25, "nll": 1.0}, # PPO for RL, NLL for validity
                                            reward_temperature=20, # From tutorial
                                            baseline_momentum=0.9, # From tutorial
                                            agent_update_interval=5, # From tutorial
                                            gamma=0.9) # From tutorial

    rl_optimizer = torch.optim.Adam(rl_task.parameters(), lr=1e-5) # Smaller LR for fine-tuning
    rl_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(rl_optimizer, mode='max', factor=0.5, patience=5,
                                                               metric="static_polarizability/mean") # Monitor property for RL

    rl_solver = core.Engine(rl_task, training_dataset_torchdrug, # Use the dataset directly
                            rl_optimizer, rl_scheduler,
                            device=device, batch_size=32, log_interval=10)

    # Load the unconditionally trained model for fine-tuning
    print(f"Loading unconditional model from {unconditional_model_path} for fine-tuning...")
    rl_solver.load(unconditional_model_path, load_optimizer=False) # Load model weights, not optimizer state

    # --- 8. Train the Model (RL Fine-tuning) ---
    print("\nStarting GraphAF RL fine-tuning...")
    rl_solver.train(num_epoch=10) # Train for 10 epochs (adjust as needed)
    print("GraphAF RL fine-tuning complete.")

    # --- Save the fine-tuned model ---
    finetuned_model_path = "graphaf_monomer_finetuned.pkl"
    rl_solver.save(finetuned_model_path)
    print(f"Fine-tuned model saved to {finetuned_model_path}")


    # --- 9. Generate New Graphs ---
    print("\nGenerating new graphs with the fine-tuned GraphAF model...")
    num_graphs_to_generate = 5

    # When generating after RL fine-tuning, the model is biased to produce
    # molecules with higher (or lower, depending on reward) values of the target property.
    # The `generate` method itself is still unconditional in terms of direct property input.
    
    # Generate molecules. The `generate` method returns a list of Molecule objects.
    generated_molecules = rgcn_model.generate(node_flow, edge_flow, num_graphs_to_generate,
                                               max_num_nodes=MAX_NODES_IN_DATASET)

    print(f"Generated {len(generated_molecules)} molecules.")
    for i, mol in enumerate(generated_molecules):
        print(f"\nGenerated Molecule {i+1}:")
        print(f"  Number of atoms: {mol.num_atom}")
        print(f"  Number of bonds: {mol.num_bond}")
        # You can inspect the atom types and bond types
        # print(f"  Atom types: {mol.atom_type.tolist()}")
        # print(f"  Bond types: {mol.bond_type.tolist()}")
        # print(f"  Bond index: {mol.bond_index.tolist()}")
        
        # To visualize or further process, you'd typically convert these
        # torchdrug.data.Molecule objects to RDKit molecules.
        # This requires RDKit and appropriate atom/bond type mapping.
        # Example (requires RDKit and appropriate atom/bond type mapping):
        # from rdkit import Chem
        # from rdkit.Chem.rdchem import BondType
        # # Define your atom and bond mappings based on your dataset's chemistry
        # # Example: atom_map = {0: Chem.Atom("C"), 1: Chem.Atom("O"), ...}
        # # bond_map = {0: BondType.SINGLE, 1: BondType.DOUBLE, ...}
        #
        # try:
        #     rdkit_mol = Chem.Mol()
        #     editable_mol = Chem.RWMol(rdkit_mol)
        #     for atom_idx in range(mol.num_atom):
        #         atom_type_idx = mol.atom_type[atom_idx].item()
        #         # Add atom based on atom_type_idx and your mapping
        #         # editable_mol.AddAtom(atom_map.get(atom_type_idx, Chem.Atom("C")))
        #         # For now, just add a generic atom if no mapping is provided
        #         editable_mol.AddAtom(Chem.Atom(str(atom_type_idx))) # Convert to string for generic atom
        #
        #     for bond_idx in range(mol.num_bond):
        #         src, dst = mol.bond_index[:, bond_idx].tolist()
        #         bond_type_idx = mol.bond_type[bond_idx].item()
        #         # Add bond based on bond_type_idx and your mapping
        #         # editable_mol.AddBond(src, dst, bond_map.get(bond_type_idx, BondType.SINGLE))
        #         # For now, just add a generic single bond
        #         editable_mol.AddBond(src, dst)
        #
        #     rdkit_mol = editable_mol.GetMol()
        #     print(f"  RDKit SMILES: {Chem.MolToSmiles(rdkit_mol)}")
        # except Exception as e:
        #     print(f"  Error converting to RDKit molecule: {e}")
        #     print("  (RDKit conversion requires careful mapping of atom/bond types and RDKit installation)")

