In [9]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.layers import BatchNormalization


from gtda.homology import CubicalPersistence
from gtda.diagrams import PersistenceImage, PersistenceLandscape

X_img = np.load("/home/sajedhamdan/Desktop/skin_cancer/images_train_256x192.npy") 
y = np.load("/home/sajedhamdan/Desktop/skin_cancer/train_labels.npy")            
X_img = X_img.astype(np.float32)
y_cat = to_categorical(y)

sample_size = 2000
X_img, _, y, _ = train_test_split(
    X_img, y, train_size=sample_size, stratify=y, random_state=42
)


# TDA Feature Extraction
def extract_tda_features(X_rgb):
    X_gray = 0.2989 * X_rgb[..., 0] + 0.5870 * X_rgb[..., 1] + 0.1140 * X_rgb[..., 2]
    cp = CubicalPersistence(homology_dimensions=[0, 1], n_jobs=-1)
    diagrams = cp.fit_transform(X_gray)

    pi = PersistenceImage(sigma=1.0, n_bins=20, weight_function=lambda x: x[1] ** 2)
    pi_feat = pi.fit_transform(diagrams).reshape(len(diagrams), -1)

    pl = PersistenceLandscape(n_layers=5, n_bins=50)
    pl_feat = pl.fit_transform(diagrams).reshape(len(diagrams), -1)

    return np.hstack((pi_feat, pl_feat))

print("Extracting TDA features...")
X_tda_features = extract_tda_features(X_img)
print("TDA shape:", X_tda_features.shape)

print("Balancing with SMOTE...")
smote = SMOTE(random_state=42)
X_balanced, y_bal = smote.fit_resample(X_tda_features, y)

scaler = StandardScaler()
X_final = scaler.fit_transform(X_balanced)
y_bal_cat = to_categorical(y_bal)


y_bal_int = np.argmax(y_bal_cat, axis=1)
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_bal_int), y=y_bal_int)
class_weight_dict = dict(enumerate(class_weights))


print("Final input shape:", X_final.shape)

X_train, X_test, y_train, y_test = train_test_split(
    X_final, y_bal_cat, test_size=0.2, random_state=42, stratify=y_bal)


input_layer = Input(shape=(X_final.shape[1],))
x = Dense(512, activation='relu')(input_layer)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.3)(x)
output_layer = Dense(7, activation='softmax')(x)

model = Model(inputs=input_layer, outputs=output_layer)

model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
)

history = model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=64,
    validation_data=(X_test, y_test),
    class_weight=class_weight_dict,
    verbose=1
)


results = model.evaluate(X_test, y_test, verbose=1)
print(f"Test - Accuracy: {results[1]:.4f} | Precision: {results[2]:.4f} | Recall: {results[3]:.4f}")

model.save("tda_resnet_model_v2.keras")


Extracting TDA features...
TDA shape: (2000, 1300)
Balancing with SMOTE...
Final input shape: (9366, 1300)
Epoch 1/100
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 19ms/step - accuracy: 0.3440 - loss: 2.0223 - precision: 0.4123 - recall: 0.2339 - val_accuracy: 0.5688 - val_loss: 1.1984 - val_precision: 0.7390 - val_recall: 0.3959
Epoch 2/100
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 15ms/step - accuracy: 0.5231 - loss: 1.4153 - precision: 0.6043 - recall: 0.4092 - val_accuracy: 0.6734 - val_loss: 0.9578 - val_precision: 0.8102 - val_recall: 0.5171
Epoch 3/100
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 14ms/step - accuracy: 0.5637 - loss: 1.2574 - precision: 0.6485 - recall: 0.4576 - val_accuracy: 0.7193 - val_loss: 0.8564 - val_precision: 0.8190 - val_recall: 0.5651
Epoch 4/100
[1m118/118[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 18ms/step - accuracy: 0.6183 - loss: 1.1133 - precision: 0.6940 - recall: 0.5