In [None]:
!pip install tensorflow
!pip install tensorflow_probability
!pip install git+https://github.com/henrysky/astroNN.git

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import utils

import numpy as np
import pylab as plt
import random
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

from astroNN.models import Galaxy10CNN
from astroNN.datasets import load_galaxy10sdss
from astroNN.datasets.galaxy10sdss import galaxy10cls_lookup, galaxy10_confusion

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import History

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, average_precision_score, roc_curve, precision_recall_curve
from sklearn.preprocessing import LabelEncoder, label_binarize
from sklearn.model_selection import train_test_split

In [None]:
images, labels = load_galaxy10sdss()

In [None]:
train_images, temp_images, train_labels, temp_labels = train_test_split(
    images, labels, test_size=0.3, random_state=42, stratify=labels
)

val_images, test_images, val_labels, test_labels = train_test_split(
    temp_images, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
)
num_classes = 10

In [None]:
train_labels_categorical = utils.to_categorical(train_labels, num_classes)
val_labels_categorical = utils.to_categorical(val_labels, num_classes)
test_labels_categorical = utils.to_categorical(test_labels, num_classes)

In [None]:
def cutmix_augmentation(x_batch, y_batch, alpha=1.0):
    batch_size, h, w, c = x_batch.shape
    lambda_ = np.random.beta(alpha, alpha)

    rx = np.random.randint(w)
    ry = np.random.randint(h)
    rw = int(w * np.sqrt(1 - lambda_))
    rh = int(h * np.sqrt(1 - lambda_))

    x1, x2 = np.clip(rx - rw // 2, 0, w), np.clip(rx + rw // 2, 0, w)
    y1, y2 = np.clip(ry - rh // 2, 0, h), np.clip(ry + rh // 2, 0, h)

    index = np.random.permutation(batch_size)

    x_cutmix = x_batch.copy()
    y_cutmix = y_batch.copy()

    x_cutmix[:, y1:y2, x1:x2, :] = x_batch[index, y1:y2, x1:x2, :]

    lambda_adjusted = 1 - ((x2 - x1) * (y2 - y1)) / (w * h)
    y_cutmix = (lambda_adjusted * y_batch + (1 - lambda_adjusted) * y_batch[index]).astype(np.float32)

    return x_cutmix.astype(np.uint8), y_cutmix

def balance_classes_with_cutmix(images, labels, target_samples_per_class=5000, batch_size=32, alpha=1.0):

    images = images.astype(np.uint8)


    if labels.ndim == 1 or labels.shape[1] == 1:
        labels = tf.keras.utils.to_categorical(labels, num_classes=10).astype(np.float32)

    num_classes = labels.shape[1]
    unique_classes = np.arange(num_classes)
    images_list, labels_list = [], []

    for cls in unique_classes:
        cls_mask = np.argmax(labels, axis=1) == cls
        cls_images = images[cls_mask]
        cls_labels = labels[cls_mask]

        num_images_needed = target_samples_per_class - len(cls_images)

        if num_images_needed > 0:
            augmented_images, augmented_labels = [], []

            while len(augmented_images) < num_images_needed:
                batch_indices = np.random.choice(len(cls_images), batch_size, replace=True)
                x_batch, y_batch = cls_images[batch_indices], cls_labels[batch_indices]

                x_cutmix, y_cutmix = cutmix_augmentation(x_batch, y_batch, alpha=alpha)

                augmented_images.extend(x_cutmix)
                augmented_labels.extend(y_cutmix)

            augmented_images = np.array(augmented_images[:num_images_needed], dtype=np.uint8)
            augmented_labels = np.array(augmented_labels[:num_images_needed], dtype=np.float32)

            cls_images = np.concatenate([cls_images, augmented_images], axis=0)
            cls_labels = np.concatenate([cls_labels, augmented_labels], axis=0)

        images_list.append(cls_images)
        labels_list.append(cls_labels)

    balanced_images = np.concatenate(images_list, axis=0).astype(np.uint8)
    balanced_labels = np.concatenate(labels_list, axis=0).astype(np.float32)

    return balanced_images, balanced_labels

In [None]:
balanced_images, balanced_labels_categorical = balance_classes_with_cutmix(
    train_images,
    train_labels_categorical,
    target_samples_per_class=20000,
    batch_size=8
)

In [None]:
balanced_labels_categorical = balanced_labels_categorical.astype(np.float32)
balanced_images = balanced_images.astype(np.float32)
val_labels_categorical = val_labels_categorical.astype(np.float32)
val_images = val_images.astype(np.float32)
test_labels_categorical = test_labels_categorical.astype(np.float32)
test_images = test_images.astype(np.float32)

In [None]:
train_images_preprocessed = preprocess_input(balanced_images)
val_images_preprocessed = preprocess_input(val_images)
test_images_preprocessed = preprocess_input(test_images)

In [None]:
base_model = VGG19(weights='imagenet', include_top=False, input_shape=(69, 69, 3))

In [None]:
x = GlobalAveragePooling2D()(base_model.output)

x = Dense(1024, activation='relu')(x)

x = Dropout(0.5)(x)

outputs = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=outputs)

In [None]:
for layer in base_model.layers:
    layer.trainable = False

In [None]:
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1)
]

In [None]:
history = model.fit(train_images_preprocessed, balanced_labels_categorical, batch_size=32, epochs=20, validation_data=(val_images_preprocessed, val_labels_categorical), callbacks=callbacks)

In [None]:
for layer in base_model.layers:
    layer.trainable = True

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
history_finetune = model.fit(train_images_preprocessed, balanced_labels_categorical, batch_size=32, epochs=20, validation_data=(val_images_preprocessed, val_labels_categorical), callbacks=callbacks)

In [None]:
loss, accuracy = model.evaluate(test_images_preprocessed, test_labels_categorical)
print("Loss:", loss)
print("Accuracy:", accuracy)

In [None]:
y_pred_proba = model.predict(test_images_preprocessed)

y_pred = np.argmax(y_pred_proba, axis=1)

y_true = np.argmax(test_labels_categorical, axis=1)

num_classes = len(np.unique(y_true))
y_true_one_hot = label_binarize(y_true, classes=np.arange(num_classes))

roc_auc = roc_auc_score(y_true_one_hot, y_pred_proba, multi_class='ovr')
pr_auc = average_precision_score(y_true_one_hot, y_pred_proba)

print(f'ROC AUC: {roc_auc:.4f}')
print(f'PR AUC: {pr_auc:.4f}')