In [11]:
# -----------------------------
# 1️⃣ Imports
# -----------------------------
import os
import pandas as pd
import requests
from PIL import Image
from io import BytesIO
from tqdm import tqdm


In [12]:
# -----------------------------
# 2️⃣ Paths and Config
# -----------------------------
CLEANED_CSV = "merged_gz2_sdss.csv"   # your cleaned dataset CSV
IMG_DIR = "datafinal/images"               # folder to save images
MAPPING_CSV = "datafinal/image_mapping_new.csv"

os.makedirs(IMG_DIR, exist_ok=True)

# SDSS cutout parameters
IMG_SIZE = 128   # pixels
SCALE = 0.2      # arcsec/pixel

# Balanced target images per class
TARGET_PER_CLASS = 2000   # set based on smallest class you want to include


In [13]:
# -----------------------------
# 3️⃣ Load cleaned dataset
# -----------------------------
df = pd.read_csv(CLEANED_CSV)
print("Dataset loaded:", df.shape)
df.head()

Dataset loaded: (48664, 248)


Unnamed: 0,dr7objid,ra,dec,rastring,decstring,sample,gz2_class,total_classifications,total_votes,t01_smooth_or_features_a01_smooth_count,...,z,petroR50_r,petroR90_r,fracDeV_r,concentration_index,u_g_color,g_r_color,r_i_color,i_z_color,redshift
0,5.88009e+17,135.084396,52.49424,00:20.3,+52:29:39.3,original,Sb+t,42,332,1,...,11.91701,14.10413,36.87098,0.95881,2.614197,1.871747,0.891748,0.461641,0.281975,0.030118
1,5.8773e+17,246.921387,40.926968,27:41.1,+40:55:37.1,extra,Ei,48,154,41,...,12.0717,11.19578,35.93066,0.864857,3.209305,2.025663,0.82004,0.435435,0.398816,0.031728
2,5.87732e+17,183.062058,56.177532,12:14.9,+56:10:39.1,original,Sb?t,43,275,8,...,12.0424,9.284981,28.55589,1.0,3.075492,1.964868,0.975801,0.456394,0.309238,0.031083
3,5.87729e+17,119.617126,37.786617,58:28.1,+37:47:11.8,original,Ei,42,139,39,...,12.2125,10.26644,32.90302,1.0,3.20491,1.987321,0.892363,0.46651,0.306604,0.040825
4,5.87726e+17,209.473053,64.91098,57:53.5,+64:54:39.5,original,Er,35,102,26,...,12.0669,12.15487,38.81894,1.0,3.193695,1.940805,0.837738,0.391821,0.320923,0.032005


In [18]:
# -----------------------------
# 4️⃣ Assign Extended Labels
# -----------------------------
def assign_extended_label(row):
    fractions = {
        0: row['t01_smooth_or_features_a01_smooth_fraction'],   # Smooth
        1: row['t02_edgeon_a04_yes_fraction'] if 't02_edgeon_a04_yes_fraction' in row else 0,  # Edge-on
        # 2: row['t04_spiral_a08_spiral_fraction'] if 't04_spiral_a08_spiral_fraction' in row else 0,  # Spiral
        # 3: row['t03_bar_a06_bar_fraction'] if 't03_bar_a06_bar_fraction' in row else 0,  # Barred Spiral
    }
    return max(fractions, key=fractions.get)

df['extended_label'] = df.apply(assign_extended_label, axis=1)
print("Label distribution:\n", df['extended_label'].value_counts())


Label distribution:
 extended_label
0    42296
1     6368
Name: count, dtype: int64


In [19]:
# -----------------------------
# 5️⃣ Create balanced dataset
# -----------------------------
balanced_df = pd.DataFrame()

for label in df['extended_label'].unique():
    class_rows = df[df['extended_label'] == label]
    if len(class_rows) >= TARGET_PER_CLASS:
        sampled_rows = class_rows.sample(n=TARGET_PER_CLASS, random_state=42)
    else:
        sampled_rows = class_rows  # include all if fewer than target
    balanced_df = pd.concat([balanced_df, sampled_rows])

# Shuffle dataset
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

print("Balanced label distribution:\n", balanced_df['extended_label'].value_counts())


Balanced label distribution:
 extended_label
0    2000
1    2000
Name: count, dtype: int64


In [20]:
# -----------------------------
# 6️⃣ Function to fetch SDSS image
# -----------------------------
def fetch_sdss_image(ra, dec, filename, scale=SCALE, size=IMG_SIZE):
    url = f"http://skyserver.sdss.org/dr16/SkyServerWS/ImgCutout/getjpeg?ra={ra}&dec={dec}&scale={scale}&width={size}&height={size}"
    try:
        response = requests.get(url, timeout=10)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        img.save(filename)
        return True
    except Exception as e:
        print(f"Failed to fetch {filename}: {e}")
        return False


In [None]:
# -----------------------------
# 7️⃣ Download balanced images
# -----------------------------
mapping = []

for idx, row in tqdm(balanced_df.iterrows(), total=len(balanced_df)):
    ra, dec = row['ra'], row['dec']
    filename = os.path.join(IMG_DIR, f"{idx}.jpg")
    
    if not os.path.exists(filename):
        success = fetch_sdss_image(ra, dec, filename)
        if not success:
            continue
    
    # store mapping info
    mapping.append({
        "idx": idx,
        "image_filename": filename,
        "ra": ra,
        "dec": dec,
        "extended_label": row['extended_label'],
        **{col: row[col] for col in df.columns if 't0' in col}  # optional morphology columns
    })

# Save mapping CSV
mapping_df = pd.DataFrame(mapping)
mapping_df.to_csv(MAPPING_CSV, index=False)
print("Balanced mapping CSV saved:", MAPPING_CSV)

  0%|          | 0/4000 [00:00<?, ?it/s]

  9%|▉         | 376/4000 [09:20<1:20:26,  1.33s/it]