# Graph Building from Crystal Structures

This notebook covers:
1. Loading material IDs from CSV
2. Fetching crystal structures from Materials Project API
3. Building neighbor lists and graph representations
4. Creating graph data structures with node and edge features

**Purpose**: Convert crystal structures into graph representations suitable for GNN training.

## Part 1: Load Material IDs from CSV

In [1]:
# Import required libraries
import pandas as pd
import json
import numpy as np
from tqdm import tqdm

In [2]:
# Read the CSV file with material IDs
csv_path = "mp-ids-27430.csv"  

df = pd.read_csv(csv_path)
print(f"CSV columns: {df.columns.tolist()}")

CSV columns: ['mp-754118']


In [3]:
# Extract material IDs (adjust column name if necessary)
if "material_id" in df.columns:
    material_ids = df["material_id"].dropna().astype(str).tolist()[:]
else:
    # If it's the first column and has no header
    material_ids = df.iloc[:, 0].dropna().astype(str).tolist()[:3]

print(f"Loaded {len(material_ids)} material IDs.")
print(f"First 3 IDs: {material_ids[:3]}")

Loaded 3 material IDs.
First 3 IDs: ['mp-22862', 'mp-633688', 'mp-3799']


## Part 2: Fetch Data from Materials Project API

**Note**: You need a Materials Project API key. Set it up using:
- Environment variable: `MP_API`

In [4]:
# Import Materials Project API client
from mp_api.client import MPRester

# Get API key (update this based on your setup)
import os
from dotenv import load_dotenv

# This line loads the .env file
load_dotenv() 

# Now you can access the variables using os.getenv()
#api_key = os.getenv('MP_API')
api_key = "zbf31k2xeBx5bKikDyW1ZjrVOImyzYR7"

if not api_key:
    raise ValueError("Please set MP_API environment variable or provide API key directly")

print("API key loaded")

  from .autonotebook import tqdm as notebook_tqdm


API key loaded


In [5]:
# Define fields to retrieve from Materials Project
fields = [
    "material_id",
    "formula_pretty",
    "formation_energy_per_atom",
    "band_gap",
    "density",
    "energy_above_hull",
    "volume",
    "structure",  # may be None for some docs
]

In [6]:
# Fetch data from Materials Project
data_list = []

with MPRester(api_key) as mpr:
    # Batch search by IDs
    print("Fetching data from Materials Project...")
    docs = mpr.materials.summary.search(material_ids=material_ids, fields=fields)

    # Convert results to plain dicts
    for doc in tqdm(docs, desc="Processing materials"):
        try:
            struct_dict = doc.structure.as_dict() if getattr(doc, "structure", None) else None
            data_list.append({
                "material_id": doc.material_id,
                "formula": doc.formula_pretty,
                "formation_energy_per_atom": getattr(doc, "formation_energy_per_atom", None),
                "band_gap": getattr(doc, "band_gap", None),
                "density": getattr(doc, "density", None),
                "energy_above_hull": getattr(doc, "energy_above_hull", None),
                "volume": getattr(doc, "volume", None),
                "structure": struct_dict,
            })
        except Exception as e:
            print(f"⚠️ Skipping {getattr(doc, 'material_id', 'unknown')}: {e}")

print(f"\nCollected {len(data_list)} records")

Fetching data from Materials Project...


Retrieving SummaryDoc documents: 100%|██████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]
Processing materials: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]


Collected 2 records





### Save Downloaded Data

In [7]:
# Save to JSON file
with open("mp_training_data.json", "w") as f:
    json.dump(data_list, f, indent=2)

print("✅ Saved to mp_training_data.json")

✅ Saved to mp_training_data.json


### Create DataFrame for Easy Access

In [8]:
# Create DataFrame
df_train = pd.DataFrame(data_list)
print(f"DataFrame shape: {df_train.shape}")
df_train.head()

DataFrame shape: (2, 8)


Unnamed: 0,material_id,formula,formation_energy_per_atom,band_gap,density,energy_above_hull,volume,structure
0,mp-3799,GdSF,-3.292371,0.0,6.864023,0.0,100.790026,"{'@module': 'pymatgen.core.structure', '@class..."
1,mp-22862,NaCl,-2.038401,5.0037,2.224545,0.0,43.625325,"{'@module': 'pymatgen.core.structure', '@class..."


## Part 3: Build Neighbor Lists Using ASE

We'll use ASE (Atomic Simulation Environment) to compute neighbor lists efficiently.

In [9]:
# Import required libraries
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from ase.neighborlist import neighbor_list

In [10]:
# Set cutoff radius for neighbor detection (in Angstroms)
cutoff_radius = 4.0

print(f"Using cutoff radius: {cutoff_radius} Å")

Using cutoff radius: 4.0 Å


### Example: Build Neighbor List for First Structure

In [27]:
# Grab the first structure from the DataFrame
s0 = df_train["structure"].iloc[1]

# Ensure it's a pymatgen Structure
if isinstance(s0, dict):
    pmg_struct = Structure.from_dict(s0)
elif isinstance(s0, Structure):
    pmg_struct = s0
else:
    raise TypeError("Structure must be a dict or pymatgen Structure")

print(f"Material: {df_train['material_id'].iloc[1]}")
print(f"Formula: {df_train['formula'].iloc[1]}")
print(f"Number of atoms: {len(pmg_struct.sites)}")

Material: mp-22862
Formula: NaCl
Number of atoms: 2


In [28]:
# Convert to ASE Atoms
atoms = AseAtomsAdaptor.get_atoms(pmg_struct)
print(f"ASE Atoms object created with {len(atoms)} atoms")
print(f"Cell parameters: {atoms.get_cell_lengths_and_angles()}")

ASE Atoms object created with 2 atoms
Cell parameters: [ 3.95140292  3.95140217  3.951402   59.9999847  59.99999098 59.99999738]


  print(f"Cell parameters: {atoms.get_cell_lengths_and_angles()}")


In [29]:
# Build neighbor list
# Returns: i (source), j (neighbor), d (distances in Å)
edge_src, edge_dst, edge_len = neighbor_list(
    "ijd", 
    atoms, 
    cutoff=cutoff_radius, 
    self_interaction=False
)

print(f"Total neighbor pairs: {len(edge_src)}")

Total neighbor pairs: 36


### Inspect Neighbor Pairs

In [30]:
# Get chemical symbols for all atoms
symbols = atoms.get_chemical_symbols()
print(f"Elements present: {set(symbols)}")
print(f"Atomic symbols: {symbols}")

Elements present: {'Na', 'Cl'}
Atomic symbols: ['Na', 'Cl']


In [31]:
# Print first few neighbor pairs with element names
print("\nFirst 20 neighbor pairs:")
print(f"{'Edge':<6} {'Source':<15} {'Target':<15} {'Distance (Å)':>12}")
print("-" * 50)

for i in range(min(20, len(edge_src))):
    src_idx = edge_src[i]
    dst_idx = edge_dst[i]
    src_el = symbols[src_idx]
    dst_el = symbols[dst_idx]
    dist = edge_len[i]
    
    print(f"{i:<6} {src_idx:>2} ({src_el})<{8} {dst_idx:>2} ({dst_el})<{8} {dist:>12.3f}")


First 20 neighbor pairs:
Edge   Source          Target          Distance (Å)
--------------------------------------------------
0       0 (Na)<8  1 (Cl)<8        2.794
1       0 (Na)<8  1 (Cl)<8        2.794
2       0 (Na)<8  0 (Na)<8        3.951
3       0 (Na)<8  1 (Cl)<8        2.794
4       0 (Na)<8  0 (Na)<8        3.951
5       0 (Na)<8  0 (Na)<8        3.951
6       0 (Na)<8  0 (Na)<8        3.951
7       0 (Na)<8  0 (Na)<8        3.951
8       0 (Na)<8  0 (Na)<8        3.951
9       0 (Na)<8  1 (Cl)<8        2.794
10      0 (Na)<8  0 (Na)<8        3.951
11      0 (Na)<8  1 (Cl)<8        2.794
12      0 (Na)<8  0 (Na)<8        3.951
13      0 (Na)<8  0 (Na)<8        3.951
14      0 (Na)<8  0 (Na)<8        3.951
15      0 (Na)<8  0 (Na)<8        3.951
16      0 (Na)<8  1 (Cl)<8        2.794
17      0 (Na)<8  0 (Na)<8        3.951
18      1 (Cl)<8  1 (Cl)<8        3.951
19      1 (Cl)<8  0 (Na)<8        2.794


### Distance Statistics

In [32]:
# Distance statistics
print("\n--- Distance Statistics (Å) ---")
print(f"Min distance:  {np.min(edge_len):.3f}")
print(f"Max distance:  {np.max(edge_len):.3f}")
print(f"Mean distance: {np.mean(edge_len):.3f}")
print(f"Std distance:  {np.std(edge_len):.3f}")


--- Distance Statistics (Å) ---
Min distance:  2.794
Max distance:  3.951
Mean distance: 3.566
Std distance:  0.546


In [33]:
# Neighbor count per atom (degree distribution)
unique, counts = np.unique(edge_src, return_counts=True)

print("\n--- Neighbor Count Per Atom ---")
for i, c in zip(unique, counts):
    print(f"Atom {i:>2} ({symbols[i]}): {c:>2} neighbors")

print(f"\nAverage coordination number: {np.mean(counts):.2f}")


--- Neighbor Count Per Atom ---
Atom  0 (Na): 18 neighbors
Atom  1 (Cl): 18 neighbors

Average coordination number: 18.00


## Part 4: Create Graph Data Structure

Now we'll build the complete graph representation with node and edge features.

In [34]:
# Import PyTorch for tensors
import torch

In [35]:
# Load atom embeddings
with open("atom_embedding.json", "r") as f:
    ATOM_EMB = json.load(f)

EMB_LEN = ATOM_EMB["embedding_length"]
EMB_TAB = ATOM_EMB["embeddings"]

print(f"Loaded embeddings for {len(EMB_TAB)} elements")
print(f"Embedding length: {EMB_LEN}")

Loaded embeddings for 118 elements
Embedding length: 186


In [36]:
def get_atom_embedding(symbol: str):
    """Return embedding vector for given atomic symbol."""
    vec = EMB_TAB.get(symbol)
    if vec is None:
        return [0.0] * EMB_LEN
    return vec

### Edge Feature Encoding: Radial Basis Functions

In [37]:
def rbf(distances: np.ndarray, r_min=0.5, r_max=6.0, bins=32, gamma=None):
    """
    Compute Gaussian radial basis expansion for distances.
    Returns array (E, bins).
    """
    centers = np.linspace(r_min, r_max, bins, dtype=np.float32)
    d = distances.reshape(-1, 1)
    if gamma is None:
        spacing = (r_max - r_min) / max(1, bins - 1)
        sigma = spacing if spacing > 0 else 1.0
        gamma = 1.0 / (2.0 * sigma * sigma)
    return np.exp(-gamma * (d - centers) ** 2)

### Complete Graph Builder Function

In [38]:
def build_graph_from_row(row, cutoff=4.0, rbf_bins=32):
    """
    Convert one DataFrame row into a graph representation.
    
    Returns:
        Dictionary with keys:
        - material_id: material identifier
        - x: node features (N, F)
        - edge_index: edge indices (2, E)
        - edge_attr: edge features (E, D)
        - y: target value (scalar)
        - num_nodes: number of nodes
    """
    s = row["structure"]
    pmg = Structure.from_dict(s) if isinstance(s, dict) else s
    atoms = AseAtomsAdaptor.get_atoms(pmg)

    # Build neighbor list (edges)
    i_src, j_dst, d_len = neighbor_list("ijd", atoms, cutoff=cutoff, self_interaction=False)
    i_src = np.asarray(i_src, dtype=np.int64)
    j_dst = np.asarray(j_dst, dtype=np.int64)
    d_len = np.asarray(d_len, dtype=np.float32)

    # Node features from atom embeddings
    symbols = atoms.get_chemical_symbols()
    x = np.vstack([np.array(get_atom_embedding(sym), dtype=np.float32) for sym in symbols])

    # Edge features: [distance, RBF]
    edge_dist = d_len.reshape(-1, 1)
    edge_rbf = rbf(d_len, r_min=0.5, r_max=6.0, bins=rbf_bins)
    edge_attr = np.hstack([edge_dist, edge_rbf]).astype(np.float32)

    # Target value (try multiple properties)
    y = None
    for key in ["formation_energy_per_atom", "energy_above_hull", "band_gap"]:
        if key in row and row[key] is not None:
            try:
                y = float(row[key])
                break
            except:
                pass

    return {
        "material_id": row.get("material_id", None),
        "x": torch.from_numpy(x),
        "edge_index": torch.from_numpy(np.vstack([i_src, j_dst])),
        "edge_attr": torch.from_numpy(edge_attr),
        "y": None if y is None else torch.tensor([y], dtype=torch.float32),
        "num_nodes": x.shape[0],
    }

### Build Graph for Example Structure

In [39]:
# Build graph for the first material
graph_example = build_graph_from_row(df_train.iloc[0], cutoff=cutoff_radius, rbf_bins=32)

print(f"Material ID: {graph_example['material_id']}")
print(f"Number of nodes: {graph_example['num_nodes']}")
print(f"Number of edges: {graph_example['edge_index'].shape[1]}")
print(f"Node feature shape: {graph_example['x'].shape}")
print(f"Edge feature shape: {graph_example['edge_attr'].shape}")
print(f"Target value: {graph_example['y']}")

Material ID: mp-3799
Number of nodes: 6
Number of edges: 96
Node feature shape: torch.Size([6, 186])
Edge feature shape: torch.Size([96, 33])
Target value: tensor([-3.2924])


### Build Graphs for All Materials

In [40]:
# Build graphs for all materials in the dataset
print("Building graphs for all materials...")
graphs = []

for idx, row in tqdm(df_train.iterrows(), total=len(df_train), desc="Building graphs"):
    try:
        graph = build_graph_from_row(row, cutoff=cutoff_radius, rbf_bins=32)
        graphs.append(graph)
    except Exception as e:
        print(f"\n⚠️ Error processing {row.get('material_id', idx)}: {e}")

print(f"\n✅ Successfully built {len(graphs)} graphs")

Building graphs for all materials...


Building graphs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 26.41it/s]


✅ Successfully built 2 graphs





### Save Graph Dataset

In [41]:
# Save graphs using PyTorch
torch.save(graphs, "graph_dataset.pt")
print("✅ Saved graph dataset to 'graph_dataset.pt'")

✅ Saved graph dataset to 'graph_dataset.pt'


### Dataset Statistics

In [42]:
# Compute dataset statistics
num_nodes_list = [g['num_nodes'] for g in graphs]
num_edges_list = [g['edge_index'].shape[1] for g in graphs]

print("\n=== Dataset Statistics ===")
print(f"Total graphs: {len(graphs)}")
print(f"\nNodes per graph:")
print(f"  Min:  {np.min(num_nodes_list)}")
print(f"  Max:  {np.max(num_nodes_list)}")
print(f"  Mean: {np.mean(num_nodes_list):.2f}")
print(f"\nEdges per graph:")
print(f"  Min:  {np.min(num_edges_list)}")
print(f"  Max:  {np.max(num_edges_list)}")
print(f"  Mean: {np.mean(num_edges_list):.2f}")


=== Dataset Statistics ===
Total graphs: 2

Nodes per graph:
  Min:  2
  Max:  6
  Mean: 4.00

Edges per graph:
  Min:  36
  Max:  96
  Mean: 66.00


## Summary

We have successfully:
1. ✅ Loaded material IDs from CSV
2. ✅ Downloaded crystal structures from Materials Project
3. ✅ Built neighbor lists using ASE
4. ✅ Created graph representations with node and edge features
5. ✅ Saved graph dataset for GNN training

The graphs are now ready to be used for training graph neural networks!