In [None]:
!apt-get update && apt-get install -y aria2

In [None]:
import os
import shutil
import subprocess
import requests
import numpy as np
import h5py
from astropy.io import fits
from astropy.table import Table
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# --- CONFIGURATION ---
OUTPUT_DIR = "/kaggle/working"
OUTPUT_FILENAME = "apogee_dr17_parallel.h5"
OUTPUT_PATH = os.path.join(OUTPUT_DIR, OUTPUT_FILENAME)

# URLs
CATALOG_URL = "https://data.sdss.org/sas/dr17/apogee/spectro/aspcap/dr17/synspec_rev1/allStar-dr17-synspec_rev1.fits"
BASE_SAS_URL = "https://data.sdss.org/sas/dr17/apogee/spectro/aspcap/dr17/synspec_rev1/"

# Filters
MIN_SNR = 100 # only the best stars for me model
MAX_STARS = 150000 
MAX_WORKERS = 10  # 10 Parallel downloads, cuz I don't wanna wait till the end of time (plz don't ban me)


def get_star_url(star_row):
    telescope = star_row['TELESCOPE']
    field = star_row['FIELD']
    apogee_id = star_row['APOGEE_ID']
    fname = f"aspcapStar-dr17-{apogee_id}.fits"
    url = f"{BASE_SAS_URL}{telescope}/{field}/{fname}"
    return url, fname

def process_single_star(row, keep_cols, custom_dtype):
    #download and extract
    url, filename = get_star_url(row)
    # Unique temp name to prevent threads overwriting each other
    local_path = f"/tmp/{filename}_{os.getpid()}_{np.random.randint(0,10000)}"
    
    try:
        # 1. Download
        with requests.get(url, stream=True, timeout=15) as r:
            if r.status_code != 200:
                return None
            with open(local_path, 'wb') as f:
                shutil.copyfileobj(r.raw, f)

        # 2. Extract Data
        with fits.open(local_path) as hdul:
            flux = hdul[1].data.astype('float32')
            err  = hdul[2].data.astype('float32')
            
            # Create Inverse Variance (Masking bad pixels)
            with np.errstate(divide='ignore'):
                ivar = 1.0 / (err**2)
            ivar[~np.isfinite(ivar)] = 0.0

        # 3. Extract Labels
        label_values = tuple(row[col] for col in keep_cols)
        label_entry = np.array([label_values], dtype=custom_dtype)
        
        star_id = row['APOGEE_ID']
        
        return (flux, ivar, label_entry, star_id)

    except Exception:
        return None
        
    finally:
        # 4. Clean up. 350 GB nahi he mere pas
        if os.path.exists(local_path):
            os.remove(local_path)

def download_and_pack():
    print("--- DOWNLOADING CATALOG ---")
    catalog_path = "/tmp/allStar.fits"
    
    # clean previous incomplete files, if they exist
    if os.path.exists(catalog_path):
        if os.path.getsize(catalog_path) < 3 * 1024**3:
            print("Deleting incomplete catalog...")
            os.remove(catalog_path)
            
    if not os.path.exists(catalog_path):
        print(f"Fetching catalog (using aria2c)...")
        cmd = f"aria2c -x 16 -s 16 -q --file-allocation=none -d /tmp -o allStar.fits {CATALOG_URL}"
        #16 connections, safe (hopefully)
        subprocess.run(cmd, shell=True, check=True)
    
    print("Reading Catalog Table...")
    catalog = Table.read(catalog_path)
    
    # --- PREPARE SCHEMA ---
    keep_cols = []
    dtype_list = []
    for col in catalog.colnames:
        if catalog[col].ndim == 1:
            keep_cols.append(col)
            dtype = catalog[col].dtype
            if dtype.kind in ['U', 'S']:
                dtype_list.append((col, 'S30')) 
            else:
                dtype_list.append((col, dtype))
    custom_dtype = np.dtype(dtype_list)

    # --- FILTERING ---
    print("Filtering stars...")
    mask = (catalog['SNR'] > MIN_SNR) & (catalog['ASPCAPFLAG'] == 0) & (catalog['TEFF'] > 0)
    best_stars = catalog[mask]
    
    indices = np.arange(len(best_stars))
    np.random.seed(42)
    np.random.shuffle(indices)
    limit = min(len(best_stars), MAX_STARS)
    final_catalog = best_stars[indices[:limit]]
    
    print(f"Targeting {len(final_catalog)} stars with {MAX_WORKERS} parallel threads.")

    with h5py.File(OUTPUT_PATH, 'w') as f:
        # Initialize datasets
        dset_flux = f.create_dataset("flux", (0, 8575), maxshape=(None, 8575), dtype='float32', compression="gzip")
        dset_ivar = f.create_dataset("ivar", (0, 8575), maxshape=(None, 8575), dtype='float32', compression="gzip")
        dset_labels = f.create_dataset("metadata", (0,), maxshape=(None,), dtype=custom_dtype)
        
        # We use a ThreadPool to fetch files, but write them in the main thread (HDF5 is not thread-safe for writing)
        #now we play the waiting game
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            # We wrap 'process_single_star' to pass the constant arguments
            future_to_star = {
                executor.submit(process_single_star, row, keep_cols, custom_dtype): row 
                for row in final_catalog
            }
            
            successful_count = 0
            
            # get result from threads
            for future in tqdm(as_completed(future_to_star), total=len(final_catalog), desc="Parallel Download"):
                result = future.result()
                
                if result is not None:
                    flux, ivar, label_entry, star_id = result
                    
                    # Sequential Write
                    size = dset_flux.shape[0]
                    dset_flux.resize(size + 1, axis=0)
                    dset_ivar.resize(size + 1, axis=0)
                    dset_labels.resize(size + 1, axis=0)
                    
                    dset_flux[size] = flux
                    dset_ivar[size] = ivar
                    dset_labels[size] = label_entry
                    
                    successful_count += 1
    
    print(f"\nSUCCESS! Extracted {successful_count} stars.")
    print(f"Data saved to: {OUTPUT_PATH}")

if __name__ == "__main__":
    download_and_pack()