In [13]:
import sys
sys.path.append("../..")

from lunar_crater_age_logic.preprocess import load_data
from pathlib import Path
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import numpy as np


In [14]:
# Configuration
DATA_DIR = Path("/home/santanu/code/VMontejo/lunar-crater-age-classifier/raw_data/train")
IMG_HEIGHT = 227  # Changed to match preprocess.py
IMG_WIDTH = 227   # Changed to match preprocess.py
BATCH_SIZE = 32
EPOCHS = 50
NUM_CLASSES = 3

In [15]:
# Load data using your preprocess.py function
print("Loading data from preprocess.py...")

# Load full dataset with weighted sampling (handles imbalance)
train_loader = load_data(
    data_dir=DATA_DIR,
    balanced=False,  # Use full dataset
    batch_size=BATCH_SIZE,
    use_weighted_sampling=True,  # Handles class imbalance
    seed=42
)

# For validation, create a separate loader without shuffling
val_loader = load_data(
    data_dir=DATA_DIR,
    balanced=False,
    batch_size=BATCH_SIZE,
    use_weighted_sampling=False,
    seed=123  # Different seed for validation
)

class_names = ["ejecta", "oldcrater", "none"]
print(f"Class names: {class_names}")


Loading data from preprocess.py...
Creating FULL dataset (all available data)
ejecta: 358 images
oldcrater: 594 images
none: 2656 images
Total images 3608
Applying weighted sampling strategy
Original class distribution:
ejecta: 358 samples
oldcrater: 594 samples
none: 2656 samples
After weighted resampling
ejecta: 2656 samples (weight: 7.42)
oldcrater: 2656 samples (weight: 4.47)
none: 2656 samples (weight: 1.00)
Creating FULL dataset (all available data)
ejecta: 358 images
oldcrater: 594 images
none: 2656 images
Total images 3608
Using imbalanced data without weighting
Class names: ['ejecta', 'oldcrater', 'none']


### Base Model

In [None]:
def build_lroc_model(input_shape, num_classes):
    model = models.Sequential(name="LROC_Custom_CNN_RGB")

    model.add(layers.InputLayer(shape=input_shape))

    # --- Block 1 ---Edge&Lines---
    model.add(layers.Conv2D(32, (3,3), padding='same', kernel_initializer='he_normal'))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.MaxPooling2D((2,2)))

    # --- Block 2 ---Simple Shape---
    model.add(layers.Conv2D(64, (3,3), padding='same', kernel_initializer='he_normal'))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.MaxPooling2D((2,2)))

    # --- Block 3 ---Complex texture---
    model.add(layers.Conv2D(128, (3,3), padding='same', kernel_initializer='he_normal'))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.MaxPooling2D((2,2)))

    # --- Block 4 ---Deeper Features(Rays/Ejecta)---
    model.add(layers.Conv2D(256, (3,3), padding='same', kernel_initializer='he_normal'))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.MaxPooling2D((2,2)))

    # --- Block 5 (NEW): deeper crater textures ---
    model.add(layers.Conv2D(512, (3,3), padding='same', kernel_initializer='he_normal'))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.MaxPooling2D((2,2)))


    # --- Classification ---
    model.add(layers.GlobalAveragePooling2D())

    model.add(layers.Dense(256, kernel_regularizer=regularizers.l2(0.001)))
    model.add(layers.ReLU())
    model.add(layers.Dropout(0.5))

    # Output layer must be softmax
    model.add(layers.Dense(num_classes, activation='softmax'))

    return model

In [7]:

# Build model with correct input shape
model = build_lroc_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES)
model.summary()

2025-12-11 08:18:24.736405: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [8]:
def sparse_softmax_focal_loss(gamma=2.0, alpha=[1.8315, 0.6731, 1.4221]):
    """
    Focal Loss for multi-class classification with sparse labels.
    """

    alpha = tf.constant(alpha, dtype=tf.float32)  # <--- move this outside the inner function

    def loss_fn(y_true, y_pred):

        # Cast labels
        y_true = tf.cast(y_true, tf.int32)

        # One-hot encode
        num_classes = y_pred.shape[-1]
        y_true_onehot = tf.one_hot(y_true, depth=num_classes)

        # Numerical stability
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)

        # Cross-entropy
        ce = -tf.reduce_sum(y_true_onehot * tf.math.log(y_pred), axis=-1)

        # p_t
        p_t = tf.reduce_sum(y_true_onehot * y_pred, axis=-1)

        # Focal modulation
        modulating_factor = tf.pow((1 - p_t), gamma)

        # Alpha weighting (no name conflict now)
        alpha_t = tf.reduce_sum(y_true_onehot * alpha, axis=-1)

        return alpha_t * modulating_factor * ce

    return loss_fn

In [9]:
# 1. Compile

optimizer = optimizers.Adam(learning_rate=0.0001,
                            global_clipnorm=1.0,
)
focal_loss = sparse_softmax_focal_loss(
    gamma=2.0,
    alpha=[1.8315, 0.6731, 1.4221]
)


model.compile(
    optimizer=optimizer,
    loss=focal_loss,
    metrics=['accuracy']
)

# --- FIX 2: ROBUST CALLBACKS ---
callbacks = [
    # Stop training if validation loss doesn't improve for 5 epochs
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1),

    # Save the best model automatically in the native Keras format
    ModelCheckpoint('best_lroc_model.keras', monitor='val_accuracy', save_best_only=True, verbose=1),

    # Slow down learning rate if the model gets stuck
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)
]

print("✅ Model compiled. Callbacks configured to save as .keras")

✅ Model compiled. Callbacks configured to save as .keras


In [None]:
# Train model
print("Starting training...")
history = model.fit(
    train_loader,
    epochs=EPOCHS,
    validation_data=val_loader,
    callbacks=callbacks,
    verbose=1
)

print("✅ Training Complete.")

In [None]:
def plot_history(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs_range = range(len(acc))

    plt.figure(figsize=(15, 5))

    # Plot Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    # Plot Loss
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()

plot_history(history)

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

print("Generating Confusion Matrix...")

val_loader = load_data(
    data_dir=DATA_DIR,
    balanced=False,
    batch_size=BATCH_SIZE,
    use_weighted_sampling=False,
    seed=123,
       # VERY IMPORTANT
)

all_preds = []
all_labels = []

for batch_images, batch_labels in val_loader:
    preds = model.predict(batch_images, verbose=0)
    preds = np.argmax(preds, axis=1)

    all_preds.extend(preds)
    all_labels.extend(batch_labels)

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.show()

print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))


In [16]:
# Load data using your preprocess.py function
print("Loading data from preprocess.py...")

# Load full dataset with weighted sampling (handles imbalance)
train_loader = load_data(
    data_dir=DATA_DIR,
    balanced=False,  # Use full dataset
    batch_size=BATCH_SIZE,
    use_weighted_sampling=True,  # Handles class imbalance
    use_zscore=True,
    seed=42
)

# For validation, create a separate loader without shuffling
val_loader = load_data(
    data_dir=DATA_DIR,
    balanced=False,
    batch_size=BATCH_SIZE,
    use_weighted_sampling=False,
    seed=123  # Different seed for validation
)

class_names = ["ejecta", "oldcrater", "none"]
print(f"Class names: {class_names}")


Loading data from preprocess.py...
Creating FULL dataset (all available data)
ejecta: 358 images
oldcrater: 594 images
none: 2656 images
Total images 3608
Applying weighted sampling strategy
Original class distribution:
ejecta: 358 samples
oldcrater: 594 samples
none: 2656 samples
After weighted resampling
ejecta: 2656 samples (weight: 7.42)
oldcrater: 2656 samples (weight: 4.47)
none: 2656 samples (weight: 1.00)
Creating FULL dataset (all available data)
ejecta: 358 images
oldcrater: 594 images
none: 2656 images
Total images 3608
Using imbalanced data without weighting
Class names: ['ejecta', 'oldcrater', 'none']
