## Install Package

In [None]:
!pip install mlconfgen==0.2.0

## Download the weights from HuggingFace
> https://huggingface.co/Membrizard/ml_conformer_generator

`edm_moi_chembl_15_39.pt`

`adj_mat_seer_chembl_15_39.pt`

## 1. Generate Molecules using a Reference Molecule and evaluate results

Generate novel molecular structures using a reference molecule as a spatial template. This is ideal when you're interested in discovering structurally similar compounds that may differ chemically, leveraging PyTorch for deep generative modeling and evaluation.

In [None]:
import time
import torch

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGenerator, evaluate_samples

RDLogger.DisableLog('rdApp.*')

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps:0")
else:
    device = torch.device("cpu")

print(f"Intitialising model on {device}")
    
generator = MLConformerGenerator(
                                 edm_weights="./edm_moi_chembl_15_39.pt",
                                 adj_mat_seer_weights="./adj_mat_seer_chembl_15_39.pt",
                                 device=device,
                                 diffusion_steps=100,
                                )

# Load a Reference conformer
ref_mol = Chem.MolFromMolFile('./assets/demo_files/ceyyag.mol')

# Generate Samples
print("Generation started...")
start = time.time()

# Resampling significantly increases generation quality, while sacrificing speed

samples = generator.generate_conformers(
                                        reference_conformer=ref_mol,
                                        n_samples=10,
                                        variance=2,
                                        resample_steps=4,
                                        )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Characterise samples   
_, std_samples = evaluate_samples(ref_mol, samples)

# Display results
mols = []
legends = []
for sample in std_samples:
    mol = Chem.MolFromMolBlock(sample['mol_block'])
    mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
    mol.SetProp("Shape_Tanimoto", str(sample['shape_tanimoto']))
    mols.append(mol)
    legends.append(f"Shape Similarity - {round(sample['shape_tanimoto'], 2)}")
    
Draw.MolsToGridImage(mols, legends=legends)

## 2. Generate Molecules from a Fixed Fragment using a Reference Molecule

### 2.1 Simple fragment-based generation
    
This method integrates the fixed fragment directly into the molecule during the denoising process. Resampling is used to guide and stabilize the generation, helping harmonize the final structure around the injected fragment.

In [None]:
import time
import torch

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGenerator, evaluate_samples

RDLogger.DisableLog('rdApp.*')

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps:0")
else:
    device = torch.device("cpu")

print(f"Intitialising model on {device}")
    
generator = MLConformerGenerator(
                                 edm_weights="./edm_moi_chembl_15_39.pt",
                                 adj_mat_seer_weights="./adj_mat_seer_chembl_15_39.pt",
                                 device=device,
                                 diffusion_steps=100,
                                )

# Load a Reference conformer
ref_mol = Chem.MolFromMolFile('./assets/demo_files/yibfeu.mol')
# Load a Fragment, which you want generated molecules would have
fixed_fragment = Chem.MolFromMolFile('./assets/demo_files/paba.mol')

print("Fixed Fragment:")

display(Draw.MolToImage(Chem.MolFromSmiles(Chem.MolToSmiles(fixed_fragment))))

# Generate Samples
print("Generation started...")
start = time.time()
# Resampling increases generation quality, while sacrificing speed

samples = generator.generate_conformers(
                                        reference_conformer=ref_mol,
                                        n_samples=10,
                                        variance=2,
                                        resample_steps=5,
                                        fixed_fragment=fixed_fragment,
                                        blend_power=3,
                                        inertial_fragment_matching=False,
                                       )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Characterise samples   
_, std_samples = evaluate_samples(ref_mol, samples)

# Display results
mols = []
legends = []
for sample in std_samples:
    mol = Chem.MolFromMolBlock(sample['mol_block'])
    mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
    mol.SetProp("Shape_Tanimoto", str(sample['shape_tanimoto']))
    mols.append(mol)
    legends.append(f"Shape Similarity - {round(sample['shape_tanimoto'], 2)}")
    
Draw.MolsToGridImage(mols, legends=legends)
#75.27


### 2.2 Fragment-Based Generation with Inertial Fragment Matching

In this mode, the Inertial Fragment Matching pipeline leverages spatial consistency and shape descriptor properties to maintain the fragment’s placement while generating compatible surrounding structures followed by merging fixed and genreated fragments together.

**Key options:**

`inertial_fragment_matching=True` : Enables this mode. (enabled by default)

`ifm_diffusion_level` : Specifies the diffusion step at which denoising of the merged fragments begins, allowing controlled integration of the fixed fragment into the generated molecule.

In [None]:
import time
import torch

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGenerator, evaluate_samples

RDLogger.DisableLog('rdApp.*')

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps:0")
else:
    device = torch.device("cpu")

print(f"Intitialising model on {device}")
    
generator = MLConformerGenerator(
                                 edm_weights="./edm_moi_chembl_15_39.pt",
                                 adj_mat_seer_weights="./adj_mat_seer_chembl_15_39.pt",
                                 device=device,
                                 diffusion_steps=100,
                                )

# Load a Reference conformer
ref_mol = Chem.MolFromMolFile('./assets/demo_files/yibfeu.mol')
# Load a Fragment, which you want generated molecules would have
fixed_fragment = Chem.MolFromMolFile('./assets/demo_files/frag_yibfeu.mol')

print("Fixed Fragment:")

display(Draw.MolToImage(Chem.MolFromSmiles(Chem.MolToSmiles(fixed_fragment))))

# Generate Samples
print("Generation started...")
start = time.time()

# Here, lower resampling values tend to give better results

samples = generator.generate_conformers(
                                        reference_conformer=ref_mol,
                                        n_samples=10,
                                        variance=2,
                                        resample_steps=2,
                                        fixed_fragment=fixed_fragment,
                                        blend_power=3,
                                        inertial_fragment_matching=True,
                                        ifm_diffusion_level= 20,
                                       )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Characterise samples   
_, std_samples = evaluate_samples(ref_mol, samples)

# Display results
mols = []
legends = []
for sample in std_samples:
    mol = Chem.MolFromMolBlock(sample['mol_block'])
    mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
    mol.SetProp("Shape_Tanimoto", str(sample['shape_tanimoto']))
    mols.append(mol)
    legends.append(f"Shape Similarity - {round(sample['shape_tanimoto'], 2)}")
    
Draw.MolsToGridImage(mols, legends=legends)

## 3. Generate Molecules Using an Arbitrary Shape (e.g., Protein Pocket)

Generate molecules that conform to a custom 3D shape, such as a protein binding pocket. This method enables structure-based design without needing a reference molecule.

**Note:** Requires trimesh >= 4.6.4 for mesh processing and shape representation.

In [None]:
!pip install trimesh==4.6.4

In [None]:
import time
import torch
import trimesh

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGenerator

RDLogger.DisableLog('rdApp.*')

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps:0")
else:
    device = torch.device("cpu")

print(f"Intitialising model on {device}")

# Load example - CLK1 pocket as .stl file
mymesh = trimesh.load('./assets/demo_files/6q8k_pocket.stl')
mymesh.density = 0.02

# The shape that you are using needs to be watertight
check = mymesh.is_watertight

if not check:
    raise ValueError("The .stl file needs to be watertight.")

ref_context = torch.tensor(mymesh.principal_inertia_components, dtype=torch.float32)

generator = MLConformerGenerator(
                                 edm_weights="./edm_moi_chembl_15_39.pt",
                                 adj_mat_seer_weights="./adj_mat_seer_chembl_15_39.pt",
                                 device=device,
                                 diffusion_steps=100,
                                )

# Generate Samples
print("Generation started...")
start = time.time()

samples = generator.generate_conformers(
                                        reference_context=ref_context,
                                        n_atoms=38,
                                        n_samples=10,
                                        variance=1,
                                       )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Display results
mols = [Chem.MolFromSmiles(Chem.MolToSmiles(x)) for x in samples]

Draw.MolsToGridImage(mols)

## 4. Generate Molecules from a Fixed Fragment Using an Arbitrary Shape (e.g., Protein Pocket)
Combine a fixed chemical fragment with a custom 3D shape—such as a protein binding pocket—to guide molecular generation. This allows for fragment-based design within spatial constraints defined by the target environment.

**Note:** Requires trimesh >= 4.6.4 for mesh handling and shape integration.

### 4.1 Simple fragment-based generation (Arbitrary Shape)

In [None]:
import time
import torch
import trimesh

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGenerator

RDLogger.DisableLog('rdApp.*')

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps:0")
else:
    device = torch.device("cpu")

print(f"Intitialising model on {device}")

fixed_fragment = Chem.MolFromMolFile('./assets/demo_files/paba.mol')

# Load example - CLK1 pocket as .stl file
mymesh = trimesh.load('./assets/demo_files/6q8k_pocket.stl')
mymesh.density = 0.02

# The shape that you are using needs to be watertight
check = mymesh.is_watertight

if not check:
    raise ValueError("The .stl file needs to be watertight.")

ref_context = torch.tensor(mymesh.principal_inertia_components, dtype=torch.float32)

generator = MLConformerGenerator(
                                 edm_weights="./edm_moi_chembl_15_39.pt",
                                 adj_mat_seer_weights="./adj_mat_seer_chembl_15_39.pt",
                                 device=device,
                                 diffusion_steps=100,
                                )

print("Fixed Fragment:")

display(Draw.MolToImage(Chem.MolFromSmiles(Chem.MolToSmiles(fixed_fragment))))

# Generate Samples
print("Generation started...")
start = time.time()
samples = generator.generate_conformers(
                                        reference_context=ref_context,
                                        n_atoms=38,
                                        n_samples=10,
                                        variance=1,
                                        resample_steps=2,
                                        fixed_fragment=fixed_fragment,
                                        blend_power=3,
                                        inertial_fragment_matching=False,
                                       )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Display results
mols = [Chem.MolFromSmiles(Chem.MolToSmiles(x)) for x in samples]

Draw.MolsToGridImage(mols)

### 4.2 Inertial Fragment Matching (Arbitrary Shape)

In [None]:
import time
import torch
import trimesh

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGenerator

RDLogger.DisableLog('rdApp.*')

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps:0")
else:
    device = torch.device("cpu")

print(f"Intitialising model on {device}")

fixed_fragment = Chem.MolFromMolFile('./assets/demo_files/frag_yibfeu.mol')

# Load example - CLK1 pocket as .stl file
mymesh = trimesh.load('./assets/demo_files/6q8k_pocket.stl')
mymesh.density = 0.02

# The shape that you are using needs to be watertight
check = mymesh.is_watertight

if not check:
    raise ValueError("The .stl file needs to be watertight.")

ref_context = torch.tensor(mymesh.principal_inertia_components, dtype=torch.float32)

generator = MLConformerGenerator(
                                 edm_weights="./edm_moi_chembl_15_39.pt",
                                 adj_mat_seer_weights="./adj_mat_seer_chembl_15_39.pt",
                                 device=device,
                                 diffusion_steps=100,
                                )

print("Fixed Fragment:")

display(Draw.MolToImage(Chem.MolFromSmiles(Chem.MolToSmiles(fixed_fragment))))

# Generate Samples
print("Generation started...")
start = time.time()
samples = generator.generate_conformers(
                                        reference_context=ref_context,
                                        n_atoms=38,
                                        n_samples=10,
                                        variance=1,
                                        resample_steps=2,
                                        fixed_fragment=fixed_fragment,
                                        blend_power=3,
                                        inertial_fragment_matching=True,
                                       )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Display results
mols = [Chem.MolFromSmiles(Chem.MolToSmiles(x)) for x in samples]

Draw.MolsToGridImage(mols)

## 5. Generate Molecules using a Reference Molecule (ONNX)

Run molecular generation with a reference molecule using the ONNX runtime for efficient, hardware-agnostic inference—ideal for deployment or environments without native PyTorch support.

While only reference molecule-based examples are shown here, ONNX supports all generation modes—including those based on arbitrary shapes—identically to the PyTorch implementation.

Additional dependencies required:

>onnx==1.17.0

>onnxruntime==1.21.0


### Download the weights in the ONNX format from HuggingFace
> https://huggingface.co/Membrizard/ml_conformer_generator

`egnn_chembl_15_39.onnx`

`adj_mat_seer_chembl_15_39.onnx`

In [None]:
import time

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGeneratorONNX

RDLogger.DisableLog('rdApp.*')
    
generator = MLConformerGeneratorONNX(
                                     egnn_onnx="./egnn_chembl_15_39.onnx",
                                     adj_mat_seer_onnx="./adj_mat_seer_chembl_15_39.onnx",
                                    )

# Load a Reference conformer
ref_mol = Chem.MolFromMolFile('./assets/demo_files/yibfeu.mol')

# Generate Samples
print("Generation started...")
start = time.time()

samples = generator.generate_conformers(
                                        reference_conformer=ref_mol,
                                        n_samples=10,
                                        variance=2,
                                        resample_steps=2,
                                        )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Display results
mols = [Chem.MolFromSmiles(Chem.MolToSmiles(x)) for x in samples]
    
Draw.MolsToGridImage(mols)

## 6. Generate Molecules from a Fixed Fragment using a Reference Molecule  (ONNX)

Perform fixed-fragment-based molecular generation with a reference molecule using the ONNX runtime for fast, portable, and hardware-agnostic inference. This mode replicates the PyTorch implementation 1:1, ensuring consistent results across both backends.

Additional dependencies required:

>onnx==1.17.0

>onnxruntime==1.21.0

### 6.1 Simple fragment-based generation (ONNX)

In [None]:
import time

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGeneratorONNX

RDLogger.DisableLog('rdApp.*')
    
generator = MLConformerGeneratorONNX(
                                     egnn_onnx="./egnn_chembl_15_39.onnx",
                                     adj_mat_seer_onnx="./adj_mat_seer_chembl_15_39.onnx",
                                    )

# Load a Reference conformer
ref_mol = Chem.MolFromMolFile('./assets/demo_files/yibfeu.mol')
fixed_fragment = Chem.MolFromMolFile('./assets/demo_files/paba.mol')

print("Fixed Fragment:")

display(Draw.MolToImage(Chem.MolFromSmiles(Chem.MolToSmiles(fixed_fragment))))

# Generate Samples
print("Generation started...")
start = time.time()
samples = generator.generate_conformers(reference_conformer=ref_mol,
                                        n_samples=10,
                                        variance=2,
                                        resample_steps=2,
                                        fixed_fragment=fixed_fragment,
                                        blend_power=3,
                                        inertial_fragment_matching=False,
                                       )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Display results
mols = [Chem.MolFromSmiles(Chem.MolToSmiles(x)) for x in samples]
    
Draw.MolsToGridImage(mols)

### 6.2 Inertial Fragment Matching (ONNX)

In [None]:
import time

from rdkit import Chem, RDLogger
from rdkit.Chem import Draw

from mlconfgen import MLConformerGeneratorONNX

RDLogger.DisableLog('rdApp.*')
    
generator = MLConformerGeneratorONNX(
                                     egnn_onnx="./egnn_chembl_15_39.onnx",
                                     adj_mat_seer_onnx="./adj_mat_seer_chembl_15_39.onnx",
                                    )

# Load a Reference conformer
ref_mol = Chem.MolFromMolFile('./assets/demo_files/yibfeu.mol')
fixed_fragment = Chem.MolFromMolFile('./assets/demo_files/frag_yibfeu.mol')

print("Fixed Fragment:")

display(Draw.MolToImage(Chem.MolFromSmiles(Chem.MolToSmiles(fixed_fragment))))

# Generate Samples
print("Generation started...")
start = time.time()
samples = generator.generate_conformers(reference_conformer=ref_mol,
                                        n_samples=10,
                                        variance=2,
                                        resample_steps=2,
                                        fixed_fragment=fixed_fragment,
                                        blend_power=3,
                                        inertial_fragment_matching=True,
                                        ifm_diffusion_level=30,
                                       )

print(f"Generation complete in {round(time.time() - start, 2)}")

# Display results
mols = [Chem.MolFromSmiles(Chem.MolToSmiles(x)) for x in samples]
    
Draw.MolsToGridImage(mols)

## 7. Export MLConfGen Model to ONNX
Convert a trained MLConfGen model to the ONNX format for optimized, framework-independent inference.
This enables deployment in environments where PyTorch is not available or where faster runtime performance is desired (e.g., for production or cross-platform compatibility).

In [None]:
from mlconfgen import MLConformerGenerator
from onnx_export import export_to_onnx

generator = MLConformerGenerator(
                                 edm_weights="./edm_moi_chembl_15_39.pt",
                                 adj_mat_seer_weights="./adj_mat_seer_chembl_15_39.pt",
                                 diffusion_steps=100,
                                )
export_to_onnx(model=generator)