# Embedding Workflow
---

In [None]:
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
import numpy as np
import torch
from torch_geometric.data import Data

In [None]:
smiles = "COC(=O)C[C@](O)(CCCC(C)(C)O)C(=O)O[C@H]1[C@H]2c3cc4OCOc4cc3CCN3CCC[C@]23C=C1OC"  # Example SMILES string
mol = Chem.MolFromSmiles(smiles)

# Create a Morgan fingerprint generator
generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)

# Generate the fingerprint
fingerprint = generator.GetFingerprint(mol)

In [None]:
# Convert the fingerprint to a NumPy array
arr = np.zeros((1,), dtype=int)
Chem.DataStructs.ConvertToNumpyArray(fingerprint, arr)
print(np.unique(arr))  # Print the unique values in the fingerprint

In [None]:
from rdkit import Chem
import torch
from torch_geometric.data import Data


def one_hot_encoding(value, allowable_set):
    """One-hot encode a value among the allowable set."""
    if value not in allowable_set:
        value = allowable_set[-1]  # Assign 'Unknown' or 'Other' for uncommon values
    return [int(value == s) for s in allowable_set]


def get_atom_features(atom):
    """Extract atom features from an RDKit Atom object."""
    atom_features = []

    # Atom Type (One-hot Encoding)
    atom_type = atom.GetSymbol()
    atom_features += one_hot_encoding(
        atom_type, ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "Other"]
    )

    # Degree (Number of connected neighbors)
    atom_features.append(atom.GetDegree())

    # Formal Charge
    atom_features.append(atom.GetFormalCharge())

    # Hybridization (One-hot Encoding)
    hybridization = atom.GetHybridization()
    atom_features += one_hot_encoding(
        str(hybridization),
        [
            "HybridizationType.SP",
            "HybridizationType.SP2",
            "HybridizationType.SP3",
            "Other",
        ],
    )

    # Aromaticity
    atom_features.append(int(atom.GetIsAromatic()))

    # Chirality (One-hot Encoding)
    chirality = atom.GetChiralTag()
    atom_features += one_hot_encoding(
        str(chirality),
        [
            "ChiralType.CHI_UNSPECIFIED",
            "ChiralType.CHI_TETRAHEDRAL_CW",
            "ChiralType.CHI_TETRAHEDRAL_CCW",
            "Other",
        ],
    )

    # Number of Hydrogens
    atom_features.append(atom.GetTotalNumHs())

    # Is in a Ring
    atom_features.append(int(atom.IsInRing()))

    return atom_features


def get_bond_features(bond):
    """Extract bond features from an RDKit Bond object."""
    bond_features = []

    # Bond Type (One-hot Encoding)
    bond_type = str(bond.GetBondType())
    bond_features += one_hot_encoding(
        bond_type, ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC", "OTHER"]
    )

    # Is Conjugated
    bond_features.append(int(bond.GetIsConjugated()))

    # Is Aromatic
    bond_features.append(int(bond.GetIsAromatic()))

    # Is in a Ring
    bond_features.append(int(bond.IsInRing()))

    return bond_features


def mol_to_graph(smiles):
    """Convert a SMILES string to a PyTorch Geometric Data object with rich features."""
    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        return None

    # Atom Features
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features = get_atom_features(atom)
        atom_features_list.append(atom_features)
    x = torch.tensor(atom_features_list, dtype=torch.float)

    # Edge Index and Edge Features
    edge_index = []
    edge_features_list = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        # Edge Index (both directions)
        edge_index.extend([[i, j], [j, i]])

        # Bond Features
        bond_features = get_bond_features(bond)

        # Append bond features for both directions
        edge_features_list.extend([bond_features, bond_features])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_features_list, dtype=torch.float)

    # Create Data Object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

In [None]:
import torch.nn as nn
from torch_geometric.nn import GCNConv


class GNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        return x


# Example usage
input_dim = 1  # Atomic number
hidden_dim = 64
output_dim = 32
model = GNNModel(input_dim, hidden_dim, output_dim)

In [None]:
from torch_geometric.data import DataLoader

# Assume you have a list of SMILES strings and corresponding labels
smiles_list = ["CCO", "CCN", "CCC"]  # Example SMILES
labels = [0, 1, 0]  # Example labels

# Convert SMILES to graph data
data_list = [mol_to_graph(smiles) for smiles in smiles_list]

# Add labels to data objects
for data, label in zip(data_list, labels):
    data.y = torch.tensor([label], dtype=torch.float)

# Create a DataLoader
loader = DataLoader(data_list, batch_size=2, shuffle=True)

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(100):
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

In [None]:
# Assume you have a test dataset
test_smiles_list = ["CCO", "CCN"]
test_labels = [0, 1]

# Convert SMILES to graph data
test_data_list = [mol_to_graph(smiles) for smiles in test_smiles_list]

# Add labels to data objects
for data, label in zip(test_data_list, test_labels):
    data.y = torch.tensor([label], dtype=torch.float)

# Create a DataLoader
test_loader = DataLoader(test_data_list, batch_size=2, shuffle=False)

# Evaluation loop
model.eval()
correct = 0
for batch in test_loader:
    out = model(batch)
    pred = (out > 0).float()
    correct += (pred == batch.y).sum().item()

accuracy = correct / len(test_data_list)
print(f"Accuracy: {accuracy:.4f}")