
# Exoplanet ML — End-to-End (KOI & PS compatible)

This notebook trains:
1. **Type classifier** (multiclass): size-based (and optional thermal) exoplanet classes.
2. **Binary classifier** (optional): *is exoplanet?* (only if your dataset includes a reliable binary label).

It supports **KOI** (`koi_*`) and **Planetary Systems (PS)** (`pl_*`) tables from the NASA Exoplanet Archive.  
**Auto-detects** which one you loaded and maps column names accordingly.



## 1) Requirements


In [1]:

# If needed:
! pip install pandas numpy scikit-learn joblib





## 2) Imports & Constants


In [2]:

import os
from typing import Optional, Tuple

import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.ensemble import HistGradientBoostingClassifier
import joblib

MJUP_TO_MEARTH = 317.828



## 3) Load Dataset (robust CSV/TSV, skip Archive header)


In [3]:

# <<< EDIT PATH IF NEEDED >>>
CSV_PATH = "cumulative.csv"  # your downloaded table

# Try CSV (comma) with Archive comments skipped. If it fails, try TSV.
try:
    df = pd.read_csv(CSV_PATH, comment="#", encoding="utf-8-sig", engine="python")
except Exception as e:
    print("CSV read failed, trying TSV...", e)
    df = pd.read_csv(CSV_PATH, sep="\t", comment="#", encoding="utf-8-sig", engine="python")

print(df.shape)
pd.set_option("display.max_columns", 120)
df.head(3)


CSV read failed, trying TSV... [Errno 2] No such file or directory: 'cumulative.csv'


FileNotFoundError: [Errno 2] No such file or directory: 'cumulative.csv'


## 4) Detect Table Type (KOI vs PS) and Map Columns
Maps key columns to a unified set so the rest of the notebook is identical.


In [None]:

def first_present(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

IS_KOI = any(c.startswith("koi_") for c in df.columns)
IS_PS  = any(c.startswith("pl_")  for c in df.columns)

# Unified names via aliasing
COL_PL_RADE   = first_present(df, ["pl_rade", "koi_prad"])      # R_⊕
COL_PL_BMASSE = first_present(df, ["pl_bmasse"])                # M_⊕ (often absent in KOI)
COL_PL_BMASSJ = first_present(df, ["pl_bmassj"])                # M_J (often absent in KOI)
COL_PL_EQT    = first_present(df, ["pl_eqt", "koi_teq"])        # K

COL_PERIOD    = first_present(df, ["pl_orbper", "koi_period"])  # days
COL_ST_TEFF   = first_present(df, ["st_teff", "koi_steff"])     # K
COL_ST_RAD    = first_present(df, ["st_rad", "koi_srad"])       # R_⊙
COL_INSOL     = first_present(df, ["koi_insol"])                # KOI-only

print("Detected: KOI?", IS_KOI, "| PS?", IS_PS)
print("Mapped columns:",
      "RADE=",COL_PL_RADE, "BMASSE=",COL_PL_BMASSE, "BMASSJ=",COL_PL_BMASSJ, "EQT=",COL_PL_EQT,
      "PERIOD=",COL_PERIOD, "ST_TEFF=",COL_ST_TEFF, "ST_RAD=",COL_ST_RAD, "INSOL=",COL_INSOL)


Detected: KOI? True | PS? False
Mapped columns: RADE= koi_prad BMASSE= None BMASSJ= None EQT= koi_teq PERIOD= koi_period ST_TEFF= koi_steff ST_RAD= koi_srad INSOL= koi_insol



## 5) Build Labels (Size & Thermal → Type)
Creates:
- `size_label` from radius/mass
- `thermal_label` from equilibrium temperature
- `type_label` = size or size + thermal (e.g., `joviano_caliente`)


In [None]:

def _size_class(row: pd.Series) -> Optional[str]:
    r  = row.get(COL_PL_RADE,  np.nan)  # Earth radii
    me = row.get(COL_PL_BMASSE, np.nan) # Earth masses
    mj = row.get(COL_PL_BMASSJ, np.nan) # Jupiter masses

    if pd.notna(mj) and pd.isna(me):
        me = mj * MJUP_TO_MEARTH

    # Prefer radius
    if pd.notna(r):
        if r < 0.8:
            return "subterrestre"
        elif r < 1.5:
            return "terraneo"
        elif r < 2.5:
            return "super_tierra"
        elif r < 4.0:
            return "mini_neptuno"
        elif r < 6.0:
            return "neptuniano"
        else:
            if pd.notna(me) and me >= 2 * MJUP_TO_MEARTH:
                return "super_jupiter"
            return "joviano"

    # Fallback to mass (if available)
    if pd.notna(me):
        if me < 0.5:                 return "subterrestre"
        elif me < 2:                 return "terraneo"
        elif me < 10:                return "super_tierra"
        elif me < 20:                return "mini_neptuno"
        elif me < 50:                return "neptuniano"
        elif me >= 2 * MJUP_TO_MEARTH: return "super_jupiter"
        else:                        return "joviano"
    return None

def _thermal_class(row: pd.Series) -> Optional[str]:
    teq = row.get(COL_PL_EQT, np.nan)
    if pd.notna(teq):
        if teq < 200:         return "frio"
        elif teq <= 320:      return "templado"
        elif teq <= 800:      return "tibio"
        else:                 return "caliente"
    return None

def build_type_labels(df: pd.DataFrame, combine_thermal: bool = True) -> Tuple[pd.Series, pd.Series, pd.Series]:
    size = df.apply(_size_class, axis=1)
    thermal = df.apply(_thermal_class, axis=1)
    if combine_thermal:
        final = [f"{s}_{t}" if (s is not None and t is not None) else s for s,t in zip(size, thermal)]
        final = pd.Series(final, index=df.index)
    else:
        final = size
    return size, thermal, final

size_label, thermal_label, type_label = build_type_labels(df, combine_thermal=True)
df["size_label"]    = size_label
df["thermal_label"] = thermal_label
df["type_label"]    = type_label

cols_to_show = [c for c in [COL_PL_RADE, COL_PL_BMASSE, COL_PL_BMASSJ, COL_PL_EQT,
                            "size_label","thermal_label","type_label"] if c is not None]
df[cols_to_show].head(10)


Unnamed: 0,koi_prad,koi_teq,size_label,thermal_label,type_label
0,2.26,793.0,super_tierra,tibio,super_tierra_tibio
1,2.83,443.0,mini_neptuno,tibio,mini_neptuno_tibio
2,14.6,638.0,joviano,tibio,joviano_tibio
3,33.46,1395.0,joviano,caliente,joviano_caliente
4,2.75,1406.0,mini_neptuno,caliente,mini_neptuno_caliente
5,3.9,835.0,mini_neptuno,caliente,mini_neptuno_caliente
6,2.77,1160.0,mini_neptuno,caliente,mini_neptuno_caliente
7,1.59,1360.0,super_tierra,caliente,super_tierra_caliente
8,39.21,1342.0,joviano,caliente,joviano_caliente
9,5.76,600.0,neptuniano,tibio,neptuniano_tibio


In [None]:
df

Unnamed: 0,kepid,kepoi_name,kepler_name,koi_disposition,koi_pdisposition,koi_score,koi_fpflag_nt,koi_fpflag_ss,koi_fpflag_co,koi_fpflag_ec,koi_period,koi_period_err1,koi_period_err2,koi_time0bk,koi_time0bk_err1,koi_time0bk_err2,koi_impact,koi_impact_err1,koi_impact_err2,koi_duration,koi_duration_err1,koi_duration_err2,koi_depth,koi_depth_err1,koi_depth_err2,koi_prad,koi_prad_err1,koi_prad_err2,koi_teq,koi_teq_err1,koi_teq_err2,koi_insol,koi_insol_err1,koi_insol_err2,koi_model_snr,koi_tce_plnt_num,koi_tce_delivname,koi_steff,koi_steff_err1,koi_steff_err2,koi_slogg,koi_slogg_err1,koi_slogg_err2,koi_srad,koi_srad_err1,koi_srad_err2,ra,dec,koi_kepmag,size_label,thermal_label,type_label
0,10797460,K00752.01,Kepler-227 b,CONFIRMED,CANDIDATE,1.000,0,0,0,0,9.488036,2.775000e-05,-2.775000e-05,170.538750,0.002160,-0.002160,0.146,0.318,-0.146,2.95750,0.08190,-0.08190,615.8,19.5,-19.5,2.26,0.26,-0.15,793.0,,,93.59,29.45,-16.65,35.8,1.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347,super_tierra,tibio,super_tierra_tibio
1,10797460,K00752.02,Kepler-227 c,CONFIRMED,CANDIDATE,0.969,0,0,0,0,54.418383,2.479000e-04,-2.479000e-04,162.513840,0.003520,-0.003520,0.586,0.059,-0.443,4.50700,0.11600,-0.11600,874.8,35.5,-35.5,2.83,0.32,-0.19,443.0,,,9.11,2.87,-1.62,25.8,2.0,q1_q17_dr25_tce,5455.0,81.0,-81.0,4.467,0.064,-0.096,0.927,0.105,-0.061,291.93423,48.141651,15.347,mini_neptuno,tibio,mini_neptuno_tibio
2,10811496,K00753.01,,CANDIDATE,CANDIDATE,0.000,0,0,0,0,19.899140,1.494000e-05,-1.494000e-05,175.850252,0.000581,-0.000581,0.969,5.126,-0.077,1.78220,0.03410,-0.03410,10829.0,171.0,-171.0,14.60,3.92,-1.31,638.0,,,39.30,31.04,-10.49,76.3,1.0,q1_q17_dr25_tce,5853.0,158.0,-176.0,4.544,0.044,-0.176,0.868,0.233,-0.078,297.00482,48.134129,15.436,joviano,tibio,joviano_tibio
3,10848459,K00754.01,,FALSE POSITIVE,FALSE POSITIVE,0.000,0,1,0,0,1.736952,2.630000e-07,-2.630000e-07,170.307565,0.000115,-0.000115,1.276,0.115,-0.092,2.40641,0.00537,-0.00537,8079.2,12.8,-12.8,33.46,8.50,-2.83,1395.0,,,891.96,668.95,-230.35,505.6,1.0,q1_q17_dr25_tce,5805.0,157.0,-174.0,4.564,0.053,-0.168,0.791,0.201,-0.067,285.53461,48.285210,15.597,joviano,caliente,joviano_caliente
4,10854555,K00755.01,Kepler-664 b,CONFIRMED,CANDIDATE,1.000,0,0,0,0,2.525592,3.761000e-06,-3.761000e-06,171.595550,0.001130,-0.001130,0.701,0.235,-0.478,1.65450,0.04200,-0.04200,603.3,16.9,-16.9,2.75,0.88,-0.35,1406.0,,,926.16,874.33,-314.24,40.9,1.0,q1_q17_dr25_tce,6031.0,169.0,-211.0,4.438,0.070,-0.210,1.046,0.334,-0.133,288.75488,48.226200,15.509,mini_neptuno,caliente,mini_neptuno_caliente
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9559,10090151,K07985.01,,FALSE POSITIVE,FALSE POSITIVE,0.000,0,1,1,0,0.527699,1.160000e-07,-1.160000e-07,131.705093,0.000170,-0.000170,1.252,0.051,-0.049,3.22210,0.01740,-0.01740,1579.2,4.6,-4.6,29.35,7.70,-2.57,2088.0,,,4500.53,3406.38,-1175.26,453.3,1.0,q1_q17_dr25_tce,5638.0,139.0,-166.0,4.529,0.035,-0.196,0.903,0.237,-0.079,297.18875,47.093819,14.082,joviano,caliente,joviano_caliente
9560,10128825,K07986.01,,CANDIDATE,CANDIDATE,0.497,0,0,0,0,1.739849,1.780000e-05,-1.780000e-05,133.001270,0.007690,-0.007690,0.043,0.423,-0.043,3.11400,0.22900,-0.22900,48.5,5.4,-5.4,0.72,0.24,-0.08,1608.0,,,1585.81,1537.86,-502.22,10.6,1.0,q1_q17_dr25_tce,6119.0,165.0,-220.0,4.444,0.056,-0.224,1.031,0.341,-0.114,286.50937,47.163219,14.757,subterrestre,caliente,subterrestre_caliente
9561,10147276,K07987.01,,FALSE POSITIVE,FALSE POSITIVE,0.021,0,0,1,0,0.681402,2.434000e-06,-2.434000e-06,132.181750,0.002850,-0.002850,0.147,0.309,-0.147,0.86500,0.16200,-0.16200,103.6,14.7,-14.7,1.07,0.36,-0.11,2218.0,,,5713.41,5675.74,-1836.94,12.3,1.0,q1_q17_dr25_tce,6173.0,193.0,-236.0,4.447,0.056,-0.224,1.041,0.341,-0.114,294.16489,47.176281,15.385,terraneo,caliente,terraneo_caliente
9562,10155286,K07988.01,,CANDIDATE,CANDIDATE,0.092,0,0,0,0,333.486169,4.235000e-03,-4.235000e-03,153.615010,0.005070,-0.005070,0.214,0.255,-0.214,3.19900,0.22900,-0.22900,639.1,52.7,-52.7,19.30,0.55,-4.68,557.0,,,22.68,2.07,-10.95,14.0,1.0,q1_q17_dr25_tce,4989.0,39.0,-128.0,2.992,0.030,-0.027,7.824,0.223,-1.896,296.76288,47.145142,10.998,joviano,tibio,joviano_tibio



## 6) Feature Selection (Auto for KOI/PS)


In [None]:

def present(df, cols):
    return [c for c in cols if c in df.columns]

# Candidates per table
NUM_KOI = [
    "koi_prad","koi_teq","koi_period","koi_insol","koi_model_snr","koi_score",
    "koi_steff","koi_srad","ra","dec","koi_kepmag",
]
CAT_KOI = [
    # Avoid koi_disposition/koi_pdisposition to prevent leakage if you ever do binary classification
    "koi_tce_delivname",
]

NUM_PS = [
    "pl_rade","pl_bmasse","pl_bmassj","pl_orbper","pl_orbsmax","pl_eqt",
    "st_teff","st_rad","st_mass","st_lum","sy_dist","sy_pnum","sy_snum",
]
CAT_PS = ["discoverymethod","disc_year","discoverylocale","facility"]

num_candidates = NUM_KOI if IS_KOI else NUM_PS
cat_candidates = CAT_KOI if IS_KOI else CAT_PS

num_cols = present(df, num_candidates)
cat_cols = present(df, cat_candidates)

# Fallbacks
if not num_cols and not cat_cols:
    num_cols = df.select_dtypes(include=["number"]).columns.tolist()

print("num_cols:", num_cols)
print("cat_cols:", cat_cols)


num_cols: ['koi_prad', 'koi_teq', 'koi_period', 'koi_insol', 'koi_model_snr', 'koi_score', 'koi_steff', 'koi_srad', 'ra', 'dec', 'koi_kepmag']
cat_cols: []



## 7) Preprocessing Pipelines


In [None]:

transformers = []
if num_cols:
    transformers.append(("num",
                         Pipeline([("imputer", SimpleImputer(strategy="median")),
                                   ("scaler", StandardScaler())]),
                         num_cols))
if cat_cols:
    transformers.append(("cat",
                         Pipeline([("imputer", SimpleImputer(strategy="most_frequent")),
                                   ("onehot", OneHotEncoder(handle_unknown="ignore", min_frequency=0.01))]),
                         cat_cols))

if not transformers:
    raise RuntimeError("No features available for modeling.")

preprocessor = ColumnTransformer(transformers=transformers,
                                 remainder="drop",
                                 verbose_feature_names_out=False)



## 8) Train/Test Split & Train Type Classifier


In [None]:

df_type = df[df["type_label"].notna()].copy()
X_type  = df_type[num_cols + cat_cols]
y_type  = df_type["type_label"].astype(str)

# Filter ultra-rare classes for stability
vc = y_type.value_counts()
rare = vc[vc < 5].index
mask = ~y_type.isin(rare)
if mask.sum() < len(y_type):
    print(f"Excluding {len(y_type) - mask.sum()} rows due to rare classes (<5 samples): {list(rare)}")
X_type = X_type[mask]
y_type = y_type[mask]

Xt_train, Xt_test, yt_train, yt_test = train_test_split(
    X_type, y_type, test_size=0.2, stratify=y_type, random_state=42
)
print("Shapes:", Xt_train.shape, Xt_test.shape)

clf_type = Pipeline([("prep", preprocessor),
                     ("clf", HistGradientBoostingClassifier(random_state=42))])
clf_type.fit(Xt_train, yt_train)

yt_pred = clf_type.predict(Xt_test)
print("== Type classification report ==")
print(classification_report(yt_test, yt_pred, digits=4))
print("Confusion matrix:\n", confusion_matrix(yt_test, yt_pred))


Excluding 1 rows due to rare classes (<5 samples): ['neptuniano_frio']
Shapes: (7360, 11) (1840, 11)
== Type classification report ==
                       precision    recall  f1-score   support

     joviano_caliente     0.6935    0.1150    0.1972       374
         joviano_frio     0.0000    0.0000    0.0000         3
     joviano_templado     0.0000    0.0000    0.0000        27
        joviano_tibio     0.3167    0.9521    0.4753       167
mini_neptuno_caliente     0.0385    0.0122    0.0185        82
    mini_neptuno_frio     0.0000    0.0000    0.0000         3
mini_neptuno_templado     0.0000    0.0000    0.0000        34
   mini_neptuno_tibio     0.0000    0.0000    0.0000       121
  neptuniano_caliente     0.0000    0.0000    0.0000        28
  neptuniano_templado     0.0000    0.0000    0.0000         8
     neptuniano_tibio     0.4737    0.2571    0.3333        35
subterrestre_caliente     0.0000    0.0000    0.0000        87
    subterrestre_frio     0.0000    0.0000    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



## 9) (Optional) Binary Classifier — Is Exoplanet?
> Provide your binary label column name if available (1/0). For KOI `koi_disposition` is **not** a clean binary ground truth; avoid leakage.


In [None]:

BINARY_LABEL_COL = None  # e.g., "is_planet" if your dataset truly includes it

clf_bin = None
if BINARY_LABEL_COL and BINARY_LABEL_COL in df.columns:
    X_bin = df[num_cols + cat_cols]
    y_bin = df[BINARY_LABEL_COL].astype(int)
    Xb_train, Xb_test, yb_train, yb_test = train_test_split(
        X_bin, y_bin, test_size=0.2, stratify=y_bin, random_state=42
    )
    clf_bin = Pipeline([("prep", preprocessor),
                        ("clf", HistGradientBoostingClassifier(random_state=42))])
    clf_bin.fit(Xb_train, yb_train)
    yb_pred = clf_bin.predict(Xb_test)
    print("== Binary classification report ==")
    print(classification_report(yb_test, yb_pred, digits=4))
    print("Confusion matrix:\n", confusion_matrix(yb_test, yb_pred))
else:
    print("No valid binary label provided; skipping binary model.")


No valid binary label provided; skipping binary model.



## 10) Save Models & Metadata


In [None]:

os.makedirs("modelos", exist_ok=True)
joblib.dump(clf_type, "modelos/clf_exoplanet_type.joblib")
if clf_bin is not None:
    joblib.dump(clf_bin, "modelos/clf_is_exoplanet.joblib")

meta = {
    "num_cols": num_cols,
    "cat_cols": cat_cols,
    "classes_type": sorted(y_type.unique().tolist()),
    "is_koi": bool(IS_KOI),
    "is_ps": bool(IS_PS),
    "col_alias": {
        "R_earth": COL_PL_RADE,
        "M_earth": COL_PL_BMASSE,
        "M_jup": COL_PL_BMASSJ,
        "T_eq": COL_PL_EQT,
        "period_days": COL_PERIOD,
        "st_teff": COL_ST_TEFF,
        "st_rad": COL_ST_RAD,
        "insol": COL_INSOL,
    },
}
joblib.dump(meta, "modelos/metadata.joblib")
print("Saved models & metadata to ./modelos")


Saved models & metadata to ./modelos



## 11) Inference Helpers


In [None]:

def load_models(model_dir: str = "modelos"):
    clf_bin = None
    p_bin = os.path.join(model_dir, "clf_is_exoplanet.joblib")
    if os.path.exists(p_bin):
        clf_bin = joblib.load(p_bin)
    clf_type = joblib.load(os.path.join(model_dir, "clf_exoplanet_type.joblib"))
    meta = joblib.load(os.path.join(model_dir, "metadata.joblib"))
    return clf_bin, clf_type, meta

def predict_exoplanet(example: dict, model_dir: str = "modelos") -> dict:
    clf_bin, clf_type, meta = load_models(model_dir)
    cols = meta["num_cols"] + meta["cat_cols"]
    X = pd.DataFrame([{c: example.get(c, np.nan) for c in cols}])
    out = {}
    if clf_bin is not None:
        proba = clf_bin.predict_proba(X)[0, 1]
        out["is_exoplanet"] = int(proba >= 0.5)
        out["is_exoplanet_proba"] = float(proba)
    proba_type = clf_type.predict_proba(X)[0]
    pred_type = clf_type.predict(X)[0]
    classes = clf_type.named_steps["clf"].classes_
    topk = np.argsort(proba_type)[::-1][:3]
    out["type"] = str(pred_type)
    out["type_top3"] = [(str(classes[i]), float(proba_type[i])) for i in topk]
    return out



## 12) Example Inference


In [None]:

# Example uses either koi_* or pl_* depending on your dataset.
example = {}
if IS_KOI:
    example = {
        "koi_prad": 11.2,
        "koi_teq": 1400,
        "koi_period": 3.5,
        "koi_model_snr": 12.0,
        "koi_steff": 5600.0,
        "koi_srad": 1.0,
    }
else:
    example = {
        "pl_rade": 11.2,
        "pl_eqt": 1400,
        "pl_orbper": 3.5,
        "st_teff": 5600.0,
        "st_rad": 1.0,
    }

predict_exoplanet(example, model_dir="modelos")


{'type': 'super_tierra_caliente',
 'type_top3': [('super_tierra_caliente', 1.0),
  ('terraneo_templado', 0.0),
  ('terraneo_tibio', 0.0)]}