In [1]:
"""
Integrated Magnetic Topological Classifier with TQC Data
=========================================================

This script shows how to integrate new preprocessing routines that merge
Materials Project data with Topological Quantum Chemistry (TQC) insights,
and then use the preprocessed dataset in the transformer‐based multi‐task 
magnetic and topological classifier.
"""

import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_scatter import scatter_add, scatter_mean
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv

# Materials science libraries
import pymatgen as pmg
from pymatgen.core.structure import Structure
from pymatgen.analysis.magnetism.analyzer import CollinearMagneticStructureAnalyzer
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from mp_api.client import MPRester

# Load environment variables (ensure your .env has your MP_API_KEY)
load_dotenv()
api_key = os.getenv("MP_API_KEY")

# Define global constants and encoding dictionaries
order_encode = {"NM": 0, "AFM": 1, "FM": 2, "FiM": 2}  # Magnetic ordering
topo_encode = {"None": 0, "TI": 1}  # For topological classification (here: either not TI or TI)

# Global training parameters (you can adjust these)
PARAMS = {
    'max_radius': 10.0,        # Distance cutoff for constructing edges
    'n_norm': 35,              # (For some normalization, if needed)
    'hidden_dim': 128,         # Hidden layer dimensions for transformer
    'num_heads': 4,
    'batch_size': 4,
    'lr': 0.0001,
    'weight_decay': 0.01,
    'max_epochs': 100
}

###############################################################################
# DATA PREPROCESSING: Integration with TQC
###############################################################################

# In your integration, you want to enrich the feature space using both MP and TQC.
# Below is the custom Data class (derived from PyG’s Data) that we use:

class DataPeriodicNeighbors(Data):
    """
    Custom Data class to store graph information for periodic structures.
    """
    def __inc__(self, key, value, *args, **kwargs):
        if key in ['edge_index', 'cell_index']:
            return self.x.size(0)
        return super().__inc__(key, value, *args, **kwargs)

# --- TQC Data–Informed Preprocessing Functions ---

from pymatgen.core import Element
from pymatgen.analysis.magnetism.analyzer import CollinearMagneticStructureAnalyzer

def get_en_pauling(symbol):
    """Retrieve Pauling electronegativity from pymatgen's Element."""
    try:
        return Element(symbol).electronegativity('pauling')
    except Exception:
        return 0.0

def extract_magnetic_features(structure):
    """
    Extract simple magnetic features such as the fraction of magnetic elements,
    average exchange distance (if available) etc.
    """
    magnetic_elements = ['Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Gd', 'Dy', 'Ho', 'Er', 'Tm', 'Yb']
    element_counts = {}
    total_sites = len(structure)
    for site in structure:
        symb = str(site.specie.symbol)
        element_counts[symb] = element_counts.get(symb, 0) + 1
    magnetic_fraction = sum(element_counts.get(el, 0) for el in magnetic_elements) / total_sites
    # (More sophisticated metrics could be added here.)
    return {'magnetic_fraction': magnetic_fraction}

def extract_symmetry_indicators(structure):
    """
    Extract symmetry indicators from structure. Here we simply include the space 
    group and whether inversion is present.
    """
    try:
        analyzer = SpacegroupAnalyzer(structure)
        sg = analyzer.get_space_group_number()
        has_inversion = 1 if analyzer.has_inversion() else 0
    except Exception:
        sg = 0
        has_inversion = 0
    return {'spacegroup': sg, 'has_inversion': has_inversion}

def check_bcs_compatibility(structure, bcs_id="3.7"):
    """
    A simplified check for compatibility with a given BCS classification.
    You may wish to replace this with a call to the official TQC API.
    """
    try:
        analyzer = SpacegroupAnalyzer(structure)
        sg = analyzer.get_space_group_number()
    except Exception:
        sg = 0
    # Example: assume certain space groups satisfy BCS 3.7 rules.
    return sg in [2, 10, 47, 83, 87, 199, 216, 227]

def predict_topological_class(structure, symmetry_indicators, is_bcs_compatible):
    """
    Use simple rules to predict topological class.
    This is a placeholder and should be replaced by your domain‐specific logic.
    """
    # For demonstration, if structure is BCS compatible and inversion is present,
    # we label it as a topological insulator.
    if is_bcs_compatible and symmetry_indicators.get("has_inversion", 0) == 1:
        return "TI"
    else:
        return "None"

def preprocess_structures_with_tqc(structures, bcs_id="3.7"):
    """
    For each structure (pymatgen Structure), compute a feature tensor for nodes,
    build the edge index with a given cutoff (using a simple distance scan),
    and extract additional magnetic and symmetry features.
    
    Returns a list of DataPeriodicNeighbors objects.
    """
    processed_data = []
    # We assume a fixed number (e.g. 100) is an upper bound for atomic number indexing.
    len_element = 100  
    for i, struct in enumerate(structures):
        try:
            num_sites = len(struct)
            # Create node features based on element properties.
            # For each site, we create a feature vector that spans a fixed-length embedding.
            # In this example we use three properties per element (atomic radius, en_pauling, and a placeholder dipole polarizability).
            node_features = torch.zeros(num_sites, 3 * len_element)
            for j, site in enumerate(struct):
                elem = str(site.specie)
                atomic_num = Element(elem).Z
                if atomic_num >= len_element:
                    atomic_num = len_element - 1
                atomic_radius = getattr(Element(elem), 'atomic_radius', None) or 0.0
                en_pauling = get_en_pauling(elem) or 0.0
                # You can add a third property (e.g., dipole polarizability) if available.
                dipole_polarizability = 0.0  
                node_features[j, atomic_num] = atomic_radius
                node_features[j, len_element + atomic_num] = en_pauling
                node_features[j, 2 * len_element + atomic_num] = dipole_polarizability
            
            # Positions
            positions = torch.tensor(struct.cart_coords, dtype=torch.float)
            
            # Build edge connectivity using a simple double loop (for demonstration).
            src_list = []
            dst_list = []
            edge_attr_list = []  # e.g. [distance, 0, 0] as placeholder
            for src in range(num_sites):
                for dst in range(num_sites):
                    if src == dst:
                        continue
                    dist = torch.norm(positions[src] - positions[dst]).item()
                    if dist <= PARAMS['max_radius']:
                        src_list.append(src)
                        dst_list.append(dst)
                        edge_attr_list.append([dist / PARAMS['max_radius'], 0, 0])
            if src_list:
                edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
                edge_attr = torch.tensor(edge_attr_list, dtype=torch.float)
            else:
                # Fallback: self-loops if no edges are found.
                edge_index = torch.stack([torch.arange(num_sites), torch.arange(num_sites)], dim=0)
                edge_attr = torch.zeros((num_sites, 3), dtype=torch.float)
            
            # Additional global features from TQC side:
            magnetic_feats = extract_magnetic_features(struct)
            symmetry_indicators = extract_symmetry_indicators(struct)
            bcs_compatible = check_bcs_compatibility(struct, bcs_id=bcs_id)
            topo_class = predict_topological_class(struct, symmetry_indicators, bcs_compatible)
            
            # Get magnetic ordering via pymatgen analyzer.
            mag_analyzer = CollinearMagneticStructureAnalyzer(struct)
            magnetic_order = mag_analyzer.ordering.name  # e.g., "NM", "AFM", "FM", etc.
            
            # Create our custom data object.
            data_point = DataPeriodicNeighbors(
                x=node_features,  # Node features: shape [num_sites, 3*len_element]
                pos=positions,
                lattice=torch.tensor(struct.lattice.matrix, dtype=torch.float),
                edge_index=edge_index,
                edge_attr=edge_attr,
                r_max=PARAMS['max_radius'],
                magnetic_y=torch.tensor([order_encode.get(magnetic_order, 0)], dtype=torch.long),
                topological_y=torch.tensor([topo_encode.get(topo_class, 0)], dtype=torch.long),
                # Save extra features if needed:
                magnetic_features=torch.tensor(list(magnetic_feats.values()), dtype=torch.float),
                symmetry_features=torch.tensor(list(symmetry_indicators.values()), dtype=torch.float),
                bcs_compatible=torch.tensor([1 if bcs_compatible else 0], dtype=torch.float),
                n_norm=PARAMS['n_norm']
            )
            processed_data.append(data_point)
        except Exception as e:
            print(f"Error processing structure {i}: {e}")
    print(f"Preprocessed {len(processed_data)} structures successfully.")
    return processed_data

###############################################################################
# MODEL ARCHITECTURE (Transformer-Based)
###############################################################################

class GraphMultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, edge_attr_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.edge_proj = nn.Linear(edge_attr_dim, num_heads)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x, edge_index, edge_attr):
        # If no edge_attr is supplied, create default zeros.
        if edge_attr is None:
            num_edges = edge_index.size(1)
            edge_attr = torch.zeros(num_edges, self.edge_proj.in_features, device=x.device)
        
        q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
        
        edge_weights = self.edge_proj(edge_attr).unsqueeze(-1)  # shape [E, num_heads, 1]
        out = self.propagate(edge_index, q=q, k=k, v=v, edge_weights=edge_weights)
        return self.output_proj(out.view(-1, self.hidden_dim))
    
    def message(self, q_i, k_j, v_j, edge_weights):
        # Compute attention scores
        attn = (q_i * k_j).sum(dim=-1) / math.sqrt(self.head_dim)  # [E, num_heads]
        attn = attn.unsqueeze(-1) * edge_weights
        attn = F.softmax(attn, dim=0)
        return attn * v_j

    def propagate(self, edge_index, **kwargs):
        # Use PyG's built-in propagation. Here we simply call the message function
        row, col = edge_index
        messages = self.message(kwargs['q'][row], kwargs['k'][col],
                                kwargs['v'][col], kwargs['edge_weights'])
        out = scatter_mean(messages, row, dim=0, dim_size=kwargs['q'].size(0))
        return out

class MagneticTopologicalTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, edge_attr_dim, num_heads=4, num_layers=3):
        super().__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.attention_layers = nn.ModuleList([
            GraphMultiHeadAttention(hidden_dim, num_heads, edge_attr_dim)
            for _ in range(num_layers)
        ])
        self.layer_norms1 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        self.ffn_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim*4),
                nn.GELU(),
                nn.Linear(hidden_dim*4, hidden_dim)
            ) for _ in range(num_layers)
        ])
        self.layer_norms2 = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        self.magnetic_head = nn.Linear(hidden_dim, 3)      # Predict NM, AFM, FM/FiM
        self.topological_head = nn.Linear(hidden_dim, 2)   # Predict Not TI, TI
    
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.embedding(x)
        # Apply several layers of attention and feed-forward modules
        for i in range(len(self.attention_layers)):
            attn_out = self.attention_layers[i](x, edge_index, edge_attr)
            x = self.layer_norms1[i](x + attn_out)
            ffn_out = self.ffn_layers[i](x)
            x = self.layer_norms2[i](x + ffn_out)
        # Global pooling: here we use mean pooling over nodes in each graph.
        x_global = scatter_mean(x, batch, dim=0)
        magnetic_pred = self.magnetic_head(x_global)
        topological_pred = self.topological_head(x_global)
        return magnetic_pred, topological_pred

###############################################################################
# TRAINING LOOP (Transformer with Preprocessed Data)
###############################################################################

def compute_loss(magnetic_pred, topological_pred, batch):
    # For this integrated example, we use cross-entropy for both classification tasks.
    mag_loss = F.cross_entropy(magnetic_pred, batch.magnetic_y)
    topo_loss = F.cross_entropy(topological_pred, batch.topological_y)
    return mag_loss + topo_loss

def train_mag_topo_model(model, optimizer, scheduler, dataloader, dataloader_valid, max_epochs, device):
    model.train()
    for epoch in range(max_epochs):
        epoch_loss = 0.0
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()
            magnetic_pred, topological_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            loss = compute_loss(magnetic_pred, topological_pred, batch)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            for batch in dataloader_valid:
                batch = batch.to(device)
                magnetic_pred, topological_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
                val_loss += compute_loss(magnetic_pred, topological_pred, batch).item()
            val_loss /= len(dataloader_valid)
        model.train()
        print(f"Epoch {epoch+1}/{max_epochs} | Train Loss: {epoch_loss / len(dataloader):.6f} | Val Loss: {val_loss:.6f}")

###############################################################################
# MAIN EXECUTION
###############################################################################

def main():
    mp_structures_file = '/Users/abiralshakya/Documents/Research/Topological_Insulators_OnGithub/generative_nmti/Integrated_Magnetic_Topological/magnetic_order/preload_data/mp_structures_2025-04-07_12-52.pt'
    if not os.path.exists(mp_structures_file):
        print(f"File {mp_structures_file} not found!")
        return
    mp_structures_dict = torch.load(mp_structures_file, weights_only=False)
    # Here we use the list of structures; you may change this depending on your saved format.
    structures = mp_structures_dict['structures']
    print(f"Loaded {len(structures)} structures.")

    # 2. Preprocess structures with TQC and extract additional features.
    enhanced_data = preprocess_structures_with_tqc(structures, bcs_id="3.7")
    
    # 3. Split the dataset into train / validation / test sets.
    indices = np.arange(len(enhanced_data))
    np.random.shuffle(indices)
    train_end = int(0.8 * len(indices))
    val_end = int(0.9 * len(indices))
    train_data = [enhanced_data[i] for i in indices[:train_end]]
    val_data = [enhanced_data[i] for i in indices[train_end:val_end]]
    test_data = [enhanced_data[i] for i in indices[val_end:]]
    
    train_loader = DataLoader(train_data, batch_size=PARAMS['batch_size'], shuffle=True)
    val_loader = DataLoader(val_data, batch_size=PARAMS['batch_size'], shuffle=False)
    test_loader = DataLoader(test_data, batch_size=PARAMS['batch_size'], shuffle=False)
    
    # 4. Initialize the MagneticTopologicalTransformer model.
    # Here, input_dim is the dimension of our node feature tensor.
    input_dim = enhanced_data[0].x.size(1)  # e.g., 3*len_element (here 300 if len_element==100)
    model = MagneticTopologicalTransformer(input_dim=input_dim, hidden_dim=PARAMS['hidden_dim'],
                                           edge_attr_dim=3, num_heads=PARAMS['num_heads'], num_layers=3)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # 5. Set up optimizer and learning rate scheduler.
    optimizer = torch.optim.AdamW(model.parameters(), lr=PARAMS['lr'], weight_decay=PARAMS['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
    
    # 6. Train the model.
    train_mag_topo_model(model, optimizer, scheduler, train_loader, val_loader, 
                         max_epochs=PARAMS['max_epochs'], device=device)
    
    # Optionally, you can save the model checkpoint here.
    torch.save(model.state_dict(), "best_mag_topo_transformer.pt")
    
    # 7. Test the model (you can build a similar testing loop)
    model.eval()
    all_magnetic_preds = []
    all_topological_preds = []
    all_magnetic_labels = []
    all_topological_labels = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            magnetic_pred, topological_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            all_magnetic_preds.extend(magnetic_pred.argmax(dim=1).cpu().numpy())
            all_topological_preds.extend(topological_pred.argmax(dim=1).cpu().numpy())
            all_magnetic_labels.extend(batch.magnetic_y.cpu().numpy())
            all_topological_labels.extend(batch.topological_y.cpu().numpy())
    
    mag_acc = accuracy_score(all_magnetic_labels, all_magnetic_preds)
    topo_acc = accuracy_score(all_topological_labels, all_topological_preds)
    print(f"Test Magnetic Accuracy: {mag_acc:.4f}, Test Topological Accuracy: {topo_acc:.4f}")

if __name__ == '__main__':
    main()


  from .autonotebook import tqdm as notebook_tqdm


File ./mp_structures_2025-04-07_12-52.pt not found!
