In [None]:
import itertools
import numpy as np
import tarfile
from astropy.table import Table
from matplotlib import pyplot as plt
from scipy.stats import randint, uniform

from sklearn.model_selection import train_test_split, RandomizedSearchCV, cross_validate
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.compose import TransformedTargetRegressor
from sklearn.metrics import mean_squared_error

# NEW imports for SVR
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler

# -------------------------
# Load
# -------------------------
tab = Table.read('/Users/fengbocheng/Projects/Photometric-Redshifts/3dhst_master.phot.v4.1/3dhst_master.phot.v4.1.cat',
                 format='ascii').to_pandas()
target = 'z_spec'

# -------------------------
# 0) Basic quality cuts BEFORE feature engineering
# -------------------------
must_have = {'use_phot', 'z_spec'}
if must_have.issubset(tab.columns):
    mask = (tab['use_phot'] == 1) & (tab[target] > 0)
    if 'lmass' in tab.columns:
        mask &= (tab['lmass'] > 9)
    tab = tab[mask].copy()

# -------------------------
# 1) Clean error columns
# -------------------------
errors = [c for c in tab.columns if (c.startswith('e_') and c.endswith('W')) or c.startswith('eaper_')]
for e in errors:
    tab.loc[(~np.isfinite(tab[e])) | (tab[e] <= 0) | (tab[e] < -90), e] = np.nan

# -------------------------
# 2) Transfer all fluxes to AB magnitudes (zeropoint = 25)
# -------------------------
bands = ['F606W', 'F814W', 'F140W', 'F160W']
for b in bands:
    fcol = f'f_{b}'
    if fcol in tab.columns:
        tab.loc[(~np.isfinite(tab[fcol])) | (tab[fcol] <= 0), fcol] = np.nan
        tab[f'mag_{b}'] = 25.0 - 2.5 * np.log10(np.clip(tab[fcol].to_numpy(), 1e-6, np.inf))

# -------------------------
# 3) Add color features: ALL pairwise colors (no duplicates)
# -------------------------
mag_cols = [f'mag_{b}' for b in bands if f'mag_{b}' in tab.columns]
for a, b in itertools.combinations(mag_cols, 2):
    band_a = a.replace('mag_', '')
    band_b = b.replace('mag_', '')
    tab[f'color_{band_a}_{band_b}'] = tab[a] - tab[b]

# -------------------------
# 4) Add axis ratio b/a (simple morphology)
# -------------------------
if {'a_image', 'b_image'}.issubset(tab.columns):
    tab['axis_ratio'] = tab['b_image'] / (tab['a_image'] + 1e-6)

# -------------------------
# 5) Assemble feature list 
# -------------------------
feature_cols = []
# magnitudes
for b in bands:
    c = f'mag_{b}'
    if c in tab.columns:
        feature_cols.append(c)
# all pairwise colors
for a, b in itertools.combinations([bb for bb in bands if f'mag_{bb}' in tab.columns], 2):
    c = f'color_{a}_{b}'
    if c in tab.columns:
        feature_cols.append(c)
# axis ratio (optional)
if 'axis_ratio' in tab.columns:
    feature_cols.append('axis_ratio')
# other scalar features (only add if present)
for extra in ['flux_radius', 'fwhm_image', 'kron_radius', 'tot_cor']:
    if extra in tab.columns:
        feature_cols.append(extra)

# -------------------------
# 5b) Add S/N features per band (if both flux & error exist)
# -------------------------
for b in bands:
    f = f"f_{b}"
    e = f"e_{b}"
    if f in tab.columns and e in tab.columns:
        sn = f"snr_{b}"
        tab[sn] = np.where((tab[e] > 0) & np.isfinite(tab[e]), tab[f]/tab[e], np.nan)
        if sn not in feature_cols:
            feature_cols.append(sn)

# -------------------------
# Build X (DataFrame) and y
# -------------------------
X = tab[feature_cols]   # DataFrame (keep names for ColumnTransformer)
y = tab[target].values
indices = np.arange(len(y))

# -------------------------
# 6) Stratified splits on binned z_spec
# -------------------------
n_bins = 10  # reduce if you get "least populated class has only 1 member"
quantiles = np.quantile(y, np.linspace(0, 1, n_bins + 1))
quantiles = np.unique(quantiles)
if len(quantiles) < 3:
    quantiles = np.quantile(y, np.linspace(0, 1, 6))  # 5 bins
y_bins_all = np.digitize(y, quantiles[1:-1], right=True)

# 6.1 Train vs temp (val+test)
X_train, X_temp, y_train, y_temp, indices_train, indices_temp, bins_train, bins_temp = train_test_split(
    X, y, indices, y_bins_all,
    test_size=0.30, random_state=42, stratify=y_bins_all
)

# 6.2 Temp → validation & test (50/50), stratified within temp
X_validate, X_test, y_validate, y_test, indices_validate, indices_test = train_test_split(
    X_temp, y_temp, indices_temp,
    test_size=0.50, random_state=42, stratify=bins_temp
)

# -------------------------
# 7) Pipeline (no leakage) + log1p target + tuned SVR
# -------------------------
from sklearn.pipeline import Pipeline as SkPipe  # (alias used inside ColumnTransformer)

# Impute + Scale inside CV (SVR needs scaling; prevents leakage)
pre = ColumnTransformer(
    transformers=[
        ("num", SkPipe([
            ("imputer", SimpleImputer(strategy="median")),
            ("scaler", StandardScaler())
        ]), feature_cols)
    ],
    remainder="drop",
)

svr_base = SVR(kernel="linear")

pipe = Pipeline([
    ("pre", pre),
    ("reg", TransformedTargetRegressor(
        regressor=svr_base, func=np.log1p, inverse_func=np.expm1
    )),
])

# Compact SVR search space 
from scipy.stats import reciprocal
param_dist = {
    "reg__regressor__C":       reciprocal(1e-1, 1e3),   # 1e-1 … 1e3
    "reg__regressor__gamma":   reciprocal(1e-3, 1e0),   # 1e-3 … 1e0
    "reg__regressor__epsilon": [0.01, 0.03, 0.1, 0.3]
}

rf_random = RandomizedSearchCV(  
    estimator=pipe,
    param_distributions=param_dist,
    n_iter=24, cv=5, scoring="neg_mean_squared_error",
    random_state=42, verbose=1
)

rf_random.fit(X_train, y_train)
best_rf = rf_random.best_estimator_
print("Best params:", rf_random.best_params_)

# -------------------------
# 8) Evaluate on test set
# -------------------------
y_predict_test = best_rf.predict(X_test)
print("Test MSE:", mean_squared_error(y_test, y_predict_test))

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=400)
ax.scatter(y_test, y_predict_test, alpha=0.2)
ax.plot([-0.1, 5], [-0.1, 5], linestyle='--')  # 1:1 line
ax.set_aspect('equal')
ax.set_xlim(-0.1, 5)
ax.set_ylim(-0.1, 5)
ax.grid(True)
ax.set_xlabel('Truth (z_spec, test)')
ax.set_ylabel('SVR prediction (z_phot)')
plt.tight_layout()
plt.show()

# -------------------------
# 9) Validation comparison vs Skelton+2014 z_peak
# -------------------------
z_ml   = best_rf.predict(X_validate)                   # ML (SVR pipeline)
z_eazy = tab['z_peak'].to_numpy()[indices_validate]    # EAZY baseline
z_spec = y_validate                                    # ground truth

# MSE / RMSE and ratios (validation)
mse_ml  = mean_squared_error(z_spec, z_ml)
mse_ez  = mean_squared_error(z_spec, z_eazy)
rmse_ml = np.sqrt(mse_ml)
rmse_ez = np.sqrt(mse_ez)

print(f"[Validation] SVR (RBF)   MSE: {mse_ml:.6f} | RMSE: {rmse_ml:.6f}")
print(f"[Validation] EAZY (z_peak)  MSE: {mse_ez:.6f} | RMSE: {rmse_ez:.6f}")
print(f"[Validation] MSE ratio  (SVR/EAZY): {mse_ml / mse_ez:.3f}")
print(f"[Validation] RMSE ratio (SVR/EAZY): {rmse_ml / rmse_ez:.3f}")

# Scatter: predictions vs. z_spec (SVR & EAZY)
fig, ax = plt.subplots(1, 2, figsize=(10, 4), dpi=400)

ax[0].scatter(z_spec, z_ml, alpha=0.2, s=8)
ax[0].plot([-0.1, 5], [-0.1, 5], ls='--')
ax[0].set_aspect('equal'); ax[0].set_xlim(-0.1, 5); ax[0].set_ylim(-0.1, 5)
ax[0].grid(True, alpha=0.3)
ax[0].set_xlabel('Truth (z_spec, val)')
ax[0].set_ylabel('Prediction (z_phot)')
ax[0].set_title(f'SVR (RBF)\nRMSE = {rmse_ml:.3f}, MSE = {mse_ml:.3f}')

ax[1].scatter(z_spec, z_eazy, alpha=0.2, s=8)
ax[1].plot([-0.1, 5], [-0.1, 5], ls='--')
ax[1].set_aspect('equal'); ax[1].set_xlim(-0.1, 5); ax[1].set_ylim(-0.1, 5)
ax[1].grid(True, alpha=0.3)
ax[1].set_xlabel('Truth (z_spec, val)')
ax[1].set_ylabel('Prediction (z_phot)')
ax[1].set_title(f'EAZY (z_peak)\nRMSE = {rmse_ez:.3f}, MSE = {mse_ez:.3f}')

plt.tight_layout()
plt.show()

# Residuals vs z_spec
res_svr = z_ml - z_spec
res_ez  = z_eazy - z_spec

fig, ax = plt.subplots(figsize=(6, 4), dpi=400)
ax.scatter(z_spec, res_svr, alpha=0.25, s=8, label='SVR')
ax.scatter(z_spec, res_ez,  alpha=0.25, s=8, label='EAZY')
ax.axhline(0, color='gray', ls='--')
ax.grid(True, alpha=0.3)
ax.set_xlabel('Truth (z_spec, val)')
ax.set_ylabel('Residual (z_phot − z_spec)')
ax.set_title('Residuals vs. redshift')
ax.legend()

plt.tight_layout()
plt.show()