# Policy Performance Factor Analysis

This notebook performs factor analysis on policy performance across multiple training runs using evaluation stats from the Metta stats database.

In [2]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

%matplotlib inline
plt.style.use('default')

print('Setup complete! Auto-reload enabled.')

Setup complete! Auto-reload enabled.


In [3]:
# Training run names to analyze
run_names = [
    "jacke.cyclical_8tick_timing_study_20250728_134443",
    "yudhister.2x4_arena_lp_02_july25_fix_seed_incl_steps_fix_2",
    # "jacke.cyclical_8tick_timing_study_20250728_134805",
    # "george.operantconditioning.smoketestthree.07-29",
    # "jacke.cyclical_11tick_timing_study_20250729_095309",
    # "jacke.sky_nav_base_20250725_143311",
    # "george.operantconditioning.smoketestfour.07-29",
    # "zfogg.devbox.arena_nav_combined.07-29",
    # "daphne.operantconditioning.backchaining.earlyterm.any.2.07-29",
    # "jacke.sky_nav_grid_20250725_154822",
    # "zfogg.skypilot.sim_all.gpu8.07-29",
    # "daphne.operantconditioning.smoketesttwo.earlyterm.all.2.07-29",
    # "jacke.cyclical_8tick_timing_study_20250729_095333",
    # "yudhister.2x4_arena_lp_03_july25_fix_seed_incl_steps_fix_2",
    # "jacke.sky_nav_grid_20250725_154808",
    # "jacke.sky_nav_spiral_20250725_154729",
    # "yudhister.2x4_arena_random_07_july25_fix_seed_incl_steps_fix_2",
    # "yudhister.2x4_arena_lp_05_july25_fix_seed_incl_steps_fix_2",
    # "daphne.operantconditioning.smoketestfour.earlyterm.all.2.07-29",
    # "jacke.sky_random_nav_grid_spiral_20250725_154829",
    # "nishad-0726-1148",
    # "yudhister.2x4_arena_lp_01_july25_fix_seed_incl_steps_fix_2",
    # "daphne.operantconditioning.smoketestthree.earlyterm.any.2.07-29",
    # "daphne.navigation.earlyterm.half.2.07-28",
    # "jacke.cyclical_15tick_timing_study_20250729_095339",
    # "daphne.operantconditioning.smoketestthree.earlyterm.all.2.07-29",
    # "jacke.cyclical_10tick_timing_study_20250729_095335",
    # "bullm.navigation.low_reward.with_context.07-24",
    # "absurdlybasictest5",
    # "zfogg.skypilot.learning_progress.07-29.2",
    # "yudhister.2x4_arena_lp_07_july25_fix_seed_incl_steps_fix_2",
    # "jacke.sky_nav_grid_spiral_20250725_154838",
    # "george.operantconditioning.smoketesttwo.07-29",
    # "jacke.cyclical_11tick_timing_study_20250729_095304",
    # "jacke.cyclical_6tick_timing_study_20250728_134441",
    # "jacke.sky_nav_spiral_20250725_154737",
    # "yudhister.2x4_arena_random_05_july25_fix_seed_incl_steps_fix_2",
    # "daphne.operantconditioning.smoketestfour.earlyterm.any.2.07-29",
    # "george.operantconditioning.backchaining.07-29"
]

print(f"Total runs to analyze: {len(run_names)}")

Total runs to analyze: 2


In [4]:
# Initialize API client
from metta.common.util.constants import PROD_STATS_SERVER_URI
from metta.common.client.metta_client import MettaAPIClient

client = MettaAPIClient(PROD_STATS_SERVER_URI)

print(f"Connected to Metta API at: {PROD_STATS_SERVER_URI}")

Connected to Metta API at: https://api.observatory.softmax-research.net


In [6]:
from experiments.notebooks.utils.heatmap_widget.heatmap_widget.util import fetch_real_heatmap_data
# Fetch policy data using the training run names
print(f"Fetching evaluation data for {len(run_names)} training runs...")

# Fetch heatmap data for all runs
heatmap_data = await fetch_real_heatmap_data(
    search_texts=run_names,
    api_base_url=PROD_STATS_SERVER_URI,
    metrics=["reward"],
    policy_selector="latest",  # Get best policy from each training run
    max_policies=100
)

print("Fetched data")

Fetching evaluation data for 2 training runs...
🚀 HeatmapWidget initialized successfully!
📊 Multi-metric data set with 2 policies and 6 evaluations
📈 Available metrics: reward
📈 Selected metric: reward
Fetched data


In [18]:
heatmap_data

HeatmapWidget(heatmap_data={'cells': {'jacke.cyclical_8tick_timing_study_20250728_134443:v43': {'arena/advance…

In [22]:
# Convert heatmap data to dataframe for factor analysis
def heatmap_to_dataframe(heatmap_data):
    """Convert heatmap data structure to a dataframe suitable for factor analysis."""
    rows = []
    
    for policy_name in heatmap_data.policyNames:
        row = {'policy_name': policy_name}
        
        # Extract metrics across all evaluations
        for eval_name in heatmap_data.evalNames:
            cell = heatmap_data.cells.get(policy_name, {}).get(eval_name, {})
            metrics = cell.get('metrics', {})
            
            # Add each metric with eval_name prefix
            for metric_name, value in metrics.items():
                col_name = f"{eval_name}_{metric_name}"
                row[col_name] = value
        
        # Add average scores
        row['average_score'] = heatmap_data.policyAverageScores.get(policy_name, 0)
        rows.append(row)
    
    return pd.DataFrame(rows)

# Create dataframe
df = heatmap_to_dataframe(heatmap_data)
print(f"Created dataframe with shape: {df.shape}")
print(f"\nFirst few columns: {list(df.columns[:10])}")
print(f"\nDataframe head:")
df.head()

AttributeError: 'HeatmapWidget' object has no attribute 'policyNames'

In [None]:
# Data overview and cleaning
print("Data shape:", df.shape)
print("\nMissing values per column:")
missing_counts = df.isnull().sum()
print(missing_counts[missing_counts > 0].head(20))

# Fill missing values with 0 (indicating no performance on that evaluation)
df_filled = df.fillna(0)

print(f"\nData types:")
print(df_filled.dtypes.value_counts())

In [None]:
# Prepare features for factor analysis
# Select only numeric columns (exclude policy_name)
numeric_cols = df_filled.select_dtypes(include=[np.number]).columns.tolist()

print(f"Found {len(numeric_cols)} numeric features")

# Create feature matrix
X = df_filled[numeric_cols]

# Remove columns with zero variance
variance_filter = X.var() > 0
X_filtered = X.loc[:, variance_filter]
selected_features = X_filtered.columns.tolist()

print(f"After removing zero-variance features: {len(selected_features)} features")

# Standardize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_filtered)

print(f"\nFeature matrix shape: {X_scaled.shape}")
print(f"Number of policies: {X_scaled.shape[0]}")
print(f"Number of features: {X_scaled.shape[1]}")

In [None]:
# Perform PCA to understand variance structure
if X_scaled.shape[0] > 1 and X_scaled.shape[1] > 0:
    # Determine number of components
    n_components = min(X_scaled.shape[0], X_scaled.shape[1])
    
    pca = PCA(n_components=n_components)
    pca_result = pca.fit_transform(X_scaled)
    
    # Plot explained variance
    plt.figure(figsize=(10, 6))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, min(20, len(pca.explained_variance_ratio_)) + 1), 
             pca.explained_variance_ratio_[:20], 
             'bo-')
    plt.xlabel('Component Number')
    plt.ylabel('Explained Variance Ratio')
    plt.title('PCA Explained Variance (First 20 Components)')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(range(1, min(20, len(pca.explained_variance_ratio_)) + 1), 
             np.cumsum(pca.explained_variance_ratio_[:20]), 
             'ro-')
    plt.xlabel('Number of Components')
    plt.ylabel('Cumulative Explained Variance')
    plt.title('Cumulative Explained Variance')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Print variance explained by first few components
    print("Variance explained by first 10 components:")
    for i in range(min(10, len(pca.explained_variance_ratio_))):
        print(f"PC{i+1}: {pca.explained_variance_ratio_[i]:.3f} ({np.cumsum(pca.explained_variance_ratio_)[i]:.3f} cumulative)")
else:
    print("Not enough data for PCA analysis")

In [None]:
# Perform Factor Analysis
if X_scaled.shape[0] > 3 and X_scaled.shape[1] > 3:
    # Choose number of factors based on PCA results
    n_factors = min(5, X_scaled.shape[0] - 1, X_scaled.shape[1])  # Start with 5 factors or less
    
    print(f"Performing factor analysis with {n_factors} factors...")
    
    fa = FactorAnalysis(n_components=n_factors, random_state=42)
    factors = fa.fit_transform(X_scaled)
    
    # Get factor loadings
    loadings = pd.DataFrame(
        fa.components_.T,
        columns=[f'Factor_{i+1}' for i in range(n_factors)],
        index=selected_features
    )
    
    print(f"\nFactor analysis complete. Shape of factor scores: {factors.shape}")
    print("\nTop loadings for each factor:")
    
    for i in range(n_factors):
        print(f"\nFactor {i+1}:")
        factor_col = f'Factor_{i+1}'
        top_positive = loadings[factor_col].nlargest(5)
        top_negative = loadings[factor_col].nsmallest(5)
        
        print("  Top positive loadings:")
        for feat, loading in top_positive.items():
            print(f"    {feat}: {loading:.3f}")
        
        print("  Top negative loadings:")
        for feat, loading in top_negative.items():
            print(f"    {feat}: {loading:.3f}")
else:
    print("Not enough data for factor analysis")

In [None]:
# Visualize factor loadings heatmap
if 'loadings' in locals():
    plt.figure(figsize=(12, 8))
    
    # Select top features with highest absolute loadings
    abs_loadings = loadings.abs()
    max_loadings = abs_loadings.max(axis=1)
    top_features_idx = max_loadings.nlargest(30).index
    
    # Create subset of loadings for visualization
    loadings_subset = loadings.loc[top_features_idx]
    
    sns.heatmap(loadings_subset.T, cmap='RdBu_r', center=0, 
                annot=True, fmt='.2f', 
                cbar_kws={'label': 'Loading'})
    plt.title('Factor Loadings Heatmap (Top 30 Features)')
    plt.xlabel('Features')
    plt.ylabel('Factors')
    plt.tight_layout()
    plt.show()

In [None]:
# Add factor scores to dataframe and analyze
if 'factors' in locals():
    # Add factor scores to original dataframe
    for i in range(n_factors):
        df[f'factor_{i+1}'] = factors[:, i]
    
    # Show policies with highest scores on each factor
    print("Top policies by factor scores:\n")
    for i in range(n_factors):
        print(f"Factor {i+1} - Top 5 policies:")
        top_policies = df.nlargest(5, f'factor_{i+1}')[['policy_name', f'factor_{i+1}']]
        for _, row in top_policies.iterrows():
            print(f"  {row['policy_name'][:60]:60} {row[f'factor_{i+1}']:.3f}")
        print()
    
    # Also show bottom 5 for contrast
    print("\nBottom policies by factor scores:\n")
    for i in range(n_factors):
        print(f"Factor {i+1} - Bottom 5 policies:")
        bottom_policies = df.nsmallest(5, f'factor_{i+1}')[['policy_name', f'factor_{i+1}']]
        for _, row in bottom_policies.iterrows():
            print(f"  {row['policy_name'][:60]:60} {row[f'factor_{i+1}']:.3f}")
        print()

In [None]:
# Scatter plot of factor scores
if 'factors' in locals() and n_factors >= 2:
    plt.figure(figsize=(10, 8))
    
    # Plot first two factors
    plt.scatter(df['factor_1'], df['factor_2'], alpha=0.6, s=50)
    
    # Annotate interesting points (high/low on either factor)
    threshold = 1.5
    for idx, row in df.iterrows():
        if abs(row['factor_1']) > threshold or abs(row['factor_2']) > threshold:
            # Extract just the run name part for cleaner labels
            label = row['policy_name'].split('.')[-1][:20] if '.' in row['policy_name'] else row['policy_name'][:20]
            plt.annotate(label, 
                        (row['factor_1'], row['factor_2']),
                        fontsize=8, alpha=0.7,
                        xytext=(5, 5), textcoords='offset points')
    
    plt.xlabel('Factor 1')
    plt.ylabel('Factor 2')
    plt.title('Policy Performance Factor Analysis - Factor Space')
    plt.grid(True, alpha=0.3)
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.2)
    plt.axvline(x=0, color='k', linestyle='-', alpha=0.2)
    plt.show()
    
    # If we have more factors, show factor 1 vs 3
    if n_factors >= 3:
        plt.figure(figsize=(10, 8))
        plt.scatter(df['factor_1'], df['factor_3'], alpha=0.6, s=50)
        
        for idx, row in df.iterrows():
            if abs(row['factor_1']) > threshold or abs(row['factor_3']) > threshold:
                label = row['policy_name'].split('.')[-1][:20] if '.' in row['policy_name'] else row['policy_name'][:20]
                plt.annotate(label, 
                            (row['factor_1'], row['factor_3']),
                            fontsize=8, alpha=0.7,
                            xytext=(5, 5), textcoords='offset points')
        
        plt.xlabel('Factor 1')
        plt.ylabel('Factor 3')
        plt.title('Policy Performance Factor Analysis - Factors 1 vs 3')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0, color='k', linestyle='-', alpha=0.2)
        plt.axvline(x=0, color='k', linestyle='-', alpha=0.2)
        plt.show()

In [None]:
# Interpret factors by grouping features
if 'loadings' in locals():
    print("Factor Interpretation:\n")
    
    for i in range(n_factors):
        print(f"Factor {i+1} Interpretation:")
        factor_col = f'Factor_{i+1}'
        
        # Group features by evaluation task and metric type
        eval_groups = {}
        metric_groups = {}
        
        for feature in loadings.index:
            loading = loadings.loc[feature, factor_col]
            if abs(loading) > 0.3:  # Only consider significant loadings
                # Parse feature name (format: eval_name_metric)
                parts = feature.split('_')
                if len(parts) >= 2:
                    # Assume format is evalcategory/envname_metric
                    eval_part = parts[0] if '/' not in parts[0] else parts[0].split('/')[0]
                    metric_part = '_'.join(parts[1:]) if len(parts) > 1 else parts[-1]
                    
                    if eval_part not in eval_groups:
                        eval_groups[eval_part] = []
                    eval_groups[eval_part].append((feature, loading))
                    
                    if metric_part not in metric_groups:
                        metric_groups[metric_part] = []
                    metric_groups[metric_part].append((feature, loading))
        
        # Show patterns
        print(f"  Evaluation categories with high loadings:")
        for eval_cat, features in sorted(eval_groups.items(), key=lambda x: -len(x[1]))[:5]:
            avg_loading = np.mean([f[1] for f in features])
            print(f"    {eval_cat}: {len(features)} features, avg loading: {avg_loading:.3f}")
        
        print(f"  Metric types with high loadings:")
        for metric, features in sorted(metric_groups.items(), key=lambda x: -len(x[1]))[:5]:
            avg_loading = np.mean([f[1] for f in features])
            print(f"    {metric}: {len(features)} occurrences, avg loading: {avg_loading:.3f}")
        print()

In [None]:
# Correlation between factors and average performance
if 'factors' in locals() and 'average_score' in df.columns:
    correlations = []
    for i in range(n_factors):
        corr = df[f'factor_{i+1}'].corr(df['average_score'])
        correlations.append(corr)
        print(f"Factor {i+1} correlation with average score: {corr:.3f}")
    
    # Visualize correlations
    plt.figure(figsize=(8, 6))
    plt.bar(range(1, n_factors + 1), correlations)
    plt.xlabel('Factor')
    plt.ylabel('Correlation with Average Score')
    plt.title('Factor Correlations with Overall Performance')
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
# Save results
output_dir = Path("./factor_analysis_results")
output_dir.mkdir(exist_ok=True)

# Save policy factor scores
if 'factors' in locals():
    factor_cols = ['policy_name'] + [f'factor_{i+1}' for i in range(n_factors)] + ['average_score']
    df[factor_cols].to_csv(output_dir / "policy_factor_scores.csv", index=False)
    print(f"Saved policy factor scores to {output_dir / 'policy_factor_scores.csv'}")

# Save factor loadings
if 'loadings' in locals():
    loadings.to_csv(output_dir / "factor_loadings.csv")
    print(f"Saved factor loadings to {output_dir / 'factor_loadings.csv'}")

# Save summary statistics
if 'factors' in locals():
    with open(output_dir / "factor_analysis_summary.txt", "w") as f:
        f.write("Factor Analysis Summary\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Number of policies analyzed: {len(df)}\n")
        f.write(f"Number of features: {len(selected_features)}\n")
        f.write(f"Number of factors extracted: {n_factors}\n\n")
        
        f.write("Factor correlations with average score:\n")
        for i in range(n_factors):
            corr = df[f'factor_{i+1}'].corr(df['average_score'])
            f.write(f"  Factor {i+1}: {corr:.3f}\n")
        
        f.write("\nVariance explained (from PCA):\n")
        for i in range(min(n_factors, len(pca.explained_variance_ratio_))):
            f.write(f"  PC{i+1}: {pca.explained_variance_ratio_[i]:.3f}\n")
    
    print(f"Saved analysis summary to {output_dir / 'factor_analysis_summary.txt'}")

print(f"\nAll results saved to {output_dir}/")

In [None]:
# Interpret factors by looking at highest loading features
for i in range(n_factors):
    print(f"\nFactor {i+1} - Top loading features:")
    factor_col = f'Factor_{i+1}'
    top_positive = loadings[factor_col].nlargest(3)
    top_negative = loadings[factor_col].nsmallest(3)
    
    print("  Positive loadings:")
    for feat, loading in top_positive.items():
        print(f"    {feat}: {loading:.3f}")
    
    print("  Negative loadings:")
    for feat, loading in top_negative.items():
        print(f"    {feat}: {loading:.3f}")

In [None]:
# Save results
output_dir = Path("./factor_analysis_results")
output_dir.mkdir(exist_ok=True)

# Save factor scores
df.to_csv(output_dir / "policy_factor_scores.csv", index=False)

# Save loadings
loadings.to_csv(output_dir / "factor_loadings.csv")

print(f"Results saved to {output_dir}")