In [1]:
import pandas as pd
import numpy as np

df_qm9 = pd.read_csv('qm9_demo_0.5k.csv')
df_qm9.head()

Unnamed: 0,file_name,mol_len,atom_coords,vibrationalfrequence,smiles_basic,smiles_stereo,inchi_basic,inchi_stereo,A,B,...,homo,lumo,gap,R2,zpve,Uo,U,H,G,Cv
0,dsgdb9nsd_000001.xyz,5,C\t-0.0126981359\t 1.0858041578\t 0.0080009958...,1341.307\t1341.3284\t1341.365\t1562.6731\t1562...,C,C,1S/CH4/h1H4,1S/CH4/h1H4,157.7118,157.70997,...,-0.3877,0.1171,0.5048,35.3641,0.044749,-40.47893,-40.476062,-40.475117,-40.498597,6.469
1,dsgdb9nsd_000002.xyz,4,N\t-0.0404260543\t 1.0241077531\t 0.0625637998...,1103.8733\t1684.1158\t1684.3072\t3458.7145\t35...,N,N,1S/H3N/h1H3,1S/H3N/h1H3,293.60975,293.54111,...,-0.257,0.0829,0.3399,26.1563,0.034358,-56.525887,-56.523026,-56.522082,-56.544961,6.316
2,dsgdb9nsd_000003.xyz,3,O\t-0.0343604951\t 0.9775395708\t 0.0076015923...,1671.4222\t3803.6305\t3907.698,O,O,1S/H2O/h1H2,1S/H2O/h1H2,799.58812,437.90386,...,-0.2928,0.0687,0.3615,19.0002,0.021375,-76.404702,-76.401867,-76.400922,-76.422349,6.002
3,dsgdb9nsd_000004.xyz,4,C\t 0.5995394918\t 0.\t 1.\t-0.207019\nC\t-0.5...,549.7648\t549.7648\t795.2713\t795.2713\t2078.1...,C#C,C#C,1S/C2H2/c1-2/h1-2H,1S/C2H2/c1-2/h1-2H,0.0,35.610036,...,-0.2845,0.0506,0.3351,59.5248,0.026841,-77.308427,-77.305527,-77.304583,-77.327429,8.574
4,dsgdb9nsd_000005.xyz,3,C\t-0.0133239314\t 1.1324657151\t 0.0082758861...,799.0101\t799.0101\t2198.4393\t3490.3686,C#N,C#N,1S/CHN/c1-2/h1H,1S/CHN/c1-2/h1H,0.0,44.593883,...,-0.3604,0.0191,0.3796,48.7476,0.016601,-93.411888,-93.40937,-93.408425,-93.431246,6.278


In [2]:
class QM9DataExtractor:
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        self.property_columns = ['A', 'B', 'C', 'miu', 'alpha', 'homo', 'lumo', 'gap', 'R2', 'zpve', 'Uo', 'U', 'H',
                                 'G', 'Cv']

    def _get_idx_from_smiles(self, smiles):
        """Convert SMILES string to DataFrame index"""
        mask = self.df['smiles_basic'] == smiles
        if not mask.any():
            raise ValueError(f"SMILES '{smiles}' not found in dataset")
        return mask.idxmax()

    def _resolve_identifier(self, identifier):
        """Resolve identifier to DataFrame index - supports both int index and SMILES string"""
        if isinstance(identifier, str):
            return self._get_idx_from_smiles(identifier)
        elif isinstance(identifier, int):
            if identifier < 0 or identifier >= len(self.df):
                raise IndexError(f"Index {identifier} out of range for dataset of size {len(self.df)}")
            return identifier
        else:
            raise TypeError("Identifier must be either int (index) or str (SMILES)")

    def extract_atom_coords(self, identifier=None):
        """Extract atomic coordinate information vector

        Args:
            identifier: int (index) or str (SMILES) or None (all molecules)
        """
        if identifier is not None:
            idx = self._resolve_identifier(identifier)
            return self._parse_single_coords(idx)

        coords_list = []
        for i in range(len(self.df)):
            coords_list.append(self._parse_single_coords(i))
        return coords_list

    def _parse_single_coords(self, idx):
        """Parse atomic coordinates for a single molecule"""
        atom_coords_str = self.df.iloc[idx]['atom_coords']
        smiles = self.df.iloc[idx]['smiles_basic']

        lines = atom_coords_str.strip().split('\n')
        atoms = []
        coordinates = []
        charges = []

        for line in lines:
            parts = line.split('\t')
            atoms.append(parts[0])
            coordinates.append([float(parts[1]), float(parts[2]), float(parts[3])])
            charges.append(float(parts[4]))

        return {
            'smiles': smiles,
            'atoms': atoms,
            'coordinates': np.array(coordinates),
            'charges': np.array(charges),
            'num_atoms': len(atoms)
        }

    def extract_properties(self, identifier=None):
        """Extract molecular property information

        Args:
            identifier: int (index) or str (SMILES) or None (all molecules)
        """
        if identifier is not None:
            idx = self._resolve_identifier(identifier)
            return self._get_single_properties(idx)

        props_list = []
        for i in range(len(self.df)):
            props_list.append(self._get_single_properties(i))
        return props_list

    def _get_single_properties(self, idx):
        """Get properties for a single molecule"""
        row = self.df.iloc[idx]
        smiles = row['smiles_basic']

        properties = {}
        for prop in self.property_columns:
            properties[prop] = row[prop]

        return {
            'smiles': smiles,
            'properties': properties,
            'property_vector': np.array([row[prop] for prop in self.property_columns])
        }

    def get_smiles_list(self):
        """Get all SMILES strings"""
        return self.df['smiles_basic'].tolist()

    def get_data_by_smiles(self, target_smiles):
        """Get coordinates and properties data by SMILES string"""
        try:
            coords = self.extract_atom_coords(target_smiles)
            props = self.extract_properties(target_smiles)
            return {
                'coordinates_data': coords,
                'properties_data': props
            }
        except ValueError:
            return None

    def get_molecule_info(self, identifier):
        """Get complete molecule information (coordinates + properties)

        Args:
            identifier: int (index) or str (SMILES)
        """
        coords = self.extract_atom_coords(identifier)
        props = self.extract_properties(identifier)

        return {
            'coordinates_data': coords,
            'properties_data': props
        }

In [3]:
# 使用当前notebook中定义的QM9DataExtractor类
extractor = QM9DataExtractor('qm9_demo_0.5k.csv')

# 使用整数索引获取第一个分子的坐标数据
coords_data = extractor.extract_atom_coords(0)
print("First molecule coordinates:")
print(coords_data)


First molecule coordinates:
{'smiles': 'C', 'atoms': ['C', 'H', 'H', 'H', 'H'], 'coordinates': array([[-1.26981359e-02,  1.08580416e+00,  8.00099580e-03],
       [ 2.15041600e-03, -6.03131760e-03,  1.97612040e-03],
       [ 1.01173084e+00,  1.46375116e+00,  2.76574800e-04],
       [-5.40815069e-01,  1.44752661e+00, -8.76643715e-01],
       [-5.23813634e-01,  1.43793264e+00,  9.06397294e-01]]), 'charges': array([-0.535689,  0.133921,  0.133922,  0.133923,  0.133923]), 'num_atoms': 5}


In [4]:
# 使用数据集中实际存在的SMILES字符串进行测试
coords_data = extractor.extract_atom_coords("C")
print("\nCoordinates for molecule with SMILES 'C':")
print(coords_data)


Coordinates for molecule with SMILES 'C':
{'smiles': 'C', 'atoms': ['C', 'H', 'H', 'H', 'H'], 'coordinates': array([[-1.26981359e-02,  1.08580416e+00,  8.00099580e-03],
       [ 2.15041600e-03, -6.03131760e-03,  1.97612040e-03],
       [ 1.01173084e+00,  1.46375116e+00,  2.76574800e-04],
       [-5.40815069e-01,  1.44752661e+00, -8.76643715e-01],
       [-5.23813634e-01,  1.43793264e+00,  9.06397294e-01]]), 'charges': array([-0.535689,  0.133921,  0.133922,  0.133923,  0.133923]), 'num_atoms': 5}


In [5]:
# 完整功能演示
print("=== QM9DataExtractor Complete Functionality Demo ===\n")

# 1. 获取数据集基本信息
print(f"Dataset size: {len(extractor.df)} molecules")
print(f"Available properties: {extractor.property_columns}")

# 2. 通过索引获取分子信息
print("\n--- Example 1: Access by Index ---")
mol_info_by_idx = extractor.get_molecule_info(0)
print(f"SMILES: {mol_info_by_idx['coordinates_data']['smiles']}")
print(f"Atoms: {mol_info_by_idx['coordinates_data']['atoms']}")
print(f"HOMO energy: {mol_info_by_idx['properties_data']['properties']['homo']}")

# 3. 通过SMILES获取分子信息
print("\n--- Example 2: Access by SMILES ---")
target_smiles = mol_info_by_idx['coordinates_data']['smiles']
mol_info_by_smiles = extractor.get_molecule_info(target_smiles)
print(f"Retrieved same molecule using SMILES '{target_smiles}'")
print(f"Number of atoms: {mol_info_by_smiles['coordinates_data']['num_atoms']}")
print(f"LUMO energy: {mol_info_by_smiles['properties_data']['properties']['lumo']}")

# 4. 错误处理演示
print("\n--- Example 3: Error Handling ---")
try:
    invalid_data = extractor.extract_atom_coords("INVALID_SMILES")
except ValueError as e:
    print(f"Caught expected error: {e}")

# 5. 批量处理示例
print("\n--- Example 4: Batch Processing ---")
all_smiles = extractor.get_smiles_list()
print(f"Processing first 3 molecules:")
for i in range(3):
    smiles = all_smiles[i]
    coords = extractor.extract_atom_coords(i)
    print(f"  Molecule {i}: {smiles} - {coords['num_atoms']} atoms")

print("\n=== Demo Complete ===")


=== QM9DataExtractor Complete Functionality Demo ===

Dataset size: 500 molecules
Available properties: ['A', 'B', 'C', 'miu', 'alpha', 'homo', 'lumo', 'gap', 'R2', 'zpve', 'Uo', 'U', 'H', 'G', 'Cv']

--- Example 1: Access by Index ---
SMILES: C
Atoms: ['C', 'H', 'H', 'H', 'H']
HOMO energy: -0.3877

--- Example 2: Access by SMILES ---
Retrieved same molecule using SMILES 'C'
Number of atoms: 5
LUMO energy: 0.1171

--- Example 3: Error Handling ---
Caught expected error: SMILES 'INVALID_SMILES' not found in dataset

--- Example 4: Batch Processing ---
Processing first 3 molecules:
  Molecule 0: C - 5 atoms
  Molecule 1: N - 4 atoms
  Molecule 2: O - 3 atoms

=== Demo Complete ===
