# ETL Galaxy Zoo 2 - Prepare Data

by BRAUX Owen and CAMBIER Elliot in 2026

## Imports 

In [None]:
import os
import pandas as pd
import numpy as np
import requests
from PIL import Image
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm
import shutil

## Configuration

In [None]:
# CONFIGURATION
BASE_DIR = os.getcwd()
DATA_RAW_DIR = os.path.join(BASE_DIR, 'data_raw')
IMAGES_DIR = os.path.join(BASE_DIR, 'dataset_images')
CSV_PATH = "gz2_hart16.csv.gz"
SDSS_URL = "http://skyserver.sdss.org/dr16/SkyServerWS/ImgCutout/getjpeg"
GZ2_URL = "https://gz2hart.s3.amazonaws.com/gz2_hart16.csv.gz"

if not os.path.exists(CSV_PATH):
    print("Downloading Galaxy Zoo 2 metadata...")
    r = requests.get(GZ2_URL)
    with open(CSV_PATH, 'wb') as f:
        f.write(r.content)
    print("Download complete.")

## Mapping of galaxies

In [None]:
# MAPPING GALAXIES :
# main probabilities
col_smooth = 't01_smooth_or_features_a01_smooth_debiased'
col_features = 't01_smooth_or_features_a02_features_or_disk_debiased'
col_edgeon = 't02_edgeon_a04_yes_debiased' 
col_bar = 't03_bar_a06_bar_debiased'      
col_spiral = 't04_spiral_a08_spiral_debiased' 

# rounded a16 and a18
col_round = 't07_rounded_a16_completely_round_debiased'
col_cigar = 't07_rounded_a18_cigar_shaped_debiased'

# Winding
col_tight = 't10_arms_winding_a28_tight_debiased'
col_medium = 't10_arms_winding_a29_medium_debiased'
col_loose = 't10_arms_winding_a30_loose_debiased'

# Fusion
col_merger = 't08_odd_feature_a24_merger_debiased'

CLASS_NAMES = {
    0: "0_Elliptique_Ronde",
    1: "1_Elliptique_Allongee",
    2: "2_Lenticulaire",
    3: "3_Spirale_Serree",
    4: "4_Spirale_Moyenne",
    5: "5_Spirale_Lache",
    6: "6_Barree_Serree",
    7: "7_Barree_Moyenne",
    8: "8_Barree_Lache",
    9: "9_Merger_Irreguliere"
}

## Creating Folders for classes 

In [None]:
os.makedirs(IMAGES_DIR, exist_ok=True)
for folder_name in CLASS_NAMES.values():
    os.makedirs(os.path.join(IMAGES_DIR, folder_name), exist_ok=True)

## Main ETL logic

In [None]:
def get_galaxy_class(row):
    try:     
        # CASE 1 : ROUND
        if row[col_smooth] > 0.8:
            if row[col_cigar] > 0.5:
                return 1 # E4-E7 
            elif row[col_round] > 0.5:
                return 0 # E0-E3 
            else:
                return 0 # DEFAULT E0-E3

        # CASE 2 : Merger
        if row[col_merger] > 0.6:
            return 9 

        # CASE 3 :S & S0
        if row[col_features] > 0.5 and row[col_edgeon] < 0.5:
            
            # S0
            if row[col_spiral] < 0.5:
                return 2 
                
            # Barre ?
            is_barred = (row[col_bar] > 0.5)
            
            # windings (Tight / Medium / Loose)
            winding_scores = [row[col_tight], row[col_medium], row[col_loose]]
            max_winding = np.argmax(winding_scores) 
            
            if is_barred:
                if max_winding == 0: return 6 # SBa
                if max_winding == 1: return 7 # SBb
                if max_winding == 2: return 8 # SBc
            else:
                if max_winding == 0: return 3 # Sa
                if max_winding == 1: return 4 # Sb
                if max_winding == 2: return 5 # Sc

        return -1 # trash
    
    except KeyError:
        return -1

def download_image_worker(row):
    try:
        label = row['label']
        objid = str(row['dr7objid'])
        ra = row['ra']
        dec = row['dec']
        
        folder = CLASS_NAMES[label]
        filepath = os.path.join(IMAGES_DIR, folder, f"{objid}.jpg")
        
        # skip if already exists
        if os.path.exists(filepath):
            return "EXIST"

        params = {
            'ra': ra, 'dec': dec, 'scale': 0.396, 
            'width': 128, 'height': 128, 'opt': ''
        }
        
        # slightly longer timeout since we share bandwidth
        response = requests.get(SDSS_URL, params=params, timeout=10)
        response.raise_for_status()
        
        img = Image.open(BytesIO(response.content))
        img.save(filepath)
        return "OK"
        
    except Exception as e:
        return "ERROR"
    


def run_pipeline(csv_file, target_per_class=2000, max_workers=10):
    print(f"Chargement du CSV '{csv_file}'...")
    if not os.path.exists(csv_file):
        raise FileNotFoundError(f"Le fichier {csv_file} est introuvable !")
        
    df = pd.read_csv(csv_file, compression='gzip')
    print("   -> Application de l'arbre de décision...")
    df['label'] = df.apply(get_galaxy_class, axis=1)

    # cleaning
    df_clean = df[df['label'] != -1].copy()
    print(f"   -> Galaxies classifiées valides : {len(df_clean)}")

    # Sampling
    print("\nÉquilibrage...")
    dfs_list = []
    
    for label in range(10):
        df_class = df_clean[df_clean['label'] == label]
        count = len(df_class)
        print(f"   -> Classe {label} [{CLASS_NAMES[label]}] : {count} dispo", end="")
        
        if count > target_per_class:
            df_sampled = df_class.sample(n=target_per_class, random_state=42)
            print(f" -> {target_per_class} gardées")
        else:
            df_sampled = df_class
            print(f" -> Tout gardé")
        
        dfs_list.append(df_sampled)

    df_final = pd.concat(dfs_list).sample(frac=1, random_state=42).reset_index(drop=True)
    
    # Medium csv 
    df_final.to_csv("dataset_metadata.csv", index=False)
    print(f"   -> Dataset final prêt : {len(df_final)} images à récupérer.")

    # DOWNLOAD
    print(f"\nLancement du téléchargement ({max_workers} threads)...")
    
    results = {"OK": 0, "EXIST": 0, "ERROR": 0}
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(download_image_worker, row): row for _, row in df_final.iterrows()}
        for future in tqdm(as_completed(futures), total=len(df_final), unit="img"):
            res = future.result()
            results[res] += 1
            
    print(f"-> Résultat : {results}")

    # D. ARCHIVAGE
    print("\nCréation de l'archive ZIP...")
    output_filename = "dataset_galaxies"
    shutil.make_archive(output_filename, 'zip', IMAGES_DIR)
    print(f"L'archive '{output_filename}.zip' est prête.")

In [None]:
if __name__ == "__main__":
    try:
        run_pipeline(CSV_PATH)
    except Exception as e:
        print(f"Erreur : {e}")