# CGFormer Training with PyTorch Lightning

Cleaned up training code for:
1. **Energy Function Learning** - Crystal property prediction (MAE loss)
2. **Swap-based Structure Search** - REINFORCE training

## Setup

In [None]:
# Colab setup
import os
IN_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in str(get_ipython())

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    %cd /content/drive/MyDrive/CGformer
    
    # Install dependencies
    !pip install -q torch torchvision torchaudio
    !pip install -q pytorch-lightning
    !pip install -q torch-geometric
    !pip install -q pymatgen

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import numpy as np
from typing import Optional, Dict, Any

print(f"PyTorch: {torch.__version__}")
print(f"Lightning: {pl.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")

## 1. Model Definition

In [None]:
# Import model
from model import CrystalGraphConvNet

## 2. Data Module

In [None]:
from data import CIFData, collate_pool, get_train_val_test_loader


class CrystalDataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule for crystal data."""
    
    def __init__(
        self,
        root_dir: str,
        batch_size: int = 16,
        train_ratio: float = 0.7,
        val_ratio: float = 0.15,
        test_ratio: float = 0.15,
        num_workers: int = 0,
        max_num_nbr: int = 12,
        radius: float = 8.0,
        random_seed: int = 123,
    ):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.num_workers = num_workers
        self.max_num_nbr = max_num_nbr
        self.radius = radius
        self.random_seed = random_seed
        
        self.dataset = None
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        
    def setup(self, stage: Optional[str] = None):
        if self.dataset is None:
            self.dataset = CIFData(
                self.root_dir,
                max_num_nbr=self.max_num_nbr,
                radius=self.radius,
                random_seed=self.random_seed
            )
            
            self.train_loader, self.val_loader, self.test_loader = get_train_val_test_loader(
                dataset=self.dataset,
                collate_fn=collate_pool,
                batch_size=self.batch_size,
                train_ratio=self.train_ratio,
                val_ratio=self.val_ratio,
                test_ratio=self.test_ratio,
                return_test=True,
                num_workers=self.num_workers,
                pin_memory=torch.cuda.is_available(),
                train_size=None,
                val_size=None,
                test_size=None,
            )
            
    def train_dataloader(self):
        return self.train_loader
    
    def val_dataloader(self):
        return self.val_loader
    
    def test_dataloader(self):
        return self.test_loader
    
    def get_sample_batch(self):
        """Get a sample batch for model initialization."""
        self.setup()
        return next(iter(self.train_loader))

## 3. Normalizer

In [None]:
class Normalizer:
    """Normalize targets to zero mean and unit variance."""
    
    def __init__(self, tensor=None):
        if tensor is not None:
            self.mean = tensor.mean()
            self.std = tensor.std()
        else:
            self.mean = 0.0
            self.std = 1.0
            
    def norm(self, tensor):
        return (tensor - self.mean) / self.std
    
    def denorm(self, tensor):
        return tensor * self.std + self.mean
    
    def state_dict(self):
        return {'mean': self.mean, 'std': self.std}
    
    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']

## 4. Lightning Module - Energy Prediction

In [None]:
class CGFormerModule(pl.LightningModule):
    """PyTorch Lightning module for CGFormer energy prediction."""
    
    def __init__(
        self,
        orig_atom_fea_len: int,
        nbr_fea_len: int,
        atom_fea_len: int = 64,
        n_conv: int = 3,
        h_fea_len: int = 128,
        n_h: int = 1,
        graphormer_layers: int = 1,
        num_heads: int = 4,
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-4,
        normalizer: Optional[Normalizer] = None,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['normalizer'])
        
        self.model = CrystalGraphConvNet(
            orig_atom_fea_len=orig_atom_fea_len,
            nbr_fea_len=nbr_fea_len,
            atom_fea_len=atom_fea_len,
            n_conv=n_conv,
            h_fea_len=h_fea_len,
            n_h=n_h,
            graphormer_layers=graphormer_layers,
            num_heads=num_heads,
            classification=False,
        )
        
        self.normalizer = normalizer or Normalizer()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        
        self.criterion = nn.MSELoss()
        
    def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
        return self.model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
    
    def _shared_step(self, batch, batch_idx):
        (atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx), target, _ = batch
        
        # Forward pass
        output = self(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
        
        # Normalize target
        target_normed = self.normalizer.norm(target)
        
        # Loss
        loss = self.criterion(output, target_normed)
        
        # MAE (denormalized)
        pred_denorm = self.normalizer.denorm(output)
        mae = F.l1_loss(pred_denorm, target)
        
        return loss, mae
    
    def training_step(self, batch, batch_idx):
        loss, mae = self._shared_step(batch, batch_idx)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_mae', mae, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, mae = self._shared_step(batch, batch_idx)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_mae', mae, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, mae = self._shared_step(batch, batch_idx)
        self.log('test_loss', loss)
        self.log('test_mae', mae)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=10
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_mae'
            }
        }

## 5. Lightning Module - Swap + REINFORCE Training

In [None]:
from swap_utils import (
    parse_poscar_string, poscar_to_tensors,
    sample_sublattice_swap, apply_n_swaps,
    log_prob_sublattice_swap
)


class SwapScoreNet(nn.Module):
    """Network that outputs swap scores per atom."""
    
    def __init__(self, input_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, x):
        # x: [batch, N, input_dim] -> [batch, N]
        return self.net(x).squeeze(-1)


class SwapREINFORCEModule(pl.LightningModule):
    """REINFORCE training for swap-based structure optimization."""
    
    def __init__(
        self,
        energy_model: nn.Module,
        input_dim: int = 5,  # one-hot atom type
        hidden_dim: int = 128,
        learning_rate: float = 1e-4,
        n_swaps_per_step: int = 10,
        reinforce_samples: int = 8,
        entropy_reg: float = 0.01,
        baseline_ema: float = 0.99,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['energy_model'])
        
        self.energy_model = energy_model
        self.energy_model.eval()
        for p in self.energy_model.parameters():
            p.requires_grad = False
            
        self.score_net = SwapScoreNet(input_dim, hidden_dim)
        
        self.learning_rate = learning_rate
        self.n_swaps = n_swaps_per_step
        self.reinforce_samples = reinforce_samples
        self.entropy_reg = entropy_reg
        self.baseline_ema = baseline_ema
        
        self.register_buffer('baseline', torch.tensor(0.0))
        
    def forward(self, atom_types_onehot):
        """Get swap scores for each atom."""
        return self.score_net(atom_types_onehot)
    
    def compute_energy(self, atom_types, tensors):
        """Compute energy using frozen energy model.
        
        Note: This is a placeholder. Real implementation needs
        to convert atom_types to proper crystal graph features.
        """
        # Placeholder: use negative sum as "energy"
        return -atom_types.float().sum(dim=-1)
    
    def training_step(self, batch, batch_idx):
        atom_types, tensors = batch
        batch_size, N = atom_types.shape
        device = atom_types.device
        
        # One-hot encoding
        n_types = 5  # Sr, Ti, Fe, O, VO
        atom_types_onehot = F.one_hot(atom_types, n_types).float()
        
        # Get swap scores
        scores = self(atom_types_onehot)  # [batch, N]
        
        # Sample swaps and compute REINFORCE loss
        b_site_mask = tensors['b_site_mask']
        type_map = tensors['type_map']
        ti, fe = type_map['Ti'], type_map['Fe']
        
        total_log_prob = 0.0
        total_energy = 0.0
        current = atom_types.clone()
        
        for _ in range(self.n_swaps):
            # Sample swap
            swapped, indices = sample_sublattice_swap(
                current, b_site_mask, ti, fe, scores
            )
            
            # Log probability
            log_prob = log_prob_sublattice_swap(
                scores, b_site_mask, ti, fe, current, indices
            )
            
            total_log_prob = total_log_prob + log_prob
            current = swapped
        
        # Compute final energy
        energy = self.compute_energy(current, tensors)
        
        # REINFORCE loss with baseline
        advantage = energy - self.baseline
        reinforce_loss = (advantage.detach() * total_log_prob).mean()
        
        # Entropy regularization
        entropy_loss = -self.entropy_reg * total_log_prob.mean()
        
        loss = reinforce_loss + entropy_loss
        
        # Update baseline
        self.baseline = self.baseline_ema * self.baseline + (1 - self.baseline_ema) * energy.mean().detach()
        
        self.log('train_loss', loss, prog_bar=True)
        self.log('energy', energy.mean(), prog_bar=True)
        self.log('baseline', self.baseline)
        
        return loss
    
    def configure_optimizers(self):
        return optim.Adam(self.score_net.parameters(), lr=self.learning_rate)

## 6. Training - Energy Prediction

In [None]:
# Configuration
DATA_DIR = './STFO_data'  # Change to your data path
BATCH_SIZE = 16
MAX_EPOCHS = 100
LEARNING_RATE = 1e-3

In [None]:
# Check if data exists
import os

if os.path.exists(DATA_DIR):
    print(f"Data directory found: {DATA_DIR}")
    print(f"Files: {os.listdir(DATA_DIR)[:10]}...")
else:
    print(f"Data directory not found: {DATA_DIR}")
    print("Please set DATA_DIR to your crystal data path")

In [None]:
# Create data module
data_module = CrystalDataModule(
    root_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    train_ratio=0.7,
    val_ratio=0.15,
    test_ratio=0.15,
)

# Setup and get sample
data_module.setup()
sample_batch = data_module.get_sample_batch()
(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx), target, cif_ids = sample_batch

print(f"Atom features shape: {atom_fea.shape}")
print(f"Neighbor features shape: {nbr_fea.shape}")
print(f"Target shape: {target.shape}")

In [None]:
# Collect targets for normalization
train_targets = []
for batch in data_module.train_dataloader():
    _, target, _ = batch
    train_targets.append(target)
train_targets = torch.cat(train_targets, dim=0)

normalizer = Normalizer(train_targets)
print(f"Target mean: {normalizer.mean:.4f}")
print(f"Target std: {normalizer.std:.4f}")

In [None]:
# Create model
model = CGFormerModule(
    orig_atom_fea_len=atom_fea.shape[-1],
    nbr_fea_len=nbr_fea.shape[-1],
    atom_fea_len=64,
    n_conv=3,
    h_fea_len=128,
    n_h=1,
    graphormer_layers=1,
    num_heads=4,
    learning_rate=LEARNING_RATE,
    normalizer=normalizer,
)

print(model)

In [None]:
# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_mae',
    dirpath='./checkpoints',
    filename='cgformer-{epoch:02d}-{val_mae:.4f}',
    save_top_k=3,
    mode='min',
)

early_stop_callback = EarlyStopping(
    monitor='val_mae',
    patience=20,
    mode='min',
)

In [None]:
# Trainer
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator='auto',
    devices=1,
    callbacks=[checkpoint_callback, early_stop_callback],
    enable_progress_bar=True,
    log_every_n_steps=10,
)

In [None]:
# Train!
trainer.fit(model, data_module)

In [None]:
# Test
trainer.test(model, data_module)

## 7. Training - Swap + REINFORCE (Demonstration)

In [None]:
# POSCAR example for swap training
poscar_str = """SrTiFeO
1.000000
11.199000 0.000000 0.000000
0.000000 11.199000 0.000000
0.000000 0.000000 15.983000
Sr Ti Fe O VO
32 16 16 88 8
Direct
0.000000 0.250000 0.125000
0.000000 0.250000 0.625000
0.000000 0.750000 0.125000
0.000000 0.750000 0.625000
0.500000 0.250000 0.125000
0.500000 0.250000 0.625000
0.500000 0.750000 0.125000
0.500000 0.750000 0.625000
0.000000 0.250000 0.375000
0.000000 0.250000 0.875000
0.000000 0.750000 0.375000
0.000000 0.750000 0.875000
0.500000 0.250000 0.375000
0.500000 0.250000 0.875000
0.500000 0.750000 0.375000
0.500000 0.750000 0.875000
0.250000 0.000000 0.125000
0.250000 0.000000 0.625000
0.250000 0.500000 0.125000
0.250000 0.500000 0.625000
0.750000 0.000000 0.125000
0.750000 0.000000 0.625000
0.750000 0.500000 0.125000
0.750000 0.500000 0.625000
0.250000 0.000000 0.375000
0.250000 0.000000 0.875000
0.250000 0.500000 0.375000
0.250000 0.500000 0.875000
0.750000 0.000000 0.375000
0.750000 0.000000 0.875000
0.750000 0.500000 0.375000
0.750000 0.500000 0.875000"""

In [None]:
# Demo swap operations
device = 'cuda' if torch.cuda.is_available() else 'cpu'

poscar = parse_poscar_string(poscar_str)
tensors = poscar_to_tensors(poscar, device=device)

# Create batch
batch_size = 128
atom_types = tensors['atom_types'].unsqueeze(0).expand(batch_size, -1).clone()

print(f"Atom types shape: {atom_types.shape}")
print(f"B-site count: {tensors['b_site_mask'].sum().item()}")
print(f"O-site count: {tensors['o_site_mask'].sum().item()}")

In [None]:
# Test swap
import time

n_swaps = 100

torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()

swapped, history = apply_n_swaps(
    atom_types,
    tensors['b_site_mask'],
    tensors['o_site_mask'],
    tensors['type_map'],
    n_swaps=n_swaps,
    swap_mode='both'
)

torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = time.time() - start

print(f"{batch_size} samples Ã— {n_swaps} swaps = {batch_size * n_swaps:,} total")
print(f"Time: {elapsed:.3f}s")
print(f"Throughput: {(batch_size * n_swaps) / elapsed:,.0f} swaps/sec")

## 8. Save/Load Checkpoint

In [None]:
# Save
# trainer.save_checkpoint('cgformer_final.ckpt')

# Load
# model = CGFormerModule.load_from_checkpoint('cgformer_final.ckpt')

---

## Summary

This notebook provides:

1. **CGFormerModule** - Lightning module for energy prediction
   - MSE loss with target normalization
   - MAE metric tracking
   - ReduceLROnPlateau scheduler

2. **SwapREINFORCEModule** - Lightning module for swap policy learning
   - REINFORCE with baseline
   - Entropy regularization
   - Frozen energy model

3. **CrystalDataModule** - Data loading wrapper
   - Train/val/test split
   - Collate function for crystal graphs

4. **Swap utilities** - GPU-accelerated swap operations
   - Gumbel-max sampling
   - Beam search
   - Log probability computation