In [57]:
import pandas as pd
from astroquery.sdss import SDSS
from astropy import coordinates as coords
from astropy import units as u
from tqdm import tqdm
from joblib import Parallel, delayed

from astroquery.sdss import SDSS
from astropy import units as u
from astropy import coordinates as coords
import numpy as np
import time

SDSS.clear_cache()

## No SQL

In [63]:
def nearest_sdss_galaxy(ra, dec, radius_arcmin=0.5):
    pos = coords.SkyCoord(ra, dec, unit="deg")
    while True:
        try:
            result = SDSS.query_crossid(
                pos,
                radius=radius_arcmin * u.arcmin,
                #spectro=True,
                photoobj_fields=['objid'],
                #specobj_fields=['zWarning','targetType','survey','primtarget'],
                data_release=17
            )
            # --- Caso: coordenadas fuera del footprint (respuesta vacía, NO error HTTP) ---
            if result is None or len(result) == 0:
                return None   # ← devolver NaN/None inmediatamente

            # --- Filtro según tus condiciones ---
            # mask = (
            #     (result['zWarning'] == 0) &
            #     (result['targetType'] == 'SCIENCE') &
            #     (result['survey'] == 'sdss') &
            #     (result['primtarget'] >= 64)
            # )
            # result = result[mask]

            # Nada cumple las condiciones → también es un "no match"
            if len(result) == 0:
                return None

            # Si llegamos aquí: todo bien → retornar el más cercano
            return result[0], len(result)

        except Exception as e:
            print(f"Error con RA={ra}, DEC={dec}: {e}")
            SDSS.clear_cache()
            time.sleep(1)  
            continue

def _crossmatch_single_row(row, ra_col, dec_col, radius_arcmin):
    ra = float(row[ra_col])
    dec = float(row[dec_col])
    out = nearest_sdss_galaxy(ra, dec, radius_arcmin)

    if out is None:
        return {
            'sdss_objid': pd.NA,
            #'sdss_ra': None,
            #'sdss_dec': None,
            #'sdss_dist_arcsec': None,
        }
    else:
        (obj, matches) = out
        return {
            'sdss_objid': int(obj['objid']),
            'matches': int(matches),
            #'sdss_ra': float(obj['ra']),
            #'sdss_dec': float(obj['dec']),
            #'sdss_dist_arcsec': float(dist_arcsec),
        }

def sdss_crossmatch_joblib(df, ra_col='ra', dec_col='dec', radius_arcmin=0.5, n_jobs=-1):

    rows = list(df.to_dict("records"))

    # tqdm + joblib
    results = Parallel(n_jobs=n_jobs, backend="threading")(
        delayed(_crossmatch_single_row)(row, ra_col, dec_col, radius_arcmin)
        for row in tqdm(rows)
    )

    df_out = pd.DataFrame(results)
    df_out["sdss_objid"] = df_out["sdss_objid"].astype("Int64")

    return df_out

In [64]:
df_test = pd.read_csv("/home/acontreras/PRISM/data/SERSIC/df_test_delight.csv")
preds_delight = pd.read_csv("/home/acontreras/PRISM/resultados/delight_baseline_r/test_predictions.csv")
preds_autolabeling_05 = pd.read_csv("/home/acontreras/PRISM/resultados/delight_augmented_05_percent_1_sampler/test_predictions.csv")

In [65]:
df_1nn = sdss_crossmatch_joblib(df_test, ra_col='sn_ra', dec_col='sn_dec')


100%|██████████| 4788/4788 [00:37<00:00, 127.58it/s]


Error con RA=8.9017771, DEC=9.1242674: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Read timed out. (read timeout=60)
Error con RA=136.47615732, DEC=29.85249066: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Read timed out. (read timeout=60)
Error con RA=245.48463486, DEC=21.89142672: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Read timed out. (read timeout=60)
Error con RA=125.45031489, DEC=-14.06993678: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Read timed out. (read timeout=60)
Error con RA=259.72118235, DEC=38.38097485: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Read timed out. (read timeout=60)
Error con RA=340.02166933, DEC=-6.02605709: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Read timed out. (read timeout=60)
Error con RA=357.69015237, DEC=15.48692974: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Read timed out. (read timeout=60)
Error con RA=191.8289088, DEC=44.4500367: HTTPSConne

In [66]:
#df_1nn = sdss_crossmatch_joblib(df_test, ra_col='sn_ra', dec_col='sn_dec')
df_out = sdss_crossmatch_joblib(df_test, ra_col='host_ra', dec_col='host_dec')
df_preds_delight = sdss_crossmatch_joblib(preds_delight, ra_col='ra_pred', dec_col='dec_pred')
df_preds_auto = sdss_crossmatch_joblib(preds_autolabeling_05, ra_col='ra_pred', dec_col='dec_pred')

  0%|          | 0/4788 [00:00<?, ?it/s]

100%|██████████| 4788/4788 [00:43<00:00, 109.38it/s]


Error con RA=42.37795833333333, DEC=22.89238888888889: HTTPSConnectionPool(host='skyserver.sdss.org', port=443): Max retries exceeded with url: /dr17/en/tools/search/X_Results.aspx (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7f574deb5fd0>, 'Connection to skyserver.sdss.org timed out. (connect timeout=60)'))


100%|██████████| 4788/4788 [00:33<00:00, 142.04it/s]
100%|██████████| 4788/4788 [00:35<00:00, 135.83it/s]


In [67]:
preds = (df_1nn['sdss_objid'] == df_out['sdss_objid'])

accuracy = preds.eq(True).sum() / df_out['sdss_objid'].notna().sum()
print(f"# galaxias totales ID: {df_out['sdss_objid'].notna().sum()}")
print(f"# galaxias match: {preds.eq(True).sum()}")
print(f"accuracy: {accuracy}")

# galaxias totales ID: 3245
# galaxias match: 2526
accuracy: 0.7784283513097072


In [68]:
preds = (df_preds_delight['sdss_objid'] == df_out['sdss_objid'])

accuracy = preds.eq(True).sum() / df_out['sdss_objid'].notna().sum()
print(f"# galaxias totales ID: {df_out['sdss_objid'].notna().sum()}")
print(f"# galaxias match: {preds.eq(True).sum()}")
print(f"accuracy: {accuracy}")

# galaxias totales ID: 3245
# galaxias match: 3147
accuracy: 0.9697996918335902


In [69]:
preds = (df_preds_auto['sdss_objid'] == df_out['sdss_objid'])

accuracy = preds.eq(True).sum() / df_out['sdss_objid'].notna().sum()
print(f"# galaxias totales ID: {df_out['sdss_objid'].notna().sum()}")
print(f"# galaxias match: {preds.eq(True).sum()}")
print(f"accuracy: {accuracy}")

# galaxias totales ID: 3245
# galaxias match: 3144
accuracy: 0.9688751926040061
