# TESS → ExoNet-XS (Multi-target): build a dataset with Global/Local views

This notebook creates a **multi-target pipeline**:

* Downloads **TESScut TPFs** for a list of targets
* Extracts a **light curve** using a **threshold** or **circular** aperture mask
* Applies **detrending** with **CBVs** (if available) or **PLD** as a fallback
* Runs **BLS** to obtain **period/epoch/duration**
* Builds **Global/Local views** + **centroids**
* **Saves** a standardized dataset (`index.csv` + per-*sample* folders with `global.npy`, `local.npy`, `params.npy`, `meta.json`)
* Includes a **PyTorch Dataset** and a **training skeleton** (uses labels if you provide them)


## 0) Install dependencies (if needed)

In [None]:
# Se precisar:
!pip install --upgrade pip
!pip install numpy scipy astropy photutils lightkurve astroquery pandas torch torchvision matplotlib tqdm scikit-image


Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m29.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.2
Collecting astropy
  Downloading astropy-7.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting photutils
  Downloading photutils-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (5.4 kB)
Collecting lightkurve
  Downloading lightkurve-2.5.1-py3-none-any.whl.metadata (6.3 kB)
Collecting astroquery
  Downloading astroquery-0.4.11-py3-none-any.whl.metadata (6.5 kB)
Collecting pyerfa>=2.0.1.1 (from astropy)
  Downloading pyerfa-2.0.1.5-cp39-abi3-manylinux_2_17_x

## 1) Configurações gerais

In [None]:
from pathlib import Path

# List of targets (names or TIC/TOI). You can edit/expand:
TARGETS = [
    "TOI 700",
    "pi Mensae",
    "Kepler-10",
    # "TIC 150428135",
]

# (Optional) Force a single sector per target (None => all available)
FORCED_SECTOR = None  # e.g., 13

# TESScut cutout (in pixels, square side length)
CUTOUT_SIZE = 15

# Aperture mask: "threshold" OR "circular"
APERTURE_MODE = "threshold"
THRESHOLD = 3       # sigma for threshold mode
CIRC_RADIUS = 3     # px for circular mode

# Detrending: prefer CBVs; if unavailable, use PLD
USE_CBVS = True
USE_PLD  = True

# BLS
P_MIN, P_MAX = 0.5, 30.0  # days
N_PERIODS    = 20000
DUR_FRAC     = 0.02

# Output directories
DATA_DIR    = Path("dataset_tess_exonetxs")
CUTS_DIR    = DATA_DIR / "tpf"
SAMPLES_DIR = DATA_DIR / "samples"
for d in [DATA_DIR, CUTS_DIR, SAMPLES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# (Optional) Path to CSV with labels (columns: target,label [,sector])
# label: 1=planet/positive; 0=false/negative; -1=unknown
LABELS_CSV = None  # e.g., "labels.csv"

print("Config OK.")


Config OK.


## 2) Helper functions (download, aperture mask, detrending, BLS, views)

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from astropy.coordinates import SkyCoord
import astropy.units as u
from astroquery.mast import Tesscut
from lightkurve import open as lk_open
from lightkurve import search_tesscut
from lightkurve.correctors import CBVCorrector, PLDCorrector
from astropy.timeseries import BoxLeastSquares

def download_tesscut(target_name: str, sector=None, size=15, out_dir=Path(".")):
    """Download TESScut cutouts and save .fits; return a list of file paths."""
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    # Resolve coordinates
    try:
        coord = SkyCoord.from_name(target_name)
    except Exception:
        sr = search_tesscut(target_name, sector=sector)
        if len(sr) == 0:
            raise RuntimeError(f"Nothing found in TESScut for '{target_name}'.")
        coord = SkyCoord(sr.table[0]["target_ra"], sr.table[0]["target_dec"], unit=(u.deg, u.deg))
    # Download cutouts
    hduls = Tesscut.get_cutouts(coordinates=coord, size=size, sector=sector)
    if hduls is None or len(hduls) == 0:
        raise RuntimeError(f"TESScut returned no cutouts for '{target_name}' (sector={sector}).")
    paths = []
    for i, hdu in enumerate(hduls):
        p = out_dir / f"{target_name.replace(' ', '_')}_sector{sector if sector is not None else 'ALL'}_{i:02d}.fits"
        hdu.writeto(p, overwrite=True); paths.append(str(p))
    return paths

def build_aperture_mask(tpf, mode="threshold", threshold=3, circ_radius=3):
    if mode.lower() == "threshold":
        return tpf.create_threshold_mask(threshold=threshold)
    # circular
    img = np.nanmean(tpf.flux.value, axis=0)
    yy, xx = np.indices(img.shape)
    den = np.nansum(img) + 1e-8
    cx = np.nansum(img * xx) / den
    cy = np.nansum(img * yy) / den
    rr2 = (xx - cx)**2 + (yy - cy)**2
    return (rr2 <= (circ_radius**2))

def detrend_lightcurve(tpf, mask, use_cbvs=True, use_pld=True):
    raw_lc = tpf.to_lightcurve(aperture_mask=mask).remove_nans()
    corrected_lc = None
    if use_cbvs:
        try:
            cbv = CBVCorrector(raw_lc)
            corrected_lc = cbv.correct()
            print("CBVCorrector OK.")
        except Exception as e:
            print("CBVCorrector unavailable:", e)
    if corrected_lc is None and use_pld:
        try:
            pld = PLDCorrector(tpf)
            corrected_lc = pld.correct(aperture_mask=mask)
            print("PLDCorrector OK (fallback).")
        except Exception as e:
            print("PLDCorrector failed:", e)
    if corrected_lc is None:
        corrected_lc = raw_lc
    lc = corrected_lc.normalize(unit="ppm").remove_outliers(sigma=5)
    return raw_lc, lc

def run_bls(time_mjd, flux_norm, p_min, p_max, n_periods, dur_frac):
    y = flux_norm / np.nanmedian(flux_norm)
    bls = BoxLeastSquares(time_mjd, y)
    periods = np.linspace(p_min, p_max, n_periods)
    power = bls.power(periods, dur_frac)
    i = np.nanargmax(power.power)
    P = float(power.period[i]); t0 = float(power.transit_time[i]); dur = float(power.duration[i]); depth = float(power.depth[i])
    return P, t0, dur, depth, power

def phase_fold(t, y, P, t0):
    phi = ((t - t0 + 0.5 * P) % P) / P - 0.5
    order = np.argsort(phi)
    return phi[order], y[order]

def resample_uniform(x, y, n, xmin, xmax):
    grid = np.linspace(xmin, xmax, n)
    y_res = np.interp(grid, x, y, left=y[0], right=y[-1])
    return grid, y_res

def build_views(times_mjd, rel_flux, P, t0, dur, centroids=None, n_points=2001, local_k=2.0):
    phi, yv = phase_fold(times_mjd, rel_flux, P, t0)
    gx, gy = resample_uniform(phi, yv, n_points, -0.5, 0.5)
    w = local_k * (dur / P)
    lx, ly = resample_uniform(phi, yv, n_points, -w, +w)
    def norm_robust(v):
        med = np.nanmedian(v)
        q25, q75 = np.nanpercentile(v, [25, 75])
        iqr = max(q75 - q25, 1e-6)
        return (v - med) / iqr
    gy = norm_robust(gy); ly = norm_robust(ly)
    if centroids is not None:
        cx, cy = centroids
        _, cxg = resample_uniform(*phase_fold(times_mjd, cx, P, t0), n_points, -0.5, 0.5)
        _, cyg = resample_uniform(*phase_fold(times_mjd, cy, P, t0), n_points, -0.5, 0.5)
        _, cxl = resample_uniform(*phase_fold(times_mjd, cx, P, t0), n_points, -w, +w)
        _, cyl = resample_uniform(*phase_fold(times_mjd, cy, P, t0), n_points, -w, +w)
        def z(v):
            mu = np.nanmean(v); sd = np.nanstd(v) + 1e-6
            return (v - mu)/sd
        cxg, cyg, cxl, cyl = z(cxg), z(cyg), z(cxl), z(cyl)
        global_view = np.stack([gy, cxg, cyg], axis=0).astype("float32")
        local_view  = np.stack([ly, cxl, cyl], axis=0).astype("float32")
    else:
        global_view = gy[None, :].astype("float32")
        local_view  = ly[None, :].astype("float32")
    return global_view, local_view

def compute_centroids_from_tpf(tpf):
    flux_cube = tpf.flux.value
    yy, xx = np.indices(flux_cube.shape[1:])
    den = np.nansum(flux_cube, axis=(1,2)) + 1e-8
    cx = np.nansum(flux_cube * xx, axis=(1,2)) / den
    cy = np.nansum(flux_cube * yy, axis=(1,2)) / den
    return cx, cy

def save_sample(sample_dir, global_view, local_view, params, meta):
    sample_dir = Path(sample_dir); sample_dir.mkdir(parents=True, exist_ok=True)
    np.save(sample_dir / "global.npy", global_view)
    np.save(sample_dir / "local.npy",  local_view)
    np.save(sample_dir / "params.npy", params)
    with (sample_dir / "meta.json").open("w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)




## 3) (Optional) Load labels from a CSV

In [None]:
import pandas as pd

labels_map = {}
if LABELS_CSV:
    df_labels = pd.read_csv(LABELS_CSV)
    # Expects columns: target,label[,sector]
    for _, r in df_labels.iterrows():
        key = (str(r['target']).strip(), int(r['sector'])) if 'sector' in r and pd.notna(r['sector']) else (str(r['target']).strip(), None)
        labels_map[key] = int(r['label'])
    print(f"Loaded labels: {len(labels_map)} entries")
else:
    print("No labels CSV defined; samples will be saved with label=-1 (unknown). Use index.csv to label them later.")


Nenhum CSV de rótulos definido; os samples serão salvos com label=-1 (desconhecido). Use o index.csv para rotular depois.


## 4) Main loop: build the dataset

In [None]:
import uuid
import pandas as pd

index_rows = []
for target in TARGETS:
    try:
        # 1) Download TESScut (um ou vários arquivos por setor)
        files = download_tesscut(target, sector=FORCED_SECTOR, size=CUTOUT_SIZE, out_dir=CUTS_DIR)
        if len(files) == 0:
            print(f"[WARN] Sem TPFs para {target}"); continue

        # 2) Abre o 1º TPF (você pode iterar sobre todos se quiser aumentar amostra)
        tpf = lk_open(files[0])
        # Extrai setor do header, se existir
        sector_hdr = None
        try:
            sector_hdr = int(getattr(tpf.hdu[0].header, 'SECTOR', None) or getattr(tpf.hdu[1].header, 'SECTOR', None) or getattr(tpf.hdu[2].header, 'SECTOR', None))
        except Exception:
            sector_hdr = None

        # 3) Máscara
        mask = build_aperture_mask(tpf, mode=APERTURE_MODE, threshold=THRESHOLD, circ_radius=CIRC_RADIUS)

        # 4) Detrending
        raw_lc, lc = detrend_lightcurve(tpf, mask, use_cbvs=USE_CBVS, use_pld=USE_PLD)
        time_mjd = lc.time.to_value("mjd")
        flux     = lc.flux.value

        # 5) BLS
        P, t0, dur, depth, power = run_bls(time_mjd, flux, P_MIN, P_MAX, N_PERIODS, DUR_FRAC)

        # 6) Vistas + centróides
        cx, cy = compute_centroids_from_tpf(tpf)
        ynorm = flux / np.nanmedian(flux)
        gview, lview = build_views(time_mjd, ynorm, P, t0, dur, centroids=(cx, cy))

        # 7) Params/Meta
        params = np.zeros(4, dtype=np.float32)  # placeholder p/ T_eff, log g, R*, Fe/H
        meta = {
            "target": target,
            "sector": sector_hdr if sector_hdr is not None else FORCED_SECTOR,
            "period": float(P), "t0": float(t0), "duration": float(dur), "depth": float(depth),
            "n_points": int(gview.shape[-1]), "local_k": 2.0
        }

        # 8) Label
        key1 = (target, sector_hdr if sector_hdr is not None else FORCED_SECTOR)
        key2 = (target, None)
        label = labels_map.get(key1, labels_map.get(key2, -1))
        meta["label"] = int(label)

        # 9) Salvar sample
        sample_id = f"{target.replace(' ','_')}_{(sector_hdr if sector_hdr is not None else 'ALL')}_{uuid.uuid4().hex[:8]}"
        sample_dir = SAMPLES_DIR / sample_id
        save_sample(sample_dir, gview, lview, params, meta)

        # 10) Index
        index_rows.append({
            "sample_id": sample_id,
            "target": target,
            "sector": meta["sector"],
            "label": meta["label"],
            "period": meta["period"],
            "t0": meta["t0"],
            "duration": meta["duration"],
            "depth": meta["depth"],
            "global_path": str(sample_dir / "global.npy"),
            "local_path": str(sample_dir / "local.npy"),
            "params_path": str(sample_dir / "params.npy"),
            "meta_path": str(sample_dir / "meta.json"),
            "tpf_path": files[0],
        })

        print(f"OK: {target} (sector={meta['sector']}) → sample_id={sample_id}")
    except Exception as e:
        print(f"[ERRO] {target}: {e}")

# Salva index.csv
df_index = pd.DataFrame(index_rows)
df_index.to_csv(DATA_DIR / "index.csv", index=False)
print(f"Dataset pronto: {len(df_index)} samples. Index salvo em {DATA_DIR/'index.csv'}")


        Use read() instead.
  tpf = lk_open(files[0])
The SingleScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The Spike CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
  return function(*args, **kwargs)


CBVCorrector indisponível: SVD did not converge
PLDCorrector OK (fallback).
OK: TOI 700 (sector=None) → sample_id=TOI_700_ALL_4c09077b


        Use read() instead.
  tpf = lk_open(files[0])
The SingleScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The Spike CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
  return function(*args, **kwargs)


CBVCorrector indisponível: SVD did not converge
PLDCorrector OK (fallback).
OK: pi Mensae (sector=None) → sample_id=pi_Mensae_ALL_953947c2


        Use read() instead.
  tpf = lk_open(files[0])
The SingleScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The MultiScale CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
The Spike CBVs do not appear to be well aligned to the light curve. Consider using "interpolate_cbvs=True"
  return function(*args, **kwargs)


CBVCorrector indisponível: SVD did not converge
PLDCorrector OK (fallback).
OK: Kepler-10 (sector=None) → sample_id=Kepler-10_ALL_58822b3f
Dataset pronto: 3 samples. Index salvo em dataset_tess_exonetxs/index.csv


## 5) View dataset summary

In [None]:
import pandas as pd
from pathlib import Path

idx_path = Path("dataset_tess_exonetxs/index.csv")
if idx_path.exists():
    df = pd.read_csv(idx_path)
    display(df.head(10))
    print("Total:", len(df))
else:
    print("index.csv não encontrado.")


Unnamed: 0,sample_id,target,sector,label,period,t0,duration,depth,global_path,local_path,params_path,meta_path,tpf_path
0,TOI_700_ALL_4c09077b,TOI 700,,-1,13.085329,58334.499023,0.02,0.008277,dataset_tess_exonetxs/samples/TOI_700_ALL_4c09...,dataset_tess_exonetxs/samples/TOI_700_ALL_4c09...,dataset_tess_exonetxs/samples/TOI_700_ALL_4c09...,dataset_tess_exonetxs/samples/TOI_700_ALL_4c09...,dataset_tess_exonetxs/tpf/TOI_700_sectorALL_00...
1,pi_Mensae_ALL_953947c2,pi Mensae,,-1,4.633157,58325.220261,0.02,0.002162,dataset_tess_exonetxs/samples/pi_Mensae_ALL_95...,dataset_tess_exonetxs/samples/pi_Mensae_ALL_95...,dataset_tess_exonetxs/samples/pi_Mensae_ALL_95...,dataset_tess_exonetxs/samples/pi_Mensae_ALL_95...,dataset_tess_exonetxs/tpf/pi_Mensae_sectorALL_...
2,Kepler-10_ALL_58822b3f,Kepler-10,,-1,9.874094,58687.421371,0.02,0.00228,dataset_tess_exonetxs/samples/Kepler-10_ALL_58...,dataset_tess_exonetxs/samples/Kepler-10_ALL_58...,dataset_tess_exonetxs/samples/Kepler-10_ALL_58...,dataset_tess_exonetxs/samples/Kepler-10_ALL_58...,dataset_tess_exonetxs/tpf/Kepler-10_sectorALL_...


Total: 3


## 6) PyTorch Dataset + training (uses only samples with label ∈ {0,1})

In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class ExoNetXSDataset(Dataset):
    def __init__(self, index_csv, supervised_only=True):
        import pandas as pd
        self.df = pd.read_csv(index_csv)
        if supervised_only:
            self.df = self.df[self.df['label'].isin([0,1])].reset_index(drop=True)
        self.supervised_only = supervised_only
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        g = np.load(r['global_path']).astype('float32')  # (C,N)
        l = np.load(r['local_path']).astype('float32')
        p = np.load(r['params_path']).astype('float32')  # (D,)
        y = int(r['label']) if r['label'] in [0,1] else -1
        return torch.from_numpy(g), torch.from_numpy(l), torch.from_numpy(p), torch.tensor(y, dtype=torch.long)

class ConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k=5, s=1, p=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_ch, out_ch, kernel_size=k, stride=1, padding=p),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class ExoNetXS(nn.Module):
    def __init__(self, in_ch_global=1, in_ch_local=1, params_dim=4, n_classes=2):
        super().__init__()
        self.g1 = ConvBlock1D(in_ch_global, 16)
        self.g2 = ConvBlock1D(16, 32)
        self.g3 = ConvBlock1D(32, 64)
        self.gpool = nn.AdaptiveMaxPool1d(1)
        self.l1 = ConvBlock1D(in_ch_local, 16)
        self.l2 = ConvBlock1D(16, 32)
        self.lpool = nn.AdaptiveMaxPool1d(1)
        fused = 64 + 32 + params_dim
        self.head = nn.Sequential(
            nn.Linear(fused, 64), nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(64, n_classes)
        )
    def forward(self, g, l, p):
        g = self.g3(self.g2(self.g1(g))); g = self.gpool(g).squeeze(-1)
        l = self.l2(self.l1(l));          l = self.lpool(l).squeeze(-1)
        x = torch.cat([g, l, p], dim=1)
        return self.head(x)

# Carrega dataset
index_csv = "dataset_tess_exonetxs/index.csv"
ds = ExoNetXSDataset(index_csv, supervised_only=True)
if len(ds) == 0:
    print("Sem rótulos 0/1 no index.csv — adicione rótulos e reexecute.")
else:
    # split
    n = len(ds); n_val = max(1, int(0.2*n))
    tr, va = torch.utils.data.random_split(ds, [n-n_val, n_val])
    tr_loader = DataLoader(tr, batch_size=8, shuffle=True)
    va_loader = DataLoader(va, batch_size=8)

    # infer channels/dims
    g0,l0,p0,y0 = ds[0]
    model = ExoNetXS(in_ch_global=g0.shape[0], in_ch_local=l0.shape[0], params_dim=p0.shape[0], n_classes=2)

    # class weights (balanceamento)
    import pandas as pd
    df = pd.read_csv(index_csv)
    df = df[df['label'].isin([0,1])]
    n_pos = (df['label']==1).sum(); n_neg = (df['label']==0).sum()
    total  = n_pos + n_neg
    w_pos = total / (2*max(1,n_pos)); w_neg = total / (2*max(1,n_neg))
    crit = nn.CrossEntropyLoss(weight=torch.tensor([w_neg, w_pos], dtype=torch.float32))

    opt = torch.optim.Adam(model.parameters(), lr=1e-3)

    for ep in range(5):
        model.train(); run=0.0
        for g,l,p,y in tr_loader:
            opt.zero_grad(); logits = model(g,l,p); loss = crit(logits,y); loss.backward(); opt.step()
            run += loss.item()*y.size(0)
        tl = run/len(tr)

        model.eval(); corr=tot=0
        with torch.no_grad():
            for g,l,p,y in va_loader:
                logits = model(g,l,p)
                pred = logits.argmax(1)
                corr += (pred==y).sum().item(); tot += y.size(0)
        acc = corr/max(1,tot)
        print(f"Epoch {ep+1} | train_loss={tl:.4f} | val_acc={acc:.3f}")


FileNotFoundError: [Errno 2] No such file or directory: 'dataset_tess_exonetxs/index.csv'

## 7) How to label (examples)

In [None]:
# Crie um CSV simples com colunas: target,label[,sector]
# Exemplo:
# target,label,sector
# TOI 700,1,
# Kepler-10,1,
# pi Mensae,1,
# TIC 150428135,0,13
#
# Salve como labels.csv e aponte LABELS_CSV para esse caminho na seção de Configurações.


print("See the example in the comment above. You can also open dataset_tess_exonetxs/index.csv and add the 'label' column.")



Veja o exemplo no comentário acima. Você também pode abrir dataset_tess_exonetxs/index.csv e adicionar a coluna 'label'.
