# Foundation Models Comparison: astroPT vs AION vs AstroCLIP

This notebook compares embeddings from three different multimodal foundation models:

**astroPT Multimodal**: Transformer model trained on DESI spectra + Euclid images
- Checkpoint: iteration 21000
- Location: `/pbs/home/a/astroinfo09/logs/logs/astropt_multimodal_full_20251106_011934/`

**AION Multimodal**: Foundation model trained on Euclid images + DESI spectra
- Embeddings from Maxime's work
- Location: `/pbs/throng/training/astroinfo2025/work/maxime/data_all_tokens_spectrums.pt`

**AstroCLIP**: Multimodal contrastive learning model producing joint embeddings for Euclid images and DESI spectra
- Joint embeddings file
- Location: `/pbs/throng/training/astroinfo2025/data/AstroCLIP_team/AstroCLIP/AstroCLIP_joint_embedding_outliers/embeddings_all.pt`

We compare how the three models encode spectral and imaging information in their embedding spaces.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
from astropy.io import fits
from astropy.table import Table
import umap
import time
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity

# Configure matplotlib
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

## 1. Load astroPT Multimodal Embeddings

Load the embeddings extracted from the astroPT multimodal model (spectra + images).

In [None]:
embeddings_dir = Path('/pbs/throng/training/astroinfo2025/data/astroPT_euclid_desi_dataset/astropt/embeddings_output_21000')
print('Loading astroPT multimodal embeddings...')
print(f'Embeddings directory: {embeddings_dir}')

if embeddings_dir.exists():
    files = list(embeddings_dir.glob('*.npy')) + list(embeddings_dir.glob('*.npz'))
    print(f"\nFound {len(files)} files:")
    for f in files[:10]:
        print(f'  {f.name}')
    if len(files) > 10:
        print(f'  ... and {len(files) - 10} more')
else:
    print('⚠ WARNING: Directory not found')

try:
    astropt_embeddings = np.load(embeddings_dir / 'multimodal_train_embeddings.npy')
    astropt_targetids = np.load(embeddings_dir / 'multimodal_train_target_ids.npy')
    astropt_object_ids = np.load(embeddings_dir / 'multimodal_train_object_ids.npy')
    astropt_redshifts = np.load(embeddings_dir / 'multimodal_train_redshifts.npy')
    astropt_has_image = np.load(embeddings_dir / 'multimodal_train_has_image.npy')
    astropt_has_spectrum = np.load(embeddings_dir / 'multimodal_train_has_spectrum.npy')
    print(f"\n✓ astroPT embeddings loaded: {astropt_embeddings.shape}")
    print(f'✓ Target IDs: {len(astropt_targetids)}')
    print(f'✓ Object IDs: {len(astropt_object_ids)}')
    print(f'✓ Redshifts: {len(astropt_redshifts)}')
    print(f'  Redshift range: [{astropt_redshifts.min():.3f}, {astropt_redshifts.max():.3f}]')
    print('\n✓ Data availability:')
    print(f'  Has spectrum: {astropt_has_spectrum.sum()} ({100*astropt_has_spectrum.sum()/len(astropt_has_spectrum):.1f}%)')
    print(f'  Has image: {astropt_has_image.sum()} ({100*astropt_has_image.sum()/len(astropt_has_image):.1f}%)')
    print(f'  Has both: {(astropt_has_spectrum & astropt_has_image).sum()} ({100*(astropt_has_spectrum & astropt_has_image).sum()/len(astropt_has_spectrum):.1f}%)')
except FileNotFoundError as e:
    print('\n❌ Could not find embeddings files.')
    print(f'Error: {e}')
    print('\nPlease check the directory structure.')
    raise

## 2. Load AION Multimodal Embeddings

Load the embeddings from AION model.

In [None]:
aion_path = '/pbs/throng/training/astroinfo2025/work/maxime/data_all_tokens_spectrums.pt'
print('Loading AION embeddings...')
aion_data = torch.load(aion_path, map_location='cpu')

aion_records = aion_data if isinstance(aion_data, list) else [aion_data]
print(f'✓ AION data loaded: {len(aion_records)} records')

def stack_embeddings(records, key):
    vectors, indices = [], []
    for idx, rec in enumerate(records):
        tensor = rec.get(key)
        if tensor is None:
            continue
        if isinstance(tensor, torch.Tensor):
            vectors.append(tensor.detach().cpu().numpy())
        else:
            vectors.append(np.asarray(tensor))
        indices.append(idx)
    if not vectors:
        raise ValueError(f"No embeddings found for key '{key}'")
    return np.stack(vectors, axis=0), np.array(indices)

aion_embeddings, _ = stack_embeddings(aion_records, 'embedding_hsc_desi')
print(f'✓ AION multimodal embeddings: {aion_embeddings.shape}')

aion_object_ids, aion_redshifts = [], []
for rec in aion_records:
    obj_id = rec.get('object_id') or rec.get('TARGETID') or rec.get('targetid')
    aion_object_ids.append(obj_id if obj_id is not None else np.nan)
    z = rec.get('redshift') or rec.get('Z') or rec.get('z')
    aion_redshifts.append(z if z is not None else np.nan)

aion_object_ids = np.array(aion_object_ids)
aion_redshifts = np.array(aion_redshifts, dtype=float)

## 2b. Load AstroCLIP Embeddings

Load the joint embeddings from the AstroCLIP model (images + spectra).

In [None]:
astroclip_path = '/pbs/throng/training/astroinfo2025/data/AstroCLIP_team/AstroCLIP/AstroCLIP_joint_embedding_outliers/embeddings_all.pt'
print('Loading AstroCLIP embeddings...')
print(f'Path: {astroclip_path}')
astroclip_data = torch.load(astroclip_path, map_location='cpu')

astroclip_embeddings = None
astroclip_object_ids = None
astroclip_redshifts = None

if isinstance(astroclip_data, dict):
    emb_keys = ['embeddings','embedding','features','vectors','joint_embeddings']
    id_keys = ['ids','object_ids','object_id','target_ids','TARGETID','targetid']
    z_keys = ['redshifts','redshift','z','Z']
    for k in emb_keys:
        if k in astroclip_data:
            v = astroclip_data[k]
            astroclip_embeddings = v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else np.asarray(v)
            print(f"✓ AstroCLIP embeddings found in key '{k}': {astroclip_embeddings.shape}")
            break
    if astroclip_embeddings is None:
        for k, v in astroclip_data.items():
            if isinstance(v, (torch.Tensor, np.ndarray)) and getattr(v, 'ndim', len(v.shape)) == 2:
                astroclip_embeddings = v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else np.asarray(v)
                print(f"⚠ Using '{k}' as AstroCLIP embeddings (guessed)")
                break
    for k in id_keys:
        if k in astroclip_data:
            v = astroclip_data[k]
            astroclip_object_ids = v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else np.asarray(v)
            print(f"✓ AstroCLIP IDs found in key '{k}': {len(astroclip_object_ids)}")
            break
    for k in z_keys:
        if k in astroclip_data:
            v = astroclip_data[k]
            astroclip_redshifts = v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else np.asarray(v, dtype=float)
            print(f"✓ AstroCLIP redshifts found in key '{k}': {len(astroclip_redshifts)}")
            break
elif isinstance(astroclip_data, torch.Tensor):
    astroclip_embeddings = astroclip_data.detach().cpu().numpy()
    print(f'✓ AstroCLIP embeddings tensor: {astroclip_embeddings.shape}')
else:
    print(f'⚠ Unexpected AstroCLIP data type: {type(astroclip_data)}')

if astroclip_embeddings is None:
    raise ValueError('AstroCLIP embeddings could not be located in file')
if astroclip_redshifts is None:
    print('⚠ AstroCLIP redshifts not found in file; will use catalog')
if astroclip_object_ids is None:
    print('⚠ AstroCLIP object IDs not found in file; catalog matching may be limited')

## 3. Load Catalog and Match IDs

We'll load the DESI+Euclid catalog and create mappings to align objects across the three models using TARGETID and/or object_id.

In [None]:
catalog_path = '/pbs/throng/training/astroinfo2025/data/astroPT_euclid_desi_dataset/desi_euclid_catalog.fits'
print('Loading catalog...')
print(f'Path: {catalog_path}')

# Load as Astropy Table
try:
    cat = Table.read(catalog_path, format='fits')
    print(f'✓ Catalog loaded: {len(cat)} rows, {len(cat.colnames)} columns')
except Exception as e:
    print('❌ Failed to read catalog')
    raise

# Helper to fetch a column with fallbacks
def get_col(table, names):
    for n in names:
        if n in table.colnames:
            return table[n]
    print(f"⚠ Column not found; tried {names}")
    return None

# Identify ID columns
cat_targetid = get_col(cat, ['TARGETID', 'targetid', 'TargetID'])
cat_object_id = get_col(cat, ['OBJECT_ID', 'object_id', 'OBJID', 'OBJ_ID', 'ID'])

# Build index maps

def build_index_map(values):
    if values is None:
        return None
    m = {}
    for i, v in enumerate(values):
        key = v
        try:
            # Normalize numpy types
            if hasattr(v, 'item'):
                key = v.item()
        except Exception:
            pass
        m[key] = i
    return m

targetid_to_idx = build_index_map(cat_targetid)
objectid_to_idx = build_index_map(cat_object_id)

print('Index map availability:')
print(f'  TARGETID map: {targetid_to_idx is not None}')
print(f'  object_id map: {objectid_to_idx is not None}')

# Matching helpers

def match_ids(ids, id_map, cast_to_int=False):
    if ids is None or id_map is None:
        return np.full(0 if ids is None else len(ids), -1, dtype=int), 0.0
    idx = np.full(len(ids), -1, dtype=int)
    for i, v in enumerate(ids):
        key = v
        # Clean NaNs/None
        if key is None:
            continue
        if isinstance(key, float) and not np.isfinite(key):
            continue
        # Try integer cast if requested
        if cast_to_int:
            try:
                key = int(key)
            except Exception:
                pass
        # Normalize numpy scalar
        try:
            if hasattr(key, 'item'):
                key = key.item()
        except Exception:
            pass
        j = id_map.get(key)
        if j is not None:
            idx[i] = j
    rate = (idx >= 0).mean() if len(idx) > 0 else 0.0
    return idx, rate

# 3.1 Match astroPT (prefer TARGETID)
astropt_match_idx, astropt_match_rate = match_ids(astropt_targetids, targetid_to_idx, cast_to_int=True)
print(f"astroPT match via TARGETID: {astropt_match_rate*100:.1f}% ({(astropt_match_idx>=0).sum()}/{len(astropt_match_idx)})")

# 3.2 Match AION (prefer object_id)
aion_match_idx, aion_match_rate = match_ids(aion_object_ids, objectid_to_idx, cast_to_int=False)
print(f"AION match via object_id: {aion_match_rate*100:.1f}% ({(aion_match_idx>=0).sum()}/{len(aion_match_idx)})")

# 3.3 Match AstroCLIP (try object_id first, then attempt interpreting as TARGETID)
astroclip_match_idx = None
astroclip_match_key = None

if astroclip_object_ids is not None:
    idx_obj, rate_obj = match_ids(astroclip_object_ids, objectid_to_idx, cast_to_int=False)
    # Try integer interpretation as TARGETID too
    idx_tid, rate_tid = match_ids(astroclip_object_ids, targetid_to_idx, cast_to_int=True)
    if rate_obj >= rate_tid:
        astroclip_match_idx, astroclip_match_key = idx_obj, 'object_id'
    else:
        astroclip_match_idx, astroclip_match_key = idx_tid, 'TARGETID'
    print(f"AstroCLIP match via {astroclip_match_key}: {(astroclip_match_idx>=0).sum()}/{len(astroclip_match_idx)} ({(astroclip_match_idx>=0).mean()*100:.1f}%)")
else:
    astroclip_match_idx = np.full(len(astroclip_embeddings), -1, dtype=int)
    astroclip_match_key = 'none'
    print('⚠ AstroCLIP: no IDs available for matching; downstream property joins will be limited')

## 4. Extract Properties for Each Model

We'll extract key physical properties (redshift, mass, SFR, Dn4000, g−r) from the catalog for the matched objects of each model and assemble aligned arrays.

In [None]:
# Helper to extract properties with graceful fallbacks
z_cat = get_col(cat, ['Z', 'REDSHIFT', 'z', 'REDSHIFT_MEDIAN'])
logm_cat = get_col(cat, ['LOGM', 'logM', 'LOG_M', 'MSTAR_LOG'])
logsfr_cat = get_col(cat, ['LOGSFR', 'logSFR', 'SFR_LOG'])
dn4000_cat = get_col(cat, ['DN4000', 'dn4000'])
gr_cat = get_col(cat, ['GR', 'g_r', 'GMINUSR', 'G_MINUS_R'])

# Utility to assemble aligned property arrays for each model

def extract_from_catalog(match_idx, cat_col):
    n = len(match_idx)
    out = np.full(n, np.nan)
    if cat_col is None:
        return out
    valid = match_idx >= 0
    out[valid] = np.asarray(cat_col)[match_idx[valid]]
    return out

# Prefer primary with fallback when primary missing/NaN

def prefer(primary, fallback):
    if primary is None or (isinstance(primary, (float, int))):
        return fallback
    p = np.asarray(primary, dtype=float)
    f = np.asarray(fallback, dtype=float) if fallback is not None else np.full_like(p, np.nan)
    use_f = ~np.isfinite(p)
    out = p.copy()
    out[use_f] = f[use_f]
    return out

# astroPT properties
astropt_z_cat = extract_from_catalog(astropt_match_idx, z_cat)
astropt_z = prefer(astropt_redshifts, astropt_z_cat)
astropt_logm = extract_from_catalog(astropt_match_idx, logm_cat)
astropt_logsfr = extract_from_catalog(astropt_match_idx, logsfr_cat)
astropt_dn4000 = extract_from_catalog(astropt_match_idx, dn4000_cat)
astropt_gr = extract_from_catalog(astropt_match_idx, gr_cat)

# AION properties
aion_z_cat = extract_from_catalog(aion_match_idx, z_cat)
aion_z = prefer(aion_redshifts, aion_z_cat)
aion_logm = extract_from_catalog(aion_match_idx, logm_cat)
aion_logsfr = extract_from_catalog(aion_match_idx, logsfr_cat)
aion_dn4000 = extract_from_catalog(aion_match_idx, dn4000_cat)
aion_gr = extract_from_catalog(aion_match_idx, gr_cat)

# AstroCLIP properties
astroclip_z_cat = extract_from_catalog(astroclip_match_idx, z_cat)
astroclip_z = prefer(astroclip_redshifts, astroclip_z_cat)
astroclip_logm = extract_from_catalog(astroclip_match_idx, logm_cat)
astroclip_logsfr = extract_from_catalog(astroclip_match_idx, logsfr_cat)
astroclip_dn4000 = extract_from_catalog(astroclip_match_idx, dn4000_cat)
astroclip_gr = extract_from_catalog(astroclip_match_idx, gr_cat)

print('Property extraction complete:')
for name, arr in [
    ('astropt_z', astropt_z), ('aion_z', aion_z), ('astroclip_z', astroclip_z),
    ('astropt_logm', astropt_logm), ('aion_logm', aion_logm), ('astroclip_logm', astroclip_logm),
    ('astropt_logsfr', astropt_logsfr), ('aion_logsfr', aion_logsfr), ('astroclip_logsfr', astroclip_logsfr),
    ('astropt_dn4000', astropt_dn4000), ('aion_dn4000', aion_dn4000), ('astroclip_dn4000', astroclip_dn4000),
    ('astropt_gr', astropt_gr), ('aion_gr', aion_gr), ('astroclip_gr', astroclip_gr)
]:
    if isinstance(arr, np.ndarray):
        finite = np.isfinite(arr)
        print(f"  {name}: {finite.sum()}/{len(arr)} finite")
    else:
        print(f"  {name}: not available")

## 5. UMAP projections

In [None]:
def compute_umap_projection(X, n_neighbors=50, min_dist=0.1, metric='cosine', random_state=42):
    reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric, random_state=random_state)
    t0 = time.time()
    Y = reducer.fit_transform(X)
    dt = time.time() - t0
    print(f'UMAP done in {dt:.2f}s -> shape {Y.shape}')
    return Y

print('Computing UMAP projections (cosine) for all models...')
umap_astropt = compute_umap_projection(astropt_embeddings)
umap_aion = compute_umap_projection(aion_embeddings)
umap_astroclip = compute_umap_projection(astroclip_embeddings)

## 6. Visualize UMAP colored by redshift

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)

# Compute global color limits across models
z_arrays = [astropt_z, aion_z, astroclip_z]
finite_vals = np.hstack([z[np.isfinite(z)] for z in z_arrays if isinstance(z, np.ndarray)])
if finite_vals.size == 0:
    vmin, vmax = 0.0, 1.0
else:
    vmin, vmax = np.nanpercentile(finite_vals, [1, 99])

plots = [
    ('astroPT', umap_astropt, astropt_z),
    ('AION', umap_aion, aion_z),
    ('AstroCLIP', umap_astroclip, astroclip_z),
]

for ax, (title, U, z) in zip(axes, plots):
    if not isinstance(z, np.ndarray):
        z = np.full(U.shape[0], np.nan)
    mask = np.isfinite(z)
    sc = ax.scatter(U[mask,0], U[mask,1], c=z[mask], s=5, cmap='viridis', vmin=vmin, vmax=vmax, alpha=0.8)
    ax.scatter(U[~mask,0], U[~mask,1], c='lightgray', s=3, alpha=0.3)
    ax.set_title(title)
    ax.set_xticks([]); ax.set_yticks([])

cbar = fig.colorbar(sc, ax=axes, shrink=0.9)
cbar.set_label('Redshift z')
plt.show()

## 7. Property visualizations (1×3 panels)

In [None]:
def plot_property_triptych(prop_name, arrays, umaps, cmap='viridis', robust=True):
    names = ['astroPT', 'AION', 'AstroCLIP']
    zvals = []
    for arr in arrays:
        if isinstance(arr, np.ndarray):
            zvals.append(arr[np.isfinite(arr)])
    if len(zvals) == 0 or sum(map(len, zvals)) == 0:
        print(f'No finite values for {prop_name}. Skipping.')
        return
    all_vals = np.hstack(zvals)
    if robust:
        vmin, vmax = np.nanpercentile(all_vals, [1, 99])
    else:
        vmin, vmax = np.nanmin(all_vals), np.nanmax(all_vals)

    fig, axes = plt.subplots(1, 3, figsize=(16, 5), constrained_layout=True)
    for ax, name, U, vals in zip(axes, names, umaps, arrays):
        if not isinstance(vals, np.ndarray):
            vals = np.full(U.shape[0], np.nan)
        mask = np.isfinite(vals)
        sc = ax.scatter(U[mask,0], U[mask,1], c=vals[mask], s=5, cmap=cmap, vmin=vmin, vmax=vmax, alpha=0.85)
        ax.scatter(U[~mask,0], U[~mask,1], c='lightgray', s=3, alpha=0.3)
        ax.set_title(name)
        ax.set_xticks([]); ax.set_yticks([])
    cbar = fig.colorbar(sc, ax=axes, shrink=0.9)
    cbar.set_label(prop_name)
    plt.show()

# Try a few properties if available
umaps = [umap_astropt, umap_aion, umap_astroclip]
plot_property_triptych('logM', [astropt_logm, aion_logm, astroclip_logm], umaps)
plot_property_triptych('logSFR', [astropt_logsfr, aion_logsfr, astroclip_logsfr], umaps)
plot_property_triptych('Dn4000', [astropt_dn4000, aion_dn4000, astroclip_dn4000], umaps)
plot_property_triptych('g-r', [astropt_gr, aion_gr, astroclip_gr], umaps)

## 8. Correlation analysis (Spearman ρ between UMAP axes and properties)

In [None]:
def safe_spearman(x, y):
    mask = np.isfinite(x) & np.isfinite(y)
    if mask.sum() < 10:
        return np.nan
    r, _ = spearmanr(x[mask], y[mask])
    return r

props = {
    'z': [astropt_z, aion_z, astroclip_z],
    'logM': [astropt_logm, aion_logm, astroclip_logm],
    'logSFR': [astropt_logsfr, aion_logsfr, astroclip_logsfr],
    'Dn4000': [astropt_dn4000, aion_dn4000, astroclip_dn4000],
    'g-r': [astropt_gr, aion_gr, astroclip_gr],
}
models = ['astroPT', 'AION', 'AstroCLIP']
umaps = [umap_astropt, umap_aion, umap_astroclip]

print('Spearman correlations (UMAP-1/2 vs property)')
for pname, arrays in props.items():
    print(f"\nProperty: {pname}")
    for model, U, vals in zip(models, umaps, arrays):
        r1 = safe_spearman(U[:,0], vals) if isinstance(vals, np.ndarray) else np.nan
        r2 = safe_spearman(U[:,1], vals) if isinstance(vals, np.ndarray) else np.nan
        print(f"  {model:9s}  ρ(UMAP1,{pname})={r1:6.3f}  ρ(UMAP2,{pname})={r2:6.3f}")

## 9. Summary

- All three models loaded successfully and matched to the DESI+Euclid catalog where possible.
- UMAP projections computed with cosine metric; quick-look visualizations colored by redshift and other properties provided.
- Spearman correlations printed for UMAP axes vs. astrophysical properties to benchmark structure encoding.

Next steps (optional):
- Add cosine-similarity comparisons on the intersection of matched objects across models.
- Save figures to disk and export UMAP coordinates for downstream analysis.
- Tune UMAP hyperparameters per model to test stability.