In [4]:
import tensorflow as tf
import shap
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Restart kernel to reload preprocess.py with augmentation support
import sys
sys.path.append("../..")

# Force reload the module to get latest changes
import importlib
import lunar_crater_age_logic.preprocess as preprocess_module
importlib.reload(preprocess_module)
from lunar_crater_age_logic.preprocess import load_data

from pathlib import Path


In [7]:
# --- 1. SETUP ---
# Load your best trained model
model = tf.keras.models.load_model('/home/santanu/code/VMontejo/lunar-crater-age-classifier/models/6. custom-model-new-preprocess-old-model-2-one-hot-no batch-1_AdamW_CosinLr.keras')
print("Model loaded.")

Model loaded.


In [9]:
# Configuration
DATA_DIR = Path("/home/santanu/code/VMontejo/lunar-crater-age-classifier/raw_data/train")
# Load data using your preprocess.py function
print("Loading data from preprocess.py...")
train_ds, val_ds, test_ds, train_count, val_count, test_count = load_data(
    data_dir=DATA_DIR.parent,
    model_type='custom',
    normalization='zscore',
    batch_size=32,
    seed=42,
    augment_train=True,  # Enable TensorFlow augmentation (rotation, flip, brightness, contrast, zoom)
    train_balanced=True,
    train_weighted_sampling=False
)

class_names = ["ejecta", "oldcrater", "none"]
print(f"Class names: {class_names}")

Loading data from preprocess.py...
Loading data for custom
Normalization: zscore
Batch size: 32
Training: TensorFlow augmentation ENABLED (rotation, flip, brightness, contrast, zoom)
Training: BALANCED (358 per class)
Balanced train: 1074 images (358 per class)

Data loaded:
Training: 1074 images (33 batches)
Validation: 613 images (19 batches)
Test: 779 images (24 batches)
Class names: ['ejecta', 'oldcrater', 'none']


In [11]:
# 2 ─ Collect background sample (20 images)
train_unbatched = train_ds.unbatch()

background_samples = []
for img, label in train_unbatched.take(20):
    background_samples.append(img.numpy())
background = np.stack(background_samples)

2025-12-15 11:39:12.766556: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [14]:
# 3 ─ Collect test images (5 images)
test_unbatched = test_ds.unbatch()

test_samples = []
for img, label in test_unbatched.take(5):
    test_samples.append(img.numpy())
test_images = np.stack(test_samples)

2025-12-15 11:41:23.743665: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [15]:
# 4 ─ Create SHAP explainer
explainer = shap.GradientExplainer(model, background)

In [None]:
# 5 ─ Compute SHAP values
print("Computing SHAP values...")
shap_values = explainer.shap_values(test_images)