In [10]:
import os
import numpy as np
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
df = pd.read_csv('../eval_results_20250512_184623/test_predictions_detailed.csv')

In [4]:
df = df.drop("Primary_Element", axis=1)

In [7]:
print(df.head())

    Composition   Actual   Predicted      Error  Absolute_Error  \
0  Sr0.98Ga0.02  1026.95  1028.74800  -1.798096        1.798096   
1  Sr0.57Ga0.43  1001.17  1006.73895  -5.568970        5.568970   
2  Ga0.22Ba0.78   717.58   769.79940 -52.219360       52.219360   
3  Ge0.87Ga0.13  1166.02  1262.69900 -96.678955       96.678955   
4  Mg0.58Cu0.42   844.14   871.76953 -27.629517       27.629517   

   Percentage_Error  Ag_fraction  Al_fraction  Ba_fraction  Bi_fraction  ...  \
0          0.175091          0.0          0.0         0.00          0.0  ...   
1          0.556246          0.0          0.0         0.00          0.0  ...   
2          7.277148          0.0          0.0         0.78          0.0  ...   
3          8.291363          0.0          0.0         0.00          0.0  ...   
4          3.273096          0.0          0.0         0.00          0.0  ...   

   In_fraction  Li_fraction  Mg_fraction  Na_fraction  Pb_fraction  \
0          0.0          0.0         0.00      

In [12]:
def parse_composition(self, composition: str) -> list[tuple[str, float]]:
    """
    Parse composition string to extract elements and their fractions
    Example: "Al0.19Mg0.81" -> [("Al", 0.19), ("Mg", 0.81)]
    
    Args:
        composition: String representation of the composition (e.g., "Al0.19Mg0.81")
        
    Returns:
        List of tuples with elements and their normalized fractions
    """
    pattern = r'([A-Z][a-z]*)([0-9]*\.?[0-9]*)'
    matches = re.findall(pattern, composition)

    elements_fractions = []
    for element, fraction in matches:
        # Handle empty fraction (e.g., "Fe" instead of "Fe1.0")
        if fraction == "":
            fraction = "1"
        frac = float(fraction)
        if frac > 1:
            frac = frac / 100.0
        elements_fractions.append((element, frac))
    
    # Normalize fractions to ensure they sum to 1
    total = sum(frac for _, frac in elements_fractions)
    if total != 1.0:
        elements_fractions = [(element, frac/total) for element, frac in elements_fractions]
    
    # Pad with empty elements if needed (for fixed-length representation)
    while len(elements_fractions) < 4:
        elements_fractions.append(("", 0.0))
    return elements_fractions

def analyze_element_count_impact(self):
    """
    Analyze how the number of elements in an alloy affects prediction accuracy
    """
    if self.results_df is None:
        self.create_results_dataframe()
    
    # Group by number of elements
    element_count_analysis = self.results_df.groupby('n_elements').agg({
        'Absolute_Error': ['mean', 'median', 'std', 'min', 'max', 'count'],
        'Percentage_Error': ['mean', 'median']
    })
    
    # Save to CSV
    element_count_analysis.to_csv(os.path.join(self.save_dir, 'element_count_analysis.csv'))
    
    # Create visualization
    plt.figure(figsize=(14, 10))
    
    # Plot 1: Mean error by element count
    ax1 = plt.subplot(2, 1, 1)
    bars = ax1.bar(element_count_analysis.index, element_count_analysis[('Absolute_Error', 'mean')])
    
    # Add count labels
    for bar, count in zip(bars, element_count_analysis[('Absolute_Error', 'count')]):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'n={int(count)}', ha='center', va='bottom', rotation=0)
    
    ax1.set_xlabel('Number of Elements in Composition')
    ax1.set_ylabel('Mean Absolute Error')
    ax1.set_title('Mean Prediction Error by Number of Elements')
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Box plot of errors by element count
    ax2 = plt.subplot(2, 1, 2)
    
    element_counts = sorted(self.results_df['n_elements'].unique())
    data_for_boxplot = [self.results_df[self.results_df['n_elements'] == count]['Absolute_Error'] 
                        for count in element_counts]
    
    ax2.boxplot(data_for_boxplot, labels=[f'{count} elements' for count in element_counts])
    ax2.set_ylabel('Absolute Error')
    ax2.set_title('Error Distribution by Number of Elements')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(self.save_dir, 'element_count_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print("\nImpact of Element Count on Prediction Error:")
    print(element_count_analysis)
    
    return element_count_analysis

def analyze_entropy_effect(self):
    """
    Analyze how compositional entropy affects prediction accuracy
    Entropy is higher when elements are more evenly distributed
    """
    if self.results_df is None:
        self.create_results_dataframe()
    
    # Calculate compositional entropy for each alloy
    entropies = []
    
    for idx, row in self.results_df.iterrows():
        composition = row['Composition']
        elements_fractions = self.parse_composition(composition)
        
        # Calculate Shannon entropy: -sum(p * log(p))
        entropy = 0
        for _, frac in elements_fractions:
            if frac > 0:
                entropy -= frac * np.log(frac)
        
        entropies.append(entropy)
    
    # Add entropy to results dataframe
    self.results_df['Compositional_Entropy'] = entropies
    
    # Create entropy bands for analysis
    self.results_df['Entropy_Band'] = pd.cut(
        self.results_df['Compositional_Entropy'], 
        bins=5, 
        labels=['Very Low', 'Low', 'Medium', 'High', 'Very High']
    )
    
    # Analyze error by entropy band
    entropy_analysis = self.results_df.groupby('Entropy_Band').agg({
        'Absolute_Error': ['mean', 'median', 'std', 'count'],
        'Percentage_Error': ['mean', 'median']
    })
    
    # Save to CSV
    entropy_analysis.to_csv(os.path.join(self.save_dir, 'entropy_analysis.csv'))
    
    # Visualize relationship between entropy and error
    plt.figure(figsize=(16, 12))
    
    # Plot 1: Scatter plot of entropy vs error
    ax1 = plt.subplot(2, 2, 1)
    scatter = ax1.scatter(
        self.results_df['Compositional_Entropy'], 
        self.results_df['Absolute_Error'],
        alpha=0.6, 
        c=self.results_df['n_elements'],
        cmap='viridis',
        s=50
    )
    
    # Add trend line
    z = np.polyfit(self.results_df['Compositional_Entropy'], self.results_df['Absolute_Error'], 1)
    p = np.poly1d(z)
    x_range = np.linspace(self.results_df['Compositional_Entropy'].min(), 
                          self.results_df['Compositional_Entropy'].max(), 100)
    ax1.plot(x_range, p(x_range), "r--", linewidth=2)
    
    # Add correlation coefficient
    corr = np.corrcoef(self.results_df['Compositional_Entropy'], self.results_df['Absolute_Error'])[0, 1]
    ax1.text(0.05, 0.95, f"r = {corr:.2f}", transform=ax1.transAxes, 
             fontsize=12, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    ax1.set_xlabel('Compositional Entropy', fontsize=12)
    ax1.set_ylabel('Absolute Error', fontsize=12)
    ax1.set_title('Prediction Error vs Compositional Entropy', fontsize=14)
    ax1.grid(True, alpha=0.3)
    
    # Add colorbar for number of elements
    cbar = plt.colorbar(scatter, ax=ax1)
    cbar.set_label('Number of Elements', fontsize=12)
    
    # Plot 2: Bar chart of mean error by entropy band
    ax2 = plt.subplot(2, 2, 2)
    bars = ax2.bar(entropy_analysis.index, entropy_analysis[('Absolute_Error', 'mean')])
    
    # Add count labels
    for bar, count in zip(bars, entropy_analysis[('Absolute_Error', 'count')]):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'n={int(count)}', ha='center', va='bottom', rotation=0)
    
    ax2.set_xlabel('Entropy Band', fontsize=12)
    ax2.set_ylabel('Mean Absolute Error', fontsize=12)
    ax2.set_title('Mean Error by Compositional Entropy', fontsize=14)
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Box plot of errors by entropy band
    ax3 = plt.subplot(2, 2, 3)
    sns.boxplot(x='Entropy_Band', y='Absolute_Error', data=self.results_df, ax=ax3)
    ax3.set_xlabel('Entropy Band', fontsize=12)
    ax3.set_ylabel('Absolute Error', fontsize=12)
    ax3.set_title('Error Distribution by Compositional Entropy', fontsize=14)
    
    # Plot 4: Entropy vs percentage error
    ax4 = plt.subplot(2, 2, 4)
    scatter = ax4.scatter(
        self.results_df['Compositional_Entropy'], 
        self.results_df['Percentage_Error'],
        alpha=0.6, 
        c=self.results_df['Actual'],
        cmap='plasma',
        s=50
    )
    
    z = np.polyfit(self.results_df['Compositional_Entropy'], self.results_df['Percentage_Error'], 1)
    p = np.poly1d(z)
    ax4.plot(x_range, p(x_range), "r--", linewidth=2)
    
    corr = np.corrcoef(self.results_df['Compositional_Entropy'], self.results_df['Percentage_Error'])[0, 1]
    ax4.text(0.05, 0.95, f"r = {corr:.2f}", transform=ax4.transAxes, 
             fontsize=12, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    ax4.set_xlabel('Compositional Entropy', fontsize=12)
    ax4.set_ylabel('Percentage Error (%)', fontsize=12)
    ax4.set_title('Percentage Error vs Compositional Entropy', fontsize=14)
    ax4.grid(True, alpha=0.3)
    
    cbar = plt.colorbar(scatter, ax=ax4)
    cbar.set_label('Actual Temperature (K)', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(os.path.join(self.save_dir, 'entropy_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print("\nImpact of Compositional Entropy on Prediction Error:")
    print(entropy_analysis)
    
    return entropy_analysis

def analyze_elemental_similarity(self):
    """
    Analyze if compositions with similar elements but different ratios 
    have consistent prediction errors
    """
    if self.results_df is None:
        self.create_results_dataframe()
    
    # Find all unique elements in the dataset
    element_columns = [col for col in self.results_df.columns if col.endswith('_fraction')]
    elements = [col.split('_')[0] for col in element_columns]
    elements = [e for e in elements if e]  # Remove empty strings
    
    # Get element sets (combinations of elements used in compositions)
    element_sets = []
    element_set_to_idx = {}
    
    for idx, row in self.results_df.iterrows():
        # Get elements present in this composition
        present_elements = set()
        for element in elements:
            if row[f'{element}_fraction'] > 0:
                present_elements.add(element)
        
        present_elements = tuple(sorted(present_elements))
        
        if present_elements:
            if present_elements not in element_set_to_idx:
                element_set_to_idx[present_elements] = len(element_sets)
                element_sets.append(present_elements)
    
    # Add element set index to dataframe
    self.results_df['Element_Set'] = -1
    for idx, row in self.results_df.iterrows():
        present_elements = tuple(sorted([
            element for element in elements 
            if row[f'{element}_fraction'] > 0
        ]))
        
        if present_elements in element_set_to_idx:
            self.results_df.at[idx, 'Element_Set'] = element_set_to_idx[present_elements]
    
    # Find element sets with at least 3 compositions for meaningful analysis
    element_set_counts = self.results_df['Element_Set'].value_counts()
    common_element_sets = element_set_counts[element_set_counts >= 3].index.tolist()
    
    if not common_element_sets:
        print("No element sets with at least 3 compositions found. Skipping elemental similarity analysis.")
        return None
    
    # Create consolidated dataframe for analysis
    element_set_data = []
    
    for set_idx in common_element_sets:
        set_compositions = self.results_df[self.results_df['Element_Set'] == set_idx]
        element_names = element_sets[set_idx]
        
        element_set_data.append({
            'Element_Set': '-'.join(element_names),
            'Count': len(set_compositions),
            'Mean_Error': set_compositions['Absolute_Error'].mean(),
            'Std_Error': set_compositions['Absolute_Error'].std(),
            'Mean_Pct_Error': set_compositions['Percentage_Error'].mean(),
            'Min_Temp': set_compositions['Actual'].min(),
            'Max_Temp': set_compositions['Actual'].max(),
            'Temp_Range': set_compositions['Actual'].max() - set_compositions['Actual'].min()
        })
    
    element_set_df = pd.DataFrame(element_set_data)
    element_set_df = element_set_df.sort_values('Count', ascending=False)
    
    # Save to CSV
    element_set_df.to_csv(os.path.join(self.save_dir, 'element_set_analysis.csv'), index=False)
    
    # Visualize element set analysis
    plt.figure(figsize=(14, 10))
    
    # Plot 1: Error by element set
    ax1 = plt.subplot(2, 1, 1)
    
    # Limit to top 15 most common sets for readability
    top_sets = element_set_df.head(15)
    
    bars = ax1.bar(top_sets['Element_Set'], top_sets['Mean_Error'])
    
    # Add error bars
    ax1.errorbar(
        top_sets['Element_Set'], 
        top_sets['Mean_Error'],
        yerr=top_sets['Std_Error'],
        fmt='none', 
        ecolor='black', 
        capsize=5
    )
    
    # Add count labels
    for bar, count in zip(bars, top_sets['Count']):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 2,
                f'n={int(count)}', ha='center', va='bottom', rotation=0)
    
    ax1.set_xlabel('Element Combination', fontsize=12)
    ax1.set_ylabel('Mean Absolute Error', fontsize=12)
    ax1.set_title('Prediction Error by Element Combination', fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.tick_params(axis='x', rotation=45)
    
    # Plot 2: Detailed analysis of top element sets
    ax2 = plt.subplot(2, 1, 2)
    
    # Get top 5 most common element sets for detailed analysis
    top_5_sets = element_set_df.head(5)['Element_Set'].tolist()
    
    # Create list of compositions for each set
    detailed_data = []
    
    for set_name in top_5_sets:
        element_names = set_name.split('-')
        set_idx = element_set_to_idx[tuple(element_names)]
        
        set_compositions = self.results_df[self.results_df['Element_Set'] == set_idx]
        
        for idx, row in set_compositions.iterrows():
            # Get the ratio of the first element to second element (for binary systems)
            if len(element_names) == 2:
                ratio = row[f'{element_names[0]}_fraction'] / row[f'{element_names[1]}_fraction'] if row[f'{element_names[1]}_fraction'] > 0 else 0
            else:
                ratio = 0
                
            detailed_data.append({
                'Element_Set': set_name,
                'Composition': row['Composition'],
                'Actual': row['Actual'],
                'Predicted': row['Predicted'],
                'Error': row['Absolute_Error'],
                'Ratio': ratio
            })
    
    detailed_df = pd.DataFrame(detailed_data)
    
    # For binary systems, plot error vs. element ratio
    binary_sets = [s for s in top_5_sets if len(s.split('-')) == 2]
    
    if binary_sets:
        for set_name in binary_sets:
            set_data = detailed_df[detailed_df['Element_Set'] == set_name]
            ax2.scatter(
                set_data['Ratio'], 
                set_data['Error'],
                label=set_name,
                s=50,
                alpha=0.7
            )
        
        ax2.set_xlabel('Element Ratio (First/Second)', fontsize=12)
        ax2.set_ylabel('Absolute Error', fontsize=12)
        ax2.set_title('Error vs Element Ratio for Binary Systems', fontsize=14)
        ax2.grid(True, alpha=0.3)
        ax2.legend()
    else:
        # If no binary systems, create a different visualization
        sns.boxplot(x='Element_Set', y='Error', data=detailed_df, ax=ax2)
        ax2.set_xlabel('Element Combination', fontsize=12)
        ax2.set_ylabel('Absolute Error', fontsize=12)
        ax2.set_title('Error Distribution by Element Combination', fontsize=14)
        ax2.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.savefig(os.path.join(self.save_dir, 'element_set_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create scatter plots for top binary systems
    if binary_sets:
        for set_name in binary_sets[:3]:  # Limit to top 3 for brevity
            plt.figure(figsize=(14, 7))
            
            set_data = detailed_df[detailed_df['Element_Set'] == set_name]
            element_names = set_name.split('-')
            
            # Plot 1: Error vs first element fraction
            ax1 = plt.subplot(1, 2, 1)
            scatter = ax1.scatter(
                [row[f'{element_names[0]}_fraction'] for _, row in set_data.iterrows()],
                set_data['Error'],
                c=set_data['Actual'],
                cmap='viridis',
                s=70,
                alpha=0.8
            )
            
            ax1.set_xlabel(f'{element_names[0]} Fraction', fontsize=12)
            ax1.set_ylabel('Absolute Error', fontsize=12)
            ax1.set_title(f'Error vs {element_names[0]} Content for {set_name}', fontsize=14)
            ax1.grid(True, alpha=0.3)
            
            cbar = plt.colorbar(scatter, ax=ax1)
            cbar.set_label('Actual Temperature (K)', fontsize=12)
            
            # Plot 2: Actual vs Predicted for this system
            ax2 = plt.subplot(1, 2, 2)
            fractions = [row[f'{element_names[0]}_fraction'] for _, row in set_data.iterrows()]
            
            scatter = ax2.scatter(
                set_data['Actual'],
                set_data['Predicted'],
                c=fractions,
                cmap='plasma',
                s=70,
                alpha=0.8
            )
            
            # Add perfect prediction line
            min_val = min(set_data['Actual'].min(), set_data['Predicted'].min())
            max_val = max(set_data['Actual'].max(), set_data['Predicted'].max())
            ax2.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.7)
            
            ax2.set_xlabel('Actual Temperature (K)', fontsize=12)
            ax2.set_ylabel('Predicted Temperature (K)', fontsize=12)
            ax2.set_title(f'Predicted vs Actual for {set_name}', fontsize=14)
            ax2.grid(True, alpha=0.3)
            
            cbar = plt.colorbar(scatter, ax=ax2)
            cbar.set_label(f'{element_names[0]} Fraction', fontsize=12)
            
            plt.tight_layout()
            plt.savefig(os.path.join(self.save_dir, f'binary_analysis_{set_name}.png'), dpi=300, bbox_inches='tight')
            plt.close()
    
    print("\nElement Set Analysis (Top 10):")
    print(element_set_df.head(10))
    
    return element_set_df

def run_full_evaluation(self):
    """Run complete evaluation pipeline with enhanced element analysis"""
    print("Starting model evaluation...")
    
    # Get predictions
    self.predict()
    
    # Calculate metrics
    self.calculate_metrics()
    
    # Create results dataframe
    self.create_results_dataframe()
    
    # Basic plots
    print("Creating basic evaluation plots...")
    self.plot_basic_evaluation()
    self.plot_composition_analysis()
    
    # Enhanced element analysis
    print("Performing enhanced element analysis...")
    
    # Analysis of element count impact
    print("Analyzing impact of element count...")
    self.analyze_element_count_impact()
    
    # Analysis of compositional entropy
    print("Analyzing effect of compositional entropy...")
    self.analyze_entropy_effect()
    
    # Analysis of elemental similarity
    print("Analyzing elemental similarity patterns...")
    self.analyze_elemental_similarity()
    
    # Save results
    print("Saving results...")
    self.save_results()
    
    # Additional information to console
    print("\nTop 10 Worst Predictions:")
    print("Composition | Primary Element | Actual | Predicted | Error")
    print("-" * 70)
    
    worst_10 = self.results_df.nlargest(10, 'Absolute_Error')
    for _, row in worst_10.iterrows():
        print(f"{row['Composition']:<20} | {row['Primary_Element']:<15} | {row['Actual']:6.2f} | {row['Predicted']:9.2f} | {row['Absolute_Error']:5.2f}")
    
    return self.results_df, self.summary_stats