# Crystal Graph Convolutional Neural Networks (CGCNN)

## Introduction

This notebook implements the **Crystal Graph Convolutional Neural Network (CGCNN)** architecture, a deep learning framework designed for predicting material properties directly from crystal structures.

### Reference
**Paper**: Xie, T., & Grossman, J. C. (2018). "Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties." *Physical Review Letters*, 120(14), 145301.

---

## Theoretical Background

### 1. Graph Representation of Crystals

In CGCNN, crystal structures are represented as **multigraphs** where:
- **Nodes (Vertices)**: Represent atoms in the crystal
- **Edges**: Represent bonds/interactions between atoms
- **Node Features**: Atom type, electron configuration, electronegativity, etc.
- **Edge Features**: Distance, bond type, coordination information

### 2. Convolution on Crystal Graphs

Unlike standard CNNs that operate on regular grids (images), crystal graphs have:
- Variable number of neighbors per atom
- Non-Euclidean structure
- Periodic boundary conditions

The convolutional operation aggregates information from neighboring atoms to update each atom's feature representation.

### 3. Message Passing Framework

The convolution layer implements a **message passing** scheme:

$$
\mathbf{v}_i^{(t+1)} = \mathbf{v}_i^{(t)} + \sum_{j \in \mathcal{N}(i)} \sigma\left(\mathbf{z}_{ij}^{(t)}\right) \odot \mathbf{g}\left(\mathbf{z}_{ij}^{(t)}\right)
$$

Where:
- $\mathbf{v}_i^{(t)}$: Feature vector of atom $i$ at layer $t$
- $\mathcal{N}(i)$: Set of neighboring atoms of atom $i$
- $\mathbf{z}_{ij}^{(t)} = \mathbf{v}_i^{(t)} \oplus \mathbf{v}_j^{(t)} \oplus \mathbf{u}_{ij}$: Concatenation of central atom, neighbor, and edge features
- $\sigma$: Sigmoid activation (gate function)
- $\mathbf{g}$: Core message function (typically with softplus activation)
- $\odot$: Element-wise multiplication

### 4. Architecture Components

#### a) **Embedding Layer**
Projects initial atom features to a learnable hidden representation:
$$
\mathbf{v}_i^{(0)} = \mathbf{W}_0 \mathbf{x}_i + \mathbf{b}_0
$$

#### b) **Convolutional Layers**
Multiple graph convolution layers that progressively refine atomic representations by aggregating local neighborhood information.

#### c) **Pooling Layer**
Aggregates all atomic features in a crystal to obtain a crystal-level representation:
$$
\mathbf{v}_{\text{crystal}} = \frac{1}{N} \sum_{i=1}^{N} \mathbf{v}_i^{(T)}
$$
(Mean pooling over all atoms)

#### d) **Fully Connected Layers**
Maps the pooled crystal representation to the target property (e.g., formation energy, band gap).

### 5. Key Innovations

1. **Gating Mechanism**: The sigmoid gate ($\sigma$) allows the network to learn which neighbor information is relevant
2. **Residual Connections**: Skip connections ($\mathbf{v}_i^{(t)} +$ ...) help with gradient flow
3. **Batch Normalization**: Stabilizes training and improves convergence
4. **Invariance to Atom Ordering**: Pooling ensures the prediction is invariant to the order of atoms

---

## Implementation

### Import Required Libraries

In [10]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.nn.functional as F

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

PyTorch Version: 2.9.0+cu126
CUDA Available: True


---

## Convolutional Layer

### Theory: Graph Convolution Operation

The `ConvLayer` performs the core graph convolution operation. For each atom, it:

1. **Gathers neighbor features**: Retrieves features from all neighboring atoms
2. **Concatenates features**: Combines central atom, neighbor, and edge features
3. **Applies gating**: Uses a sigmoid gate to control information flow
4. **Aggregates**: Sums the gated messages from all neighbors
5. **Updates**: Adds the aggregated message to the original features (residual connection)

### Mathematical Formulation

Given:
- $\mathbf{v}_i \in \mathbb{R}^{d}$: Central atom feature (atom_fea_len)
- $\mathbf{v}_j \in \mathbb{R}^{d}$: Neighbor atom feature
- $\mathbf{u}_{ij} \in \mathbb{R}^{d'}$: Edge feature (nbr_fea_len)

The convolution operation:

$$
\mathbf{z}_{ij} = [\mathbf{v}_i \oplus \mathbf{v}_j \oplus \mathbf{u}_{ij}] \in \mathbb{R}^{2d + d'}
$$

$$
\mathbf{f}_{ij}, \mathbf{c}_{ij} = \text{split}\left(\text{BN}\left(\mathbf{W} \mathbf{z}_{ij} + \mathbf{b}\right)\right)
$$

$$
\mathbf{m}_{ij} = \sigma(\mathbf{f}_{ij}) \odot \text{softplus}(\mathbf{c}_{ij})
$$

$$
\mathbf{v}_i' = \text{softplus}\left(\mathbf{v}_i + \text{BN}\left(\sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}\right)\right)
$$

In [11]:
class ConvLayer(nn.Module):
    """
    Convolutional operation on graphs
    
    This layer implements the graph convolution operation that aggregates
    information from neighboring atoms using a gating mechanism.
    """
    def __init__(self, atom_fea_len, nbr_fea_len):
        """
        Initialize ConvLayer.

        Parameters
        ----------
        atom_fea_len: int
          Number of atom hidden features.
        nbr_fea_len: int
          Number of bond features.
        """
        super(ConvLayer, self).__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        
        # Fully connected layer: maps concatenated features to gated features
        # Input: [atom_i, atom_j, edge_ij] -> Output: [filter, core]
        self.fc_full = nn.Linear(2*self.atom_fea_len+self.nbr_fea_len,
                                 2*self.atom_fea_len)
        
        # Activation functions
        self.sigmoid = nn.Sigmoid()      # Gate function (0 to 1)
        self.softplus1 = nn.Softplus()   # Smooth approximation of ReLU
        self.softplus2 = nn.Softplus()
        
        # Batch normalization layers for training stability
        self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len)
        self.bn2 = nn.BatchNorm1d(self.atom_fea_len)

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        """
        Forward pass

        N: Total number of atoms in the batch
        M: Max number of neighbors

        Parameters
        ----------
        atom_in_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
          Atom hidden features before convolution
        nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
          Bond features of each atom's M neighbors
        nbr_fea_idx: torch.LongTensor shape (N, M)
          Indices of M neighbors of each atom

        Returns
        -------
        atom_out_fea: nn.Variable shape (N, atom_fea_len)
          Atom hidden features after convolution
        """
        N, M = nbr_fea_idx.shape
        
        # Step 1: Gather neighbor features using indices
        # atom_nbr_fea: (N, M, atom_fea_len)
        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
        
        # Step 2: Concatenate [central_atom, neighbor_atom, edge] features
        # total_nbr_fea: (N, M, 2*atom_fea_len + nbr_fea_len)
        total_nbr_fea = torch.cat(
            [atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
             atom_nbr_fea, nbr_fea], dim=2)
        
        # Step 3: Apply linear transformation and batch normalization
        total_gated_fea = self.fc_full(total_nbr_fea)
        total_gated_fea = self.bn1(total_gated_fea.view(
            -1, self.atom_fea_len*2)).view(N, M, self.atom_fea_len*2)
        
        # Step 4: Split into filter (gate) and core (message) components
        nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
        
        # Step 5: Apply activations
        nbr_filter = self.sigmoid(nbr_filter)     # Gate: controls information flow
        nbr_core = self.softplus1(nbr_core)       # Message: actual information
        
        # Step 6: Element-wise gating and sum over neighbors
        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sumed = self.bn2(nbr_sumed)
        
        # Step 7: Residual connection + activation
        out = self.softplus2(atom_in_fea + nbr_sumed)
        
        return out

---

## Complete CGCNN Architecture

### Theory: End-to-End Network

The complete CGCNN architecture consists of:

1. **Embedding Layer**: Projects raw atom features to learnable representations
2. **Stacked Convolution Layers**: Multiple graph convolutions to capture multi-hop neighborhood information
3. **Pooling Layer**: Aggregates atom-level features to crystal-level representation
4. **Fully Connected Layers**: Maps crystal features to target properties
5. **Output Layer**: 
   - Regression: Single value (e.g., formation energy)
   - Classification: Log-softmax over classes (e.g., metal/non-metal)

### Network Flow

```
Input: (atom_features, edge_features, edge_indices, crystal_mapping)
   |
   v
[Embedding Layer] ‚îÄ‚îÄ> atom_fea_len dimensions
   |
   v
[Conv Layer 1] ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> Message passing iteration 1
   |
   v
[Conv Layer 2] ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> Message passing iteration 2
   |
   v
   ...
   |
   v
[Conv Layer n] ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> Message passing iteration n
   |
   v
[Pooling] ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> Mean over all atoms per crystal
   |
   v
[FC Layer 1] ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> h_fea_len dimensions
   |
   v
[FC Layer 2] ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> Optional additional hidden layers
   |
   v
[Output Layer] ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> Prediction (1 for regression, 2+ for classification)
```

### Hyperparameters

- `orig_atom_fea_len`: Dimension of input atom features (e.g., 92 for one-hot encoding)
- `nbr_fea_len`: Dimension of edge features (e.g., distance-based Gaussian expansion)
- `atom_fea_len`: Hidden atom feature dimension (default: 64)
- `n_conv`: Number of convolution layers (default: 3)
- `h_fea_len`: Hidden dimension after pooling (default: 128)
- `n_h`: Number of fully connected layers (default: 1)

In [12]:
class CrystalGraphConvNet(nn.Module):
    """
    Create a crystal graph convolutional neural network for predicting total
    material properties.
    """
    def __init__(self, orig_atom_fea_len, nbr_fea_len,
                 atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1,
                 classification=False):
        """
        Initialize CrystalGraphConvNet.

        Parameters
        ----------
        orig_atom_fea_len: int
          Number of atom features in the input.
        nbr_fea_len: int
          Number of bond features.
        atom_fea_len: int
          Number of hidden atom features in the convolutional layers
        n_conv: int
          Number of convolutional layers
        h_fea_len: int
          Number of hidden features after pooling
        n_h: int
          Number of hidden layers after pooling
        classification: bool
          Whether this is a classification task
        """
        super(CrystalGraphConvNet, self).__init__()
        self.classification = classification
        
        # Embedding layer: projects input atom features to hidden space
        self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
        
        # Stack of convolutional layers
        self.convs = nn.ModuleList([ConvLayer(atom_fea_len=atom_fea_len,
                                    nbr_fea_len=nbr_fea_len)
                                    for _ in range(n_conv)])
        
        # Transition from convolutional to fully connected layers
        self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
        self.conv_to_fc_softplus = nn.Softplus()
        
        # Additional fully connected hidden layers
        if n_h > 1:
            self.fcs = nn.ModuleList([nn.Linear(h_fea_len, h_fea_len)
                                      for _ in range(n_h-1)])
            self.softpluses = nn.ModuleList([nn.Softplus()
                                             for _ in range(n_h-1)])
        
        # Output layer
        if self.classification:
            self.fc_out = nn.Linear(h_fea_len, 2)  # Binary classification
        else:
            self.fc_out = nn.Linear(h_fea_len, 1)  # Regression
        
        # Classification-specific components
        if self.classification:
            self.logsoftmax = nn.LogSoftmax(dim=1)
            self.dropout = nn.Dropout()

    def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
        """
        Forward pass

        N: Total number of atoms in the batch
        M: Max number of neighbors
        N0: Total number of crystals in the batch

        Parameters
        ----------
        atom_fea: Variable(torch.Tensor) shape (N, orig_atom_fea_len)
          Atom features from atom type
        nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len)
          Bond features of each atom's M neighbors
        nbr_fea_idx: torch.LongTensor shape (N, M)
          Indices of M neighbors of each atom
        crystal_atom_idx: list of torch.LongTensor of length N0
          Mapping from the crystal idx to atom idx

        Returns
        -------
        prediction: nn.Variable shape (N0, ) for regression or (N0, 2) for classification
          Predicted property for each crystal
        """
        # Step 1: Embed atom features
        atom_fea = self.embedding(atom_fea)
        
        # Step 2: Apply graph convolutions
        for conv_func in self.convs:
            atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
        
        # Step 3: Pool atom features to get crystal-level representation
        crys_fea = self.pooling(atom_fea, crystal_atom_idx)
        
        # Step 4: Apply fully connected layers
        crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
        crys_fea = self.conv_to_fc_softplus(crys_fea)
        
        # Apply dropout for classification
        if self.classification:
            crys_fea = self.dropout(crys_fea)
        
        # Additional hidden layers if specified
        if hasattr(self, 'fcs') and hasattr(self, 'softpluses'):
            for fc, softplus in zip(self.fcs, self.softpluses):
                crys_fea = softplus(fc(crys_fea))
        
        # Step 5: Output layer
        out = self.fc_out(crys_fea)
        
        # Apply log-softmax for classification
        if self.classification:
            out = self.logsoftmax(out)
        
        return out

    def pooling(self, atom_fea, crystal_atom_idx):
        """
        Pooling the atom features to crystal features

        N: Total number of atoms in the batch
        N0: Total number of crystals in the batch

        Parameters
        ----------
        atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len)
          Atom feature vectors of the batch
        crystal_atom_idx: list of torch.LongTensor of length N0
          Mapping from the crystal idx to atom idx
          
        Returns
        -------
        crys_fea: Variable(torch.Tensor) shape (N0, atom_fea_len)
          Crystal feature vectors
        """
        # Verify that all atoms are accounted for
        assert sum([len(idx_map) for idx_map in crystal_atom_idx]) == \
            atom_fea.data.shape[0]
        
        # Average pooling: mean of all atom features in each crystal
        summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
                      for idx_map in crystal_atom_idx]
        
        return torch.cat(summed_fea, dim=0)

---

## Model Instantiation and Summary

Let's create example models for both regression and classification tasks.

In [13]:
# Example configuration for regression (e.g., formation energy prediction)
print("=" * 80)
print("REGRESSION MODEL: Predicting Formation Energy")
print("=" * 80)

model_regression = CrystalGraphConvNet(
    orig_atom_fea_len=92,    # Atom feature dimension (e.g., from atom_init.json)
    nbr_fea_len=41,          # Edge feature dimension (e.g., Gaussian expansion)
    atom_fea_len=64,         # Hidden atom features
    n_conv=3,                # 3 graph convolution layers
    h_fea_len=128,           # Hidden features after pooling
    n_h=1,                   # 1 hidden layer after pooling
    classification=False     # Regression task
)

print(model_regression)
print(f"\nTotal parameters: {sum(p.numel() for p in model_regression.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model_regression.parameters() if p.requires_grad):,}")

REGRESSION MODEL: Predicting Formation Energy
CrystalGraphConvNet(
  (embedding): Linear(in_features=92, out_features=64, bias=True)
  (convs): ModuleList(
    (0-2): 3 x ConvLayer(
      (fc_full): Linear(in_features=169, out_features=128, bias=True)
      (sigmoid): Sigmoid()
      (softplus1): Softplus(beta=1.0, threshold=20.0)
      (softplus2): Softplus(beta=1.0, threshold=20.0)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (conv_to_fc): Linear(in_features=64, out_features=128, bias=True)
  (conv_to_fc_softplus): Softplus(beta=1.0, threshold=20.0)
  (fc_out): Linear(in_features=128, out_features=1, bias=True)
)

Total parameters: 80,833
Trainable parameters: 80,833


In [14]:
# Example configuration for classification (e.g., metal vs. non-metal)
print("=" * 80)
print("CLASSIFICATION MODEL: Metal vs. Non-Metal Classification")
print("=" * 80)

model_classification = CrystalGraphConvNet(
    orig_atom_fea_len=92,
    nbr_fea_len=41,
    atom_fea_len=64,
    n_conv=3,
    h_fea_len=128,
    n_h=1,
    classification=True      # Classification task
)

print(model_classification)
print(f"\nTotal parameters: {sum(p.numel() for p in model_classification.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model_classification.parameters() if p.requires_grad):,}")

CLASSIFICATION MODEL: Metal vs. Non-Metal Classification
CrystalGraphConvNet(
  (embedding): Linear(in_features=92, out_features=64, bias=True)
  (convs): ModuleList(
    (0-2): 3 x ConvLayer(
      (fc_full): Linear(in_features=169, out_features=128, bias=True)
      (sigmoid): Sigmoid()
      (softplus1): Softplus(beta=1.0, threshold=20.0)
      (softplus2): Softplus(beta=1.0, threshold=20.0)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (conv_to_fc): Linear(in_features=64, out_features=128, bias=True)
  (conv_to_fc_softplus): Softplus(beta=1.0, threshold=20.0)
  (fc_out): Linear(in_features=128, out_features=2, bias=True)
  (logsoftmax): LogSoftmax(dim=1)
  (dropout): Dropout(p=0.5, inplace=False)
)

Total parameters: 80,962
Trainable parameters: 80,962


---

## Model Testing with Dummy Data

Let's verify the model works correctly with synthetic input data.

In [15]:
# Create dummy data for testing
batch_size = 2           # Number of crystals in batch
n_atoms_1 = 8            # Number of atoms in crystal 1
n_atoms_2 = 6            # Number of atoms in crystal 2
total_atoms = n_atoms_1 + n_atoms_2
max_neighbors = 12       # Maximum number of neighbors per atom

# Atom features: (total_atoms, orig_atom_fea_len)
atom_fea = torch.randn(total_atoms, 92)

# Neighbor features (edge features): (total_atoms, max_neighbors, nbr_fea_len)
nbr_fea = torch.randn(total_atoms, max_neighbors, 41)

# Neighbor indices: (total_atoms, max_neighbors)
# Random indices pointing to neighbor atoms
nbr_fea_idx = torch.randint(0, total_atoms, (total_atoms, max_neighbors))

# Crystal-to-atom mapping: list of tensors indicating which atoms belong to which crystal
crystal_atom_idx = [
    torch.LongTensor(list(range(0, n_atoms_1))),           # Crystal 1: atoms 0-7
    torch.LongTensor(list(range(n_atoms_1, total_atoms)))  # Crystal 2: atoms 8-13
]

print("Input shapes:")
print(f"  atom_fea: {atom_fea.shape}")
print(f"  nbr_fea: {nbr_fea.shape}")
print(f"  nbr_fea_idx: {nbr_fea_idx.shape}")
print(f"  crystal_atom_idx: {len(crystal_atom_idx)} crystals")
print(f"    Crystal 0: {len(crystal_atom_idx[0])} atoms")
print(f"    Crystal 1: {len(crystal_atom_idx[1])} atoms")

Input shapes:
  atom_fea: torch.Size([14, 92])
  nbr_fea: torch.Size([14, 12, 41])
  nbr_fea_idx: torch.Size([14, 12])
  crystal_atom_idx: 2 crystals
    Crystal 0: 8 atoms
    Crystal 1: 6 atoms


In [16]:
# Test regression model
print("\n" + "=" * 80)
print("TESTING REGRESSION MODEL")
print("=" * 80)

model_regression.eval()
with torch.no_grad():
    output_regression = model_regression(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

print(f"\nOutput shape: {output_regression.shape}")
print(f"Output values (predicted properties):")
print(output_regression)
print(f"\nExpected: ({batch_size}, 1) for regression")
print(f"Actual: {output_regression.shape}")
print(f"‚úì Test passed!" if output_regression.shape == (batch_size, 1) else "‚úó Test failed!")


TESTING REGRESSION MODEL

Output shape: torch.Size([2, 1])
Output values (predicted properties):
tensor([[-0.7730],
        [-1.0977]])

Expected: (2, 1) for regression
Actual: torch.Size([2, 1])
‚úì Test passed!


In [17]:
# Test classification model
print("\n" + "=" * 80)
print("TESTING CLASSIFICATION MODEL")
print("=" * 80)

model_classification.eval()
with torch.no_grad():
    output_classification = model_classification(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

print(f"\nOutput shape: {output_classification.shape}")
print(f"Output values (log probabilities):")
print(output_classification)
print(f"\nProbabilities (after exp):")
print(torch.exp(output_classification))
print(f"\nExpected: ({batch_size}, 2) for binary classification")
print(f"Actual: {output_classification.shape}")
print(f"‚úì Test passed!" if output_classification.shape == (batch_size, 2) else "‚úó Test failed!")


TESTING CLASSIFICATION MODEL

Output shape: torch.Size([2, 2])
Output values (log probabilities):
tensor([[-1.5023e-03, -6.5015e+00],
        [-8.7009e-04, -7.0474e+00]])

Probabilities (after exp):
tensor([[9.9850e-01, 1.5013e-03],
        [9.9913e-01, 8.6966e-04]])

Expected: (2, 2) for binary classification
Actual: torch.Size([2, 2])
‚úì Test passed!


---

## Training Considerations

### Loss Functions

**Regression:**
- Mean Squared Error (MSE): $\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2$
- Mean Absolute Error (MAE): $\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|$

**Classification:**
- Negative Log-Likelihood (NLL): $\mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \log p(y_i | \mathbf{x}_i)$

### Optimization

- **Optimizer**: Adam, AdamW (commonly used)
- **Learning Rate**: 0.001 - 0.01 (with learning rate scheduling)
- **Batch Size**: 128 - 512 (depends on GPU memory)
- **Epochs**: 100 - 500 (with early stopping)

### Data Normalization

For regression tasks, normalize target values:
$$
y_{\text{norm}} = \frac{y - \mu_y}{\sigma_y}
$$

### Evaluation Metrics

**Regression:**
- Mean Absolute Error (MAE)
- Root Mean Squared Error (RMSE)
- $R^2$ Score

**Classification:**
- Accuracy
- Precision, Recall, F1-Score
- ROC-AUC

---

### Step 1: Import Additional Libraries for Training

In [18]:
import os
import csv
import json
import random
import functools
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pymatgen.core.structure import Structure
from pymatgen.core.periodic_table import Element

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set plot style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("‚úÖ All libraries imported successfully")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

‚úÖ All libraries imported successfully
Device: cuda


### Step 2: Define CGCNN Data Loading Classes

These classes are needed to load crystal structures and create graph representations.

In [19]:
# Define GaussianDistance class for edge feature expansion
class GaussianDistance(object):
    """Expands the distance by Gaussian basis functions."""
    def __init__(self, dmin, dmax, step, var=None):
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax+step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """Apply Gaussian distance filter to distance array."""
        return np.exp(-(distances[..., np.newaxis] - self.filter)**2 / self.var**2)


# Define AtomInitializer classes
class AtomInitializer(object):
    """Base class for initializing atom feature vectors."""
    def __init__(self, atom_types):
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())

    def state_dict(self):
        return self._embedding


class AtomCustomJSONInitializer(AtomInitializer):
    """Initialize atom features from JSON file mapping atomic numbers to embeddings."""
    def __init__(self, elem_embedding_file):
        from pymatgen.core.periodic_table import Element
        
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        
        # Handle both formats: symbol keys and integer keys
        converted_embedding = {}
        for key, value in elem_embedding.items():
            try:
                # Try to convert key to integer (if already numeric)
                atomic_num = int(key)
            except ValueError:
                # If key is a symbol, convert to atomic number
                try:
                    atomic_num = Element(key).Z
                except Exception as e:
                    print(f"Warning: Could not process element '{key}': {e}")
                    continue
            
            converted_embedding[atomic_num] = value
        
        atom_types = set(converted_embedding.keys())
        super(AtomCustomJSONInitializer, self).__init__(atom_types)
        
        for key, value in converted_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)


print("‚úÖ CGCNN classes defined: GaussianDistance, AtomInitializer, AtomCustomJSONInitializer")

‚úÖ CGCNN classes defined: GaussianDistance, AtomInitializer, AtomCustomJSONInitializer


### Step 3: Define CIFData Dataset Class

In [20]:
# Define CIFData dataset class
class CIFData(Dataset):
    """
    Dataset for crystal structures stored as CIF files.
    
    Expected directory structure:
        root_dir/
        ‚îú‚îÄ‚îÄ id_prop.csv          # material_id, target_property
        ‚îú‚îÄ‚îÄ atom_init.json       # atom embeddings
        ‚îú‚îÄ‚îÄ material1.cif
        ‚îú‚îÄ‚îÄ material2.cif
        ‚îî‚îÄ‚îÄ ...
    """
    def __init__(self, root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2,
                 random_seed=123):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        
        assert os.path.exists(root_dir), f'root_dir does not exist: {root_dir}'
        
        # Load id_prop.csv
        id_prop_file = os.path.join(self.root_dir, 'id_prop.csv')
        assert os.path.exists(id_prop_file), f'id_prop.csv not found: {id_prop_file}'
        
        with open(id_prop_file) as f:
            reader = csv.reader(f)
            self.id_prop_data = [row for row in reader]
        
        random.seed(random_seed)
        random.shuffle(self.id_prop_data)
        
        # Load atom embeddings
        atom_init_file = os.path.join(self.root_dir, 'atom_init.json')
        assert os.path.exists(atom_init_file), f'atom_init.json not found: {atom_init_file}'
        
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)

    def __len__(self):
        return len(self.id_prop_data)

    @functools.lru_cache(maxsize=None)
    def __getitem__(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id+'.cif'))
        
        # Get atom features
        atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number)
                              for i in range(len(crystal))])
        atom_fea = torch.Tensor(atom_fea)
        
        # Get neighbors within cutoff radius
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        
        # Build neighbor feature and index arrays
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn(f'{cif_id}: not enough neighbors (found {len(nbr)}, need {self.max_num_nbr})')
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) +
                                   [0] * (self.max_num_nbr - len(nbr)))
                nbr_fea.append(list(map(lambda x: x[1], nbr)) +
                               [self.radius + 1.] * (self.max_num_nbr - len(nbr)))
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr[:self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1], nbr[:self.max_num_nbr])))
        
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])
        
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id


print("‚úÖ CIFData class defined")

‚úÖ CIFData class defined


In [21]:
# Define collate function for batching variable-sized graphs
def collate_pool(dataset_list):
    """
    Collate a list of data and return a batch for predicting crystal properties.
    Handles variable-sized graphs by concatenating them into a single large graph.
    """
    batch_atom_fea, batch_nbr_fea, batch_nbr_fea_idx = [], [], []
    crystal_atom_idx, batch_target = [], []
    batch_cif_ids = []
    base_idx = 0
    
    for i, ((atom_fea, nbr_fea, nbr_fea_idx), target, cif_id) in enumerate(dataset_list):
        n_i = atom_fea.shape[0]  # number of atoms for this crystal
        batch_atom_fea.append(atom_fea)
        batch_nbr_fea.append(nbr_fea)
        batch_nbr_fea_idx.append(nbr_fea_idx + base_idx)
        
        new_idx = torch.LongTensor(np.arange(n_i) + base_idx)
        crystal_atom_idx.append(new_idx)
        batch_target.append(target)
        batch_cif_ids.append(cif_id)
        base_idx += n_i
    
    return (torch.cat(batch_atom_fea, dim=0),
            torch.cat(batch_nbr_fea, dim=0),
            torch.cat(batch_nbr_fea_idx, dim=0),
            crystal_atom_idx), \
           torch.stack(batch_target, dim=0), \
           batch_cif_ids


# Define data loader creation function
def get_train_val_test_loader(dataset, collate_fn=default_collate,
                              batch_size=64, train_ratio=None,
                              val_ratio=0.1, test_ratio=0.1, return_test=False,
                              num_workers=0, pin_memory=False, **kwargs):
    """Create train, validation, and test data loaders with proper splitting."""
    
    total_size = len(dataset)
    if train_ratio is None:
        assert val_ratio + test_ratio < 1
        train_ratio = 1 - val_ratio - test_ratio
        print(f'[Info] train_ratio not specified, using {train_ratio} for training')
    else:
        assert train_ratio + val_ratio + test_ratio <= 1
    
    indices = list(range(total_size))
    
    train_size = kwargs.get('train_size', int(train_ratio * total_size))
    test_size = kwargs.get('test_size', int(test_ratio * total_size))
    valid_size = kwargs.get('val_size', int(val_ratio * total_size))
    
    train_sampler = SubsetRandomSampler(indices[:train_size])
    val_sampler = SubsetRandomSampler(indices[-(valid_size + test_size):-test_size])
    
    train_loader = DataLoader(dataset, batch_size=batch_size,
                              sampler=train_sampler,
                              num_workers=num_workers,
                              collate_fn=collate_fn, pin_memory=pin_memory)
    val_loader = DataLoader(dataset, batch_size=batch_size,
                            sampler=val_sampler,
                            num_workers=num_workers,
                            collate_fn=collate_fn, pin_memory=pin_memory)
    
    if return_test:
        test_sampler = SubsetRandomSampler(indices[-test_size:])
        test_loader = DataLoader(dataset, batch_size=batch_size,
                                 sampler=test_sampler,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn, pin_memory=pin_memory)
        return train_loader, val_loader, test_loader
    else:
        return train_loader, val_loader


print("‚úÖ collate_pool and get_train_val_test_loader functions defined")

‚úÖ collate_pool and get_train_val_test_loader functions defined


### Step 4: Load Dataset from Notebook 02

We'll use the CIF structures created in notebook 02.

In [None]:
# Load dataset
print("Loading dataset...")
dataset = CIFData(
    root_dir=CIF_DIR,
    max_num_nbr=MAX_NUM_NBR,
    radius=RADIUS,
    dmin=DMIN,
    step=DSTEP,
    random_seed=RANDOM_SEED
)

# Get feature dimensions from first sample
(atom_fea, nbr_fea, nbr_fea_idx), target, cif_id = dataset[0]
ORIG_ATOM_FEA_LEN = atom_fea.shape[1]
NBR_FEA_LEN = nbr_fea.shape[2]

print(f"‚úÖ Dataset loaded:")
print(f"   Total samples: {len(dataset)}")
print(f"   Atom feature length: {ORIG_ATOM_FEA_LEN}")
print(f"   Neighbor feature length: {NBR_FEA_LEN}")
print(f"   First sample: {cif_id}, target: {target.item():.6f}")

### Step 5: Create Data Loaders

In [22]:
# Create train, validation, and test data loaders
try:
    train_loader, val_loader, test_loader = get_train_val_test_loader(
        dataset=dataset,
        collate_fn=collate_pool,
        batch_size=BATCH_SIZE,
        train_ratio=0.7,
        val_ratio=0.15,
        test_ratio=0.15,
        return_test=True,
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    print(f"‚úÖ Data Loaders Created:")
    print(f"   Training batches: {len(train_loader)}")
    print(f"   Validation batches: {len(val_loader)}")
    print(f"   Test batches: {len(test_loader)}")
    
    # Calculate approximate dataset split sizes
    train_size = len(train_loader) * BATCH_SIZE
    val_size = len(val_loader) * BATCH_SIZE
    test_size = len(test_loader) * BATCH_SIZE
    
    print(f"\nüìä Approximate dataset split:")
    print(f"   Training samples: ~{train_size}")
    print(f"   Validation samples: ~{val_size}")
    print(f"   Test samples: ~{test_size}")
    print(f"   Total: {len(dataset)}")
    
except Exception as e:
    print(f"‚ùå Error creating data loaders: {e}")
    import traceback
    traceback.print_exc()

‚ùå Error creating data loaders: name 'dataset' is not defined


Traceback (most recent call last):
  File "C:\Users\abhin\AppData\Local\Temp\ipykernel_948\1469353720.py", line 4, in <module>
    dataset=dataset,
            ^^^^^^^
NameError: name 'dataset' is not defined. Did you mean: 'Dataset'?


In [None]:
# CGCNN Configuration
MAX_NUM_NBR = 12
RADIUS = 8.0
DMIN = 0
DSTEP = 0.2
BATCH_SIZE = 32
RANDOM_SEED = 42

# Path to the CIF structures directory
CIF_DIR = os.path.join('..', 'notebooks', 'cif_structures')

print(f"CGCNN Configuration:")
print(f"  Max neighbors per atom: {MAX_NUM_NBR}")
print(f"  Cutoff radius: {RADIUS} √Ö")
print(f"  Gaussian distance range: {DMIN} to {RADIUS}")
print(f"  Gaussian distance step: {DSTEP}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Random seed: {RANDOM_SEED}")
print(f"  CIF directory: {CIF_DIR}")

# Load the dataset
try:
    dataset = CIFData(
        root_dir=CIF_DIR,
        max_num_nbr=MAX_NUM_NBR,
        radius=RADIUS,
        dmin=DMIN,
        step=DSTEP,
        random_seed=RANDOM_SEED
    )
    
    # Get dimensions from first sample
    (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id = dataset[0]
    ORIG_ATOM_FEA_LEN = atom_fea.shape[1]
    NBR_FEA_LEN = nbr_fea.shape[2]
    
    print(f"\n‚úÖ Dataset loaded successfully:")
    print(f"   Total samples: {len(dataset)}")
    print(f"   Original atom feature length: {ORIG_ATOM_FEA_LEN}")
    print(f"   Neighbor feature length: {NBR_FEA_LEN}")
    print(f"   First sample: {cif_id}, target: {target.item():.6f}")
    
except Exception as e:
    print(f"‚ùå Error loading dataset: {e}")
    import traceback
    traceback.print_exc()

### Step 6: Initialize Model, Loss Function, and Optimizer

In [None]:
# Model hyperparameters
ATOM_FEA_LEN = 64
N_CONV = 3
H_FEA_LEN = 128
N_H = 1
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0
EPOCHS = 100

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
model = CrystalGraphConvNet(
    orig_atom_fea_len=ORIG_ATOM_FEA_LEN,
    nbr_fea_len=NBR_FEA_LEN,
    atom_fea_len=ATOM_FEA_LEN,
    n_conv=N_CONV,
    h_fea_len=H_FEA_LEN,
    n_h=N_H,
    classification=False  # Regression task
)

model = model.to(device)

print(f"\n‚úÖ Model initialized:")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Loss function (MSE for regression)
criterion = nn.MSELoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

print(f"\n‚úÖ Training setup complete:")
print(f"   Loss function: MSE")
print(f"   Optimizer: Adam")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Weight decay: {WEIGHT_DECAY}")
print(f"   Epochs: {EPOCHS}")

### Step 7: Define Training and Validation Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    
    total_loss = 0.0
    predictions = []
    targets = []
    
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for batch_idx, (input_data, target, _) in enumerate(pbar):
        # Move data to device
        atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input_data
        atom_fea = atom_fea.to(device)
        nbr_fea = nbr_fea.to(device)
        nbr_fea_idx = nbr_fea_idx.to(device)
        target = target.to(device)
        
        # Forward pass
        output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
        loss = criterion(output, target)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item() * target.size(0)
        predictions.extend(output.detach().cpu().numpy().flatten())
        targets.extend(target.detach().cpu().numpy().flatten())
        
        # Update progress bar
        pbar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(train_loader.dataset)
    mae = mean_absolute_error(targets, predictions)
    rmse = np.sqrt(mean_squared_error(targets, predictions))
    
    return avg_loss, mae, rmse, predictions, targets


def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    
    total_loss = 0.0
    predictions = []
    targets = []
    cif_ids = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation', leave=False)
        for batch_idx, (input_data, target, batch_cif_ids) in enumerate(pbar):
            # Move data to device
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input_data
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_fea_idx = nbr_fea_idx.to(device)
            target = target.to(device)
            
            # Forward pass
            output = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
            loss = criterion(output, target)
            
            # Track metrics
            total_loss += loss.item() * target.size(0)
            predictions.extend(output.detach().cpu().numpy().flatten())
            targets.extend(target.detach().cpu().numpy().flatten())
            cif_ids.extend(batch_cif_ids)
            
            # Update progress bar
            pbar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(val_loader.dataset)
    mae = mean_absolute_error(targets, predictions)
    rmse = np.sqrt(mean_squared_error(targets, predictions))
    r2 = r2_score(targets, predictions)
    
    return avg_loss, mae, rmse, r2, predictions, targets, cif_ids


print("‚úÖ Training and validation functions defined")

### Step 8: Training Loop with Progress Tracking

In [None]:
# Initialize tracking variables
history = {
    'train_loss': [],
    'train_mae': [],
    'train_rmse': [],
    'val_loss': [],
    'val_mae': [],
    'val_rmse': [],
    'val_r2': [],
    'learning_rate': []
}

best_val_loss = float('inf')
best_model_state = None
patience_counter = 0
early_stopping_patience = 30

print("=" * 80)
print("Starting Training")
print("=" * 80)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 80)
    
    # Train
    train_loss, train_mae, train_rmse, _, _ = train_epoch(
        model, train_loader, criterion, optimizer, device
    )
    
    # Validate
    val_loss, val_mae, val_rmse, val_r2, val_preds, val_targets, _ = validate_epoch(
        model, val_loader, criterion, device
    )
    
    # Update learning rate scheduler
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_mae'].append(train_mae)
    history['train_rmse'].append(train_rmse)
    history['val_loss'].append(val_loss)
    history['val_mae'].append(val_mae)
    history['val_rmse'].append(val_rmse)
    history['val_r2'].append(val_r2)
    history['learning_rate'].append(current_lr)
    
    # Print metrics
    print(f"Train Loss: {train_loss:.6f} | Train MAE: {train_mae:.6f} | Train RMSE: {train_rmse:.6f}")
    print(f"Val Loss:   {val_loss:.6f} | Val MAE:   {val_mae:.6f} | Val RMSE:   {val_rmse:.6f} | Val R¬≤: {val_r2:.4f}")
    print(f"Learning Rate: {current_lr:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict().copy()
        patience_counter = 0
        print(f"‚úÖ New best model! Val Loss: {best_val_loss:.6f}")
        
        # Save model checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_mae': val_mae,
        }, 'best_model.pth')
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= early_stopping_patience:
        print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch+1} epochs")
        break
    
    # Save training progress periodically
    if (epoch + 1) % 10 == 0:
        pd.DataFrame(history).to_csv('training_history.csv', index=False)

print("\n" + "=" * 80)
print("Training Complete!")
print("=" * 80)
print(f"Best validation loss: {best_val_loss:.6f}")

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("‚úÖ Best model loaded")

# Save final history
pd.DataFrame(history).to_csv('training_history.csv', index=False)
print("‚úÖ Training history saved to 'training_history.csv'")

### Step 9: Visualize Training Progress

In [None]:
# Create comprehensive training visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

epochs_range = range(1, len(history['train_loss']) + 1)

# Plot 1: Loss curves
axes[0, 0].plot(epochs_range, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0, 0].plot(epochs_range, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('MSE Loss', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: MAE curves
axes[0, 1].plot(epochs_range, history['train_mae'], 'b-', label='Train MAE', linewidth=2)
axes[0, 1].plot(epochs_range, history['val_mae'], 'r-', label='Validation MAE', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Mean Absolute Error', fontsize=12)
axes[0, 1].set_title('Mean Absolute Error', fontsize=14, fontweight='bold')
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: RMSE curves
axes[1, 0].plot(epochs_range, history['train_rmse'], 'b-', label='Train RMSE', linewidth=2)
axes[1, 0].plot(epochs_range, history['val_rmse'], 'r-', label='Validation RMSE', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Root Mean Squared Error', fontsize=12)
axes[1, 0].set_title('Root Mean Squared Error', fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: R¬≤ Score and Learning Rate
ax1 = axes[1, 1]
ax2 = ax1.twinx()

line1 = ax1.plot(epochs_range, history['val_r2'], 'g-', label='Validation R¬≤', linewidth=2)
line2 = ax2.plot(epochs_range, history['learning_rate'], 'orange', label='Learning Rate', 
                 linewidth=2, linestyle='--')

ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('R¬≤ Score', fontsize=12, color='g')
ax2.set_ylabel('Learning Rate', fontsize=12, color='orange')
ax1.set_title('R¬≤ Score and Learning Rate', fontsize=14, fontweight='bold')
ax1.tick_params(axis='y', labelcolor='g')
ax2.tick_params(axis='y', labelcolor='orange')
ax1.grid(True, alpha=0.3)

# Combine legends
lines = line1 + line2
labels = [l.get_label() for l in lines]
ax1.legend(lines, labels, fontsize=10)

plt.tight_layout()
plt.savefig('training_progress.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Training progress visualizations saved to 'training_progress.png'")

### Step 10: Evaluate on Test Set

In [None]:
# Evaluate on test set
print("=" * 80)
print("Evaluating on Test Set")
print("=" * 80)

test_loss, test_mae, test_rmse, test_r2, test_preds, test_targets, test_cif_ids = validate_epoch(
    model, test_loader, criterion, device
)

print(f"\nüìä Test Set Performance:")
print(f"   Loss (MSE): {test_loss:.6f}")
print(f"   MAE:        {test_mae:.6f}")
print(f"   RMSE:       {test_rmse:.6f}")
print(f"   R¬≤ Score:   {test_r2:.4f}")

# Create results DataFrame
test_results = pd.DataFrame({
    'material_id': test_cif_ids,
    'true_value': test_targets,
    'predicted_value': test_preds,
    'error': np.array(test_targets) - np.array(test_preds),
    'abs_error': np.abs(np.array(test_targets) - np.array(test_preds))
})

# Sort by absolute error
test_results_sorted = test_results.sort_values('abs_error', ascending=False)

print(f"\nüìã Top 5 Best Predictions:")
print(test_results_sorted.tail(5)[['material_id', 'true_value', 'predicted_value', 'abs_error']])

print(f"\n‚ö†Ô∏è Top 5 Worst Predictions:")
print(test_results_sorted.head(5)[['material_id', 'true_value', 'predicted_value', 'abs_error']])

# Save test results
test_results.to_csv('test_results.csv', index=False)
print(f"\n‚úÖ Test results saved to 'test_results.csv'")

### Step 11: Prediction Visualizations

In [None]:
# Create comprehensive prediction visualizations
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Predicted vs True Values (Test Set)
axes[0, 0].scatter(test_targets, test_preds, alpha=0.6, s=100, edgecolors='black', linewidth=0.5)
axes[0, 0].plot([min(test_targets), max(test_targets)], 
                [min(test_targets), max(test_targets)], 
                'r--', linewidth=2, label='Perfect Prediction')
axes[0, 0].set_xlabel('True Values', fontsize=12)
axes[0, 0].set_ylabel('Predicted Values', fontsize=12)
axes[0, 0].set_title(f'Test Set: Predicted vs True\nR¬≤ = {test_r2:.4f}, MAE = {test_mae:.4f}', 
                     fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Residual Plot
residuals = np.array(test_targets) - np.array(test_preds)
axes[0, 1].scatter(test_preds, residuals, alpha=0.6, s=100, edgecolors='black', linewidth=0.5)
axes[0, 1].axhline(y=0, color='r', linestyle='--', linewidth=2)
axes[0, 1].set_xlabel('Predicted Values', fontsize=12)
axes[0, 1].set_ylabel('Residuals (True - Predicted)', fontsize=12)
axes[0, 1].set_title('Residual Plot', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Error Distribution
axes[1, 0].hist(residuals, bins=20, edgecolor='black', alpha=0.7, color='skyblue')
axes[1, 0].axvline(x=0, color='r', linestyle='--', linewidth=2, label='Zero Error')
axes[1, 0].set_xlabel('Residuals (True - Predicted)', fontsize=12)
axes[1, 0].set_ylabel('Frequency', fontsize=12)
axes[1, 0].set_title(f'Error Distribution\nMean = {np.mean(residuals):.4f}, Std = {np.std(residuals):.4f}', 
                     fontsize=14, fontweight='bold')
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3, axis='y')

# Plot 4: Absolute Error Distribution
abs_errors = np.abs(residuals)
axes[1, 1].hist(abs_errors, bins=20, edgecolor='black', alpha=0.7, color='lightcoral')
axes[1, 1].axvline(x=test_mae, color='r', linestyle='--', linewidth=2, 
                   label=f'MAE = {test_mae:.4f}')
axes[1, 1].set_xlabel('Absolute Error', fontsize=12)
axes[1, 1].set_ylabel('Frequency', fontsize=12)
axes[1, 1].set_title('Absolute Error Distribution', fontsize=14, fontweight='bold')
axes[1, 1].legend(fontsize=10)
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('prediction_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Prediction analysis visualizations saved to 'prediction_analysis.png'")

### Step 12: Compare Performance Across All Splits

In [None]:
# Evaluate on all splits for comprehensive comparison
print("Evaluating on all data splits...")

# Get final training set predictions
_, train_mae_final, train_rmse_final, _, train_preds_final, train_targets_final = train_epoch(
    model, train_loader, criterion, optimizer, device
)

# Get final validation set predictions
_, val_mae_final, val_rmse_final, val_r2_final, val_preds_final, val_targets_final, _ = validate_epoch(
    model, val_loader, criterion, device
)

# Create comparison DataFrame
performance_summary = pd.DataFrame({
    'Split': ['Training', 'Validation', 'Test'],
    'MAE': [train_mae_final, val_mae_final, test_mae],
    'RMSE': [train_rmse_final, val_rmse_final, test_rmse],
    'R¬≤': [r2_score(train_targets_final, train_preds_final), val_r2_final, test_r2]
})

print("\n" + "=" * 80)
print("FINAL PERFORMANCE SUMMARY")
print("=" * 80)
print(performance_summary.to_string(index=False))
print("=" * 80)

# Visualize performance comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

metrics = ['MAE', 'RMSE', 'R¬≤']
colors = ['#3498db', '#e74c3c', '#2ecc71']

for idx, metric in enumerate(metrics):
    axes[idx].bar(performance_summary['Split'], performance_summary[metric], 
                  color=colors, alpha=0.7, edgecolor='black', linewidth=2)
    axes[idx].set_ylabel(metric, fontsize=12)
    axes[idx].set_title(f'{metric} Comparison', fontsize=14, fontweight='bold')
    axes[idx].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for i, v in enumerate(performance_summary[metric]):
        axes[idx].text(i, v, f'{v:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('performance_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úÖ Performance comparison saved to 'performance_comparison.png'")