# Exploring the QM9 Dataset: A Comprehensive Machine Learning Tutorial for Molecular Property Prediction

**Estimated Time: 90 minutes**

- Download QM9 dataset: https://doi.org/10.6084/m9.figshare.978904
- Modify QM9_ZIP_PATH in the tutorial file to tell the code where the downloaded dataset is

---

## Introduction

The **QM9 dataset** is one of the most widely used benchmark datasets in computational chemistry and machine learning for molecules. It contains quantum mechanical properties of 133,885 small organic molecules, each with up to 9 heavy atoms (C, N, O, F) and hydrogen atoms.

For QM9, researchers took a specific subset of **GDB-17 database**: all molecules containing up to 9 heavy atoms (C, O, N, F). The GDB-17 (Geometric Data Base) is a massive chemical universe of nearly 166 billion molecules.

### Why QM9 Matters

Understanding molecular properties is fundamental to:

- **Drug Discovery**: Predicting how molecules will interact with biological targets
- **Materials Science**: Designing molecules with specific electronic or optical properties
- **Green Chemistry**: Finding catalysts and reactions that are more sustainable
- **Energy Storage**: Developing better battery materials and solar cell components

Traditionally, computing these properties requires expensive quantum mechanical calculations using methods like Density Functional Theory (DFT). These calculations can take hours to days for a single molecule. Machine learning offers a promising alternative: once trained on DFT data, models can predict properties in milliseconds.

### What's in QM9?

Each molecule in QM9 comes with 19 properties computed at the B3LYP/6-31G(2df,p) level of DFT:

- **Geometric properties**: Rotational constants (A, B, C)
- **Electronic properties**: HOMO, LUMO, gap, dipole moment (μ), polarizability (α)
- **Thermodynamic properties**: Zero-point energy, internal energies (U₀, U), enthalpy (H), free energy (G), heat capacity (Cᵥ)
- **Other**: Electronic spatial extent (R²)

The dataset also includes atomic coordinates (in Ångströms) and SMILES strings for each molecule.

### Chemical Context

Predicting molecular properties from structure is a core challenge in chemistry. The relationship between a molecule's atoms and bonds (its **topology**) and its 3D arrangement (its **geometry**) determines all its properties. In this tutorial, we'll explore how to:

1. Represent molecules numerically (featurization)
2. Identify patterns in chemical space (clustering, visualization)
3. Build predictive models for quantum properties (regression)

---

## Learning Objectives

By the end of this tutorial, you will be able to:

1. **Load and parse** the QM9 dataset, extracting molecular properties and structures
2. **Perform EDA** to explore distributions and correlations in molecular data
3. **Generate molecular fingerprints** using RDKit (Morgan, MACCS)
4. **Visualize chemical space** using PCA and UMAP
5. **Apply clustering** to identify molecular families

---

## Dependencies and Setup

Before we begin, let's install the required packages. We'll use:

- **RDKit**: Cheminformatics toolkit for molecular processing
- **pandas/numpy**: Data manipulation
- **matplotlib/seaborn**: Visualization
- **scikit-learn**: Machine learning
- **umap-learn**: Dimensionality reduction

In [None]:
# Install dependencies (run once)
# Uncomment the following lines if packages are not installed

# !pip install rdkit pandas numpy matplotlib seaborn scikit-learn umap-learn tqdm


In [None]:
# Import all required libraries
import os
import warnings
import zipfile
import shutil
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Display settings
pd.set_option('display.max_columns', 20)
pd.set_option('display.width', 200)

print("All imports successful!")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

---

## Section 1: Loading the QM9 Dataset

In this section, we'll load the QM9 dataset by parsing individual XYZ files. Each XYZ file contains atomic coordinates, quantum mechanical properties, and molecular identifiers like SMILES.

### 1.1 Understanding the XYZ File Format

QM9 XYZ files have a specific structure:

```
Line 1:      Number of atoms (N)
Line 2:      Properties: gdb tag, A, B, C, μ, α, HOMO, LUMO, gap, R², zpve, U₀, U, H, G, Cᵥ
Lines 3-N+2: Atom element, x, y, z, Mulliken partial charge
Line N+3:    Harmonic vibrational frequencies (cm⁻¹)
Line N+4:    SMILES from GDB-17 and relaxed geometry
Line N+5:    InChI from GDB-17 and relaxed geometry
```

For example, `data/QM9/dsgdb9nsd_000042.xyz` contains a database entry for ethylene glycol.

```
10
gdb 42  14.79671        5.6956  4.58846 0.0075  31.42   -0.2594 0.0584  0.3179  297.8398        0.085172       -230.183076     -230.177723     -230.176779     -230.211195     16.837
O       -0.0141873452    1.4264224166   -0.0542365107   -0.42367
C       -0.003918394     0.0071771432    0.0349606216   -0.064049
C       -1.4231701908   -0.5556060013   -0.0125663926   -0.064037
O       -2.1122941316   -0.1056387482   -1.1722831277   -0.423669
H       -0.1898788211    1.6415470511   -0.9755514801    0.283356
H        0.5965950414   -0.4331542952   -0.7740909498    0.0965
H        0.4731111572   -0.2490373319    0.9874721596    0.107862
H       -1.394948591    -1.6497527924   -0.0628318813    0.107862
H       -1.9649741719   -0.2721691543    0.9013197201    0.096497
H       -2.359259913     0.8094439625   -1.0061459692    0.28335
76.5151 149.3748        316.1843        505.8235        532.5295        870.3558        876.6631      1045.1048        1053.4694       1116.1618       1200.2034       1245.2124       1393.7745       1400.2924      1402.7471       1412.2916       1489.6311       1496.1831       2998.9079       3005.5406     3073.897 3078.4301       3823.0092       3823.8243
OCCO    OCCO
InChI=1S/C2H6O2/c3-1-2-4/h3-4H,1-2H2    InChI=1S/C2H6O2/c3-1-2-4/h3-4H,1-2H2
```

Let's write a parser to extract all this information:

In [None]:
def parse_qm9_xyz(filepath: str) -> Optional[Dict]:
    """
    Parse a single QM9 XYZ file and extract all properties.

    Parameters:
    -----------
    filepath : str
        Path to the XYZ file

    Returns:
    --------
    dict or None
        Dictionary containing all extracted properties, or None if parsing fails
    """
    # Property names in order they appear in the XYZ file
    property_names = [
        'tag', 'A', 'B', 'C', 'mu', 'alpha',
        'homo', 'lumo', 'gap', 'R2', 'zpve',
        'U0', 'U', 'H', 'G', 'Cv'
    ]

    try:
        with open(filepath, 'r') as f:
            lines = f.readlines()

        # Line 0: Number of atoms
        n_atoms = int(lines[0].strip())

        # Line 1: Properties (tab-separated)
        # Format: "gdb {tag}\t{A}\t{B}\t..."
        prop_line = lines[1].strip().split('\t')

        # Extract tag from "gdb {tag}" format
        tag = int(prop_line[0].split()[1])

        # Extract numeric properties
        properties = {'tag': tag}
        for i, value in enumerate(prop_line[1:], start=1):
            if i < len(property_names):
                properties[property_names[i]] = float(value)

        # Extract atomic coordinates and elements
        atoms = []
        coords = []
        charges = []

        for i in range(2, 2 + n_atoms):
            parts = lines[i].strip().split()
            atoms.append(parts[0])
            coords.append([float(parts[1]), float(parts[2]), float(parts[3])])
            charges.append(float(parts[4]))

        properties['atoms'] = atoms
        properties['coordinates'] = coords
        properties['partial_charges'] = charges
        properties['n_atoms'] = n_atoms

        # SMILES is on the second-to-last line (first SMILES from the pair)
        smiles_line = lines[-2].strip().split('\t')
        properties['smiles'] = smiles_line[0]  # Original SMILES from GDB-17

        return properties

    except Exception as e:
        print(f"Error parsing {filepath}: {e}")
        return None

### 1.2 Loading Multiple Molecules

Now let's load a subset of the QM9 dataset. We'll load the first ~10,000 molecules to keep processing time reasonable:

In [None]:
def load_qm9_dataset(data_dir: str, n_molecules: int = 10000) -> pd.DataFrame:
    """
    Load QM9 molecules from XYZ files into a pandas DataFrame.

    Parameters:
    -----------
    data_dir : str
        Directory containing QM9 XYZ files
    n_molecules : int
        Maximum number of molecules to load

    Returns:
    --------
    pd.DataFrame
        DataFrame containing all molecular properties
    """
    data_path = Path(data_dir)
    xyz_files = sorted(data_path.glob('dsgdb9nsd_*.xyz'))

    print(f"Found {len(xyz_files)} XYZ files")
    print(f"Loading first {n_molecules} molecules...")

    molecules = []

    for filepath in tqdm(xyz_files[:n_molecules], desc="Parsing XYZ files"):
        mol_data = parse_qm9_xyz(str(filepath))
        if mol_data is not None:
            molecules.append(mol_data)

    print(f"Successfully parsed {len(molecules)} molecules")

    # Create DataFrame (excluding coordinate arrays for main table)
    df_data = []
    for mol in molecules:
        row = {
            'tag': mol['tag'],
            'smiles': mol['smiles'],
            'n_atoms': mol['n_atoms'],
            'A': mol['A'],
            'B': mol['B'],
            'C': mol['C'],
            'mu': mol['mu'],
            'alpha': mol['alpha'],
            'homo': mol['homo'],
            'lumo': mol['lumo'],
            'gap': mol['gap'],
            'R2': mol['R2'],
            'zpve': mol['zpve'],
            'U0': mol['U0'],
            'U': mol['U'],
            'H': mol['H'],
            'G': mol['G'],
            'Cv': mol['Cv'],
            # Store atoms as string for display
            'formula': ''.join(mol['atoms'])
        }
        df_data.append(row)

    df = pd.DataFrame(df_data)
    return df, molecules  # Return both DataFrame and raw data (with coordinates)

### 1.3 Extract and Load the Dataset

The QM9 dataset is provided as a ZIP file to save space. Let's extract it and load the data:

In [None]:
def find_qm9_directory(base_dir: Path, zip_stem: str = None) -> Path | None:
    """
    Find the QM9 data directory by checking common naming conventions.
    
    Parameters:
    -----------
    base_dir : Path
        Directory to search in
    zip_stem : str, optional
        The stem (name without extension) of the ZIP file, used as additional search hint
    
    Returns:
    --------
    Path | None : Path to directory containing xyz files, or None if not found
    """
    # Standard QM9 directory names
    possible_dirs = [
        base_dir / 'QM9',
        base_dir / 'dsgdb9nsd',
        base_dir / 'qm9',
    ]
    
    # If zip_stem provided, also check for directory matching the zip file name
    if zip_stem and zip_stem.lower() not in ['qm9', 'dsgdb9nsd']:
        possible_dirs.append(base_dir / zip_stem)
    
    for dir_path in possible_dirs:
        if dir_path.exists() and any(dir_path.glob('*.xyz')):
            return dir_path
    
    # Also check if xyz files are directly in base_dir
    if any(base_dir.glob('*.xyz')):
        return base_dir
    
    return None


def extract_qm9_dataset(zip_path: str, extract_to: str = None) -> str:
    """
    Extract QM9 dataset from ZIP file.
    
    Parameters:
    -----------
    zip_path : str
        Path to the QM9 ZIP file (can have any name, e.g., QM9.zip, archive.zip)
    extract_to : str, optional
        Directory to extract to. If None, extracts to same directory as ZIP file.
    
    Returns:
    --------
    str : Path to the extracted QM9 directory
    """
    zip_path = Path(zip_path)
    
    if not zip_path.exists():
        raise FileNotFoundError(f"ZIP file not found: {zip_path}")
    
    # Determine extraction directory
    if extract_to is None:
        extract_to = zip_path.parent
    else:
        extract_to = Path(extract_to)
    
    extract_to.mkdir(parents=True, exist_ok=True)
    
    # Check if already extracted
    zip_stem = zip_path.stem
    existing_dir = find_qm9_directory(extract_to, zip_stem)
    if existing_dir is not None:
        print(f"✓ QM9 dataset already extracted at: {existing_dir}")
        xyz_files = list(existing_dir.glob('*.xyz'))
        print(f"  Found {len(xyz_files)} molecule files (.xyz)")
        return str(existing_dir)
    
    print(f"Extracting QM9 dataset from {zip_path.name}...")
    
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # Get list of files to extract
            file_list = zip_ref.namelist()
            
            # Extract with progress bar
            for file in tqdm(file_list, desc="Extracting files"):
                zip_ref.extract(file, extract_to)
        
        # Find the extracted directory
        qm9_dir = find_qm9_directory(extract_to, zip_stem)
        
        if qm9_dir is None:
            raise RuntimeError(
                f"Extraction completed but no xyz files found in {extract_to}. "
                "The ZIP file may not contain QM9 data in the expected format."
            )
        
        # Rename to standard 'QM9' if needed (skip if files are directly in extract_to)
        if qm9_dir.name != 'QM9' and qm9_dir != extract_to:
            final_dir = extract_to / 'QM9'
            if final_dir.exists():
                shutil.rmtree(final_dir)
            qm9_dir.rename(final_dir)
            qm9_dir = final_dir
        
        print(f"✓ Extraction complete! Dataset location: {qm9_dir}")
        
        # Count files
        xyz_files = list(qm9_dir.glob('*.xyz'))
        print(f"  Found {len(xyz_files)} molecule files (.xyz)")
        
        return str(qm9_dir)
        
    except zipfile.BadZipFile:
        raise ValueError(f"Invalid ZIP file: {zip_path}")
    except Exception as e:
        raise RuntimeError(f"Error extracting ZIP file: {e}")

In [None]:
# Set the path to your QM9 ZIP file
QM9_ZIP_PATH = '../../data/QM9.zip'  # Adjust this path as needed

# Extract the dataset (this will skip extraction if already done)
QM9_DATA_DIR = extract_qm9_dataset(QM9_ZIP_PATH)

# Load ~10,000 molecules
print("\nLoading QM9 molecules...")
df_qm9, molecules_raw = load_qm9_dataset(QM9_DATA_DIR, n_molecules=10000)
print(f"✓ Loaded {len(df_qm9)} molecules successfully!")

**Note:** The extraction only happens once. If the dataset is already extracted, the function will detect it and skip extraction, making subsequent runs much faster.

### 1.4 Exploring the DataFrame

Let's look at what we've loaded:

In [None]:
# Display basic info
print(f"Dataset shape: {df_qm9.shape}")
print(f"\nColumn names:\n{df_qm9.columns.tolist()}")
print(f"\nData types:\n{df_qm9.dtypes}")

In [None]:
# Display first few molecules
print("\nFirst 5 molecules:")
df_qm9.head()

In [None]:
# Display summary statistics
print("\nSummary statistics for quantum properties:")
df_qm9[['homo', 'lumo', 'gap', 'mu', 'alpha', 'Cv']].describe()

In [None]:
# Look at a few example SMILES
print("\nExample SMILES strings:")
for i in range(5):
    print(f"  Tag {df_qm9.iloc[i]['tag']:6d}: {df_qm9.iloc[i]['smiles']}")

### 1.5 Quick Data Validation

Let's verify our parsing is correct by checking some expected patterns:

In [None]:
# Verify HOMO-LUMO gap is computed correctly (should equal LUMO - HOMO)
computed_gap = df_qm9['lumo'] - df_qm9['homo']
parsing_error = (computed_gap - df_qm9['gap']).abs().mean()
print(f"Gap computation verification (mean error): {parsing_error:.6f}")

# Check that all molecules have valid SMILES
valid_smiles = df_qm9['smiles'].apply(lambda x: len(x) > 0).sum()
print(f"Molecules with valid SMILES: {valid_smiles}/{len(df_qm9)}")

# Distribution of molecule sizes
print(f"\nAtom count distribution:")
print(df_qm9['n_atoms'].value_counts().sort_index())

You've successfully loaded the QM9 dataset! The DataFrame `df_qm9` contains all 16 quantum mechanical properties for each molecule, along with SMILES strings for molecular structure representation.

In [None]:
# Convert Hartree to eV for more intuitive interpretation
# 1 Hartree = 27.2114 eV
HARTREE_TO_EV = 27.2114

df_qm9['homo_eV'] = df_qm9['homo'] * HARTREE_TO_EV
df_qm9['lumo_eV'] = df_qm9['lumo'] * HARTREE_TO_EV
df_qm9['gap_eV'] = df_qm9['gap'] * HARTREE_TO_EV

print("Added energy columns in eV (electron volts) for easier interpretation")
print(f"Gap range: {df_qm9['gap_eV'].min():.2f} to {df_qm9['gap_eV'].max():.2f} eV")

## Section 2: Exploratory Data Analysis

Now that we understand what each property means, let's explore the data visually. Exploratory Data Analysis (EDA) helps us:

1. Understand the **distribution** of molecular properties
2. Identify **outliers** or unusual molecules
3. Discover **relationships** between properties
4. Guide our **feature engineering** and modeling decisions

### 2.1 Property Distributions: Histograms with KDE

Let's start by examining the distributions of key electronic properties. We'll use histograms with Kernel Density Estimation (KDE) overlays to see both the raw counts and smooth probability density.

In [None]:
# Create figure with histograms for key properties
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Properties to visualize with their chemical context
properties_to_plot = [
    ('homo_eV', 'HOMO Energy (eV)', 'steelblue'),
    ('lumo_eV', 'LUMO Energy (eV)', 'darkorange'),
    ('gap_eV', 'HOMO-LUMO Gap (eV)', 'seagreen'),
    ('mu', 'Dipole Moment (Debye)', 'darkviolet')
]

for ax, (prop, label, color) in zip(axes.flatten(), properties_to_plot):
    # Plot histogram with KDE overlay
    sns.histplot(data=df_qm9, x=prop, kde=True, ax=ax, color=color,
                 bins=50, alpha=0.7, edgecolor='white', linewidth=0.5)

    # Add vertical line for mean
    mean_val = df_qm9[prop].mean()
    ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.2f}')

    ax.set_xlabel(label, fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title(f'Distribution of {label}', fontsize=14, fontweight='bold')
    ax.legend()

plt.tight_layout()
plt.show()

The plots reveal important patterns in the QM9 dataset. Most molecules have HOMO energies between -8 and -6 eV and HOMO-LUMO gaps centered around 6-7 eV, typical for small organic insulators. The dipole moment distribution shows that many molecules are non-polar or weakly polar, with a tail extending toward highly polar species.

### 2.2 Thermodynamic Property Distributions

Let's also examine the thermodynamic properties:

In [None]:
# Create figure for thermodynamic properties
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Thermodynamic properties to visualize
thermo_props = [
    ('Cv', 'Heat Capacity Cᵥ (cal/mol·K)', 'coral'),
    ('zpve', 'Zero-Point Energy (Hartree)', 'teal'),
    ('G', 'Free Energy G (Hartree)', 'mediumorchid'),
    ('alpha', 'Polarizability α (Bohr³)', 'goldenrod')
]

for ax, (prop, label, color) in zip(axes.flatten(), thermo_props):
    sns.histplot(data=df_qm9, x=prop, kde=True, ax=ax, color=color,
                 bins=50, alpha=0.7, edgecolor='white', linewidth=0.5)

    mean_val = df_qm9[prop].mean()
    ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.2f}')

    ax.set_xlabel(label, fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title(f'Distribution of {label}', fontsize=14, fontweight='bold')
    ax.legend()

plt.tight_layout()
plt.show()

**Chemical Interpretation:**

In [None]:
# Summarize thermodynamic properties
print("Thermodynamic Property Summary:")
print("-" * 50)

for prop, label, _ in thermo_props:
    print(f"\n{label}:")
    print(f"  Range: {df_qm9[prop].min():.3f} to {df_qm9[prop].max():.3f}")
    print(f"  Mean ± Std: {df_qm9[prop].mean():.3f} ± {df_qm9[prop].std():.3f}")

**Key Observations:**

1. **Heat Capacity (Cᵥ):** Shows a multimodal distribution reflecting discrete molecular sizes. Larger molecules have more vibrational modes and higher heat capacity.

2. **Zero-Point Energy (zpve):** Positively correlated with molecular size. More atoms = more vibrational modes = more zero-point energy.

3. **Free Energy (G):** Highly negative values (stored electronic and nuclear energy). The distribution reflects the variety of molecular sizes and compositions.

4. **Polarizability (α):** Right-skewed distribution. Larger molecules with more electrons are more polarizable.

### 2.3 Distribution by Molecular Size

Let's see how properties vary with the number of atoms:

In [None]:
# Create violin plots to show property distributions by molecule size
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Properties to analyze by atom count
props_by_size = [
    ('gap_eV', 'HOMO-LUMO Gap (eV)'),
    ('mu', 'Dipole Moment (Debye)'),
    ('Cv', 'Heat Capacity (cal/mol·K)'),
    ('alpha', 'Polarizability (Bohr³)')
]

for ax, (prop, label) in zip(axes.flatten(), props_by_size):
    # Filter to common atom counts for cleaner visualization
    df_filtered = df_qm9[df_qm9['n_atoms'].isin(range(5, 30))]

    sns.boxplot(data=df_filtered, x='n_atoms', y=prop, ax=ax,
                palette='viridis', showfliers=False)

    ax.set_xlabel('Number of Atoms', fontsize=12)
    ax.set_ylabel(label, fontsize=12)
    ax.set_title(f'{label} vs Molecular Size', fontsize=14, fontweight='bold')

    # Rotate x-axis labels for readability
    ax.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

**Chemical Insights from Size Dependence:**

In [None]:
# Calculate correlations with molecular size
print("Correlation with Number of Atoms:")
print("-" * 40)
size_correlations = {}
for prop in ['gap_eV', 'mu', 'Cv', 'alpha', 'homo_eV', 'lumo_eV']:
    corr = df_qm9['n_atoms'].corr(df_qm9[prop])
    size_correlations[prop] = corr
    print(f"  {prop:12s}: r = {corr:+.3f}")

print("\nInterpretation:")
print("  - Cᵥ and α strongly increase with molecular size (more atoms = more modes, more electrons)")
print("  - Gap slightly decreases with size (more extended conjugation = smaller gap)")
print("  - Dipole moment has weak correlation (depends more on symmetry than size)")

### 2.4 Identifying Outliers

Let's identify molecules with unusual properties that might be interesting for further study:

In [None]:
# Remove outliers in rotational constants
# Rotational constants A, B, C have obvious extreme outliers (very high values for tiny molecules)
print("Removing rotational constant outliers...")
initial_count = len(df_qm9)

# Identify outliers using IQR method for rotational constant A
Q1 = df_qm9['A'].quantile(0.25)
Q3 = df_qm9['A'].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 2.5 * IQR
upper_bound = Q3 + 2.5 * IQR

# Filter out outliers
df_qm9 = df_qm9[(df_qm9['A'] >= lower_bound) & (df_qm9['A'] <= upper_bound)].copy()

print(f"Removed {initial_count - len(df_qm9)} molecules with extreme rotational constants")
print(f"Dataset size: {len(df_qm9)} molecules")

In [None]:
# Find molecules with extreme properties
print("Molecules with Extreme Properties:")
print("=" * 60)

# Smallest gap (most colored/reactive)
smallest_gap = df_qm9.nsmallest(5, 'gap_eV')[['tag', 'smiles', 'gap_eV', 'n_atoms']]
print("\nSmallest HOMO-LUMO Gaps (most likely colored):")
print(smallest_gap.to_string(index=False))

# Largest dipole moment
largest_dipole = df_qm9.nlargest(5, 'mu')[['tag', 'smiles', 'mu', 'n_atoms']]
print("\nLargest Dipole Moments (most polar):")
print(largest_dipole.to_string(index=False))

# Highest polarizability
most_polarizable = df_qm9.nlargest(5, 'alpha')[['tag', 'smiles', 'alpha', 'n_atoms']]
print("\nHighest Polarizabilities (largest electron clouds):")
print(most_polarizable.to_string(index=False))

In [None]:
# Visualize outliers on a scatter plot
fig, ax = plt.subplots(figsize=(10, 8))

scatter = ax.scatter(df_qm9['gap_eV'], df_qm9['mu'],
                     c=df_qm9['n_atoms'], cmap='plasma',
                     alpha=0.5, s=15, edgecolors='none')

# Highlight extreme points
extreme_gap = df_qm9.nsmallest(10, 'gap_eV')
extreme_dipole = df_qm9.nlargest(10, 'mu')

ax.scatter(extreme_gap['gap_eV'], extreme_gap['mu'],
           s=100, facecolors='none', edgecolors='red', linewidths=2,
           label='Smallest gaps')
ax.scatter(extreme_dipole['gap_eV'], extreme_dipole['mu'],
           s=100, facecolors='none', edgecolors='blue', linewidths=2,
           label='Largest dipoles')

ax.set_xlabel('HOMO-LUMO Gap (eV)', fontsize=12)
ax.set_ylabel('Dipole Moment (Debye)', fontsize=12)
ax.set_title('Gap vs Dipole Moment with Outliers Highlighted', fontsize=14, fontweight='bold')
ax.legend(loc='upper right')
plt.colorbar(scatter, label='Number of Atoms')
plt.tight_layout()
plt.show()

print("Red circles: Molecules with smallest gaps (potentially colored compounds)")
print("Blue circles: Molecules with largest dipole moments (most polar)")

**Summary of Section 2.1-2.4:**

We've explored the distributions of key molecular properties in QM9:

1. **HOMO and LUMO energies** follow approximately normal distributions, characteristic of small organic molecules
2. **HOMO-LUMO gap** is centered around 6-7 eV (UV absorbers), with few molecules having visible-light gaps
3. **Dipole moment** is right-skewed with many symmetric (non-polar) molecules
4. **Thermodynamic properties** scale with molecular size as expected from statistical mechanics
5. **Outliers** exist and can be chemically interesting (small gaps = colored, large dipoles = polar)

---

### 2.5 Correlation Analysis

One of the most powerful EDA techniques is **correlation analysis**—discovering which properties are related to each other. Understanding these relationships helps us:

1. **Identify redundant features** (highly correlated features may be redundant for ML)
2. **Discover structure-property relationships** (central goal of chemistry!)
3. **Choose prediction targets** (understanding what drives a property)
4. **Validate our data** (some correlations are expected from physics)

Let's compute and visualize the correlation matrix for all numeric properties:

In [None]:
# Select numeric columns for correlation analysis
numeric_cols = ['A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap',
                'R2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'n_atoms']

# Compute correlation matrix
corr_matrix = df_qm9[numeric_cols].corr()

# Create a large heatmap for better readability
fig, ax = plt.subplots(figsize=(14, 12))

# Create heatmap with annotations
heatmap = sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='RdBu_r',
                       center=0, vmin=-1, vmax=1, square=True,
                       linewidths=0.5, cbar_kws={'shrink': 0.8},
                       ax=ax, annot_kws={'size': 9})

ax.set_title('Correlation Matrix of QM9 Properties', fontsize=16, fontweight='bold', pad=20)

# Rotate labels for better readability
plt.xticks(rotation=45, ha='right', fontsize=11)
plt.yticks(rotation=0, fontsize=11)

plt.tight_layout()
plt.show()

#### Interpreting the Correlation Heatmap

Let's identify the key patterns in the correlation matrix:

In [None]:
# Find strongest correlations (excluding self-correlations)
def get_top_correlations(corr_matrix, n=15):
    """Extract the top n absolute correlations from a correlation matrix."""
    # Get upper triangle to avoid duplicates
    upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))

    # Stack and sort by absolute value
    correlations = upper.stack().reset_index()
    correlations.columns = ['Property 1', 'Property 2', 'Correlation']
    correlations['Abs_Correlation'] = correlations['Correlation'].abs()
    correlations = correlations.sort_values('Abs_Correlation', ascending=False)

    return correlations.head(n)[['Property 1', 'Property 2', 'Correlation']]

print("Top 15 Strongest Correlations in QM9:")
print("=" * 55)
top_corr = get_top_correlations(corr_matrix)
for _, row in top_corr.iterrows():
    direction = "↑↑" if row['Correlation'] > 0 else "↑↓"
    print(f"  {row['Property 1']:8s} — {row['Property 2']:8s}: r = {row['Correlation']:+.3f} {direction}")

**Chemical Explanation of Key Correlations:**

1. **Thermodynamic Properties (r ≈ 0.99-1.00)**: U0 ↔ U ↔ H ↔ G are essentially the same quantity at different conditions. The differences (PV work, entropy) are small compared to the total electronic energy.

2. **Size-Dependent Properties (r > 0.90)**: zpve ↔ Cv ↔ alpha ↔ n_atoms all scale with molecular size. More atoms = more vibrations = higher zpve and Cv. More electrons = higher polarizability (alpha).

3. **HOMO-LUMO Relationship (r ≈ 0.2-0.4)**: HOMO and LUMO are only moderately correlated. This is important: they carry independent information! The gap depends on BOTH, not just one.

4. **Rotational Constants (r < -0.5)**: A, B, C negatively correlate with size and other properties. Larger molecules rotate more slowly (lower constants).

5. **Electronic vs Thermodynamic (weak correlation)**: Gap, mu, homo, lumo have low correlation with U, H, G. Electronic structure and total energy are somewhat independent—this is why we need separate models for each!

### 2.6 Key Scatter Plots: Exploring Relationships

Let's create scatter plots to visualize the most chemically interesting relationships:

#### HOMO vs LUMO: The Frontier Orbital Relationship

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: HOMO vs LUMO colored by gap
ax1 = axes[0]
scatter1 = ax1.scatter(df_qm9['homo_eV'], df_qm9['lumo_eV'],
                        c=df_qm9['gap_eV'], cmap='viridis',
                        alpha=0.5, s=10, edgecolors='none')
ax1.set_xlabel('HOMO Energy (eV)', fontsize=12)
ax1.set_ylabel('LUMO Energy (eV)', fontsize=12)
ax1.set_title('HOMO vs LUMO\n(colored by gap)', fontsize=14, fontweight='bold')
plt.colorbar(scatter1, ax=ax1, label='Gap (eV)')

# Add diagonal lines showing constant gap
for gap_val in [4, 6, 8, 10]:
    x = np.linspace(df_qm9['homo_eV'].min(), df_qm9['homo_eV'].max(), 100)
    y = x + gap_val
    ax1.plot(x, y, '--', color='red', alpha=0.5, linewidth=1)
    # Add label at edge
    ax1.text(x[-1], y[-1], f'gap={gap_val}', fontsize=8, color='red', alpha=0.7)

# Plot 2: Gap vs Cv relationship
ax2 = axes[1]
scatter2 = ax2.scatter(df_qm9['gap_eV'], df_qm9['Cv'],
                        c=df_qm9['n_atoms'], cmap='plasma',
                        alpha=0.5, s=10, edgecolors='none')
ax2.set_xlabel('HOMO-LUMO Gap (eV)', fontsize=12)
ax2.set_ylabel('Heat Capacity Cᵥ (cal/mol·K)', fontsize=12)
ax2.set_title('Gap vs Heat Capacity\n(colored by n_atoms)', fontsize=14, fontweight='bold')
plt.colorbar(scatter2, ax=ax2, label='n_atoms')

# Plot 3: Alpha vs n_atoms with gap coloring
ax3 = axes[2]
scatter3 = ax3.scatter(df_qm9['n_atoms'], df_qm9['alpha'],
                        c=df_qm9['gap_eV'], cmap='coolwarm',
                        alpha=0.5, s=10, edgecolors='none')
ax3.set_xlabel('Number of Atoms', fontsize=12)
ax3.set_ylabel('Polarizability α (Bohr³)', fontsize=12)
ax3.set_title('Polarizability vs Size\n(colored by gap)', fontsize=14, fontweight='bold')
plt.colorbar(scatter3, ax=ax3, label='Gap (eV)')

plt.tight_layout()
plt.show()

**Chemical Insights from HOMO vs LUMO Plot:**

In [None]:
# Calculate and report the HOMO-LUMO correlation
homo_lumo_corr = df_qm9['homo_eV'].corr(df_qm9['lumo_eV'])
print(f"HOMO-LUMO correlation: r = {homo_lumo_corr:.3f}")

**Key observations from HOMO vs LUMO scatter plot:**

1. **Moderate positive correlation**: HOMO and LUMO tend to move together (stabilizing both orbitals generally), but the correlation is not perfect (r ≈ 0.3-0.4). This means they carry independent chemical information.

2. **Diagonal bands**: Molecules organize into diagonal bands of constant gap. The gap is the vertical distance from the HOMO-LUMO line to a horizontal reference.

3. **Spread in gap**: Even at similar HOMO levels, molecules can have very different LUMO levels (and thus gaps). This reflects the diverse electronic structures in QM9.

4. **Chemical implication**: A molecule's band gap cannot be predicted from HOMO alone—you need both frontier orbitals. This is why accurate LUMO prediction is also important for materials design.

#### Additional Correlation Scatter Plots

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.flatten()

# Define scatter plots: (x, y, color_by, x_label, y_label, title)
scatter_configs = [
    ('gap_eV', 'alpha', 'n_atoms', 'HOMO-LUMO Gap (eV)', 'Polarizability (Bohr³)',
     'Gap vs Polarizability'),
    ('mu', 'gap_eV', 'alpha', 'Dipole Moment (Debye)', 'HOMO-LUMO Gap (eV)',
     'Dipole vs Gap'),
    ('zpve', 'Cv', 'n_atoms', 'Zero-Point Energy (Hartree)', 'Heat Capacity (cal/mol·K)',
     'ZPVE vs Heat Capacity'),
    ('U0', 'H', 'n_atoms', 'Internal Energy U₀ (Hartree)', 'Enthalpy H (Hartree)',
     'U₀ vs H (Near Perfect Correlation)'),
    ('homo_eV', 'mu', 'gap_eV', 'HOMO Energy (eV)', 'Dipole Moment (Debye)',
     'HOMO vs Dipole'),
    ('gap_eV', 'G', 'n_atoms', 'HOMO-LUMO Gap (eV)', 'Free Energy G (Hartree)',
     'Gap vs Free Energy')
]

for ax, (x, y, c, xlabel, ylabel, title) in zip(axes, scatter_configs):
    scatter = ax.scatter(df_qm9[x], df_qm9[y], c=df_qm9[c],
                         cmap='viridis', alpha=0.4, s=8, edgecolors='none')
    ax.set_xlabel(xlabel, fontsize=11)
    ax.set_ylabel(ylabel, fontsize=11)
    ax.set_title(title, fontsize=12, fontweight='bold')

    # Add correlation coefficient
    r = df_qm9[x].corr(df_qm9[y])
    ax.text(0.05, 0.95, f'r = {r:.3f}', transform=ax.transAxes,
            fontsize=11, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

### 2.7 Structure-Property Relationships

The core goal of computational chemistry is understanding **structure-property relationships**: how a molecule's structure determines its properties. Let's explore some of these relationships:

In [None]:
# Analyze how electronic properties relate to molecular composition
print("Structure-Property Relationships in QM9:")
print("=" * 60)

# Group molecules by number of heavy atoms (non-hydrogen)
df_qm9['heavy_atoms'] = df_qm9['formula'].apply(
    lambda f: len([c for c in f if c in 'CNOF'])
)

# Analyze gap distribution by heavy atom count
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Gap distribution by heavy atoms
ax1 = axes[0]
heavy_atom_counts = df_qm9['heavy_atoms'].value_counts().sort_index()
common_heavy_atoms = heavy_atom_counts[heavy_atom_counts > 100].index

df_filtered = df_qm9[df_qm9['heavy_atoms'].isin(common_heavy_atoms)]
sns.violinplot(data=df_filtered, x='heavy_atoms', y='gap_eV', ax=ax1,
               palette='Set2', inner='box')
ax1.set_xlabel('Number of Heavy Atoms (C, N, O, F)', fontsize=12)
ax1.set_ylabel('HOMO-LUMO Gap (eV)', fontsize=12)
ax1.set_title('Band Gap Distribution by Molecular Size', fontsize=14, fontweight='bold')

# Right: Correlation between size and gap
ax2 = axes[1]
size_gap_stats = df_filtered.groupby('heavy_atoms')['gap_eV'].agg(['mean', 'std']).reset_index()
ax2.errorbar(size_gap_stats['heavy_atoms'], size_gap_stats['mean'],
             yerr=size_gap_stats['std'], fmt='o-', capsize=5, capthick=2,
             color='steelblue', markersize=10, linewidth=2)
ax2.set_xlabel('Number of Heavy Atoms', fontsize=12)
ax2.set_ylabel('Mean Gap (eV) ± Std Dev', fontsize=12)
ax2.set_title('Average Band Gap vs Molecular Size', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Report the trend
corr_heavy_gap = df_qm9['heavy_atoms'].corr(df_qm9['gap_eV'])
print(f"\nCorrelation (heavy atoms vs gap): r = {corr_heavy_gap:.3f}")

**Observation:** Larger molecules (more heavy atoms) tend to have SMALLER band gaps. This is a well-known chemical principle: extended conjugation lowers the gap. In organic electronics, larger π-conjugated systems are used to create semiconductors with small gaps for solar cell and LED applications.

### 2.8 Pair Plot for Comprehensive View

For a comprehensive view of relationships between key properties, let's create a pair plot:

In [None]:
# Select a subset of important properties for the pair plot
pair_cols = ['homo_eV', 'lumo_eV', 'gap_eV', 'mu', 'alpha', 'Cv', 'n_atoms']

# Create pair plot (sample for performance)
sample_size = min(2000, len(df_qm9))
df_sample = df_qm9[pair_cols].sample(n=sample_size, random_state=42)

# Create the pair plot
g = sns.pairplot(df_sample, diag_kind='kde', plot_kws={'alpha': 0.3, 's': 10},
                  corner=True, height=2)
g.fig.suptitle('Pair Plot of Key QM9 Properties', fontsize=16, fontweight='bold', y=1.02)

plt.tight_layout()
plt.show()

print("Pair plot shows pairwise relationships between key properties.")
print("Diagonal: Distribution of each property (KDE)")
print("Off-diagonal: Scatter plots between property pairs")

### 2.9 Summary: Key Correlations and Chemical Insights

Let's summarize the most important correlations and their chemical significance:

In [None]:
# Create a summary table of key correlations
print("=" * 70)
print("SUMMARY: Key Correlations in QM9 and Their Chemical Significance")
print("=" * 70)

summary_data = [
    ("U0 ↔ H ↔ G", "~1.00", "Thermodynamic properties are dominated by total electronic energy"),
    ("zpve ↔ Cv", ">0.95", "Both scale with vibrational modes (3N-6 for nonlinear molecules)"),
    ("alpha ↔ n_atoms", ">0.90", "More electrons → larger polarizable electron cloud"),
    ("HOMO ↔ LUMO", "~0.35", "Moderate correlation: both affected by electron density but independently useful"),
    ("gap ↔ n_atoms", "~-0.20", "Larger molecules have smaller gaps (extended conjugation)"),
    ("A, B, C ↔ n_atoms", "<-0.50", "Larger molecules rotate slower (higher moment of inertia)"),
    ("mu ↔ gap", "~0.05", "Nearly independent: polarity ≠ reactivity/color"),
    ("G ↔ gap", "~0.10", "Thermodynamic stability ≠ electronic gap (different phenomena)")
]

print(f"\n{'Correlation':<20} {'Strength':<10} {'Chemical Interpretation':<45}")
print("-" * 75)
for corr_pair, strength, interpretation in summary_data:
    print(f"{corr_pair:<20} {strength:<10} {interpretation}")

**Key Takeaways for Machine Learning:**

1. **Avoid redundancy**: Don't include all of U0, U, H, G as features—they're nearly identical. Pick one (usually G for chemistry applications).

2. **Size matters**: Many properties correlate with molecular size. Consider normalizing by n_atoms or using size-independent features.

3. **HOMO & LUMO are complementary**: Despite moderate correlation, both carry unique information. Use both when predicting gap or other electronic properties.

4. **Electronic vs Thermodynamic**: These property groups are largely independent. A single model may not predict both well—consider separate models.

5. **Gap is special**: The HOMO-LUMO gap has weak correlation with most other properties, making it both challenging and valuable to predict.

---

## Section 3: Molecular Featurization

In this section, we'll convert our molecules into numerical representations suitable for machine learning. This process is called **featurization** or **molecular representation**.

Machine learning models can't directly process molecular structures—they need numbers! There are many ways to represent molecules numerically:

1. **String-based representations**: SMILES
2. **Fingerprints**: Morgan (ECFP), MACCS keys, atom pairs
3. **Graph-based**: Adjacency matrices, learned embeddings
4. **3D descriptors**: Distance matrices, symmetry functions (advanced)

We'll explore fingerprint approaches in this section, starting with the widely-used **SMILES notation**.

---

### 3.1 Understanding SMILES Notation

**SMILES** (Simplified Molecular Input Line Entry System) is a line notation for representing molecular structures as text strings. It's the most common way to store and exchange molecular information.

#### SMILES Syntax Basics:

| Element | SMILES Notation | Example |
|---------|-----------------|---------|
| Atoms | Uppercase letters (organic subset: C, N, O, S, P, F, Cl, Br, I) | `C` = carbon |
| Bonds | Single (default), `=` double, `#` triple, `:` aromatic | `C=C` = ethene |
| Rings | Numbers indicate ring closures | `C1CCCCC1` = cyclohexane |
| Branches | Parentheses | `CC(C)C` = isobutane |
| Aromaticity | Lowercase letters | `c1ccccc1` = benzene |
| Charges | `+`, `-` in brackets | `[NH4+]` = ammonium |
| Stereochemistry | `@`, `@@`, `/`, `\` | `C/C=C/C` = trans-2-butene |

**Examples of SMILES:**

```
Water:           O
Ethanol:         CCO
Acetic acid:     CC(=O)O
Benzene:         c1ccccc1
Aspirin:         CC(=O)Oc1ccccc1C(=O)O
Caffeine:        Cn1cnc2c1c(=O)n(c(=O)n2C)C
```

**Why SMILES is Useful:**
- Compact representation (vs. XYZ coordinates)
- Human-readable (with practice)
- Easy to store in databases
- Can be canonicalized for unique representation

**Limitations:**
- Doesn't capture 3D geometry
- Multiple valid SMILES for the same molecule
- Can generate invalid molecules (when used for generation)

---

### 3.2 Parsing SMILES with RDKit

**RDKit** is the de facto standard cheminformatics toolkit for working with molecules. Let's use it to parse SMILES strings into molecular objects.

In [None]:
# Import RDKit modules for molecular processing
from rdkit import Chem
from rdkit.Chem import Draw, AllChem, Descriptors
from rdkit.Chem import rdMolDescriptors

print("RDKit imported successfully!")
print(f"RDKit version: {Chem.rdBase.rdkitVersion}")

In [None]:
def parse_smiles_to_mol(smiles: str) -> Optional[Chem.Mol]:
    """
    Parse a SMILES string into an RDKit Mol object.

    Parameters:
    -----------
    smiles : str
        SMILES string representing the molecule

    Returns:
    --------
    Chem.Mol or None
        RDKit Mol object if parsing succeeds, None otherwise
    """
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        return mol
    except Exception as e:
        print(f"Error parsing SMILES '{smiles}': {e}")
        return None

#### Testing the Parser

Let's test our parser on some example SMILES:

In [None]:
# Test with simple molecules
test_smiles = [
    "C",           # Methane
    "CC",          # Ethane
    "C=C",         # Ethene
    "C#C",         # Ethyne
    "c1ccccc1",    # Benzene
    "CCO",         # Ethanol
    "CC(=O)O",     # Acetic acid
    "invalid!!!",  # Invalid SMILES (will fail)
    "",            # Empty string (will fail)
]

print("Testing SMILES parsing:")
print("-" * 50)

for smiles in test_smiles:
    mol = parse_smiles_to_mol(smiles)
    if mol is not None:
        n_atoms = mol.GetNumAtoms()
        n_bonds = mol.GetNumBonds()
        formula = rdMolDescriptors.CalcMolFormula(mol)
        print(f"  '{smiles:15s}' → Valid | Atoms: {n_atoms:2d} | Bonds: {n_bonds:2d} | Formula: {formula}")
    else:
        print(f"  '{smiles:15s}' → INVALID (parsing failed)")

**Key Points:**
- `Chem.MolFromSmiles()` returns `None` for invalid SMILES (doesn't raise exceptions)
- Always check if the result is `None` before using it
- Valid molecules can be queried for atoms, bonds, and other properties

---

### 3.3 Parsing QM9 SMILES with Error Handling

Now let's parse the SMILES from our QM9 dataset. Some SMILES may fail to parse, so we need robust error handling:

In [None]:
def parse_qm9_smiles(df: pd.DataFrame) -> Tuple[List[Chem.Mol], pd.DataFrame]:
    """
    Parse SMILES from QM9 DataFrame and create RDKit Mol objects.

    Parameters:
    -----------
    df : pd.DataFrame
        DataFrame with 'smiles' column

    Returns:
    --------
    Tuple[List[Chem.Mol], pd.DataFrame]
        List of valid Mol objects and filtered DataFrame
    """
    valid_mols = []
    valid_indices = []
    failed_count = 0

    print("Parsing SMILES to RDKit Mol objects...")

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Parsing SMILES"):
        smiles = row['smiles']
        mol = Chem.MolFromSmiles(smiles)

        if mol is not None:
            valid_mols.append(mol)
            valid_indices.append(idx)
        else:
            failed_count += 1

    # Create filtered DataFrame
    df_valid = df.loc[valid_indices].copy()

    print(f"\nParsing complete!")
    print(f"  Successfully parsed: {len(valid_mols):,} molecules")
    print(f"  Failed to parse:     {failed_count:,} molecules")
    print(f"  Success rate:        {len(valid_mols)/len(df)*100:.2f}%")

    return valid_mols, df_valid


# Parse all QM9 SMILES
mols, df_valid = parse_qm9_smiles(df_qm9)

In [None]:
# Verify parsing results
print(f"\nOriginal DataFrame shape:  {df_qm9.shape}")
print(f"Valid molecules DataFrame: {df_valid.shape}")
print(f"Number of Mol objects:     {len(mols)}")

# Store the Mol objects in the DataFrame for easy access
df_valid = df_valid.reset_index(drop=True)
df_valid['mol'] = mols

# Generate canonical SMILES for consistency
df_valid['canonical_smiles'] = [Chem.MolToSmiles(mol) for mol in mols]

print(f"\nDataFrame now includes 'mol' column with RDKit Mol objects")
print(f"Canonical SMILES have been generated for all molecules")

---

### 3.5 Visualizing Molecules

RDKit can render 2D molecular structures. Let's visualize some example molecules:

In [None]:
# Select a diverse set of molecules for visualization
# Choose molecules with different sizes and properties
sample_df = pd.concat([
    df_valid.nsmallest(3, 'gap_eV'),
    df_valid.nlargest(3, 'gap_eV'),
    df_valid.nlargest(3, 'mu')
]).drop_duplicates()

# Get Mol objects for visualization
sample_mols = sample_df['mol'].tolist()
sample_labels = [f"Tag {row['tag']}\ngap={row['gap_eV']:.2f} eV"
                 for _, row in sample_df.iterrows()]

# Draw molecules in a grid
print("Example molecules from QM9:")
print("(Top: smallest gaps, Middle: largest gaps, Bottom: largest dipoles)")

img = Draw.MolsToGridImage(
    sample_mols[:9],  # Limit to 9 for 3x3 grid
    molsPerRow=3,
    subImgSize=(300, 250),
    legends=sample_labels[:9]
)
# Display the image (works in Jupyter; in scripts, use img.save('filename.png'))
try:
    from IPython.display import display
    display(img)
except ImportError:
    img  # In Jupyter, this will display; in scripts, you can save with img.save()

In [None]:
# Also show some simple molecules with their SMILES
simple_molecules = df_valid[df_valid['n_atoms'] <= 8].sample(6, random_state=123)

print("\nSimple molecules with their SMILES:")
print("-" * 60)

for _, row in simple_molecules.iterrows():
    print(f"  Tag {row['tag']:>6}: {row['canonical_smiles']:<20} ({row['n_atoms']} atoms)")

# Visualize these
simple_mols = simple_molecules['mol'].tolist()
simple_labels = [row['canonical_smiles'] for _, row in simple_molecules.iterrows()]

img_simple = Draw.MolsToGridImage(
    simple_mols,
    molsPerRow=3,
    subImgSize=(250, 200),
    legends=simple_labels
)
# Display the image (works in Jupyter; in scripts, use img_simple.save('filename.png'))
try:
    from IPython.display import display
    display(img_simple)
except ImportError:
    img_simple  # In Jupyter, this will display; in scripts, you can save with img_simple.save()

---

### 3.6 Molecular Properties from RDKit

Now that we have Mol objects, we can compute additional molecular descriptors:

In [None]:
# Calculate additional molecular descriptors using RDKit
print("Calculating molecular descriptors...")

# Add some useful descriptors
df_valid['mol_weight'] = [Descriptors.MolWt(mol) for mol in tqdm(mols, desc="Molecular weight")]
df_valid['num_heavy_atoms'] = [Descriptors.HeavyAtomCount(mol) for mol in mols]
df_valid['num_rings'] = [Descriptors.RingCount(mol) for mol in mols]
df_valid['num_aromatic_rings'] = [rdMolDescriptors.CalcNumAromaticRings(mol) for mol in mols]
df_valid['num_rotatable_bonds'] = [Descriptors.NumRotatableBonds(mol) for mol in mols]
df_valid['num_hbd'] = [Descriptors.NumHDonors(mol) for mol in mols]  # H-bond donors
df_valid['num_hba'] = [Descriptors.NumHAcceptors(mol) for mol in mols]  # H-bond acceptors

print("\nNew molecular descriptors added to DataFrame:")
new_cols = ['mol_weight', 'num_heavy_atoms', 'num_rings', 'num_aromatic_rings',
            'num_rotatable_bonds', 'num_hbd', 'num_hba']
print(df_valid[new_cols].describe())

In [None]:
# Analyze relationship between RDKit descriptors and QM9 properties
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Rings vs Gap
ax1 = axes[0]
ring_gap_means = df_valid.groupby('num_rings')['gap_eV'].mean()
ax1.bar(ring_gap_means.index, ring_gap_means.values, color='steelblue', edgecolor='white')
ax1.set_xlabel('Number of Rings')
ax1.set_ylabel('Mean HOMO-LUMO Gap (eV)')
ax1.set_title('Band Gap by Ring Count')

# Aromatic rings vs Gap
ax2 = axes[1]
arom_gap_means = df_valid.groupby('num_aromatic_rings')['gap_eV'].mean()
ax2.bar(arom_gap_means.index, arom_gap_means.values, color='coral', edgecolor='white')
ax2.set_xlabel('Number of Aromatic Rings')
ax2.set_ylabel('Mean HOMO-LUMO Gap (eV)')
ax2.set_title('Band Gap by Aromatic Ring Count')

# Molecular weight vs Polarizability
ax3 = axes[2]
ax3.scatter(df_valid['mol_weight'], df_valid['alpha'], alpha=0.3, s=10)
ax3.set_xlabel('Molecular Weight (g/mol)')
ax3.set_ylabel('Polarizability (Bohr³)')
ax3.set_title('Weight vs Polarizability')
r = df_valid['mol_weight'].corr(df_valid['alpha'])
ax3.text(0.05, 0.95, f'r = {r:.3f}', transform=ax3.transAxes,
         verticalalignment='top', fontsize=12,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

**Chemical Insights:**
- Molecules with aromatic rings tend to have SMALLER band gaps (extended π-systems)
- Molecular weight correlates strongly with polarizability (more electrons = more polarizable)
- Non-aromatic rings have little effect on the band gap

---

### 3.7 Summary: SMILES and Molecular Objects

We've accomplished the first steps of molecular featurization:

In [None]:
print("Section 3.1-3.7 Summary: SMILES Parsing and Canonicalization")
print("=" * 60)
print(f"""
✓ SMILES Notation Explained:
  - Line notation for molecular structures
  - Supports atoms, bonds, rings, branches, stereochemistry
  - Multiple valid SMILES per molecule

✓ RDKit Parsing:
  - Parsed {len(mols):,} molecules successfully
  - Robust error handling for invalid SMILES
  - Mol objects enable further cheminformatics operations

✓ Canonicalization:
  - Generated unique canonical SMILES for all molecules
  - Essential for database lookups and comparisons

✓ Molecular Descriptors:
  - Added 7 new descriptors from RDKit
  - Molecular weight, rings, H-bond donors/acceptors, etc.

Next: Generate molecular fingerprints for machine learning!
""")

# Show the updated DataFrame columns
print("Updated DataFrame columns:")
print(df_valid.columns.tolist())

---

### 3.8 Molecular Fingerprints: An Introduction

**Molecular fingerprints** are fixed-length binary (or count) vectors that encode structural features of molecules. They're the workhorses of cheminformatics machine learning because they:

1. **Have fixed length**: Unlike molecules (variable atoms/bonds), fingerprints are always the same size → perfect for ML
2. **Encode structure**: Capture connectivity, functional groups, and local environments
3. **Are fast to compute**: Much faster than 3D descriptors or quantum calculations
4. **Enable similarity searching**: Easy to compare molecules using distance metrics

#### Types of Molecular Fingerprints:

| Fingerprint Type | Description | Typical Size | Use Cases |
|------------------|-------------|--------------|-----------|
| **Morgan (ECFP)** | Circular fingerprints encoding atom environments | 1024-2048 bits | Similarity searching, ML, virtual screening |
| **MACCS Keys** | Predefined structural keys (functional groups) | 166 bits | Substructure screening, interpretability |
| **Atom Pairs** | All pairs of atoms and their distance | 2048 bits | Activity cliffs, scaffold hopping |
| **Topological** | Path-based fingerprints | 2048 bits | General similarity |
| **RDKit FP** | Daylight-like fingerprints | 2048 bits | General purpose |

We'll focus on the two most commonly used: **Morgan** and **MACCS**.

---

### 3.9 Morgan Fingerprints (ECFP)

**Morgan fingerprints** (also known as ECFP - Extended Connectivity Fingerprints) are **circular fingerprints** that encode the local chemical environment around each atom.

#### How Morgan Fingerprints Work:

1. **Initialize**: Assign each atom an identifier based on its properties (element, charge, etc.)
2. **Iterate**: For each radius (0, 1, 2, ...), update each atom's identifier by hashing it with its neighbors' identifiers
3. **Fold**: Hash all identifiers into a fixed-size bit vector

```
Radius 0: Just the atom itself
Radius 1: Atom + its immediate neighbors
Radius 2: Atom + neighbors + neighbors' neighbors
...and so on
```

**Example: Carbon in Propane (C-C-C)**

```
Radius 0: Just "C" (sp3 carbon)
Radius 1: "C" bonded to "C"
Radius 2: "C" bonded to "C" bonded to "C" (end of chain)
```

**Key Parameters:**
- **radius**: Controls the "reach" of each atom's environment (typically 2-3)
- **nBits**: Size of the bit vector (typically 1024-2048)
- **ECFP4**: Morgan with radius=2 (diameter=4), the most common choice

In [None]:
from rdkit.Chem import AllChem
import numpy as np

def compute_morgan_fingerprint(mol: Chem.Mol, radius: int = 2, n_bits: int = 2048) -> np.ndarray:
    """
    Compute Morgan (ECFP) fingerprint for a molecule.

    Parameters:
    -----------
    mol : Chem.Mol
        RDKit Mol object
    radius : int
        Radius of the circular fingerprint (default: 2 → ECFP4)
    n_bits : int
        Size of the bit vector (default: 2048)

    Returns:
    --------
    np.ndarray
        Binary fingerprint as numpy array of shape (n_bits,)
    """
    # Generate Morgan fingerprint as bit vector
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=n_bits)

    # Convert to numpy array
    arr = np.zeros((n_bits,), dtype=np.int8)
    for bit in fp.GetOnBits():
        arr[bit] = 1

    return arr

#### Testing Morgan Fingerprints

In [None]:
# Test on a few molecules
print("Testing Morgan Fingerprint Generation:")
print("-" * 60)

test_indices = [0, 1, 2, 100, 500]
for idx in test_indices:
    mol = mols[idx]
    smiles = df_valid.iloc[idx]['canonical_smiles']
    fp = compute_morgan_fingerprint(mol)

    # Count bits that are "on"
    on_bits = fp.sum()
    density = on_bits / len(fp) * 100

    print(f"  {smiles:<30} → {on_bits:3d} bits ON ({density:.1f}% density)")

#### Understanding Fingerprint Bits

Each "on" bit in a Morgan fingerprint corresponds to a specific structural feature. Let's see which bits are activated:

In [None]:
# Get information about which structural features each bit represents
def get_morgan_bit_info(mol: Chem.Mol, radius: int = 2, n_bits: int = 2048) -> dict:
    """
    Get information about which structural features activate each bit.

    Returns a dict mapping bit index → (atom index, radius) pairs.
    """
    bit_info = {}
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=n_bits, bitInfo=bit_info)
    return bit_info


# Analyze bit patterns for an example molecule
example_idx = 100
example_mol = mols[example_idx]
example_smiles = df_valid.iloc[example_idx]['canonical_smiles']

print(f"Bit Analysis for: {example_smiles}")
print("-" * 60)

bit_info = get_morgan_bit_info(example_mol)
print(f"Total unique bits activated: {len(bit_info)}")
print(f"\nFirst 10 bit indices and their origins:")
print(f"{'Bit Index':<12} {'Atom Centers':<30} {'Radii':<15}")
print("-" * 57)

for i, (bit_idx, occurrences) in enumerate(list(bit_info.items())[:10]):
    atom_centers = [occ[0] for occ in occurrences]
    radii = [occ[1] for occ in occurrences]
    print(f"{bit_idx:<12} {str(atom_centers):<30} {str(radii):<15}")

#### Visualizing Fingerprint Patterns

In [None]:
# Visualize the fingerprint as a heatmap
fig, axes = plt.subplots(2, 1, figsize=(16, 4))

# Compare fingerprints of molecules with different sizes
small_mol = mols[0]  # Smallest
large_mol = mols[len(mols)//2]  # Larger

fp_small = compute_morgan_fingerprint(small_mol)
fp_large = compute_morgan_fingerprint(large_mol)

# Reshape for visualization (e.g., 64x32 grid)
ax1 = axes[0]
ax1.imshow(fp_small.reshape(32, 64), cmap='Blues', aspect='auto')
ax1.set_title(f'Morgan FP: {df_valid.iloc[0]["canonical_smiles"]} ({fp_small.sum()} bits ON)')
ax1.set_ylabel('Bit rows')

ax2 = axes[1]
ax2.imshow(fp_large.reshape(32, 64), cmap='Oranges', aspect='auto')
ax2.set_title(f'Morgan FP: {df_valid.iloc[len(mols)//2]["canonical_smiles"]} ({fp_large.sum()} bits ON)')
ax2.set_ylabel('Bit rows')
ax2.set_xlabel('Bit columns')

plt.tight_layout()
plt.show()

print("Note: Larger/more complex molecules typically have more bits 'ON'")

---

### 3.10 MACCS Fingerprints

**MACCS Keys** (Molecular ACCess System) are a set of 166 predefined structural keys that check for the presence of specific functional groups and patterns.

Unlike Morgan fingerprints (which learn patterns from the molecule), MACCS keys are **interpretable** because each bit has a known meaning.

#### Examples of MACCS Keys:

| Key # | Pattern | Description |
|-------|---------|-------------|
| 103 | `O=C-C` | Carbonyl bonded to carbon |
| 124 | `c1ccccc1` | Benzene ring |
| 141 | `N-C=O` | Amide bond |
| 162 | `O-C-O` | Acetal/hemiacetal |
| 166 | Total atom count > 7 | Large molecule |

#### MACCS Fingerprint Generation

In [None]:
from rdkit.Chem import MACCSkeys

def compute_maccs_fingerprint(mol: Chem.Mol) -> np.ndarray:
    """
    Compute MACCS fingerprint for a molecule.

    Parameters:
    -----------
    mol : Chem.Mol
        RDKit Mol object

    Returns:
    --------
    np.ndarray
        Binary fingerprint as numpy array of shape (167,)
        Note: MACCS has 167 bits but key 0 is unused, giving 166 meaningful keys
    """
    fp = MACCSkeys.GenMACCSKeys(mol)

    # Convert to numpy array
    arr = np.zeros((167,), dtype=np.int8)
    for bit in fp.GetOnBits():
        arr[bit] = 1

    return arr

#### Testing MACCS Fingerprints

In [None]:
# Test MACCS on the same molecules
print("Testing MACCS Fingerprint Generation:")
print("-" * 60)

for idx in test_indices:
    mol = mols[idx]
    smiles = df_valid.iloc[idx]['canonical_smiles']
    fp_maccs = compute_maccs_fingerprint(mol)

    on_bits = fp_maccs.sum()
    density = on_bits / len(fp_maccs) * 100

    print(f"  {smiles:<30} → {on_bits:3d} bits ON ({density:.1f}% density)")

print(f"\nMACCS fingerprints are smaller (167 bits) and denser than Morgan (2048 bits)")

#### Interpreting MACCS Bits

Since MACCS keys are predefined, we can interpret what each bit means:

In [None]:
# Some well-known MACCS key meanings
# Keys are 1-indexed in the literature (key 0 is unused)
maccs_key_meanings = {
    # Heteroatom presence
    125: "Aromatic nitrogen",
    127: "Quaternary nitrogen",
    139: "N-H",
    145: "O-H",

    # Ring patterns
    124: "Six-membered ring (benzene-like)",
    161: "Five-membered ring",
    162: "Four-membered ring",

    # Functional groups
    103: "C=O (carbonyl)",
    141: "N-C=O (amide)",
    150: "C-O-C (ether)",

    # Atom counts
    166: "More than 7 atoms",
    165: "More than 6 atoms",

    # Bonds
    163: "Double bond",
    164: "Triple bond",
}

# Analyze MACCS keys for our example molecule
print(f"\nMACCS Key Analysis for: {example_smiles}")
print("-" * 60)

fp_maccs_example = compute_maccs_fingerprint(example_mol)
on_keys = np.where(fp_maccs_example == 1)[0]

print(f"Active MACCS keys ({len(on_keys)} total):")
for key in on_keys:
    if key in maccs_key_meanings:
        print(f"  Key {key:3d}: {maccs_key_meanings[key]}")
    # Only show keys we know about; there are 166 total

---

### 3.11 Fingerprint Caching for Performance

Computing fingerprints can be slow for large datasets. Let's implement caching to avoid recomputation:

In [None]:
import hashlib
import pickle
from pathlib import Path

class FingerprintCache:
    """
    Cache for molecular fingerprints to avoid recomputation.

    Fingerprints are stored in a dictionary keyed by canonical SMILES hash.
    """

    def __init__(self, cache_dir: str = ".fp_cache"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        self.morgan_cache = {}
        self.maccs_cache = {}

    def _get_key(self, smiles: str) -> str:
        """Generate a hash key from SMILES."""
        return hashlib.md5(smiles.encode()).hexdigest()

    def get_morgan(self, mol: Chem.Mol, smiles: str, radius: int = 2, n_bits: int = 2048) -> np.ndarray:
        """Get Morgan fingerprint, using cache if available."""
        key = self._get_key(smiles)

        if key not in self.morgan_cache:
            self.morgan_cache[key] = compute_morgan_fingerprint(mol, radius, n_bits)

        return self.morgan_cache[key]

    def get_maccs(self, mol: Chem.Mol, smiles: str) -> np.ndarray:
        """Get MACCS fingerprint, using cache if available."""
        key = self._get_key(smiles)

        if key not in self.maccs_cache:
            self.maccs_cache[key] = compute_maccs_fingerprint(mol)

        return self.maccs_cache[key]

    def save(self, filename: str = "fingerprints.pkl"):
        """Save cache to disk."""
        cache_path = self.cache_dir / filename
        with open(cache_path, 'wb') as f:
            pickle.dump({
                'morgan': self.morgan_cache,
                'maccs': self.maccs_cache
            }, f)
        print(f"Cache saved to {cache_path}")

    def load(self, filename: str = "fingerprints.pkl") -> bool:
        """Load cache from disk. Returns True if successful."""
        cache_path = self.cache_dir / filename
        if cache_path.exists():
            with open(cache_path, 'rb') as f:
                data = pickle.load(f)
                self.morgan_cache = data.get('morgan', {})
                self.maccs_cache = data.get('maccs', {})
            print(f"Cache loaded from {cache_path}")
            return True
        return False


# Create cache instance
fp_cache = FingerprintCache()

print("Fingerprint cache initialized!")
print(f"Cache directory: {fp_cache.cache_dir}")

---

### 3.12 Computing Fingerprints for All Molecules

Now let's compute fingerprints for all valid molecules in our dataset:

In [None]:
def compute_all_fingerprints(mols: List[Chem.Mol], smiles_list: List[str],
                             cache: FingerprintCache = None,
                             morgan_radius: int = 2, morgan_bits: int = 2048
                             ) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute Morgan and MACCS fingerprints for a list of molecules.

    Parameters:
    -----------
    mols : List[Chem.Mol]
        List of RDKit Mol objects
    smiles_list : List[str]
        Corresponding SMILES strings
    cache : FingerprintCache, optional
        Cache to use for storing/retrieving fingerprints
    morgan_radius : int
        Radius for Morgan fingerprints
    morgan_bits : int
        Number of bits for Morgan fingerprints

    Returns:
    --------
    Tuple[np.ndarray, np.ndarray]
        Morgan fingerprints (N x morgan_bits), MACCS fingerprints (N x 167)
    """
    n_mols = len(mols)
    morgan_fps = np.zeros((n_mols, morgan_bits), dtype=np.int8)
    maccs_fps = np.zeros((n_mols, 167), dtype=np.int8)

    print(f"Computing fingerprints for {n_mols:,} molecules...")

    for i, (mol, smiles) in enumerate(tqdm(zip(mols, smiles_list), total=n_mols,
                                            desc="Computing fingerprints")):
        if cache is not None:
            morgan_fps[i] = cache.get_morgan(mol, smiles, morgan_radius, morgan_bits)
            maccs_fps[i] = cache.get_maccs(mol, smiles)
        else:
            morgan_fps[i] = compute_morgan_fingerprint(mol, morgan_radius, morgan_bits)
            maccs_fps[i] = compute_maccs_fingerprint(mol)

    return morgan_fps, maccs_fps


# Compute fingerprints for all molecules
smiles_list = df_valid['canonical_smiles'].tolist()
morgan_fingerprints, maccs_fingerprints = compute_all_fingerprints(
    mols, smiles_list, cache=fp_cache
)

print(f"\nFingerprint arrays created:")
print(f"  Morgan: {morgan_fingerprints.shape} ({morgan_fingerprints.nbytes / 1024:.1f} KB)")
print(f"  MACCS:  {maccs_fingerprints.shape} ({maccs_fingerprints.nbytes / 1024:.1f} KB)")

In [None]:
# Save fingerprints to cache for future use
fp_cache.save()

print("\nFingerprints cached for future sessions!")
print("To reload: fp_cache.load()")

---

### 3.13 Fingerprint Statistics and Analysis

Let's analyze the computed fingerprints:

In [None]:
# Morgan fingerprint statistics
print("Morgan Fingerprint Statistics:")
print("=" * 50)

morgan_bit_counts = morgan_fingerprints.sum(axis=1)
print(f"  Bits per molecule:  min={morgan_bit_counts.min()}, max={morgan_bit_counts.max()}, "
      f"mean={morgan_bit_counts.mean():.1f}")

morgan_bit_freq = morgan_fingerprints.sum(axis=0)
print(f"  Bit frequency: {(morgan_bit_freq > 0).sum()} bits used (out of 2048)")
print(f"  Most common bit: bit {morgan_bit_freq.argmax()} (in {morgan_bit_freq.max()} molecules)")
print(f"  Unused bits: {(morgan_bit_freq == 0).sum()}")

# MACCS fingerprint statistics
print("\nMACCS Fingerprint Statistics:")
print("=" * 50)

maccs_bit_counts = maccs_fingerprints.sum(axis=1)
print(f"  Bits per molecule:  min={maccs_bit_counts.min()}, max={maccs_bit_counts.max()}, "
      f"mean={maccs_bit_counts.mean():.1f}")

maccs_bit_freq = maccs_fingerprints.sum(axis=0)
print(f"  Bit frequency: {(maccs_bit_freq > 0).sum()} bits used (out of 167)")
print(f"  Most common bit: bit {maccs_bit_freq.argmax()} (in {maccs_bit_freq.max()} molecules)")

In [None]:
# Visualize bit frequency distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Morgan bit frequency
ax1 = axes[0]
ax1.hist(morgan_bit_freq[morgan_bit_freq > 0], bins=50, color='steelblue', edgecolor='white')
ax1.set_xlabel('Number of molecules with bit ON')
ax1.set_ylabel('Number of bits')
ax1.set_title('Morgan FP: Bit Frequency Distribution')
ax1.axvline(morgan_bit_freq.mean(), color='red', linestyle='--', label=f'Mean: {morgan_bit_freq.mean():.0f}')
ax1.legend()

# MACCS bit frequency
ax2 = axes[1]
ax2.bar(range(167), maccs_bit_freq, color='coral', edgecolor='none')
ax2.set_xlabel('MACCS Key Index')
ax2.set_ylabel('Number of molecules with key present')
ax2.set_title('MACCS FP: Key Frequency')

plt.tight_layout()
plt.show()

In [None]:
# Visualize relationship between fingerprint density and molecular size
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Morgan bits vs molecule size
ax1 = axes[0]
ax1.scatter(df_valid['n_atoms'], morgan_bit_counts, alpha=0.3, s=10, color='steelblue')
ax1.set_xlabel('Number of Atoms')
ax1.set_ylabel('Morgan Fingerprint Bits ON')
ax1.set_title('Fingerprint Density vs Molecular Size')
r1 = np.corrcoef(df_valid['n_atoms'], morgan_bit_counts)[0, 1]
ax1.text(0.05, 0.95, f'r = {r1:.3f}', transform=ax1.transAxes,
         verticalalignment='top', fontsize=12,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# MACCS bits vs molecule size
ax2 = axes[1]
ax2.scatter(df_valid['n_atoms'], maccs_bit_counts, alpha=0.3, s=10, color='coral')
ax2.set_xlabel('Number of Atoms')
ax2.set_ylabel('MACCS Fingerprint Bits ON')
ax2.set_title('MACCS Density vs Molecular Size')
r2 = np.corrcoef(df_valid['n_atoms'], maccs_bit_counts)[0, 1]
ax2.text(0.05, 0.95, f'r = {r2:.3f}', transform=ax2.transAxes,
         verticalalignment='top', fontsize=12,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

print("Observation: Larger molecules activate more fingerprint bits (more structural features)")

---

### 3.14 Comparing Morgan and MACCS Fingerprints

Let's compare the two fingerprint types:

In [None]:
# Compute similarity between fingerprints of similar molecules
from scipy.spatial.distance import cdist

def tanimoto_similarity(fp1: np.ndarray, fp2: np.ndarray) -> float:
    """
    Compute Tanimoto similarity between two binary fingerprints.

    Tanimoto = |A ∩ B| / |A ∪ B|
    """
    intersection = np.logical_and(fp1, fp2).sum()
    union = np.logical_or(fp1, fp2).sum()
    if union == 0:
        return 0.0
    return intersection / union


# Compare fingerprints for a few molecule pairs
print("Tanimoto Similarity Comparison: Morgan vs MACCS")
print("=" * 70)
print(f"{'Molecule A':<25} {'Molecule B':<25} {'Morgan':<10} {'MACCS':<10}")
print("-" * 70)

# Compare pairs of similar and different molecules
comparison_pairs = [
    (0, 1),      # First two molecules
    (0, 100),    # First vs 100th
    (100, 101),  # Adjacent molecules
    (0, 5000),   # Distant molecules
]

for i, j in comparison_pairs:
    smiles_a = df_valid.iloc[i]['canonical_smiles'][:22]
    smiles_b = df_valid.iloc[j]['canonical_smiles'][:22]

    morgan_sim = tanimoto_similarity(morgan_fingerprints[i], morgan_fingerprints[j])
    maccs_sim = tanimoto_similarity(maccs_fingerprints[i], maccs_fingerprints[j])

    print(f"{smiles_a:<25} {smiles_b:<25} {morgan_sim:.3f}     {maccs_sim:.3f}")

#### Summary: Morgan vs MACCS Fingerprints

| Feature | Morgan (ECFP) | MACCS Keys |
|---------|---------------|------------|
| **Size** | Configurable (1024-4096) | Fixed (166 keys) |
| **Type** | Circular (atom environments) | Substructure (predefined) |
| **Interpretability** | Low (hash-based) | High (each key is defined) |
| **Information** | More detailed | Less detailed but complete |
| **Speed** | Fast | Very fast |
| **Best For** | ML, similarity, clustering | Interpretable models, screening |

**When to use Morgan (ECFP):**
- High-accuracy machine learning models
- Similarity searching
- Clustering and chemical space visualization
- When interpretability is not critical

**When to use MACCS:**
- When you need to explain predictions (interpretable ML)
- Substructure screening in databases
- Quick filtering based on functional groups
- When computational resources are limited

**Pro tip:** For the best of both worlds, concatenate Morgan and MACCS fingerprints!

In [None]:
# Create concatenated fingerprints
combined_fingerprints = np.hstack([morgan_fingerprints, maccs_fingerprints])
print(f"Combined fingerprint shape: {combined_fingerprints.shape}")
print(f"(2048 Morgan bits + 167 MACCS keys = {2048 + 167} total features)")

---

### 3.15 Summary: Molecular Fingerprints

We've successfully implemented molecular fingerprint generation for the QM9 dataset:

In [None]:
print("Section 3.8-3.15 Summary: Molecular Fingerprints")
print("=" * 65)
print(f"""
✓ Morgan Fingerprints (ECFP):
  - Circular fingerprints encoding local atom environments
  - Generated {len(morgan_fingerprints):,} fingerprints, each {morgan_fingerprints.shape[1]} bits
  - Mean bits per molecule: {morgan_bit_counts.mean():.1f}
  - Parameters: radius=2 (ECFP4), nBits=2048

✓ MACCS Fingerprints:
  - 166 predefined structural keys (functional groups)
  - Generated {len(maccs_fingerprints):,} fingerprints, each {maccs_fingerprints.shape[1]} bits
  - Mean bits per molecule: {maccs_bit_counts.mean():.1f}
  - Interpretable: each bit has known meaning

✓ Fingerprint Caching:
  - Implemented cache to avoid recomputation
  - Saved to disk for future sessions

✓ Key Insights:
  - Larger molecules have more bits ON (more structural features)
  - Morgan captures finer details; MACCS captures broad patterns
  - Tanimoto similarity is the standard metric for binary fingerprints

Next: We'll use these fingerprints for:
  - PCA and UMAP visualization (Section 4)
  - Clustering (Section 5)
  - Property prediction models (Section 6)
""")

# Store fingerprints in the DataFrame for easy access
df_valid['morgan_fp'] = list(morgan_fingerprints)
df_valid['maccs_fp'] = list(maccs_fingerprints)
print("\nFingerprints stored in DataFrame as 'morgan_fp' and 'maccs_fp' columns")

---

## Section 4: Chemical Space Visualization

Understanding the structure of chemical space is crucial for drug discovery, materials design, and molecular generation. In this section, we'll use dimensionality reduction techniques to visualize how molecules are distributed based on their fingerprint representations.

**Learning Objectives:**
- Apply PCA to high-dimensional fingerprints
- Understand variance explained by principal components
- Visualize molecular clusters colored by properties
- Interpret chemical meaning of the visualization

---

### 4.1 Introduction to Chemical Space

**Chemical space** refers to the theoretical space containing all possible molecules. The number of drug-like molecules is estimated at 10^60—more than atoms in the observable universe!

#### Key Concepts

1. **Dimensionality**: Fingerprints are high-dimensional (2048 bits) → Need dimensionality reduction to visualize

2. **Outlier Detection**: Identify unusual or novel molecules → Points far from dense regions

3. **Cluster Discovery**: Find groups of similar molecules → Chemical families, scaffolds

4. **Property Mapping**: Color by target property → See structure-property relationships

#### Tools We'll Use

- **PCA**: Linear dimensionality reduction (fast, interpretable)
- **UMAP**: Non-linear embedding (preserves local structure)
- **Clustering**: K-means to identify molecule groups

---

### 4.2 Preparing Fingerprint Data

In [None]:
print("Preparing Data for Visualization")
print("=" * 50)

# Verify we have fingerprints from Section 3
print(f"\nAvailable fingerprint data:")
print(f"  Morgan fingerprints: {morgan_fingerprints.shape}")
print(f"  MACCS fingerprints: {maccs_fingerprints.shape}")

# We'll use Morgan fingerprints for PCA
X_morgan = morgan_fingerprints.astype(np.float64)
print(f"\nData matrix shape: {X_morgan.shape}")
print(f"  {X_morgan.shape[0]} molecules × {X_morgan.shape[1]} fingerprint bits")

# Check fingerprint statistics
n_nonzero = np.count_nonzero(X_morgan, axis=1)
print(f"\nFingerprint statistics:")
print(f"  Mean bits set per molecule: {n_nonzero.mean():.1f}")
print(f"  Min bits set: {n_nonzero.min()}")
print(f"  Max bits set: {n_nonzero.max()}")

# Confirm we have target properties for coloring
print(f"\nTarget properties available:")
print(f"  HOMO-LUMO gap: {df_valid['gap_eV'].min():.2f} to {df_valid['gap_eV'].max():.2f} eV")
print(f"  Dipole moment: {df_valid['mu'].min():.2f} to {df_valid['mu'].max():.2f} D")
print(f"  Number of atoms: {df_valid['n_atoms'].min()} to {df_valid['n_atoms'].max()}")

---

### 4.3 Principal Component Analysis (PCA)

**Principal Component Analysis (PCA)** is a fundamental dimensionality reduction technique that transforms high-dimensional data into a lower-dimensional representation while preserving as much variance as possible.

#### The Main Idea

PCA finds new axes (called **principal components**) that are linear combinations of the original features. The first principal component (PC1) captures the direction of maximum variance in the data, PC2 captures the second-most variance while being orthogonal to PC1, and so on.

#### Mathematical Foundation

Given a data matrix $X$ of shape $(n \times d)$ where $n$ is the number of samples and $d$ is the number of features:

1. **Center the data**: Subtract the mean of each feature: $\tilde{X} = X - \bar{X}$

2. **Compute the covariance matrix**: $C = \frac{1}{n-1} \tilde{X}^T \tilde{X}$

3. **Eigendecomposition**: Find eigenvalues $\lambda_1, \lambda_2, ..., \lambda_d$ and eigenvectors $v_1, v_2, ..., v_d$ of $C$

4. **Select top components**: Sort by eigenvalue magnitude; the eigenvectors corresponding to the $k$ largest eigenvalues form the projection matrix $W$

5. **Project data**: $X_{reduced} = \tilde{X} \cdot W$

The **variance explained** by each component is $\frac{\lambda_i}{\sum_j \lambda_j}$.

#### Typical Applications

- **Visualization**: Project high-dimensional data to 2D or 3D for plotting
- **Noise reduction**: Keep only components with significant variance
- **Feature extraction**: Create uncorrelated features for downstream ML
- **Data compression**: Reduce storage while retaining information
- **Exploratory analysis**: Understand structure and relationships in data

For molecular fingerprints (2048 bits), PCA can compress them to ~20-50 components while retaining 80-95% of the variance, making subsequent analyses much faster.

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

print("Applying PCA to Morgan Fingerprints")
print("=" * 50)

# Note: For fingerprints (binary), we don't typically standardize
# but we'll show both approaches for learning

# Approach 1: PCA without standardization (common for fingerprints)
pca = PCA(n_components=50)  # Keep top 50 components initially
X_pca_all = pca.fit_transform(X_morgan)

print(f"\nPCA Results:")
print(f"  Input dimensions: {X_morgan.shape[1]}")
print(f"  Output dimensions: {X_pca_all.shape[1]}")

# Variance explained by each component
var_explained = pca.explained_variance_ratio_
var_cumsum = np.cumsum(var_explained)

print(f"\nVariance Explained:")
print(f"  PC1: {var_explained[0]*100:.1f}%")
print(f"  PC2: {var_explained[1]*100:.1f}%")
print(f"  PC1+PC2: {var_cumsum[1]*100:.1f}%")
print(f"  Top 10 PCs: {var_cumsum[9]*100:.1f}%")
print(f"  Top 50 PCs: {var_cumsum[49]*100:.1f}%")

In [None]:
# Visualize variance explained (Scree plot)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Individual variance
ax1 = axes[0]
ax1.bar(range(1, 21), var_explained[:20] * 100, color='steelblue', alpha=0.7)
ax1.set_xlabel('Principal Component')
ax1.set_ylabel('Variance Explained (%)')
ax1.set_title('Scree Plot: Variance by Component')
ax1.set_xticks(range(1, 21, 2))

# Cumulative variance
ax2 = axes[1]
ax2.plot(range(1, 51), var_cumsum * 100, 'b-o', markersize=4)
ax2.axhline(y=80, color='r', linestyle='--', alpha=0.7, label='80% threshold')
ax2.axhline(y=95, color='g', linestyle='--', alpha=0.7, label='95% threshold')

# Find number of components for thresholds
n_80 = np.argmax(var_cumsum >= 0.80) + 1
n_95 = np.argmax(var_cumsum >= 0.95) + 1
ax2.axvline(x=n_80, color='r', linestyle=':', alpha=0.5)
ax2.axvline(x=n_95, color='g', linestyle=':', alpha=0.5)

ax2.set_xlabel('Number of Components')
ax2.set_ylabel('Cumulative Variance Explained (%)')
ax2.set_title('Cumulative Variance Explained')
ax2.legend(loc='lower right')
ax2.set_xlim(0, 51)
ax2.set_ylim(0, 101)

plt.tight_layout()
plt.show()

print(f"\nDimensionality Reduction Summary:")
print(f"  Components for 80% variance: {n_80}")
print(f"  Components for 95% variance: {n_95}")
print(f"  Original dimensions: {X_morgan.shape[1]}")
print(f"  Compression ratio (80%): {X_morgan.shape[1]/n_80:.1f}x")

---

### 4.4 Chemical Interpretation of PCA

In [None]:
print("Interpreting Principal Components")
print("=" * 50)

# The loading vectors tell us which fingerprint bits contribute to each PC
loadings = pca.components_  # Shape: (n_components, n_features)

print(f"\nPCA loadings shape: {loadings.shape}")
print("(Each row is a component, each column is a fingerprint bit)")

# Find the most important bits for PC1 and PC2
def get_top_contributing_bits(loadings, pc_idx, n_top=10):
    """Get bits with highest absolute loading for a given PC."""
    pc_loadings = loadings[pc_idx]
    # Sort by absolute value
    top_idx = np.argsort(np.abs(pc_loadings))[::-1][:n_top]
    return [(idx, pc_loadings[idx]) for idx in top_idx]

print("\nTop contributing bits for PC1:")
top_pc1 = get_top_contributing_bits(loadings, 0)
for bit_idx, loading in top_pc1[:5]:
    print(f"  Bit {bit_idx}: loading = {loading:.4f}")

print("\nTop contributing bits for PC2:")
top_pc2 = get_top_contributing_bits(loadings, 1)
for bit_idx, loading in top_pc2[:5]:
    print(f"  Bit {bit_idx}: loading = {loading:.4f}")

**Chemical Interpretation:**
- Morgan fingerprint bits encode local molecular environments
- Bits capturing common substructures (C-C, C-H) will dominate
- PC1 often captures overall molecular size/complexity
- PC2 often captures major structural differences (ring vs chain)
- The spread along each PC reflects structural diversity

---

### 4.5 2D Visualization Colored by HOMO-LUMO Gap

In [None]:
# Extract first 2 PCs for visualization
X_pca_2d = X_pca_all[:, :2]

print("Creating 2D PCA Visualization")
print("=" * 50)

# Create figure
fig, ax = plt.subplots(figsize=(10, 8))

# Scatter plot colored by HOMO-LUMO gap
scatter = ax.scatter(
    X_pca_2d[:, 0],
    X_pca_2d[:, 1],
    c=df_valid['gap_eV'].values,
    cmap='viridis',
    alpha=0.5,
    s=10,
    edgecolors='none'
)

# Colorbar
cbar = plt.colorbar(scatter, ax=ax, label='HOMO-LUMO Gap (eV)')

# Labels
ax.set_xlabel(f'PC1 ({var_explained[0]*100:.1f}% variance)')
ax.set_ylabel(f'PC2 ({var_explained[1]*100:.1f}% variance)')
ax.set_title('QM9 Chemical Space: PCA of Morgan Fingerprints\nColored by HOMO-LUMO Gap')

plt.tight_layout()
plt.show()

# Statistics by quadrant
pc1_median = np.median(X_pca_2d[:, 0])
pc2_median = np.median(X_pca_2d[:, 1])

print(f"\nChemical Space Analysis:")
print(f"  PC1 range: [{X_pca_2d[:, 0].min():.1f}, {X_pca_2d[:, 0].max():.1f}]")
print(f"  PC2 range: [{X_pca_2d[:, 1].min():.1f}, {X_pca_2d[:, 1].max():.1f}]")
print(f"  Total variance shown: {var_cumsum[1]*100:.1f}%")

---

### 4.6 Multi-Property Visualization

In [None]:
# Create multiple scatter plots colored by different properties
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Properties to visualize
properties = [
    ('gap_eV', 'HOMO-LUMO Gap (eV)', 'viridis'),
    ('mu', 'Dipole Moment (Debye)', 'plasma'),
    ('n_atoms', 'Number of Atoms', 'coolwarm'),
    ('alpha', 'Polarizability (Bohr³)', 'cividis')
]

for ax, (prop, label, cmap) in zip(axes.flatten(), properties):
    scatter = ax.scatter(
        X_pca_2d[:, 0],
        X_pca_2d[:, 1],
        c=df_valid[prop].values,
        cmap=cmap,
        alpha=0.4,
        s=8,
        edgecolors='none'
    )
    plt.colorbar(scatter, ax=ax, label=label)
    ax.set_xlabel(f'PC1 ({var_explained[0]*100:.1f}%)')
    ax.set_ylabel(f'PC2 ({var_explained[1]*100:.1f}%)')
    ax.set_title(f'Chemical Space Colored by {label}')

plt.suptitle('QM9 PCA: Different Property Colorings', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

**Observations:**

1. **HOMO-LUMO Gap**: Shows gradient pattern—structural similarity correlates with electronic properties

2. **Dipole Moment**: More patchy distribution—dipole depends on specific functional groups, not overall structure

3. **Number of Atoms**: Clear gradient along PC1—larger molecules cluster together (size is a major source of variance)

4. **Polarizability**: Similar to n_atoms—correlates with molecular size (more electrons = more polarizable)

---

### 4.7 Correlating PCs with Molecular Properties

In [None]:
print("Correlation: Principal Components vs. Properties")
print("=" * 55)

# Compute correlations between PCs and molecular properties
numeric_cols = ['n_atoms', 'homo_eV', 'lumo_eV', 'gap_eV', 'mu', 'alpha', 'Cv']

# Build correlation matrix
pc_property_corr = np.zeros((5, len(numeric_cols)))  # Top 5 PCs

for i in range(5):
    for j, col in enumerate(numeric_cols):
        pc_property_corr[i, j] = np.corrcoef(X_pca_all[:, i], df_valid[col])[0, 1]

# Visualize as heatmap
fig, ax = plt.subplots(figsize=(10, 5))

im = ax.imshow(pc_property_corr, cmap='RdBu_r', aspect='auto', vmin=-1, vmax=1)
plt.colorbar(im, ax=ax, label='Correlation')

# Labels
ax.set_xticks(range(len(numeric_cols)))
ax.set_xticklabels(numeric_cols, rotation=45, ha='right')
ax.set_yticks(range(5))
ax.set_yticklabels([f'PC{i+1}' for i in range(5)])
ax.set_title('Correlation: Principal Components vs. Molecular Properties')

# Add correlation values as text
for i in range(5):
    for j in range(len(numeric_cols)):
        val = pc_property_corr[i, j]
        color = 'white' if abs(val) > 0.5 else 'black'
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', color=color, fontsize=9)

plt.tight_layout()
plt.show()

# Report strongest correlations
print("\nStrongest PC-Property Correlations:")
for i in range(5):
    max_idx = np.argmax(np.abs(pc_property_corr[i]))
    max_corr = pc_property_corr[i, max_idx]
    print(f"  PC{i+1} ↔ {numeric_cols[max_idx]}: r = {max_corr:.3f}")

---

### 4.8 Highlighting Molecular Subgroups

In [None]:
print("Highlighting Specific Molecule Classes")
print("=" * 50)

# Create masks for interesting subgroups
small_gap = df_valid['gap_eV'] < 5.5  # Low band gap
large_gap = df_valid['gap_eV'] > 8.0  # High band gap
high_dipole = df_valid['mu'] > 5.0   # Polar molecules
aromatic = df_valid['num_aromatic_rings'] > 0  # Has benzene rings

# Count molecules in each group
print(f"\nMolecule subgroups:")
print(f"  Small gap (<5.5 eV): {small_gap.sum()} molecules")
print(f"  Large gap (>8.0 eV): {large_gap.sum()} molecules")
print(f"  High dipole (>5 D): {high_dipole.sum()} molecules")
print(f"  Aromatic: {aromatic.sum()} molecules")

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# 1. Small vs Large gap
ax1 = axes[0, 0]
ax1.scatter(X_pca_2d[:, 0], X_pca_2d[:, 1], c='lightgray', alpha=0.2, s=5, label='All')
ax1.scatter(X_pca_2d[small_gap, 0], X_pca_2d[small_gap, 1], c='red', alpha=0.6, s=15, label=f'Gap < 5.5 eV (n={small_gap.sum()})')
ax1.scatter(X_pca_2d[large_gap, 0], X_pca_2d[large_gap, 1], c='blue', alpha=0.6, s=15, label=f'Gap > 8.0 eV (n={large_gap.sum()})')
ax1.set_xlabel('PC1')
ax1.set_ylabel('PC2')
ax1.set_title('Small vs. Large Band Gap Molecules')
ax1.legend(loc='upper right', fontsize=8)

# 2. Aromatic molecules
ax2 = axes[0, 1]
ax2.scatter(X_pca_2d[:, 0], X_pca_2d[:, 1], c='lightgray', alpha=0.2, s=5, label='All')
ax2.scatter(X_pca_2d[aromatic, 0], X_pca_2d[aromatic, 1], c='green', alpha=0.6, s=15, label=f'Aromatic (n={aromatic.sum()})')
ax2.set_xlabel('PC1')
ax2.set_ylabel('PC2')
ax2.set_title('Aromatic Molecules in Chemical Space')
ax2.legend(loc='upper right', fontsize=8)

# 3. High dipole molecules
ax3 = axes[1, 0]
ax3.scatter(X_pca_2d[:, 0], X_pca_2d[:, 1], c='lightgray', alpha=0.2, s=5, label='All')
ax3.scatter(X_pca_2d[high_dipole, 0], X_pca_2d[high_dipole, 1], c='purple', alpha=0.6, s=15, label=f'Dipole > 5 D (n={high_dipole.sum()})')
ax3.set_xlabel('PC1')
ax3.set_ylabel('PC2')
ax3.set_title('High Dipole Moment Molecules')
ax3.legend(loc='upper right', fontsize=8)

# 4. Combined view with size = n_atoms
ax4 = axes[1, 1]
sizes = df_valid['n_atoms'].values * 1.5  # Scale for visibility
scatter = ax4.scatter(
    X_pca_2d[:, 0],
    X_pca_2d[:, 1],
    c=df_valid['gap_eV'].values,
    s=sizes,
    cmap='viridis',
    alpha=0.4,
    edgecolors='none'
)
plt.colorbar(scatter, ax=ax4, label='Gap (eV)')
ax4.set_xlabel('PC1')
ax4.set_ylabel('PC2')
ax4.set_title('Chemical Space: Size ∝ n_atoms, Color = Gap')

plt.tight_layout()
plt.show()

**Key Observations:**

1. **Small gap molecules** (red) cluster in specific regions—likely conjugated/aromatic systems with extended π-bonding

2. **Aromatic molecules** (green) occupy distinct areas—π-electrons create unique fingerprint patterns

3. **High dipole molecules** (purple) are scattered—dipole depends on local asymmetry, not global structure captured by fingerprints

4. **Molecular size** strongly correlates with PC1—larger molecules have more bits set, creating systematic variance

---

### 4.9 3D PCA Visualization

In [None]:
from mpl_toolkits.mplot3d import Axes3D

print("3D PCA Visualization")
print("=" * 50)

# Extract 3 PCs
X_pca_3d = X_pca_all[:, :3]

fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
scatter = ax.scatter(
    X_pca_3d[:, 0],
    X_pca_3d[:, 1],
    X_pca_3d[:, 2],
    c=df_valid['gap_eV'].values,
    cmap='viridis',
    alpha=0.4,
    s=8
)

# Labels
ax.set_xlabel(f'PC1 ({var_explained[0]*100:.1f}%)')
ax.set_ylabel(f'PC2 ({var_explained[1]*100:.1f}%)')
ax.set_zlabel(f'PC3 ({var_explained[2]*100:.1f}%)')
ax.set_title('3D PCA of QM9 Chemical Space\nColored by HOMO-LUMO Gap')

cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, label='Gap (eV)')

plt.tight_layout()
plt.show()

print(f"\n3D Visualization Statistics:")
print(f"  Variance captured: {var_cumsum[2]*100:.1f}%")
print(f"  PC3 alone: {var_explained[2]*100:.1f}%")
print(f"\nCompared to 2D: +{(var_cumsum[2] - var_cumsum[1])*100:.1f}% more variance")

---

### 4.10 Summary: PCA Visualization

In [None]:
print("Section 4.1-4.10 Summary: PCA Visualization")
print("=" * 60)

print(f"""
✓ What We Learned:
  ─────────────────
  - Morgan fingerprints (2048 bits) capture molecular structure
  - PCA reduces dimensionality while preserving variance
  - {n_80} components capture 80% of variance
  - PC1 strongly correlates with molecular size

✓ Visualization Insights:
  ────────────────────────
  - HOMO-LUMO gap shows gradient patterns in PCA space
  - Aromatic molecules cluster in specific regions
  - High dipole molecules are more scattered
  - Size (n_atoms) is a major source of variance

✓ Limitations of PCA:
  ────────────────────
  - Linear method - misses nonlinear structure
  - PC1+PC2 only capture {var_cumsum[1]*100:.1f}% variance
  - Global structure emphasized over local neighborhoods
  - Better methods exist: UMAP, t-SNE (coming next!)

✓ Variables Available:
  ─────────────────────
  - X_pca_all: shape {X_pca_all.shape} (top 50 PCs)
  - X_pca_2d: shape {X_pca_2d.shape} (first 2 PCs)
  - pca: fitted PCA model
  - var_explained: variance ratio per component

Next: Section 4.11+ - UMAP Visualization (better for local structure)
""")

---

### 4.11 Introduction to UMAP

While PCA is fast and interpretable, it's limited to linear projections. For complex, high-dimensional molecular data, **nonlinear** dimensionality reduction methods often reveal structure that PCA misses.

**UMAP (Uniform Manifold Approximation and Projection)** is a state-of-the-art nonlinear method that:
- Preserves **local neighborhood structure** (nearby molecules stay nearby)
- Reveals **cluster structure** in chemical space
- Runs faster than t-SNE on large datasets
- Produces more reproducible results than t-SNE

#### The Main Idea

UMAP is based on the mathematical concept that high-dimensional data often lies on a lower-dimensional **manifold** (a curved surface embedded in the high-dimensional space). Rather than finding linear projections like PCA, UMAP tries to find a low-dimensional representation that preserves the **topological structure** of the data—specifically, which points are neighbors of which.

#### Mathematical Foundation

UMAP works in two main phases:

**Phase 1: Constructing a High-Dimensional Graph**

1. For each point $x_i$, find its $k$ nearest neighbors (controlled by `n_neighbors`)

2. Compute a **fuzzy simplicial set** (weighted graph) where edge weights represent similarity:
   $$w_{ij} = \exp\left(-\frac{d(x_i, x_j) - \rho_i}{\sigma_i}\right)$$

   where $\rho_i$ is the distance to the nearest neighbor (ensuring local connectivity) and $\sigma_i$ is chosen so that the local neighborhood has a fixed "effective number of neighbors"

3. Symmetrize the graph: $w_{ij}^{sym} = w_{ij} + w_{ji} - w_{ij} \cdot w_{ji}$

**Phase 2: Optimizing the Low-Dimensional Layout**

1. Initialize points in low-dimensional space (typically 2D)

2. Define a similar similarity function in the low-dimensional space:
   $$v_{ij} = \left(1 + a \cdot ||y_i - y_j||_2^{2b}\right)^{-1}$$

   where $a$ and $b$ are parameters derived from `min_dist`

3. Minimize the **cross-entropy** between the high-dimensional and low-dimensional edge weights:
   $$C = \sum_{ij} w_{ij} \log\frac{w_{ij}}{v_{ij}} + (1-w_{ij})\log\frac{1-w_{ij}}{1-v_{ij}}$$

4. Use stochastic gradient descent to iteratively adjust the low-dimensional positions

#### Key Parameters

| Parameter | Effect | Typical Values |
|-----------|--------|----------------|
| `n_neighbors` | Size of local neighborhood. Smaller = more local structure, more clusters; Larger = more global structure | 5-50 (default: 15) |
| `min_dist` | Minimum distance between points in embedding. Smaller = tighter clusters | 0.0-1.0 (default: 0.1) |
| `metric` | Distance function for computing neighbors | 'euclidean', 'cosine', 'jaccard' |
| `random_state` | For reproducibility (UMAP uses stochastic optimization) | Any integer |

#### Why UMAP for Molecular Data?

- **High-dimensional fingerprints**: Morgan fingerprints have 2048 dimensions—UMAP handles this well
- **Binary data**: With `metric='jaccard'`, UMAP properly handles binary fingerprints
- **Chemical similarity is local**: Similar molecules share functional groups, creating local neighborhoods
- **Cluster discovery**: UMAP reveals functional group clusters that PCA misses
- **Scalability**: Efficient for datasets of 10,000+ molecules

In [None]:
# Import UMAP
try:
    import umap
    print(f"UMAP version: {umap.__version__}")
except ImportError:
    print("Installing umap-learn...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'umap-learn'])
    import umap
    print(f"UMAP installed, version: {umap.__version__}")

---

### 4.12 Computing UMAP Embeddings

Let's apply UMAP to our Morgan fingerprints. We'll experiment with different parameters to understand their effects.

In [None]:
# Prepare the fingerprint matrix
# We already have morgan_fingerprints from Section 3
print(f"Input fingerprint matrix shape: {morgan_fingerprints.shape}")

# UMAP with default parameters
print("\nFitting UMAP with default parameters...")
umap_default = umap.UMAP(
    n_components=2,
    n_neighbors=15,      # Default: consider 15 nearest neighbors
    min_dist=0.1,        # Default: minimum distance between points
    metric='euclidean',  # Works well for dense vectors
    random_state=42      # For reproducibility
)

X_umap_default = umap_default.fit_transform(morgan_fingerprints)
print(f"UMAP embedding shape: {X_umap_default.shape}")

In [None]:
# UMAP optimized for binary fingerprints
print("Fitting UMAP with Jaccard metric (optimal for binary fingerprints)...")
umap_jaccard = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric='jaccard',    # Jaccard = intersection/union, ideal for binary
    random_state=42
)

X_umap_jaccard = umap_jaccard.fit_transform(morgan_fingerprints)
print(f"UMAP (Jaccard) embedding shape: {X_umap_jaccard.shape}")

In [None]:
# Store the main UMAP embedding for later use
X_umap = X_umap_jaccard  # Using Jaccard as our primary embedding
print(f"\nPrimary UMAP embedding: {X_umap.shape}")

---

### 4.13 Basic UMAP Visualization

In [None]:
# 2D scatter plot colored by HOMO-LUMO gap
fig, ax = plt.subplots(figsize=(12, 10))

scatter = ax.scatter(
    X_umap[:, 0], X_umap[:, 1],
    c=df_valid['gap_eV'],
    cmap='viridis',
    alpha=0.5,
    s=10
)

cbar = plt.colorbar(scatter, ax=ax, label='HOMO-LUMO Gap (eV)')
ax.set_xlabel('UMAP Dimension 1', fontsize=12)
ax.set_ylabel('UMAP Dimension 2', fontsize=12)
ax.set_title('QM9 Chemical Space via UMAP\n(colored by HOMO-LUMO gap)', fontsize=14)

# Add annotation
ax.text(0.02, 0.98, f'n = {len(df_valid):,} molecules',
        transform=ax.transAxes, fontsize=10, verticalalignment='top',
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

**Interpretation:**

Notice how UMAP reveals distinct clusters that weren't visible in PCA!

- The HOMO-LUMO gap shows clear gradients within clusters
- Some clusters contain only high-gap molecules (insulators)
- Other clusters contain low-gap molecules (better conductors)
- Island structures suggest distinct chemical families

---

### 4.14 Comparing UMAP vs PCA

Let's directly compare PCA and UMAP visualizations side-by-side to understand their differences.

In [None]:
# Side-by-side comparison: PCA vs UMAP
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# PCA subplot
ax1 = axes[0]
scatter1 = ax1.scatter(
    X_pca_2d[:, 0], X_pca_2d[:, 1],
    c=df_valid['gap_eV'],
    cmap='viridis',
    alpha=0.5,
    s=10
)
plt.colorbar(scatter1, ax=ax1, label='Gap (eV)')
ax1.set_xlabel('PC1', fontsize=12)
ax1.set_ylabel('PC2', fontsize=12)
ax1.set_title('PCA Projection\n(Linear, Global Structure)', fontsize=14)

# UMAP subplot
ax2 = axes[1]
scatter2 = ax2.scatter(
    X_umap[:, 0], X_umap[:, 1],
    c=df_valid['gap_eV'],
    cmap='viridis',
    alpha=0.5,
    s=10
)
plt.colorbar(scatter2, ax=ax2, label='Gap (eV)')
ax2.set_xlabel('UMAP 1', fontsize=12)
ax2.set_ylabel('UMAP 2', fontsize=12)
ax2.set_title('UMAP Projection\n(Nonlinear, Local Structure)', fontsize=14)

plt.tight_layout()
plt.show()

**PCA vs UMAP Comparison:**

| Aspect | PCA (Left) | UMAP (Right) |
|--------|------------|--------------|
| Projection | Linear—preserves global variance | Nonlinear—preserves local neighborhoods |
| Speed | Very fast | Fast |
| Axes | Have meaning (PC1 = largest variance direction) | No direct meaning |
| Clusters | Molecules form a blob—hard to see | Reveals cluster structure (chemical families) |
| Limitation | Linear constraints miss nonlinear structure | Cluster sizes may not reflect actual density |

**Key Insight:** Both methods are complementary—use PCA for variance analysis and linear trends; use UMAP for cluster discovery and visualization.

---

### 4.15 UMAP Parameter Effects

Let's explore how UMAP parameters affect the embedding. This is crucial for understanding and tuning visualizations.

In [None]:
# Compare different n_neighbors values
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

n_neighbors_values = [5, 15, 50]
min_dist_values = [0.0, 0.1, 0.5]

# Top row: varying n_neighbors (fixed min_dist=0.1)
for i, n_neighbors in enumerate(n_neighbors_values):
    print(f"Computing UMAP with n_neighbors={n_neighbors}...")
    reducer = umap.UMAP(
        n_components=2,
        n_neighbors=n_neighbors,
        min_dist=0.1,
        metric='jaccard',
        random_state=42
    )
    embedding = reducer.fit_transform(morgan_fingerprints)

    ax = axes[0, i]
    scatter = ax.scatter(
        embedding[:, 0], embedding[:, 1],
        c=df_valid['gap_eV'],
        cmap='viridis',
        alpha=0.5,
        s=8
    )
    ax.set_title(f'n_neighbors = {n_neighbors}\n(min_dist = 0.1)', fontsize=12)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    if i == 0:
        ax.text(-0.1, 0.5, 'Varying\nn_neighbors',
                transform=ax.transAxes, fontsize=14, fontweight='bold',
                verticalalignment='center', rotation=90)

# Bottom row: varying min_dist (fixed n_neighbors=15)
for i, min_dist in enumerate(min_dist_values):
    print(f"Computing UMAP with min_dist={min_dist}...")
    reducer = umap.UMAP(
        n_components=2,
        n_neighbors=15,
        min_dist=min_dist,
        metric='jaccard',
        random_state=42
    )
    embedding = reducer.fit_transform(morgan_fingerprints)

    ax = axes[1, i]
    scatter = ax.scatter(
        embedding[:, 0], embedding[:, 1],
        c=df_valid['gap_eV'],
        cmap='viridis',
        alpha=0.5,
        s=8
    )
    ax.set_title(f'min_dist = {min_dist}\n(n_neighbors = 15)', fontsize=12)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    if i == 0:
        ax.text(-0.1, 0.5, 'Varying\nmin_dist',
                transform=ax.transAxes, fontsize=14, fontweight='bold',
                verticalalignment='center', rotation=90)

plt.tight_layout()
plt.show()

**Parameter Effects:**

**n_neighbors (top row):**
- Small (5): Very local structure, many small clusters, noise visible
- Medium (15): Balanced local/global, good cluster separation (default)
- Large (50): More global structure, smoother embedding, fewer clusters

Use smaller values to find fine-grained substructures; use larger values for a more holistic view.

**min_dist (bottom row):**
- Small (0.0): Points can pack tightly, reveals dense cores
- Medium (0.1): Some spread, cleaner visualization (default)
- Large (0.5): Very spread out, emphasis on continuum over clusters

Use smaller values when clusters are important; use larger values to see gradual transitions.

**Recommendations for Molecular Data:**
- `n_neighbors = 15-30`: Good for chemical similarity
- `min_dist = 0.1`: Balances cluster structure
- `metric = 'jaccard'`: Optimal for binary fingerprints

---

### 4.16 Multi-Property UMAP Visualization

In [None]:
# Visualize multiple properties on UMAP
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

properties = [
    ('gap_eV', 'HOMO-LUMO Gap (eV)', 'viridis'),
    ('mu', 'Dipole Moment (Debye)', 'plasma'),
    ('alpha', 'Polarizability (Bohr³)', 'cividis'),
    ('n_atoms', 'Number of Atoms', 'RdYlBu_r')
]

for ax, (prop, label, cmap) in zip(axes.flat, properties):
    scatter = ax.scatter(
        X_umap[:, 0], X_umap[:, 1],
        c=df_valid[prop],
        cmap=cmap,
        alpha=0.5,
        s=10
    )
    plt.colorbar(scatter, ax=ax, label=label)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_title(f'Colored by {label}', fontsize=11)

plt.suptitle('UMAP Projections of QM9 Chemical Space', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

**Multi-Property Insights:**

- **HOMO-LUMO Gap**: Clear clustering by gap value. Some regions contain only high-gap molecules. Cluster membership predicts gap reasonably well.

- **Dipole Moment**: More scattered distribution within clusters. Some high-dipole molecules in specific regions. Less correlated with cluster structure.

- **Polarizability**: Strong gradient visible—follows molecular size. Larger molecules (high α) cluster together. Size is a major organizing principle.

- **Number of Atoms**: Clear size gradient across the embedding. Smaller molecules on one side, larger on the other. Confirms molecular size drives fingerprint similarity.

---

### 4.17 Identifying Chemical Space Clusters

Let's visually identify and label interesting regions in the UMAP space.

In [None]:
# Identify cluster characteristics by region
fig, ax = plt.subplots(figsize=(14, 10))

# Main scatter plot
scatter = ax.scatter(
    X_umap[:, 0], X_umap[:, 1],
    c=df_valid['gap_eV'],
    cmap='viridis',
    alpha=0.4,
    s=10
)
plt.colorbar(scatter, ax=ax, label='HOMO-LUMO Gap (eV)')

# Calculate density to find cluster centers
# We'll use simple grid-based density estimation
from scipy import ndimage

# Create density map (filter out any NaN values that might exist)
valid_mask = ~np.isnan(X_umap).any(axis=1)
X_umap_valid = X_umap[valid_mask]
x_min, x_max = X_umap_valid[:, 0].min() - 1, X_umap_valid[:, 0].max() + 1
y_min, y_max = X_umap_valid[:, 1].min() - 1, X_umap_valid[:, 1].max() + 1

# Count points in bins
bins = 30
H, xedges, yedges = np.histogram2d(X_umap_valid[:, 0], X_umap_valid[:, 1], bins=bins)

# Find local maxima (cluster centers)
H_smooth = ndimage.gaussian_filter(H, sigma=1)
threshold = np.percentile(H_smooth, 90)

# Add contours for density
ax.contour(
    (xedges[:-1] + xedges[1:]) / 2,
    (yedges[:-1] + yedges[1:]) / 2,
    H_smooth.T,
    levels=5,
    colors='white',
    alpha=0.5,
    linewidths=0.5
)

ax.set_xlabel('UMAP 1', fontsize=12)
ax.set_ylabel('UMAP 2', fontsize=12)
ax.set_title('QM9 Chemical Space: Cluster Identification\n(contours show density)', fontsize=14)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)

plt.tight_layout()
plt.show()

**Visual Cluster Identification:**

The contour lines highlight regions of high molecular density. These correspond to common molecular scaffolds in QM9:

1. **Dense Central Region**: Contains "average" molecules with a mix of functional groups and moderate gap values

2. **Peripheral Islands**: Distinct chemical families that may contain unusual functional groups and often have extreme property values

3. **Connecting Regions**: Transition zones between families with gradual property changes—good for understanding structure-property relationships

The next section (Clustering) will quantitatively identify these groups.

---

### 4.18 Highlighting Functional Groups in UMAP

In [None]:
# Highlight molecules by functional group presence
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Use MACCS keys to identify functional groups
# Key indices for common groups (from Section 3)
functional_groups = [
    ('Aromatic Ring', maccs_fingerprints[:, 124] == 1, 'tab:red'),     # Benzene
    ('Carbonyl (C=O)', maccs_fingerprints[:, 103] == 1, 'tab:blue'),   # C=O
    ('Amine (-NH2)', maccs_fingerprints[:, 91] == 1, 'tab:green'),     # NH2
    ('Hydroxyl (-OH)', maccs_fingerprints[:, 77] == 1, 'tab:orange'),  # OH
]

for ax, (name, mask, color) in zip(axes.flat, functional_groups):
    # Background: all molecules in gray
    ax.scatter(
        X_umap[:, 0], X_umap[:, 1],
        c='lightgray',
        alpha=0.3,
        s=5
    )

    # Overlay: molecules with functional group
    n_with = mask.sum()
    ax.scatter(
        X_umap[mask, 0], X_umap[mask, 1],
        c=color,
        alpha=0.6,
        s=15,
        label=f'{name}\n({n_with:,} molecules, {100*n_with/len(mask):.1f}%)'
    )

    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_title(f'Molecules with {name}', fontsize=12)
    ax.legend(loc='best', fontsize=9)

plt.suptitle('Functional Group Distribution in UMAP Space', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

**Functional Group Clustering:**

This visualization reveals how functional groups organize in UMAP space:

- Aromatic molecules cluster together (extended π-systems)
- Carbonyl compounds occupy specific regions
- Amines and hydroxyls show their own patterns
- Some groups overlap (multifunctional molecules)

**Key Insight:** UMAP naturally groups molecules by structural similarity. This confirms that molecular fingerprints capture meaningful chemical information that UMAP preserves.

---

### 4.19 Interactive Region Analysis

In [None]:
# Analyze different regions of the UMAP embedding
# Divide into quadrants and compare properties

# Get center of embedding
center_x = X_umap[:, 0].median() if hasattr(X_umap[:, 0], 'median') else np.median(X_umap[:, 0])
center_y = X_umap[:, 1].median() if hasattr(X_umap[:, 1], 'median') else np.median(X_umap[:, 1])

# Define regions
regions = {
    'Top-Left': (X_umap[:, 0] < center_x) & (X_umap[:, 1] > center_y),
    'Top-Right': (X_umap[:, 0] >= center_x) & (X_umap[:, 1] > center_y),
    'Bottom-Left': (X_umap[:, 0] < center_x) & (X_umap[:, 1] <= center_y),
    'Bottom-Right': (X_umap[:, 0] >= center_x) & (X_umap[:, 1] <= center_y)
}

# Collect statistics for each region
print("Region Analysis of UMAP Embedding")
print("=" * 70)

region_stats = []
for region_name, mask in regions.items():
    subset = df_valid[mask]
    stats = {
        'Region': region_name,
        'Count': len(subset),
        'Avg Gap (eV)': subset['gap_eV'].mean(),
        'Avg Atoms': subset['n_atoms'].mean(),
        'Avg Dipole': subset['mu'].mean(),
        'Aromatic %': 100 * maccs_fingerprints[mask, 124].mean()
    }
    region_stats.append(stats)
    print(f"\n{region_name}:")
    print(f"  Molecules: {stats['Count']:,}")
    print(f"  Avg HOMO-LUMO gap: {stats['Avg Gap (eV)']:.2f} eV")
    print(f"  Avg atoms: {stats['Avg Atoms']:.1f}")
    print(f"  Avg dipole: {stats['Avg Dipole']:.2f} Debye")
    print(f"  Contains aromatic: {stats['Aromatic %']:.1f}%")

region_df = pd.DataFrame(region_stats)
print("\n" + "=" * 70)

In [None]:
# Visualize regions
fig, ax = plt.subplots(figsize=(12, 10))

colors = {'Top-Left': 'red', 'Top-Right': 'blue',
          'Bottom-Left': 'green', 'Bottom-Right': 'orange'}

for region_name, mask in regions.items():
    ax.scatter(
        X_umap[mask, 0], X_umap[mask, 1],
        c=colors[region_name],
        alpha=0.4,
        s=10,
        label=region_name
    )

# Draw dividing lines
ax.axvline(x=center_x, color='black', linestyle='--', alpha=0.5)
ax.axhline(y=center_y, color='black', linestyle='--', alpha=0.5)

ax.set_xlabel('UMAP 1', fontsize=12)
ax.set_ylabel('UMAP 2', fontsize=12)
ax.set_title('UMAP Space Divided into Quadrants', fontsize=14)
ax.legend(loc='best')

plt.tight_layout()
plt.show()

# Bar chart comparing regions
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

metrics = ['Avg Gap (eV)', 'Avg Atoms', 'Avg Dipole', 'Aromatic %']
for ax, metric in zip(axes, metrics):
    values = [stats[metric] for stats in region_stats]
    ax.bar(range(4), values, color=[colors[r] for r in regions.keys()])
    ax.set_xticks(range(4))
    ax.set_xticklabels(['TL', 'TR', 'BL', 'BR'])
    ax.set_title(metric, fontsize=11)
    ax.set_ylabel(metric)

plt.suptitle('Property Comparison Across UMAP Regions', fontsize=12, y=1.05)
plt.tight_layout()
plt.show()

---

### 4.20 3D UMAP Visualization

In [None]:
# 3D UMAP embedding
print("Computing 3D UMAP embedding...")
umap_3d = umap.UMAP(
    n_components=3,
    n_neighbors=15,
    min_dist=0.1,
    metric='jaccard',
    random_state=42
)
X_umap_3d = umap_3d.fit_transform(morgan_fingerprints)
print(f"3D UMAP embedding shape: {X_umap_3d.shape}")

In [None]:
# 3D visualization
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(14, 10))
ax = fig.add_subplot(111, projection='3d')

scatter = ax.scatter(
    X_umap_3d[:, 0], X_umap_3d[:, 1], X_umap_3d[:, 2],
    c=df_valid['gap_eV'],
    cmap='viridis',
    alpha=0.5,
    s=5
)

fig.colorbar(scatter, ax=ax, label='HOMO-LUMO Gap (eV)', shrink=0.6)
ax.set_xlabel('UMAP 1')
ax.set_ylabel('UMAP 2')
ax.set_zlabel('UMAP 3')
ax.set_title('3D UMAP Projection of QM9 Chemical Space', fontsize=14)

# Set viewing angle
ax.view_init(elev=20, azim=45)

plt.tight_layout()
plt.show()

**3D UMAP Insights:**

The 3D embedding reveals additional structure:
- Some clusters that appear merged in 2D are separated in 3D
- The manifold has a complex, folded structure
- Different viewing angles reveal different relationships

For interactive exploration, consider using: plotly (interactive 3D), py3Dmol (molecular viewer), or nglview (Jupyter notebooks).

---

### 4.21 Summary: UMAP Visualization

In [None]:
print("Section 4.11-4.21 Summary: UMAP Visualization")
print("=" * 60)

print(f"""
✓ What We Learned:
  ─────────────────
  - UMAP reveals nonlinear structure PCA misses
  - Jaccard metric is optimal for binary fingerprints
  - n_neighbors controls local vs global focus
  - min_dist controls cluster tightness

✓ Key Parameters (recommended for molecular data):
  ───────────────────────────────────────────────────
  - n_neighbors: 15 (default works well)
  - min_dist: 0.1 (balance clusters and spread)
  - metric: 'jaccard' (for binary fingerprints)
  - random_state: 42 (for reproducibility)

✓ Visual Insights:
  ─────────────────
  - Clear cluster structure in chemical space
  - Functional groups cluster together (aromatics, etc.)
  - Property gradients visible within clusters
  - Size (n_atoms) is a major organizing principle

✓ PCA vs UMAP Comparison:
  ─────────────────────────
  | Aspect          | PCA           | UMAP            |
  |-----------------|---------------|-----------------|
  | Method          | Linear        | Nonlinear       |
  | Focus           | Global        | Local           |
  | Speed           | Very fast     | Fast            |
  | Interpretability| High (axes)   | Low (axes)      |
  | Clusters        | Not visible   | Clearly visible |
  | Use case        | Variance      | Structure       |

✓ Variables Available:
  ─────────────────────
  - X_umap: shape {X_umap.shape} (2D embedding, Jaccard)
  - X_umap_3d: shape {X_umap_3d.shape} (3D embedding)
  - X_umap_default: shape {X_umap_default.shape} (Euclidean)
  - X_umap_jaccard: shape {X_umap_jaccard.shape} (Jaccard)

Next: Section 5 - Clustering Analysis (k-means on UMAP)
""")

---

## Section 5: Clustering and Unsupervised Learning

In this section, we'll explore unsupervised learning techniques to discover natural groupings in our molecular dataset. Clustering can reveal chemical families, identify outliers, and help us understand structure-property relationships without using target labels.

**Learning Objectives:**
- Apply k-means clustering to molecular fingerprints
- Use the elbow method to determine optimal cluster count
- Visualize and interpret clusters in UMAP space
- Analyze chemical composition of each cluster
- Connect cluster membership to molecular properties

---

### 5.1 Introduction to Molecular Clustering

#### What is Clustering?

**Clustering** is an unsupervised machine learning technique that groups similar data points together without using predefined labels. The goal is to partition data into groups (clusters) such that:
- Points **within** the same cluster are similar to each other
- Points in **different** clusters are dissimilar

Unlike supervised learning (where we predict known labels), clustering discovers hidden structure in data—we don't tell the algorithm what the groups should be; it finds them automatically.

#### The Core Idea

Clustering algorithms try to optimize an objective that balances two competing goals:

1. **Compactness**: Points within a cluster should be close together (low intra-cluster distance)
2. **Separation**: Points in different clusters should be far apart (high inter-cluster distance)

Different algorithms formalize these goals differently, leading to different cluster shapes and characteristics.

#### Why Cluster Molecules?

Clustering is particularly valuable for molecular datasets because:

1. **Discover chemical families**: Group molecules by structural similarity (scaffolds, functional groups)
2. **Identify similar compounds**: Find molecules with shared properties without explicit rules
3. **Detect outliers**: Unusual molecules that don't fit any cluster may have unique properties
4. **Guide data splitting**: Stratify train/test sets to ensure diversity
5. **Understand chemical space**: Visualize the coverage and gaps in your dataset
6. **Lead optimization**: In drug discovery, explore clusters around promising hits

#### Common Clustering Methods

| Method | Description | Best For |
|--------|-------------|----------|
| **K-Means** | Partition into k centroids by minimizing within-cluster variance | Spherical clusters, fast |
| **DBSCAN** | Density-based: clusters are dense regions separated by sparse areas | Arbitrary shapes, outlier detection |
| **Hierarchical** | Build tree of clusters (agglomerative or divisive) | Dendrograms, unknown k |
| **Spectral** | Use graph Laplacian eigenvectors | Non-convex shapes |

We'll primarily use **k-means**, which works well with molecular fingerprints due to their roughly spherical distribution in high-dimensional space and its computational efficiency.

---

### 5.2 Elbow Method for Optimal k

A critical challenge with k-means clustering is choosing the number of clusters $k$. The **Elbow Method** is a heuristic technique to help identify a good value.

#### How the Elbow Method Works

1. **Run k-means (see below) for different values of k** (e.g., k = 2, 3, 4, ..., 15)

2. **Compute the inertia** (also called WCSS - Within-Cluster Sum of Squares) for each k:
   $$\text{Inertia} = \sum_{i=1}^{n} ||x_i - \mu_{c(i)}||^2$$
   where $\mu_{c(i)}$ is the centroid of the cluster that point $x_i$ belongs to

3. **Plot inertia vs. k**: As k increases, inertia always decreases (more clusters = less variance within each)

4. **Find the "elbow"**: Look for the point where adding more clusters provides diminishing returns—the curve bends sharply, forming an "elbow"

#### Why It Works

- With **too few clusters**: Points are forced into large, heterogeneous groups → high inertia
- With **too many clusters**: Each point nearly has its own cluster → minimal reduction in inertia
- At the **optimal k**: Clusters capture natural groupings; adding more doesn't help much

#### Complementary Metric: Silhouette Score

The **Silhouette Score** provides another perspective, measuring both cohesion and separation:

$$s(i) = \frac{b(i) - a(i)}{\max(a(i), b(i))}$$

where:
- $a(i)$ = average distance from point $i$ to other points in its cluster (cohesion)
- $b(i)$ = average distance from point $i$ to points in the nearest other cluster (separation)

Silhouette ranges from -1 to +1, where higher is better.

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# We'll cluster on Morgan fingerprints (already computed)
X_cluster = morgan_fingerprints.astype(np.float64)
print(f"Fingerprint matrix for clustering: {X_cluster.shape}")

# Test a range of k values
k_range = range(2, 16)
inertias = []
silhouette_scores = []

print("\nComputing clustering metrics for k = 2 to 15...")
for k in k_range:
    kmeans = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300)
    labels = kmeans.fit_predict(X_cluster)
    inertias.append(kmeans.inertia_)

    # Silhouette score (higher is better)
    sil_score = silhouette_score(X_cluster, labels, sample_size=5000, random_state=42)
    silhouette_scores.append(sil_score)

    print(f"  k={k:2d}: Inertia = {kmeans.inertia_:,.0f}, Silhouette = {sil_score:.4f}")

print("\n✓ Elbow analysis complete")

In [None]:
# Plot elbow curve and silhouette scores
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Elbow plot (inertia)
ax1 = axes[0]
ax1.plot(list(k_range), inertias, 'bo-', linewidth=2, markersize=8)
ax1.set_xlabel('Number of Clusters (k)', fontsize=12)
ax1.set_ylabel('Inertia (Within-Cluster Sum of Squares)', fontsize=12)
ax1.set_title('Elbow Method for Optimal k', fontsize=14)
ax1.set_xticks(list(k_range))
ax1.grid(True, alpha=0.3)

# Add annotations for key points
ax1.annotate('Elbow region', xy=(6, inertias[4]), xytext=(8, inertias[2]),
             arrowprops=dict(arrowstyle='->', color='red'),
             fontsize=10, color='red')

# Silhouette plot
ax2 = axes[1]
ax2.plot(list(k_range), silhouette_scores, 'go-', linewidth=2, markersize=8)
ax2.set_xlabel('Number of Clusters (k)', fontsize=12)
ax2.set_ylabel('Silhouette Score', fontsize=12)
ax2.set_title('Silhouette Score vs Number of Clusters', fontsize=14)
ax2.set_xticks(list(k_range))
ax2.grid(True, alpha=0.3)

# Find best silhouette
best_k_sil = list(k_range)[np.argmax(silhouette_scores)]
ax2.axvline(x=best_k_sil, color='red', linestyle='--', alpha=0.7, label=f'Best: k={best_k_sil}')
ax2.legend()

plt.tight_layout()
plt.show()

print(f"""
Interpretation:
─────────────────────────────────────────────────────────────────
  - Elbow Method: Look for where the curve bends (diminishing returns)
  - Silhouette Score: Higher is better (measures cluster separation)
  - Best k by silhouette: {best_k_sil}

  For molecules, we'll use k=6 as a balance between:
    • Interpretable number of clusters
    • Capturing chemical diversity
    • Reasonable cluster sizes
""")

---

### 5.3 K-Means Clustering (k=6)

**K-Means** is one of the most widely used clustering algorithms due to its simplicity and efficiency. It partitions $n$ data points into $k$ clusters by iteratively refining cluster assignments.

#### The K-Means Optimization Problem

K-means seeks to find **two things simultaneously**:
1. **Cluster assignments** $c(i)$ for each point $x_i$ — which cluster does point $i$ belong to?
2. **Cluster centroids** $\mu_1, \mu_2, ..., \mu_k$ — where is the "center" of each cluster?

**Objective**: Minimize the within-cluster sum of squares (inertia) with respect to both the assignments and centroids:

$$J(\{c(i)\}, \{\mu_j\}) = \sum_{j=1}^{k} \sum_{i: c(i)=j} ||x_i - \mu_j||^2$$

This objective measures the total squared distance from each point to its assigned cluster center. Lower $J$ means tighter, more compact clusters.

#### Algorithm Steps (Lloyd's Algorithm)

Since optimizing $J$ over all possible assignments is computationally intractable (there are $k^n$ possibilities), K-means uses an iterative **coordinate descent** approach that alternates between optimizing assignments and centroids:

**Step 1: Initialize Centroids**

Choose $k$ initial centroid positions. Common strategies:
- **Random**: Select $k$ random data points as initial centroids
- **K-means++** (default in scikit-learn): Spread out initial centroids by choosing each successive centroid with probability proportional to distance from existing centroids

**Step 2: Assignment Step** — Fix centroids, optimize assignments

Given fixed centroids $\{\mu_j\}$, assign each point to the cluster whose centroid is nearest:

$$c(i) = \arg\min_{j \in \{1,...,k\}} ||x_i - \mu_j||^2$$

This is the optimal assignment for the current centroids because assigning each point to its nearest center minimizes that point's contribution to $J$.

**Step 3: Update Step** — Fix assignments, optimize centroids

Given fixed assignments $\{c(i)\}$, recompute each centroid as the mean of its assigned points:

$$\mu_j = \frac{1}{|C_j|} \sum_{i: c(i)=j} x_i$$

where $|C_j|$ is the number of points assigned to cluster $j$. The mean minimizes the sum of squared distances to all points in the cluster (this can be shown by taking the derivative of $J$ with respect to $\mu_j$ and setting it to zero).

**Step 4: Repeat** steps 2-3 until convergence (assignments don't change) or max iterations reached.

#### Why This Works

Each step is guaranteed to decrease (or maintain) the objective $J$:
- The assignment step finds the best assignments for fixed centroids
- The update step finds the best centroids for fixed assignments

Since $J$ is bounded below by 0 and decreases each iteration, the algorithm must converge. However, it converges to a **local minimum**, not necessarily the global optimum—hence the importance of running multiple times with different initializations (`n_init` parameter).

#### Key Properties

- **Converges** to a local minimum (not guaranteed to find global optimum)
- **Sensitive to initialization**: Use `n_init` to run multiple times with different starting points and keep the best result
- **Assumes spherical clusters**: Works best when clusters are roughly equal-sized and spherical
- **Scales well**: $O(n \cdot k \cdot d \cdot i)$ where $n$ = samples, $k$ = clusters, $d$ = dimensions, $i$ = iterations

For molecular fingerprints (2048 dimensions), k-means is effective because fingerprint similarity tends to be roughly spherical in the high-dimensional space.

In [None]:
# Final k-means clustering with k=6
optimal_k = 6

kmeans_final = KMeans(
    n_clusters=optimal_k,
    random_state=42,
    n_init=20,  # More initializations for better convergence
    max_iter=500
)
cluster_labels = kmeans_final.fit_predict(X_cluster)

# Add cluster labels to DataFrame
df_valid['cluster'] = cluster_labels

print(f"K-Means Clustering Results (k={optimal_k})")
print("=" * 50)
print(f"\nCluster Distribution:")
for i in range(optimal_k):
    count = np.sum(cluster_labels == i)
    percentage = 100 * count / len(cluster_labels)
    print(f"  Cluster {i}: {count:,} molecules ({percentage:.1f}%)")

print(f"\nInertia (final): {kmeans_final.inertia_:,.0f}")
print(f"Silhouette Score: {silhouette_score(X_cluster, cluster_labels, sample_size=5000, random_state=42):.4f}")

---

### 5.4 Cluster Visualization on UMAP

After clustering in the high-dimensional fingerprint space (2048 dimensions), we need a way to **visualize** the results. Since we can't plot 2048 dimensions directly, we use the UMAP embedding computed earlier.

**Important**: K-means clustering was performed in the original high-dimensional space, not on the UMAP coordinates. UMAP is only used here for visualization. This means:

- Clusters that appear to overlap in 2D may be well-separated in 2048D
- Clusters that span multiple UMAP regions may still be coherent in the original space
- The visual boundaries don't perfectly match the actual cluster boundaries

The visualization below creates two views:
1. **Left panel**: Scatter plot with each cluster shown in a different color
2. **Right panel**: Same data with cluster centers (mean UMAP position) marked with X symbols

In [None]:
# Visualize clusters in UMAP space
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Cluster-colored UMAP
ax1 = axes[0]
cluster_colors = plt.cm.Set2(np.linspace(0, 1, optimal_k))

for i in range(optimal_k):
    mask = cluster_labels == i
    ax1.scatter(
        X_umap[mask, 0], X_umap[mask, 1],
        c=[cluster_colors[i]],
        label=f'Cluster {i} (n={np.sum(mask):,})',
        alpha=0.5,
        s=10
    )

ax1.set_xlabel('UMAP 1', fontsize=12)
ax1.set_ylabel('UMAP 2', fontsize=12)
ax1.set_title('K-Means Clusters in UMAP Space', fontsize=14)
ax1.legend(loc='best', fontsize=9)

# Cluster centroids projected to UMAP space (approximate)
ax2 = axes[1]
scatter = ax2.scatter(
    X_umap[:, 0], X_umap[:, 1],
    c=cluster_labels,
    cmap='Set2',
    alpha=0.4,
    s=8
)

# Mark approximate cluster centers in UMAP space
for i in range(optimal_k):
    mask = cluster_labels == i
    center_x = np.mean(X_umap[mask, 0])
    center_y = np.mean(X_umap[mask, 1])
    ax2.scatter(center_x, center_y, c='black', s=200, marker='X', edgecolors='white', linewidth=2)
    ax2.annotate(f'{i}', (center_x, center_y), fontsize=12, fontweight='bold',
                 ha='center', va='center', color='white')

ax2.set_xlabel('UMAP 1', fontsize=12)
ax2.set_ylabel('UMAP 2', fontsize=12)
ax2.set_title('Cluster Centers in UMAP Space', fontsize=14)

plt.tight_layout()
plt.show()

**Visualization Notes:**
- K-means operates in high-dimensional fingerprint space
- UMAP is a 2D projection (clusters may overlap visually)
- Cluster centers (X markers) show approximate UMAP locations
- Some clusters span multiple UMAP regions (high-D structure)

---

### 5.5 Cluster Property Analysis

In [None]:
# Analyze properties within each cluster
cluster_stats = []

print("Cluster Property Summary")
print("=" * 80)

for i in range(optimal_k):
    mask = df_valid['cluster'] == i
    cluster_data = df_valid[mask]

    stats = {
        'Cluster': i,
        'Count': len(cluster_data),
        'Avg Gap (eV)': cluster_data['gap_eV'].mean(),
        'Std Gap (eV)': cluster_data['gap_eV'].std(),
        'Avg HOMO (eV)': cluster_data['homo_eV'].mean(),
        'Avg LUMO (eV)': cluster_data['lumo_eV'].mean(),
        'Avg Dipole (D)': cluster_data['mu'].mean(),
        'Avg Alpha': cluster_data['alpha'].mean(),
        'Avg Atoms': cluster_data['n_atoms'].mean(),
        'Avg Heavy': cluster_data['num_heavy_atoms'].mean(),
        'Aromatic %': 100 * (cluster_data['num_aromatic_rings'] > 0).mean(),
        'Avg Rings': cluster_data['num_rings'].mean()
    }
    cluster_stats.append(stats)

    print(f"\nCluster {i} (n={stats['Count']:,})")
    print("-" * 40)
    print(f"  Electronic: Gap={stats['Avg Gap (eV)']:.2f}±{stats['Std Gap (eV)']:.2f} eV, "
          f"HOMO={stats['Avg HOMO (eV)']:.2f} eV, LUMO={stats['Avg LUMO (eV)']:.2f} eV")
    print(f"  Physical:   μ={stats['Avg Dipole (D)']:.2f} D, α={stats['Avg Alpha']:.1f}")
    print(f"  Structural: {stats['Avg Atoms']:.1f} atoms, {stats['Avg Heavy']:.1f} heavy, "
          f"{stats['Aromatic %']:.1f}% aromatic, {stats['Avg Rings']:.2f} rings")

# Create DataFrame for comparison
df_cluster_stats = pd.DataFrame(cluster_stats)
print("\n\nCluster Comparison Table:")
print(df_cluster_stats.to_string(index=False))

In [None]:
# Visualize cluster properties
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

properties = ['Avg Gap (eV)', 'Avg Atoms', 'Avg Dipole (D)', 'Aromatic %', 'Avg Rings', 'Avg Alpha']
colors = plt.cm.Set2(np.linspace(0, 1, optimal_k))

for ax, prop in zip(axes.flatten(), properties):
    values = df_cluster_stats[prop].values
    bars = ax.bar(range(optimal_k), values, color=colors)
    ax.set_xlabel('Cluster', fontsize=11)
    ax.set_ylabel(prop, fontsize=11)
    ax.set_title(prop, fontsize=12)
    ax.set_xticks(range(optimal_k))

    # Add value labels on bars
    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02 * max(values),
                f'{val:.1f}', ha='center', va='bottom', fontsize=9)

plt.suptitle('Property Comparison Across K-Means Clusters', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---

### 5.6 HOMO-LUMO Gap Distribution by Cluster

In [None]:
# Box plots of gap by cluster
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Box plot
ax1 = axes[0]
cluster_gaps = [df_valid[df_valid['cluster'] == i]['gap_eV'].values for i in range(optimal_k)]
bp = ax1.boxplot(cluster_gaps, patch_artist=True)

for i, patch in enumerate(bp['boxes']):
    patch.set_facecolor(cluster_colors[i])
    patch.set_alpha(0.7)

ax1.set_xlabel('Cluster', fontsize=12)
ax1.set_ylabel('HOMO-LUMO Gap (eV)', fontsize=12)
ax1.set_title('Gap Distribution by Cluster', fontsize=14)
ax1.set_xticklabels(range(optimal_k))
ax1.grid(True, alpha=0.3, axis='y')

# Violin plot with individual distributions
ax2 = axes[1]
for i in range(optimal_k):
    mask = df_valid['cluster'] == i
    data = df_valid[mask]['gap_eV'].values

    # Create violin-like distribution
    parts = ax2.violinplot([data], positions=[i], showmeans=True, showextrema=True)
    for pc in parts['bodies']:
        pc.set_facecolor(cluster_colors[i])
        pc.set_alpha(0.7)

ax2.set_xlabel('Cluster', fontsize=12)
ax2.set_ylabel('HOMO-LUMO Gap (eV)', fontsize=12)
ax2.set_title('Gap Distribution (Violin Plot)', fontsize=14)
ax2.set_xticks(range(optimal_k))
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

**Gap Distribution Insights:**
- Different clusters show distinct gap distributions
- Clusters with more aromatic compounds tend to have lower gaps
- Variance within clusters indicates structural diversity
- Some clusters have bimodal distributions (sub-populations)

---

### 5.7 Cluster Chemical Composition

In [None]:
# Analyze functional groups by cluster using MACCS keys
# Key MACCS indices (from US-011 learnings):
# 124=benzene, 103=C=O, 91=NH2, 77=OH, 141=amide

maccs_keys_of_interest = {
    'Benzene (aromatic)': 124,
    'Carbonyl (C=O)': 103,
    'Primary Amine (NH2)': 91,
    'Hydroxyl (OH)': 77,
    'Nitrogen': 162,
    'Oxygen': 160
}

print("Functional Group Prevalence by Cluster (%)")
print("=" * 80)

functional_group_data = {key: [] for key in maccs_keys_of_interest.keys()}
functional_group_data['Cluster'] = list(range(optimal_k))

for i in range(optimal_k):
    mask = df_valid['cluster'] == i
    cluster_maccs = maccs_fingerprints[mask]

    print(f"\nCluster {i} (n={np.sum(mask):,}):")
    for name, key_idx in maccs_keys_of_interest.items():
        prevalence = 100 * np.mean(cluster_maccs[:, key_idx])
        functional_group_data[name].append(prevalence)
        print(f"  {name:25s}: {prevalence:5.1f}%")

df_functional = pd.DataFrame(functional_group_data)

In [None]:
# Heatmap of functional group prevalence
fig, ax = plt.subplots(figsize=(12, 6))

# Prepare data for heatmap
heatmap_data = df_functional.drop('Cluster', axis=1).values.T
group_names = list(maccs_keys_of_interest.keys())

im = ax.imshow(heatmap_data, cmap='YlOrRd', aspect='auto')

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Prevalence (%)', fontsize=11)

# Set labels
ax.set_xticks(range(optimal_k))
ax.set_xticklabels([f'Cluster {i}' for i in range(optimal_k)])
ax.set_yticks(range(len(group_names)))
ax.set_yticklabels(group_names)
ax.set_xlabel('Cluster', fontsize=12)
ax.set_ylabel('Functional Group', fontsize=12)
ax.set_title('Functional Group Prevalence Heatmap by Cluster', fontsize=14)

# Add value annotations
for i in range(len(group_names)):
    for j in range(optimal_k):
        val = heatmap_data[i, j]
        color = 'white' if val > 50 else 'black'
        ax.text(j, i, f'{val:.0f}%', ha='center', va='center', color=color, fontsize=9)

plt.tight_layout()
plt.show()

---

### 5.8 Cluster Molecular Size Analysis

In [None]:
# Analyze molecule sizes within clusters
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram of atom counts per cluster
ax1 = axes[0]
for i in range(optimal_k):
    mask = df_valid['cluster'] == i
    data = df_valid[mask]['n_atoms'].values
    ax1.hist(data, bins=20, alpha=0.5, label=f'Cluster {i}', color=cluster_colors[i])

ax1.set_xlabel('Number of Atoms', fontsize=12)
ax1.set_ylabel('Count', fontsize=12)
ax1.set_title('Molecule Size Distribution by Cluster', fontsize=14)
ax1.legend(loc='upper right')
ax1.grid(True, alpha=0.3)

# Scatter: Size vs Gap colored by cluster
ax2 = axes[1]
for i in range(optimal_k):
    mask = df_valid['cluster'] == i
    ax2.scatter(
        df_valid[mask]['n_atoms'],
        df_valid[mask]['gap_eV'],
        c=[cluster_colors[i]],
        alpha=0.3,
        s=15,
        label=f'Cluster {i}'
    )

ax2.set_xlabel('Number of Atoms', fontsize=12)
ax2.set_ylabel('HOMO-LUMO Gap (eV)', fontsize=12)
ax2.set_title('Molecule Size vs Gap by Cluster', fontsize=14)
ax2.legend(loc='upper right', fontsize=8)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

**Size Analysis Insights:**
- Some clusters are dominated by small/large molecules
- Size correlates with cluster assignment (fingerprint similarity)
- Gap-size relationship varies by cluster (different chemistry)
- Smaller molecules often have larger gaps (less conjugation)

---

### 5.9 Representative Molecules per Cluster

In [None]:
# Find representative molecules (closest to cluster centroid)
from sklearn.metrics.pairwise import euclidean_distances

def get_cluster_representatives(X, labels, centroids, n_samples=3):
    """Find molecules closest to each cluster centroid."""
    representatives = {}
    for i, centroid in enumerate(centroids):
        mask = labels == i
        cluster_indices = np.where(mask)[0]
        cluster_data = X[mask]

        # Compute distances to centroid
        distances = euclidean_distances(cluster_data, centroid.reshape(1, -1)).flatten()

        # Get indices of closest molecules
        closest_idx = np.argsort(distances)[:n_samples]
        representatives[i] = cluster_indices[closest_idx]

    return representatives

representatives = get_cluster_representatives(
    X_cluster, cluster_labels, kmeans_final.cluster_centers_, n_samples=3
)

print("Representative Molecules per Cluster")
print("=" * 60)

for cluster_id, indices in representatives.items():
    print(f"\nCluster {cluster_id} Representatives:")
    print("-" * 40)
    for idx in indices:
        mol_data = df_valid.iloc[idx]
        print(f"  SMILES: {mol_data['canonical_smiles']}")
        print(f"  Gap: {mol_data['gap_eV']:.2f} eV, Atoms: {mol_data['n_atoms']}, "
              f"Aromatic rings: {mol_data['num_aromatic_rings']}")

In [None]:
# Visualize representative molecules
from rdkit.Chem import Draw
from PIL import Image

fig, axes = plt.subplots(2, 3, figsize=(18, 12))

for cluster_id, ax in enumerate(axes.flatten()):
    if cluster_id >= optimal_k:
        ax.axis('off')
        continue

    # Get representative molecules for this cluster
    indices = representatives[cluster_id]
    rep_mols = [df_valid.iloc[idx]['mol'] for idx in indices]

    # Create grid image with legends
    legends = [f"Gap: {df_valid.iloc[idx]['gap_eV']:.2f} eV" for idx in indices]
    img = Draw.MolsToGridImage(rep_mols, molsPerRow=3, subImgSize=(200, 200), legends=legends)

    # Convert PIL Image to numpy array for matplotlib compatibility
    if isinstance(img, Image.Image):
        img = np.array(img)

    ax.imshow(img)
    ax.set_title(f'Cluster {cluster_id} (n={np.sum(cluster_labels == cluster_id):,})', fontsize=12)
    ax.axis('off')

plt.suptitle('Representative Molecules from Each Cluster', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---

### 5.10 Cluster Chemical Interpretation

In [None]:
print("Cluster Chemical Interpretation")
print("=" * 70)

# Analyze each cluster
for i in range(optimal_k):
    mask = df_valid['cluster'] == i
    cluster_data = df_valid[mask]

    # Compute key characteristics
    avg_gap = cluster_data['gap_eV'].mean()
    avg_atoms = cluster_data['n_atoms'].mean()
    aromatic_frac = (cluster_data['num_aromatic_rings'] > 0).mean()
    avg_rings = cluster_data['num_rings'].mean()
    avg_dipole = cluster_data['mu'].mean()

    # Get MACCS key prevalence for this cluster
    cluster_maccs = maccs_fingerprints[mask]
    benzene_prev = np.mean(cluster_maccs[:, 124])
    carbonyl_prev = np.mean(cluster_maccs[:, 103])
    amine_prev = np.mean(cluster_maccs[:, 91])

    print(f"\n{'='*70}")
    print(f"CLUSTER {i}: {np.sum(mask):,} molecules")
    print(f"{'='*70}")

    # Characterize based on properties
    if aromatic_frac > 0.5:
        print(f"  Type: AROMATIC compounds")
    elif avg_rings > 1:
        print(f"  Type: CYCLIC (non-aromatic)")
    elif avg_atoms < 15:
        print(f"  Type: SMALL aliphatic")
    else:
        print(f"  Type: LARGER aliphatic")

    print(f"\n  Key Properties:")
    print(f"    • Gap: {avg_gap:.2f} eV (typical range: 4-9 eV)")
    if avg_gap < 5.5:
        print(f"      → LOW gap: potential semiconductors, colored compounds")
    elif avg_gap > 7.5:
        print(f"      → HIGH gap: stable, transparent molecules")
    else:
        print(f"      → MEDIUM gap: typical organic molecules")

    print(f"    • Size: {avg_atoms:.1f} atoms average")
    print(f"    • Aromaticity: {100*aromatic_frac:.0f}% contain aromatic rings")
    print(f"    • Rings: {avg_rings:.1f} average")
    print(f"    • Polarity: μ = {avg_dipole:.2f} Debye")

    print(f"\n  Functional Groups:")
    print(f"    • Benzene rings: {100*benzene_prev:.0f}%")
    print(f"    • Carbonyl (C=O): {100*carbonyl_prev:.0f}%")
    print(f"    • Amines (NH2): {100*amine_prev:.0f}%")

    # Application suggestions
    print(f"\n  Potential Applications:")
    if avg_gap < 5.5 and aromatic_frac > 0.3:
        print(f"    → Organic photovoltaics, OLEDs, dyes")
    elif avg_gap > 7 and avg_dipole < 1.5:
        print(f"    → Solvents, fuel additives")
    elif carbonyl_prev > 0.3:
        print(f"    → Pharmaceuticals, flavoring agents")
    elif amine_prev > 0.3:
        print(f"    → Drug scaffolds, chemical intermediates")
    else:
        print(f"    → General organic chemistry")

---

### 5.11 Cluster Stability Analysis

After clustering, it's important to assess how **stable** and **well-defined** our clusters are. A good clustering should have:
- Points that are clearly assigned to their cluster (not on boundaries)
- Clusters that are internally cohesive (points within are similar)
- Clusters that are well-separated from each other

#### Silhouette Analysis

The **silhouette coefficient** is a powerful metric for evaluating cluster quality at both the individual sample and cluster level. For each sample $i$, the silhouette score is:

$$s(i) = \frac{b(i) - a(i)}{\max(a(i), b(i))}$$

where:
- $a(i)$ = **mean intra-cluster distance**: average distance from sample $i$ to all other points in its cluster (measures cohesion)
- $b(i)$ = **mean nearest-cluster distance**: average distance from sample $i$ to all points in the nearest neighboring cluster (measures separation)

**Interpretation of silhouette scores:**

| Score Range | Interpretation |
|-------------|----------------|
| **+1** | Sample is far from neighboring clusters (ideal) |
| **0** | Sample is on or very close to the decision boundary |
| **-1** | Sample may have been assigned to the wrong cluster |

#### Silhouette Plot (Knife Plot)

The silhouette plot visualizes the silhouette coefficient for every sample, sorted by cluster. This reveals:
- **Width of each "knife"**: Number of samples in that cluster
- **Thickness/extent to the right**: How well-separated points are
- **Points extending left of zero**: Potentially misclassified samples

In [None]:
# Analyze how stable the clusters are using silhouette scores per sample
from sklearn.metrics import silhouette_samples

# Compute silhouette scores for each sample
silhouette_vals = silhouette_samples(X_cluster, cluster_labels)

# Add to DataFrame
df_valid['silhouette'] = silhouette_vals

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart of average silhouette per cluster
ax1 = axes[0]
avg_silhouettes = [silhouette_vals[cluster_labels == i].mean() for i in range(optimal_k)]
bars = ax1.bar(range(optimal_k), avg_silhouettes, color=cluster_colors)
ax1.axhline(y=np.mean(silhouette_vals), color='red', linestyle='--', label=f'Overall: {np.mean(silhouette_vals):.3f}')
ax1.set_xlabel('Cluster', fontsize=12)
ax1.set_ylabel('Average Silhouette Score', fontsize=12)
ax1.set_title('Cluster Cohesion (Silhouette Analysis)', fontsize=14)
ax1.set_xticks(range(optimal_k))
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# Add value labels
for bar, val in zip(bars, avg_silhouettes):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
             f'{val:.3f}', ha='center', va='bottom', fontsize=10)

# Silhouette plot (knife plot)
ax2 = axes[1]
y_lower = 10
for i in range(optimal_k):
    ith_silhouette = silhouette_vals[cluster_labels == i]
    ith_silhouette.sort()

    size_cluster = ith_silhouette.shape[0]
    y_upper = y_lower + size_cluster

    ax2.fill_betweenx(np.arange(y_lower, y_upper), 0, ith_silhouette,
                       facecolor=cluster_colors[i], alpha=0.7)

    # Label cluster
    ax2.text(-0.05, y_lower + 0.5 * size_cluster, str(i), fontsize=10)
    y_lower = y_upper + 10

ax2.axvline(x=np.mean(silhouette_vals), color='red', linestyle='--')
ax2.set_xlabel('Silhouette Coefficient', fontsize=12)
ax2.set_ylabel('Cluster', fontsize=12)
ax2.set_title('Silhouette Plot', fontsize=14)
ax2.set_xlim(-0.2, 1)
ax2.set_yticks([])

plt.tight_layout()
plt.show()

print(f"Overall silhouette score: {np.mean(silhouette_vals):.3f}")
print("Cluster-wise averages:")
for i in range(optimal_k):
    cluster_score = silhouette_vals[cluster_labels == i].mean()
    print(f"  Cluster {i}: {cluster_score:.3f}")

**Interpreting the Results:**
- **Higher cluster-wise scores** indicate more cohesive, well-separated clusters
- **Lower scores** suggest overlap with neighboring clusters
- **The knife plot width** shows cluster sizes—look for uniform widths and scores above the mean (red dashed line)

---

### 5.12 Summary: Clustering and Unsupervised Learning

In [None]:
print("Section 5 Summary: Clustering and Unsupervised Learning")
print("=" * 60)

print(f"""
✓ What We Learned:
  ─────────────────
  - K-means effectively groups molecules by fingerprint similarity
  - Elbow method + silhouette score help choose optimal k
  - Clusters reveal chemical families (aromatics, small molecules, etc.)
  - UMAP visualization shows cluster structure in 2D

✓ Key Results:
  ─────────────
  - Optimal k: {optimal_k} clusters
  - Overall silhouette score: {np.mean(silhouette_vals):.3f}
  - Clusters differ in:
    • Size (atom count)
    • Aromaticity
    • Electronic properties (gap)
    • Functional group composition

✓ Cluster Characteristics Summary:
  ─────────────────────────────────
""")

for i in range(optimal_k):
    mask = df_valid['cluster'] == i
    count = np.sum(mask)
    aromatic = 100 * (df_valid[mask]['num_aromatic_rings'] > 0).mean()
    gap = df_valid[mask]['gap_eV'].mean()
    print(f"  Cluster {i}: n={count:5,}, Gap={gap:.1f}eV, Aromatic={aromatic:.0f}%")

print(f"""
✓ Variables Available:
  ─────────────────────
  - cluster_labels: array of shape {cluster_labels.shape}
  - df_valid['cluster']: cluster assignments
  - df_valid['silhouette']: per-molecule silhouette scores
  - kmeans_final: fitted KMeans model
  - df_cluster_stats: cluster statistics DataFrame

✓ Applications of Clustering:
  ────────────────────────────
  - Stratified train/test splitting
  - Diversity sampling for screening
  - Lead optimization (explore cluster space)
  - Scaffold hopping (move between clusters)
  - Outlier detection (low silhouette samples)

Next: Property Prediction (see separate tutorial)
""")

---

## Section 6: Summary

### What We Learned

In this tutorial, we explored the QM9 dataset through:

1. **Data Loading**: Parsed XYZ files to extract molecular properties and structures
2. **EDA**: Analyzed distributions, correlations, and structure-property relationships
3. **Featurization**: Generated Morgan and MACCS fingerprints from SMILES
4. **Visualization**: Used PCA and UMAP to explore chemical space
5. **Clustering**: Applied K-means to identify molecular families

### Key Takeaways

- The HOMO-LUMO gap is a critical electronic property centered around 6-7 eV in QM9
- Molecular fingerprints effectively encode structural information for ML
- UMAP reveals meaningful clusters that correspond to chemical families
- Larger molecules tend to have smaller band gaps (extended conjugation)

---

## Next Steps: Property Prediction

**Continue with:** [QM9 Property Prediction Tutorial](./QM9-Property-Prediction.md)

This tutorial covers:
- Train/test splitting for molecular data
- Regression models: Linear, Ridge, Random Forest, Kernel Ridge Regression
- Model evaluation and comparison
- Error analysis and chemical interpretation

---

## Exercise

### Task: Clustering Analysis with Different Parameters and Representations

Using the skills you learned in this tutorial, complete the following exercise:

#### Part A: Varying the Number of Clusters

1. **Parameter Exploration**: Run K-means clustering with k=4, k=8, and k=10 using Morgan fingerprints. Compare the silhouette scores and choose the best k for this dataset.

2. **UMAP Visualization**: Visualize each clustering result on a UMAP plot. Color points by cluster assignment.

3. **Property Analysis**: For your best clustering (highest silhouette score):
   - Calculate the mean HOMO-LUMO gap for each cluster
   - Identify which cluster has the smallest average gap (potentially most interesting for organic electronics)
   - Report the percentage of aromatic molecules in each cluster

4. **Chemical Interpretation**: Write 2-3 sentences explaining what chemical features distinguish the cluster with the smallest gap from others.

#### Part B: Comparing Molecular Representations (Advanced)

The tutorial used **Morgan fingerprints**, which encode 2D molecular topology (atoms and bonds). However, for 3D-sensitive properties like HOMO-LUMO gap, representations that encode **3D geometry** may be more appropriate.

**Task**: Compare different molecular representations for PCA, UMAP visualization, and clustering:

| Representation | Type | What it Encodes | Library |
|----------------|------|-----------------|---------|
| **Morgan/ECFP** | 2D | Atom environments, connectivity | RDKit |
| **Behler-Parrinello Symmetry Functions** | 3D | Local atomic environments via radial/angular functions | DScribe, AMP |
| **SOAP (Smooth Overlap of Atomic Positions)** | 3D | Local atomic density with rotational invariance | DScribe, quippy |
| **Coulomb Matrix** | 3D | Nuclear charges and distances | DScribe |

**Steps**:
1. Install DScribe: `pip install dscribe`
2. Extract 3D coordinates from the QM9 XYZ files (already parsed in `molecules_raw`)
3. Generate Behler-Parrinello or SOAP descriptors for each molecule
4. Repeat the PCA, UMAP, and K-means analysis with the new representations
5. Compare: Do 3D representations produce different clusters? Do they correlate better with electronic properties?

**Example using DScribe for SOAP**:

In [None]:
from dscribe.descriptors import SOAP
from ase import Atoms

# Create SOAP descriptor
soap = SOAP(
    species=["C", "H", "O", "N", "F"],
    r_cut=5.0,
    n_max=8,
    l_max=6,
    average="inner"  # Average over atoms to get molecule-level descriptor
)

# Convert QM9 molecule to ASE Atoms object
def mol_to_ase(mol_data):
    return Atoms(
        symbols=mol_data['atoms'],
        positions=mol_data['coordinates']
    )

# Generate SOAP descriptors
soap_fingerprints = np.array([
    soap.create(mol_to_ase(mol)) for mol in molecules_raw
])

**Discussion Questions**:
- How do the UMAP embeddings differ between 2D (Morgan) and 3D (SOAP) representations?
- Which representation produces clusters that better separate molecules by HOMO-LUMO gap?
- For which properties might 2D fingerprints be sufficient, and for which are 3D representations essential?

**Deliverables:**
- Silhouette scores for k=4, 8, 10 with Morgan fingerprints
- UMAP plots colored by cluster
- (Advanced) Comparison of Morgan vs SOAP/Behler-Parrinello representations
- A brief written interpretation of your findings

**Hint:** You can use the `df_valid` DataFrame, `morgan_fingerprints` array, and `molecules_raw` list from the tutorial.