<a href="https://colab.research.google.com/github/apoorvapu/data_science/blob/main/molecule_generation_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [96]:
!pip install torch torchvision diffusers transformers



In [97]:
!pip install rdkit pubchempy tqdm py3Dmol



## download 1000 molecules 3D .sdf files from pubchem and store in a molecules directory

In [98]:
import os
import requests
from tqdm import tqdm

# Make sure the folder exists
os.makedirs("molecules", exist_ok=True)

downloaded = 0
target = 1000
cid = 1

while downloaded < target and cid < 10000:
    try:
        url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/SDF?record_type=3d"
        response = requests.get(url, timeout=10)
        if response.status_code == 200 and len(response.text) > 100:
            file_path = f"molecules/mol_{cid}.sdf"
            with open(file_path, "w") as f:
                f.write(response.text)
            downloaded += 1
        cid += 1
    except Exception as e:
        print(f"CID {cid} failed: {e}")
        cid += 1


In [99]:
ls molecules/*.sdf | wc -l


1048


In [100]:
import os
from rdkit import Chem
from rdkit.Chem import AllChem
import py3Dmol
from IPython.display import display

def view_mol(mol):
    # Generate 3D coordinates if missing
    if mol.GetNumConformers() == 0:
        AllChem.EmbedMolecule(mol)
    mb = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=400, height=350)
    viewer.addModel(mb, 'mol')
    viewer.setStyle({'stick': {}})
    viewer.setBackgroundColor('white')
    viewer.zoomTo()
    return viewer

# List SDF files
sdf_dir = "molecules"
sdf_files = sorted([f for f in os.listdir(sdf_dir) if f.endswith(".sdf")])

# Visualize first 5 molecules (change range as needed)
for sdf_file in sdf_files[:5]:
    sdf_path = os.path.join(sdf_dir, sdf_file)
    mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
    if mol:
        print(f"Showing: {sdf_file}")
        view_mol(mol).show()


Showing: mol_1.sdf


Showing: mol_10.sdf


Showing: mol_1000.sdf


Showing: mol_1001.sdf


Showing: mol_1002.sdf


# create padding based on maximum number of atoms in a molecule

In [101]:
# Find maximum number of atoms across all molecules
max_atoms = max(len(coords) for coords in coords_list)
max_atoms

57

In [102]:
# Pad coordinates and atom types
padded_coords = []
padded_atom_types = []

for coords, atom_types in zip(coords_list, atom_types_list):
    # Pad coordinates (with zeros)
    padded_coords.append(np.pad(coords, ((0, max_atoms - len(coords)), (0, 0)), mode='constant'))

    # Pad atom types (with placeholder 'X')
    padded_atom_types.append(atom_types + ['X'] * (max_atoms - len(atom_types)))

coords_array = np.array(padded_coords)
atom_types_array = np.array(padded_atom_types)

print(f"Shape of coords_array: {coords_array.shape}")
print(f"Shape of atom_types_array: {atom_types_array.shape}")


Shape of coords_array: (100, 57, 3)
Shape of atom_types_array: (100, 57)


In [103]:
atom_types_array

array([['O', 'O', 'O', ..., 'X', 'X', 'X'],
       ['O', 'O', 'O', ..., 'X', 'X', 'X'],
       ['O', 'O', 'O', ..., 'X', 'X', 'X'],
       ...,
       ['O', 'O', 'N', ..., 'X', 'X', 'X'],
       ['O', 'O', 'C', ..., 'X', 'X', 'X'],
       ['O', 'O', 'O', ..., 'X', 'X', 'X']], dtype='<U2')

In [104]:
coords_array

array([[[-2.1417,  1.0315,  0.7136],
        [ 0.4877,  1.4813, -0.3153],
        [ 2.8465,  0.6297, -0.0278],
        ...,
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ]],

       [[-0.8007,  3.1719, -0.3307],
        [ 1.9377,  2.1145, -0.2359],
        [-2.2221, -2.4151, -1.7057],
        ...,
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ]],

       [[-1.3259, -1.1598, -0.3674],
        [ 3.2938,  2.0525,  0.2706],
        [ 1.1787,  2.8738,  0.5567],
        ...,
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ]],

       ...,

       [[ 2.5311,  1.0877,  1.1888],
        [-2.5344, -1.0904, -1.1875],
        [ 1.4397,  1.8125, -0.7268],
        ...,
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ]],

       [[ 2.923 ,  0.6054, -0.7218],
  

# convert absolute coordinates to relative coordinates of each molecule so that invariant to translation, rotation, etc.

In [105]:
# Atomic weights in g/mol (for common atoms)
atomic_weights = {
    'H': 1.008,
    'C': 12.011,
    'N': 14.007,
    'O': 15.999,
    'S': 32.06,
    'P': 30.974,
    'Cl': 35.45,
    'Br': 79.904,
    'I': 126.904,
    'F': 18.998,
    # Add more atoms as needed
}


In [106]:
import numpy as np

def compute_center_of_mass(coords, atom_types, atomic_weights):
    # Initialize the total mass and weighted sum of coordinates
    total_mass = 0.0
    weighted_coords = np.zeros(3)

    # Sum over each atom's contribution to the center of mass
    for i in range(len(coords)):
        atom_weight = atomic_weights.get(atom_types[i], 0)  # Default to 0 if atom type is not found
        total_mass += atom_weight
        weighted_coords += np.array(coords[i]) * atom_weight

    # Compute the center of mass
    center_of_mass = weighted_coords / total_mass
    return center_of_mass

# Example for a single molecule (coords and atom_types are lists for the molecule)
coords = [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]  # Example coordinates
atom_types = ['C', 'H', 'H']  # Example atom types

# Compute the center of mass
com = compute_center_of_mass(coords, atom_types, atomic_weights)
print("Center of Mass:", com)


Center of Mass: [0.07186141 0.07186141 0.        ]


In [107]:
def convert_to_relative_coordinates(coords, com):
    # Subtract the center of mass from each atom's coordinates
    relative_coords = np.array(coords) - com
    return relative_coords

# Convert to relative coordinates for the molecule
relative_coords = convert_to_relative_coordinates(coords, com)
print("Relative Coordinates:", relative_coords)


Relative Coordinates: [[-0.07186141 -0.07186141  0.        ]
 [ 0.92813859 -0.07186141  0.        ]
 [-0.07186141  0.92813859  0.        ]]


In [108]:
relative_coords_list = []

for i in range(coords_array.shape[0]):  # Iterate over each molecule
    coords = coords_array[i]
    atom_types = atom_types_array[i]

    # Remove padding atoms for COM calculation (ensure atom_types is a list)
    atom_types = atom_types.tolist()  # Convert numpy array to list if needed

    # Identify the indices of valid atoms (i.e., not 'X')
    valid_indices = [i for i, atom in enumerate(atom_types) if atom != 'X']

    # Get non-padding coordinates and atom types
    non_padding_coords = coords[valid_indices]
    non_padding_atom_types = [atom_types[i] for i in valid_indices]

    # Compute the center of mass
    com = compute_center_of_mass(non_padding_coords, non_padding_atom_types, atomic_weights)

    # Convert to relative coordinates
    relative_coords = convert_to_relative_coordinates(coords, com)

    # Append the relative coordinates for the current molecule
    relative_coords_list.append(relative_coords)

# Convert relative coordinates list to numpy array
relative_coords_array = np.array(relative_coords_list)

print(f"Shape of relative_coords_array: {relative_coords_array.shape}")

Shape of relative_coords_array: (100, 57, 3)


In [109]:
import torch
import torch.nn as nn
from diffusers import DDPMScheduler, UNet2DModel
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import random

# Example sizes and hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 4
learning_rate = 1e-4
num_epochs = 10
num_molecules = 1000  # Adjust as necessary
max_atoms = 50  # Maximum number of atoms in a molecule, adjust accordingly

# Dataset class to load relative coordinates and atom types
class MoleculeDataset(Dataset):
    def __init__(self, coords_array, atom_types_array):
        self.coords = coords_array
        self.atom_types = atom_types_array

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        coords = self.coords[idx]
        atom_types = self.atom_types[idx]

        # Convert atom types to one-hot encoding (you can also use embedding layer later)
        atom_type_encoding = np.zeros(len(atom_types))  # Placeholder for encoding
        for i, atom in enumerate(atom_types):
            atom_type_encoding[i] = random.randint(0, 1)  # Example encoding, you can extend for all atoms

        return torch.tensor(coords, dtype=torch.float32), torch.tensor(atom_type_encoding, dtype=torch.float32)

# Initialize dataset and dataloaders
dataset = MoleculeDataset(relative_coords_array, atom_types_array)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the U-Net model (simplified for this example)
class DiffusionUNet(nn.Module):
    def __init__(self, in_channels, out_channels, time_embed_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_embed_dim),
            nn.ReLU(),
            nn.Linear(time_embed_dim, in_channels)
        )
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 128, 3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv1d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, out_channels, 3, padding=1),
        )

    def forward(self, x, t):
        # t: (batch,) → (batch, 1)
        t_embed = self.time_mlp(t[:, None]).unsqueeze(-1)  # (batch, in_channels, 1)
        x = x + t_embed  # Broadcast to all atoms
        return self.decoder(self.encoder(x))


In [110]:
# Initialize the U-Net model for diffusion
model = DiffusionUNet(in_channels=3, out_channels=3).to(device)  # We are working with 3D coordinates
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Initialize DDPMScheduler (this defines the diffusion process)
scheduler = DDPMScheduler(num_train_timesteps=1000)

# Training loop
# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        coords, atom_types = batch
        coords, atom_types = coords.to(device), atom_types.to(device)

        # Sample random timesteps and noise
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (coords.shape[0],), device=device).long()
        noise = torch.randn_like(coords)

        # Add noise to the original coordinates
        noisy_coords = scheduler.add_noise(original_samples=coords, noise=noise, timesteps=timesteps)

        # Predict noise using the model
        model_input = noisy_coords.permute(0, 2, 1)  # (batch_size, 3, max_atoms) for Conv1D
        predicted_noise = model(model_input, timesteps.float()).permute(0, 2, 1)  # Back to (batch_size, max_atoms, 3)

        # Loss: compare predicted noise with true noise
        loss = torch.mean((predicted_noise - noise) ** 2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss / len(train_loader):.6f}")




  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
Epoch 1/10: 100%|██████████| 25/25 [00:00<00:00, 66.49it/s]


Epoch 1/10 - Loss: 11.839390


Epoch 2/10: 100%|██████████| 25/25 [00:00<00:00, 63.58it/s]


Epoch 2/10 - Loss: 1.448439


Epoch 3/10: 100%|██████████| 25/25 [00:00<00:00, 54.07it/s]


Epoch 3/10 - Loss: 1.059202


Epoch 4/10: 100%|██████████| 25/25 [00:00<00:00, 75.61it/s]


Epoch 4/10 - Loss: 1.019744


Epoch 5/10: 100%|██████████| 25/25 [00:00<00:00, 61.62it/s]


Epoch 5/10 - Loss: 1.025103


Epoch 6/10: 100%|██████████| 25/25 [00:00<00:00, 53.97it/s]


Epoch 6/10 - Loss: 1.005381


Epoch 7/10: 100%|██████████| 25/25 [00:00<00:00, 47.93it/s]


Epoch 7/10 - Loss: 1.002916


Epoch 8/10: 100%|██████████| 25/25 [00:00<00:00, 65.24it/s]


Epoch 8/10 - Loss: 1.010818


Epoch 9/10: 100%|██████████| 25/25 [00:00<00:00, 49.44it/s]


Epoch 9/10 - Loss: 0.996224


Epoch 10/10: 100%|██████████| 25/25 [00:00<00:00, 54.32it/s]

Epoch 10/10 - Loss: 0.998601





In [111]:

import py3Dmol
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Geometry import Point3D

def visualize_mol(mol):
    mb = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=500, height=400)
    viewer.addModel(mb, 'mol')
    viewer.setStyle({'stick': {}})
    viewer.setBackgroundColor('white')
    viewer.zoomTo()
    return viewer  # <-- Return the viewer instead of .show()



In [112]:
from rdkit import Chem
from rdkit.Chem import AllChem

def coords_to_rdkit_mol(coords, atom_types):
    mol = Chem.RWMol()
    atom_ids = []

    for atom_symbol in atom_types:
        if atom_symbol == 'X':  # Skip padding atoms
            continue
        atom = Chem.Atom(atom_symbol)
        atom_id = mol.AddAtom(atom)
        atom_ids.append(atom_id)

    conf = Chem.Conformer(len(atom_ids))
    for i, (x, y, z) in enumerate(coords[:len(atom_ids)]):
        conf.SetAtomPosition(i, (float(x), float(y), float(z)))
    mol.AddConformer(conf)

    # Optionally, sanitize
    try:
        Chem.SanitizeMol(mol)
    except:
        pass  # Might fail for rough initial coords

    return mol


In [115]:
# Generate new molecules after training
from IPython.display import display

model.eval()
with torch.no_grad():
    x = torch.randn((1, 3, max_atoms), device=device)  # Initial noise

    T = 1000  # Number of diffusion steps (should match training)
    betas = torch.linspace(1e-4, 0.02, T).to(device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    for t in reversed(range(T)):
        t_tensor = torch.full((1,), t, device=device, dtype=torch.float32)
        noise_pred = model(x, t_tensor)  # Predict noise at time t

        alpha = alphas[t]
        alpha_cum = alphas_cumprod[t]
        beta = betas[t]

        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = 0

        x = (1 / torch.sqrt(alpha)) * (
            x - ((1 - alpha) / torch.sqrt(1 - alpha_cum)) * noise_pred
        ) + torch.sqrt(beta) * noise

    generated_coords = x.permute(0, 2, 1).squeeze(0).cpu().numpy()  # (max_atoms, 3)
    generated_atom_types = ['C'] * max_atoms  # Simplified for now






    mol = coords_to_rdkit_mol(generated_coords, generated_atom_types)
    viewer = visualize_mol(mol)
    display(viewer.show())

    #img = Chem.Draw.MolToImage(mol, size=(400, 300))
    #img.show()


None

In [116]:
T = scheduler.num_train_timesteps  # Number of diffusion steps
x = torch.randn((1, max_atoms, 3), device=device)  # Random noisy 3D positions
random_coords = x.squeeze(0).cpu().numpy()
random_atom_types = ['C'] * max_atoms  # Assume all atoms are Carbon

mol_random = coords_to_rdkit_mol(random_coords, random_atom_types)
viewer_random = visualize_mol(mol_random)
display(viewer_random.show())


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


None

In [117]:
model.eval()
with torch.no_grad():
    for t in reversed(range(T)):
        t_tensor = torch.full((1,), t, device=device, dtype=torch.float32)
        noise_pred = model(x.permute(0, 2, 1), t_tensor).permute(0, 2, 1)

        # Update using DDPM reverse formula
        x = scheduler.step(noise_pred, t, x).prev_sample
        if t % 200 == 0:
            coords_t = x.squeeze(0).cpu().numpy()
            mol_t = coords_to_rdkit_mol(coords_t, random_atom_types)
            viewer_t = visualize_mol(mol_t)
            display(viewer_t.show())


None

None

None

None

None

In [118]:
final_coords = x.squeeze(0).cpu().numpy()
mol = coords_to_rdkit_mol(final_coords, random_atom_types)

viewer = visualize_mol(mol)
display(viewer.show())

# Optional 2D image
img = Chem.Draw.MolToImage(mol, size=(400, 300))
img.show()


None