# Module 2: Data Foundation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NabKh/ML-for-Materials-Science/blob/main/Tutorial-07-ML-Discovery/notebooks/02_data_foundation.ipynb)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/NabKh/ML-for-Materials-Science/main?labpath=Tutorial-07-ML-Discovery/notebooks/02_data_foundation.ipynb)

---

> **Before You Start:** Please check the [INSTALLATION_GUIDE.md](../../INSTALLATION_GUIDE.md) for setup instructions. For Google Colab:
> ```python
> !pip install pymatgen matminer shap -q
> ```
> Then restart the runtime (Runtime ‚Üí Restart runtime).

---

## üéØ Learning Objectives

By the end of this module, you will be able to:

1. **Access** the Materials Project database using the API
2. **Query** materials by properties (band gap, formation energy, etc.)
3. **Clean** messy materials data (handle missing values, outliers)
4. **Split** data properly to avoid leakage
5. **Understand** the importance of data quality for ML

---

**‚è±Ô∏è Estimated time: 60 minutes**

**üìö Difficulty: üü¢ Beginner**

## üì¶ Setup

### API Key Required!

To access the Materials Project, you need a free API key:

1. Go to [materialsproject.org](https://materialsproject.org)
2. Create an account (free)
3. Go to Dashboard ‚Üí API ‚Üí Generate API Key
4. Copy your key and paste it below

In [1]:
# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Create figures directory
os.makedirs('figures', exist_ok=True)

# Materials Project API
from mp_api.client import MPRester

# Pymatgen for materials analysis
from pymatgen.core import Composition, Structure

# Scikit-learn
from sklearn.model_selection import train_test_split, GroupShuffleSplit

# Interactive widgets
import ipywidgets as widgets
from IPython.display import display, HTML

# Plotting style
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 6)

print("‚úÖ Libraries imported successfully!")

‚úÖ Libraries imported successfully!


In [2]:
# Enter your Materials Project API key here
# Get one at: https://materialsproject.org/api

MP_API_KEY = "lE8Uu9bCjX256qMEAAULD4FBT0HLsrPP"  # Replace with your key!

# Verify the key works
try:
    with MPRester(MP_API_KEY) as mpr:
        # Simple test query
        test = mpr.materials.summary.search(material_ids=["mp-149"], fields=["formula_pretty"])
        print(f"‚úÖ API key valid! Connected to Materials Project.")
        print(f"   Test query: {test[0].formula_pretty}")
except Exception as e:
    print(f"‚ùå API key error: {e}")
    print("   Please get a valid key from materialsproject.org/api")

Retrieving SummaryDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

‚úÖ API key valid! Connected to Materials Project.
   Test query: Si


<cell_type>markdown</cell_type>---

## 1. Introduction to Materials Databases

### üìñ Theory

<div style="background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%); padding: 20px; border-radius: 10px; border-left: 4px solid #6366f1;">

**Materials databases** are repositories of computed or experimental material properties. They enable ML by providing the large datasets needed for training.

| Database | Size | Focus | Access |
|----------|------|-------|--------|
| Materials Project | ~150,000 | Inorganic | API (free) |
| AFLOW | ~3.5M | Inorganic | API (free) |
| OQMD | ~1M | Inorganic | Download |
| JARVIS-DFT | ~75,000 | 2D + bulk | API (free) |
| NOMAD | ~12M | Various | API (free) |

</div>

### Key Material Properties

**Band Gap ($E_g$)**: The energy difference between the valence band maximum (VBM) and conduction band maximum (CBM):

$$E_g = E_{CBM} - E_{VBM}$$

- $E_g = 0$: Metal (conductor)
- $0 < E_g < 3$ eV: Semiconductor
- $E_g > 3$ eV: Insulator

**Formation Energy ($\Delta H_f$)**: The energy change when forming a compound from its constituent elements in their standard states:

$$\Delta H_f = E_{compound} - \sum_i n_i \cdot E_i^{ref}$$

where $n_i$ is the number of atoms of element $i$ and $E_i^{ref}$ is the reference energy per atom.

- $\Delta H_f < 0$: Compound is stable relative to elements
- More negative = more stable

**Energy Above Hull ($E_{hull}$)**: The energy above the thermodynamic ground state (convex hull):

$$E_{hull} = E_{compound} - E_{hull}^{ref}$$

- $E_{hull} = 0$: On the convex hull (thermodynamically stable)
- $E_{hull} > 0$: Metastable (may decompose into more stable phases)

In this tutorial, we focus on the **Materials Project** due to its:
- High data quality (consistent DFT calculations)
- Excellent API (MPRester)
- Rich property data
- Active maintenance

---

## 2. Querying the Materials Project

### Basic Queries

The Materials Project API lets you search for materials by various criteria.

In [3]:
# Example 1: Get a specific material by ID
with MPRester(MP_API_KEY) as mpr:
    # Silicon (diamond structure)
    si = mpr.materials.summary.search(
        material_ids=["mp-149"],
        fields=["material_id", "formula_pretty", "band_gap", "formation_energy_per_atom",
                "density", "symmetry"]
    )[0]
    
print("Silicon (mp-149):")
print(f"  Formula: {si.formula_pretty}")
print(f"  Band gap: {si.band_gap:.3f} eV")
print(f"  Formation energy: {si.formation_energy_per_atom:.3f} eV/atom")
print(f"  Density: {si.density:.3f} g/cm¬≥")
print(f"  Space group: {si.symmetry.symbol}")

Retrieving SummaryDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Silicon (mp-149):
  Formula: Si
  Band gap: 0.611 eV
  Formation energy: 0.000 eV/atom
  Density: 2.313 g/cm¬≥
  Space group: Fd-3m


In [4]:
# Example 2: Search for materials with specific properties
with MPRester(MP_API_KEY) as mpr:
    # Find semiconductors with band gap between 1-2 eV
    semiconductors = mpr.materials.summary.search(
        band_gap=(1.0, 2.0),  # eV
        is_stable=True,       # Only thermodynamically stable
        fields=["material_id", "formula_pretty", "band_gap", "nelements"],
        num_chunks=1,
        chunk_size=100  # Limit to 100 results
    )

print(f"Found {len(semiconductors)} semiconductors with band gap 1-2 eV")
print("\nFirst 10:")
for mat in semiconductors[:10]:
    print(f"  {mat.material_id}: {mat.formula_pretty:15} Eg = {mat.band_gap:.3f} eV")

Retrieving SummaryDoc documents:   0%|          | 0/100 [00:00<?, ?it/s]

Found 100 semiconductors with band gap 1-2 eV

First 10:
  mp-560328: Ag15P4S16Cl3    Eg = 1.289 eV
  mp-861942: Ag2GePbS4       Eg = 1.370 eV
  mp-690687: Ag2H2IOF        Eg = 1.459 eV
  mp-23485: Ag2HgI4         Eg = 1.222 eV
  mp-556866: Ag2HgSI2        Eg = 1.221 eV
  mp-27966: Ag2Mo2O7        Eg = 1.901 eV
  mp-1190325: Ag2P2PdO7       Eg = 1.167 eV
  mp-707138: Ag2PHO4         Eg = 1.327 eV
  mp-13956: Ag2PSe3         Eg = 1.187 eV
  mp-5625: Ag2SO4          Eg = 1.229 eV


In [5]:
# Example 3: Search by chemical system (elements)
with MPRester(MP_API_KEY) as mpr:
    # All compounds containing only Si and O
    si_o_compounds = mpr.materials.summary.search(
        chemsys="Si-O",
        fields=["material_id", "formula_pretty", "band_gap", "formation_energy_per_atom"],
    )

print(f"Found {len(si_o_compounds)} compounds in the Si-O system")
print("\nAll Si-O compounds:")
for mat in si_o_compounds:
    print(f"  {mat.material_id}: {mat.formula_pretty:10} Eg={mat.band_gap:.2f} eV, Ef={mat.formation_energy_per_atom:.3f} eV/atom")

Retrieving SummaryDoc documents:   0%|          | 0/343 [00:00<?, ?it/s]

Found 343 compounds in the Si-O system

All Si-O compounds:
  mp-1194828: Si17O37    Eg=0.09 eV, Ef=-2.919 eV/atom
  mp-1199711: Si18O29    Eg=0.09 eV, Ef=-2.814 eV/atom
  mp-1063118: Si2O       Eg=0.00 eV, Ef=-0.838 eV/atom
  mp-1208867: Si2O       Eg=0.00 eV, Ef=0.688 eV/atom
  mp-1179195: Si2O3      Eg=0.99 eV, Ef=-2.165 eV/atom
  mp-862998: Si2O5      Eg=3.05 eV, Ef=-2.227 eV/atom
  mp-32566: Si3O       Eg=0.16 eV, Ef=-0.776 eV/atom
  mp-32881: Si3O       Eg=0.19 eV, Ef=-0.782 eV/atom
  mp-1250755: Si3O7      Eg=0.20 eV, Ef=-2.345 eV/atom
  mp-638900: Si3O7      Eg=0.10 eV, Ef=-2.027 eV/atom
  mp-1219421: Si3O8      Eg=0.07 eV, Ef=-2.105 eV/atom
  mp-1221354: Si48O107   Eg=0.03 eV, Ef=-2.855 eV/atom
  mp-1179275: Si4O9      Eg=0.00 eV, Ef=-2.773 eV/atom
  mp-1179651: Si5O8      Eg=3.45 eV, Ef=-2.718 eV/atom
  mp-1173536: Si6O13     Eg=0.08 eV, Ef=-2.277 eV/atom
  mp-530027: Si6O13     Eg=2.08 eV, Ef=-2.760 eV/atom
  mp-673849: Si6O13     Eg=0.05 eV, Ef=-2.183 eV/atom
  mp-731864: S

<cell_type>markdown</cell_type>### üîç Building Custom Queries

Let's explore how to build queries with different parameters.

In [None]:
# Example: Custom query with different parameters
# You can modify these parameters to explore different material spaces

query_params = {
    'band_gap_min': 0.5,      # eV
    'band_gap_max': 2.0,      # eV  
    'max_elements': 3,         # Binary and ternary only
    'stable_only': True        # Only thermodynamically stable
}

print(f"üîç Query Parameters:")
print(f"   Band gap range: {query_params['band_gap_min']} - {query_params['band_gap_max']} eV")
print(f"   Max elements: {query_params['max_elements']}")
print(f"   Stable only: {query_params['stable_only']}")
print("\n‚è≥ Querying Materials Project...")

with MPRester(MP_API_KEY) as mpr:
    custom_results = mpr.materials.summary.search(
        band_gap=(query_params['band_gap_min'], query_params['band_gap_max']),
        nelements=(1, query_params['max_elements']),
        is_stable=query_params['stable_only'] if query_params['stable_only'] else None,
        fields=["material_id", "formula_pretty", "band_gap", "nelements"],
        num_chunks=1,
        chunk_size=50
    )

print(f"\n‚úÖ Found {len(custom_results)} materials!")
print("\nSample results (first 15):")
print("-" * 50)
for mat in custom_results[:15]:
    print(f"  {mat.formula_pretty:15} | Eg = {mat.band_gap:.3f} eV | {mat.nelements} elements")

print("\nüí° Try modifying the query_params above to explore different material spaces!")

---

## 3. Building a Dataset for ML

Let's build a real dataset for band gap prediction.

In [7]:
# Build a dataset of semiconductors and insulators
print("üîÑ Building dataset from Materials Project...")
print("   This may take a minute...\n")

with MPRester(MP_API_KEY) as mpr:
    # Query materials with band gap > 0 (not metals)
    materials = mpr.materials.summary.search(
        band_gap=(0.1, 6.0),  # Non-metals
        nelements=(2, 4),      # Binary to quaternary
        is_stable=True,        # Thermodynamically stable
        fields=[
            "material_id",
            "formula_pretty",
            "composition",
            "band_gap",
            "formation_energy_per_atom",
            "energy_above_hull",
            "density",
            "volume",
            "nelements",
            "nsites",
            "symmetry",
        ],
        num_chunks=5,
        chunk_size=500  # Get up to 2500 materials
    )

print(f"‚úÖ Downloaded {len(materials)} materials")

üîÑ Building dataset from Materials Project...
   This may take a minute...



   [31mnelements[39m
To ensure long term support, please use their [34m`search`[39m equivalents:
   [34mnum_elements[39m
Please see the documentation:
    [34m`search`: https://materialsproject.github.io/api/_autosummary/mp_api.client.routes.materials.summary.SummaryRester.html#mp_api.client.routes.materials.summary.SummaryRester.search[39m
   [31m`_search`: https://api.materialsproject.org/redoc#tag/Materials-Summary/operation/search_materials_summary__get[39m


Retrieving SummaryDoc documents:   0%|          | 0/2500 [00:00<?, ?it/s]

‚úÖ Downloaded 2500 materials


In [8]:
# Convert to DataFrame
data = []
for mat in materials:
    data.append({
        'material_id': mat.material_id,
        'formula': mat.formula_pretty,
        'composition': mat.composition,
        'band_gap': mat.band_gap,
        'formation_energy': mat.formation_energy_per_atom,
        'energy_above_hull': mat.energy_above_hull,
        'density': mat.density,
        'volume': mat.volume,
        'nelements': mat.nelements,
        'nsites': mat.nsites,
        'spacegroup': mat.symmetry.symbol if mat.symmetry else None,
    })

df = pd.DataFrame(data)
print(f"\nDataFrame shape: {df.shape}")
print(f"\nColumns: {list(df.columns)}")
df.head()


DataFrame shape: (2500, 11)

Columns: ['material_id', 'formula', 'composition', 'band_gap', 'formation_energy', 'energy_above_hull', 'density', 'volume', 'nelements', 'nsites', 'spacegroup']


Unnamed: 0,material_id,formula,composition,band_gap,formation_energy,energy_above_hull,density,volume,nelements,nsites,spacegroup
0,mp-11107,Ac2O3,"(Ac, O)",3.5226,-3.737668,0.0,9.10913,91.511224,2,5,P-3m1
1,mp-32800,Ac2S3,"(Ac, S)",2.2962,-2.493064,0.0,6.535149,1118.407852,2,40,I-42d
2,mp-1183115,AcAlO3,"(Ac, Al, O)",4.1024,-3.690019,0.0,8.72823,57.451413,3,5,Pm-3m
3,mp-27972,AcBr3,"(Ac, Br)",4.1033,-2.494519,0.0,5.679086,272.928947,2,8,P6_3/m
4,mp-30274,AcBrO,"(Ac, Br, O)",4.241,-3.396186,0.0,7.65229,140.13941,3,6,P4/nmm


---

## 4. Data Cleaning

Real-world data is messy! Let's clean it up.

In [9]:
# Check for missing values
print("Missing values:")
print(df.isnull().sum())
print(f"\nTotal rows: {len(df)}")

Missing values:
material_id          0
formula              0
composition          0
band_gap             0
formation_energy     0
energy_above_hull    0
density              0
volume               0
nelements            0
nsites               0
spacegroup           0
dtype: int64

Total rows: 2500


In [None]:
# Check data distributions with professional styling
fig, axes = plt.subplots(2, 2, figsize=(14, 11), facecolor='white')

# Color palette - consistent with other figures
colors = {
    'primary': '#6366f1',      # Indigo
    'secondary': '#0ea5e9',    # Sky blue
    'tertiary': '#10b981',     # Green
    'accent': '#f59e0b',       # Orange
    'text': '#1e293b',
    'grid': '#e2e8f0'
}

# Band gap distribution
ax1 = axes[0, 0]
n, bins, patches = ax1.hist(df['band_gap'], bins=50, color=colors['primary'], 
                            alpha=0.8, edgecolor='white', linewidth=0.5)
ax1.set_xlabel('Band Gap (eV)', fontsize=12, color=colors['text'])
ax1.set_ylabel('Count', fontsize=12, color=colors['text'])
ax1.set_title('Band Gap Distribution', fontsize=14, fontweight='bold', color=colors['text'])
median_val = df['band_gap'].median()
ax1.axvline(median_val, color='#ef4444', linestyle='--', linewidth=2, 
            label=f'Median: {median_val:.2f} eV')
ax1.legend(fontsize=10, framealpha=0.95)
ax1.set_facecolor('white')
ax1.grid(True, alpha=0.3, color=colors['grid'])

# Add classification regions
ax1.axvspan(0, 0.5, alpha=0.1, color='red', label='Metal-like')
ax1.axvspan(0.5, 3.0, alpha=0.1, color='orange')
ax1.axvspan(3.0, 6.0, alpha=0.1, color='blue')
ax1.text(1.5, ax1.get_ylim()[1]*0.9, 'Semiconductors', fontsize=9, ha='center', color=colors['accent'])
ax1.text(4.5, ax1.get_ylim()[1]*0.9, 'Insulators', fontsize=9, ha='center', color=colors['primary'])

# Formation energy distribution
ax2 = axes[0, 1]
ax2.hist(df['formation_energy'], bins=50, color=colors['secondary'], 
         alpha=0.8, edgecolor='white', linewidth=0.5)
ax2.set_xlabel('Formation Energy (eV/atom)', fontsize=12, color=colors['text'])
ax2.set_ylabel('Count', fontsize=12, color=colors['text'])
ax2.set_title('Formation Energy Distribution', fontsize=14, fontweight='bold', color=colors['text'])
ax2.axvline(0, color='#ef4444', linestyle='-', linewidth=2, alpha=0.7)
ax2.text(0.2, ax2.get_ylim()[1]*0.8, 'Unstable\n(vs. elements)', fontsize=9, color='#ef4444')
ax2.text(-2, ax2.get_ylim()[1]*0.8, 'Stable\n(vs. elements)', fontsize=9, color=colors['tertiary'])
ax2.set_facecolor('white')
ax2.grid(True, alpha=0.3, color=colors['grid'])

# Number of elements - improved bar chart
ax3 = axes[1, 0]
element_counts = df['nelements'].value_counts().sort_index()
bars = ax3.bar(element_counts.index.astype(str), element_counts.values, 
               color=[colors['secondary'], colors['tertiary'], colors['accent']][:len(element_counts)],
               edgecolor='white', linewidth=1.5, alpha=0.8)
ax3.set_xlabel('Number of Elements', fontsize=12, color=colors['text'])
ax3.set_ylabel('Count', fontsize=12, color=colors['text'])
ax3.set_title('Compositional Complexity', fontsize=14, fontweight='bold', color=colors['text'])

# Add labels on bars
for bar, count in zip(bars, element_counts.values):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20, 
             f'{count}', ha='center', va='bottom', fontsize=11, fontweight='bold', color=colors['text'])

# Add element type labels
labels = ['Binary', 'Ternary', 'Quaternary']
for bar, label in zip(bars, labels[:len(bars)]):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height()/2, 
             label, ha='center', va='center', fontsize=10, color='white', fontweight='bold')

ax3.set_facecolor('white')
ax3.grid(True, alpha=0.3, axis='y', color=colors['grid'])

# Density vs Band Gap scatter plot
ax4 = axes[1, 1]
scatter = ax4.scatter(df['density'], df['band_gap'], 
                      c=df['nelements'], cmap='viridis', 
                      alpha=0.6, s=40, edgecolors='white', linewidth=0.3)
ax4.set_xlabel('Density (g/cm¬≥)', fontsize=12, color=colors['text'])
ax4.set_ylabel('Band Gap (eV)', fontsize=12, color=colors['text'])
ax4.set_title('Density vs Band Gap', fontsize=14, fontweight='bold', color=colors['text'])
cbar = plt.colorbar(scatter, ax=ax4)
cbar.set_label('Number of Elements', fontsize=11, color=colors['text'])
ax4.set_facecolor('white')
ax4.grid(True, alpha=0.3, color=colors['grid'])

# Add trend annotation
ax4.text(12, 5.5, 'Higher density materials\ntend to have varied band gaps', 
         fontsize=9, ha='center', color=colors['text'], style='italic',
         bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor=colors['grid'], alpha=0.9))

plt.tight_layout()
plt.savefig('figures/02_data_distributions.png', dpi=200, bbox_inches='tight', facecolor='white')
plt.show()

print("\nFigure saved to figures/02_data_distributions.png")

In [11]:
# Data cleaning steps
print("Data Cleaning Steps:")
print("="*50)

# 1. Remove rows with missing values
df_clean = df.dropna()
print(f"1. After removing missing values: {len(df_clean)} rows")

# 2. Remove outliers (e.g., unrealistic densities)
df_clean = df_clean[(df_clean['density'] > 0) & (df_clean['density'] < 25)]
print(f"2. After removing density outliers: {len(df_clean)} rows")

# 3. Remove very large unit cells (might be problematic)
df_clean = df_clean[df_clean['nsites'] <= 50]
print(f"3. After limiting unit cell size: {len(df_clean)} rows")

# 4. Check for duplicates (same composition)
n_duplicates = df_clean.duplicated(subset=['formula']).sum()
print(f"4. Duplicate formulas found: {n_duplicates}")

print(f"\n‚úÖ Final clean dataset: {len(df_clean)} materials")

Data Cleaning Steps:
1. After removing missing values: 2500 rows
2. After removing density outliers: 2500 rows
3. After limiting unit cell size: 1966 rows
4. Duplicate formulas found: 9

‚úÖ Final clean dataset: 1966 materials


<cell_type>markdown</cell_type>---

## 5. Proper Train/Test Splitting

### ‚ö†Ô∏è The Data Leakage Problem in Materials

<div style="background: rgba(239, 68, 68, 0.1); padding: 15px; border-radius: 10px; border-left: 4px solid #ef4444;">

**Standard random splitting can cause data leakage in materials ML!**

**What is Data Leakage?**

Data leakage occurs when information from the test set "leaks" into the training process, leading to overly optimistic performance estimates that don't generalize to truly new data.

**Materials-Specific Leakage Sources:**

1. **Compositional Similarity**: Training on Fe‚ÇÇO‚ÇÉ and testing on Fe‚ÇÉO‚ÇÑ - the model learns Fe-O bonding patterns
2. **Structural Polymorphs**: Same composition with different structures (e.g., anatase vs rutile TiO‚ÇÇ)
3. **Substitutional Variants**: Li‚ÇÇO vs Na‚ÇÇO share similar chemistry

**Mathematical Impact:**

If we define generalization error as:

$$\epsilon_{gen} = \mathbb{E}[(f(x) - y)^2]$$

With leakage, we measure:

$$\epsilon_{leaked} < \epsilon_{gen}$$

This gives a false sense of model accuracy!

</div>

In [None]:
# Visualize the data leakage concept
fig, axes = plt.subplots(1, 2, figsize=(14, 5), facecolor='white')

from matplotlib.patches import FancyBboxPatch, Circle, FancyArrowPatch
from matplotlib.collections import PatchCollection

colors = {
    'train': '#0ea5e9',
    'test': '#ec4899', 
    'overlap': '#ef4444',
    'good': '#10b981',
    'text': '#1e293b',
    'bg': '#f8fafc'
}

# LEFT: Bad split (with leakage)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.set_title('Random Split (Data Leakage Risk)', fontsize=14, fontweight='bold', color=colors['text'])

# Draw overlapping circles representing Fe compounds
train_circle = plt.Circle((3.5, 5), 2.5, color=colors['train'], alpha=0.3, label='Training Set')
test_circle = plt.Circle((6.5, 5), 2.5, color=colors['test'], alpha=0.3, label='Test Set')
ax1.add_patch(train_circle)
ax1.add_patch(test_circle)

# Add compound labels
ax1.text(2.5, 6, r'$Fe_2O_3$', fontsize=11, ha='center', color=colors['train'], fontweight='bold')
ax1.text(2.5, 4.5, r'$Fe_3O_4$', fontsize=11, ha='center', color=colors['train'], fontweight='bold')
ax1.text(3.5, 3.5, r'$FeO$', fontsize=11, ha='center', color=colors['train'], fontweight='bold')

ax1.text(7.5, 6, r'$Fe_2O_3$', fontsize=11, ha='center', color=colors['test'], fontweight='bold')
ax1.text(7.5, 4.5, r'$FeO$', fontsize=11, ha='center', color=colors['test'], fontweight='bold')
ax1.text(6.5, 3.5, r'$Fe_3O_4$', fontsize=11, ha='center', color=colors['test'], fontweight='bold')

# Overlap region
ax1.text(5, 5, 'Fe-O\npatterns\nshared!', fontsize=10, ha='center', va='center', 
         color=colors['overlap'], fontweight='bold',
         bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor=colors['overlap'], alpha=0.9))

ax1.text(5, 1.5, 'Model learns Fe-O bonding from train,\n"cheats" on Fe compounds in test', 
         fontsize=10, ha='center', color=colors['overlap'], style='italic')

ax1.legend(loc='upper right', fontsize=10)
ax1.axis('off')
ax1.set_facecolor(colors['bg'])

# RIGHT: Good split (grouped)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.set_title('Grouped Split (No Leakage)', fontsize=14, fontweight='bold', color=colors['text'])

# Draw separate circles
train_circle2 = plt.Circle((3, 5), 2.5, color=colors['train'], alpha=0.3, label='Training Set')
test_circle2 = plt.Circle((7, 5), 2.5, color=colors['good'], alpha=0.3, label='Test Set')
ax2.add_patch(train_circle2)
ax2.add_patch(test_circle2)

# Training compounds (Fe-based)
ax2.text(2, 6.5, r'$Fe_2O_3$', fontsize=11, ha='center', color=colors['train'], fontweight='bold')
ax2.text(3, 5.5, r'$Fe_3O_4$', fontsize=11, ha='center', color=colors['train'], fontweight='bold')
ax2.text(4, 4.5, r'$FeO$', fontsize=11, ha='center', color=colors['train'], fontweight='bold')
ax2.text(2.5, 3.5, r'$FeCl_3$', fontsize=11, ha='center', color=colors['train'], fontweight='bold')

# Test compounds (Zn-based - different chemistry)
ax2.text(6, 6.5, r'$ZnO$', fontsize=11, ha='center', color=colors['good'], fontweight='bold')
ax2.text(7, 5.5, r'$ZnS$', fontsize=11, ha='center', color=colors['good'], fontweight='bold')
ax2.text(8, 4.5, r'$ZnCl_2$', fontsize=11, ha='center', color=colors['good'], fontweight='bold')
ax2.text(7, 3.5, r'$ZnSe$', fontsize=11, ha='center', color=colors['good'], fontweight='bold')

ax2.text(3, 2, 'Fe compounds', fontsize=10, ha='center', color=colors['train'], fontweight='bold')
ax2.text(7, 2, 'Zn compounds', fontsize=10, ha='center', color=colors['good'], fontweight='bold')

ax2.text(5, 1.2, 'No element overlap between train and test\nTests true generalization ability!', 
         fontsize=10, ha='center', color=colors['good'], style='italic')

ax2.legend(loc='upper right', fontsize=10)
ax2.axis('off')
ax2.set_facecolor(colors['bg'])

plt.tight_layout()
plt.savefig('figures/02_data_leakage.png', dpi=200, bbox_inches='tight', facecolor='white')
plt.show()

print("Figure saved to figures/02_data_leakage.png")

In [13]:
# Method 2: Group by elements to avoid leakage
def get_element_set(composition):
    """Get the set of elements in a composition as a frozen set."""
    return frozenset(str(el) for el in composition.elements)

# Create groups based on element combinations
df_clean['element_group'] = df_clean['composition'].apply(get_element_set)

# Use GroupShuffleSplit to ensure no element overlap
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
groups = df_clean['element_group'].apply(lambda x: hash(x))  # Convert to hashable

train_idx, test_idx = next(gss.split(X, y, groups))

X_train_grouped = X.iloc[train_idx]
X_test_grouped = X.iloc[test_idx]
y_train_grouped = y.iloc[train_idx]
y_test_grouped = y.iloc[test_idx]

print("Method 2: Group Split by Elements")
print(f"  Training set: {len(X_train_grouped)} samples")
print(f"  Test set: {len(X_test_grouped)} samples")
print("\n‚úÖ This reduces the risk of data leakage!")

Method 2: Group Split by Elements
  Training set: 1573 samples
  Test set: 393 samples

‚úÖ This reduces the risk of data leakage!


### üíæ Save the Dataset

In [14]:
# Save the clean dataset for future use
df_clean.to_csv('../data/sample_datasets/materials_bandgap.csv', index=False)
print("‚úÖ Dataset saved to data/sample_datasets/materials_bandgap.csv")
print(f"   Shape: {df_clean.shape}")

‚úÖ Dataset saved to data/sample_datasets/materials_bandgap.csv
   Shape: (1966, 12)


---

## üìù Exercises

### Exercise 1: Query Different Properties

Query the Materials Project for materials with:
- Formation energy < -1 eV/atom (very stable)
- Contains oxygen
- Has 2-3 elements

In [15]:
# Exercise 1: Your code here
# with MPRester(MP_API_KEY) as mpr:
#     stable_oxides = mpr.materials.summary.search(
#         # Add your query parameters
#     )

# print(f"Found {len(stable_oxides)} stable oxides")

### Exercise 2: Explore Data Quality

Create a function to check data quality metrics.

In [16]:
# Exercise 2: Complete this function
def data_quality_report(df):
    """Generate a data quality report for a materials DataFrame."""
    report = {}
    
    # TODO: Calculate these metrics
    # report['n_samples'] = ...
    # report['n_features'] = ...
    # report['missing_values'] = ...
    # report['duplicate_formulas'] = ...
    
    return report

# Test your function
# report = data_quality_report(df_clean)
# for key, value in report.items():
#     print(f"{key}: {value}")

---

## ‚úÖ Module Summary

### Key Takeaways

1. **Materials databases** (Materials Project, AFLOW, etc.) provide the data foundation for ML
2. **MPRester** allows powerful queries by properties, elements, and stability
3. **Data cleaning** is essential: handle missing values, outliers, and duplicates
4. **Proper splitting** avoids data leakage in materials ML
5. **Group-based splits** are recommended when compositional overlap is a concern

### What's Next?

In **Module 3: Featurization Basics**, you'll learn to:
- Convert compositions and structures to numerical features
- Use matminer's 70+ featurizers
- Select the most important features

---

**üìö Continue to Module 3:** [Featurization Basics](03_featurization_basics.ipynb)