# Model IV Galaxy Classifier Improved (Spiral, Elliptical, Odd objects)

**Most of this code is made by the authors of the paper of Ghaderi et al. (2025) (https://iopscience.iop.org/article/10.3847/1538-4365/ada8ab) and taken from the GitHub Repository: https://github.com/hmddev1/machine_learning_for_morphological_galaxy_classification**

**This notebook classifies the images of galaxies to spiral and elliptical galaxies and odd objects using the Vision Transformer and ResNet50 (Model IV).**

**Purpose of this notebook: Improve Model IV using seeding and cross-validation**

In [None]:
import os
path = '/content/drive/Shared drives/DLP Project/Project/Models/Galaxy Models/Looping each Model'
os.chdir(path)
%run imports.py
%matplotlib inline
import plotting
roc_curves = {}

# Paths to images
spath = r'/content/drive/Shared drives/DLP Project/Project/Data/galaxy/images/cropped_spiral'
epath = r'/content/drive/Shared drives/DLP Project/Project/Data/galaxy/images/cropped_elliptical'
opath = r'/content/drive/Shared drives/DLP Project/Project/Data/galaxy/images/cropped_odd'

# Default image size and zernike order.
image_size = 200
zernike_order = 45

# Loading the ZMs and concatenating to a consolidated dataset
spiral_data = pd.read_csv('/content/drive/Shared drives/DLP Project/Project/spiral_zms.csv')
elliptical_data = pd.read_csv('/content/drive/Shared drives/DLP Project/Project/elliptical_zms.csv')
odd_data = pd.read_csv('/content/drive/Shared drives/DLP Project/Project/odd_zms.csv')

spiral_data.drop("Unnamed: 0", axis = 1, inplace = True)
elliptical_data.drop("Unnamed: 0", axis = 1, inplace = True)
odd_data.drop("Unnamed: 0", axis = 1, inplace = True)

all_zm_data = np.concatenate([spiral_data, elliptical_data, odd_data])
np.shape(all_zm_data)

spiral_label = [0] * len(spiral_data)
elliptical_label = [1] * len(elliptical_data)
odd_label = [2] * len(odd_data)

all_labels = spiral_label + elliptical_label + odd_label
len(all_labels)

11744

In [None]:
def load_galaxy_images(data_dir, target_size):
        """
        Loads, resizes, and processes all JPG images from the specified directory.

        Parameters:
        data_dir (str): The directory containing the JPG images to be processed.
        target_size (tuple): The target size for resizing the images, specified as (width, height).

        Returns:
        list: A list of PIL Image objects, each representing a resized and processed image.

        The function performs the following steps:
        1. Lists all JPG image files in the specified directory.
        2. Reads each image using OpenCV.
        3. Resizes each image to the specified target size.
        4. Scales the pixel values and converts the image to a format compatible with PIL.
        5. Converts each resized image to a PIL Image object.
        6. Appends each PIL Image object to a list.
        7. Returns the list of PIL Image objects.
        """

        all_images = []

        file_path = [os.path.join(data_dir, filename) for filename in os.listdir(data_dir) if filename.endswith('.jpg')]

        for img in file_path:
            image = cv2.imread(img)
            resized_images=cv2.resize(image, target_size)
            resized_images = (resized_images * 255).astype(np.uint8)
            pil_images = Image.fromarray(resized_images)
            all_images.append(pil_images)

        return all_images

sp_dir = '/content/drive/Shared drives/DLP Project/Project/Data/galaxy/images/cropped_spiral'
el_dir = '/content/drive/Shared drives/DLP Project/Project/Data/galaxy/images/cropped_elliptical'
odd_dir = '/content/drive/Shared drives/DLP Project/Project/Data/galaxy/images/cropped_odd'

image_size = 200

sp_img = load_galaxy_images(sp_dir, target_size=(image_size,image_size))
el_img = load_galaxy_images(el_dir, target_size=(image_size,image_size))
odd_img = load_galaxy_images(odd_dir, target_size=(image_size,image_size))

# Due to lack of computational power, we will use a subset of the data (6000 galaxies) while keeping the same ratios between categories.
sp_img = sp_img[:3136]
el_img = el_img[:2082]
odd_img = odd_img[:782]

all_data = sp_img + el_img + odd_img

# Labels
label_s = [0] * len(sp_img)
label_e = [1] * len(el_img)
label_o = [2] * len(odd_img)

all_labels = label_s + label_e + label_o
len(all_labels)


# transforms for training data
train_transform = transforms.Compose([transforms.CenterCrop(image_size),
                                      transforms.RandomRotation(90),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0), ratio=(0.99, 1.01)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                      ])


# transforms for test data
test_transform = transforms.Compose([transforms.CenterCrop(image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                      ])



In [None]:
# Set a fixed seed for reproducibility
seed_value = 42
np.random.seed(seed_value)
random.seed(seed_value)
tf.random.set_seed(seed_value)

metrics_list = []

# Cross Validation Setup using StratifiedKFold
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed_value)

for fold, (train_idx, test_idx) in enumerate(skf.split(all_data, all_labels)):
    print(f"Fold {fold + 1}...")

    X_train, X_test, y_train, y_test, train_indices, test_indices = train_test_split(all_data, all_labels, np.arange(len(all_labels)), test_size=0.25, shuffle=True, random_state=None)
    y_train_encoded = to_categorical(y_train, num_classes=3)

    # Transformer for training data
    transformed_X_train = []
    for i in range(len(X_train)):
        transformed_train_images = train_transform(X_train[i])
        new_image = np.transpose(transformed_train_images, (1, 2, 0))
        transformed_X_train.append(new_image)

    # Transformer for testing data
    transformed_X_test = []
    for j in range(len(X_test)):
        transformed_test_images = test_transform(X_test[j])
        new_images = np.transpose(transformed_test_images, (1, 2, 0))
        transformed_X_test.append(new_images)

    class_weights = {0: len(all_data) / (3 * len(spiral_data)),
                     1: len(all_data) / (3 * len(elliptical_data)),
                     2: len(all_data) / (3 * len(odd_data))}

    # Defining the pretrained ResNet50
    b_size = 64
    e_num = 30

    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))

    x = Flatten()(base_model.output)
    x = Dense(64, activation='relu')(x)  # The custom layers
    output = Dense(3, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=output)

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)

    history = model.fit(
    np.array(transformed_X_train), y_train_encoded,
    batch_size=b_size,
    epochs=e_num,
    verbose = 1,
    callbacks=es,
    class_weight=class_weights,
    validation_split=0.1)

    y_pred = model.predict(np.array(transformed_X_test))
    y_pred_labels = np.argmax(y_pred, axis=1)

    # Compute confusion matrix
    cm = confusion_matrix(y_test, y_pred_labels)

    # Performance metrics (per-class)
    recall_per_class = recall_score(y_test, y_pred_labels, average=None)
    precision_per_class = precision_score(y_test, y_pred_labels, average=None)
    f1_per_class = f1_score(y_test, y_pred_labels, average=None)
    accuracy = accuracy_score(y_test, y_pred_labels) # overall

    tss_per_class = {}
    for i, class_name in enumerate(['Spiral', 'Elliptical', 'Odd']):
        tp = cm[i, i]
        fn = np.sum(cm[i, :]) - tp
        fp = np.sum(cm[:, i]) - tp
        tn = np.sum(cm) - (tp + fn + fp)
        tss_per_class[class_name] = (tp / (tp + fn + 1e-6)) - (fp / (fp + tn + 1e-6))

    metrics_list.append({
        'Fold': fold + 1,
        'Recall per Class': recall_per_class,
        'Precision per Class': precision_per_class,
        'F1 per Class': f1_per_class,
        'Accuracy': accuracy,
        'TSS per Class': tss_per_class
    })

for result in metrics_list:
    print(result)


Fold 1...
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 0us/step
Epoch 1/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 693ms/step - accuracy: 0.4862 - loss: 5.1528 - val_accuracy: 0.1600 - val_loss: 11.9960
Epoch 2/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 255ms/step - accuracy: 0.4883 - loss: 0.7710 - val_accuracy: 0.2778 - val_loss: 1.0920
Epoch 3/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 254ms/step - accuracy: 0.8500 - loss: 0.4641 - val_accuracy: 0.5044 - val_loss: 1.0808
Epoch 4/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 255ms/step - accuracy: 0.8173 - loss: 0.5284 - val_accuracy: 0.3289 - val_loss: 13.4202
Epoch 5/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 254ms/step - accuracy: 0.9046 -

In [None]:
# Running the same code independently, due to limited computational resources.

# Set a fixed seed for reproducibility
seed_value = 42
np.random.seed(seed_value)
random.seed(seed_value)
tf.random.set_seed(seed_value)

metrics_list = []

# Cross Validation Setup using StratifiedKFold
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed_value)

for fold, (train_idx, test_idx) in enumerate(skf.split(all_data, all_labels)):
    print(f"Fold {fold + 1}...")

    X_train, X_test, y_train, y_test, train_indices, test_indices = train_test_split(all_data, all_labels, np.arange(len(all_labels)), test_size=0.25, shuffle=True, random_state=None)
    y_train_encoded = to_categorical(y_train, num_classes=3)

    # Transformer for training data
    transformed_X_train = []
    for i in range(len(X_train)):
        transformed_train_images = train_transform(X_train[i])
        new_image = np.transpose(transformed_train_images, (1, 2, 0))
        transformed_X_train.append(new_image)

    # Transformer for testing data
    transformed_X_test = []
    for j in range(len(X_test)):
        transformed_test_images = test_transform(X_test[j])
        new_images = np.transpose(transformed_test_images, (1, 2, 0))
        transformed_X_test.append(new_images)

    class_weights = {0: len(all_data) / (3 * len(spiral_data)),
                     1: len(all_data) / (3 * len(elliptical_data)),
                     2: len(all_data) / (3 * len(odd_data))}

    # Defining the pretrained ResNet50
    b_size = 64
    e_num = 30

    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))

    x = Flatten()(base_model.output)
    x = Dense(64, activation='relu')(x)  # The custom layers
    output = Dense(3, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=output)

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)

    history = model.fit(
    np.array(transformed_X_train), y_train_encoded,
    batch_size=b_size,
    epochs=e_num,
    verbose = 1,
    callbacks=es,
    class_weight=class_weights,
    validation_split=0.1)

    y_pred = model.predict(np.array(transformed_X_test))
    y_pred_labels = np.argmax(y_pred, axis=1)

    # Compute confusion matrix
    cm = confusion_matrix(y_test, y_pred_labels)

    # Performance metrics (per-class)
    recall_per_class = recall_score(y_test, y_pred_labels, average=None)
    precision_per_class = precision_score(y_test, y_pred_labels, average=None)
    f1_per_class = f1_score(y_test, y_pred_labels, average=None)
    accuracy = accuracy_score(y_test, y_pred_labels) # overall

    tss_per_class = {}
    for i, class_name in enumerate(['Spiral', 'Elliptical', 'Odd']):
        tp = cm[i, i]
        fn = np.sum(cm[i, :]) - tp
        fp = np.sum(cm[:, i]) - tp
        tn = np.sum(cm) - (tp + fn + fp)
        tss_per_class[class_name] = (tp / (tp + fn + 1e-6)) - (fp / (fp + tn + 1e-6))

    metrics_list.append({
        'Fold': fold + 1,
        'Recall per Class': recall_per_class,
        'Precision per Class': precision_per_class,
        'F1 per Class': f1_per_class,
        'Accuracy': accuracy,
        'TSS per Class': tss_per_class
    })

for result in metrics_list:
    print(result)


Fold 1...
Epoch 1/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 580ms/step - accuracy: 0.6253 - loss: 4.8092 - val_accuracy: 0.3289 - val_loss: 5.1448
Epoch 2/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 257ms/step - accuracy: 0.7767 - loss: 0.5665 - val_accuracy: 0.1600 - val_loss: 6.7775
Epoch 3/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 257ms/step - accuracy: 0.8598 - loss: 0.3761 - val_accuracy: 0.1600 - val_loss: 1.9046
Epoch 4/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 255ms/step - accuracy: 0.9031 - loss: 0.2974 - val_accuracy: 0.3511 - val_loss: 1.4716
Epoch 5/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 257ms/step - accuracy: 0.9186 - loss: 0.2529 - val_accuracy: 0.1756 - val_loss: 4.9492
Epoch 6/30
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 255ms/step - accuracy: 0.9456 - loss: 0.1477 - val_accuracy: 0.1622 - val_loss: 2.0510
Epoch 7/30
