# astroPT Multimodal Embeddings Analysis

This notebook analyzes embeddings extracted from the **astroPT multimodal model**.

astroPT is a transformer model trained on both DESI spectra and Euclid images. This analysis uses the multimodal checkpoint that combines spectral and imaging information.

We use UMAP dimensionality reduction to visualize the high-dimensional embeddings and analyze:
- How embeddings correlate with physical galaxy properties
- Spectral type separation (GALAXY, QSO, STAR)
- The model's learned representation of multimodal astronomical data

**Model checkpoint**: `/pbs/home/a/astroinfo09/logs/logs/astropt_multimodal_full_20251106_011934/ckpt_iter_21000.pt`  
**Embeddings**: `/pbs/home/a/astroinfo09/logs/logs/astropt_multimodal_full_20251106_011934/embeddings_output_21000/`

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from astropy.io import fits
from astropy.table import Table
import umap
import time
from scipy.stats import spearmanr

# Configure matplotlib
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

## 1. Load astroPT Embeddings

Load the embeddings extracted from the multimodal astroPT model.

In [None]:
# Path to astroPT multimodal embeddings
embeddings_dir = Path("/pbs/home/a/astroinfo09/logs/logs/astropt_multimodal_full_20251106_011934/embeddings_output_21000")

print("Loading astroPT multimodal embeddings...")
print(f"Embeddings directory: {embeddings_dir}")

# Check directory contents
if embeddings_dir.exists():
    files = list(embeddings_dir.glob("*.npy")) + list(embeddings_dir.glob("*.npz"))
    print(f"\nFound {len(files)} files:")
    for f in files[:10]:  # Show first 10 files
        print(f"  {f.name}")
    if len(files) > 10:
        print(f"  ... and {len(files) - 10} more")
else:
    print(f"‚ö† WARNING: Directory not found: {embeddings_dir}")

# Try to load embeddings (adapt based on actual file structure)
# Common patterns:
# - embeddings.npy, targetids.npy, redshifts.npy
# - embeddings_with_metadata.npz
# - batch_*.npy files

try:
    # Try loading as individual files
    embeddings = np.load(embeddings_dir / "embeddings.npy")
    target_ids = np.load(embeddings_dir / "targetids.npy")
    
    # Try to load redshifts if available
    try:
        redshifts = np.load(embeddings_dir / "redshifts.npy")
    except FileNotFoundError:
        redshifts = None
        print("‚ö† Redshifts file not found, will extract from catalog")
    
    print(f"\n‚úì Embeddings loaded: {embeddings.shape}")
    print(f"‚úì Target IDs: {len(target_ids)}")
    if redshifts is not None:
        print(f"‚úì Redshifts: {len(redshifts)}")
        print(f"  Range: [{redshifts.min():.3f}, {redshifts.max():.3f}]")

except FileNotFoundError:
    print("\n‚ö† Standard embedding files not found.")
    print("\nTrying alternative formats...")
    
    # Try .npz format
    try:
        data = np.load(embeddings_dir / "embeddings_with_metadata.npz")
        embeddings = data['embeddings']
        target_ids = data['target_ids']
        redshifts = data.get('redshifts', None)
        
        print(f"\n‚úì Loaded from .npz file")
        print(f"‚úì Embeddings: {embeddings.shape}")
        print(f"‚úì Target IDs: {len(target_ids)}")
        if redshifts is not None:
            print(f"‚úì Redshifts: {len(redshifts)}")
    
    except FileNotFoundError:
        print("\n‚ùå Could not find embeddings in expected formats.")
        print("\nPlease check the directory structure and update the loading code.")
        raise

## 2. Load Catalog and Match Objects

Load the DESI-Euclid catalog to get physical properties for each object.

In [None]:
# Path to the combined DESI-Euclid catalog
catalog_path = "/pbs/throng/training/astroinfo2025/data/astroPT_euclid_desi_dataset/desi_euclid_catalog.fits"

print("Loading catalog...")
with fits.open(catalog_path) as hdul:
    catalog = Table(hdul[1].data)

print(f"‚úì Catalog loaded: {len(catalog)} entries")
print(f"‚úì Available columns ({len(catalog.colnames)}):")
for i, col in enumerate(catalog.colnames[:20]):  # Show first 20 columns
    print(f"  {col}")
if len(catalog.colnames) > 20:
    print(f"  ... and {len(catalog.colnames) - 20} more")

In [None]:
# Find ID column in catalog
id_column = None
for col in ['TARGETID', 'targetid', 'object_id', 'OBJECT_ID']:
    if col in catalog.colnames:
        id_column = col
        print(f"Using catalog ID column: {id_column}")
        break

if id_column is None:
    print("‚ö† WARNING: Could not find standard ID column in catalog")
    print("Available columns:", catalog.colnames[:10])
else:
    # Create mapping from object ID to catalog index
    catalog_ids = np.array(catalog[id_column], dtype=np.int64)
    id_to_idx = {int(cid): i for i, cid in enumerate(catalog_ids)}
    
    print(f"\nCatalog ID mapping created: {len(id_to_idx)} unique IDs")
    
    # Convert target_ids to int64 for matching
    target_ids_int = target_ids.astype(np.int64)
    
    # Match embeddings to catalog
    matched_indices = []
    unmatched_count = 0
    
    for tid in target_ids_int:
        if tid in id_to_idx:
            matched_indices.append(id_to_idx[tid])
        else:
            matched_indices.append(-1)
            unmatched_count += 1
    
    matched_indices = np.array(matched_indices)
    match_rate = (len(matched_indices) - unmatched_count) / len(target_ids) * 100
    
    print(f"\n{'='*60}")
    print(f"MATCHING RESULTS")
    print(f"{'='*60}")
    print(f"‚úì Matched: {len(matched_indices) - unmatched_count}/{len(target_ids)} ({match_rate:.1f}%)")
    
    if unmatched_count > 0:
        print(f"‚ö† Unmatched: {unmatched_count} ({100-match_rate:.1f}%)")
    print(f"{'='*60}")

## 3. Extract Physical Properties from Catalog

In [None]:
# Extract physical properties from catalog
def safe_extract(column_name):
    """Extract column values, filling unmatched entries with NaN"""
    if column_name not in catalog.colnames:
        print(f"‚ö† Column '{column_name}' not found in catalog")
        return np.full(len(matched_indices), np.nan, dtype=float)
    
    values = np.full(len(matched_indices), np.nan, dtype=float)
    mask = matched_indices >= 0
    values[mask] = catalog[column_name][matched_indices[mask]]
    return values

def safe_extract_string(column_name):
    """Extract string column values, filling unmatched entries with empty string"""
    if column_name not in catalog.colnames:
        print(f"‚ö† Column '{column_name}' not found in catalog")
        return np.full(len(matched_indices), '', dtype=object)
    
    values = np.full(len(matched_indices), '', dtype=object)
    mask = matched_indices >= 0
    values[mask] = [s.strip() if isinstance(s, (str, bytes)) else str(s) for s in catalog[column_name][matched_indices[mask]]]
    return values

print("Extracting physical properties from catalog...\n")

# Key physical properties
if redshifts is None:
    # If redshifts weren't in the embeddings, get them from catalog
    redshifts = safe_extract('Z')
    if np.all(np.isnan(redshifts)):
        redshifts = safe_extract('REDSHIFT')

logm = safe_extract('LOGM')              # Stellar mass
logsfr = safe_extract('LOGSFR')          # Star formation rate
dn4000 = safe_extract('DN4000')          # 4000√Ö break
gr_color = safe_extract('GR')            # g-r color
spectype = safe_extract_string('SPECTYPE')  # Spectral type

# Derived properties
ssfr = logsfr - logm  # Specific SFR

# Print statistics
print("‚úì Physical property statistics:")
properties = {
    'Redshift': redshifts,
    'LOGM (stellar mass)': logm,
    'LOGSFR': logsfr,
    'sSFR (log)': ssfr,
    'DN4000': dn4000,
    'g-r color': gr_color
}

for name, values in properties.items():
    valid = ~np.isnan(values)
    if valid.sum() > 0:
        print(f"  {name}: {np.nanmean(values):.2f} ¬± {np.nanstd(values):.2f} (n={valid.sum()})")
    else:
        print(f"  {name}: No valid data")

# Spectral types
unique_spectypes = np.unique(spectype[spectype != ''])
if len(unique_spectypes) > 0:
    print(f"\n‚úì Spectral types found: {list(unique_spectypes)}")
    for stype in unique_spectypes:
        count = (spectype == stype).sum()
        print(f"  {stype}: {count:,} ({100*count/len(spectype):.1f}%)")

## 4. UMAP Dimensionality Reduction

Apply UMAP to reduce the embeddings to 2D for visualization.

In [None]:
print("Running UMAP dimensionality reduction...")
print("This may take a few minutes...")

# Configure UMAP
umap_model = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
    n_components=2,
    metric='cosine',
    random_state=42
)

# Fit and transform
start_time = time.time()
embeddings_umap = umap_model.fit_transform(embeddings)
elapsed = time.time() - start_time

print(f"‚úì UMAP completed in {elapsed:.2f} seconds")
print(f"UMAP embeddings shape: {embeddings_umap.shape}")

## 5. Visualize UMAP - Redshift

Plot the 2D UMAP projection colored by redshift.

In [None]:
fig, ax = plt.subplots(figsize=(12, 10))

valid_z = ~np.isnan(redshifts)

if valid_z.sum() > 0:
    # Plot valid redshifts
    scatter = ax.scatter(
        embeddings_umap[valid_z, 0],
        embeddings_umap[valid_z, 1],
        c=redshifts[valid_z],
        cmap='viridis',
        s=8,
        alpha=0.6,
        edgecolors='none'
    )
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Redshift', fontsize=12)
    
    # Plot invalid redshifts in gray
    if (~valid_z).sum() > 0:
        ax.scatter(
            embeddings_umap[~valid_z, 0],
            embeddings_umap[~valid_z, 1],
            s=5,
            color='lightgray',
            alpha=0.3,
            edgecolors='none'
        )
else:
    ax.scatter(embeddings_umap[:, 0], embeddings_umap[:, 1], s=8, alpha=0.6)

ax.set_xlabel('UMAP 1', fontsize=12)
ax.set_ylabel('UMAP 2', fontsize=12)
ax.set_title('astroPT Multimodal Embeddings (colored by redshift)', fontsize=14, fontweight='bold')
ax.grid(alpha=0.2)

plt.tight_layout()
plt.savefig('astropt_embeddings_redshift.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: astropt_embeddings_redshift.png")

## 6. Visualize by Physical Properties

Create visualizations colored by different physical properties.

In [None]:
def plot_property(property_data, property_name, cmap, output_file):
    """Create visualization for a given property."""
    fig, ax = plt.subplots(figsize=(12, 10))
    
    valid_mask = ~np.isnan(property_data)
    
    if valid_mask.sum() == 0:
        print(f"‚ö† No valid data for {property_name}, skipping plot")
        plt.close(fig)
        return
    
    vmin = np.nanpercentile(property_data, 2)
    vmax = np.nanpercentile(property_data, 98)
    
    scatter = ax.scatter(
        embeddings_umap[valid_mask, 0],
        embeddings_umap[valid_mask, 1],
        c=property_data[valid_mask],
        cmap=cmap,
        s=8,
        alpha=0.6,
        edgecolors='none',
        vmin=vmin,
        vmax=vmax
    )
    
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label(property_name, fontsize=12)
    
    # Plot invalid in gray
    if (~valid_mask).sum() > 0:
        ax.scatter(
            embeddings_umap[~valid_mask, 0],
            embeddings_umap[~valid_mask, 1],
            s=5,
            color='lightgray',
            alpha=0.3,
            edgecolors='none'
        )
    
    ax.set_xlabel('UMAP 1', fontsize=12)
    ax.set_ylabel('UMAP 2', fontsize=12)
    ax.set_title(f'astroPT Embeddings colored by {property_name}', fontsize=14, fontweight='bold')
    ax.grid(alpha=0.2)
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_file}")

# Plot for different properties
property_configs = [
    (logm, 'log(M*/M‚òâ) - Stellar Mass', 'plasma', 'astropt_embeddings_logm.png'),
    (logsfr, 'log(SFR) [M‚òâ/yr] - Star Formation Rate', 'coolwarm', 'astropt_embeddings_logsfr.png'),
    (ssfr, 'log(sSFR) [yr‚Åª¬π] - Specific SFR', 'coolwarm', 'astropt_embeddings_ssfr.png'),
    (dn4000, 'DN4000 (4000√Ö break)', 'RdYlBu_r', 'astropt_embeddings_dn4000.png'),
    (gr_color, 'g-r Color [mag]', 'RdBu_r', 'astropt_embeddings_gr_color.png'),
]

for prop_data, prop_name, cmap, output_file in property_configs:
    plot_property(prop_data, prop_name, cmap, output_file)

## 7. Visualize by Spectral Type

Color embeddings by spectral classification (GALAXY, QSO, STAR).

In [None]:
# Color palette for spectral types
colors_map = {
    'GALAXY': '#1f77b4',  # Blue
    'QSO': '#ff7f0e',     # Orange
    'STAR': '#2ca02c',    # Green
}

fig, ax = plt.subplots(figsize=(12, 10))

valid_mask = spectype != ''
unique_types = np.unique(spectype[valid_mask])

# Plot each spectral type
for stype in unique_types:
    mask = (spectype == stype)
    if mask.sum() > 0:
        color = colors_map.get(stype, '#d62728')
        ax.scatter(
            embeddings_umap[mask, 0],
            embeddings_umap[mask, 1],
            c=color,
            label=f'{stype} (n={mask.sum():,})',
            s=10,
            alpha=0.6,
            edgecolors='none'
        )

# Plot unknown
if (~valid_mask).sum() > 0:
    ax.scatter(
        embeddings_umap[~valid_mask, 0],
        embeddings_umap[~valid_mask, 1],
        c='lightgray',
        label=f'Unknown (n={(~valid_mask).sum():,})',
        s=6,
        alpha=0.3,
        edgecolors='none'
    )

ax.set_xlabel('UMAP 1', fontsize=12)
ax.set_ylabel('UMAP 2', fontsize=12)
ax.set_title('astroPT Embeddings colored by Spectral Type', fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=10, markerscale=2)
ax.grid(alpha=0.2)

plt.tight_layout()
plt.savefig('astropt_embeddings_spectype.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: astropt_embeddings_spectype.png")

## 8. Correlation Analysis

Compute correlations between UMAP dimensions and physical properties.

In [None]:
print(f"\n{'='*70}")
print(f"CORRELATION ANALYSIS: astroPT Multimodal Embeddings")
print(f"{'='*70}")

properties_dict = {
    'Redshift': redshifts,
    'LOGM (Stellar Mass)': logm,
    'LOGSFR (SFR)': logsfr,
    'sSFR': ssfr,
    'DN4000 (Age)': dn4000,
    'g-r Color': gr_color,
}

print(f"\n{'Property':<25} {'UMAP-1':>10} {'UMAP-2':>10} {'p-val 1':>10} {'p-val 2':>10}")
print("-" * 70)

for prop_name, prop_data in properties_dict.items():
    valid_mask = ~np.isnan(prop_data)
    if valid_mask.sum() > 100:
        corr1, pval1 = spearmanr(embeddings_umap[valid_mask, 0], prop_data[valid_mask])
        corr2, pval2 = spearmanr(embeddings_umap[valid_mask, 1], prop_data[valid_mask])
        sig1 = "***" if pval1 < 0.001 else "**" if pval1 < 0.01 else "*" if pval1 < 0.05 else ""
        sig2 = "***" if pval2 < 0.001 else "**" if pval2 < 0.01 else "*" if pval2 < 0.05 else ""
        print(f"{prop_name:<25} {corr1:>9.3f}{sig1} {corr2:>9.3f}{sig2} {pval1:>10.3e} {pval2:>10.3e}")
    else:
        print(f"{prop_name:<25} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'N/A':>10}")

print("\n" + "="*70)
print("Significance: *** p<0.001, ** p<0.01, * p<0.05")
print("="*70)

## 9. Summary

Key findings from the astroPT multimodal embeddings analysis.

In [None]:
print("="*70)
print("astroPT MULTIMODAL EMBEDDINGS ANALYSIS SUMMARY")
print("="*70)
print(f"\nTotal objects analyzed: {len(embeddings)}")
print(f"Embedding dimension: {embeddings.shape[1]}")

print(f"\nCatalog matching:")
print(f"  ‚Ä¢ Matched objects: {(matched_indices >= 0).sum()} ({match_rate:.1f}%)")
print(f"  ‚Ä¢ Valid redshifts: {(~np.isnan(redshifts)).sum()}")
if (~np.isnan(redshifts)).sum() > 0:
    print(f"  ‚Ä¢ Redshift range: [{np.nanmin(redshifts):.3f}, {np.nanmax(redshifts):.3f}]")

print(f"\nModel checkpoint:")
print(f"  ‚Ä¢ /pbs/home/a/astroinfo09/logs/logs/astropt_multimodal_full_20251106_011934/ckpt_iter_21000.pt")

print("\n" + "="*70)
print("‚úì Analysis complete! Figures saved to current directory.")
print("="*70)

print("\nüìÅ Generated files:")
print("  ‚Ä¢ astropt_embeddings_redshift.png")
print("  ‚Ä¢ astropt_embeddings_logm.png")
print("  ‚Ä¢ astropt_embeddings_logsfr.png")
print("  ‚Ä¢ astropt_embeddings_ssfr.png")
print("  ‚Ä¢ astropt_embeddings_dn4000.png")
print("  ‚Ä¢ astropt_embeddings_gr_color.png")
print("  ‚Ä¢ astropt_embeddings_spectype.png")