# Test

In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GraphUNet
from diffusers import DDPMScheduler
from rdkit import Chem
from rdkit.Chem import QED
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm
import pandas as pd

# Allowed atoms and bonds
allowed_atoms = [1, 6, 7, 8, 9, 11, 15, 16, 17, 19, 30, 35, 53]
allowed_bonds = [0, 1, 2, 3]  # Include 0 for no bond
atom_type_to_idx = {z: i for i, z in enumerate(allowed_atoms)}

# Dataset
class CustomDataset(Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        self.df = self.df[self.df['chain_size'] == 0]
        self.df = self.df[self.df['static_polarizability'] <= 217]
        scaler = StandardScaler()
        self.df['scaled_polarizability'] = scaler.fit_transform(self.df[['static_polarizability']])
        self.atom_type_to_idx = atom_type_to_idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        smiles = row['smiles']
        mol = Chem.MolFromSmiles(smiles)
        num_atoms = mol.GetNumAtoms()

        x = torch.zeros(num_atoms, len(self.atom_type_to_idx) + 2)  # atom types + valence + degree
        target_node = torch.full((num_atoms,), -100, dtype=torch.long)  # ignore_index for loss
        for i, atom in enumerate(mol.GetAtoms()):
            z = atom.GetAtomicNum()
            if z not in self.atom_type_to_idx:
                continue  # skip unknown atoms, or handle differently
            idx_atom = self.atom_type_to_idx[z]
            x[i, idx_atom] = 1
            x[i, -1] = atom.GetTotalValence()
            x[i, -2] = mol.GetAtomWithIdx(i).GetDegree()
            target_node[i] = idx_atom

        edge_index = []
        target_edge = []
        # True bonds
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            bt = int(bond.GetBondTypeAsDouble())
            if bt in allowed_bonds[1:]:
                idx = allowed_bonds.index(bt)
                edge_index.extend([[i, j], [j, i]])
                target_edge.extend([idx, idx])
        # Sample non-bonded pairs
        non_bond_pairs = torch.combinations(torch.arange(num_atoms), r=2)
        num_non_bonds = len(edge_index)  # e.g., 2x true bonds
        if len(non_bond_pairs) > num_non_bonds:
            indices = torch.randperm(len(non_bond_pairs))[:num_non_bonds]
            non_bond_pairs = non_bond_pairs[indices]
        else:
            non_bond_pairs = non_bond_pairs
        edge_index.extend(non_bond_pairs.tolist())
        edge_index.extend(non_bond_pairs.flip(1).tolist())
        target_edge.extend([0] * 2 * len(non_bond_pairs))
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        target_edge = torch.tensor(target_edge, dtype=torch.long)
        edge_attr = torch.tensor([allowed_bonds[te] for te in target_edge], dtype=torch.float).unsqueeze(-1)
        graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=num_atoms)
        label = row['scaled_polarizability']

        return graph, torch.tensor(label, dtype=torch.float32), target_node, target_edge

# Model
class ClassConditionedGNN(nn.Module):
    def __init__(self, class_emb_size=75):
        super().__init__()
        in_channels = len(allowed_atoms) + 2
        self.class_emb = nn.Linear(1, class_emb_size)
        self.time_emb = nn.Embedding(1000, 64)
        self.node_model = GraphUNet(
            in_channels=in_channels + class_emb_size + 64,
            hidden_channels=256,
            out_channels=len(allowed_atoms),  # node classification logits
            depth=4
        )
        self.edge_model = nn.Sequential(
            nn.Linear(2 * (in_channels + class_emb_size + 64) + 1, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, len(allowed_bonds))  # Now 4
        )

    def forward(self, graph, t, class_labels):
        if hasattr(graph, 'num_graphs'):
            batch_idx = graph.batch
        else:
            batch_idx = torch.zeros(graph.num_nodes, dtype=torch.long, device=graph.x.device)

        class_cond = self.class_emb(class_labels.unsqueeze(-1))
        time_cond = self.time_emb(t)
        class_cond = class_cond[batch_idx]
        time_cond = time_cond[batch_idx]

        node_input = torch.cat((graph.x, class_cond, time_cond), dim=-1)
        pred_node_logits = self.node_model(node_input, graph.edge_index)

        u, v = graph.edge_index
        edge_input = torch.cat((node_input[u], node_input[v], graph.edge_attr), dim=-1)
        pred_edge_logits = self.edge_model(edge_input)

        return {'x': pred_node_logits, 'edge_attr': pred_edge_logits}

class GraphDDPMScheduler(DDPMScheduler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alphas_cumprod = self.alphas_cumprod.to(device)

    def add_noise(self, original_samples, noise, timesteps):
        noisy = original_samples.clone()

        # Handle single graph vs batch of graphs
        if hasattr(original_samples, 'batch') and original_samples.batch is not None:
            batch_idx = original_samples.batch
        else:
            batch_idx = torch.zeros(original_samples.num_nodes, dtype=torch.long, device=original_samples.x.device)

        # Get scaling coefficients
        sqrt_alpha_prod = self.alphas_cumprod[timesteps][batch_idx].sqrt().unsqueeze(-1)
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps][batch_idx]).sqrt().unsqueeze(-1)

        # Add noise to atom type one-hot (not valence/degree)
        noisy.x[:, :len(allowed_atoms)] = sqrt_alpha_prod * original_samples.x[:, :len(allowed_atoms)] + \
                                        sqrt_one_minus_alpha_prod * noise['x']
        noisy.x[:, len(allowed_atoms):] = original_samples.x[:, len(allowed_atoms):]

        # Edge noise (need edge_batch)
        edge_index = original_samples.edge_index
        edge_batch = batch_idx[edge_index[0]]  # shape [E]
        sqrt_alpha_prod_edge = self.alphas_cumprod[timesteps][edge_batch].sqrt().unsqueeze(-1)
        sqrt_one_minus_alpha_prod_edge = (1 - self.alphas_cumprod[timesteps][edge_batch]).sqrt().unsqueeze(-1)

        noisy.edge_attr = sqrt_alpha_prod_edge * original_samples.edge_attr + \
                        sqrt_one_minus_alpha_prod_edge * noise['edge_attr']

        return noisy

    def step(self, model_output, timestep, sample):
        t = timestep.item() if timestep.dim() == 0 else timestep[0].item()
        prev_t = t - 1
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else 1.0
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        alpha_t = self.alphas[t]
        # Node
        x0_atom = torch.softmax(model_output['x'], dim=-1)
        x0_x = torch.cat((x0_atom, sample.x[:, -2:]), dim=-1)
        # Edge (expected bond order as scalar)
        bond_probs = torch.softmax(model_output['edge_attr'], dim=-1)
        x0_edge = (bond_probs * torch.tensor(allowed_bonds, device=sample.x.device)).sum(dim=-1).unsqueeze(-1)
        if prev_t < 0:
            return Data(x=x0_x, edge_attr=x0_edge, edge_index=sample.edge_index, num_nodes=sample.num_nodes)
        # Mean
        coef1 = alpha_prod_t_prev ** 0.5 * alpha_t / beta_prod_t
        coef2 = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 * (1 - alpha_t) / beta_prod_t
        mu_x = coef1 * x0_x + coef2 * sample.x
        mu_edge = coef1 * x0_edge + coef2 * sample.edge_attr
        # Variance
        variance = beta_prod_t_prev * (1 - alpha_t) / beta_prod_t
        sigma = variance ** 0.5
        prev_x = mu_x + sigma * torch.randn_like(sample.x)
        prev_edge = mu_edge + sigma * torch.randn_like(sample.edge_attr)
        return Data(x=prev_x, edge_attr=prev_edge, edge_index=sample.edge_index, num_nodes=sample.num_nodes)

# Valence limits for atoms you have (extend as needed)
VALENCE_LIMITS = {
    1: 1, 6: 4, 7: 3, 8: 2, 9: 1, 11: 1, 15: 3, 16: 2, 17: 1, 19: 1, 30: 2, 35: 1, 53: 1
}

def valence_violation_loss(graph, pred_node_logits, pred_edge_logits):
    # node: [N, num_atom_types], edge: [E, num_bond_types]
    atom_probs = torch.softmax(pred_node_logits, dim=-1)
    bond_probs = torch.softmax(pred_edge_logits, dim=-1)

    inv_atom_map = {v: k for k, v in atom_type_to_idx.items()}
    expected_atomic_nums = torch.zeros(graph.num_nodes, device=graph.x.device)
    for idx in range(len(atom_type_to_idx)):
        expected_atomic_nums += atom_probs[:, idx] * inv_atom_map[idx]

    expected_bond_orders = (bond_probs * torch.tensor(allowed_bonds, device=graph.x.device)).sum(dim=-1)
    bond_sum_per_node = torch.zeros(graph.num_nodes, device=graph.x.device)
    for idx, (u, v) in enumerate(graph.edge_index.t()):
        bond_sum_per_node[u] += expected_bond_orders[idx]
        bond_sum_per_node[v] += expected_bond_orders[idx]

    valence_limits_tensor = torch.tensor(
        [VALENCE_LIMITS.get(z.item(), 4) for z in expected_atomic_nums.round().int()],
        device=graph.x.device, dtype=torch.float
    )
    violation = torch.relu(bond_sum_per_node - valence_limits_tensor)
    return violation.mean()

# Convert graph to mol (updated to use allowed atoms)
def graph_to_mol(data: Data):
    mol = Chem.RWMol()
    node_map = {}

    for i, node_feat in enumerate(data.x):
        atomic_type_idx = node_feat[:len(allowed_atoms)].argmax().item()  # Only on atom part
        if atomic_type_idx >= len(allowed_atoms):
            return None
        atomic_num = allowed_atoms[atomic_type_idx]
        atom = Chem.Atom(atomic_num)
        idx = mol.AddAtom(atom)
        node_map[i] = idx

    added_bonds = set()
    for k in range(data.edge_index.size(1)):
        u = data.edge_index[0, k].item()
        v = data.edge_index[1, k].item()
        if u >= v or (u, v) in added_bonds:
            continue
        added_bonds.add((u, v))
        bond_val = data.edge_attr[k].item()  # scalar
        bond_type_idx = round(bond_val)
        if bond_type_idx <= 0 or bond_type_idx >= len(allowed_bonds):
            continue
        bond_type_val = allowed_bonds[bond_type_idx]
        if bond_type_val == 1:
            bond_type = Chem.rdchem.BondType.SINGLE
        elif bond_type_val == 2:
            bond_type = Chem.rdchem.BondType.DOUBLE
        elif bond_type_val == 3:
            bond_type = Chem.rdchem.BondType.TRIPLE
        else:
            continue
        try:
            mol.AddBond(node_map[u], node_map[v], bond_type)
        except:
            return None

    try:
        mol = mol.GetMol()
        Chem.SanitizeMol(mol)
        return mol
    except:
        return None

def is_valid_molecule(mol, min_atoms=3, max_atoms=50, min_qed=0.2):
    if mol is None:
        return False
    if mol.GetNumAtoms() < min_atoms or mol.GetNumAtoms() > max_atoms:
        return False
    try:
        qed_score = QED.qed(mol)
        if qed_score < min_qed:
            return False
    except:
        return False
    return True

from torch_geometric.data import Batch

def custom_collate(batch):
    graphs, labels, target_nodes, target_edges = zip(*batch)
    
    # Type and shape sanity checks
    for i, (g, tn, te) in enumerate(zip(graphs, target_nodes, target_edges)):
        assert isinstance(g, Data), f"Item {i} graph not Data!"
        assert tn.dim() == 1, f"Item {i} target_node not 1D!"
        assert te.dim() == 1, f"Item {i} target_edge not 1D!"
        assert tn.size(0) == g.num_nodes, f"Item {i} target_node size mismatch!"
        assert te.size(0) == g.edge_index.size(1), f"Item {i} target_edge size mismatch!"
    
    batch_graph = Batch.from_data_list(graphs)
    labels = torch.stack(labels)
    target_nodes = torch.cat(target_nodes)
    target_edges = torch.cat(target_edges)
    return batch_graph, labels, target_nodes, target_edges


# Initialize dataset, dataloader, model, optimizer, scheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
dataset = CustomDataset('../polygraphpy/data/polarizability_data.csv')
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=custom_collate, num_workers=0)

net = ClassConditionedGNN().to(device)
loss_fn_node = nn.CrossEntropyLoss(ignore_index=-100)
loss_fn_edge = nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-4)

noise_scheduler = GraphDDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")

In [3]:
# Training loop
n_epochs = 35

aux_loss = 1e9

for epoch in range(n_epochs):
    losses = []
    for batch_graphs, labels, target_nodes, target_edges in tqdm(train_dataloader):
        batch_graphs = batch_graphs.to(device)
        target_nodes = target_nodes.to(device)
        target_edges = target_edges.to(device)
        y = labels.to(device)

        timesteps = torch.randint(0, 999, (batch_graphs.num_graphs,), device=device)
        noise = {
            'x': torch.randn((batch_graphs.x.size(0), len(allowed_atoms)), device=device),
            'edge_attr': torch.randn_like(batch_graphs.edge_attr)
        }

        noisy_batch = noise_scheduler.add_noise(batch_graphs, noise, timesteps)
        pred = net(noisy_batch, timesteps, y)

        loss_node = loss_fn_node(pred['x'], target_nodes)
        loss_edge = loss_fn_edge(pred['edge_attr'], target_edges)

        val_loss = valence_violation_loss(batch_graphs, pred['x'], pred['edge_attr'])
        loss = loss_node + loss_edge +  val_loss

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
        opt.step()
        losses.append(loss.item())
    
    if aux_loss > loss.item():
        torch.save(net, 'graph_diffusion_model.pt')
        aux_loss = loss.item()

    avg_loss = sum(losses) / len(losses)
    print(f"Epoch {epoch} avg loss: {avg_loss:.5f}")

    # Save best model etc. here as needed

  adj = torch.sparse_csr_tensor(
100%|██████████| 107/107 [03:22<00:00,  1.89s/it]


Epoch 0 avg loss: 9.13511


100%|██████████| 107/107 [03:23<00:00,  1.90s/it]


Epoch 1 avg loss: 1.94226


100%|██████████| 107/107 [03:21<00:00,  1.89s/it]


Epoch 2 avg loss: 1.81450


100%|██████████| 107/107 [03:21<00:00,  1.89s/it]


Epoch 3 avg loss: 1.78269


100%|██████████| 107/107 [03:17<00:00,  1.85s/it]


Epoch 4 avg loss: 1.75671


100%|██████████| 107/107 [03:03<00:00,  1.72s/it]


Epoch 5 avg loss: 1.73665


100%|██████████| 107/107 [02:23<00:00,  1.34s/it]


Epoch 6 avg loss: 1.71805


100%|██████████| 107/107 [02:06<00:00,  1.18s/it]


Epoch 7 avg loss: 1.70126


100%|██████████| 107/107 [03:00<00:00,  1.69s/it]


Epoch 8 avg loss: 1.68807


100%|██████████| 107/107 [02:31<00:00,  1.41s/it]


Epoch 9 avg loss: 1.67521


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]


Epoch 10 avg loss: 1.66416


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 11 avg loss: 1.65352


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]


Epoch 12 avg loss: 1.64197


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 13 avg loss: 1.63395


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 14 avg loss: 1.62478


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 15 avg loss: 1.61728


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]


Epoch 16 avg loss: 1.60888


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]


Epoch 17 avg loss: 1.60248


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 18 avg loss: 1.59710


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]


Epoch 19 avg loss: 1.59092


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 20 avg loss: 1.58427


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]


Epoch 21 avg loss: 1.57869


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]


Epoch 22 avg loss: 1.57423


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 23 avg loss: 1.56926


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 24 avg loss: 1.56808


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 25 avg loss: 1.56399


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 26 avg loss: 1.56240


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 27 avg loss: 1.55748


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 28 avg loss: 1.55319


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 29 avg loss: 1.55070


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 30 avg loss: 1.54647


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 31 avg loss: 1.54915


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 32 avg loss: 1.54071


100%|██████████| 107/107 [02:01<00:00,  1.13s/it]


Epoch 33 avg loss: 1.54030


100%|██████████| 107/107 [02:01<00:00,  1.14s/it]

Epoch 34 avg loss: 1.53867





In [5]:
# Sampling conditioned on scaled polarizability
def sample_conditioned_molecules(net, scheduler, y_target, num_samples=100, device=device):
    net.eval()
    scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device)
    valid_smiles = []
    # Load training dataset to sample graphs
    dataset = CustomDataset('../polygraphpy/data/polarizability_data.csv')

    for i in tqdm(range(num_samples), desc="Sampling"):
        # Sample a random training graph structure
        idx = torch.randint(0, len(dataset), (1,)).item()
        graph, _, _, _ = dataset[idx]
        graph = graph.to(device)
        # Fully noise it (start from t=999)
        timesteps = torch.full((1,), scheduler.num_train_timesteps - 1, device=device)
        noise = {
            'x': torch.randn((graph.num_nodes, len(allowed_atoms)), device=device),
            'edge_attr': torch.randn_like(graph.edge_attr)
        }
        graph = scheduler.add_noise(graph, noise, timesteps)
        # Set condition
        y = torch.tensor([y_target], dtype=torch.float32, device=device)

        for t in reversed(range(scheduler.num_train_timesteps)):
            t_tensor = torch.tensor([t], device=device)
            with torch.no_grad():
                model_output = net(graph, t_tensor, y)
                graph = scheduler.step(model_output, t_tensor, graph)

        mol = graph_to_mol(graph)
        if is_valid_molecule(mol):
            smiles = Chem.MolToSmiles(mol)
            valid_smiles.append(smiles)
            print(f"[✓] Sample {i}: {smiles}")
        else:
            print(f"[✗] Sample {i} invalid.")

    return valid_smiles

# Example usage:
net = torch.load('graph_diffusion_model.pt', weights_only=False).to(device)
samples = sample_conditioned_molecules(net, noise_scheduler, y_target=0.0, num_samples=100, device=device)

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
Sampling:   0%|          | 0/100 [00:07<?, ?it/s]


ValueError: cannot convert float NaN to integer