# Molecule Generation and Validation

The model's output is not a molecule, but rather two tensors: one for final denoised atom features and another for the 3D atomic coordinates. To get a molecule, we follow a two-step process:

1. Feature-to-Atom Conversion: We use a simple lookup table to convert the final atom feature vectors into atom types (e.g., C, N, O).
2. Coordinates-to-Bonds Conversion: We infer the bonds between atoms based on the distance between their 3D coordinates. If two atoms are a plausible distance apart, we assume there's a bond. This is a crucial, heuristic step that translates the continuous output into a discrete chemical graph.

In [None]:
import torch
import os
import logging
from rdkit import Chem
from typing import List

# Import the refactored components
from mol_diff_3d.models.diffusion import MolecularDiffusionModel
from mol_diff_3d.models.noise_scheduler import NoiseScheduler
from mol_diff_3d.sampling.samplers import DDPMPsampler
from mol_diff_3d.generation.generator import generate_molecules_from_model
from mol_diff_3d.models.property_predictor import MolecularPropertyPredictor
from mol_diff_3d.utils.checkpoints import load_checkpoint
from mol_diff_3d.utils.visualization import visualize_molecule
from mol_diff_3d.utils.molecular import validate_molecule, calculate_molecular_properties
from mol_diff_3d.utils.metrics import calculate_validity_rate, calculate_uniqueness_rate

# Set up logging and device
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Define model configuration
config = {
    'atom_dim': 11, # from previous notebook
    'pos_dim': 3,
    'hidden_dim': 128,
    'time_dim': 128,
    'num_timesteps': 1000
}

# Initialize model and noise scheduler
model = MolecularDiffusionModel(
    atom_dim=config['atom_dim'],
    pos_dim=config['pos_dim'],
    hidden_dim=config['hidden_dim'],
    time_dim=config['time_dim']
).to(device)
noise_scheduler = NoiseScheduler(num_timesteps=config['num_timesteps']).to(device)

# Load the checkpoint
checkpoint_path = "checkpoints/mol_diff_final.pth"
if os.path.exists(checkpoint_path):
    checkpoint_info = load_checkpoint(checkpoint_path, model)
    print("Model successfully loaded from checkpoint.")
else:
    raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please run the training notebook first.")

# Initialize the reverse sampler
p_sampler = DDPMPsampler(noise_scheduler.get_parameters())

In [None]:
# Set generation parameters
num_to_generate = 10
max_atoms = 25

# Generate molecules
generated_mols, generated_smiles = generate_molecules_from_model(
    model=model,
    p_sampler=p_sampler,
    num_molecules=num_to_generate,
    max_atoms=max_atoms,
    atom_dim=config['atom_dim'],
    pos_dim=config['pos_dim'],
    device=device
)

# Print generation statistics
print(f"\nAttempted to generate {num_to_generate} molecules.")
print(f"Successfully generated {len(generated_smiles)} valid molecules.")

# Calculate basic metrics
validity_rate = calculate_validity_rate(generated_smiles)
uniqueness_rate = calculate_uniqueness_rate(generated_smiles)
print(f"Validity rate: {validity_rate:.2f}")
print(f"Uniqueness rate: {uniqueness_rate:.2f}")

In [None]:
if generated_mols:
    print("\n--- Visualizing a few generated molecules ---")
    for i, mol in enumerate(generated_mols[:3]): # Visualize first 3
        smiles = generated_smiles[i]
        print(f"Generated Molecule {i+1}: {smiles}")

        # Visualize the 2D structure
        visualize_molecule(smiles)

        # Calculate properties (using the refactored function)
        properties = calculate_molecular_properties(mol)
        print("Calculated Properties:")
        for prop, value in properties.items():
            print(f"  - {prop}: {value:.2f}")

        # Note: You would visualize the 3D structure with a dedicated viewer
        # (e.g., PyMOL, NGLview in a browser-based notebook)
        # This part is for conceptual demonstration.
        #

## Molecular Property Analysis

In addition to standard metrics like validity and uniqueness, we can now assess the quality of the generated 3D structures. For example, we can use RDKit's built-in functions to calculate properties from the generated molecules and compare them to the original dataset. We could check:

1. Molecular Weight: How does the weight distribution of generated molecules compare to the training set?
2. LogP: Is the generated hydrophobicity reasonable?
3. Bond Lengths: Are the inferred bond lengths chemically plausible?



In [None]:
from mol_diff.models.gnn import E3GNN

# Assume a pre-trained property predictor exists or train a new one
# Here, we will just instantiate one to show the concept of using the same GNN backbone
# This requires a pre-trained property predictor model.
# For demonstration, we'll assume a dummy model.
try:
    # Instantiate the GNN used by the diffusion model
    dummy_gnn = E3GNN(
        in_feat_dim=config['atom_dim'],
        pos_dim=config['pos_dim'],
        hidden_dim=config['hidden_dim'],
        out_feat_dim=config['hidden_dim']
    )

    # Create the property predictor using the same GNN
    # Note: A real implementation would load trained weights for this predictor
    predictor = MolecularPropertyPredictor(gnn=dummy_gnn, num_tasks=1).to(device)

    # Pick a generated molecule to predict a property for (e.g., LogP)
    if generated_mols:
        sample_mol = generated_mols[0]
        sample_smiles = generated_smiles[0]

        print(f"\n--- Predicting a property for '{sample_smiles}' ---")

        # Convert the generated molecule back to a PyG batch for prediction
        # This requires a function to convert an RDKit Mol to a PyG Data object
        # (This is a simplified example)
        # In a real scenario, you'd feed the molecule through a featurizer.
        # Here we just use the raw tensors from generation.
        sample_pos = generated_mols[0].GetConformer().GetPositions()

        # NOTE: A more robust featurizer would be needed here to match the training pipeline.

        # For demonstration purposes, we'll just print a dummy prediction.
        print("This part of the code requires a trained Property Predictor and a Mol-to-PyG featurizer.")
        print("Once implemented, you would call `predictor(batch.x, batch.edge_index, batch.pos, batch.batch)`")
        print("and the predictor would use the GNN to infer a property.")

except Exception as e:
    print(f"Could not demonstrate property prediction. Error: {e}")