# Graph Building from Crystal Structures

This notebook covers:
1. Loading material IDs from CSV
2. Fetching crystal structures from Materials Project API
3. Building neighbor lists and graph representations
4. Creating graph data structures with node and edge features

**Purpose**: Convert crystal structures into graph representations suitable for GNN training.

## Part 1: Load Material IDs from CSV

In [37]:
# Import required libraries
import pandas as pd
import json
import numpy as np
from tqdm import tqdm

In [38]:
# Read the CSV file with material IDs
csv_path = "mp-ids-27430.csv"  

df = pd.read_csv(csv_path)
print(f"CSV columns: {df.columns.tolist()}")

CSV columns: ['mp-754118']


In [39]:
# Extract material IDs (adjust column name if necessary)
if "material_id" in df.columns:
    material_ids = df["material_id"].dropna().astype(str).tolist()[:]
else:
    # If it's the first column and has no header
    material_ids = df.iloc[:, 0].dropna().astype(str).tolist()[:3]

print(f"Loaded {len(material_ids)} material IDs.")
print(f"First 10 IDs: {material_ids[:10]}")

Loaded 3 material IDs.
First 10 IDs: ['mp-633688', 'mp-3799', 'mp-12487']


## Part 2: Fetch Data from Materials Project API

**Note**: You need a Materials Project API key. Set it up using:
- Environment variable: `MP_API`

In [40]:
# Import Materials Project API client
from mp_api.client import MPRester

# Get API key (update this based on your setup)
import os
from dotenv import load_dotenv

# This line loads the .env file
load_dotenv() 

# Now you can access the variables using os.getenv()
api_key = os.getenv('MP_API')

# Option 2: Direct input (uncomment and add your key)
# api_key = "your_api_key_here"

if not api_key:
    raise ValueError("Please set MP_API environment variable or provide API key directly")

print("✅ API key loaded")

✅ API key loaded


In [41]:
# Define fields to retrieve from Materials Project
fields = [
    "material_id",
    "formula_pretty",
    "formation_energy_per_atom",
    "band_gap",
    "density",
    "energy_above_hull",
    "volume",
    "structure",  # may be None for some docs
]

In [42]:
# Fetch data from Materials Project
data_list = []

with MPRester(api_key) as mpr:
    # Batch search by IDs
    print("Fetching data from Materials Project...")
    docs = mpr.materials.summary.search(material_ids=material_ids, fields=fields)

    # Convert results to plain dicts
    for doc in tqdm(docs, desc="Processing materials"):
        try:
            struct_dict = doc.structure.as_dict() if getattr(doc, "structure", None) else None
            data_list.append({
                "material_id": doc.material_id,
                "formula": doc.formula_pretty,
                "formation_energy_per_atom": getattr(doc, "formation_energy_per_atom", None),
                "band_gap": getattr(doc, "band_gap", None),
                "density": getattr(doc, "density", None),
                "energy_above_hull": getattr(doc, "energy_above_hull", None),
                "volume": getattr(doc, "volume", None),
                "structure": struct_dict,
            })
        except Exception as e:
            print(f"⚠️ Skipping {getattr(doc, 'material_id', 'unknown')}: {e}")

print(f"\nCollected {len(data_list)} records")

Fetching data from Materials Project...


Retrieving SummaryDoc documents: 100%|█████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]
Processing materials: 100%|██████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1644.50it/s]


Collected 2 records





### Save Downloaded Data

In [43]:
# Save to JSON file
with open("mp_training_data.json", "w") as f:
    json.dump(data_list, f, indent=2)

print("✅ Saved to mp_training_data.json")

✅ Saved to mp_training_data.json


### Create DataFrame for Easy Access

In [44]:
# Create DataFrame
df_train = pd.DataFrame(data_list)
print(f"DataFrame shape: {df_train.shape}")
df_train.head()

DataFrame shape: (2, 8)


Unnamed: 0,material_id,formula,formation_energy_per_atom,band_gap,density,energy_above_hull,volume,structure
0,mp-12487,CsDyCdTe3,-1.283459,1.5476,5.586797,0.0,469.982918,"{'@module': 'pymatgen.core.structure', '@class..."
1,mp-3799,GdSF,-3.292371,0.0,6.864023,0.0,100.790026,"{'@module': 'pymatgen.core.structure', '@class..."


## Part 3: Build Neighbor Lists Using ASE

We'll use ASE (Atomic Simulation Environment) to compute neighbor lists efficiently.

In [45]:
# Import required libraries
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from ase.neighborlist import neighbor_list

In [46]:
# Set cutoff radius for neighbor detection (in Angstroms)
cutoff_radius = 4.0

print(f"Using cutoff radius: {cutoff_radius} Å")

Using cutoff radius: 4.0 Å


### Example: Build Neighbor List for First Structure

In [47]:
# Grab the first structure from the DataFrame
s0 = df_train["structure"].iloc[0]

# Ensure it's a pymatgen Structure
if isinstance(s0, dict):
    pmg_struct = Structure.from_dict(s0)
elif isinstance(s0, Structure):
    pmg_struct = s0
else:
    raise TypeError("Structure must be a dict or pymatgen Structure")

print(f"Material: {df_train['material_id'].iloc[0]}")
print(f"Formula: {df_train['formula'].iloc[0]}")
print(f"Number of atoms: {len(pmg_struct.sites)}")

Material: mp-12487
Formula: CsDyCdTe3
Number of atoms: 12


In [48]:
# Convert to ASE Atoms
atoms = AseAtomsAdaptor.get_atoms(pmg_struct)
print(f"ASE Atoms object created with {len(atoms)} atoms")
print(f"Cell parameters: {atoms.get_cell_lengths_and_angles()}")

ASE Atoms object created with 12 atoms
Cell parameters: [  4.547226     8.98579253  11.889019    90.          90.
 104.65647551]


  print(f"Cell parameters: {atoms.get_cell_lengths_and_angles()}")


In [49]:
# Build neighbor list
# Returns: i (source), j (neighbor), d (distances in Å)
edge_src, edge_dst, edge_len = neighbor_list(
    "ijd", 
    atoms, 
    cutoff=cutoff_radius, 
    self_interaction=False
)

print(f"Total neighbor pairs: {len(edge_src)}")

Total neighbor pairs: 64


### Inspect Neighbor Pairs

In [50]:
# Get chemical symbols for all atoms
symbols = atoms.get_chemical_symbols()
print(f"Elements present: {set(symbols)}")
print(f"Atomic symbols: {symbols}")

Elements present: {'Te', 'Cs', 'Cd', 'Dy'}
Atomic symbols: ['Cs', 'Cs', 'Dy', 'Dy', 'Cd', 'Cd', 'Te', 'Te', 'Te', 'Te', 'Te', 'Te']


In [51]:
# Print first few neighbor pairs with element names
print("\nFirst 20 neighbor pairs:")
print(f"{'Edge':<6} {'Source':<15} {'Target':<15} {'Distance (Å)':>12}")
print("-" * 50)

for i in range(min(20, len(edge_src))):
    src_idx = edge_src[i]
    dst_idx = edge_dst[i]
    src_el = symbols[src_idx]
    dst_el = symbols[dst_idx]
    dist = edge_len[i]
    
    print(f"{i:<6} {src_idx:>2} ({src_el})<{8} {dst_idx:>2} ({dst_el})<{8} {dist:>12.3f}")


First 20 neighbor pairs:
Edge   Source          Target          Distance (Å)
--------------------------------------------------
0       0 (Cs)<8 10 (Te)<8        3.950
1       0 (Cs)<8 10 (Te)<8        3.950
2       1 (Cs)<8 11 (Te)<8        3.950
3       1 (Cs)<8 11 (Te)<8        3.950
4       2 (Dy)<8  7 (Te)<8        3.133
5       2 (Dy)<8  4 (Cd)<8        3.797
6       2 (Dy)<8  5 (Cd)<8        3.797
7       2 (Dy)<8 10 (Te)<8        3.181
8       2 (Dy)<8  6 (Te)<8        3.133
9       2 (Dy)<8  6 (Te)<8        3.133
10      2 (Dy)<8  5 (Cd)<8        3.797
11      2 (Dy)<8 11 (Te)<8        3.181
12      2 (Dy)<8  7 (Te)<8        3.133
13      2 (Dy)<8  4 (Cd)<8        3.797
14      3 (Dy)<8 11 (Te)<8        3.181
15      3 (Dy)<8  4 (Cd)<8        3.797
16      3 (Dy)<8  8 (Te)<8        3.133
17      3 (Dy)<8 10 (Te)<8        3.181
18      3 (Dy)<8  9 (Te)<8        3.133
19      3 (Dy)<8  5 (Cd)<8        3.797


### Distance Statistics

In [52]:
# Distance statistics
print("\n--- Distance Statistics (Å) ---")
print(f"Min distance:  {np.min(edge_len):.3f}")
print(f"Max distance:  {np.max(edge_len):.3f}")
print(f"Mean distance: {np.mean(edge_len):.3f}")
print(f"Std distance:  {np.std(edge_len):.3f}")


--- Distance Statistics (Å) ---
Min distance:  2.817
Max distance:  3.950
Mean distance: 3.336
Std distance:  0.415


In [53]:
# Neighbor count per atom (degree distribution)
unique, counts = np.unique(edge_src, return_counts=True)

print("\n--- Neighbor Count Per Atom ---")
for i, c in zip(unique, counts):
    print(f"Atom {i:>2} ({symbols[i]}): {c:>2} neighbors")

print(f"\nAverage coordination number: {np.mean(counts):.2f}")


--- Neighbor Count Per Atom ---
Atom  0 (Cs):  2 neighbors
Atom  1 (Cs):  2 neighbors
Atom  2 (Dy): 10 neighbors
Atom  3 (Dy): 10 neighbors
Atom  4 (Cd):  8 neighbors
Atom  5 (Cd):  8 neighbors
Atom  6 (Te):  3 neighbors
Atom  7 (Te):  3 neighbors
Atom  8 (Te):  3 neighbors
Atom  9 (Te):  3 neighbors
Atom 10 (Te):  6 neighbors
Atom 11 (Te):  6 neighbors

Average coordination number: 5.33


## Part 4: Create Graph Data Structure

Now we'll build the complete graph representation with node and edge features.

In [54]:
# Import PyTorch for tensors
import torch

Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe


OSError: [WinError 126] The specified module could not be found. Error loading "E:\cgcnn\.venv\Lib\site-packages\torch\lib\c10.dll" or one of its dependencies.

In [None]:
# Load atom embeddings
with open("atom_embedding.json", "r") as f:
    ATOM_EMB = json.load(f)

EMB_LEN = ATOM_EMB["embedding_length"]
EMB_TAB = ATOM_EMB["embeddings"]

print(f"Loaded embeddings for {len(EMB_TAB)} elements")
print(f"Embedding length: {EMB_LEN}")

In [None]:
def get_atom_embedding(symbol: str):
    """Return embedding vector for given atomic symbol."""
    vec = EMB_TAB.get(symbol)
    if vec is None:
        return [0.0] * EMB_LEN
    return vec

### Edge Feature Encoding: Radial Basis Functions

In [None]:
def rbf(distances: np.ndarray, r_min=0.5, r_max=6.0, bins=32, gamma=None):
    """
    Compute Gaussian radial basis expansion for distances.
    Returns array (E, bins).
    """
    centers = np.linspace(r_min, r_max, bins, dtype=np.float32)
    d = distances.reshape(-1, 1)
    if gamma is None:
        spacing = (r_max - r_min) / max(1, bins - 1)
        sigma = spacing if spacing > 0 else 1.0
        gamma = 1.0 / (2.0 * sigma * sigma)
    return np.exp(-gamma * (d - centers) ** 2)

### Complete Graph Builder Function

In [None]:
def build_graph_from_row(row, cutoff=4.0, rbf_bins=32):
    """
    Convert one DataFrame row into a graph representation.
    
    Returns:
        Dictionary with keys:
        - material_id: material identifier
        - x: node features (N, F)
        - edge_index: edge indices (2, E)
        - edge_attr: edge features (E, D)
        - y: target value (scalar)
        - num_nodes: number of nodes
    """
    s = row["structure"]
    pmg = Structure.from_dict(s) if isinstance(s, dict) else s
    atoms = AseAtomsAdaptor.get_atoms(pmg)

    # Build neighbor list (edges)
    i_src, j_dst, d_len = neighbor_list("ijd", atoms, cutoff=cutoff, self_interaction=False)
    i_src = np.asarray(i_src, dtype=np.int64)
    j_dst = np.asarray(j_dst, dtype=np.int64)
    d_len = np.asarray(d_len, dtype=np.float32)

    # Node features from atom embeddings
    symbols = atoms.get_chemical_symbols()
    x = np.vstack([np.array(get_atom_embedding(sym), dtype=np.float32) for sym in symbols])

    # Edge features: [distance, RBF]
    edge_dist = d_len.reshape(-1, 1)
    edge_rbf = rbf(d_len, r_min=0.5, r_max=6.0, bins=rbf_bins)
    edge_attr = np.hstack([edge_dist, edge_rbf]).astype(np.float32)

    # Target value (try multiple properties)
    y = None
    for key in ["formation_energy_per_atom", "energy_above_hull", "band_gap"]:
        if key in row and row[key] is not None:
            try:
                y = float(row[key])
                break
            except:
                pass

    return {
        "material_id": row.get("material_id", None),
        "x": torch.from_numpy(x),
        "edge_index": torch.from_numpy(np.vstack([i_src, j_dst])),
        "edge_attr": torch.from_numpy(edge_attr),
        "y": None if y is None else torch.tensor([y], dtype=torch.float32),
        "num_nodes": x.shape[0],
    }

### Build Graph for Example Structure

In [None]:
# Build graph for the first material
graph_example = build_graph_from_row(df_train.iloc[0], cutoff=cutoff_radius, rbf_bins=32)

print(f"Material ID: {graph_example['material_id']}")
print(f"Number of nodes: {graph_example['num_nodes']}")
print(f"Number of edges: {graph_example['edge_index'].shape[1]}")
print(f"Node feature shape: {graph_example['x'].shape}")
print(f"Edge feature shape: {graph_example['edge_attr'].shape}")
print(f"Target value: {graph_example['y']}")

### Build Graphs for All Materials

In [None]:
# Build graphs for all materials in the dataset
print("Building graphs for all materials...")
graphs = []

for idx, row in tqdm(df_train.iterrows(), total=len(df_train), desc="Building graphs"):
    try:
        graph = build_graph_from_row(row, cutoff=cutoff_radius, rbf_bins=32)
        graphs.append(graph)
    except Exception as e:
        print(f"\n⚠️ Error processing {row.get('material_id', idx)}: {e}")

print(f"\n✅ Successfully built {len(graphs)} graphs")

### Save Graph Dataset

In [None]:
# Save graphs using PyTorch
torch.save(graphs, "graph_dataset.pt")
print("✅ Saved graph dataset to 'graph_dataset.pt'")

### Dataset Statistics

In [None]:
# Compute dataset statistics
num_nodes_list = [g['num_nodes'] for g in graphs]
num_edges_list = [g['edge_index'].shape[1] for g in graphs]

print("\n=== Dataset Statistics ===")
print(f"Total graphs: {len(graphs)}")
print(f"\nNodes per graph:")
print(f"  Min:  {np.min(num_nodes_list)}")
print(f"  Max:  {np.max(num_nodes_list)}")
print(f"  Mean: {np.mean(num_nodes_list):.2f}")
print(f"\nEdges per graph:")
print(f"  Min:  {np.min(num_edges_list)}")
print(f"  Max:  {np.max(num_edges_list)}")
print(f"  Mean: {np.mean(num_edges_list):.2f}")

## Summary

We have successfully:
1. ✅ Loaded material IDs from CSV
2. ✅ Downloaded crystal structures from Materials Project
3. ✅ Built neighbor lists using ASE
4. ✅ Created graph representations with node and edge features
5. ✅ Saved graph dataset for GNN training

The graphs are now ready to be used for training graph neural networks!