In [1]:
"""
Magnetic Topological Material Classifier
=======================================
A deep learning framework for simultaneously predicting magnetic ordering and topological classification
of crystalline materials using graph neural networks and transformers.
"""

import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add, scatter_mean
from sklearn.metrics import accuracy_score, f1_score, classification_report
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv

# Materials science libraries
import pymatgen as pmg
from pymatgen.core.structure import Structure
from pymatgen.core import Element
from pymatgen.analysis.magnetism.analyzer import CollinearMagneticStructureAnalyzer
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from mp_api.client import MPRester

# Load environment variables
load_dotenv()
api_key = os.getenv("MP_API_KEY")

# Constants
ORDER_ENCODE = {"NM": 0, "AFM": 1, "FM": 2, "FiM": 2}  # FiM grouped with FM
TOPO_ENCODE = {False: 0, True: 1}  # Non-topological vs topological

# Global parameters
PARAMS = {
    'max_radius': 8.0,        # Maximum cutoff radius for atom connections
    'n_norm': 35,             # Normalization factor
    'hidden_dim': 128,        # Hidden dimension size
    'num_heads': 4,           # Number of attention heads
    'num_layers': 3,          # Number of transformer layers
    'batch_size': 8,          # Batch size for training
    'lr': 0.0005,             # Learning rate
    'weight_decay': 0.01,     # Weight decay for regularization
    'max_epochs': 200,        # Maximum training epochs
    'early_stop_patience': 15  # Patience for early stopping
}


#===============================
# DATA STRUCTURES AND PROCESSING
#===============================

class MaterialData(Data):
    """
    Custom PyTorch Geometric Data class for material structure data with proper batching.
    """
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index':
            return self.num_nodes
        if key == 'cell_index':
            return self.num_nodes
        return super().__inc__(key, value, *args, **kwargs)


def get_element_properties(symbol):
    """Get key properties for an element by symbol."""
    try:
        elem = Element(symbol)
        return {
            'Z': elem.Z,
            'group': elem.group,
            'row': elem.row,
            'atomic_radius': elem.atomic_radius or 0.0,
            'atomic_mass': elem.atomic_mass or 0.0,
            'electronegativity': elem.electronegativity or 0.0,
            'is_magnetic': int(symbol in ['Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Gd', 'Dy', 'Ho', 'Er', 'Tm', 'Yb'])
        }
    except Exception:
        # Default values if element properties can't be retrieved
        return {'Z': 0, 'group': 0, 'row': 0, 'atomic_radius': 0.0, 
                'atomic_mass': 0.0, 'electronegativity': 0.0, 'is_magnetic': 0}


def extract_structure_features(structure):
    """Extract features from a pymatgen Structure object."""
    # Symmetry features
    analyzer = SpacegroupAnalyzer(structure)
    spacegroup = analyzer.get_space_group_number()
    point_group = analyzer.get_point_group_symbol()
    has_inversion = int(analyzer.has_inversion())
    
    # Magnetic features
    mag_elements = ['Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Gd', 'Dy', 'Ho', 'Er', 'Tm', 'Yb']
    mag_count = sum(1 for site in structure if site.species_string in mag_elements)
    mag_fraction = mag_count / len(structure)
    
    # Lattice features
    a, b, c = structure.lattice.abc
    alpha, beta, gamma = structure.lattice.angles
    volume = structure.lattice.volume
    density = structure.density
    
    return {
        'spacegroup': spacegroup,
        'point_group_id': hash(point_group) % 100,  # Simple hash for point group
        'has_inversion': has_inversion,
        'mag_fraction': mag_fraction,
        'a': a, 'b': b, 'c': c,
        'alpha': alpha, 'beta': beta, 'gamma': gamma,
        'volume': volume,
        'density': density
    }


def structure_to_graph(structure, max_radius=8.0):
    """
    Convert a pymatgen Structure to a graph representation.
    
    Args:
        structure: pymatgen Structure object
        max_radius: Maximum bond distance to consider
        
    Returns:
        x: Node features tensor
        edge_index: Edge connectivity tensor
        edge_attr: Edge features tensor
        pos: Node positions tensor
        structure_features: Global structure features tensor
    """
    num_sites = len(structure)
    
    # Node features: 7 features per atom
    node_features = []
    for site in structure:
        element = site.species_string
        props = get_element_properties(element)
        
        # Feature vector for each atom: element properties
        features = [
            props['Z'] / 100,  # Normalized atomic number
            props['group'] / 18,  # Normalized group
            props['row'] / 7,  # Normalized row
            props['atomic_radius'] / 2.0 if props['atomic_radius'] else 0,  # Normalized radius
            props['electronegativity'] / 4.0 if props['electronegativity'] else 0,  # Normalized electronegativity
            props['atomic_mass'] / 250.0,  # Normalized mass
            float(props['is_magnetic'])  # Is magnetic element
        ]
        node_features.append(features)
    
    # Node positions
    positions = torch.tensor(structure.cart_coords, dtype=torch.float)
    
    # Create edges based on distance
    src_list = []
    dst_list = []
    edge_attr_list = []
    
    # For each pair of atoms, check if they're within max_radius
    for i in range(num_sites):
        for j in range(num_sites):
            if i == j:  # Skip self-loops for now
                continue
                
            # Get the distance considering periodic boundary conditions
            dist = structure.get_distance(i, j)
            
            if dist <= max_radius:
                src_list.append(i)
                dst_list.append(j)
                
                # Edge features: distance, direction vector (normalized)
                direction = positions[j] - positions[i]
                direction_norm = torch.norm(direction)
                if direction_norm > 0:
                    direction = direction / direction_norm
                
                # Create edge feature vector:
                # [distance, dx, dy, dz]
                edge_attr_list.append([dist / max_radius] + direction.tolist())
    
    # If no edges were found, create self-loops to avoid errors
    if not src_list:
        for i in range(num_sites):
            src_list.append(i)
            dst_list.append(i)
            edge_attr_list.append([0.0, 0.0, 0.0, 0.0])  # Self-loop has zero features
    
    # Convert to tensors
    edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
    edge_attr = torch.tensor(edge_attr_list, dtype=torch.float)
    x = torch.tensor(node_features, dtype=torch.float)
    
    # Extract global structure features
    structure_feats = extract_structure_features(structure)
    structure_features = torch.tensor([
        structure_feats['spacegroup'] / 230,  # Normalize by max space group
        structure_feats['point_group_id'] / 100,
        float(structure_feats['has_inversion']),
        structure_feats['mag_fraction'],
        structure_feats['a'] / 20.0,  # Normalize lattice parameters
        structure_feats['b'] / 20.0,
        structure_feats['c'] / 20.0,
        structure_feats['alpha'] / 180.0,
        structure_feats['beta'] / 180.0,
        structure_feats['gamma'] / 180.0,
        structure_feats['volume'] / 1000.0,
        structure_feats['density'] / 20.0
    ], dtype=torch.float).unsqueeze(0).repeat(num_sites, 1)
    
    return x, edge_index, edge_attr, positions, structure_features


def process_structures(structures, materials_ids=None, formulas=None):
    """
    Process a list of structures into graph data objects.
    
    Args:
        structures: List of pymatgen Structure objects
        materials_ids: List of Material Project IDs (optional)
        formulas: List of chemical formulas (optional)
        
    Returns:
        data_list: List of MaterialData objects
    """
    data_list = []
    
    for i, structure in enumerate(structures):
        print(f"Processing structure {i+1}/{len(structures)}", end="\r", flush=True)
        
        try:
            # Extract magnetic ordering
            mag_analyzer = CollinearMagneticStructureAnalyzer(structure)
            ordering = mag_analyzer.ordering.name
            magnetic_y = ORDER_ENCODE.get(ordering, 0)
            
            # Get material ID and formula if provided
            material_id = materials_ids[i] if materials_ids else f"struct_{i}"
            formula = formulas[i] if formulas else structure.composition.reduced_formula
            
            # Convert structure to graph representation
            x, edge_index, edge_attr, pos, structure_features = structure_to_graph(
                structure, max_radius=PARAMS['max_radius']
            )
            
            # Create data object
            data = MaterialData(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                pos=pos,
                structure_features=structure_features,
                magnetic_y=torch.tensor([magnetic_y], dtype=torch.long),
                topological_y=torch.tensor([0], dtype=torch.long),  # Default to non-topological
                material_id=material_id,
                formula=formula,
                num_atoms=len(structure)
            )
            data.num_nodes = x.size(0) 
            data_list.append(data)
            
        except Exception as e:
            print(f"\nError processing structure {i}: {e}")
            continue
    
    print(f"\nProcessed {len(data_list)}/{len(structures)} structures successfully")
    return data_list


def fetch_topological_labels(data_list, api_key):
    """
    Fetch topological classifications for materials using the Materials Project API.
    
    Args:
        data_list: List of MaterialData objects
        api_key: Materials Project API key
        
    Returns:
        data_list: Updated list with topological labels
    """
    materials_with_topo_info = 0
    
    with MPRester(api_key=api_key) as mpr:
        for i, data in enumerate(data_list):
            material_id = data.material_id
            
            try:
                # Skip if not a real MP ID
                if not material_id.startswith("mp-"):
                    continue
                    
                # Query Materials Project API
                result = mpr.materials.summary.search(material_ids=[material_id])
                
                if result and hasattr(result[0], "is_topological"):
                    label = result[0].is_topological
                    data.topological_y = torch.tensor([TOPO_ENCODE[label]], dtype=torch.long)
                    materials_with_topo_info += 1
                    print(f"Found topological info for {material_id}: {label}")
                    
            except Exception as e:
                print(f"Error retrieving topological info for {material_id}: {e}")
    
    print(f"Added topological labels for {materials_with_topo_info} materials")
    return data_list


def load_and_process_data(mp_structures_file, api_key=None):
    """
    Load and process materials data.
    
    Args:
        mp_structures_file: Path to saved structures file
        api_key: Materials Project API key for topological data
        
    Returns:
        processed_data: List of processed MaterialData objects
    """
    print(f"Loading structures from {mp_structures_file}")
    mp_structures_dict = torch.load(mp_structures_file, weights_only=False)
    
    structures = mp_structures_dict['structures']
    materials_ids = mp_structures_dict['materials_id']
    formulas = mp_structures_dict['formulas']
    
    print(f"Loaded {len(structures)} structures")
    
    # Process structures to graph data
    processed_data = process_structures(structures, materials_ids, formulas)
    
    # Fetch topological labels if API key is provided
    if api_key:
        processed_data = fetch_topological_labels(processed_data, api_key)
    
    return processed_data


def prepare_datasets(data_list, train_ratio=0.8, val_ratio=0.1):
    """
    Split data into training, validation and test sets.
    
    Args:
        data_list: List of MaterialData objects
        train_ratio: Ratio of training data
        val_ratio: Ratio of validation data
        
    Returns:
        train_loader, val_loader, test_loader: DataLoader objects
    """
    # Balance datasets by magnetic ordering
    data_by_order = {0: [], 1: [], 2: []}
    
    for data in data_list:
        magnetic_y = data.magnetic_y.item()
        data_by_order[magnetic_y].append(data)
    
    print(f"Data distribution by magnetic ordering:")
    for order, items in data_by_order.items():
        order_name = {0: "NM", 1: "AFM", 2: "FM/FiM"}[order]
        print(f"  {order_name}: {len(items)} structures")
    
    # Find minimum count to ensure balanced classes
    min_count = min(len(items) for items in data_by_order.values())
    balanced_data = []
    
    for order, items in data_by_order.items():
        random.shuffle(items)
        balanced_data.extend(items[:min_count])
    
    # Shuffle balanced dataset
    random.shuffle(balanced_data)
    
    # Split into train/val/test
    n = len(balanced_data)
    train_size = int(train_ratio * n)
    val_size = int(val_ratio * n)
    
    train_data = balanced_data[:train_size]
    val_data = balanced_data[train_size:train_size + val_size]
    test_data = balanced_data[train_size + val_size:]
    
    print(f"Dataset splits: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}")
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=PARAMS['batch_size'], shuffle=True)
    val_loader = DataLoader(val_data, batch_size=PARAMS['batch_size'])
    test_loader = DataLoader(test_data, batch_size=PARAMS['batch_size'])
    
    return train_loader, val_loader, test_loader


#===============================
# MODEL ARCHITECTURE
#===============================

class AttentionLayer(MessagePassing):
    """
    Graph attention layer for materials science applications.
    """
    def __init__(self, hidden_dim, num_heads=4, edge_dim=4):
        super().__init__(aggr='add')
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # Query, key, value projections
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        
        # Edge feature projection
        self.edge_proj = nn.Sequential(
            nn.Linear(edge_dim, 32),
            nn.ReLU(),
            nn.Linear(32, num_heads)
        )
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        
        # Layer normalization
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.GELU(),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
    
    def forward(self, x, edge_index, edge_attr):
        # First attention block with residual connection
        identity = x
        out = self.ln1(x)
        out = self._attention_block(out, edge_index, edge_attr)
        out = out + identity
        
        # Feed-forward block with residual connection
        identity = out
        out = self.ln2(out)
        out = self.ffn(out)
        return out + identity
    
    def _attention_block(self, x, edge_index, edge_attr):
        # Project inputs to queries, keys, values
        q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
        
        # Process edge attributes
        edge_weights = self.edge_proj(edge_attr).unsqueeze(-1)  # [E, num_heads, 1]
        
        # Propagate through the graph
        out = self.propagate(edge_index, q=q, k=k, v=v, edge_weights=edge_weights)
        
        # Project output back to original dimension
        return self.output_proj(out.view(-1, self.hidden_dim))
    
    def message(self, q_i, k_j, v_j, edge_weights, index, ptr, size_i):
        # Compute attention scores
        attention = (q_i * k_j).sum(dim=-1) / math.sqrt(self.head_dim)
        
        # Multiply by edge weights (based on edge features)
        attention = attention.unsqueeze(-1) * edge_weights
        
        # Apply softmax to normalize scores
        alpha = F.softmax(attention, dim=0)
        
        # Apply attention weights to values
        return alpha * v_j


class MagneticTopologicalTransformer(nn.Module):
    """
    Transformer-based model for predicting magnetic ordering and topological class.
    """
    def __init__(self, node_dim=7, structure_dim=12, hidden_dim=128, edge_dim=4, num_heads=4, num_layers=3):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Input projections
        self.node_proj = nn.Linear(node_dim, hidden_dim)
        self.structure_proj = nn.Linear(structure_dim, hidden_dim)
        
        # Attention layers
        self.attention_layers = nn.ModuleList([
            AttentionLayer(hidden_dim, num_heads, edge_dim)
            for _ in range(num_layers)
        ])
        
        # Output heads
        self.magnetic_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 3)  # NM, AFM, FM/FiM
        )
        
        self.topological_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 2)  # Not TI, TI
        )
    
    def forward(self, x, edge_index, edge_attr, structure_features, batch):
        # Project node and structure features
        h_nodes = self.node_proj(x)
        h_struct = self.structure_proj(structure_features)
        
        # Combine node features with structure features
        h = h_nodes + h_struct
        
        # Apply attention layers
        for layer in self.attention_layers:
            h = layer(h, edge_index, edge_attr)
        
        # Global pooling
        if batch is None:
            batch = torch.zeros(h.size(0), dtype=torch.long, device=h.device)
        
        h_global = scatter_mean(h, batch, dim=0)
        
        # Predict magnetic ordering and topological class
        magnetic_pred = self.magnetic_head(h_global)
        topological_pred = self.topological_head(h_global)
        
        return magnetic_pred, topological_pred


#===============================
# TRAINING AND EVALUATION
#===============================

class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=15, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_score is None:
            self.best_score = val_loss
            return True
            
        if val_loss > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False
        else:
            self.best_score = val_loss
            self.counter = 0
            return True


def train_epoch(model, dataloader, optimizer, device, alpha=0.5):
    """Train the model for one epoch"""
    model.train()
    total_loss = 0
    magnetic_loss_total = 0
    topological_loss_total = 0
    
    # Class weights to handle imbalance
    magnetic_class_weights = torch.tensor([1.0, 1.2, 1.2], device=device)
    topological_class_weights = torch.tensor([1.0, 1.5], device=device)
    
    # Loss functions
    magnetic_criterion = nn.CrossEntropyLoss(weight=magnetic_class_weights)
    topological_criterion = nn.CrossEntropyLoss(weight=topological_class_weights)
    
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        magnetic_pred, topological_pred = model(
            batch.x, 
            batch.edge_index, 
            batch.edge_attr, 
            batch.structure_features,
            batch.batch
        )
        
        # Compute losses
        magnetic_loss = magnetic_criterion(magnetic_pred, batch.magnetic_y.squeeze())
        topological_loss = topological_criterion(topological_pred, batch.topological_y.squeeze())
        
        # Combined loss with weighting parameter alpha
        loss = alpha * magnetic_loss + (1 - alpha) * topological_loss
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Track losses
        total_loss += loss.item() * batch.num_graphs
        magnetic_loss_total += magnetic_loss.item() * batch.num_graphs
        topological_loss_total += topological_loss.item() * batch.num_graphs
    
    # Calculate average losses
    avg_loss = total_loss / len(dataloader.dataset)
    avg_magnetic_loss = magnetic_loss_total / len(dataloader.dataset)
    avg_topological_loss = topological_loss_total / len(dataloader.dataset)
    
    return avg_loss, avg_magnetic_loss, avg_topological_loss


def validate(model, dataloader, device, alpha=0.5):
    """Validate the model"""
    model.eval()
    total_loss = 0
    magnetic_loss_total = 0
    topological_loss_total = 0
    
    magnetic_preds = []
    magnetic_targets = []
    topological_preds = []
    topological_targets = []
    
    magnetic_class_weights = torch.tensor([1.0, 1.2, 1.2], device=device)
    topological_class_weights = torch.tensor([1.0, 1.5], device=device)
    
    magnetic_criterion = nn.CrossEntropyLoss(weight=magnetic_class_weights)
    topological_criterion = nn.CrossEntropyLoss(weight=topological_class_weights)
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            
            # Forward pass
            magnetic_pred, topological_pred = model(
                batch.x, 
                batch.edge_index, 
                batch.edge_attr, 
                batch.structure_features,
                batch.batch
            )
            
            # Compute losses
            magnetic_loss = magnetic_criterion(magnetic_pred, batch.magnetic_y.squeeze())
            topological_loss = topological_criterion(topological_pred, batch.topological_y.squeeze())
            
            # Combined loss
            loss = alpha * magnetic_loss + (1 - alpha) * topological_loss
            
            # Track losses
            total_loss += loss.item() * batch.num_graphs
            magnetic_loss_total += magnetic_loss.item() * batch.num_graphs
            topological_loss_total += topological_loss.item() * batch.num_graphs
            
            # Track predictions for metrics
            magnetic_preds.append(magnetic_pred.argmax(dim=1).cpu())
            magnetic_targets.append(batch.magnetic_y.squeeze().cpu())
            topological_preds.append(topological_pred.argmax(dim=1).cpu())
            topological_targets.append(batch.topological_y.squeeze().cpu())
    
    # Concatenate predictions and targets
    magnetic_preds = torch.cat(magnetic_preds)
    magnetic_targets = torch.cat(magnetic_targets)
    topological_preds = torch.cat(topological_preds)
    topological_targets = torch.cat(topological_targets)
    
    # Calculate metrics
    magnetic_acc = accuracy_score(magnetic_targets, magnetic_preds)
    magnetic_f1 = f1_score(magnetic_targets, magnetic_preds, average='macro')
    topological_acc = accuracy_score(topological_targets, topological_preds)
    topological_f1 = f1_score(topological_targets, topological_preds, average='macro')
    
    # Calculate average losses
    avg_loss = total_loss / len(dataloader.dataset)
    avg_magnetic_loss = magnetic_loss_total / len(dataloader.dataset)
    avg_topological_loss = topological_loss_total / len(dataloader.dataset)
    
    return (avg_loss, avg_magnetic_loss, avg_topological_loss, 
            magnetic_acc, magnetic_f1, topological_acc, topological_f1)


def train_model(model, train_loader, val_loader, device, model_save_path="./model"):
    """
    Train the model with early stopping.
    
    Args:
        model: MagneticTopologicalTransformer model
        train_loader: Training data loader
        val_loader: Validation data loader
        device: Device to train on (CPU or GPU)
        model_save_path: Directory to save model checkpoints
        
    Returns:
        model: Trained model
        history: Training history
    """
    # Create save directory if it doesn't exist
    os.makedirs(model_save_path, exist_ok=True)
    
    # Initialize optimizer and scheduler
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=PARAMS['lr'], 
        weight_decay=PARAMS['weight_decay']
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=10
    )
    
    # Initialize early stopping
    early_stopping = EarlyStopping(patience=PARAMS['early_stop_patience'])
    
    # Initialize best model tracker
    best_val_loss = float('inf')
    best_model_path = os.path.join(model_save_path, "best_model.pt")
    
    # Initialize training history
    history = {
        'train_loss': [], 
        'val_loss': [],
        'train_magnetic_loss': [], 
        'val_magnetic_loss': [],
        'train_topological_loss': [], 
        'val_topological_loss': [],
        'magnetic_acc': [],
        'magnetic_f1': [],
        'topological_acc': [],
        'topological_f1': []
    }
    
    # Training loop
    start_time = time.time()
    print("Starting training...")
    
    for epoch in range(PARAMS['max_epochs']):
        epoch_start = time.time()
        
        # Train one epoch
        train_loss, train_magnetic_loss, train_topological_loss = train_epoch(
            model, train_loader, optimizer, device
        )
        
        # Validate
        val_metrics = validate(model, val_loader, device)
        (val_loss, val_magnetic_loss, val_topological_loss, 
         magnetic_acc, magnetic_f1, topological_acc, topological_f1) = val_metrics
        
        # Update learning rate
        scheduler.step(val_loss)  # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_magnetic_loss'].append(train_magnetic_loss)
        history['val_magnetic_loss'].append(val_magnetic_loss)
        history['train_topological_loss'].append(train_topological_loss)
        history['val_topological_loss'].append(val_topological_loss)
        history['magnetic_acc'].append(magnetic_acc)
        history['magnetic_f1'].append(magnetic_f1)
        history['topological_acc'].append(topological_acc)
        history['topological_f1'].append(topological_f1)

        # Print progress
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{PARAMS['max_epochs']} - {epoch_time:.2f}s - "
            f"Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - "
            f"Magnetic Loss: {train_magnetic_loss:.4f}/{val_magnetic_loss:.4f} - "
            f"Topological Loss: {train_topological_loss:.4f}/{val_topological_loss:.4f} - "
            f"Magnetic Acc: {magnetic_acc:.4f} - Magnetic F1: {magnetic_f1:.4f} - "
            f"Topological Acc: {topological_acc:.4f} - Topological F1: {topological_f1:.4f}")

        # Save checkpoint if best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'history': history,
                'best_val_loss': best_val_loss
            }, os.path.join(PARAMS['checkpoint_dir'], 'best_model.pth'))
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")
            
        # Early stopping check
        if val_loss > best_val_loss and epoch > PARAMS['min_epochs']:
            early_stop_counter += 1
            if early_stop_counter >= PARAMS['patience']:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
        else:
            early_stop_counter = 0


def test_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    test_loss = 0
    magnetic_loss = 0
    topological_loss = 0
    
    all_magnetic_preds = []
    all_magnetic_labels = []
    all_topological_preds = []
    all_topological_labels = []
    
    with torch.no_grad():  # No gradient calculation needed for testing
        for batch_idx, (data, magnetic_target, topological_target) in enumerate(test_loader):
            data = data.to(device)
            magnetic_target = magnetic_target.to(device)
            topological_target = topological_target.to(device)
            
            # Forward pass
            magnetic_output, topological_output = model(data)
            
            # Calculate losses
            batch_magnetic_loss = F.binary_cross_entropy_with_logits(magnetic_output, magnetic_target)
            batch_topological_loss = F.binary_cross_entropy_with_logits(topological_output, topological_target)
            batch_loss = batch_magnetic_loss + batch_topological_loss
            
            # Accumulate losses
            test_loss += batch_loss.item()
            magnetic_loss += batch_magnetic_loss.item()
            topological_loss += batch_topological_loss.item()
            
            # Store predictions and labels for metrics calculation
            magnetic_preds = (torch.sigmoid(magnetic_output) > 0.5).float().cpu().numpy()
            topological_preds = (torch.sigmoid(topological_output) > 0.5).float().cpu().numpy()
            
            all_magnetic_preds.extend(magnetic_preds)
            all_magnetic_labels.extend(magnetic_target.cpu().numpy())
            all_topological_preds.extend(topological_preds)
            all_topological_labels.extend(topological_target.cpu().numpy())
    
    # Calculate average losses
    test_loss /= len(test_loader)
    magnetic_loss /= len(test_loader)
    topological_loss /= len(test_loader)
    
    # Convert lists to arrays for scikit-learn metrics
    all_magnetic_preds = np.array(all_magnetic_preds)
    all_magnetic_labels = np.array(all_magnetic_labels)
    all_topological_preds = np.array(all_topological_preds)
    all_topological_labels = np.array(all_topological_labels)
    
    # Calculate metrics
    magnetic_acc = accuracy_score(all_magnetic_labels, all_magnetic_preds)
    magnetic_f1 = f1_score(all_magnetic_labels, all_magnetic_preds, average='weighted')
    topological_acc = accuracy_score(all_topological_labels, all_topological_preds)
    topological_f1 = f1_score(all_topological_labels, all_topological_preds, average='weighted')
    
    # Print results
    print(f"Test Results:")
    print(f"Total Loss: {test_loss:.4f}")
    print(f"Magnetic Loss: {magnetic_loss:.4f}, Accuracy: {magnetic_acc:.4f}, F1 Score: {magnetic_f1:.4f}")
    print(f"Topological Loss: {topological_loss:.4f}, Accuracy: {topological_acc:.4f}, F1 Score: {topological_f1:.4f}")
    
    # Return all metrics
    return {
        'test_loss': test_loss,
        'magnetic_loss': magnetic_loss,
        'topological_loss': topological_loss,
        'magnetic_acc': magnetic_acc,
        'magnetic_f1': magnetic_f1,
        'topological_acc': topological_acc,
        'topological_f1': topological_f1,
        'magnetic_preds': all_magnetic_preds,
        'magnetic_labels': all_magnetic_labels,
        'topological_preds': all_topological_preds,
        'topological_labels': all_topological_labels
    }


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import os
import time
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from pymatgen.core import Structure

# Custom dataset for materials data
class MaterialsDataset(Dataset):
    def __init__(self, struct_dict, magnetic_labels, topological_labels, transform=None):
        """
        Args:
            struct_dict (dict): Dictionary containing materials data with keys:
                - structures: list of pymatgen Structure objects
                - materials_id: list of material IDs
                - nsites: list of number of sites
                - formulas: list of chemical formulas
                - order: list of order parameters
            magnetic_labels (array): Binary labels for magnetic properties
            topological_labels (array): Binary labels for topological properties
            transform (callable, optional): Optional transform to be applied on features
        """
        self.structures = struct_dict["structures"]
        self.materials_ids = struct_dict["materials_id"]
        self.nsites = struct_dict["nsites"]
        self.formulas = struct_dict["formulas"]
        self.order = struct_dict["order"]
        
        self.magnetic_labels = magnetic_labels
        self.topological_labels = topological_labels
        self.transform = transform
        
    def __len__(self):
        return len(self.structures)
    
    def __getitem__(self, idx):
        # Extract relevant features from the structure
        structure = self.structures[idx]
        
        # Feature extraction from structure
        features = self._extract_features(structure, idx)
        
        magnetic_label = self.magnetic_labels[idx]
        topological_label = self.topological_labels[idx]
        
        if self.transform:
            features = self.transform(features)
            
        return torch.tensor(features, dtype=torch.float32), torch.tensor(magnetic_label, dtype=torch.float32), torch.tensor(topological_label, dtype=torch.float32)
    
    def _extract_features(self, structure, idx):
        """Extract features from a pymatgen Structure object and other available data"""
        # Here you can implement feature extraction based on the structure
        # This is a simple example - you'll want to enhance this based on your domain knowledge
        
        # Basic structural features
        num_sites = self.nsites[idx]
        order_param = self.order[idx]
        
        # Get lattice parameters
        a, b, c = structure.lattice.abc
        alpha, beta, gamma = structure.lattice.angles
        volume = structure.volume
        density = structure.density
        
        # Element-based features (example)
        elements = [site.specie.symbol for site in structure]
        unique_elements = set(elements)
        num_elements = len(unique_elements)
        
        # Count of each element
        element_counts = {}
        for element in elements:
            if element in element_counts:
                element_counts[element] += 1
            else:
                element_counts[element] = 1
        
        # Statistical features of atomic properties
        atomic_numbers = [site.specie.Z for site in structure]
        avg_atomic_number = np.mean(atomic_numbers)
        std_atomic_number = np.std(atomic_numbers)
        
        # Combine all features
        features = [
            num_sites, 
            order_param,
            a, b, c, 
            alpha, beta, gamma,
            volume,
            density,
            num_elements,
            avg_atomic_number,
            std_atomic_number
        ]
        
        # You can add more domain-specific features here
        
        return np.array(features, dtype=np.float32)

# Model Definition - Multi-task Neural Network
class MagneticTopologicalModel(nn.Module):
    def __init__(self, input_dim, hidden_dims=[256, 128, 64], dropout_rate=0.3):
        super(MagneticTopologicalModel, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        
        # Shared layers
        layers = []
        prev_dim = input_dim
        for dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(dim))
            layers.append(nn.Dropout(dropout_rate))
            prev_dim = dim
        
        self.shared_layers = nn.Sequential(*layers)
        
        # Task-specific heads
        self.magnetic_head = nn.Linear(hidden_dims[-1], 1)
        self.topological_head = nn.Linear(hidden_dims[-1], 1)
        
    def forward(self, x):
        # Forward pass through shared layers
        shared_features = self.shared_layers(x)
        
        # Task-specific predictions
        magnetic_output = self.magnetic_head(shared_features)
        topological_output = self.topological_head(shared_features)
        
        return magnetic_output.squeeze(), topological_output.squeeze()

def train_epoch(model, train_loader, optimizer, device):
    model.train()
    train_loss = 0
    magnetic_loss = 0
    topological_loss = 0
    
    all_magnetic_preds = []
    all_magnetic_labels = []
    all_topological_preds = []
    all_topological_labels = []
    
    for batch_idx, (data, magnetic_target, topological_target) in enumerate(train_loader):
        data = data.to(device)
        magnetic_target = magnetic_target.to(device)
        topological_target = topological_target.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        magnetic_output, topological_output = model(data)
        
        # Calculate losses
        batch_magnetic_loss = F.binary_cross_entropy_with_logits(magnetic_output, magnetic_target)
        batch_topological_loss = F.binary_cross_entropy_with_logits(topological_output, topological_target)
        batch_loss = batch_magnetic_loss + batch_topological_loss
        
        # Backward pass
        batch_loss.backward()
        optimizer.step()
        
        # Accumulate losses
        train_loss += batch_loss.item()
        magnetic_loss += batch_magnetic_loss.item()
        topological_loss += batch_topological_loss.item()
        
        # Store predictions and labels for metrics calculation
        magnetic_preds = (torch.sigmoid(magnetic_output) > 0.5).float().cpu().numpy()
        topological_preds = (torch.sigmoid(topological_output) > 0.5).float().cpu().numpy()
        
        all_magnetic_preds.extend(magnetic_preds)
        all_magnetic_labels.extend(magnetic_target.cpu().numpy())
        all_topological_preds.extend(topological_preds)
        all_topological_labels.extend(topological_target.cpu().numpy())
    
    # Calculate average losses
    train_loss /= len(train_loader)
    magnetic_loss /= len(train_loader)
    topological_loss /= len(train_loader)
    
    # Calculate metrics
    magnetic_acc = accuracy_score(all_magnetic_labels, all_magnetic_preds)
    magnetic_f1 = f1_score(all_magnetic_labels, all_magnetic_preds, average='weighted')
    topological_acc = accuracy_score(all_topological_labels, all_topological_preds)
    topological_f1 = f1_score(all_topological_labels, all_topological_preds, average='weighted')
    
    return train_loss, magnetic_loss, topological_loss, magnetic_acc, magnetic_f1, topological_acc, topological_f1

def validate(model, val_loader, device):
    model.eval()
    val_loss = 0
    magnetic_loss = 0
    topological_loss = 0
    
    all_magnetic_preds = []
    all_magnetic_labels = []
    all_topological_preds = []
    all_topological_labels = []
    
    with torch.no_grad():
        for batch_idx, (data, magnetic_target, topological_target) in enumerate(val_loader):
            data = data.to(device)
            magnetic_target = magnetic_target.to(device)
            topological_target = topological_target.to(device)
            
            # Forward pass
            magnetic_output, topological_output = model(data)
            
            # Calculate losses
            batch_magnetic_loss = F.binary_cross_entropy_with_logits(magnetic_output, magnetic_target)
            batch_topological_loss = F.binary_cross_entropy_with_logits(topological_output, topological_target)
            batch_loss = batch_magnetic_loss + batch_topological_loss
            
            # Accumulate losses
            val_loss += batch_loss.item()
            magnetic_loss += batch_magnetic_loss.item()
            topological_loss += batch_topological_loss.item()
            
            # Store predictions and labels for metrics calculation
            magnetic_preds = (torch.sigmoid(magnetic_output) > 0.5).float().cpu().numpy()
            topological_preds = (torch.sigmoid(topological_output) > 0.5).float().cpu().numpy()
            
            all_magnetic_preds.extend(magnetic_preds)
            all_magnetic_labels.extend(magnetic_target.cpu().numpy())
            all_topological_preds.extend(topological_preds)
            all_topological_labels.extend(topological_target.cpu().numpy())
    
    # Calculate average losses
    val_loss /= len(val_loader)
    magnetic_loss /= len(val_loader)
    topological_loss /= len(val_loader)
    
    # Calculate metrics
    magnetic_acc = accuracy_score(all_magnetic_labels, all_magnetic_preds)
    magnetic_f1 = f1_score(all_magnetic_labels, all_magnetic_preds, average='weighted')
    topological_acc = accuracy_score(all_topological_labels, all_topological_preds)
    topological_f1 = f1_score(all_topological_labels, all_topological_preds, average='weighted')
    
    return val_loss, magnetic_loss, topological_loss, magnetic_acc, magnetic_f1, topological_acc, topological_f1

def test_model(model, test_loader, device):
    model.eval()
    test_loss = 0
    magnetic_loss = 0
    topological_loss = 0
    
    all_magnetic_preds = []
    all_magnetic_labels = []
    all_topological_preds = []
    all_topological_labels = []
    all_material_ids = []  # To track which materials were predicted correctly/incorrectly
    
    with torch.no_grad():
        for batch_idx, (data, magnetic_target, topological_target) in enumerate(test_loader):
            data = data.to(device)
            magnetic_target = magnetic_target.to(device)
            topological_target = topological_target.to(device)
            
            # Forward pass
            magnetic_output, topological_output = model(data)
            
            # Calculate losses
            batch_magnetic_loss = F.binary_cross_entropy_with_logits(magnetic_output, magnetic_target)
            batch_topological_loss = F.binary_cross_entropy_with_logits(topological_output, topological_target)
            batch_loss = batch_magnetic_loss + batch_topological_loss
            
            # Accumulate losses
            test_loss += batch_loss.item()
            magnetic_loss += batch_magnetic_loss.item()
            topological_loss += batch_topological_loss.item()
            
            # Store predictions and labels for metrics calculation
            magnetic_preds = (torch.sigmoid(magnetic_output) > 0.5).float().cpu().numpy()
            topological_preds = (torch.sigmoid(topological_output) > 0.5).float().cpu().numpy()
            
            all_magnetic_preds.extend(magnetic_preds)
            all_magnetic_labels.extend(magnetic_target.cpu().numpy())
            all_topological_preds.extend(topological_preds)
            all_topological_labels.extend(topological_target.cpu().numpy())
            
            # Track material IDs for this batch (if available in the dataset)
            # all_material_ids.extend([test_loader.dataset.materials_ids[idx] for idx in range(batch_idx * test_loader.batch_size, min((batch_idx + 1) * test_loader.batch_size, len(test_loader.dataset)))])
    
    # Calculate average losses
    test_loss /= len(test_loader)
    magnetic_loss /= len(test_loader)
    topological_loss /= len(test_loader)
    
    # Convert lists to arrays for scikit-learn metrics
    all_magnetic_preds = np.array(all_magnetic_preds)
    all_magnetic_labels = np.array(all_magnetic_labels)
    all_topological_preds = np.array(all_topological_preds)
    all_topological_labels = np.array(all_topological_labels)
    
    # Calculate metrics
    magnetic_acc = accuracy_score(all_magnetic_labels, all_magnetic_preds)
    magnetic_f1 = f1_score(all_magnetic_labels, all_magnetic_preds, average='weighted')
    topological_acc = accuracy_score(all_topological_labels, all_topological_preds)
    topological_f1 = f1_score(all_topological_labels, all_topological_preds, average='weighted')
    
    # Print results
    print(f"Test Results:")
    print(f"Total Loss: {test_loss:.4f}")
    print(f"Magnetic Loss: {magnetic_loss:.4f}, Accuracy: {magnetic_acc:.4f}, F1 Score: {magnetic_f1:.4f}")
    print(f"Topological Loss: {topological_loss:.4f}, Accuracy: {topological_acc:.4f}, F1 Score: {topological_f1:.4f}")
    
    # Calculate and plot confusion matrices
    plot_confusion_matrices(all_magnetic_labels, all_magnetic_preds, all_topological_labels, all_topological_preds)
    
    return {
        'test_loss': test_loss,
        'magnetic_loss': magnetic_loss,
        'topological_loss': topological_loss,
        'magnetic_acc': magnetic_acc,
        'magnetic_f1': magnetic_f1,
        'topological_acc': topological_acc,
        'topological_f1': topological_f1,
        'magnetic_preds': all_magnetic_preds,
        'magnetic_labels': all_magnetic_labels,
        'topological_preds': all_topological_preds,
        'topological_labels': all_topological_labels
    }

def plot_confusion_matrices(magnetic_labels, magnetic_preds, topological_labels, topological_preds):
    # Create figure with two subplots
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    
    # Plot magnetic confusion matrix
    magnetic_cm = confusion_matrix(magnetic_labels, magnetic_preds)
    sns.heatmap(magnetic_cm, annot=True, fmt="d", cmap="Blues", ax=axes[0])
    axes[0].set_title("Magnetic Property Confusion Matrix")
    axes[0].set_xlabel("Predicted")
    axes[0].set_ylabel("True")
    
    # Plot topological confusion matrix
    topological_cm = confusion_matrix(topological_labels, topological_preds)
    sns.heatmap(topological_cm, annot=True, fmt="d", cmap="Greens", ax=axes[1])
    axes[1].set_title("Topological Property Confusion Matrix")
    axes[1].set_xlabel("Predicted")
    axes[1].set_ylabel("True")
    
    plt.tight_layout()
    plt.savefig("confusion_matrices.png")
    plt.show()
    
    # Print additional metrics
    print("\nDetailed Classification Results:")
    print("Magnetic Property:")
    print(f"True Positive: {magnetic_cm[1, 1]}")
    print(f"False Positive: {magnetic_cm[0, 1]}")
    print(f"True Negative: {magnetic_cm[0, 0]}")
    print(f"False Negative: {magnetic_cm[1, 0]}")
    
    print("\nTopological Property:")
    print(f"True Positive: {topological_cm[1, 1]}")
    print(f"False Positive: {topological_cm[0, 1]}")
    print(f"True Negative: {topological_cm[0, 0]}")
    print(f"False Negative: {topological_cm[1, 0]}")

# Main training function
def train_model(train_struct_dict, val_struct_dict, train_magnetic_labels, train_topological_labels, 
               val_magnetic_labels, val_topological_labels, params):
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create datasets
    train_dataset = MaterialsDataset(train_struct_dict, train_magnetic_labels, train_topological_labels)
    val_dataset = MaterialsDataset(val_struct_dict, val_magnetic_labels, val_topological_labels)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=params['batch_size'], shuffle=False)
    
    # Determine input dimension based on feature extraction
    sample_features = train_dataset[0][0]
    input_dim = sample_features.shape[0]
    
    # Initialize model
    model = MagneticTopologicalModel(input_dim, hidden_dims=params['hidden_dims'], 
                                     dropout_rate=params['dropout_rate']).to(device)
    
    # Initialize optimizer and scheduler
    optimizer = Adam(model.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay'])
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    
    # Initialize tracking variables
    best_val_loss = float('inf')
    early_stop_counter = 0
    history = {
        'train_loss': [], 'val_loss': [],
        'train_magnetic_loss': [], 'val_magnetic_loss': [],
        'train_topological_loss': [], 'val_topological_loss': [],
        'magnetic_acc': [], 'magnetic_f1': [],
        'topological_acc': [], 'topological_f1': []
    }
    
    # Create checkpoint directory if it doesn't exist
    os.makedirs(params['checkpoint_dir'], exist_ok=True)
    
    # Training loop
    for epoch in range(params['max_epochs']):
        epoch_start = time.time()
        
        # Train one epoch
        train_loss, train_magnetic_loss, train_topological_loss, magnetic_acc, magnetic_f1, topological_acc, topological_f1 = train_epoch(model, train_loader, optimizer, device)
        
        # Validate
        val_loss, val_magnetic_loss, val_topological_loss, val_magnetic_acc, val_magnetic_f1, val_topological_acc, val_topological_f1 = validate(model, val_loader, device)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_magnetic_loss'].append(train_magnetic_loss)
        history['val_magnetic_loss'].append(val_magnetic_loss)
        history['train_topological_loss'].append(train_topological_loss)
        history['val_topological_loss'].append(val_topological_loss)
        history['magnetic_acc'].append(val_magnetic_acc)
        history['magnetic_f1'].append(val_magnetic_f1)
        history['topological_acc'].append(val_topological_acc)
        history['topological_f1'].append(val_topological_f1)
        
        # Print progress
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{params['max_epochs']} - {epoch_time:.2f}s - "
              f"Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - "
              f"Magnetic Loss: {train_magnetic_loss:.4f}/{val_magnetic_loss:.4f} - "
              f"Topological Loss: {train_topological_loss:.4f}/{val_topological_loss:.4f} - "
              f"Magnetic Acc: {val_magnetic_acc:.4f} - Magnetic F1: {val_magnetic_f1:.4f} - "
              f"Topological Acc: {val_topological_acc:.4f} - Topological F1: {val_topological_f1:.4f}")
        
        # Save checkpoint if best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'history': history,
                'best_val_loss': best_val_loss
            }, os.path.join(params['checkpoint_dir'], 'best_model.pth'))
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")
            
            # Reset early stopping counter
            early_stop_counter = 0
        else:
            # Increment early stopping counter
            early_stop_counter += 1
            if early_stop_counter >= params['patience'] and epoch > params['min_epochs']:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
    
    # Plot training history
    plot_training_history(history)
    
    return model, history

def plot_training_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Create a figure with subplots
    fig, axs = plt.subplots(2, 2, figsize=(16, 12))
    
    # Plot losses
    axs[0, 0].plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    axs[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    axs[0, 0].set_title('Total Loss')
    axs[0, 0].set_xlabel('Epochs')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend()
    
    # Plot task-specific losses
    axs[0, 1].plot(epochs, history['train_magnetic_loss'], 'b--', label='Train Magnetic Loss')
    axs[0, 1].plot(epochs, history['val_magnetic_loss'], 'r--', label='Val Magnetic Loss')
    axs[0, 1].plot(epochs, history['train_topological_loss'], 'g--', label='Train Topological Loss')
    axs[0, 1].plot(epochs, history['val_topological_loss'], 'm--', label='Val Topological Loss')
    axs[0, 1].set_title('Task-Specific Losses')
    axs[0, 1].set_xlabel('Epochs')
    axs[0, 1].set_ylabel('Loss')
    axs[0, 1].legend()
    
    # Plot magnetic metrics
    axs[1, 0].plot(epochs, history['magnetic_acc'], 'b-', label='Magnetic Accuracy')
    axs[1, 0].plot(epochs, history['magnetic_f1'], 'r-', label='Magnetic F1 Score')
    axs[1, 0].set_title('Magnetic Property Metrics')
    axs[1, 0].set_xlabel('Epochs')
    axs[1, 0].set_ylabel('Score')
    axs[1, 0].legend()
    
    # Plot topological metrics
    axs[1, 1].plot(epochs, history['topological_acc'], 'b-', label='Topological Accuracy')
    axs[1, 1].plot(epochs, history['topological_f1'], 'r-', label='Topological F1 Score')
    axs[1, 1].set_title('Topological Property Metrics')
    axs[1, 1].set_xlabel('Epochs')
    axs[1, 1].set_ylabel('Score')
    axs[1, 1].legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

# Main test function
def run_test(test_struct_dict, test_magnetic_labels, test_topological_labels, checkpoint_path, batch_size=32):
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create test dataset
    test_dataset = MaterialsDataset(test_struct_dict, test_magnetic_labels, test_topological_labels)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Determine input dimension based on feature extraction
    sample_features = test_dataset[0][0]
    input_dim = sample_features.shape[0]
    
    # Initialize model
    model = MagneticTopologicalModel(input_dim).to(device)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']+1} with validation loss: {checkpoint['best_val_loss']:.4f}")
    
    # Test the model
    test_results = test_model(model, test_loader, device)
    
    # You can add more analysis here based on test_results
    
    return test_results, model

# Usage example
if __name__ == "__main__":
    # Define hyperparameters
    PARAMS = {
        'batch_size': 32,
        'learning_rate': 0.001,
        'weight_decay': 1e-5,
        'hidden_dims': [256, 128, 64],
        'dropout_rate': 0.3,
        'max_epochs': 100,
        'min_epochs': 10,
        'patience': 10,
        'checkpoint_dir': './checkpoints'
    }
    
    # You would need to prepare these variables:
    # 1. train_struct_dict, val_struct_dict, test_struct_dict - dictionaries with your structure data
    # 2. train_magnetic_labels, train_topological_labels - binary labels for training
    # 3. val_magnetic_labels, val_topological_labels - binary labels for validation
    # 4. test_magnetic_labels, test_topological_labels - binary labels for testing
    
    # Example training call:
    # model, history = train_model(train_struct_dict, val_struct_dict, 
    #                            train_magnetic_labels, train_topological_labels,
    #                            val_magnetic_labels, val_topological_labels, 
    #                            PARAMS)
    
    # Example testing call:
    # test_results, model = run_test(test_struct_dict, test_magnetic_labels, test_topological_labels,
    #                               os.path.join(PARAMS['checkpoint_dir'], 'best_model.pth'))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# first attempt at multi task learning and integration with topological quantum chemistry database
import e3nn.util
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric
import torch_scatter

import e3nn
from e3nn import o3
#from e3nn.util.datatypes import DataPeriodicNeighbors
#from e3nn.nn._gate import GatedConvParityNetwork
#from e3nn.math._linalg import Kernel

import pymatgen as mg
import pymatgen.io
from pymatgen.core.structure import Structure
import pymatgen.analysis.magnetism.analyzer as pg
from mp_api.client import MPRester
import numpy as np
import pickle
from mendeleev import element
import matplotlib.pyplot as plt

from sklearn.metrics import average_precision_score
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

import io
import random
import math
import sys
import time, os
import datetime
from pathlib import Path
from dotenv import load_dotenv

In [3]:
load_dotenv()
load_dotenv(Path("/Users/abiralshakya/Documents/Research/Topological_Insulators_OnGithub/generative_nmti/Integrated_Magnetic_Topological/matprojectapi.env"))
api_key = os.getenv("MP_API_KEY")

In [4]:
import numpy as np
import torch
import pymatgen as pg
from pymatgen.ext.matproj import MPRester
from pymatgen.analysis.magnetism import CollinearMagneticStructureAnalyzer

order_list_mp = []
structures_list_mp = []
formula_list_mp = []
sites_list = []
id_list_mp = []
y_values_mp = []

order_encode = {"NM": 0, "AFM": 1, "FM": 2, "FiM": 2}
topo_encode = {False: 0, True: 1}

# Load data
mp_structures_dict = torch.load('/Users/abiralshakya/Documents/Research/Topological_Insulators_OnGithub/generative_nmti/Integrated_Magnetic_Topological/magnetic_order/preload_data/mp_structures_2025-04-07_12-52.pt', 
                                weights_only=False)

structures = mp_structures_dict['structures']
materials = mp_structures_dict['materials_id']
formulas = mp_structures_dict['formulas']
orders = mp_structures_dict['order']
nsites = mp_structures_dict['nsites']


In [5]:
order_list = []

for struct in structures:
    analyzer = CollinearMagneticStructureAnalyzer(struct)
    order_list.append(analyzer.ordering.name)

In [6]:
id_NM = [i for i, order in enumerate(order_list) if order == 'NM']
id_AFM = [i for i, order in enumerate(order_list) if order == 'AFM']
id_FM = [i for i, order in enumerate(order_list) if order in ['FM', 'FiM']]

# Shuffle
np.random.shuffle(id_NM)
np.random.shuffle(id_FM)
np.random.shuffle(id_AFM)

# Balance dataset (keeping AFM as reference size)
id_AFM, id_AFM_to_delete = np.split(id_AFM, [int(len(id_AFM))])
id_NM, id_NM_to_delete = np.split(id_NM, [int(1.2 * len(id_AFM))])
id_FM, id_FM_to_delete = np.split(id_FM, [int(1.2 * len(id_AFM))])

# Final index list
selected_ids = np.concatenate((id_NM, id_FM, id_AFM))
np.random.shuffle(selected_ids)

In [7]:
for idx in selected_ids:
    structure = structures[idx]
    material_id = materials[idx]
    formula = formulas[idx]
    nsite = nsites[idx]

    analyzer = CollinearMagneticStructureAnalyzer(structure)
    ordering = analyzer.ordering

    structures_list_mp.append(structure)
    id_list_mp.append(material_id)
    formula_list_mp.append(formula)
    sites_list.append(nsite)
    order_list_mp.append(ordering)


In [8]:
topo_encode = {False: 0, True: 1}
topo_labels = []

from mp_api.client import MPRester
m = MPRester(api_key=api_key)

for material_id in id_list_mp:
    try:
        result = m.materials.summary.search(material_ids=[material_id])
        if result and hasattr(result[0], "is_topological"):
            label = result[0].is_topological
            topo_labels.append(topo_encode[label])
        else:
            print(f"No topological info for {material_id}")
            topo_labels.append(topo_encode[False])
    except Exception as e:
        print(f"Error retrieving TI label for {material_id}: {e}")
        topo_labels.append(topo_encode[False])

Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 16710.37it/s]


No topological info for mp-1245108


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 25731.93it/s]


No topological info for mp-1184067


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 25731.93it/s]


No topological info for mp-23


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 22192.08it/s]


No topological info for mp-1067880


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 23301.69it/s]


No topological info for mp-10658


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 28532.68it/s]


No topological info for mp-1184113


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 32263.88it/s]


No topological info for mp-613989


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 23301.69it/s]


No topological info for mp-90


Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 18157.16it/s]

No topological info for mp-2647013





In [9]:
import requests
import json
import pandas as pd
from pymatgen.core import Structure

def fetch_tqc_magnetic_data(bcs_id="3.7"):
    """
    Fetch magnetic topological materials data from the TQC database
    """
    # Simulated example structure — replace with API call in production
    return {
        "magnetic_materials": [
            {
                "id": "mp-123",
                "formula": "Fe2O3",
                "spacegroup": 167,
                "magnetic_ordering": "AFM",
                "topological_class": "Strong TI",
                "band_gap": 0.8,
                "magnetic_moment": 4.2
            },
            # Add more entries as needed
        ]
    }

def create_combined_dataset(mp_structures, tqc_data):
    """
    Combine Materials Project structures (as Structure objects or dicts) with TQC magnetic data
    """
    tqc_map = {item["id"]: item for item in tqc_data["magnetic_materials"]}
    combined_data = []

    for struct in mp_structures:
        # Handle both dict and Structure input
        if isinstance(struct, dict):
            mp_id = struct.get("material_id")
            structure = struct.get("structure")
            formula = struct.get("pretty_formula")
            nsites = struct.get("nsites")
        elif isinstance(struct, Structure):
            mp_id = getattr(struct, "material_id", None)
            structure = struct
            formula = structure.formula
            nsites = structure.num_sites
        else:
            continue  # skip unknown formats

        if mp_id in tqc_map:
            tqc_info = tqc_map[mp_id]
            struct_data = {
                "structure": structure,
                "material_id": mp_id,
                "formula": formula,
                "nsites": nsites,
                "magnetic_ordering": tqc_info["magnetic_ordering"],
                "topological_class": tqc_info["topological_class"],
                "band_gap": tqc_info.get("band_gap", None),
                "magnetic_moment": tqc_info.get("magnetic_moment", None),
                "symmetry_operations": tqc_info.get("symmetry_operations", None)
            }
            combined_data.append(struct_data)

    return combined_data

# Example usage:
tqc_data = fetch_tqc_magnetic_data(bcs_id="3.7")
combined_dataset = create_combined_dataset(structures, tqc_data)


In [10]:
from pymatgen.analysis.local_env import CrystalNN
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
import torch

def extract_magnetic_features(structure):
    """
    Extract features relevant for magnetic ordering classification
    """
    features = {}

    # 1. Magnetic elements
    magnetic_elements = ['Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Ce', 'Pr', 'Nd',
                         'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb']
    element_counts = {}
    total_atoms = len(structure)

    for site in structure:
        symbol = str(site.specie.symbol)
        element_counts[symbol] = element_counts.get(symbol, 0) + 1

    magnetic_fraction = sum(element_counts.get(el, 0) for el in magnetic_elements) / total_atoms
    features['magnetic_fraction'] = magnetic_fraction

    # 2. Magnetic exchange pathways
    magnetic_sites = [i for i, site in enumerate(structure) if str(site.specie.symbol) in magnetic_elements]
    exchange_distances = []

    for i in magnetic_sites:
        for j in magnetic_sites:
            if i < j:
                distance = structure.get_distance(i, j)
                if distance < 4.0:
                    exchange_distances.append(distance)

    if exchange_distances:
        features['avg_exchange_distance'] = sum(exchange_distances) / len(exchange_distances)
        features['min_exchange_distance'] = min(exchange_distances)
    else:
        features['avg_exchange_distance'] = 0.0
        features['min_exchange_distance'] = 0.0

    # 3. Crystal field distortion (optional, just log one example distortion)
    # NOTE: This can produce many features, we log just one for simplicity
    distortion_list = []
    for i in magnetic_sites:
        neighbors = structure.get_neighbors(structure[i], 3.0)
        if neighbors:
            distances = [n[1] for n in neighbors]
            avg_distance = sum(distances) / len(distances)
            distortion = sum((d - avg_distance)**2 for d in distances) / len(distances)
            distortion_list.append(distortion)

    features['avg_coordination_distortion'] = (
        sum(distortion_list) / len(distortion_list) if distortion_list else 0.0
    )

    # 4. Symmetry features
    try:
        analyzer = SpacegroupAnalyzer(structure)
        spacegroup = analyzer.get_space_group_number()
    except Exception:
        spacegroup = 0  # fallback if symmetry detection fails

    features['spacegroup'] = spacegroup

    # 5. Time-reversal breaking (heuristic)
    features['potential_time_reversal_breaking'] = 1 if magnetic_fraction > 0.1 else 0

    return features

# Apply feature extraction to all structures
magnetic_features = []
for struct in structures_list_mp:
    magnetic_features.append(extract_magnetic_features(struct))

# Convert to tensor (ensure values are numeric)
magnetic_feature_tensor = torch.tensor([
    [
        float(f['magnetic_fraction']),
        float(f['avg_exchange_distance']),
        float(f['min_exchange_distance']),
        int(f['spacegroup']),
        int(f['potential_time_reversal_breaking'])
    ]
    for f in magnetic_features
], dtype=torch.float)


In [11]:
import e3nn.nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add

class MagneticTopologicalClassifier(torch.nn.Module):
    def __init__(self, atom_type_in, hidden_dim=128, model_kwargs=None):
        super().__init__()
        if model_kwargs is None:
            model_kwargs = {}

        # Atom embedding
        self.atom_embedding = torch.nn.Linear(atom_type_in, hidden_dim)
        
        # Magnetic attention module
        self.magnetic_attention = MagneticAttention(hidden_dim)
        
        # E3NN convolution layers (you must define GatedConvParityNetwork elsewhere)
        self.model = e3nn.nn.Gate(**model_kwargs)
       # self.model = GatedConvParityNetwork(**model_kwargs)
        
        # Magnetic ordering head
        self.magnetic_head = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 3)  # NM, AFM, FM/FiM
        )
        
        # Topological classification head
        self.topological_head = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 4)  # None, Weak TI, Strong TI, HOTI
        )

    def forward(self, x, edge_index, edge_attr, batch=None, n_norm=35):
        # Initial embedding
        x = self.atom_embedding(x)
        x = F.relu(x)
        
        # Magnetic attention layer
        x = self.magnetic_attention(x, edge_index, edge_attr)
        
        # E3NN-based processing
        x = self.model(x, edge_index, edge_attr, n_norm=n_norm)
        
        # Global pooling
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        x_global = scatter_add(x, batch, dim=0)
        
        # Classification heads
        magnetic_pred = self.magnetic_head(x_global)
        topological_pred = self.topological_head(x_global)
        
        return magnetic_pred, topological_pred


class MagneticAttention(MessagePassing):
    """
    Attention mechanism for magnetic interactions
    """
    def __init__(self, hidden_dim):
        super().__init__(aggr='add')  # Aggregation function
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.att_proj = nn.Linear(2 * hidden_dim, 1)

    def forward(self, x, edge_index, edge_attr):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        return self.propagate(edge_index, q=q, k=k, v=v, edge_attr=edge_attr)

    def message(self, q_i, k_j, v_j, edge_attr):
        attention_input = torch.cat([q_i, k_j], dim=-1)
        alpha = F.leaky_relu(self.att_proj(attention_input).squeeze(-1))
        alpha = torch.softmax(alpha, dim=0)
        return alpha.unsqueeze(-1) * v_j


In [12]:
import torch
import numpy as np
import time
from datetime import datetime

def evaluate_multi_task(model, dataloader, device):
    """Evaluate model on validation set."""
    model.eval()
    magnetic_loss_fn = torch.nn.CrossEntropyLoss()
    topological_loss_fn = torch.nn.CrossEntropyLoss()
    
    magnetic_loss_cumulative = 0.0
    topological_loss_cumulative = 0.0
    
    with torch.no_grad():
        for d in dataloader:
            d.to(device)
            # Forward pass
            magnetic_pred, topological_pred = model(d.x, d.edge_index, d.edge_attr, 
                                                  n_norm=True, batch=d.batch)
            # Compute losses
            magnetic_loss = magnetic_loss_fn(magnetic_pred, d.magnetic_y)
            topological_loss = topological_loss_fn(topological_pred, d.topological_y)
            
            # Update cumulative losses
            magnetic_loss_cumulative += magnetic_loss.detach().item()
            topological_loss_cumulative += topological_loss.detach().item()
    
    # Compute average losses
    magnetic_valid_loss = magnetic_loss_cumulative / len(dataloader)
    topological_valid_loss = topological_loss_cumulative / len(dataloader)
    
    return magnetic_valid_loss, topological_valid_loss

def multi_task_training(model, dataloader, dataloader_valid, max_iter=101, device="cpu"):
    # Define parameters
    params = {
        'adamw_lr': 0.001,  # Learning rate for AdamW optimizer
        'adamw_wd': 0.01    # Weight decay for AdamW optimizer
    }
    
    model.to(device)
    # Loss functions
    magnetic_loss_fn = torch.nn.CrossEntropyLoss()
    topological_loss_fn = torch.nn.CrossEntropyLoss()
    
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=params['adamw_lr'], weight_decay=params['adamw_wd'])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.78)
    
    # Training metrics
    valid_loss_min = np.inf
    
    for step in range(max_iter):
        model.train()
        magnetic_loss_cumulative = 0.0
        topological_loss_cumulative = 0.0
        start_time = time.time()
        
        for j, d in enumerate(dataloader):
            d.to(device)
            
            # Forward pass
            magnetic_pred, topological_pred = model(d.x, d.edge_index, d.edge_attr, 
                                                  n_norm=True, batch=d.batch)
            
            # Compute losses
            magnetic_loss = magnetic_loss_fn(magnetic_pred, d.magnetic_y)
            topological_loss = topological_loss_fn(topological_pred, d.topological_y)
            
            # Apply class weighting for imbalanced classes
            #TODO: can change this
            cost_multiplier = 1
            if d.magnetic_y.item() == 2: # FM/FiM classes
                magnetic_loss = cost_multiplier * magnetic_loss
            
            # Combine losses
            combined_loss = magnetic_loss + topological_loss
            
            # Update cumulative losses
            magnetic_loss_cumulative += magnetic_loss.detach().item()
            topological_loss_cumulative += topological_loss.detach().item()
            
            # Backward pass and optimization
            optimizer.zero_grad()
            combined_loss.backward()
            optimizer.step()
        
        # Compute average losses
        magnetic_train_loss = magnetic_loss_cumulative / len(dataloader)
        topological_train_loss = topological_loss_cumulative / len(dataloader)
        
        # Validation
        magnetic_valid_loss, topological_valid_loss = evaluate_multi_task(model, dataloader_valid, device)
        
        # Log progress
        if step % 10 == 0:
            print(f"Step {step:4d}/{max_iter - 1:4d} "
                  f"Magnetic Loss: {magnetic_train_loss:7.4f} "
                  f"Topological Loss: {topological_train_loss:7.4f} "
                  f"Valid Magnetic: {magnetic_valid_loss:7.4f} "
                  f"Valid Topological: {topological_valid_loss:7.4f} "
                  f"Time: {time.time() - start_time:.4f}")
        
        # Save model if validation loss improves
        valid_loss = magnetic_valid_loss + topological_valid_loss
        if valid_loss < valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model...'.format(
                valid_loss_min, valid_loss))
            run_name = datetime.today().strftime('%Y-%m-%d_%H-%M')
            torch.save(model.state_dict(), run_name + 'multi_task_model.pt')
            valid_loss_min = valid_loss
        
        # Update learning rate
        scheduler.step()
    
    return model

In [13]:
def extract_symmetry_indicators(structure):
    """
    Extract symmetry indicators relevant for topological classification
    Based on Topological Quantum Chemistry principles
    """
    indicators = {}
    
    # Get space group information
    analyzer = SpacegroupAnalyzer(structure)
    spacegroup_number = analyzer.get_space_group_number()
    point_group = analyzer.get_point_group_symbol()
    
    indicators['spacegroup_number'] = spacegroup_number
    
    # Check for inversion symmetry (important for many topological materials)
    indicators['has_inversion'] = 1 if analyzer.has_inversion() else 0
    
    # Time-reversal symmetry is crucial for topological classification
    # In a real implementation, this would be more sophisticated
    magnetic_elements = ['Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Gd', 'Dy', 'Ho', 'Er', 'Tm', 'Yb']
    has_magnetic_elements = any(element in str(structure.composition) for element in magnetic_elements)
    indicators['potential_time_reversal_breaking'] = 1 if has_magnetic_elements else 0
    
    # BCS symmetry indicators based on BCS 3.7 from the TQC database
    # This is a simplified implementation
    if spacegroup_number in [2, 10, 47, 83, 87, 199, 216, 227]:  # These are examples
        indicators['compatible_with_bcs_3_7'] = 1
    else:
        indicators['compatible_with_bcs_3_7'] = 0
    
    # Add indicators for nonsymmorphic symmetries
    indicators['has_nonsymmorphic'] = 1 if analyzer.is_nonsymmorphic() else 0
    
    # Add band connectivity indicators
    # In real implementation, this would require electronic structure calculation
    indicators['estimated_band_inversion'] = 0  # Placeholder
    
    return indicators

In [14]:
def check_bcs_compatibility(structure, bcs_id="3.7"):
    """
    Check if a structure is compatible with a specific BCS classification
    from the Topological Quantum Chemistry database
    """
    # This would typically require calling the TQC API or using their methods
    # For this example, we'll use a simplified approach
    
    analyzer = SpacegroupAnalyzer(structure)
    spacegroup = analyzer.get_space_group_number()
    
    # BCS 3.7 compatibility rules (simplified)
    # In reality, this would involve more detailed symmetry analysis
    if bcs_id == "3.7":
        # Example conditions for BCS 3.7 (these would be replaced with actual conditions)
        if spacegroup in [2, 10, 47, 83, 87, 199, 216, 227]:
            # Check additional conditions like orbital character, band inversion, etc.
            composition = structure.composition
            
            # Check for elements commonly found in TIs with this BCS
            has_heavy_elements = any(element in str(composition) for element in ['Bi', 'Sb', 'Te', 'Se'])
            
            # Check for inversion symmetry (important for many TIs)
            has_inversion = analyzer.has_inversion()
            
            return has_heavy_elements and has_inversion
    
    return False

In [15]:
#TODO: change this lol

def predict_topological_class(struct, symmetry_indicators, is_bcs_compatible):
    if not is_bcs_compatible:
        return "None"
    if symmetry_indicators.get("z4", 0) == 1:
        return "Strong TI"
    return "Weak TI"


In [16]:
from mendeleev import element
def get_en_pauling(symbol):
    elem = element(str(symbol))
    return elem.electronegativity('pauling')

#print(get_en_pauling('O'))


In [17]:
import torch
from pymatgen.core import Element
from pymatgen.analysis.magnetism.analyzer import CollinearMagneticStructureAnalyzer
from torch_geometric.data import Data

# Dummy helper functions and dictionaries for demonstration.
def get_en_pauling(symbol):
    en_dict = {'Fe': 1.83, 'O': 3.44, 'Ru': 2.2, 'Rh': 2.28}  # Extend as needed
    return en_dict.get(symbol, 0.0)

def extract_magnetic_features(struct):
    return {'dummy_magnetic': 1.0}

def extract_symmetry_indicators(struct):
    return {'dummy_symmetry': 1.0}

def check_bcs_compatibility(struct, bcs_id):
    return True

def predict_topological_class(struct, symmetry_indicators, is_bcs_compatible):
    return 'TI'

order_encode = {'NM': 0, 'AFM': 1, 'FM': 2}  # Adjust as needed
topo_encode = {'None': 0, 'TI': 1}            # Adjust as needed
params = {'max_radius': 10.0}
n_norm = 35


class DataPeriodicNeighbors(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index':
            return self.x.size(0)  # Use the number of rows in x
        if key == 'cell_index':
            return self.x.size(0)
        return super().__inc__(key, value, *args, **kwargs)
    
    def __init__(self, x=None, pos=None, lattice=None, edge_index=None, r_max=None, 
                 magnetic_y=None, topological_y=None, magnetic_features=None, 
                 symmetry_features=None, bcs_compatible=None, n_norm=None, **kwargs):
        super().__init__()
        self.x = x                        # Node features (tensor of shape [num_nodes, feature_dim])
        self.pos = pos                    # Node positions
        self.lattice = lattice            # Lattice matrix
        self.edge_index = edge_index      # Edge connectivity
        self.r_max = r_max
        self.magnetic_y = magnetic_y      # Magnetic ordering label
        self.topological_y = topological_y  # Topological label
        self.magnetic_features = magnetic_features
        self.symmetry_features = symmetry_features
        self.bcs_compatible = bcs_compatible
        self.n_norm = n_norm
        # Attach any extra fields passed via kwargs.
        for key, value in kwargs.items():
            setattr(self, key, value)


def preprocess_structures_with_tqc(structures, bcs_id="3.7"):
    """
    Enhanced preprocessing pipeline incorporating TQC data and BCS classification.
    Returns a list of DataPeriodicNeighbors objects.
    """
    processed_data = []
    len_element = 100  # Maximum number of element indices we support (e.g., Z < 100)
    
    for i, struct in enumerate(structures):
        print(f"Processing structure {i+1}/{len(structures)}", end="\r", flush=True)
        try:
            num_sites = len(struct)
            # Allocate features: 3 properties per element index
            input_features = torch.zeros(num_sites, 3 * len_element)
            
            for j, site in enumerate(struct):
                elem = str(site.specie)
                atomic_num = Element(elem).Z
                # Clip atomic number if it exceeds our fixed size
                if atomic_num >= len_element:
                    atomic_num = len_element - 1
                
                # Retrieve properties with defaults if not available.
                atomic_radius = getattr(Element(elem), 'atomic_radius', 0.0) or 0.0
                en_pauling = get_en_pauling(elem)
                if en_pauling is None:
                    en_pauling = 0.0
                dipole_polarizability = getattr(Element(elem), 'dipole_polarizability', 0.0) or 0.0
                
                # Place properties in the feature tensor at positions based on atomic_num
                input_features[j, atomic_num] = atomic_radius
                input_features[j, len_element + atomic_num] = en_pauling
                input_features[j, 2 * len_element + atomic_num] = dipole_polarizability

                # Get atom positions
            positions = torch.tensor(struct.cart_coords, dtype=torch.float)
            
            # Create edges based on distance cutoff - consider periodic boundaries
            # This is simplified - you may need to use a library like PyMatGen for proper periodic distances
            src_list = []
            dst_list = []
            edge_attr_list = []

            for src_idx in range(num_sites):
                for dst_idx in range(num_sites):
                    # Skip self-loops if you don't want them
                    if src_idx == dst_idx:
                        continue
                        
                    # Calculate distance (simplified - doesn't account for periodicity)
                    dist = torch.norm(positions[src_idx] - positions[dst_idx])
                    
                    if dist <= params['max_radius']:
                        src_list.append(src_idx)
                        dst_list.append(dst_idx)
                        # Add distance or other edge features
                        edge_attr_list.append([dist.item(), 0, 0])  # Example: [distance, 0, 0]

            if src_list:  # If there are edges
                edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
                edge_attr = torch.tensor(edge_attr_list, dtype=torch.float)
            else:
                # Fallback to self-loops if no edges found
                edge_index = torch.stack([torch.arange(num_sites), torch.arange(num_sites)], dim=0)
                edge_attr = torch.zeros((num_sites, 3), dtype=torch.float)
        
        
            
            # Create a self-loop edge_index: each node connected to itself.
            # This yields an edge_index of shape [2, num_sites] with indices 0...num_sites-1.
            #edge_index = torch.stack([torch.arange(num_sites), torch.arange(num_sites)], dim=0)
            
            # Extract additional features
            magnetic_feats = extract_magnetic_features(struct)
            symmetry_indicators = extract_symmetry_indicators(struct)
            is_bcs_compatible = check_bcs_compatibility(struct, bcs_id=bcs_id)
            
            # Get magnetic ordering using pymatgen analyzer
            analyzer = CollinearMagneticStructureAnalyzer(struct)
            magnetic_order = analyzer.ordering.name  # e.g. 'AFM', 'NM', 'FM'
            
            # Predict topological class
            topo_class = predict_topological_class(struct, symmetry_indicators, is_bcs_compatible)
            
            # Create the DataPeriodicNeighbors object with a valid edge_index
            data_point = DataPeriodicNeighbors(
                x=input_features,
                pos=torch.tensor(struct.cart_coords, dtype=torch.float),
                lattice=torch.tensor(struct.lattice.matrix, dtype=torch.float),
                edge_index=edge_index,
                edge_attr = edge_attr,
                r_max=params['max_radius'],
                magnetic_y=torch.tensor([order_encode[magnetic_order]], dtype=torch.long),
                topological_y=torch.tensor([topo_encode[topo_class]], dtype=torch.long),
                magnetic_features=torch.tensor(list(magnetic_feats.values()), dtype=torch.float),
                symmetry_features=torch.tensor(list(symmetry_indicators.values()), dtype=torch.float),
                bcs_compatible=torch.tensor([int(is_bcs_compatible)], dtype=torch.float),
                n_norm=n_norm,
            )
            
            processed_data.append(data_point)
            
        except Exception as e:
            print(f"\nError processing structure {i}: {e}")
            continue
    
    print(f"\nProcessed {len(processed_data)} structures successfully.")
    return processed_data

In [18]:
def analyze_magnetic_space_group(structure):
    """
    Analyze the magnetic space group of a structure
    Important for classifying magnetic topological materials
    """
    # This is a placeholder for more sophisticated analysis
    # In practice, you would use a library like ISOTROPY or Bilbao Crystallographic Server
    
    # Get standard space group
    analyzer = SpacegroupAnalyzer(structure)
    space_group = analyzer.get_space_group_number()
    
    # Analyze magnetic ordering
    mag_analyzer = CollinearMagneticStructureAnalyzer(structure)
    ordering = mag_analyzer.ordering.name
    
    # Simplified magnetic space group determination
    # In reality, this requires detailed analysis of symmetry and magnetic moments
    if ordering == "NM":
        # Non-magnetic: equivalent to standard space group
        mag_space_group = f"{space_group}.0"
    elif ordering == "FM":
        # Ferromagnetic: typically type III or IV MSG
        if analyzer.has_inversion():
            mag_space_group = f"{space_group}.10"  # Example type IV
        else:
            mag_space_group = f"{space_group}.8"   # Example type III
    elif ordering == "AFM":
        # Antiferromagnetic: typically type II or III MSG
        if space_group % 2 == 0:  # Even space groups often become type II
            mag_space_group = f"{space_group}.7"   # Example type II
        else:
            mag_space_group = f"{space_group}.9"   # Example type III
    elif ordering in ["FiM"]:
        # Ferromagnetic or ferrimagnetic: typically type III or IV MSG
        if analyzer.has_inversion():
            mag_space_group = f"{space_group}.10"  # Example type IV
        else:
            mag_space_group = f"{space_group}.8"   # Example type III
    else:
        mag_space_group = "unknown"
    
    # For BCS 3.7 compatibility
    compatible_with_bcs37 = False
    if mag_space_group in ["2.4", "10.42", "47.252", "83.43", "87.78", "199.13", "216.77", "227.131"]:
        compatible_with_bcs37 = True
    
    return {
        "magnetic_space_group": mag_space_group,
        "compatible_with_bcs37": compatible_with_bcs37
    }

In [19]:
class MagneticTopologicalTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, edge_attr_dim, num_heads=4):
        super().__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        self.attention_layers = nn.ModuleList([
            GraphMultiHeadAttention(hidden_dim, num_heads, edge_attr_dim)
            for _ in range(3)
        ])
        
        self.ffn_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.GELU(),
                nn.Linear(hidden_dim * 4, hidden_dim)
            ) for _ in range(3)
        ])
        
        self.layer_norms1 = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(3)
        ])
        self.layer_norms2 = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(3)
        ])
        
        self.magnetic_head = nn.Linear(hidden_dim, 3)      # NM, AFM, FM/FiM
        self.topological_head = nn.Linear(hidden_dim, 2)   # Not TI, TI
        
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.embedding(x)
        
        for i in range(3):
            attention_output = self.attention_layers[i](x, edge_index, edge_attr)
            x = self.layer_norms1[i](x + attention_output)
            ffn_output = self.ffn_layers[i](x)
            x = self.layer_norms2[i](x + ffn_output)
        
        x = torch_scatter.scatter_mean(x, batch, dim=0)
        
        magnetic_pred = self.magnetic_head(x)
        topological_pred = self.topological_head(x)
        
        return magnetic_pred, torch.sigmoid(topological_pred)


class GraphMultiHeadAttention(MessagePassing):
    def __init__(self, hidden_dim, num_heads, edge_attr_dim):
        super().__init__(aggr='add')
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.edge_proj = nn.Linear(edge_attr_dim, num_heads)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        
    # def forward(self, x, edge_index, edge_attr):
    #     # Project inputs
    #     q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
    #     k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
    #     v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
        
    #     # Process edge attributes
    #     edge_weights = self.edge_proj(edge_attr).unsqueeze(-1)  # [E, num_heads, 1]
        
    #     # Propagate through edges
    #     out = self.propagate(edge_index, q=q, k=k, v=v, edge_weights=edge_weights)
        
    #     # Project output
    #     return self.output_proj(out.view(-1, self.hidden_dim))
    def forward(self, x, edge_index, edge_attr):
        # If edge_index is None, create a self-loop edge index for each node.
        if edge_index is None:
            N = x.size(0)
            # Create self-loops: each node connected to itself.
            edge_index = torch.stack([torch.arange(N, device=x.device),
                                    torch.arange(N, device=x.device)], dim=0)
        
        # If edge_attr is None, create a default tensor with zeros.
        if edge_attr is None:
            E = edge_index.size(1)  # number of edges
            edge_attr = torch.zeros(E, self.edge_proj.in_features, device=x.device)
        
        # Project inputs
        q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
        
        # Process edge attributes
        edge_weights = self.edge_proj(edge_attr).unsqueeze(-1)  # [E, num_heads, 1]
        
        # Propagate through edges
        out = self.propagate(edge_index, q=q, k=k, v=v, edge_weights=edge_weights)
        
        # Project output
        return self.output_proj(out.view(-1, self.hidden_dim))


    
    def message(self, q_i, k_j, v_j, edge_weights):
        attention = (q_i * k_j).sum(dim=-1) / math.sqrt(self.head_dim)  # [E, num_heads]
        attention = attention.unsqueeze(-1) * edge_weights             # Apply edge weighting
        attention = F.softmax(attention, dim=0)                        # Normalize over neighbors
        return attention * v_j                                         # [E, num_heads, head_dim]

  

In [20]:
class BCS37Classifier(nn.Module):
    """
    Specialized classifier for BCS 3.7 magnetic topological materials
    """
    def __init__(self, input_dim):
        super().__init__()
        self.feature_extraction = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        
        # Symmetry analysis module
        self.symmetry_module = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16)
        )
        
        # Final classification heads
        self.magnetic_head = nn.Linear(16, 3)  # NM, AFM, FM/FiM
        self.bcs_compatibility_head = nn.Linear(16, 1)  # Sigmoid will be applied
        
    def forward(self, x, structure_features):
        # Extract features
        features = self.feature_extraction(x)
        
        # Add structure-level features like symmetry indicators
        combined_features = features + structure_features
        
        # Analyze symmetry
        symmetry_features = self.symmetry_module(combined_features)
        
        # Get predictions
        magnetic_pred = self.magnetic_head(symmetry_features)
        bcs_compatibility = torch.sigmoid(self.bcs_compatibility_head(symmetry_features))
        
        return magnetic_pred, bcs_compatibility

In [21]:
import torch.nn.functional as F

#TODO: change this to somenthing more suitable lol

def compute_loss(magnetic_pred, topological_pred, batch):
    """
    Compute the loss for both magnetic and topological predictions.

    Arguments:
    - magnetic_pred: Predictions for the magnetic ordering (output from the magnetic head)
    - topological_pred: Predictions for the topological class (output from the topological head)
    - batch: Batch of data containing the true labels for magnetic and topological classes

    Returns:
    - loss: Total loss (sum of magnetic and topological losses)
    """
    
    # Magnetic prediction loss (Cross-Entropy Loss)
    magnetic_true = batch.magnetic_y
    magnetic_loss = F.cross_entropy(magnetic_pred, magnetic_true)
    
    # Topological prediction loss (Cross-Entropy Loss)
    topological_true = batch.topological_y
    topological_loss = F.cross_entropy(topological_pred, topological_true)
    
    # Total loss is the sum of both losses
    total_loss = magnetic_loss + topological_loss
    
    return total_loss


In [22]:
# from mendeleev import element

# #TODO: think of this logic more

# def get_en_pauling(symbol):
#     try:
#         elem = element(symbol)
#         return elem.electronegativityl
#     except KeyError:
#         return None

# symbol = 'Fe'  # Example element
# en_pauling = get_en_pauling(symbol)
# print(en_pauling)  # Should print the electronegativity value or None if not found


from mendeleev import element
# Access the element you want (e.g., Oxygen)
# element = element('O')

# # Get Paulling electronegativity
# pauling_electronegativity = element.electronegativity('pauling')
# print(f"Pauling electronegativity of Oxygen: {pauling_electronegativity}")

# # Get Mulliken electronegativity
# mulliken_electronegativity = element.electronegativity('mulliken')
# print(f"Mulliken electronegativity of Oxygen: {mulliken_electronegativity}")

def get_en_pauling(symbol):
    elem = element(str(symbol))
    return elem.electronegativity('pauling')

print(get_en_pauling('O'))

3.44


In [23]:
import torch
import torch_geometric
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt


# Define the device (use CUDA if available, otherwise use CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define atom_types_dim based on your feature encoding
# For example, if you use atomic radius, electronegativity, and dipole moment as features for each atom type
atom_types_dim = 300  # For this case, you would need 3 features per atom type

# Initialize the model
model = MagneticTopologicalTransformer(
    input_dim=atom_types_dim, 
    hidden_dim=128,
    num_heads=4, 
    edge_attr_dim= 3
).to(device)  # Make sure to move the model to the correct device

# Preprocess data with TQC insights
enhanced_data = preprocess_structures_with_tqc(structures_list_mp, bcs_id="3.7")

# Split data
indices = np.arange(len(enhanced_data))
np.random.shuffle(indices)
index_tr, index_va, index_te = np.split(indices, [int(.8 * len(indices)), int(.9 * len(indices))])

# Create dataloaders
batch_size = 4 # Increased batch size for transformer
print(f"Length of enhanced_data: {len(enhanced_data)}")

from torch_geometric.loader import DataLoader
print([type(g) for g in enhanced_data])

# dataloader = DataLoader(
#     [enhanced_data[i] for i in index_tr], 
#     batch_size=batch_size, 
#     shuffle=True
# )
# dataloader_valid = DataLoader(
#     [enhanced_data[i] for i in index_va], 
#     batch_size=batch_size
# )
# Create DataLoader instances using the batched data
dataloader = DataLoader(enhanced_data, batch_size=4, shuffle=True)
dataloader_valid = DataLoader(enhanced_data, batch_size=4)

# Initialize optimizer with learning rate warmup
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)

# Learning rate scheduler (Cosine annealing with warmup)
scheduler = CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=100,  # First restart period
    T_mult=2,  # Multiply restart period after each cycle
    eta_min=0,  # Minimum learning rate
    last_epoch=-1  # Start from epoch 0
)

# Training function placeholder (assuming you have defined it previously)
from torch_geometric.data import Batch

def validate_model(model, dataload, device): 
    model.eval()
    with torch.no_grad():
        for batch in dataload:
            
            batch = batch.to(device)  # Ensure validation batch is on the correct device
            magnetic_pred, topological_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            # Compute validation loss or any evaluation metric here
            val_loss = compute_loss(magnetic_pred, topological_pred, batch)

    return val_loss

def train_mag_topo_model(model, optimizer, scheduler, dataloader, dataloader_valid, max_epochs, device):
    model.train()
    for epoch in range(max_epochs):
        epoch_loss = 0.0
        
        for batch_idx, batch in enumerate(dataloader):
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Debug information
            print("Batch x.shape:", batch.x.shape)
            print("Batch edge_index.shape:", batch.edge_index.shape)
            print("Batch batch:", batch.batch)
            
            # Get the edge index from the batch
            edge_index = batch.edge_index
            
            # IMPORTANT: When working with batched graphs in PyTorch Geometric,
            # the batch.edge_index is already corrected for the batch, so no need
            # to manually adjust the indices.
            
            # Forward pass
            out = model(batch.x, edge_index, batch.edge_attr, batch.batch)
            
            # Calculate loss
            loss = F.mse_loss(out, batch.y)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.6f}")
        
        # Step the scheduler
        scheduler.step()
        
        # Validation
        val_loss = validate_model(model, dataloader_valid, device)
        print(f"Epoch: {epoch}, Train Loss: {epoch_loss / len(dataloader):.6f}, Val Loss: {val_loss:.6f}")
        
# def train_mag_topo_model(model, optimizer, scheduler, dataloader, dataloader_valid, max_epochs, device):
#     model.train()
#     for epoch in range(max_epochs):
#         for batch in dataloader:
            
#             nx_graph = to_networkx(batch, to_undirected=True)
#             plt.figure(figsize=(8, 6))
#             nx.draw(nx_graph, with_labels=True, node_size=300, font_size=10)
#             plt.title("Batch Graph Visualization")
#             plt.show()

#             batch = batch.to(device)  # Ensure the batch is on the correct device
#             optimizer.zero_grad()
#             magnetic_pred, topological_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
#             # Define loss functions and backpropagate here
#             loss = compute_loss(magnetic_pred, topological_pred, batch)
#             loss.backward()
#             optimizer.step()
#         scheduler.step()

#         # Validation
#         model.eval()
#         with torch.no_grad():
#             for batch in dataloader_valid:
             
               
#                 batch = batch.to(device)  # Ensure validation batch is on the correct device
#                 magnetic_pred, topological_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
#                 # Compute validation loss or any evaluation metric here
#                 val_loss = compute_loss(magnetic_pred, topological_pred, batch)
        
#         print(f"Epoch {epoch+1}/{max_epochs}, Training Loss: {loss.item()}, Validation Loss: {val_loss.item()}")


for i, data in enumerate(enhanced_data):
    assert isinstance(data, DataPeriodicNeighbors)
    # print(f"Graph {i}:")
    # print(f"  num_nodes: {data.num_nodes}")
    # print(f"  edge_index max: {data.edge_index.max().item()}")
    # print(f"  edge_index shape: {data.edge_index.shape}")

for i, g in enumerate(enhanced_data):
    print(f"Graph {i}: x.shape = {g.x.shape}, num_nodes = {g.num_nodes}")

for batch in dataloader:
    print("Batched x shape:", batch.x.shape)
    print("Batched edge_index max:", batch.edge_index.max())
    print("Batched num_nodes:", batch.num_nodes)
    break

from torch_geometric.data import Batch
batched_data = Batch.from_data_list(enhanced_data)
print("Batched x shape:", batched_data.x.shape)
print("Batched edge_index shape:", batched_data.edge_index.shape)
print("Batched edge_index max:", batched_data.edge_index.max().item(), 
      "out of", batched_data.x.size(0), "nodes")

# Train the model
train_mag_topo_model(
    model, 
    optimizer, 
    scheduler,
    dataloader, 
    dataloader_valid, 
    max_epochs=100, 
    device=device
)



Processing structure 1/9
Error processing structure 0: 'FiM'
Processing structure 7/9