In [1]:
import numpy as np
import torch
from torch_geometric.data import DataLoader
from rdkit import Chem
from matchms.importing import load_from_mgf
from mol2dreams.featurizer.featurize import MoleculeFeaturizer
from mol2dreams.featurizer.atom_features import AtomFeaturizer
from mol2dreams.featurizer.bond_features import BondFeaturizer

In [2]:
bond_config = {
    'features': {
        'bond_type': True,
        'conjugated': True,
        'in_ring': True,
        'stereochemistry': False,
    }
}

# Initialize the BondFeaturizer
featurizer = BondFeaturizer(bond_config)

# Test with a sample molecule (Ethylene)
smiles = 'C=C'
mol = Chem.MolFromSmiles(smiles)
bond = mol.GetBondBetweenAtoms(0, 1)  # Bond between first and second atom

feature = featurizer.featurize(bond)
print(f"Bond features shape: {feature.shape}")
print(feature)

Bond features shape: torch.Size([7])
tensor([0., 1., 0., 0., 0., 0., 0.])


In [3]:
atom_config = {
    'features': {
        'atom_symbol': True,
        'total_valence': True,
        'aromatic': True,
        'hybridization': True,
        'formal_charge': True,
        'default_valence': True,
        'ring_size': True,        
        'hydrogen_count': True,
    },
    'feature_attributes': {
        'atom_symbol': {
            'top_n_atoms': 42,        # Number of top atoms to recognize
            'include_other': True,    # Whether to include an 'Unknown' category
        },
        # Additional feature-specific attributes can be added here
    }
}

# Initialize the AtomFeaturizer
featurizer = AtomFeaturizer(atom_config)

# Test with a sample molecule (Ethanol)
smiles = 'CCO'
mol = Chem.MolFromSmiles(smiles)

for atom in mol.GetAtoms():
    feature = featurizer.featurize(atom)
    print(f"Atom index {atom.GetIdx()}: Feature shape {feature.shape}")
    print(feature)

Atom index 0: Feature shape torch.Size([84])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
Atom index 1: Feature shape torch.Size([84])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
Atom index 2: Feature shape torch.Size([84])
tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.,

In [4]:
spectrum_embedding_size = 1024 
featurizer = MoleculeFeaturizer(atom_config, bond_config, spectrum_embedding_size)

In [5]:
spectra_path = "../../data/data/MassSpecGym.mgf"

In [6]:
spectra = list(load_from_mgf(spectra_path))

In [7]:
spectra_small = spectra[:50]

In [12]:
list_smiles = []
for spectrum in spectra_small:
    entry = {}
    entry['smiles'] = spectrum.metadata.get('smiles', None)
    entry['embedding'] = torch.zeros(1024)
    list_smiles.append(entry)


In [13]:
# Preprocess the dataset
data_list = featurizer.featurize_dataset(list_smiles)

print(f"Number of successfully featurized molecules: {len(data_list)}")


Featurizing dataset: 100%|██████████| 50/50 [00:00<00:00, 295.81it/s]

Number of successfully featurized molecules: 50





In [14]:
data_list

[Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[21, 84], edge_index=[2, 44], edge_attr=[44, 7], y=[1, 1024]),
 Data(x=[23, 84], edge_index=[2, 48], edge_attr=[48, 7], y=[1, 1024]),
 Data(x=[23, 84], edge_index=[2, 48], edge_attr=[48, 7], y=[1, 1024]),
 Data(x=[23, 84], edge_index=[2, 48], edge_attr=[48, 7], y=[1, 1024]),
 Data(

In [15]:
batch_size = 32
loader = DataLoader(data_list, batch_size=batch_size, shuffle=True, num_workers=1)

for batch in loader:
    print(batch)
    print(f"Batch size: {batch.num_graphs}")
    print(f"Node feature shape: {batch.x.shape}")         
    print(f"Edge index shape: {batch.edge_index.shape}")  
    print(f"Edge feature shape: {batch.edge_attr.shape}") 
    print(f"Spectrum embedding shape: {batch.y.shape}")  
    break 

DataBatch(x=[752, 84], edge_index=[2, 1584], edge_attr=[1584, 7], y=[32, 1024], batch=[752], ptr=[33])
Batch size: 32
Node feature shape: torch.Size([752, 84])
Edge index shape: torch.Size([2, 1584])
Edge feature shape: torch.Size([1584, 7])
Spectrum embedding shape: torch.Size([32, 1024])
