# Foundation Models Comparison: astroPT vs AION

This notebook compares embeddings from two different multimodal foundation models:

**astroPT Multimodal**: Transformer model trained on DESI spectra + Euclid images
- Checkpoint: iteration 21000
- Location: `/pbs/home/a/astroinfo09/logs/logs/astropt_multimodal_full_20251106_011934/`

**AION Multimodal**: Foundation model trained on Euclid images + DESI spectra
- Embeddings from Maxime's work
- Location: `/pbs/throng/training/astroinfo2025/work/maxime/data_all_tokens_spectrums.pt`

We compare how the two models encode spectral and imaging information in their embedding spaces.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
from astropy.io import fits
from astropy.table import Table
import umap
import time
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity

# Configure matplotlib
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

## 1. Load astroPT Multimodal Embeddings

Load the embeddings extracted from the astroPT multimodal model (spectra + images).

In [None]:
# Path to astroPT multimodal embeddings
embeddings_dir = Path("/pbs/home/a/astroinfo09/logs/logs/astropt_multimodal_full_20251106_011934/embeddings_output_21000")

print("Loading astroPT multimodal embeddings...")
print(f"Embeddings directory: {embeddings_dir}")

# Check directory contents
if embeddings_dir.exists():
    files = list(embeddings_dir.glob("*.npy")) + list(embeddings_dir.glob("*.npz"))
    print(f"\nFound {len(files)} files:")
    for f in files[:10]:
        print(f"  {f.name}")
    if len(files) > 10:
        print(f"  ... and {len(files) - 10} more")
else:
    print(f"‚ö† WARNING: Directory not found")

# Try to load embeddings
try:
    # Try loading as individual files
    astropt_embeddings = np.load(embeddings_dir / "embeddings.npy")
    astropt_targetids = np.load(embeddings_dir / "targetids.npy")
    
    try:
        astropt_redshifts = np.load(embeddings_dir / "redshifts.npy")
    except FileNotFoundError:
        astropt_redshifts = None
        print("‚ö† Redshifts file not found, will extract from catalog")
    
    print(f"\n‚úì astroPT embeddings loaded: {astropt_embeddings.shape}")
    print(f"‚úì Target IDs: {len(astropt_targetids)}")
    if astropt_redshifts is not None:
        print(f"‚úì Redshifts: {len(astropt_redshifts)}")

except FileNotFoundError:
    print("\n‚ö† Standard files not found, trying .npz format...")
    try:
        data = np.load(embeddings_dir / "embeddings_with_metadata.npz")
        astropt_embeddings = data['embeddings']
        astropt_targetids = data['target_ids']
        astropt_redshifts = data.get('redshifts', None)
        
        print(f"\n‚úì Loaded from .npz file")
        print(f"‚úì Embeddings: {astropt_embeddings.shape}")
        print(f"‚úì Target IDs: {len(astropt_targetids)}")
    
    except FileNotFoundError:
        print("\n‚ùå Could not find embeddings. Please check directory structure.")
        raise

## 2. Load AION Multimodal Embeddings

Load the embeddings from AION model.

In [None]:
# Path to AION embeddings
aion_path = "/pbs/throng/training/astroinfo2025/work/maxime/data_all_tokens_spectrums.pt"

print("Loading AION embeddings...")
import torch
aion_data = torch.load(aion_path, map_location="cpu")

aion_records = aion_data if isinstance(aion_data, list) else [aion_data]
print(f"‚úì AION data loaded: {len(aion_records)} records")

# Extract embeddings
def stack_embeddings(records, key):
    """Extract and stack embeddings for a given key."""
    vectors = []
    indices = []
    
    for idx, rec in enumerate(records):
        tensor = rec.get(key)
        if tensor is None:
            continue
        if isinstance(tensor, torch.Tensor):
            vectors.append(tensor.detach().cpu().numpy())
        else:
            vectors.append(np.asarray(tensor))
        indices.append(idx)
    
    if not vectors:
        raise ValueError(f"No embeddings found for key '{key}'")
    
    return np.stack(vectors, axis=0), np.array(indices)

# Extract AION multimodal embeddings (spectra + images)
aion_embeddings, _ = stack_embeddings(aion_records, "embedding_hsc_desi")

print(f"‚úì AION multimodal embeddings: {aion_embeddings.shape}")

# Extract AION object IDs and redshifts
aion_object_ids = []
aion_redshifts = []

for rec in aion_records:
    obj_id = rec.get('object_id') or rec.get('TARGETID') or rec.get('targetid')
    aion_object_ids.append(obj_id if obj_id is not None else np.nan)
    
    z = rec.get('redshift') or rec.get('Z') or rec.get('z')
    aion_redshifts.append(z if z is not None else np.nan)

aion_object_ids = np.array(aion_object_ids)
aion_redshifts = np.array(aion_redshifts, dtype=float)

## 3. Load Catalog and Match Objects

Load the catalog to get physical properties for both sets of embeddings.

In [None]:
# Path to catalog
catalog_path = "/pbs/throng/training/astroinfo2025/data/astroPT_euclid_desi_dataset/desi_euclid_catalog.fits"

print("Loading catalog...")
with fits.open(catalog_path) as hdul:
    catalog = Table(hdul[1].data)

print(f"‚úì Catalog loaded: {len(catalog)} entries")

# Find ID column
id_column = None
for col in ['TARGETID', 'targetid', 'object_id', 'OBJECT_ID']:
    if col in catalog.colnames:
        id_column = col
        print(f"‚úì Using catalog ID column: {id_column}")
        break

# Create ID mapping
catalog_ids = np.array(catalog[id_column], dtype=float)
id_to_idx = {int(cid): i for i, cid in enumerate(catalog_ids) if not np.isnan(cid)}

In [None]:
def match_to_catalog(object_ids, id_to_idx):
    """Match object IDs to catalog indices."""
    matched_indices = []
    unmatched_count = 0
    
    for obj_id in object_ids:
        if obj_id is None or obj_id == '' or (isinstance(obj_id, float) and np.isnan(obj_id)):
            matched_indices.append(-1)
            unmatched_count += 1
        else:
            try:
                obj_id_int = int(float(obj_id))
                if obj_id_int in id_to_idx:
                    matched_indices.append(id_to_idx[obj_id_int])
                else:
                    matched_indices.append(-1)
                    unmatched_count += 1
            except (ValueError, TypeError):
                matched_indices.append(-1)
                unmatched_count += 1
    
    return np.array(matched_indices), unmatched_count

# Match astroPT
astropt_matched, astropt_unmatched = match_to_catalog(astropt_targetids, id_to_idx)
astropt_match_rate = (len(astropt_matched) - astropt_unmatched) / len(astropt_targetids) * 100

print(f"\nastroPT matching:")
print(f"  Matched: {len(astropt_matched) - astropt_unmatched}/{len(astropt_targetids)} ({astropt_match_rate:.1f}%)")

# Match AION
aion_matched, aion_unmatched = match_to_catalog(aion_object_ids, id_to_idx)
aion_match_rate = (len(aion_matched) - aion_unmatched) / len(aion_object_ids) * 100

print(f"\nAION matching:")
print(f"  Matched: {len(aion_matched) - aion_unmatched}/{len(aion_object_ids)} ({aion_match_rate:.1f}%)")

In [None]:
# Extract physical properties for both datasets
def safe_extract(matched_indices, column_name):
    """Extract column values, filling unmatched entries with NaN"""
    if column_name not in catalog.colnames:
        return np.full(len(matched_indices), np.nan, dtype=float)
    
    values = np.full(len(matched_indices), np.nan, dtype=float)
    mask = matched_indices >= 0
    values[mask] = catalog[column_name][matched_indices[mask]]
    return values

# astroPT properties
if astropt_redshifts is None:
    astropt_redshifts = safe_extract(astropt_matched, 'Z')
    if np.all(np.isnan(astropt_redshifts)):
        astropt_redshifts = safe_extract(astropt_matched, 'REDSHIFT')

astropt_logm = safe_extract(astropt_matched, 'LOGM')
astropt_logsfr = safe_extract(astropt_matched, 'LOGSFR')
astropt_dn4000 = safe_extract(astropt_matched, 'DN4000')
astropt_gr = safe_extract(astropt_matched, 'GR')
astropt_ssfr = astropt_logsfr - astropt_logm

# AION properties
if np.all(np.isnan(aion_redshifts)):
    aion_redshifts = safe_extract(aion_matched, 'Z')
    if np.all(np.isnan(aion_redshifts)):
        aion_redshifts = safe_extract(aion_matched, 'REDSHIFT')

aion_logm = safe_extract(aion_matched, 'LOGM')
aion_logsfr = safe_extract(aion_matched, 'LOGSFR')
aion_dn4000 = safe_extract(aion_matched, 'DN4000')
aion_gr = safe_extract(aion_matched, 'GR')
aion_ssfr = aion_logsfr - aion_logm

print("\n‚úì Physical properties extracted for both datasets")

## 4. Find Common Objects

Identify objects that appear in both astroPT and AION datasets for direct comparison.

In [None]:
# Convert to int for matching
def to_int_array(obj_ids):
    """Convert object IDs to integer array, filtering out invalid values."""
    int_ids = []
    for obj_id in obj_ids:
        try:
            if obj_id is not None and obj_id != '' and not (isinstance(obj_id, float) and np.isnan(obj_id)):
                int_ids.append(int(float(obj_id)))
            else:
                int_ids.append(-1)
        except (ValueError, TypeError):
            int_ids.append(-1)
    return np.array(int_ids)

astropt_ids_int = to_int_array(astropt_targetids)
aion_ids_int = to_int_array(aion_object_ids)

# Find common IDs
astropt_valid = astropt_ids_int >= 0
aion_valid = aion_ids_int >= 0

common_ids = np.intersect1d(astropt_ids_int[astropt_valid], aion_ids_int[aion_valid])

print(f"\n{'='*60}")
print(f"COMMON OBJECTS ANALYSIS")
print(f"{'='*60}")
print(f"astroPT total objects: {astropt_valid.sum()}")
print(f"AION total objects: {aion_valid.sum()}")
print(f"Common objects: {len(common_ids)}")
print(f"{'='*60}")

if len(common_ids) > 0:
    # Create indices for common objects
    astropt_common_mask = np.isin(astropt_ids_int, common_ids)
    aion_common_mask = np.isin(aion_ids_int, common_ids)
    
    # Create mappings to align the two datasets
    astropt_id_to_idx = {tid: i for i, tid in enumerate(astropt_ids_int[astropt_common_mask])}
    aion_id_to_idx = {tid: i for i, tid in enumerate(aion_ids_int[aion_common_mask])}
    
    # Create aligned indices
    astropt_common_indices = np.where(astropt_common_mask)[0]
    aion_common_indices = np.where(aion_common_mask)[0]
    
    print(f"\n‚úì Found {len(common_ids)} common objects for direct comparison")
else:
    print("\n‚ö† WARNING: No common objects found between astroPT and AION datasets")
    print("   Will proceed with separate visualizations only")

## 5. Compute UMAP Projections

Apply UMAP to all three embedding types for visualization.

In [None]:
def compute_umap_projection(embeddings, name, random_state=42):
    """Compute UMAP projection for embeddings."""
    print(f"Running UMAP for {name}...")
    
    umap_model = umap.UMAP(
        n_neighbors=15,
        min_dist=0.1,
        n_components=2,
        metric='cosine',
        random_state=random_state
    )
    
    start_time = time.time()
    projection = umap_model.fit_transform(embeddings)
    elapsed = time.time() - start_time
    
    print(f"‚úì {name} UMAP completed in {elapsed:.2f} seconds")
    return projection

print("Computing UMAP projections...\n")

# Compute UMAP for both models' multimodal embeddings
umap_astropt = compute_umap_projection(astropt_embeddings, "astroPT multimodal")
umap_aion = compute_umap_projection(aion_embeddings, "AION multimodal")

print("\n‚úì All UMAP projections completed!")

## 6. Comparison Visualization - Redshift

Side-by-side comparison of embeddings from both models, colored by redshift.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

projections = [umap_astropt, umap_aion]
redshift_data = [astropt_redshifts, aion_redshifts]
titles = [
    'astroPT Multimodal\n(DESI Spectra + Euclid Images)',
    'AION Multimodal\n(DESI Spectra + Euclid Images)'
]

for ax, proj, z_data, title in zip(axes, projections, redshift_data, titles):
    valid_z = ~np.isnan(z_data)
    
    if valid_z.sum() > 0:
        scatter = ax.scatter(
            proj[valid_z, 0],
            proj[valid_z, 1],
            c=z_data[valid_z],
            cmap='viridis',
            s=8,
            alpha=0.6,
            edgecolors='none'
        )
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Redshift', fontsize=11)
        
        if (~valid_z).sum() > 0:
            ax.scatter(
                proj[~valid_z, 0],
                proj[~valid_z, 1],
                s=5,
                color='lightgray',
                alpha=0.3,
                edgecolors='none'
            )
    else:
        ax.scatter(proj[:, 0], proj[:, 1], s=8, alpha=0.6)
    
    ax.set_xlabel('UMAP 1', fontsize=12)
    ax.set_ylabel('UMAP 2', fontsize=12)
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.grid(alpha=0.2)

fig.suptitle('Multimodal Foundation Models Comparison (colored by redshift)', fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('model_comparison_redshift.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: model_comparison_redshift.png")

## 7. Comparison Visualization - Physical Properties

Compare how different models encode physical properties in their embedding spaces.

In [None]:
def plot_property_comparison_2models(property_name, cmap, output_file):
    """Create 2-panel comparison for a physical property."""
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    
    projections = [umap_astropt, umap_aion]
    properties = [astropt_logm, aion_logm]  # Will be updated per property
    titles = ['astroPT Multimodal', 'AION Multimodal']
    
    # Select the appropriate property arrays
    if 'LOGM' in property_name or 'Stellar Mass' in property_name:
        properties = [astropt_logm, aion_logm]
    elif 'LOGSFR' in property_name or 'Formation Rate' in property_name:
        properties = [astropt_logsfr, aion_logsfr]
    elif 'sSFR' in property_name or 'Specific' in property_name:
        properties = [astropt_ssfr, aion_ssfr]
    elif 'DN4000' in property_name or '4000' in property_name:
        properties = [astropt_dn4000, aion_dn4000]
    elif 'g-r' in property_name or 'Color' in property_name:
        properties = [astropt_gr, aion_gr]
    
    # Compute global vmin/vmax across all datasets for consistent coloring
    all_data = np.concatenate([p[~np.isnan(p)] for p in properties if (~np.isnan(p)).sum() > 0])
    
    if len(all_data) == 0:
        print(f"‚ö† No valid data for {property_name}, skipping")
        plt.close(fig)
        return
    
    vmin = np.percentile(all_data, 2)
    vmax = np.percentile(all_data, 98)
    
    for ax, proj, prop_data, title in zip(axes, projections, properties, titles):
        valid_mask = ~np.isnan(prop_data)
        
        if valid_mask.sum() > 0:
            scatter = ax.scatter(
                proj[valid_mask, 0],
                proj[valid_mask, 1],
                c=prop_data[valid_mask],
                cmap=cmap,
                s=8,
                alpha=0.6,
                edgecolors='none',
                vmin=vmin,
                vmax=vmax
            )
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label(property_name, fontsize=11)
            
            # Plot invalid
            if (~valid_mask).sum() > 0:
                ax.scatter(
                    proj[~valid_mask, 0],
                    proj[~valid_mask, 1],
                    s=5,
                    color='lightgray',
                    alpha=0.3,
                    edgecolors='none'
                )
        
        ax.set_xlabel('UMAP 1', fontsize=12)
        ax.set_ylabel('UMAP 2', fontsize=12)
        ax.set_title(title, fontsize=13, fontweight='bold')
        ax.grid(alpha=0.2)
    
    fig.suptitle(f'Model Comparison: {property_name}', fontsize=15, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_file}")

# Generate comparison plots for different properties
property_configs = [
    ('log(M*/M‚òâ) - Stellar Mass', 'plasma', 'model_comparison_logm.png'),
    ('log(SFR) [M‚òâ/yr] - Star Formation Rate', 'coolwarm', 'model_comparison_logsfr.png'),
    ('log(sSFR) [yr‚Åª¬π] - Specific SFR', 'coolwarm', 'model_comparison_ssfr.png'),
    ('DN4000 (4000√Ö break)', 'RdYlBu_r', 'model_comparison_dn4000.png'),
    ('g-r Color [mag]', 'RdBu_r', 'model_comparison_gr_color.png'),
]

for prop_name, cmap, output_file in property_configs:
    plot_property_comparison_2models(prop_name, cmap, output_file)

## 8. Correlation Analysis - Model Comparison

Compare how well physical properties correlate with embedding dimensions across different models.

In [None]:
def correlation_analysis(umap_proj, properties_dict, name):
    """Compute Spearman correlations between UMAP dimensions and physical properties."""
    print(f"\n{'='*70}")
    print(f"CORRELATION ANALYSIS: {name}")
    print(f"{'='*70}")
    
    print(f"\n{'Property':<25} {'UMAP-1':>10} {'UMAP-2':>10} {'|corr|':>10}")
    print("-" * 70)
    
    results = {}
    
    for prop_name, prop_data in properties_dict.items():
        valid_mask = ~np.isnan(prop_data)
        if valid_mask.sum() > 100:
            corr1, pval1 = spearmanr(umap_proj[valid_mask, 0], prop_data[valid_mask])
            corr2, pval2 = spearmanr(umap_proj[valid_mask, 1], prop_data[valid_mask])
            sig1 = "***" if pval1 < 0.001 else "**" if pval1 < 0.01 else "*" if pval1 < 0.05 else ""
            sig2 = "***" if pval2 < 0.001 else "**" if pval2 < 0.01 else "*" if pval2 < 0.05 else ""
            
            max_corr = max(abs(corr1), abs(corr2))
            results[prop_name] = max_corr
            
            print(f"{prop_name:<25} {corr1:>9.3f}{sig1} {corr2:>9.3f}{sig2} {max_corr:>10.3f}")
        else:
            results[prop_name] = np.nan
            print(f"{prop_name:<25} {'N/A':>10} {'N/A':>10} {'N/A':>10}")
    
    print("\nSignificance: *** p<0.001, ** p<0.01, * p<0.05")
    print("="*70)
    
    return results

# Analyze correlations for all models
astropt_properties = {
    'Redshift': astropt_redshifts,
    'LOGM': astropt_logm,
    'LOGSFR': astropt_logsfr,
    'sSFR': astropt_ssfr,
    'DN4000': astropt_dn4000,
    'g-r Color': astropt_gr,
}

aion_properties = {
    'Redshift': aion_redshifts,
    'LOGM': aion_logm,
    'LOGSFR': aion_logsfr,
    'sSFR': aion_ssfr,
    'DN4000': aion_dn4000,
    'g-r Color': aion_gr,
}

corr_astropt = correlation_analysis(umap_astropt, astropt_properties, "astroPT (Spectra)")
corr_aion_multi = correlation_analysis(umap_aion_multimodal, aion_properties, "AION (Multimodal)")
corr_aion_img = correlation_analysis(umap_aion_image, aion_properties, "AION (Image-only)")

In [None]:
# Visualize correlation comparison
fig, ax = plt.subplots(figsize=(12, 7))

properties_list = list(corr_astropt.keys())
x = np.arange(len(properties_list))
width = 0.25

astropt_vals = [corr_astropt[p] for p in properties_list]
aion_multi_vals = [corr_aion_multi[p] for p in properties_list]
aion_img_vals = [corr_aion_img[p] for p in properties_list]

ax.bar(x - width, astropt_vals, width, label='astroPT (Spectra)', alpha=0.8, color='#1f77b4')
ax.bar(x, aion_multi_vals, width, label='AION (Multimodal)', alpha=0.8, color='#ff7f0e')
ax.bar(x + width, aion_img_vals, width, label='AION (Image-only)', alpha=0.8, color='#2ca02c')

ax.set_xlabel('Physical Property', fontsize=12)
ax.set_ylabel('Max |Spearman Correlation|', fontsize=12)
ax.set_title('Model Comparison: Correlation with Physical Properties', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(properties_list, rotation=45, ha='right')
ax.legend(fontsize=11)
ax.grid(axis='y', alpha=0.3)
ax.set_ylim(0, max(max(astropt_vals), max(aion_multi_vals), max(aion_img_vals)) * 1.1)

plt.tight_layout()
plt.savefig('model_comparison_correlations.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: model_comparison_correlations.png")

## 9. Direct Comparison on Common Objects

For objects present in both datasets, compare the embeddings directly.

In [None]:
if len(common_ids) > 0:
    print(f"Analyzing {len(common_ids)} common objects...\n")
    
    # Extract embeddings for common objects
    astropt_common_emb = astropt_embeddings[astropt_common_indices]
    aion_common_emb = aion_spectra_emb[aion_common_indices]
    
    # Compute cosine similarities
    cosine_sims = []
    for i in range(len(astropt_common_emb)):
        sim = cosine_similarity(astropt_common_emb[i:i+1], aion_common_emb[i:i+1])[0, 0]
        cosine_sims.append(sim)
    
    cosine_sims = np.array(cosine_sims)
    
    print("Cosine Similarity (astroPT vs AION multimodal):")
    print(f"  Mean: {cosine_sims.mean():.4f}")
    print(f"  Std: {cosine_sims.std():.4f}")
    print(f"  Min: {cosine_sims.min():.4f}")
    print(f"  Max: {cosine_sims.max():.4f}")
    print(f"  Median: {np.median(cosine_sims):.4f}")
    
    # Plot distribution
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ax.hist(cosine_sims, bins=50, alpha=0.7, edgecolor='black', color='steelblue')
    ax.axvline(cosine_sims.mean(), color='red', linestyle='--', linewidth=2, 
               label=f'Mean = {cosine_sims.mean():.3f}')
    ax.axvline(np.median(cosine_sims), color='orange', linestyle='--', linewidth=2, 
               label=f'Median = {np.median(cosine_sims):.3f}')
    
    ax.set_xlabel('Cosine Similarity', fontsize=13)
    ax.set_ylabel('Count', fontsize=13)
    ax.set_title('Cosine Similarity: astroPT vs AION Multimodal\n(Common Objects)', 
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('model_comparison_cosine_similarity.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\nSaved: model_comparison_cosine_similarity.png")
else:
    print("‚ö† Skipping common objects analysis (no common objects found)")

## 10. Summary and Conclusions

Summarize the key findings from the model comparison.

In [None]:
print("="*80)
print("MULTIMODAL FOUNDATION MODELS COMPARISON SUMMARY")
print("="*80)

print("\nüìä DATASET SIZES:")
print(f"  ‚Ä¢ astroPT multimodal embeddings: {astropt_embeddings.shape}")
print(f"  ‚Ä¢ AION multimodal embeddings: {aion_embeddings.shape}")
print(f"  ‚Ä¢ Common objects: {len(common_ids) if len(common_ids) > 0 else 'None'}")

print("\nüîç KEY FINDINGS:")

print("\n1. astroPT MULTIMODAL:")
print(f"   Transformer architecture trained on spectra + images")
print(f"   Checkpoint: iteration 21000")

print("\n2. AION MULTIMODAL:")
print(f"   Foundation model trained on images + spectra")
print(f"   Embeddings from multimodal training")

if len(common_ids) > 0:
    print("\n3. DIRECT COMPARISON (Common Objects):")
    print(f"   Average cosine similarity: {cosine_sims.mean():.3f}")
    if cosine_sims.mean() > 0.7:
        print(f"   ‚Üí High similarity: Models encode similar information")
    elif cosine_sims.mean() > 0.4:
        print(f"   ‚Üí Moderate similarity: Models capture complementary features")
    else:
        print(f"   ‚Üí Low similarity: Models focus on different aspects")

print("\n" + "="*80)
print("‚úì Comparison analysis complete! All figures saved.")
print("="*80)

print("\nüìÅ Generated files:")
print("  ‚Ä¢ model_comparison_redshift.png")
print("  ‚Ä¢ model_comparison_logm.png")
print("  ‚Ä¢ model_comparison_logsfr.png")
print("  ‚Ä¢ model_comparison_ssfr.png")
print("  ‚Ä¢ model_comparison_dn4000.png")
print("  ‚Ä¢ model_comparison_gr_color.png")
print("  ‚Ä¢ model_comparison_correlations.png")
if len(common_ids) > 0:
    print("  ‚Ä¢ model_comparison_cosine_similarity.png")