In [2]:
# ===============================================
# ALL IMPORTS - Phonon Dataset Creation Pipeline
# ===============================================

# Basic Libraries
import yaml
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

# Scientific Computing
from pymatgen.core import Structure, Lattice, Element
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.analysis.local_env import VoronoiNN
import MDAnalysis as mda
from MDAnalysis.analysis.rdf import InterRDF
from scipy.signal import find_peaks

# Python Standard Library
import itertools
from itertools import combinations_with_replacement

print("📚 All libraries imported successfully!")
print("✅ Ready for phonon dataset creation pipeline")

  from .autonotebook import tqdm as notebook_tqdm


📚 All libraries imported successfully!
✅ Ready for phonon dataset creation pipeline


In [3]:
# ===============================================
# Extract Phonon Band Frequencies (Y data)
# ===============================================

def extract_phonon_frequencies(file_path):
    """
    Extract phonon frequencies from a single YAML file
    
    Args:
        file_path: Path to the band YAML file
        
    Returns:
        List of all phonon frequencies for this structure
    """
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)
    
    frequencies = []
    for phonon_point in data['phonon']:
        for band in phonon_point['band']:
            frequencies.append(band['frequency'])
    
    return frequencies

# Extract frequencies from all band files
print("📊 Extracting phonon frequencies from band files...")
Y_data = []

for file_index in range(1, 359):  # Files 1 to 358
    file_path = f'band/band ({file_index}).yaml'
    
    try:
        frequencies = extract_phonon_frequencies(file_path)
        Y_data.append(frequencies)
        
        if file_index % 50 == 0:  # Progress indicator
            print(f"   Processed {file_index}/358 files...")
            
    except FileNotFoundError:
        print(f"⚠️  Warning: File {file_path} not found, skipping...")
        continue
    except Exception as e:
        print(f"❌ Error processing file {file_path}: {e}")
        continue

print(f"✅ Successfully extracted phonon data from {len(Y_data)} structures")
print(f"📈 Y data shape: {np.array(Y_data).shape}")
print(f"   Each structure has {len(Y_data[0]) if Y_data else 0} phonon frequencies")

# Convert to numpy array for easier handling
Y = np.array(Y_data)

📊 Extracting phonon frequencies from band files...
   Processed 50/358 files...
   Processed 50/358 files...
   Processed 100/358 files...
   Processed 100/358 files...
   Processed 150/358 files...
   Processed 150/358 files...
   Processed 200/358 files...
   Processed 200/358 files...
   Processed 250/358 files...
   Processed 250/358 files...
   Processed 300/358 files...
   Processed 300/358 files...
   Processed 350/358 files...
   Processed 350/358 files...
✅ Successfully extracted phonon data from 358 structures
📈 Y data shape: (358, 8568)
   Each structure has 8568 phonon frequencies
✅ Successfully extracted phonon data from 358 structures
📈 Y data shape: (358, 8568)
   Each structure has 8568 phonon frequencies


In [4]:
# Save Y data (phonon frequencies) to file
print("💾 Saving Y data (phonon frequencies)...")
np.save("Y.npy", Y)
print(f"✅ Y data saved successfully!")
print(f"   File: Y.npy")
print(f"   Shape: {Y.shape}")
print(f"   Data type: {Y.dtype}")
print(f"   File size: ~{Y.nbytes / (1024**2):.2f} MB")

💾 Saving Y data (phonon frequencies)...
✅ Y data saved successfully!
   File: Y.npy
   Shape: (358, 8568)
   Data type: float64
   File size: ~23.40 MB


In [5]:
# ===============================================
# Feature Extraction Functions for Crystal Structures
# ===============================================

# ===============================================
# Geometric Features
# ===============================================

def get_angle(v1, v2):
    """Calculate angle between two vectors in degrees"""
    cos_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
    cos_angle = np.clip(cos_angle, -1, 1)  # Prevent numerical errors
    angle_rad = np.arccos(cos_angle)
    return np.degrees(angle_rad)

def extract_lattice_angles(lattice):
    """Extract alpha, beta, gamma angles from lattice vectors"""
    a, b, c = lattice[0], lattice[1], lattice[2]
    alpha = get_angle(b, c)  # angle between b and c
    beta = get_angle(a, c)   # angle between a and c
    gamma = get_angle(a, b)  # angle between a and b
    return [alpha, beta, gamma]

def calculate_volume(lattice):
    """Calculate unit cell volume"""
    return np.abs(np.linalg.det(lattice))

# ===============================================
# Structure Parsing
# ===============================================

def parse_atomic_positions(data, atomic_numbers):
    """
    Parse atomic positions and types from YAML data
    
    Returns:
        atomic_masses: List of atomic numbers
        positions: List of fractional coordinates 
        symbols: List of element symbols
    """
    points = np.array(data['points'])
    atomic_masses = []
    symbols = []
    positions = []
    
    for point in points:
        symbol = point['symbol']
        atomic_masses.append(atomic_numbers[symbol])
        symbols.append(symbol)
        positions.append(point['coordinates'])
    
    return atomic_masses, positions, symbols

def calculate_density(volume, atomic_masses):
    """Calculate density as volume per total atomic mass"""
    return volume / sum(atomic_masses)

# ===============================================
# Symmetry Analysis
# ===============================================

def get_spacegroup_number(lattice, symbols, positions):
    """Get spacegroup number using SpacegroupAnalyzer"""
    structure = Structure(lattice, symbols, positions)
    analyzer = SpacegroupAnalyzer(structure, symprec=0.1)
    return analyzer.get_space_group_number()

# ===============================================
# MAX Phase Analysis
# ===============================================

def identify_max_elements(element_symbols):
    """
    Automatically identify M, A, X elements in MAX phase structures
    
    Returns:
        M: Transition metal element
        A: A-group element  
        X: Carbon or Nitrogen
    """
    unique_elements = list(set(element_symbols))
    
    # Must have exactly 3 unique elements for MAX phase
    if len(unique_elements) != 3:
        return None, None, None

    # Identify X element (C or N)
    X = None
    if 'C' in unique_elements:
        X = 'C'
    elif 'N' in unique_elements:
        X = 'N'
    else:
        return None, None, None
    
    # Identify M and A from remaining elements
    remaining_elements = [el for el in unique_elements if el != X]
    el1, el2 = Element(remaining_elements[0]), Element(remaining_elements[1])

    # M is typically the transition metal
    if el1.is_transition_metal and not el2.is_transition_metal:
        M, A = el1.symbol, el2.symbol
    elif el2.is_transition_metal and not el1.is_transition_metal:
        M, A = el2.symbol, el1.symbol
    else:
        # Fallback: use group number
        if el1.group < el2.group:
            M, A = el1.symbol, el2.symbol
        else:
            M, A = el2.symbol, el1.symbol
            
    return M, A, X

# ===============================================
# Radial Distribution Function (RDF)
# ===============================================

def calculate_rdf_features(lattice, positions, symbols):
    """Calculate RDF peaks for all element pairs"""
    
    # Convert to cartesian coordinates
    coords_cart = np.dot(positions, lattice)
    
    # Calculate lattice parameters
    a, b, c = np.linalg.norm(lattice, axis=1)
    alpha = np.rad2deg(np.arccos(np.dot(lattice[1], lattice[2]) / (b * c)))
    beta = np.rad2deg(np.arccos(np.dot(lattice[0], lattice[2]) / (a * c)))
    gamma = np.rad2deg(np.arccos(np.dot(lattice[0], lattice[1]) / (a * b)))
    box_dims = [a, b, c, alpha, beta, gamma]
    
    # Create MDAnalysis universe
    universe = mda.Universe.empty(n_atoms=len(symbols), trajectory=True)
    universe.add_TopologyAttr('name', symbols)
    universe.atoms.positions = coords_cart
    universe.dimensions = box_dims
    
    # Calculate partial RDFs for all element pairs
    unique_elements = sorted(list(set(symbols)))
    element_pairs = list(combinations_with_replacement(unique_elements, 2))
    
    all_peaks = []
    
    for elem1, elem2 in element_pairs:
        group1 = universe.select_atoms(f"name {elem1}")
        group2 = universe.select_atoms(f"name {elem2}")
        
        # Calculate RDF
        rdf = InterRDF(group1, group2, nbins=100001, range=(0.0, 10.0))
        rdf.run()
        
        # Find peaks
        if elem1 == elem2:
            bins, rdf_values = rdf.results.bins[1:], rdf.results.rdf[1:]
        else:
            bins, rdf_values = rdf.results.bins, rdf.results.rdf
        
        peak_indices, _ = find_peaks(rdf_values, height=1.0)
        peak_distances = bins[peak_indices]
        all_peaks.extend(peak_distances)
    
    # Return unique sorted peaks
    unique_peaks = sorted(list(set(np.round(all_peaks, 3))))
    return unique_peaks

# ===============================================
# Coordination Numbers
# ===============================================

def calculate_coordination_numbers(lattice, positions, symbols, M, A, X):
    """Calculate average coordination numbers for M, A, X elements"""
    structure = Structure(lattice, symbols, positions)
    neighbor_finder = VoronoiNN(allow_pathological=True, tol=0.5)
    
    cn_M, cn_A, cn_X = [], [], []
    
    for i in range(len(structure)):
        atom_type = structure[i].specie.symbol
        cn = neighbor_finder.get_cn(structure, i)
        
        if atom_type == M:
            cn_M.append(cn)
        elif atom_type == A:
            cn_A.append(cn)
        elif atom_type == X:
            cn_X.append(cn)
    
    return [
        np.mean(cn_M) if cn_M else 0,
        np.mean(cn_A) if cn_A else 0,
        np.mean(cn_X) if cn_X else 0
    ]

# ===============================================
# Bond Angle Analysis
# ===============================================

def calculate_angle_statistics(angle_list):
    """Calculate mean and std of angle list"""
    if angle_list:
        return np.mean(angle_list), np.std(angle_list)
    return 0, 0

def calculate_bond_angles(lattice, positions, symbols, M, A, X):
    """Calculate bond angle statistics for different atom combinations"""
    structure = Structure(lattice, symbols, positions)
    neighbor_finder = VoronoiNN(allow_pathological=True, tol=0.5)
    
    angles_XMX = []  # X-M-X angles
    angles_AMX = []  # A-M-X angles  
    angles_MAM = []  # M-A-M angles
    
    for i in range(len(structure)):
        center_atom = structure[i].specie.symbol
        neighbors_info = neighbor_finder.get_nn_info(structure, i)
        
        if len(neighbors_info) < 2:
            continue
    
        # For M-centered angles
        if center_atom == M:
            neighbor_sites = {n['site_index']: n['site'].specie.symbol 
                            for n in neighbors_info}
            
            for j, k in itertools.combinations(neighbor_sites.keys(), 2):
                neighbor1, neighbor2 = neighbor_sites[j], neighbor_sites[k]
                angle = structure.get_angle(j, i, k)
    
                if neighbor1 == X and neighbor2 == X:
                    angles_XMX.append(angle)
                elif (neighbor1 == A and neighbor2 == X) or \
                     (neighbor1 == X and neighbor2 == A):
                    angles_AMX.append(angle)
    
        # For A-centered angles
        elif center_atom == A:
            neighbor_indices = [n['site_index'] for n in neighbors_info]
            for j, k in itertools.combinations(neighbor_indices, 2):
                if (structure[j].specie.symbol == M and 
                    structure[k].specie.symbol == M):
                    angle = structure.get_angle(j, i, k)
                    angles_MAM.append(angle)

    return [
        calculate_angle_statistics(angles_XMX),
        calculate_angle_statistics(angles_AMX), 
        calculate_angle_statistics(angles_MAM)
    ]

print("🔧 Feature extraction functions loaded successfully!")
print("   📐 Geometric features: lattice angles, volume, density")
print("   🔬 Structure analysis: spacegroup, MAX phase identification") 
print("   📊 Advanced features: RDF, coordination numbers, bond angles")

🔧 Feature extraction functions loaded successfully!
   📐 Geometric features: lattice angles, volume, density
   🔬 Structure analysis: spacegroup, MAX phase identification
   📊 Advanced features: RDF, coordination numbers, bond angles


In [6]:
# ===============================================
# Load Elemental Properties Database
# ===============================================

print("📊 Loading elemental properties database...")

# Load periodic table data
df = pd.read_csv('ptable.csv')
df.fillna(0, inplace=True)

print(f"✅ Loaded data for {len(df)} elements")
print(f"📋 Available properties: {len(df.columns)} total")

# Clean up unnecessary columns
columns_to_drop = [
    'electronic_configuration', 
    'name', 
    'block', 
    'lattice_structure', 
    'is_radioactive'
]

print(f"🧹 Removing {len(columns_to_drop)} unnecessary columns...")
for col in columns_to_drop:
    if col in df.columns:
        df.drop(col, axis=1, inplace=True)
        print(f"   ❌ Dropped: {col}")

print(f"✅ Final dataset: {len(df)} elements × {len(df.columns)} properties")
print(f"📋 Remaining properties: {list(df.columns)[1:]}")  # Skip 'symbol' column

# Preview the data structure
print(f"\n🔍 Sample properties for first element:")
sample_properties = list(df.iloc[0,:])[1:]  # Skip symbol
print(f"   Properties count: {len(sample_properties)}")
print(f"   Sample values: {sample_properties[:5]}...")  # Show first 5

📊 Loading elemental properties database...
✅ Loaded data for 118 elements
📋 Available properties: 82 total
🧹 Removing 5 unnecessary columns...
   ❌ Dropped: electronic_configuration
   ❌ Dropped: name
   ❌ Dropped: block
   ❌ Dropped: lattice_structure
   ❌ Dropped: is_radioactive
✅ Final dataset: 118 elements × 77 properties
📋 Remaining properties: ['atomic_number', 'atomic_radius', 'atomic_volume', 'boiling_point', 'density', 'dipole_polarizability', 'electron_affinity', 'valance_main', 'evaporation_heat', 'fusion_heat', 'group_id', 'lattice_constant', 'melting_point', 'period', 'series_id', 'specific_heat', 'thermal_conductivity', 'vdw_radius', 'covalent_radius_cordero', 'covalent_radius_pyykko', 'en_pauling', 'en_allen', 'heat_of_formation', 'c6', 'covalent_radius_bragg', 'covalent_radius_slater', 'vdw_radius_batsanov', 'vdw_radius_uff', 'vdw_radius_mm3', 'abundance_crust', 'abundance_sea', 'en_ghosh', 'vdw_radius_alvarez', 'c6_gb', 'atomic_weight', 'atomic_weight_uncertainty', 'is

In [10]:
# ===============================================
# Extract Structural Features (X data)
# ===============================================

# Atomic number mapping
ATOMIC_NUMBERS = {
    "H": 1,   "He": 2,  "Li": 3,  "Be": 4,   "B": 5,
    "C": 6,   "N": 7,   "O": 8,   "F": 9,    "Ne": 10,
    "Na": 11, "Mg": 12, "Al": 13, "Si": 14,  "P": 15,
    "S": 16,  "Cl": 17, "Ar": 18, "K": 19,   "Ca": 20,
    "Sc": 21, "Ti": 22, "V": 23,  "Cr": 24,  "Mn": 25,
    "Fe": 26, "Co": 27, "Ni": 28, "Cu": 29,  "Zn": 30,
    "Ga": 31, "Ge": 32, "As": 33, "Se": 34,  "Br": 35,
    "Kr": 36, "Rb": 37, "Sr": 38, "Y": 39,   "Zr": 40,
    "Nb": 41, "Mo": 42, "Tc": 43, "Ru": 44,  "Rh": 45,
    "Pd": 46, "Ag": 47, "Cd": 48, "In": 49,  "Sn": 50,
    "Sb": 51, "Te": 52, "I": 53,  "Xe": 54,  "Cs": 55,
    "Ba": 56, "La": 57, "Ce": 58, "Pr": 59,  "Nd": 60,
    "Pm": 61, "Sm": 62, "Eu": 63, "Gd": 64,  "Tb": 65,
    "Dy": 66, "Ho": 67, "Er": 68, "Tm": 69,  "Yb": 70,
    "Lu": 71, "Hf": 72, "Ta": 73, "W": 74,   "Re": 75,
    "Os": 76, "Ir": 77, "Pt": 78, "Au": 79,  "Hg": 80,
    "Tl": 81, "Pb": 82, "Bi": 83, "Po": 84,  "At": 85,
    "Rn": 86, "Fr": 87, "Ra": 88, "Ac": 89,  "Th": 90,
    "Pa": 91, "U": 92,  "Np": 93, "Pu": 94,  "Am": 95,
    "Cm": 96, "Bk": 97, "Cf": 98, "Es": 99,  "Fm": 100,
    "Md": 101, "No": 102,"Lr": 103,"Rf": 104, "Db": 105,
    "Sg": 106,"Bh": 107,"Hs": 108,"Mt": 109, "Ds": 110,
    "Rg": 111,"Cn": 112,"Nh": 113,"Fl": 114, "Mc": 115,
    "Lv": 116,"Ts": 117,"Og": 118
}

def extract_structure_features(file_path, elemental_df):
    """
    Extract comprehensive structural features from a single YAML file
    
    Args:
        file_path: Path to the structure YAML file
        elemental_df: DataFrame containing elemental properties
        
    Returns:
        List of all structural features for this structure
    """
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)

    # Parse basic structure data
    lattice = np.array(data['lattice'])
    atomic_masses, positions, symbols = parse_atomic_positions(data, ATOMIC_NUMBERS)
    
    # Identify MAX phase elements
    M, A, X = identify_max_elements(symbols)
    if not all([M, A, X]):
        raise ValueError(f"Could not identify MAX phase elements in {file_path}")
    
    # Calculate structural features
    structural_features = []
    
    # 1. Basic composition features
    structural_features.extend([
        ATOMIC_NUMBERS[M], ATOMIC_NUMBERS[A], ATOMIC_NUMBERS[X]  # Atomic numbers
    ])
    structural_features.extend([
        symbols.count(M), symbols.count(A), symbols.count(X)    # Atom counts
    ])
    
    # 2. Lattice geometry features  
    structural_features.extend(extract_lattice_angles(lattice))     # α, β, γ angles
    structural_features.append(calculate_volume(lattice))           # Unit cell volume
    structural_features.append(calculate_density(calculate_volume(lattice), atomic_masses))  # Density
    structural_features.append(get_spacegroup_number(lattice, symbols, positions))  # Spacegroup
    
    # 3. Advanced structural features
    structural_features.extend(calculate_rdf_features(lattice, positions, symbols))  # RDF peaks
    structural_features.extend(calculate_coordination_numbers(lattice, positions, symbols, M, A, X))  # CN
    
    # 4. Bond angle features (flattened)
    angle_stats = calculate_bond_angles(lattice, positions, symbols, M, A, X)
    structural_features.extend(np.array(angle_stats).flatten())
    
    # 5. Elemental property features
    # Get elemental properties for M, A, X elements
    M_properties = list(elemental_df.iloc[ATOMIC_NUMBERS[M]-1,:])[2:]  # Skip symbol and atomic_number
    A_properties = list(elemental_df.iloc[ATOMIC_NUMBERS[A]-1,:])[2:]
    X_properties = list(elemental_df.iloc[ATOMIC_NUMBERS[X]-1,:])[2:]
    
    structural_features.extend(M_properties)
    structural_features.extend(A_properties) 
    structural_features.extend(X_properties)
    
    return structural_features

# Extract features from all structure files
print("🔧 Extracting structural features from band files...")
X_data = []

# Process all files with progress indicator
for file_index in range(1, 359):  # Files 1 to 358
    file_path = f'band/band ({file_index}).yaml'
    
    try:
        features = extract_structure_features(file_path, df)
        X_data.append(features)
        
        # Progress indicator every 50 files
        if file_index % 50 == 0:
            print(f"   📊 Processed {file_index}/358 files...")
            print(f"      Current X shape: {np.array(X_data).shape}")
        
    except FileNotFoundError:
        print(f"⚠️  Warning: File {file_path} not found, skipping...")
        continue
    except Exception as e:
        print(f"❌ Error processing file {file_path}: {e}")
        continue

# Convert to numpy array
X = np.array(X_data)

print(f"\n📊 Feature extraction complete!")
print(f"✅ Successfully processed {len(X_data)} structures")
print(f"📈 X data shape: {X.shape}")
print(f"   Features per structure: {X.shape[1] if len(X.shape) > 1 else 'N/A'}")

# Feature composition breakdown
if len(X_data) > 0:
    print(f"\n📋 Feature composition:")
    print(f"   📐 Basic composition: 6 features (atomic numbers + counts)")
    print(f"   🔬 Lattice geometry: ~10 features (angles, volume, density, spacegroup)")
    print(f"   📊 Advanced features: Variable (RDF peaks, CN, bond angles)")
    print(f"   ⚛️  Elemental properties: ~150 features (3 elements × ~50 properties each)")
    print(f"   🎯 Total: {X.shape[1] if len(X.shape) > 1 else 0} features")

🔧 Extracting structural features from band files...
   📊 Processed 50/358 files...
      Current X shape: (50, 255)
   📊 Processed 50/358 files...
      Current X shape: (50, 255)
   📊 Processed 100/358 files...
      Current X shape: (100, 255)
   📊 Processed 100/358 files...
      Current X shape: (100, 255)
   📊 Processed 150/358 files...
      Current X shape: (150, 255)
   📊 Processed 150/358 files...
      Current X shape: (150, 255)
   📊 Processed 200/358 files...
      Current X shape: (200, 255)
   📊 Processed 200/358 files...
      Current X shape: (200, 255)
   📊 Processed 250/358 files...
      Current X shape: (250, 255)
   📊 Processed 250/358 files...
      Current X shape: (250, 255)
   📊 Processed 300/358 files...
      Current X shape: (300, 255)
   📊 Processed 300/358 files...
      Current X shape: (300, 255)
   📊 Processed 350/358 files...
      Current X shape: (350, 255)
   📊 Processed 350/358 files...
      Current X shape: (350, 255)

📊 Feature extraction comple

In [11]:
# ===============================================
# Save X Data and Dataset Summary  
# ===============================================

if len(X_data) > 0:
    # Save X data (structural features)
    print("💾 Saving X data (structural features)...")
    np.save("X.npy", X)
    print(f"✅ X data saved successfully!")
    print(f"   File: X.npy")
    print(f"   Shape: {X.shape}")
    print(f"   Data type: {X.dtype}")
    print(f"   File size: ~{X.nbytes / (1024**2):.2f} MB")
    
    print(f"\n📋 COMPLETE DATASET SUMMARY:")
    print(f"{'='*50}")
    print(f"🎯 Dataset Size:")
    print(f"   📊 Structures: {X.shape[0]}")
    print(f"   🔧 Features per structure: {X.shape[1]}")
    print(f"   📈 Phonon frequencies per structure: {Y.shape[1] if 'Y' in locals() else 'Not calculated'}")
    
    print(f"\n📁 Files Created:")
    print(f"   ✅ X.npy - Structural features ({X.shape})")
    print(f"   ✅ Y.npy - Phonon frequencies ({Y.shape if 'Y' in locals() else 'Not saved'})")
    
    print(f"\n🚀 Ready for MLP Training!")
    print(f"   Load data: X = np.load('X.npy'), Y = np.load('Y.npy')")
    print(f"   Model input: X ({X.shape[1]} features)")  
    print(f"   Model output: Y ({Y.shape[1] if 'Y' in locals() else 'TBD'} phonon frequencies)")
    print(f"{'='*50}")
    
else:
    print("❌ No X data to save - check file processing above")

💾 Saving X data (structural features)...
✅ X data saved successfully!
   File: X.npy
   Shape: (358, 255)
   Data type: float64
   File size: ~0.70 MB

📋 COMPLETE DATASET SUMMARY:
🎯 Dataset Size:
   📊 Structures: 358
   🔧 Features per structure: 255
   📈 Phonon frequencies per structure: 8568

📁 Files Created:
   ✅ X.npy - Structural features ((358, 255))
   ✅ Y.npy - Phonon frequencies ((358, 8568))

🚀 Ready for MLP Training!
   Load data: X = np.load('X.npy'), Y = np.load('Y.npy')
   Model input: X (255 features)
   Model output: Y (8568 phonon frequencies)
