In [None]:
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras import mixed_precision
from tensorflow.keras.applications.efficientnet import EfficientNetB3, preprocess_input
from tensorflow.keras.applications.efficientnet import EfficientNetB4
import matplotlib.pyplot as plt
import h5py
import gc
from tensorflow import keras as K
from tensorflow.keras import layers as L
import datetime



tf.config.experimental.enable_tensor_float_32_execution(True)

tf.config.optimizer.set_jit(True)
mixed_precision.set_global_policy('mixed_float16')

In [None]:
#View dataset distribution

H5_PATH = "images/Galaxy10_DECals.h5"
LABEL_KEY = "ans" 

def main():
    with h5py.File(H5_PATH, "r") as f:
        y = f[LABEL_KEY][:]  # read labels only

    # Handle common label formats:
    # 1) integer class ids: shape (N,)
    # 2) one-hot / multi-class probs: shape (N, C)
    y = np.asarray(y)
    if y.ndim == 2:
        # one-hot or probabilities -> convert to class ids
        y_ids = np.argmax(y, axis=1).astype(np.int64)
        num_classes = y.shape[1]
    elif y.ndim == 1:
        y_ids = y.astype(np.int64)
        num_classes = int(y_ids.max()) + 1
    else:
        raise ValueError(f"Unexpected label shape: {y.shape}")

    counts = np.bincount(y_ids, minlength=num_classes)
    total = counts.sum()

    print(f"Total samples: {total}\n")
    print("Class distribution:")
    for c, n in enumerate(counts):
        pct = (n / total * 100.0) if total else 0.0
        print(f"  Class {c:>2}: {n:>7}  ({pct:>6.2f}%)")

    try:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.bar(np.arange(num_classes), counts)
        plt.xlabel("Class")
        plt.ylabel("Count")
        plt.title("Galaxy10_DECals class distribution")
        plt.xticks(np.arange(num_classes))
        plt.tight_layout()
        plt.show()
    except ImportError:
        pass

if __name__ == "__main__":
    main()


In [None]:
#View a small sample

import numpy as np
import matplotlib.pyplot as plt

#load data
import h5py
with h5py.File("images/Galaxy10_DECals.h5", "r") as f:
    X = f["images"][:]
    y = f["ans"][:]

label_names = [
    "Disturbed", "Merging", "Round Smooth", "In-between Round Smooth", "Cigar",
    "Barred Spiral", "Tight Spiral", "Loose Spiral", "Edge-on (no bulge)", "Edge-on (with bulge)"
]

def show_random_grid(X, y, nrows=3, ncols=5, seed=None):
    n = nrows * ncols
    if len(X) < n:
        raise ValueError(f"Need at least {n} images, but got {len(X)}")

    rng = np.random.default_rng(seed)
    idx = rng.choice(len(X), size=n, replace=False)

    fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
    axes = np.array(axes).reshape(-1)

    for ax, i in zip(axes, idx):
        img = X[i]
        ax.imshow(img)
        label = int(y[i])
        title = f"{label}: {label_names[label]}" if 0 <= label < len(label_names) else str(label)
        ax.set_title(title, fontsize=10)
        ax.axis("off")

    plt.tight_layout()
    plt.show()


show_random_grid(X, y, nrows=3, ncols=5, seed=42)



In [None]:
#Load Data

with h5py.File("images/Galaxy10_DECals.h5", "r") as f:
    X = f["images"][:]   # uint8 array in RAM (~3.3GB)
    y = f["ans"][:]      # labels

ds = tf.data.Dataset.from_tensor_slices((X,y))

ds_size = len(ds)
ds = ds.shuffle(buffer_size = ds_size, seed=42, reshuffle_each_iteration=False)

#Splitting dataset
train_size = int(0.8 * ds_size)
val_size   = int(0.1 * ds_size)

In [None]:
#Splitting dataset
train_size = int(0.8 * ds_size)
val_size   = int(0.1 * ds_size)

train_ds = ds.take(train_size)
rest_ds  = ds.skip(train_size)
val_ds   = rest_ds.take(val_size)
test_ds  = rest_ds.skip(val_size)

# (optional) free big arrays
del X, y
gc.collect()




In [None]:
# Compute Class weights and boost disturbed class
NUM_CLASSES = 10

# y is scalar (shape ()), so collect into 1D array
y_train_ids = np.fromiter((int(y.numpy()) for _, y in train_ds),
                          dtype=np.int32,
                          count=train_size)

counts = np.bincount(y_train_ids, minlength=NUM_CLASSES)

w = counts.sum() / (NUM_CLASSES * counts)          # inverse freq
w = np.sqrt(w).astype(np.float32)                  # soften


DISTURBED_ID = 0
LOOSE_SPIRAL_ID = 7

DISTURBED_BOOST = 1.1   # start here; try 1.25–2.0
LOOSE_SPIRAL_BOOST = 1.15
w[LOOSE_SPIRAL_ID] *= LOOSE_SPIRAL_BOOST
w[DISTURBED_ID] *= DISTURBED_BOOST

w = np.clip(w, 0.7, 2.0).astype(np.float32)        # tighter clip

cw = tf.constant(w, dtype=tf.float32)

In [None]:
print(cw)

In [None]:
H, W = 256, 256 #before resizing needed for strong augment

def add_sample_weight(x, y):
    y_id = tf.argmax(y, axis=-1)          # one-hot -> id
    w = tf.gather(cw, y_id)               # weight for that class
    return x, y, w

def to_onehot(x, y):
    # x = tf.cast(x, tf.float32)
    y = tf.one_hot(tf.cast(y, tf.int32), depth=NUM_CLASSES)
    return x, y    
    

In [None]:
IMAGE_SIZE = [300, 300]
EPOCHS = 15
BATCH_SIZE = 48
DISTURBED_ID = 0


#Oversampling disturbed Galaxies

with h5py.File("disturbed_oversampled.h5", "r") as f:
    X_disturbed = f["images"][:]
    y_disturbed = f["ans"][:]


# Create dataset from oversampled disturbed examples
disturbed_ds = tf.data.Dataset.from_tensor_slices((X_disturbed, y_disturbed))
# Combine with original training set
# train_ds is already split from your original code
train_ds_combined = train_ds.concatenate(disturbed_ds)

# Shuffle the combined dataset
train_ds_combined = train_ds_combined.shuffle(
    buffer_size=len(train_ds_combined), 
    seed=42, 
    reshuffle_each_iteration=False
)

print(f"Original training size: {train_size}")
print(f"Combined training size: {train_size + len(X_disturbed)}")

# Clean up
del X_disturbed, y_disturbed
gc.collect()



train_ds_w = (train_ds_combined
    .batch(BATCH_SIZE, drop_remainder=True)
    # .map(extra_aug_only_disturbed, num_parallel_calls=tf.data.AUTOTUNE)
    .map(to_onehot, num_parallel_calls=tf.data.AUTOTUNE)
    .map(add_sample_weight, num_parallel_calls=tf.data.AUTOTUNE)
    .prefetch(tf.data.AUTOTUNE)
)



val_ds = (val_ds
          .map(to_onehot, num_parallel_calls=tf.data.AUTOTUNE)
          .batch(BATCH_SIZE)
          .prefetch(tf.data.AUTOTUNE))

test_ds = (test_ds
           .map(to_onehot, num_parallel_calls=tf.data.AUTOTUNE)
           .batch(BATCH_SIZE)
           .prefetch(tf.data.AUTOTUNE))


In [None]:
xb, yb = next(iter(train_ds.take(1)))
print("x min/max:", tf.reduce_min(xb).numpy(), tf.reduce_max(xb).numpy())
print("x dtype:", xb.dtype, "y shape:", yb.shape)


xb, yb, wb = next(iter(train_ds_w.take(1)))
print("x:", xb.shape, xb.dtype)
print("y:", yb.shape, yb.dtype)
print("w:", wb.shape, wb.dtype, "min/max:", float(tf.reduce_min(wb)), float(tf.reduce_max(wb)))

xb2, yb2 = next(iter(val_ds.take(1)))
print("val y:", yb2.shape)



In [None]:
options = tf.data.Options()
options.autotune.enabled = True
options.experimental_optimization.map_parallelization = True
options.experimental_optimization.parallel_batch = True
train_ds_w = train_ds_w.with_options(options)
val_ds = val_ds.with_options(options)


In [None]:
#load encoder

with tf.device('/gpu:0'): #load into gpu
    pretrained_model = tf.keras.applications.EfficientNetB4(
        weights = 'imagenet', # use pre-trained weights
        include_top = False, # we are creating the head below
        # input_shape = [*IMAGE_SIZE, 3]
        input_shape = [380,380 , 3]
    )

    pretrained_model.trainable = False



In [None]:
#Model Architecture

@K.utils.register_keras_serializable()
class CastToFloat32(K.layers.Layer):
    def call(self, x):
        x = tf.cast(x, tf.float32)
        return preprocess_input(x)

    def compute_output_shape(self, input_shape):
        return input_shape




inputs = tf.keras.Input(shape=(256,256,3))
#Removed due to not being compatible with jit_compile
# x = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=42)(inputs)
# x = tf.keras.layers.RandomRotation(0.08, fill_mode="reflect", interpolation="bilinear", seed=42)(x)
# x = tf.keras.layers.RandomBrightness(factor=0.1, value_range=(0, 255),seed=42)(x)
# x = tf.keras.layers.RandomContrast(factor=0.2, seed=42)(x)


x = tf.keras.layers.Resizing(380, 380, interpolation='bilinear')(inputs)


#Moved augmentation to model to increase work done on GPU
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal_and_vertical", dtype="float32"),
    tf.keras.layers.RandomBrightness(0.25, value_range=(0.0, 255.0), dtype="float32"),
    tf.keras.layers.RandomContrast(0.2, dtype="float32"),
    tf.keras.layers.RandomErasing(factor=0.25, value_range=(0.0, 255.0), dtype="float32"),
    tf.keras.layers.RandomGaussianBlur(factor=0.15, value_range=(0.0, 255.0), dtype="float32")
], name="aug")

x = CastToFloat32(name="to_fp32")(x)
x = data_augmentation(x)



x = pretrained_model(x, training=False) #pass inputs through pretrained_model

x = L.Conv2D(128, 1, padding = 'same', use_bias=False)(x) #compress 1536 channels down (last conv layer in B3) to 256
x = L.BatchNormalization()(x)
x = L.ReLU()(x)

x = L.Conv2D(128, 3, padding = 'same', use_bias=False)(x) 
x = L.BatchNormalization()(x)
x = L.ReLU()(x)

x = L.Conv2D(128, 3, padding = 'same', use_bias=False)(x)
x = L.BatchNormalization()(x)
x = L.ReLU()(x)



x = L.Conv2D(64, 3, padding = 'same', use_bias=False)(x)
x = L.BatchNormalization()(x)
x = L.ReLU()(x)


x = L.Conv2D(64, 3, padding = 'same', use_bias=False)(x) 
x = L.BatchNormalization()(x)
x = L.ReLU()(x)


x = L.Conv2D(64, 3, padding = 'same', use_bias=False)(x)
x = L.BatchNormalization()(x)
x = L.ReLU()(x)


x = L.Conv2D(48, 1, padding = 'same', use_bias=False)(x) 
x = L.BatchNormalization()(x)
x = L.ReLU()(x)


x = L.Conv2D(48, 3, padding = 'same', use_bias=False)(x) 
x = L.BatchNormalization()(x)
x = L.ReLU()(x)


x = L.Conv2D(48, 3, padding = 'same', use_bias=False)(x) 
x = L.BatchNormalization()(x)
x = L.ReLU()(x)

x = L.Conv2D(32, 1, padding = 'same', use_bias=False)(x) 
x = L.BatchNormalization()(x)
x = L.ReLU()(x)


x = L.GlobalAveragePooling2D()(x)
x = L.Dropout(0.35)(x)

outputs = L.Dense(10, activation='softmax', dtype='float32')(x)

model = tf.keras.Model(inputs, outputs)

In [None]:


model.compile(
    optimizer=K.optimizers.Adam(1e-4),
    loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
    metrics=['categorical_accuracy'],
    jit_compile=True
)

model.summary()

In [None]:
from tensorflow.keras.callbacks import Callback
from tqdm.auto import tqdm


class TQDMBar(Callback):
    def on_train_begin(self, logs=None):
        total = self.params["epochs"] - self.params.get("initial_epoch", 0)
        self.pbar = tqdm(total=total)
    def on_epoch_end(self, epoch, logs=None):
        self.pbar.update(1)
    def on_train_end(self, logs=None):
        self.pbar.close()

#Base list of callbacks to manage LR, early stopping, and saving the best model
cbs = [
  tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', #Watch val_loss
                                       factor=0.5, # when plateauing, multiply LR by 0.5
                                       patience=2, # wait 2 epochs of no val_loss improvement before reducing LR
                                       min_lr=5e-6, # dont reduce LR below this minimum
                                       verbose=1),
  tf.keras.callbacks.EarlyStopping(
                                  monitor='val_loss', # Stop based on validation loss
                                   patience=4, # stop if no improvement for 4 epochs
                                   restore_best_weights=True # revert to best weights seen during training
                                  ),
  tf.keras.callbacks.ModelCheckpoint(
      'best.keras', # where the best model will be saved plus name
      monitor='val_loss', # pick "best" by lowest validation loss
      save_best_only=True # only overwrite file when val_loss improves
  )
]


#View training metrics
logdir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tb = tf.keras.callbacks.TensorBoard(
    log_dir=logdir,
    profile_batch=(370, 400)  
)



#all Callbacks combined
all_callbacks = cbs + [TQDMBar()] + [tb]


history = model.fit(
    train_ds_w, # Training dataset
    validation_data=val_ds, # Validation set
    epochs=EPOCHS, #number of epochs to run
    callbacks=all_callbacks
)

In [None]:
#Unfreezing layers in backbone and starting state 2 fine-tuning

FT_EPOCHS = 20  #fine-tune epochs for stage 2

#steps/epoch for cosine schedule
steps_per_epoch = tf.data.experimental.cardinality(train_ds_w).numpy()

#Handle special cardinality values
if steps_per_epoch == tf.data.experimental.INFINITE_CARDINALITY:
    raise ValueError("train_ds has infinite cardinality; cannot infer steps_per_epoch.")
if steps_per_epoch < 0:
    raise ValueError(
        "train_ds cardinality is UNKNOWN. "
        "Compute steps_per_epoch manually as ceil(num_train_examples / BATCH_SIZE)."
    )

total_decay_steps = steps_per_epoch * FT_EPOCHS
print("steps_per_epoch:", steps_per_epoch)
print("total_decay_steps:", total_decay_steps)


pretrained_model.trainable = True

for layer in pretrained_model.layers:
    layer.trainable = False

#Unfreezing last 150 layers
for layer in pretrained_model.layers[-50:]:
    if not isinstance(layer, tf.keras.layers.BatchNormalization):
        layer.trainable = True

# ---- cosine decay learning rate ----
lr0 = 1e-5
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=lr0,
    decay_steps=total_decay_steps,
    alpha=0.05,  # ends at 5% of lr0
)

optimizer = K.optimizers.Adam(learning_rate=lr_schedule)

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.05),
    metrics=[tf.keras.metrics.CategoricalAccuracy(name="acc")],
    jit_compile=True
)

model.summary()




logdir2 = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tb2 = tf.keras.callbacks.TensorBoard(
    log_dir=logdir,
    profile_batch=(370, 400) 
)





cbs2 = [
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=6, restore_best_weights=True),
    tf.keras.callbacks.ModelCheckpoint("best_stage2_cosine.keras", monitor="val_loss", save_best_only=True),
    TQDMBar(),
    tb2,
]

#Continuing from stage 1
initial_epoch = len(history.history["loss"])
history2 = model.fit(
    train_ds_w,
    validation_data=val_ds,
    epochs=initial_epoch + FT_EPOCHS,
    initial_epoch=initial_epoch,
    callbacks=cbs2
)


In [None]:
#View training metrics
%load_ext tensorboard
%tensorboard --logdir logs/fit

In [None]:
model.save("galaxy_b3_fixed.keras")                 # full model, reloadable
model.save_weights("galaxy_b3_fixed.weights.h5")    # weights-only escape hatch


In [None]:
import os
import pickle
from sklearn.metrics import (
    classification_report, f1_score, precision_score, recall_score, accuracy_score
)


class_names = [
    "Disturbed", "Merging", "Round Smooth", "In-between Round Smooth", "Cigar",
    "Barred Spiral", "Tight Spiral", "Loose Spiral", "Edge-on (no bulge)", "Edge-on (with bulge)"
]

#Collect y_true from the dataset
y_true_raw = np.concatenate([y.numpy() for _, y in test_ds], axis=0)

#Convert one-hot -> class ids
if y_true_raw.ndim == 2:
    y_true = np.argmax(y_true_raw, axis=1)
else:
    y_true = y_true_raw.astype(int)

#Get model predictions
y_prob = model.predict(test_ds, verbose=0)

#Adjust threshold for disturbed class
DISTURBED_THRESHOLD = 0.58
y_pred_standard = np.argmax(y_prob, axis=1)
y_pred = y_pred_standard.copy()

disturbed_mask = (y_pred_standard == DISTURBED_ID)
disturbed_probs = y_prob[disturbed_mask, DISTURBED_ID]
low_confidence = disturbed_probs < DISTURBED_THRESHOLD
disturbed_indices = np.where(disturbed_mask)[0]
low_confidence_indices = disturbed_indices[low_confidence]

for idx in low_confidence_indices:
    probs_sorted = np.argsort(y_prob[idx])[::-1]
    y_pred[idx] = probs_sorted[1]

print(f"Using disturbed threshold: {DISTURBED_THRESHOLD}")
print(f"Disturbed predictions before threshold: {np.sum(y_pred_standard == DISTURBED_ID)}")
print(f"Disturbed predictions after threshold: {np.sum(y_pred == DISTURBED_ID)}")
print(f"Predictions changed: {np.sum(y_pred != y_pred_standard)}\n")

# ============= ANALYZE TRUE POSITIVES (CORRECT DISTURBED PREDICTIONS) =============
disturbed_true_mask = (y_true == DISTURBED_ID)
disturbed_pred_mask = (y_pred == DISTURBED_ID)
true_positive_mask = disturbed_true_mask & disturbed_pred_mask

# Get indices of true positives
tp_indices = np.where(true_positive_mask)[0]

# Calculate statistics for true positives
tp_disturbed_probs = y_prob[tp_indices, DISTURBED_ID]

print(f"=== True Positive Analysis (Correct Disturbed Predictions) ===")
print(f"Total true positives: {len(tp_indices)}")
print(f"\n=== Disturbed Probability Statistics for True Positives ===")
print(f"Mean: {np.mean(tp_disturbed_probs):.4f}")
print(f"Median: {np.median(tp_disturbed_probs):.4f}")
print(f"Min: {np.min(tp_disturbed_probs):.4f}")
print(f"Max: {np.max(tp_disturbed_probs):.4f}")
print(f"Std: {np.std(tp_disturbed_probs):.4f}")

# Store true positive data
tp_data = []
for idx in tp_indices:
    disturbed_prob = y_prob[idx, DISTURBED_ID]
    second_best_class = np.argsort(y_prob[idx])[-2]
    second_best_prob = y_prob[idx, second_best_class]
    
    tp_data.append({
        'index': idx,
        'disturbed_prob': disturbed_prob,
        'second_best_class': class_names[second_best_class],
        'second_best_prob': second_best_prob,
        'all_probs': y_prob[idx].copy()
    })

# ============= ANALYZE FALSE POSITIVES (INCORRECT DISTURBED PREDICTIONS) =============
false_positive_mask = ~disturbed_true_mask & disturbed_pred_mask

# Get indices of false positives
fp_indices = np.where(false_positive_mask)[0]

# Calculate statistics for false positives
fp_disturbed_probs = y_prob[fp_indices, DISTURBED_ID]

print(f"\n=== False Positive Analysis (Incorrect Disturbed Predictions) ===")
print(f"Total false positives: {len(fp_indices)}")
print(f"\n=== Disturbed Probability Statistics for False Positives ===")
print(f"Mean: {np.mean(fp_disturbed_probs):.4f}")
print(f"Median: {np.median(fp_disturbed_probs):.4f}")
print(f"Min: {np.min(fp_disturbed_probs):.4f}")
print(f"Max: {np.max(fp_disturbed_probs):.4f}")
print(f"Std: {np.std(fp_disturbed_probs):.4f}")

# Analyze which classes were incorrectly predicted as disturbed
true_classes_fp = y_true[fp_indices]
unique_fp, counts_fp = np.unique(true_classes_fp, return_counts=True)

print(f"\n=== False Positives: True Classes ===")
for cls, count in zip(unique_fp, counts_fp):
    print(f"{class_names[cls]}: {count}")

# Store false positive data
fp_data = []
for idx in fp_indices:
    true_class = class_names[y_true[idx]]
    disturbed_prob = y_prob[idx, DISTURBED_ID]
    true_class_prob = y_prob[idx, y_true[idx]]
    
    fp_data.append({
        'index': idx,
        'true_class': true_class,
        'disturbed_prob': disturbed_prob,
        'true_class_prob': true_class_prob,
        'all_probs': y_prob[idx].copy()
    })

# Save analysis data
os.makedirs("disturbed_analysis", exist_ok=True)
with open("disturbed_analysis/tp_data.pkl", "wb") as f:
    pickle.dump(tp_data, f)
with open("disturbed_analysis/fp_data.pkl", "wb") as f:
    pickle.dump(fp_data, f)
print(f"\nSaved true positive data to disturbed_analysis/tp_data.pkl")
print(f"Saved false positive data to disturbed_analysis/fp_data.pkl")

# ============= ANALYZE FALSE NEGATIVES (from earlier) =============
false_negative_mask = disturbed_true_mask & ~disturbed_pred_mask
fn_indices = np.where(false_negative_mask)[0]
fn_disturbed_probs = y_prob[fn_indices, DISTURBED_ID]

print(f"\n=== False Negative Analysis (Missed Disturbed Galaxies) ===")
print(f"Total false negatives: {len(fn_indices)}")
print(f"\n=== Disturbed Probability Statistics for False Negatives ===")
print(f"Mean: {np.mean(fn_disturbed_probs):.4f}")
print(f"Median: {np.median(fn_disturbed_probs):.4f}")
print(f"Min: {np.min(fn_disturbed_probs):.4f}")
print(f"Max: {np.max(fn_disturbed_probs):.4f}")
print(f"Std: {np.std(fn_disturbed_probs):.4f}")

# ============= COMPREHENSIVE COMPARISON =============
print(f"\n{'='*90}")
print(f"=== COMPREHENSIVE COMPARISON: TP vs FP vs FN ===")
print(f"{'='*90}")
print(f"{'Metric':<35} {'True Positives':<18} {'False Positives':<18} {'False Negatives':<18}")
print("-" * 90)

print(f"{'Count':<35} {len(tp_disturbed_probs):<18} {len(fp_disturbed_probs):<18} {len(fn_disturbed_probs):<18}")
print(f"{'Mean Disturbed Prob':<35} {np.mean(tp_disturbed_probs):<18.4f} {np.mean(fp_disturbed_probs):<18.4f} {np.mean(fn_disturbed_probs):<18.4f}")
print(f"{'Median Disturbed Prob':<35} {np.median(tp_disturbed_probs):<18.4f} {np.median(fp_disturbed_probs):<18.4f} {np.median(fn_disturbed_probs):<18.4f}")
print(f"{'Min Disturbed Prob':<35} {np.min(tp_disturbed_probs):<18.4f} {np.min(fp_disturbed_probs):<18.4f} {np.min(fn_disturbed_probs):<18.4f}")
print(f"{'Max Disturbed Prob':<35} {np.max(tp_disturbed_probs):<18.4f} {np.max(fp_disturbed_probs):<18.4f} {np.max(fn_disturbed_probs):<18.4f}")
print(f"{'Std Disturbed Prob':<35} {np.std(tp_disturbed_probs):<18.4f} {np.std(fp_disturbed_probs):<18.4f} {np.std(fn_disturbed_probs):<18.4f}")

print(f"\n{'Interpretation:':<35}")
print(f"  TP avg > threshold: Model confident & correct")
print(f"  FP avg near threshold: Model uncertain, made mistakes")
print(f"  FN avg < threshold: Model uncertain, missed detections")

# Probability distribution comparison plot - 3 subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# TP distribution
axes[0].hist(tp_disturbed_probs, bins=30, alpha=0.7, edgecolor='black', color='green')
axes[0].axvline(x=DISTURBED_THRESHOLD, color='r', linestyle='--', linewidth=2, 
                label=f'Threshold = {DISTURBED_THRESHOLD}')
axes[0].axvline(x=np.mean(tp_disturbed_probs), color='blue', linestyle='-', linewidth=2, 
                label=f'Mean = {np.mean(tp_disturbed_probs):.3f}')
axes[0].set_xlabel('Disturbed Class Probability', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title(f'True Positives (n={len(tp_disturbed_probs)})\nCorrect Disturbed Predictions', fontsize=12, color='green')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# FP distribution
axes[1].hist(fp_disturbed_probs, bins=20, alpha=0.7, edgecolor='black', color='orange')
axes[1].axvline(x=DISTURBED_THRESHOLD, color='r', linestyle='--', linewidth=2, 
                label=f'Threshold = {DISTURBED_THRESHOLD}')
axes[1].axvline(x=np.mean(fp_disturbed_probs), color='blue', linestyle='-', linewidth=2, 
                label=f'Mean = {np.mean(fp_disturbed_probs):.3f}')
axes[1].set_xlabel('Disturbed Class Probability', fontsize=12)
axes[1].set_ylabel('Count', fontsize=12)
axes[1].set_title(f'False Positives (n={len(fp_disturbed_probs)})\nIncorrect Disturbed Predictions', fontsize=12, color='orange')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# FN distribution
axes[2].hist(fn_disturbed_probs, bins=20, alpha=0.7, edgecolor='black', color='red')
axes[2].axvline(x=DISTURBED_THRESHOLD, color='r', linestyle='--', linewidth=2, 
                label=f'Threshold = {DISTURBED_THRESHOLD}')
axes[2].axvline(x=np.mean(fn_disturbed_probs), color='blue', linestyle='-', linewidth=2, 
                label=f'Mean = {np.mean(fn_disturbed_probs):.3f}')
axes[2].set_xlabel('Disturbed Class Probability', fontsize=12)
axes[2].set_ylabel('Count', fontsize=12)
axes[2].set_title(f'False Negatives (n={len(fn_disturbed_probs)})\nMissed Disturbed Galaxies', fontsize=12, color='red')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('Disturbed Class Probability Distribution: TP vs FP vs FN', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig("disturbed_analysis/tp_fp_fn_probability_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

# Combined overlay plot for direct comparison
plt.figure(figsize=(12, 6))
plt.hist(tp_disturbed_probs, bins=30, alpha=0.5, edgecolor='black', color='green', label=f'TP (n={len(tp_disturbed_probs)}, μ={np.mean(tp_disturbed_probs):.3f})')
plt.hist(fp_disturbed_probs, bins=20, alpha=0.5, edgecolor='black', color='orange', label=f'FP (n={len(fp_disturbed_probs)}, μ={np.mean(fp_disturbed_probs):.3f})')
plt.hist(fn_disturbed_probs, bins=20, alpha=0.5, edgecolor='black', color='red', label=f'FN (n={len(fn_disturbed_probs)}, μ={np.mean(fn_disturbed_probs):.3f})')
plt.axvline(x=DISTURBED_THRESHOLD, color='black', linestyle='--', linewidth=2, label=f'Threshold = {DISTURBED_THRESHOLD}')
plt.xlabel('Disturbed Class Probability', fontsize=14)
plt.ylabel('Count', fontsize=14)
plt.title('Overlaid Probability Distributions: True Positives vs False Positives vs False Negatives', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("disturbed_analysis/overlay_tp_fp_fn_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

# 4) Build confusion matrix
num_classes = len(class_names) if class_names is not None else int(np.max(y_true)) + 1
cm = np.zeros((num_classes, num_classes), dtype=np.int64)
for t, p in zip(y_true, y_pred):
    cm[int(t), int(p)] += 1

# 5) Plot confusion matrix
plt.figure(figsize=(9, 8))
plt.imshow(cm, interpolation="nearest", cmap='viridis')
plt.title(f"Confusion Matrix (Disturbed Threshold={DISTURBED_THRESHOLD})")
plt.colorbar()
tick_marks = np.arange(num_classes)

labels = class_names if class_names is not None else [str(i) for i in range(num_classes)]
plt.xticks(tick_marks, labels, rotation=45, ha="right")
plt.yticks(tick_marks, labels)

thresh = cm.max() / 2.0
for i in range(num_classes):
    for j in range(num_classes):
        plt.text(
            j, i, str(cm[i, j]),
            ha="center", va="center",
            color="white" if cm[i, j] > thresh else "black"
        )

plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.show()

# 6) Overall metrics
acc = accuracy_score(y_true, y_pred)
f1_macro    = f1_score(y_true, y_pred, average="macro")
f1_weighted = f1_score(y_true, y_pred, average="weighted")
f1_micro    = f1_score(y_true, y_pred, average="micro")

prec_macro  = precision_score(y_true, y_pred, average="macro", zero_division=0)
rec_macro   = recall_score(y_true, y_pred, average="macro", zero_division=0)

print(f"\nAccuracy:     {acc:.4f}")
print(f"F1 macro:     {f1_macro:.4f}")
print(f"F1 weighted:  {f1_weighted:.4f}")
print(f"F1 micro:     {f1_micro:.4f}")
print(f"Precision(m): {prec_macro:.4f}")
print(f"Recall(m):    {rec_macro:.4f}")

# 7) Per-class report
print("\nClassification report:\n")
print(classification_report(
    y_true, y_pred,
    target_names=class_names,
    digits=4,
    zero_division=0
))

# 8) Show disturbed class metrics specifically
disturbed_pred = (y_pred == DISTURBED_ID)

tp = np.sum(disturbed_true_mask & disturbed_pred)
fp = np.sum(~disturbed_true_mask & disturbed_pred)
fn = np.sum(disturbed_true_mask & ~disturbed_pred)

precision_disturbed = tp / (tp + fp) if (tp + fp) > 0 else 0
recall_disturbed = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_disturbed = 2 * precision_disturbed * recall_disturbed / (precision_disturbed + recall_disturbed) if (precision_disturbed + recall_disturbed) > 0 else 0

print(f"\n=== Disturbed Class Detailed Metrics ===")
print(f"True Positives: {tp}")
print(f"False Positives: {fp}")
print(f"False Negatives: {fn}")
print(f"Precision: {precision_disturbed:.4f}")
print(f"Recall: {recall_disturbed:.4f}")
print(f"F1-Score: {f1_disturbed:.4f}")
print(f"\nAverage Disturbed Probability:")
print(f"  True Positives:  {np.mean(tp_disturbed_probs):.4f} (Confident & Correct)")
print(f"  False Positives: {np.mean(fp_disturbed_probs):.4f} (Uncertain, Made Mistake)")
print(f"  False Negatives: {np.mean(fn_disturbed_probs):.4f} (Uncertain, Missed Detection)")