In [1]:
# Import necessary modules:
from astropy.io import fits

import h5py
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
import copy
import pickle
import os
import gc

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K


2025-06-11 17:53:29.414537: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-11 17:53:29.414788: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-11 17:53:30.144019: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-11 17:53:31.194444: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:

# ------------------------------------------------------------------------------------------------------------

# Flag to remove the young and old clusters:
flag_remove_extremes = False # Currently removing younger than 10^7 years and older than 10^9.5

# Flag to augment or not:
flag_use_augmented = False

# If we are blacking out:
flag_black = False

# Choose the normalization method:
norm_by = "five-images"  # "dataset", "filter, "five-images", single-image"
# normalizing by dataset or by filter makes the weird results because the points values are too different. 
# very small or very large.

# ------------------------------------------------------------------------------------------------------------
# Directories:

# Raw data:
dir_data_raw = "/pool001/vianajr/cluster_ages_1/data/data_raw/raw_phangs_dataset.h5"

# Aux text:
txt_extremes = "yes" if flag_remove_extremes else "no"
txt_augment = "yes" if flag_use_augmented else "no"
# Base prefix for the results directory:
results_prefix = f"single_case_5im_remextremes_{txt_extremes}_augment_{txt_augment}_"

# If we do have blackout:
if flag_black:
    
    # Flag to use the inner or outer region of the circle:
    flag_black_inner = False # If flag_black_inner is True, we are blacking out the center, if Flase black out outer.
    # Define the radius for the blacking:
    R = 6

    # If inner:
    if flag_black_inner: 
        
        # Extension for the paths:
        extension = "inner"
        
    # If outer: 
    else: 
        # Extension for the paths:
        extension = "outer"

    

# ------------------------------------------------------------------------------------------------------------

# Number of models per case to get an average of errors:
num_models_per_case = 5

# Flag to plot preliminary data visualization:
flag_plot_data_viz = True

# ------------------------------------------------------------------------------------------------------------
# Display:

print()
print()
print("--------------------------------------")
print("SINGLE -------------------------------")
print("5-im case ----------------------------")
print()
print("Params:")
print()
print("flag_remove_extremes: ", flag_remove_extremes)
print("flag_use_augmented: ", flag_use_augmented)
print()
print("flag_black: ", flag_black)
if flag_black:
    print()
    print("flag_black_inner: ", flag_black_inner)
    print("R: ", R)
print()
print("Normalization by")
print(norm_by)
print()





--------------------------------------
SINGLE -------------------------------
5-im case ----------------------------

Params:

flag_remove_extremes:  False
flag_use_augmented:  False

flag_black:  False

Normalization by
five-images



In [4]:

# ------------------------------------------------------------------------------------------------------------
# Set the seed for reproducibility
random_seed = 15
random.seed(random_seed)
np.random.seed(random_seed)  # If you also want to ensure reproducibility with numpy functions

# ------------------------------------------------------------------------------------------------------------
# Class to read the dataset in the format it is:
class ReadPhangsH5:
    
    def __init__(self, hdf5_filename):
        # loading from hdf5 file
        with h5py.File(hdf5_filename, "r") as hf:
            
            # Get the cluster ID:
            self.cluster_ids = np.array(hf["cluster_ids"], dtype=np.int32)
            # Use astype(str) to correctly convert HDF5 string datasets to Python strings, for the galaxy_ids:
            self.galaxy_ids = np.array(hf["galaxy_ids"]).astype(str)
            # Get the image cutouts:
            self.image_cutouts = np.array(hf["image_cutouts"], dtype=np.float32)
            # Get the log of the ages:
            self.cluster_log_ages = np.array(hf["cluster_log_ages"], dtype=np.float32)

    def __getitem__(self, index):
        # Get the image cutouts for the instance (5 images)
        x = self.image_cutouts[index]

        # Get the log of the ages:
        y = self.cluster_log_ages[index]

        return x, y

    def __len__(self):
        return len(self.image_cutouts)

    
# ------------------------------------------------------------------------------------------------------------
# Function to split the input (X) and output (Y) from the dataset
def separate_X_Y(dataset):
    X = [x for x, _ in dataset]
    Y = [y for _, y in dataset]
    return np.array(X), np.array(Y)


# ------------------------------------------------------------------------------------------------------------
# Function to blackout a circle from the center:
def blackout(images, R, flag_black_inner):
    """
    Apply a circular mask blackout to the center of the images with radius R.
    The blackout will be an approximation since the images are 2D matrices.
    
    :param images: A numpy array of shape (n_samples, 5, 112, 112)
    :param R: Radius of the blackout circle (better if it is an odd number)
    :param flag_black_inner: If True we are removing the Inner center of the image, else the outer.
    :return: Modified images with the center blacked out
    """
    # Dimensions of the images
    n_samples, filters, height, width = images.shape # Here shape is: samples, filters, height, widht
    
    # Center of the images
    center_x, center_y = width // 2, height // 2  # For 111x111, this will be 55, 55
    
    # Create a mask with the same dimensions as the image
    y, x = np.ogrid[:height, :width]
    
    # Correct the radius to be applied in a "circle" like pattern
    mask = (x - center_x) ** 2 + (y - center_y) ** 2 <= (R ** 2)
    
    # If we are not blacking out the inner, then we are the outer:
    if not flag_black_inner:
        mask = np.logical_not(mask)
        
    # Show the mask:
    # plt.imshow(mask)

    # Apply the mask to each image in the dataset
    for i in range(n_samples):
        for j in range(filters):
            images[i, j][mask] = 0  # Zero out the masked area
        
    return images


# ------------------------------------------------------------------------------------------------------------
# Define a function to split the data into tr, vl, and ts sets
def split_dataset(N, tr_ratio=0.7, vl_ratio=0.15, seed=42):
    
    # Initialize:
    random.seed(seed)
    indices = list(range(N))
    random.shuffle(indices)
    
    tr_split = int(tr_ratio * N)
    vl_split = int((tr_ratio + vl_ratio) * N)
    
    tr_indices = indices[:tr_split]
    vl_indices = indices[tr_split:vl_split]
    ts_indices = indices[vl_split:]
    
    return tr_indices, vl_indices, ts_indices


    
    
# ------------------------------------------------------------------------------------------------------------
# Function to augment a dataset by creating 8 versions 4 rotations and reversed 4 totations:
def augment_dataset_full(data):
    
    # The actual augmented data:
    aug_data = []
    # Reference for the past indexes that they corresponded to:
    aug_past_idxs = [] 
    
    for i, (x, y) in enumerate(data):

        # Original + 90° rotations
        for i in range(4):
            
            # Rotate 0°, 90°, 180°, 270°
            rot_x = np.array([ np.rot90(image, k=i) for image in x ])
            
            # Increase data:
            aug_data.append( (rot_x, y) )  
            aug_past_idxs.append(i)
        
        # Flip the image (up-down flip)
        flip_x = np.array([ np.flipud(image) for image in x ])
        
        # Flipped + 90° rotations
        for i in range(4):
            # Rotate flipped image
            rot_flip_x = np.array([ np.rot90(flip_image, k=i) for flip_image in flip_x ])
            # Increase data:
            aug_data.append( (rot_flip_x, y) )  
            aug_past_idxs.append(i)


    return aug_data, aug_past_idxs



In [6]:
# ------------------------------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------------------------------

# Start running the code

# Create an instance of the dataset class
raw_dataset = ReadPhangsH5(dir_data_raw)

# Get X and Y:
raw_X, raw_Y = separate_X_Y(raw_dataset)
# Display:
print("X.shape: ", raw_X.shape)
print()

# Get the list of cluster ids and galaxy ids:
raw_clust_ids = raw_dataset.cluster_ids
raw_galax_ids = raw_dataset.galaxy_ids


X.shape:  (8651, 5, 112, 112)



In [7]:

# Remove instances where Y is less than 7 or bigger than 9.5 if flag_remove_extremes is True
if flag_remove_extremes:
    
    # Mask:
    mask = (raw_Y >= 7) & (raw_Y <= 9.5)
    
    # Apply mask to get the curated sets:
    cur_X = raw_X[mask]
    cur_Y = raw_Y[mask]
    # Update the ids:
    cur_clust_ids = raw_clust_ids[mask]
    cur_galax_ids = raw_galax_ids[mask]
    
    # Display:
    print("New X.shape after removing young clusters: ", cur_X.shape)
    print()

# Otherwise simple assignation:
else:
    # Update the data:
    cur_X = raw_X
    cur_Y = raw_Y
    # Update the ids:
    cur_clust_ids = raw_clust_ids
    cur_galax_ids = raw_galax_ids

In [9]:
# Create the DataFrame
df = pd.DataFrame({
    'cluster_id': raw_clust_ids,
    'galaxy_id': raw_galax_ids
})

# Save to CSV
df.to_csv('using_clusters.csv', index=True)

In [None]:
1/0

In [None]:
# Get tr, vl, and ts indices
tr_indices, vl_indices, ts_indices = split_dataset(len(cur_Y))

# ------------------------------------------------------------------------------------------------------------

# If we are working with the blacked dataset:
if flag_black:
    
    # Get the black set:
    blck_X = blackout(cur_X, R, flag_black_inner)
    # Group the blacked data together:
    blck_dataset = [ [x, y] for x, y in zip(blck_X, cur_Y) ]
    # Update the chosen dataset:
    chosen_dataset = blck_dataset

# If not:
else:
    
    # Group the blacked data together:
    normal_dataset = [ [x, y] for x, y in zip(cur_X, cur_Y) ]
    # Then simply the raw_dataset:
    chosen_dataset = normal_dataset


# ------------------------------------------------------------------------------------------------------------
# Now access the data
tr_data = [chosen_dataset[i] for i in tr_indices]
vl_data = [chosen_dataset[i] for i in vl_indices]
ts_data = [chosen_dataset[i] for i in ts_indices]

# Get X and Y for tr, vl, and ts sets using the separate_X_Y function
X_tr, Y_tr = separate_X_Y(tr_data)
X_vl, Y_vl = separate_X_Y(vl_data)
X_ts, Y_ts = separate_X_Y(ts_data)

# Print the shapes of the data
print(f"Shape of X_tr: {X_tr.shape}")
print(f"Shape of X_vl: {X_vl.shape}")
print(f"Shape of X_ts: {X_ts.shape}")
print()
print(f"Shape of Y_tr: {Y_tr.shape}")
print(f"Shape of Y_vl: {Y_vl.shape}")
print(f"Shape of Y_ts: {Y_ts.shape}")
print()

# Get the ids for the sets:
# Cluster ids:
tr_clust_ids = [cur_clust_ids[i] for i in tr_indices]
vl_clust_ids = [cur_clust_ids[i] for i in vl_indices]
ts_clust_ids = [cur_clust_ids[i] for i in ts_indices]
# Galaxy ids:
tr_galax_ids = [cur_galax_ids[i] for i in tr_indices]
vl_galax_ids = [cur_galax_ids[i] for i in vl_indices]
ts_galax_ids = [cur_galax_ids[i] for i in ts_indices]


# ------------------------------------------------------------------------------------------------------------
# Regardless of which set we use for training, for evaluation purposes we want to have two subsets of the data:
#
#     Inner: Data inside the limits, without the extremes (young and old stars)
#     Outer: Data in the extremes (excluding inner)

# Get the masks:
msk_subset_inner = (raw_Y >= 7) & (raw_Y <= 9.5)
msk_subset_outer = ~ msk_subset_inner

# First, you actually need to get the blacked X of all the instances:
if flag_black: treated_all_X = blackout(raw_X, R, flag_black_inner)
else: treated_all_X = raw_X

# Apply the masks to treated_all_X, raw_Y, clust_ids and galax_ids:
subset_outer_X = treated_all_X[msk_subset_outer]
subset_inner_X = treated_all_X[msk_subset_inner]

subset_outer_Y = raw_Y[msk_subset_outer]
subset_inner_Y = raw_Y[msk_subset_inner]

subset_outer_clust_ids = raw_clust_ids[msk_subset_outer]
subset_inner_clust_ids = raw_clust_ids[msk_subset_inner]

subset_outer_galax_ids = raw_galax_ids[msk_subset_outer]
subset_inner_galax_ids = raw_galax_ids[msk_subset_inner]

# CAREFUL: You must get the test points of each subset, you cannot pick any from the training set.
# We will proceed by creating a unique identifier of each instance, then seeing which are both in the ts and the subset.

# Function to create a unique identifier for each instance based on the cluster and the galaxy id:
def create_unique_ids(clust_ids, galax_ids):
    return [str(a) + b for a, b in zip(clust_ids, galax_ids)]

# Obtain the unique identifiers
ts_unique_ids = create_unique_ids(ts_clust_ids, ts_galax_ids)
subset_outer_unique_ids = create_unique_ids(subset_outer_clust_ids, subset_outer_galax_ids)
subset_inner_unique_ids = create_unique_ids(subset_inner_clust_ids, subset_inner_galax_ids)

# Find matching unique IDs
matching_ts_and_outer_ids = set(ts_unique_ids) & set(subset_outer_unique_ids)
matching_ts_and_inner_ids = set(ts_unique_ids) & set(subset_inner_unique_ids)

# If we are removing the extremes, then matching_ts_and_outer_ids is empty:
if flag_remove_extremes: matching_ts_and_outer_ids = set(subset_outer_unique_ids)
    
# Select 600 points from matching IDs - We specify 600 for the test set, to make sure all comparisons are fair among all cases:
matching_ts_and_outer_indices = np.random.choice(list(matching_ts_and_outer_ids), 600, replace=False)
matching_ts_and_inner_indices = np.random.choice(list(matching_ts_and_inner_ids), 600, replace=False)

# Create boolean masks
ts_msk_subset_outer = np.isin(subset_outer_unique_ids, matching_ts_and_outer_indices)
ts_msk_subset_inner = np.isin(subset_inner_unique_ids, matching_ts_and_inner_indices)

# Finally get the evaluation subsets of both X and Y:
ts_subset_outer_X = subset_outer_X[ts_msk_subset_outer]
ts_subset_inner_X = subset_inner_X[ts_msk_subset_inner]
# Don't forget to ravel the Ys:
ts_subset_outer_Y = np.ravel(subset_outer_Y[ts_msk_subset_outer])
ts_subset_inner_Y = np.ravel(subset_inner_Y[ts_msk_subset_inner])

# ------------------------------------------------------------------------------------------------------------

# X and Y:
use_X_tr = X_tr
use_Y_tr = np.ravel(Y_tr) 
# The ids:
use_tr_clust_ids = tr_clust_ids 
use_tr_galax_ids = tr_galax_ids


# The vl and ts Xs:
use_X_vl = X_vl
use_X_ts = X_ts
# The vl and ts Ys:
use_Y_vl = np.ravel(Y_vl)
use_Y_ts = np.ravel(Y_ts)
# The ids:
use_vl_clust_ids = vl_clust_ids
use_ts_clust_ids = ts_clust_ids
use_vl_galax_ids = vl_galax_ids
use_ts_galax_ids = ts_galax_ids


# Rearrange the shape of the inputs, so the channels/filters is the last dimension, now (n, 112, 112, 5)
use_X_tr = np.stack([use_X_tr[:, i, :, :] for i in range(5)], axis=-1)
use_X_vl = np.stack([use_X_vl[:, i, :, :] for i in range(5)], axis=-1)
use_X_ts = np.stack([use_X_ts[:, i, :, :] for i in range(5)], axis=-1)

ts_subset_outer_X = np.stack([ts_subset_outer_X[:, i, :, :] for i in range(5)], axis=-1)
ts_subset_inner_X = np.stack([ts_subset_inner_X[:, i, :, :] for i in range(5)], axis=-1)


# Print the shapes of the data
print(f"Shape of use_X_tr: {use_X_tr.shape}")
print(f"Shape of use_X_vl: {use_X_vl.shape}")
print(f"Shape of use_X_ts: {use_X_ts.shape}")
print()
print(f"Shape of use_Y_tr: {use_Y_tr.shape}")
print(f"Shape of use_Y_vl: {use_Y_vl.shape}")
print(f"Shape of use_Y_ts: {use_Y_ts.shape}")
print()