In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
# Mounting Drive
from google.colab import drive

drive.mount("/content/drive")

In [None]:
# Loading Training Dataset
images = np.load("/content/drive/MyDrive/galaxy_galaxy_train_images.npy")
labels = np.load("/content/drive/MyDrive/galaxy_galaxy_train_labels.npy")
images = images.reshape((images.shape[0], images.shape[1], images.shape[2], 1))

In [None]:
# Test Train Split
X_train, X_val, y_train, y_val = train_test_split(
    images, labels, test_size=0.1, random_state=42
)

In [None]:
# Reshaping Images and Labels for Preprocessing
X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], X_train.shape[2], 1))
X_val = X_val.reshape((X_val.shape[0], X_val.shape[1], X_val.shape[2], 1))
y_train = y_train.reshape((y_train.shape[0], y_train.shape[1], y_train.shape[2], 1))
y_val = y_val.reshape((y_val.shape[0], y_val.shape[1], y_val.shape[2], 1))

In [None]:
# Preprocessing
train_datagen = ImageDataGenerator(zoom_range=0.5)
val_datagen = ImageDataGenerator()


def dual_image_generator(images, labels, batch_size=32):
    image_gen = train_datagen.flow(images, batch_size=batch_size, seed=42)
    label_gen = train_datagen.flow(labels, batch_size=batch_size, seed=42)

    while True:
        img_batch = next(image_gen)
        lbl_batch = next(label_gen)
        lbl_batch = lbl_batch.squeeze(-1)
        lbl_batch = np.round(lbl_batch).astype(int)
        lbl_batch = np.clip(
            lbl_batch, 0, 3
        )  # lbl_batch = np.clip(lbl_batch, 0, 4) for galaxy-quasar lenses
        yield img_batch, lbl_batch


train_generator = dual_image_generator(X_train, y_train, batch_size=32)
val_generator = dual_image_generator(X_val, y_val, batch_size=32)

steps_per_epoch = len(X_train) // 32
validation_steps = len(X_val) // 32

In [None]:
# Attention Block
def attention_block(x, g, inter_channels):
    theta_x = layers.Conv2D(inter_channels, (1, 1), padding="same")(x)
    phi_g = layers.Conv2D(inter_channels, (1, 1), padding="same")(g)
    f = layers.add([theta_x, phi_g])
    f = layers.Activation("relu")(f)
    psi = layers.Conv2D(1, (1, 1), activation="sigmoid", padding="same")(f)
    return layers.multiply([x, psi])

In [None]:
# Dice Coefficient and Dice Loss
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=tf.shape(y_pred)[-1])
    y_true = tf.cast(y_true, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[0, 1, 2])
    union = tf.reduce_sum(y_true + y_pred, axis=[0, 1, 2])
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)


def dice_loss(y_true, y_pred, smooth=1e-6):
    return 1 - dice_coefficient(y_true, y_pred)

In [None]:
# Focal Loss
def focal_loss(gamma=2.0, alpha=0.25, smooth_eps=1e-6):
    def loss_fn(y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, smooth_eps, 1 - smooth_eps)
        y_true = tf.cast(y_true, tf.int32)
        true_class_probs = tf.reduce_sum(
            y_pred * tf.one_hot(y_true, depth=tf.shape(y_pred)[-1]), axis=-1
        )
        ce_loss = -tf.math.log(true_class_probs)
        modulating_factor = tf.pow(1 - true_class_probs, gamma)
        focal_loss = alpha * modulating_factor * ce_loss
        return tf.reduce_mean(focal_loss)

    return loss_fn

In [None]:
# Combined Loss
def combined_loss(y_true, y_pred):
    return dice_loss(y_true, y_pred) + focal_loss()(y_true, y_pred)

In [None]:
# Network Architecture
def attention_unet(input_shape):
    inputs = tf.keras.Input(input_shape)

    c1 = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(inputs)
    c1 = layers.Dropout(0.1)(c1)
    c1 = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(p1)
    c2 = layers.Dropout(0.1)(c2)
    c2 = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(256, (3, 3), activation="relu", padding="same")(p2)
    c3 = layers.Dropout(0.2)(c3)
    c3 = layers.Conv2D(256, (3, 3), activation="relu", padding="same")(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    c4 = layers.Conv2D(512, (3, 3), activation="relu", padding="same")(p3)
    c4 = layers.Dropout(0.2)(c4)
    c4 = layers.Conv2D(512, (3, 3), activation="relu", padding="same")(c4)
    p4 = layers.MaxPooling2D((2, 2))(c4)

    c5 = layers.Conv2D(1024, (3, 3), activation="relu", padding="same")(p4)
    c5 = layers.Dropout(0.3)(c5)
    c5 = layers.Conv2D(1024, (3, 3), activation="relu", padding="same")(c5)

    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding="same")(c5)
    attn4 = attention_block(c4, u6, inter_channels=512)
    u6 = layers.concatenate([u6, attn4])
    c6 = layers.Conv2D(512, (3, 3), activation="relu", padding="same")(u6)

    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding="same")(c6)
    attn3 = attention_block(c3, u7, inter_channels=256)
    u7 = layers.concatenate([u7, attn3])
    c7 = layers.Conv2D(256, (3, 3), activation="relu", padding="same")(u7)

    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding="same")(c7)
    attn2 = attention_block(c2, u8, inter_channels=128)
    u8 = layers.concatenate([u8, attn2])
    c8 = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(u8)

    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding="same")(c8)
    attn1 = attention_block(c1, u9, inter_channels=64)
    u9 = layers.concatenate([u9, attn1])
    c9 = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(u9)

    outputs = layers.Conv2D(4, (1, 1), activation="softmax")(
        c9
    )  # outputs = layers.Conv2D(5, (1, 1), activation='softmax')(c9) for Galaxy-Quasar Lenses

    model = models.Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer="adam", loss=combined_loss, metrics=[dice_coefficient])

    return model

In [None]:
# Training
early_stop = EarlyStopping(
    monitor="val_loss", patience=10, restore_best_weights=True, verbose=1
)
lr_schedule = ReduceLROnPlateau(
    monitor="val_loss", factor=0.5, patience=5, min_lr=1e-6, verbose=1
)

input_shape = (128, 128, 1)
model = attention_unet(input_shape=input_shape)
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=50,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=[early_stop, lr_schedule],
    verbose=1,
)

In [None]:
# Saving The Model
model.save("/content/drive/MyDrive/lensed_galaxy_segmentation_model.h5")

In [None]:
# Testing/validation Set
X_test = np.load("/content/drive/MyDrive/galaxy_galaxy_test_images.npy")
y_test = np.load("/content/drive/MyDrive/galaxy_galaxy_test_labels.npy")
X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], X_test.shape[2], 1))

In [None]:
# Prediction
predictions = model.predict(X_test)

In [None]:
# Performance Metrics
y_true = y_test.flatten()
y_pred = np.argmax(predictions, axis=-1).flatten()

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

In [None]:
# IoU and Dice Score
num_classes = 4  # num_classes = 5 for galaxy-quasar lenses


def iou_score(y_true, y_pred, num_classes):
    """
    Computes Intersection over Union (IoU) for each class and mean IoU.
    """
    ious = []
    for cls in range(num_classes):
        intersection = np.logical_and(y_true == cls, y_pred == cls).sum()
        union = np.logical_or(y_true == cls, y_pred == cls).sum()
        if union == 0:
            ious.append(float("nan"))
        else:
            ious.append(intersection / union)

    return np.nanmean(ious), ious


def dice_score(y_true, y_pred, num_classes):
    """
    Computes Dice Coefficient for each class and mean Dice Score.
    """
    dice_scores = []
    for cls in range(num_classes):
        intersection = 2 * np.logical_and(y_true == cls, y_pred == cls).sum()
        denominator = (y_true == cls).sum() + (y_pred == cls).sum()
        if denominator == 0:
            dice_scores.append(float("nan"))
        else:
            dice_scores.append(intersection / denominator)

    return np.nanmean(dice_scores), dice_scores


def pixel_accuracy(y_true, y_pred):
    """
    Computes Pixel Accuracy.
    """
    return np.mean(y_true == y_pred)


mean_iou, per_class_iou = iou_score(y_true, y_pred, num_classes)
mean_dice, per_class_dice = dice_score(y_true, y_pred, num_classes)
pixel_acc = pixel_accuracy(y_true, y_pred)

print(f"Mean IoU: {mean_iou:.4f}")
print(f"Per-Class IoU: {per_class_iou}")
print(f"Mean Dice Score: {mean_dice:.4f}")
print(f"Per-Class Dice Score: {per_class_dice}")
print(f"Pixel Accuracy: {pixel_acc:.4f}")

In [None]:
from google.colab import runtime

runtime.unassign()