In [19]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Input, BatchNormalization, Add
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
import joblib  # for saving/loading scaler

from gtda.homology import CubicalPersistence
from gtda.diagrams import PersistenceImage, PersistenceLandscape

# new dataset
X_img = np.load("/home/sajedhamdan/Desktop/skin_cancer/CNN+TDA/new_dataset/images_train_256x192.npy")
y = np.load("/home/sajedhamdan/Desktop/skin_cancer/CNN+TDA/new_dataset/train_labels.npy")
X_img = X_img.astype(np.float32)

sample_size = 728
X_img, _, y, _ = train_test_split(X_img, y, test_size=0.2, stratify=y, random_state=42)

# TDA Feature Extraction
def extract_tda_features(X_rgb):
    X_gray = 0.2989 * X_rgb[..., 0] + 0.5870 * X_rgb[..., 1] + 0.1140 * X_rgb[..., 2]
    cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
    diagrams = cp.fit_transform(X_gray)

    pi = PersistenceImage(sigma=1.0, n_bins=20, weight_function=lambda x: x[1] ** 2)
    pi_feat = pi.fit_transform(diagrams).reshape(len(diagrams), -1)

    pl = PersistenceLandscape(n_layers=5, n_bins=50)
    pl_feat = pl.fit_transform(diagrams).reshape(len(diagrams), -1)

    return np.hstack((pi_feat, pl_feat))

print("Extracting TDA features...")
X_tda = extract_tda_features(X_img)
print("TDA shape:", X_tda.shape)

smote = SMOTE(k_neighbors=1, random_state=42)
X_bal, y_bal = smote.fit_resample(X_tda, y)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_bal)

# Save scaler for later use
# joblib.dump(scaler, "tda_feature_scaler.joblib")

y_cat = to_categorical(y_bal)

class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_bal), y=y_bal)
class_weight_dict = dict(enumerate(class_weights))

# train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y_cat, test_size=0.2, stratify=y_cat, random_state=42
)

# MLP Model (ResNet-style)
input_layer = Input(shape=(X_scaled.shape[1],))

# block 1
x = Dense(512, activation='relu')(input_layer)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
res1 = Dense(512)(x)
res1 = BatchNormalization()(res1)

# block 2
x = Dense(512, activation='relu')(res1)
x = BatchNormalization()(x)
x = Dropout(0.3)(x)
res2 = Add()([x, res1]) 

# block 3
x = Dense(256, activation='relu')(res2)
x = BatchNormalization()(x)
x = Dropout(0.3)(x)
res3 = Dense(256)(x)
res3 = BatchNormalization()(res3)
x = Add()([x, res3])  

# output layer
output_layer = Dense(6, activation='softmax')(x)

model = Model(inputs=input_layer, outputs=output_layer)

model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
)

early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
lr_schedule = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1)

history = model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=64,
    validation_data=(X_test, y_test),
    class_weight=class_weight_dict,
    callbacks=[early_stop, lr_schedule],
    verbose=1
)

results = model.evaluate(X_test, y_test, verbose=1)
print(f"\nTest - Accuracy: {results[1]:.4f} | Precision: {results[2]:.4f} | Recall: {results[3]:.4f}")

# save model for later use
# model.save("tda_resnet_model_v2.keras")


Extracting TDA features...
TDA shape: (583, 1300)
Epoch 1/100
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 22ms/step - accuracy: 0.4827 - loss: 1.7430 - precision_1: 0.5430 - recall_1: 0.4220 - val_accuracy: 0.7542 - val_loss: 0.7422 - val_precision_1: 0.8216 - val_recall_1: 0.7054 - learning_rate: 1.0000e-04
Epoch 2/100
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 17ms/step - accuracy: 0.7954 - loss: 0.6879 - precision_1: 0.8299 - recall_1: 0.7713 - val_accuracy: 0.8485 - val_loss: 0.5141 - val_precision_1: 0.8839 - val_recall_1: 0.7694 - learning_rate: 1.0000e-04
Epoch 3/100
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 16ms/step - accuracy: 0.7944 - loss: 0.6170 - precision_1: 0.8192 - recall_1: 0.7720 - val_accuracy: 0.8923 - val_loss: 0.4081 - val_precision_1: 0.9167 - val_recall_1: 0.8519 - learning_rate: 1.0000e-04
Epoch 4/100
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 18ms/step - accuracy: 0.8508 - los