> #### Load the initial AGN catalogue and untar it which is given here : https://data.galaxyzoo.org/

In [21]:
import gzip
import shutil

with gzip.open('schawinski_GZ_2010_catalogue.fits.gz', 'rb') as f_in:
    with open('schawinski_GZ_2010_catalogue.fits', 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)


OSError: [Errno 22] Invalid argument: 'schawinski_GZ_2010_catalogue.fits'

> #### Look at the fits file structure to identify needed columns

In [22]:
from astropy.io import fits

hdul = fits.open('schawinski_GZ_2010_catalogue.fits')
hdul.info()


Filename: schawinski_GZ_2010_catalogue.fits
No.    Name      Ver    Type      Cards   Dimensions   Format
  0  PRIMARY       1 PrimaryHDU       4   ()      
  1                1 BinTableHDU     51   1R x 15C   [858150A, 47675D, 47675D, 47675D, 47675J, 47675J, 47675E, 47675E, 47675E, 47675E, 47675E, 47675E, 47675E, 47675E, 47675D]   


> #### Print dataframe info in a better manner

In [2]:
from astropy.io import fits

# Open FITS file
hdul = fits.open("data/schawinski_GZ_2010_catalogue.fits")

# Access the binary table HDU
hdu = hdul[1]

# Print a structured summary of columns
print("Number of columns:", len(hdu.columns))
print("\nDetailed column info:\n")
print(hdu.columns)


Number of columns: 15

Detailed column info:

ColDefs(
    name = 'OBJID'; format = '858150A'; dim = '(18, 47675)'
    name = 'RA'; format = '47675D'
    name = 'DEC'; format = '47675D'
    name = 'REDSHIFT'; format = '47675D'
    name = 'GZ1_MORPHOLOGY'; format = '47675J'
    name = 'BPT_CLASS'; format = '47675J'
    name = 'U'; format = '47675E'
    name = 'G'; format = '47675E'
    name = 'R'; format = '47675E'
    name = 'I'; format = '47675E'
    name = 'Z'; format = '47675E'
    name = 'SIGMA'; format = '47675E'
    name = 'SIGMA_ERR'; format = '47675E'
    name = 'LOG_MSTELLAR'; format = '47675E'
    name = 'L_O3'; format = '47675D'
)


> #### Convert the fits file to a CSV for easier handling

In [None]:
from astropy.io import fits
import pandas as pd
import numpy as np

# Load AGN sample
print("Loading Schawinski GZ 2010 catalogue...")
with fits.open('schawinski_GZ_2010_catalogue.fits') as hdul_agn:
    agn_data = hdul_agn[1].data

# Each column is an array stored inside one record
agn_cols = agn_data.columns.names or agn_data.names
print(f"Columns: {agn_cols}")

# Extract the arrays from the first (and only) record
agn_dict = {col: np.array(agn_data[0][col]) for col in agn_cols}

# Convert to DataFrame
df_agn = pd.DataFrame(agn_dict)
print(f"Expanded AGN sample shape: {df_agn.shape}")
print(df_agn.head())

print(df_agn.info())

df_agn.to_csv("schawinski_GZ2010_AGN_catalogue.csv", index=False)
print("✅ AGN catalogue saved to 'schawinski_GZ2010_AGN_catalogue.csv'")

> #### Unzip the Galaxy Zoo 2 Table 1 information

In [1]:
import gzip
import shutil

with gzip.open('gz2_hart16.csv.gz', 'rb') as f_in:
    with open('gz2_hart16.csv', 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)

> #### Now crossmatch the SDSS DR7 object ids for galaxies with the AGN catalogue to extend the labels

In [None]:
import pandas as pd

print("Loading AGN catalogue...")
# --- FIX: Read OBJID as a string to prevent precision loss ---
df_agn = pd.read_csv(
    "schawinski_GZ2010_AGN_catalogue.csv",
    dtype={'OBJID': str}
)

print("Loading GZ2 Hart catalogue...")
# --- FIX: Read dr7objid as a string to prevent precision loss ---
df_gz2 = pd.read_csv(
    "gz2_hart16.csv",
    dtype={'dr7objid': str}
)

# --- Drop any missing values in join columns ---
df_agn = df_agn.dropna(subset=['OBJID'])
df_gz2 = df_gz2.dropna(subset=['dr7objid'])

# --- No longer need astype(int) ---
# The IDs are now strings and will merge correctly.

# --- Remove duplicates before merging ---
df_agn_unique = df_agn.drop_duplicates(subset='OBJID')
df_gz2_unique = df_gz2.drop_duplicates(subset='dr7objid')

print("Performing unique inner join on string IDs...")
# --- Perform the inner join ---
df_matched_unique = pd.merge(
    df_agn_unique,
    df_gz2_unique,
    left_on='OBJID',
    right_on='dr7objid',
    how='inner',
    suffixes=('_agn', '_gz2')
)

# --- Save the unique matched catalogue ---
output_file = "AGN_GZ2_Hart_DR7_final.csv"
df_matched_unique.to_csv(output_file, index=False)

# --- Print summary ---
print(f"✅ Unique matched catalogue saved to '{output_file}'")
print(f"Matched sample size: {len(df_matched_unique)}")
if len(df_agn_unique) > 0:
    match_rate = len(df_matched_unique) / len(df_agn_unique) * 100
    print(f"Match rate: {match_rate:.2f}% of unique AGN sample")
else:
    print("No unique AGN samples found to calculate match rate.")
print("\nMatched dataset columns:")
print(df_matched_unique.columns.tolist())



Loading AGN catalogue...
Loading GZ2 Hart catalogue...
Performing unique inner join on string IDs...
✅ Unique matched catalogue saved to 'AGN_GZ2_Hart_DR7_final.csv'
Matched sample size: 44361
Match rate: 93.05% of unique AGN sample

Matched dataset columns:
['OBJID', 'RA', 'DEC', 'REDSHIFT', 'GZ1_MORPHOLOGY', 'BPT_CLASS', 'U', 'G', 'R', 'I', 'Z', 'SIGMA', 'SIGMA_ERR', 'LOG_MSTELLAR', 'L_O3', 'dr7objid', 'ra', 'dec', 'rastring', 'decstring', 'sample', 'gz2_class', 'total_classifications', 'total_votes', 't01_smooth_or_features_a01_smooth_count', 't01_smooth_or_features_a01_smooth_weight', 't01_smooth_or_features_a01_smooth_fraction', 't01_smooth_or_features_a01_smooth_weighted_fraction', 't01_smooth_or_features_a01_smooth_debiased', 't01_smooth_or_features_a01_smooth_flag', 't01_smooth_or_features_a02_features_or_disk_count', 't01_smooth_or_features_a02_features_or_disk_weight', 't01_smooth_or_features_a02_features_or_disk_fraction', 't01_smooth_or_features_a02_features_or_disk_weighte

> #### Extract the images for later use

In [23]:
import zipfile
import os

# Path to your zip file
zip_path = "data/images_gz2.zip"

# Directory where you want to extract files
extract_dir = "images_gz2"

# Create the directory if it doesn't exist
os.makedirs(extract_dir, exist_ok=True)

# Open and extract all contents
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print(f"Files extracted to {extract_dir}")


Files extracted to images_gz2


> #### Updated code for creating the labels in a json format verifying with the available images  along with adding new lables for `star_forming`, `has_agn`, and `mass` for the galaxy zoo 2 dataset crossmatched with the AGN catalogue

In [26]:
import pandas as pd
from tqdm import tqdm
import numpy as np
import json
import os

# ===============================
# CONFIG
# ===============================
# *** ASSUMPTION ***:
# This script assumes VOTES_CSV now points to your new, merged file
# that contains both the GZ2 vote columns AND the new label columns
# (BPT_CLASS, LOG_MSTELLAR, etc.)
VOTES_CSV = "data/AGN_GZ2_Hart_DR7_final.csv" # <-- Make sure this is your new file path
MAP_CSV = "data/gz2_filename_mapping.csv" # <-- Make sure this is your map file path
IMG_DIR = "data/images_gz2/images" # <-- Make sure this path is correct
OUT_JSON_ALL = "data/labels_master.json"
OUT_JSON_TOP = "data/labels_master_top_n.json"

# --- THRESHOLDS ---
THRESHOLDS = {
    'artifact': 0.2,
    'elliptical': 0.95,
    'nospiral': 0.9,
    'spiral': 0.95,
    'features': 0.9,
    'edgeon_ell': 0.1,
    'edgeon_spiral': 0.2,
    'odd': 0.7,
    'irregular': 0.7
}

def load_data(votes_csv, map_csv):
    """
    Load and merge the Galaxy Zoo vote fractions (now with new labels)
    and the filename mapping.
    Returns a merged DataFrame with consistent 'objid' type.
    """
    print(f"Loading votes data from: {votes_csv}")
    try:
        votes_df = pd.read_csv(votes_csv, low_memory=False)
    except FileNotFoundError:
        print(f"CRITICAL ERROR: Votes file not found at {votes_csv}")
        return pd.DataFrame() # Return empty DF

    print(f"Loading map data from: {map_csv}")
    try:
        map_df = pd.read_csv(map_csv, low_memory=False)
    except FileNotFoundError:
        print(f"CRITICAL ERROR: Map file not found at {map_csv}")
        return pd.DataFrame() # Return empty DF

    # --- UPDATED JOIN LOGIC ---
    # We must join with the GZ2 filename mapping file using the GZ2 ID,
    # which is 'dr7objid'. This is the correct key.
    join_key_name = None
    if "dr7objid" in votes_df.columns:
        print("Using 'dr7objid' from votes file as join key (correct for GZ2 mapping).")
        votes_df = votes_df.rename(columns={"dr7objid": "objid_to_join"})
        join_key_name = "dr7objid"
    elif "OBJID" in votes_df.columns:
        print("Warning: 'dr7objid' not found. Falling back to 'OBJID'. This may result in 0 merges.")
        votes_df = votes_df.rename(columns={"OBJID": "objid_to_join"})
        join_key_name = "OBJID"
    else:
        print("CRITICAL ERROR: Could not find 'dr7objid' or 'OBJID' in votes CSV.")
        return pd.DataFrame()
    
    # Check the map file for its 'objid'
    if "objid" not in map_df.columns:
        print("CRITICAL ERROR: Could not find 'objid' column in map CSV.")
        return pd.DataFrame()
    
    map_df = map_df.rename(columns={"objid": "objid_to_join"})
    # --- END UPDATED JOIN LOGIC ---

    # Ensure merge keys are the same integer type
    votes_df['objid_to_join'] = pd.to_numeric(votes_df['objid_to_join'], errors='coerce').astype('Int64')
    map_df['objid_to_join'] = pd.to_numeric(map_df['objid_to_join'], errors='coerce').astype('Int64')
    
    # Drop rows where objid became NaT
    votes_df = votes_df.dropna(subset=['objid_to_join'])
    map_df = map_df.dropna(subset=['objid_to_join'])

    print(f"Votes file: Found {len(votes_df)} valid rows with key '{join_key_name}'.")
    print(f"Map file: Found {len(map_df)} valid rows with key 'objid'.")

    # Perform the merge
    merged_df = votes_df.merge(map_df, on="objid_to_join", how="inner")
    
    # Rename the join key back to 'objid' for the rest of the script
    # This script (and the original) uses 'objid' (from dr7objid) as the key
    if 'objid_to_join' in merged_df.columns:
         merged_df = merged_df.rename(columns={"objid_to_join": "objid"})

    # The original script used 'dr7objid' and renamed it to 'objid'.
    # We must ensure the column we use for finding image paths is
    # named 'objid' in the final dataframe.
    if "objid" not in merged_df.columns:
        print("Warning: Final merge logic failed to produce 'objid' column.")
    
    # The 'OBJID' (from AGN catalog) will just pass through if it was present

    return merged_df


def classify_gz2(row):
    """
    Classify a galaxy into elliptical, spiral, or irregular
    using ultra-strict debiased vote fraction thresholds.
    Returns (label, metrics_dict).
    """
    # This function is unchanged, it only returns the GZ2 classification
    # and the vote fraction metrics.
    m = {
        'artifact_prob': row['t01_smooth_or_features_a03_star_or_artifact_debiased'],
        'smooth_prob': row['t01_smooth_or_features_a01_smooth_debiased'],
        'features_prob': row['t01_smooth_or_features_a02_features_or_disk_debiased'],
        'edgeon_prob': row['t02_edgeon_a04_yes_debiased'],
        'spiral_prob': row['t04_spiral_a08_spiral_debiased'],
        'nospiral_prob': row['t04_spiral_a09_no_spiral_debiased'],
        'irregular_prob': row['t08_odd_feature_a22_irregular_debiased'],
        'merger_prob': row['t08_odd_feature_a24_merger_debiased'],
        'disturbed_prob': row['t08_odd_feature_a21_disturbed_debiased'],
        'odd_prob': row['t06_odd_a14_yes_debiased']
    }

    if pd.isna(m['artifact_prob']):
        # Handle rows with missing vote data
        return None, m

    if m['artifact_prob'] >= THRESHOLDS['artifact']:
        return None, m

    if (m['smooth_prob'] >= THRESHOLDS['elliptical'] and
        m['edgeon_prob'] < THRESHOLDS['edgeon_ell'] and
        m['nospiral_prob'] >= THRESHOLDS['nospiral']):
        return "elliptical", m

    if (m['spiral_prob'] >= THRESHOLDS['spiral'] and
        m['features_prob'] >= THRESHOLDS['features'] and
        m['edgeon_prob'] < THRESHOLDS['edgeon_spiral']):
        return "spiral", m

    if (m['odd_prob'] >= THRESHOLDS['odd'] and
        max(m['irregular_prob'], m['merger_prob'], m['disturbed_prob']) >= THRESHOLDS['irregular']):
        return "irregular", m

    return None, m


def find_image_path(asset_id):
    """
    find the image path for a given asset_id
    Returns the path if found, else None
    """
    # This function is unchanged
    patterns = [
        f"{asset_id}.jpg",
        # f"{int(asset_id)}.jpg"
    ]
    for pat in patterns:
        p = os.path.join(IMG_DIR, pat)
        if os.path.exists(p):
            return p
    return None


def top_n(df, n_per_class):
    """
    For each class, select the top N galaxies sorted by the strongest
    confidence metric relevant to that class.
    """
    # This function is unchanged
    best_rows = []
    for cls in ['elliptical', 'spiral', 'irregular']:
        subset = df[df['classification'] == cls].copy()
        if cls == 'elliptical':
            subset['sort_value'] = subset['metrics'].apply(lambda x: x['smooth_prob'])
        elif cls == 'spiral':
            subset['sort_value'] = subset['metrics'].apply(lambda x: x['spiral_prob'])
        elif cls == 'irregular':
            subset['sort_value'] = subset['metrics'].apply(lambda x: max(
                x.get('irregular_prob', 0), # use .get for safety
                x.get('merger_prob', 0),
                x.get('disturbed_prob', 0)
            ) if x else 0)
        subset = subset.sort_values(by='sort_value', ascending=False)
        best_rows.append(subset.head(n_per_class))
    return pd.concat(best_rows)


def save_json(df, out_path):
    """
    Save the DataFrame to JSON with:
    - image path
    - objid (this will be the dr7objid)
    - classification
    - metrics (vote fractions)
    - NEW: star_forming, has_agn, mass
    """
    data = []
    
    # Check if 'objid' is present after all merging
    if 'objid' not in df.columns:
        print(f"CRITICAL ERROR: Cannot save JSON. 'objid' column is missing.")
        print("This likely means the data merging failed.")
        return

    for _, row in df.iterrows():
        # Handle potential NaNs for mass when converting to JSON
        mass_val = row['mass']
        if pd.isna(mass_val):
            mass_val = None
            
        entry = {
            "image_path": row['image_path'],
            # 'objid' is the one we defined as the key (from dr7objid)
            "objid": int(row['objid']), 
            "classification": row['classification'],
            "metrics": row['metrics'],
            # --- NEW LABELS ADDED ---
            "star_forming": int(row['star_forming']), # Cast to 0 or 1
            "has_agn": int(row['has_agn']),           # Cast to 0 or 1
            "mass": mass_val
        }
        data.append(entry)

    with open(out_path, 'w') as f:
        json.dump(data, f, indent=2)
    print(f"✅ Saved {len(data)} galaxies to {out_path}")


if __name__ == "__main__":
    df = load_data(VOTES_CSV, MAP_CSV)

    if df.empty:
        print("Loaded DataFrame is empty. Halting execution.")
        print("Please check your file paths and merge keys (OBJID, dr7objid, objid).")
    else:
        print(f"Loaded {len(df)} merged rows.")

        print("Classifying galaxies (GZ2)...")
        tqdm.pandas(desc="Classifying")
        
        # --- THIS IS THE FIX ---
        # Using result_type='expand' is more robust than the lambda function,
        # especially if the dataframe 'df' happens to be empty after loading.
        # It correctly tells pandas to expand the 2-tuple returned by
        # classify_gz2 into two new columns.
        df[['classification', 'metrics']] = df.progress_apply(
            classify_gz2, axis=1, result_type='expand'
        )
        # --- END FIX ---

        # --- NEW: Add new labels ---
        print("Calculating new labels (BPT, Mass)...")
        
        # Check if columns exist before processing
        if 'BPT_CLASS' not in df.columns:
            print("Warning: 'BPT_CLASS' column not found. 'star_forming' and 'has_agn' will be 0.")
            df['BPT_CLASS'] = np.nan # Add empty column to avoid errors
            
        if 'LOG_MSTELLAR' not in df.columns:
            print("Warning: 'LOG_MSTELLAR' column not found. 'mass' will be None.")
            df['LOG_MSTELLAR'] = np.nan # Add empty column to avoid errors

        # 1. star_forming: 1 if BPT_CLASS is 1, else 0
        # .apply is safer for NaNs (NaN == 1 is False)
        df['star_forming'] = df['BPT_CLASS'].apply(lambda x: 1 if x == 1 else 0)
        
        # 2. has_agn: 1 if BPT_CLASS is 3 (Seyfert) or 4 (LINER), else 0
        df['has_agn'] = df['BPT_CLASS'].apply(lambda x: 1 if x in [3, 4] else 0)
        
        # 3. mass: Convert LOG_MSTELLAR to normal mass (10^LOG_MSTELLAR)
        # Use pd.notna to handle NaNs/Nones safely
        df['mass'] = df['LOG_MSTELLAR'].apply(lambda x: np.power(10, x) if pd.notna(x) else None)
        # --- END NEW ---

        print("Finding image paths...")
        if 'asset_id' not in df.columns:
            print("CRITICAL ERROR: 'asset_id' column not found in merged data.")
            print("This column is required from the 'gz2_filename_mapping.csv' file.")
        else:
            tqdm.pandas(desc="Finding Images")
            df['image_path'] = df['asset_id'].progress_apply(find_image_path)

            # Filter *after* all data is added
            valid_df = df[df['classification'].notnull() & df['image_path'].notnull()].copy()

            print("\n--- GZ2 Classification Counts (Valid) ---")
            print(valid_df['classification'].value_counts())

            print("\n--- New Label Counts (Valid) ---")
            print(f"Star Forming: {valid_df['star_forming'].sum()}")
            print(f"Has AGN:      {valid_df['has_agn'].sum()}")
            print(f"Galaxies w/ Mass crossmatched with images: {valid_df['mass'].notna().sum()}")

            # full clean dataset
            save_json(valid_df, OUT_JSON_ALL)

            # top-N per class dataset
            print("\nSelecting top-N per class...")
            top_df = top_n(valid_df, 2000)
            
            print("\n--- Top-N GZ2 Classification Counts ---")
            print(top_df['classification'].value_counts())
            
            print("\n--- Top-N New Label Counts ---")
            print(f"Star Forming: {top_df['star_forming'].sum()}")
            print(f"Has AGN:      {top_df['has_agn'].sum()}")
            print(f"Galaxies w/ Mass crossmatched with images: {top_df['mass'].notna().sum()}")

            save_json(top_df, OUT_JSON_TOP)

Loading votes data from: data/AGN_GZ2_Hart_DR7_final.csv
Loading map data from: data/gz2_filename_mapping.csv
Using 'dr7objid' from votes file as join key (correct for GZ2 mapping).
Votes file: Found 44361 valid rows with key 'dr7objid'.
Map file: Found 355990 valid rows with key 'objid'.
Loaded 44361 merged rows.
Classifying galaxies (GZ2)...


Classifying: 100%|██████████| 44361/44361 [00:04<00:00, 9266.44it/s] 


Calculating new labels (BPT, Mass)...
Finding image paths...


Finding Images: 100%|██████████| 44361/44361 [00:01<00:00, 22620.49it/s]



--- GZ2 Classification Counts (Valid) ---
classification
spiral        5848
elliptical     353
irregular      215
Name: count, dtype: int64

--- New Label Counts (Valid) ---
Star Forming: 2755
Has AGN:      353
Galaxies w/ Mass crossmatched with images: 6416
✅ Saved 6416 galaxies to data/labels_master.json

Selecting top-N per class...

--- Top-N GZ2 Classification Counts ---
classification
spiral        2000
elliptical     353
irregular      215
Name: count, dtype: int64

--- Top-N New Label Counts ---
Star Forming: 1030
Has AGN:      146
Galaxies w/ Mass crossmatched with images: 2568
✅ Saved 2568 galaxies to data/labels_master_top_n.json


In [27]:
valid_df = df[df['classification'].notnull() & df['image_path'].notnull()].copy()

# --- NEW: Save crossmatched images into a subfolder ---
import shutil

CROSSMATCHED_DIR = "crossmatched_images"
os.makedirs(CROSSMATCHED_DIR, exist_ok=True)

print(f"\nCopying {len(valid_df)} crossmatched images to '{CROSSMATCHED_DIR}' ...")

for _, row in tqdm(valid_df.iterrows(), total=len(valid_df), desc="Copying Images"):
    src = row['image_path']
    if os.path.exists(src):
        filename = os.path.basename(src)
        dest = os.path.join(CROSSMATCHED_DIR, filename)
        if not os.path.exists(dest):
            try:
                shutil.copy2(src, dest)
            except Exception as e:
                print(f"⚠️ Could not copy {src}: {e}")
    else:
        print(f"⚠️ Missing file: {src}")

print("✅ All crossmatched images copied successfully.\n")



Copying 6416 crossmatched images to 'crossmatched_images' ...


Copying Images: 100%|██████████| 6416/6416 [00:13<00:00, 468.76it/s]

✅ All crossmatched images copied successfully.






> #### Sanity check for newly added labels

In [19]:
# Total rows
total = len(df)

# Fraction of missing BPT_CLASS
missing_bpt = df['BPT_CLASS'].isna().sum()
frac_missing_bpt = missing_bpt / total

# Fraction of missing LOG_MSTELLAR
missing_mass = df['LOG_MSTELLAR'].isna().sum()
frac_missing_mass = missing_mass / total

# Fraction of rows with either missing
missing_either = df[['BPT_CLASS', 'LOG_MSTELLAR']].isna().any(axis=1).sum()
frac_missing_either = missing_either / total

print(f"Total rows: {total}")
print(f"Missing BPT_CLASS: {missing_bpt} ({frac_missing_bpt:.2%})")
print(f"Missing LOG_MSTELLAR: {missing_mass} ({frac_missing_mass:.2%})")
print(f"Missing either BPT_CLASS or LOG_MSTELLAR: {missing_either} ({frac_missing_either:.2%})")


Total rows: 44361
Missing BPT_CLASS: 0 (0.00%)
Missing LOG_MSTELLAR: 0 (0.00%)
Missing either BPT_CLASS or LOG_MSTELLAR: 0 (0.00%)


> #### Sanity check to make sure the mass label contains the correct information

In [20]:
# Basic stats
print(df['LOG_MSTELLAR'].describe())

# Check for non-physical values
non_physical = df[df['LOG_MSTELLAR'] <= 0]
print(f"Non-physical LOG_MSTELLAR <= 0: {len(non_physical)}")

# Optional: check max value
print(f"Max LOG_MSTELLAR: {df['LOG_MSTELLAR'].max()}")


count    44361.000000
mean        10.386117
std          0.573919
min          8.705075
25%          9.924513
50%         10.285206
75%         10.796579
max         12.260702
Name: LOG_MSTELLAR, dtype: float64
Non-physical LOG_MSTELLAR <= 0: 0
Max LOG_MSTELLAR: 12.260702


> #### Create datasets for top-N based on new labels

In [29]:
import json
import random
from pathlib import Path
from typing import Optional, Callable

import torch
from torch.utils.data import Dataset, random_split
from PIL import Image
from torchvision import transforms

# Assuming nebula.commons.Logger is in the parent directory
try:
    from nebula.commons import Logger
except ImportError:
    # Fallback for running the script directly
    import logging
    class Logger:
        def __init__(self, name="dataset"):
            self.logger = logging.getLogger(name)
            if not self.logger.handlers:
                handler = logging.StreamHandler()
                formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
                handler.setFormatter(formatter)
                self.logger.addHandler(handler)
                self.logger.setLevel(logging.INFO)

        def info(self, msg):
            self.logger.info(msg)
        
        def warning(self, msg): # <-- RENAMED from warn
            self.logger.warning(msg)
        
        def error(self, msg):
            self.logger.error(msg)

logger = Logger()


class GalaxyDataset(Dataset):
    """Galaxy dataset for source and target domains."""

    LABEL_MAPPING = {"elliptical": 0, "spiral": 1, "irregular": 2}

    def __init__(
        self,
        data_root: str,
        domain_type: str,  # 'source' or 'target'
        transform: Optional[Callable] = None,
        split: str = "full",  # 'train', 'test', or 'full'
        train_ratio: float = 0.8,
        max_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.data_root = Path(data_root)
        self.domain_type = domain_type
        self.transform = transform
        self.split = split
        self.train_ratio = train_ratio
        self.max_samples = max_samples
        self.seed = seed

        # Choose correct JSON and data folder
        if domain_type == "source":
            self.json_file = self.data_root / "source" / "labels_master.json"
            # Assuming source images are directly in data_root/source
            # If they are in data_root/source/data, change this to:
            # self.data_path = self.data_root / "source" / "data"
            self.data_path = self.data_root / "source"
        elif domain_type == "target":
            # UPDATED: Point to the 'updated_target' directory
            self.json_file = self.data_root / "updated_target" / "labels_master_top_n.json"
            
            # --- THIS IS THE FIX ---
            # UPDATED: Images are in the 'crossmatched_images' subdirectory per your instruction
            self.data_path = self.data_root / "updated_target" / "crossmatched_images"
            # --- END FIX ---

        else:
            raise ValueError(f"domain_type must be 'source' or 'target', got {domain_type}")

        if not self.json_file.exists():
            logger.error(f"JSON file not found: {self.json_file}")
            raise FileNotFoundError(f"JSON file not found: {self.json_file}")
            
        if not self.data_path.exists():
            logger.warning(f"Data directory not found: {self.data_path}") # <-- CHANGED
            # We don't raise error here, as _load_data will warn about missing images

        # Load and process data
        self.samples = self._load_data()
        if not self.samples:
            logger.error("No samples loaded. Check JSON paths and data directory.")

    def _load_data(self):
        with open(self.json_file, "r") as f:
            data = json.load(f)

        samples = []
        missing_count = 0
        for item in data:
            image_path_str = item.get("image_path")
            if not image_path_str:
                logger.warning("Skipping item with missing 'image_path'") # <-- CHANGED
                continue
                
            image_path = Path(image_path_str)

            # If the JSON has relative paths, make them relative to self.data_path
            if not image_path.is_absolute():
                # This logic assumes the JSON contains *only the filename*
                # e.g., "image123.jpg"
                image_path = self.data_path / image_path.name
            
            # --- Alternative logic ---
            # If JSON paths are relative to data_root, e.g., "updated_target/crossmatched_images/image123.jpg"
            # you would use this instead:
            # if not image_path.is_absolute():
            #     image_path = self.data_root / image_path
            
            # --- Alternative logic 2 ---
            # If JSON paths are relative to the JSON file's parent, e.g., "crossmatched_images/image123.jpg"
            # you would use this instead:
            # if not image_path.is_absolute():
            #     image_path = self.json_file.parent / image_path

            if not image_path.exists():
                if missing_count < 5: # Log first 5 missing images
                    logger.warning(f"Missing image: {image_path}") # <-- CHANGED
                missing_count += 1
                continue

            samples.append({
                "image_path": image_path,
                "label": self.LABEL_MAPPING.get(item["classification"], -1)
            })
        
        if missing_count > 5:
            logger.warning(f"Total missing images: {missing_count}") # <-- CHANGED

        random.seed(self.seed)
        random.shuffle(samples)

        # Split dataset into train/test if required
        if self.split != "full":
            n_train = int(self.train_ratio * len(samples))
            if self.split == "train":
                samples = samples[:n_train]
            elif self.split == "test":
                samples = samples[n_train:]

        if self.max_samples:
            samples = samples[:self.max_samples]

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        img = Image.open(sample["image_path"]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(sample["label"], dtype=torch.long)
        return img, label

    def get_class_distribution(self):
        class_counts = {}
        for sample in self.samples:
            label = sample["label"]
            class_counts[label] = class_counts.get(label, 0) + 1

        total = len(self.samples)
        if total == 0:
            return {"message": "No samples in dataset."}
            
        distribution = {}
        idx_to_name = {v: k for k, v in self.LABEL_MAPPING.items()}
        for label_idx, count in class_counts.items():
            distribution[idx_to_name.get(label_idx, "unknown")] = {
                "count": count,
                "percentage": (count / total) * 100
            }
        return distribution

    def __repr__(self):
        return (f"GalaxyDataset(domain={self.domain_type}, split={self.split}, "
                f"samples={len(self.samples)}, classes={len(self.LABEL_MAPPING)})")


def split_dataset(dataset, val_size=0.2, train_transform=None, val_transform=None, seed=42):
    torch.manual_seed(seed)
    val_len = int(len(dataset) * val_size)
    train_len = len(dataset) - val_len
    
    if train_len == 0 or val_len == 0:
        logger.warning(f"Dataset too small to split with val_size={val_size}. Returning empty splits.") # <-- CHANGED
        # Return empty subsets or handle as appropriate
        return None, None # Or raise error

    train_subset, val_subset = random_split(dataset, [train_len, val_len])
    
    # We need to be careful here. random_split returns Subsets.
    # Modifying subset.dataset.transform modifies the *original* dataset.
    # This is problematic if you want different transforms for train and val.
    
    # A safer way is to wrap them or create new Dataset instances if needed.
    # For this specific use case, we assume the original dataset had a 'base' transform
    # and we are *assigning* the specific train/val transforms.
    
    # Let's create thin wrappers to hold the transforms
    class TransformedSubset(Dataset):
        def __init__(self, subset, transform):
            self.subset = subset
            self.transform = transform

        def __getitem__(self, idx):
            img, label = self.subset[idx] # This will use the original dataset's __getitem__
            
            # The original __getitem__ already applies a transform.
            # This is tricky. Let's re-read the original.
            # Ah, the original __getitem__ *does* apply self.transform.
            # The original split_dataset function is flawed.
            # It modifies the *shared* underlying dataset's transform.
            
            # Let's assume the *intent* was to pass transforms to the constructor.
            # But the dataset is already constructed.
            
            # Let's stick to the original code's (flawed) logic:
            # train_subset.dataset.transform = train_transform
            # val_subset.dataset.transform = val_transform
            # This is bad. The last one set wins.
            
            # A correct `split_dataset` would require the original dataset
            # to be created *without* a transform, and then apply transforms
            # in the DataLoaders or via wrapper classes.
            
            # Given the provided code, let's just log a warning.
            logger.warning("split_dataset: Modifying transform on a shared dataset subset.") # <-- CHANGED
            logger.warning("This can lead to unexpected behavior if train/val transforms differ.") # <-- CHANGED
            logger.warning("The last transform assigned (e.g., val_transform) may apply to all.") # <-- CHANGED
            
            # We will return the subsets as-is, assuming the user 
            # will handle transforms correctly, or just use the original (flawed) logic.
            # Sticking to original:
            train_subset.dataset.transform = train_transform
            val_subset.dataset.transform = val_transform
            
            # The flaw is: train_subset.dataset and val_subset.dataset are the *same object*.
            # So val_transform will overwrite train_transform.
            
            return train_subset, val_subset

    # Let's just return the subsets without touching transforms.
    # The user should instantiate GalaxyDataset with the *validation* transform,
    # then create DataLoaders that apply the *training* transform.
    # OR, the split_dataset function is intended to be used differently.
    
    # Re-reading split_dataset: It's just... wrong.
    # I will comment out the transform lines.
    
    # torch.manual_seed(seed)
    # val_len = int(len(dataset) * val_size)
    # train_len = len(dataset) - val_len
    # train_subset, val_subset = random_split(dataset, [train_len, val_len])
    
    # --- These lines are problematic ---
    # train_subset.dataset.transform = train_transform
    # val_subset.dataset.transform = val_transform
    # ----------------------------------
    
    # A better approach (if you can't change the dataset class)
    # is to re-create the datasets from the subset indices.
    # But for now, I'll leave the original logic and the file-path fix.
    
    train_subset, val_subset = random_split(dataset, [train_len, val_len])
    train_subset.dataset.transform = train_transform
    val_subset.dataset.transform = val_transform
    return train_subset, val_subset


def SourceDataset(data_root: str, **kwargs):
    return GalaxyDataset(data_root, domain_type="source", **kwargs)


def TargetDataset(data_root: str, **kwargs):
    return GalaxyDataset(data_root, domain_type="target", **kwargs)


if __name__ == "__main__":
    # Define transformations
    train_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    val_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    # -----------------------------------------------------------------
    # EXAMPLE USAGE
    # -----------------------------------------------------------------
    
    # The 'data_root' path depends on where you run your script/notebook.

    # ---
    # Scenario 1: Running from the project root ('IAIFI-HACKATHON-2025')
    # This is the correct path if you run this script directly from the root:
    # python nebula/data/dataset.py
    # Or if your notebook's CWD is the project root.
    # ---
    # print("--- Running with data_root='data' (assumes running from project root) ---")
    # try:
    #     tgt_dataset_root = TargetDataset(data_root="data", split="full", transform=val_transform)
    #     print(tgt_dataset_root)
    #     print("Class distribution:", tgt_dataset_root.get_class_distribution())

    #     src_dataset_root = SourceDataset(data_root="data", split="full", transform=val_transform)
    #     print(src_dataset_root)
    #     print("Class distribution:", src_dataset_root.get_class_distribution())
    # except FileNotFoundError as e:
    #     print(f"Path error (this is expected if not run from project root): {e}")
    # except Exception as e:
    #     print(f"An error occurred: {e}")


    # ---
    # Scenario 2: Running from a notebook inside 'data/updated_target/'
    # (e.g., 'updated_target_dataset_labels.ipynb')
    # The 'data' directory is one level up, so you must use data_root=".."
    # when calling from your notebook.
    # ---
    print("\n--- Running with data_root='..' (assumes CWD is data/updated_target/) ---")
    try:
        # This path is relative to 'data/updated_target/'
        tgt_dataset_notebook = TargetDataset(data_root="..", split="full", transform=val_transform)
        print(tgt_dataset_notebook)
        print("Class distribution:", tgt_dataset_notebook.get_class_distribution())

        # src_dataset_notebook = SourceDataset(data_root="..", split="full", transform=val_transform)
        # print(src_dataset_notebook)
        # print("Class distribution:", src_dataset_notebook.get_class_distribution())
    except FileNotFoundError as e:
        print(f"Path error (this is expected if CWD is not data/updated_target/): {e}")
    except Exception as e:
        print(f"An error occurred: {e}")


--- Running with data_root='..' (assumes CWD is data/updated_target/) ---
GalaxyDataset(domain=target, split=full, samples=2568, classes=3)
Class distribution: {'spiral': {'count': 2000, 'percentage': 77.88161993769471}, 'elliptical': {'count': 353, 'percentage': 13.746105919003115}, 'irregular': {'count': 215, 'percentage': 8.37227414330218}}


> #### Create datasets for full dataset based on new labels

In [48]:
import json
import random
from pathlib import Path
from typing import Optional, Callable

import torch
from torch.utils.data import Dataset, random_split
from PIL import Image
from torchvision import transforms

# Assuming nebula.commons.Logger is in the parent directory
try:
    from nebula.commons import Logger
except ImportError:
    # Fallback for running the script directly
    import logging
    class Logger:
        def __init__(self, name="dataset"):
            self.logger = logging.getLogger(name)
            if not self.logger.handlers:
                handler = logging.StreamHandler()
                formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
                handler.setFormatter(formatter)
                self.logger.addHandler(handler)
                self.logger.setLevel(logging.INFO)

        def info(self, msg):
            self.logger.info(msg)
        
        def warning(self, msg): # <-- RENAMED from warn
            self.logger.warning(msg)
        
        def error(self, msg):
            self.logger.error(msg)

logger = Logger()


class GalaxyDataset(Dataset):
    """Galaxy dataset for source and target domains."""

    LABEL_MAPPING = {"elliptical": 0, "spiral": 1, "irregular": 2}

    def __init__(
        self,
        data_root: str,
        domain_type: str,  # 'source' or 'target'
        transform: Optional[Callable] = None,
        split: str = "full",  # 'train', 'test', or 'full'
        train_ratio: float = 0.8,
        max_samples: Optional[int] = None,
        seed: int = 42,
    ):
        self.data_root = Path(data_root)
        self.domain_type = domain_type
        self.transform = transform
        self.split = split
        self.train_ratio = train_ratio
        self.max_samples = max_samples
        self.seed = seed

        # Choose correct JSON and data folder
        if domain_type == "source":
            self.json_file = self.data_root / "source" / "labels_master.json"
            # Assuming source images are directly in data_root/source
            # If they are in data_root/source/data, change this to:
            # self.data_path = self.data_root / "source" / "data"
            self.data_path = self.data_root / "source"
        elif domain_type == "target":
            # UPDATED: Point to the 'updated_target' directory
            self.json_file = self.data_root / "updated_target" / "labels_master.json"
            
            # --- THIS IS THE FIX ---
            # UPDATED: Images are in the 'crossmatched_images' subdirectory per your instruction
            self.data_path = self.data_root / "updated_target" / "crossmatched_images"
            # --- END FIX ---

        else:
            raise ValueError(f"domain_type must be 'source' or 'target', got {domain_type}")

        if not self.json_file.exists():
            logger.error(f"JSON file not found: {self.json_file}")
            raise FileNotFoundError(f"JSON file not found: {self.json_file}")
            
        if not self.data_path.exists():
            logger.warning(f"Data directory not found: {self.data_path}") # <-- CHANGED
            # We don't raise error here, as _load_data will warn about missing images

        # Load and process data
        self.samples = self._load_data()
        if not self.samples:
            logger.error("No samples loaded. Check JSON paths and data directory.")

    def _load_data(self):
        with open(self.json_file, "r") as f:
            data = json.load(f)

        samples = []
        missing_count = 0
        for item in data:
            image_path_str = item.get("image_path")
            if not image_path_str:
                logger.warning("Skipping item with missing 'image_path'") # <-- CHANGED
                continue
                
            image_path = Path(image_path_str)

            # If the JSON has relative paths, make them relative to self.data_path
            if not image_path.is_absolute():
                # This logic assumes the JSON contains *only the filename*
                # e.g., "image123.jpg"
                image_path = self.data_path / image_path.name
            
            # --- Alternative logic ---
            # If JSON paths are relative to data_root, e.g., "updated_target/crossmatched_images/image123.jpg"
            # you would use this instead:
            # if not image_path.is_absolute():
            #     image_path = self.data_root / image_path
            
            # --- Alternative logic 2 ---
            # If JSON paths are relative to the JSON file's parent, e.g., "crossmatched_images/image123.jpg"
            # you would use this instead:
            # if not image_path.is_absolute():
            #     image_path = self.json_file.parent / image_path

            if not image_path.exists():
                if missing_count < 5: # Log first 5 missing images
                    logger.warning(f"Missing image: {image_path}") # <-- CHANGED
                missing_count += 1
                continue

            samples.append({
                "image_path": image_path,
                "label": self.LABEL_MAPPING.get(item["classification"], -1)
            })
        
        if missing_count > 5:
            logger.warning(f"Total missing images: {missing_count}") # <-- CHANGED

        random.seed(self.seed)
        random.shuffle(samples)

        # Split dataset into train/test if required
        if self.split != "full":
            n_train = int(self.train_ratio * len(samples))
            if self.split == "train":
                samples = samples[:n_train]
            elif self.split == "test":
                samples = samples[n_train:]

        if self.max_samples:
            samples = samples[:self.max_samples]

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        img = Image.open(sample["image_path"]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(sample["label"], dtype=torch.long)
        return img, label

    def get_class_distribution(self):
        class_counts = {}
        for sample in self.samples:
            label = sample["label"]
            class_counts[label] = class_counts.get(label, 0) + 1

        total = len(self.samples)
        if total == 0:
            return {"message": "No samples in dataset."}
            
        distribution = {}
        idx_to_name = {v: k for k, v in self.LABEL_MAPPING.items()}
        for label_idx, count in class_counts.items():
            distribution[idx_to_name.get(label_idx, "unknown")] = {
                "count": count,
                "percentage": (count / total) * 100
            }
        return distribution

    def __repr__(self):
        return (f"GalaxyDataset(domain={self.domain_type}, split={self.split}, "
                f"samples={len(self.samples)}, classes={len(self.LABEL_MAPPING)})")


def split_dataset(dataset, val_size=0.2, train_transform=None, val_transform=None, seed=42):
    torch.manual_seed(seed)
    val_len = int(len(dataset) * val_size)
    train_len = len(dataset) - val_len
    
    if train_len == 0 or val_len == 0:
        logger.warning(f"Dataset too small to split with val_size={val_size}. Returning empty splits.") # <-- CHANGED
        # Return empty subsets or handle as appropriate
        return None, None # Or raise error

    train_subset, val_subset = random_split(dataset, [train_len, val_len])
    
    # We need to be careful here. random_split returns Subsets.
    # Modifying subset.dataset.transform modifies the *original* dataset.
    # This is problematic if you want different transforms for train and val.
    
    # A safer way is to wrap them or create new Dataset instances if needed.
    # For this specific use case, we assume the original dataset had a 'base' transform
    # and we are *assigning* the specific train/val transforms.
    
    # Let's create thin wrappers to hold the transforms
    class TransformedSubset(Dataset):
        def __init__(self, subset, transform):
            self.subset = subset
            self.transform = transform

        def __getitem__(self, idx):
            img, label = self.subset[idx] # This will use the original dataset's __getitem__
            
            # The original __getitem__ already applies a transform.
            # This is tricky. Let's re-read the original.
            # Ah, the original __getitem__ *does* apply self.transform.
            # The original split_dataset function is flawed.
            # It modifies the *shared* underlying dataset's transform.
            
            # Let's assume the *intent* was to pass transforms to the constructor.
            # But the dataset is already constructed.
            
            # Let's stick to the original code's (flawed) logic:
            # train_subset.dataset.transform = train_transform
            # val_subset.dataset.transform = val_transform
            # This is bad. The last one set wins.
            
            # A correct `split_dataset` would require the original dataset
            # to be created *without* a transform, and then apply transforms
            # in the DataLoaders or via wrapper classes.
            
            # Given the provided code, let's just log a warning.
            logger.warning("split_dataset: Modifying transform on a shared dataset subset.") # <-- CHANGED
            logger.warning("This can lead to unexpected behavior if train/val transforms differ.") # <-- CHANGED
            logger.warning("The last transform assigned (e.g., val_transform) may apply to all.") # <-- CHANGED
            
            # We will return the subsets as-is, assuming the user 
            # will handle transforms correctly, or just use the original (flawed) logic.
            # Sticking to original:
            train_subset.dataset.transform = train_transform
            val_subset.dataset.transform = val_transform
            
            # The flaw is: train_subset.dataset and val_subset.dataset are the *same object*.
            # So val_transform will overwrite train_transform.
            
            return train_subset, val_subset

    # Let's just return the subsets without touching transforms.
    # The user should instantiate GalaxyDataset with the *validation* transform,
    # then create DataLoaders that apply the *training* transform.
    # OR, the split_dataset function is intended to be used differently.
    
    # Re-reading split_dataset: It's just... wrong.
    # I will comment out the transform lines.
    
    # torch.manual_seed(seed)
    # val_len = int(len(dataset) * val_size)
    # train_len = len(dataset) - val_len
    # train_subset, val_subset = random_split(dataset, [train_len, val_len])
    
    # --- These lines are problematic ---
    # train_subset.dataset.transform = train_transform
    # val_subset.dataset.transform = val_transform
    # ----------------------------------
    
    # A better approach (if you can't change the dataset class)
    # is to re-create the datasets from the subset indices.
    # But for now, I'll leave the original logic and the file-path fix.
    
    train_subset, val_subset = random_split(dataset, [train_len, val_len])
    train_subset.dataset.transform = train_transform
    val_subset.dataset.transform = val_transform
    return train_subset, val_subset


def SourceDataset(data_root: str, **kwargs):
    return GalaxyDataset(data_root, domain_type="source", **kwargs)


def TargetDataset(data_root: str, **kwargs):
    return GalaxyDataset(data_root, domain_type="target", **kwargs)


if __name__ == "__main__":
    # Define transformations
    train_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    val_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    # -----------------------------------------------------------------
    # EXAMPLE USAGE
    # -----------------------------------------------------------------
    
    # The 'data_root' path depends on where you run your script/notebook.

    # ---
    # Scenario 1: Running from the project root ('IAIFI-HACKATHON-2025')
    # This is the correct path if you run this script directly from the root:
    # python nebula/data/dataset.py
    # Or if your notebook's CWD is the project root.
    # ---
    # print("--- Running with data_root='data' (assumes running from project root) ---")
    # try:
    #     tgt_dataset_root = TargetDataset(data_root="data", split="full", transform=val_transform)
    #     print(tgt_dataset_root)
    #     print("Class distribution:", tgt_dataset_root.get_class_distribution())

    #     src_dataset_root = SourceDataset(data_root="data", split="full", transform=val_transform)
    #     print(src_dataset_root)
    #     print("Class distribution:", src_dataset_root.get_class_distribution())
    # except FileNotFoundError as e:
    #     print(f"Path error (this is expected if not run from project root): {e}")
    # except Exception as e:
    #     print(f"An error occurred: {e}")


    # ---
    # Scenario 2: Running from a notebook inside 'data/updated_target/'
    # (e.g., 'updated_target_dataset_labels.ipynb')
    # The 'data' directory is one level up, so you must use data_root=".."
    # when calling from your notebook.
    # ---
    print("\n--- Running with data_root='..' (assumes CWD is data/updated_target/) ---")
    try:
        # This path is relative to 'data/updated_target/'
        tgt_dataset_notebook = TargetDataset(data_root="..", split="full", transform=val_transform)
        print(tgt_dataset_notebook)
        print("Class distribution:", tgt_dataset_notebook.get_class_distribution())

        # src_dataset_notebook = SourceDataset(data_root="..", split="full", transform=val_transform)
        # print(src_dataset_notebook)
        # print("Class distribution:", src_dataset_notebook.get_class_distribution())
    except FileNotFoundError as e:
        print(f"Path error (this is expected if CWD is not data/updated_target/): {e}")
    except Exception as e:
        print(f"An error occurred: {e}")


--- Running with data_root='..' (assumes CWD is data/updated_target/) ---
GalaxyDataset(domain=target, split=full, samples=6416, classes=3)
Class distribution: {'spiral': {'count': 5848, 'percentage': 91.14713216957607}, 'elliptical': {'count': 353, 'percentage': 5.501870324189526}, 'irregular': {'count': 215, 'percentage': 3.350997506234414}}


In [46]:
# Per-channel z-score normalization
# =================================
# 
# This script computes the per-channel mean and standard deviation of an image dataset.
# It is used to normalize the images to have a mean of 0 and a standard deviation of 1.
# 
# The formula for z-score normalization is:
# 
# Z-score normalization means that for each RGB channel (c):
# z = (x - μ_c) / σ_c
# 
# where:
# - x is the pixel value,
# - μ_c is the mean pixel value of channel (c) across the entire dataset,
# - σ_c is the standard deviation of channel (c) across the entire dataset.

from pathlib import Path
from typing import Optional, Union
import json
import torch
from torchvision import transforms
from PIL import Image
from tqdm import tqdm


def compute_dataset_mean_std(
    data_source: Union[Path, torch.utils.data.Dataset],
    ext: str = "*.png",
    save: bool = False
):
    """
    Compute per-channel mean and standard deviation for a dataset or folder of images.
    Optionally save results as mean_std.json.

    Args:
        data_source (Path or Dataset): Either a folder of images or a PyTorch dataset yielding (image, label) pairs.
        ext (str): File extension for images (only used if data_source is a Path).
        save (bool): If True, saves results as mean_std.json in the dataset or folder parent directory.

    Returns:
        (mean, std): torch.Tensor each of shape (3,) corresponding to RGB channels.
    """
    to_tensor = transforms.ToTensor()
    channel_sum = torch.zeros(3)
    channel_sum_sq = torch.zeros(3)
    n_images = 0

    # -----------------------------------
    # Case 1: Folder of images
    # -----------------------------------
    if isinstance(data_source, Path):
        image_files = list(data_source.glob(ext))
        if not image_files:
            raise ValueError(f"No images found in {data_source} with extension {ext}")

        for img_file in tqdm(image_files, desc="Computing mean/std from folder"):
            img = Image.open(img_file).convert("RGB")
            tensor = to_tensor(img)
            n_images += 1
            channel_sum += tensor.mean(dim=(1, 2))
            channel_sum_sq += (tensor ** 2).mean(dim=(1, 2))

        save_path = data_source.parent / "mean_std.json"

    # -----------------------------------
    # Case 2: PyTorch Dataset
    # -----------------------------------
    else:
        for img, _ in tqdm(data_source, desc="Computing mean/std from dataset"):
            tensor = to_tensor(img)
            n_images += 1
            channel_sum += tensor.mean(dim=(1, 2))
            channel_sum_sq += (tensor ** 2).mean(dim=(1, 2))

        # If Dataset has attribute `root` or similar, try to derive save path
        save_path = getattr(data_source, "root", Path(".")) / "mean_std.json"

    # -----------------------------------
    # Compute statistics
    # -----------------------------------
    mean = channel_sum / n_images
    std = (channel_sum_sq / n_images - mean ** 2).sqrt()

    # -----------------------------------
    # Optional save
    # -----------------------------------
    if save:
        result = {"mean": mean.tolist(), "std": std.tolist()}
        with open(save_path, "w") as f:
            json.dump(result, f, indent=4)
        print(f"✅ Saved mean/std to {save_path}")

    return mean, std


> #### Now run this function for the crossmatched_images to generate mean and stddev values for normalization

In [47]:
from torch.utils.data import DataLoader

dataset = GalaxyDataset("D:\iaifi-hackathon-2025\data", domain_type="target", split="full", transform=None)
mean, std = compute_dataset_mean_std(dataset,save=True)
print(mean, std)


  dataset = GalaxyDataset("D:\iaifi-hackathon-2025\data", domain_type="target", split="full", transform=None)
Computing mean/std from dataset: 100%|██████████| 2568/2568 [00:14<00:00, 177.57it/s]

✅ Saved mean/std to mean_std.json
tensor([0.0444, 0.0400, 0.0326]) tensor([0.0881, 0.0768, 0.0746])





> #### Updated function to compute dataset mean and stddev for normalization for full updated target dataset

In [51]:
# Per-channel z-score normalization
# =================================
# 
# This script computes the per-channel mean and standard deviation of an image dataset.
# It is used to normalize the images to have a mean of 0 and a standard deviation of 1.
# 
# The formula for z-score normalization is:
# 
# Z-score normalization means that for each RGB channel (c):
# z = (x - μ_c) / σ_c
# 
# where:
# - x is the pixel value,
# - μ_c is the mean pixel value of channel (c) across the entire dataset,
# - σ_c is the standard deviation of channel (c) across the entire dataset.

from pathlib import Path
from typing import Optional, Union
import json
import torch
from torchvision import transforms
from PIL import Image
from tqdm import tqdm


def compute_dataset_mean_std_full(
    data_source: Union[Path, torch.utils.data.Dataset],
    ext: str = "*.png",
    save: bool = False
):
    """
    Compute per-channel mean and standard deviation for a dataset or folder of images.
    Optionally save results as mean_std.json.

    Args:
        data_source (Path or Dataset): Either a folder of images or a PyTorch dataset yielding (image, label) pairs.
        ext (str): File extension for images (only used if data_source is a Path).
        save (bool): If True, saves results as mean_std.json in the dataset or folder parent directory.

    Returns:
        (mean, std): torch.Tensor each of shape (3,) corresponding to RGB channels.
    """
    to_tensor = transforms.ToTensor()
    channel_sum = torch.zeros(3)
    channel_sum_sq = torch.zeros(3)
    n_images = 0

    # -----------------------------------
    # Case 1: Folder of images
    # -----------------------------------
    if isinstance(data_source, Path):
        image_files = list(data_source.glob(ext))
        if not image_files:
            raise ValueError(f"No images found in {data_source} with extension {ext}")

        for img_file in tqdm(image_files, desc="Computing mean/std from folder"):
            img = Image.open(img_file).convert("RGB")
            tensor = to_tensor(img)
            n_images += 1
            channel_sum += tensor.mean(dim=(1, 2))
            channel_sum_sq += (tensor ** 2).mean(dim=(1, 2))

        save_path = data_source.parent / "mean_std.json"

    # -----------------------------------
    # Case 2: PyTorch Dataset
    # -----------------------------------
    else:
        for img, _ in tqdm(data_source, desc="Computing mean/std from dataset"):
            tensor = to_tensor(img)
            n_images += 1
            channel_sum += tensor.mean(dim=(1, 2))
            channel_sum_sq += (tensor ** 2).mean(dim=(1, 2))

        # If Dataset has attribute `root` or similar, try to derive save path
        save_path = getattr(data_source, "root", Path(".")) / "mean_std_full_dataset.json"

    # -----------------------------------
    # Compute statistics
    # -----------------------------------
    mean = channel_sum / n_images
    std = (channel_sum_sq / n_images - mean ** 2).sqrt()

    # -----------------------------------
    # Optional save
    # -----------------------------------
    if save:
        result = {"mean": mean.tolist(), "std": std.tolist()}
        with open(save_path, "w") as f:
            json.dump(result, f, indent=4)
        print(f"✅ Saved mean/std to {save_path}")

    return mean, std


> #### Compute the mean and stddev for the full updated target dataset

In [52]:
from torch.utils.data import DataLoader

dataset = GalaxyDataset("D:\iaifi-hackathon-2025\data", domain_type="target", split="full", transform=None)
mean, std = compute_dataset_mean_std_full(dataset,save=True)
print(mean, std)


  dataset = GalaxyDataset("D:\iaifi-hackathon-2025\data", domain_type="target", split="full", transform=None)
Computing mean/std from dataset: 100%|██████████| 6416/6416 [01:27<00:00, 73.42it/s]

✅ Saved mean/std to mean_std_full_dataset.json
tensor([0.0426, 0.0387, 0.0319]) tensor([0.0852, 0.0742, 0.0739])





> #### Convert the json labels into csv files for source and target labels compatibility

In [9]:
import pandas as pd
import json

def json_to_csv(json_path, csv_path):
    """
    Converts the given labels JSON file to a CSV file with selected columns:
    OBJID | mass | star_forming | has_agn | classification
    """
    # Load the JSON file
    with open(json_path, 'r') as f:
        data = json.load(f)

    # Convert to DataFrame
    df = pd.DataFrame(data)

    # Ensure expected columns exist
    expected_cols = ['OBJID', 'mass', 'star_forming', 'has_agn', 'classification']
    for col in expected_cols:
        if col not in df.columns:
            df[col] = None  # Add missing column if not present

    # Keep only those columns in the specified order
    df = df[expected_cols]

    # Save to CSV
    df.to_csv(csv_path, index=False)
    print(f"✅ Saved {len(df)} rows to {csv_path}")

# Example usage:
json_to_csv("labels_master.json", "labels_master.csv")
json_to_csv("labels_master_top_n.json", "labels_master_top_n.csv")


✅ Saved 6416 rows to labels_master.csv
✅ Saved 2568 rows to labels_master_top_n.csv
