In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight


print("TensorFlow Version:", tf.__version__)


# ------------------------------
# CONFIG & PATHS
# ------------------------------
labeled_data_path = "/Users/morgan/Desktop/BT-ClassificationDissCode/combineddataset/Training"
teacher_model_path = "ssl_teacher_model.keras"
output_model_path = "teacher_classifier_modelv5.keras"
num_classes = 3
epochs = 20
batch_size = 32


# ------------------------------
# DATA LOADING
# ------------------------------
def load_labeled_dataset(root_folder, image_size=(224, 224)):
   images_list, labels_list = [], []
   class_mapping = {"glioma": 0, "meningioma": 1, "pituitary": 2}


   for class_name, label in class_mapping.items():
       label_dir = os.path.join(root_folder, class_name)
       if not os.path.isdir(label_dir):
           print(f"Warning: Directory not found -> {label_dir}")
           continue


       for fn in os.listdir(label_dir):
           if fn.lower().endswith((".jpg", ".jpeg", ".png")):
               img_path = os.path.join(label_dir, fn)
               try:
                   image = tf.keras.preprocessing.image.load_img(img_path, target_size=image_size)
                   image = tf.keras.preprocessing.image.img_to_array(image) / 255.0
                   # Apply expanded data augmentation
                   image = tf.image.random_flip_left_right(image)
                   image = tf.image.random_brightness(image, max_delta=0.1)
                   image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
                   image = tf.image.rot90(image, k=tf.random.uniform(shape=(), minval=0, maxval=4, dtype=tf.int32))
                   # Simulate random zoom with resize and random scaling
                   scale = tf.random.uniform(shape=(), minval=0.9, maxval=1.1, dtype=tf.float32)
                   new_size = (int(image_size[0] * scale), int(image_size[1] * scale))
                   image = tf.image.resize(image, new_size)
                   image = tf.image.resize(image, image_size)  # Resize back to 224x224
                   images_list.append(image.numpy())  # Convert tensor back to numpy for stacking
                   labels_list.append(label)
               except Exception as e:
                   print(f"⚠️ [Skip] {img_path}: {e}")


   if len(images_list) == 0:
       raise ValueError("No labeled images found! Check dataset structure.")


   return np.array(images_list), np.array(labels_list)


# ------------------------------
# BUILD CLASSIFIER HEAD
# ------------------------------
def build_teacher_classifier(teacher_model_path, num_classes):
   print("Loading fine-tuned SSL teacher model...")
   teacher_model = tf.keras.models.load_model(teacher_model_path)


   # If teacher output is more than 2D, pool it
   last_layer = teacher_model.output
   if len(last_layer.shape) > 2:
       last_layer = GlobalAveragePooling2D(name="global_avg_pool")(last_layer)


   # Add a dense layer with L2 regularization, dropout, batch normalization, and final softmax
   x = Dense(128, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.01),
             name="dense_classifier_1")(last_layer)
   x = BatchNormalization()(x)  # Added for stability
   x = Dropout(0.7, name="dropout_classifier")(x)  # Increased dropout to 0.7
   classifier_output = Dense(num_classes, activation="softmax", name="classification_head")(x)


   classifier_model = Model(inputs=teacher_model.input, outputs=classifier_output)
   print("Teacher model updated with classification head!")
   return classifier_model


# ------------------------------
# TRAIN CLASSIFICATION HEAD
# ------------------------------
def train_classifier(model, X_train, y_train, X_val, y_val, batch_size=32, epochs=20):
   # One-hot encode labels
   y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)
   y_val_one_hot = tf.keras.utils.to_categorical(y_val, num_classes=num_classes)


   # Class weights
   class_weights = compute_class_weight(
       class_weight="balanced",
       classes=np.unique(y_train),
       y=y_train
   )
   class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}


   # Compile with cosine decay learning rate
   lr_schedule = tf.keras.optimizers.schedules.CosineDecay(1e-4, epochs * len(X_train) // batch_size)
   model.compile(optimizer=Adam(learning_rate=lr_schedule),
                 loss="categorical_crossentropy",
                 metrics=["accuracy"])


   # Callbacks
   early_stopping = EarlyStopping(monitor="val_loss", patience=4, restore_best_weights=True)
   reduce_lr = ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, verbose=1)
   checkpoint = ModelCheckpoint("best_teacher_classifier.keras", save_best_only=True, monitor="val_loss")


   print("Starting training...")
   history = model.fit(X_train, y_train_one_hot,
                       validation_data=(X_val, y_val_one_hot),
                       epochs=epochs,
                       batch_size=batch_size,
                       class_weight=class_weight_dict,
                       callbacks=[early_stopping, reduce_lr, checkpoint])


   return model, history


# ------------------------------
# MAIN
# ------------------------------
if __name__ == "__main__":
   print("Loading labeled dataset...")
   X, y = load_labeled_dataset(labeled_data_path)


   print("Splitting dataset into training and validation sets...")
   X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)


   print("Building classifier model on top of teacher encoder...")
   teacher_classifier = build_teacher_classifier(teacher_model_path, num_classes)


   print("Training classifier on labeled MRI dataset...")
   trained_classifier, history = train_classifier(teacher_classifier, X_train, y_train, X_val, y_val, epochs=epochs)


   trained_classifier.save(output_model_path)
   print(f"Trained teacher classifier saved as '{output_model_path}'")


   def plot_training_history(history):
       plt.figure(figsize=(12, 5))
       plt.subplot(1, 2, 1)
       plt.plot(history.history["accuracy"], label="Train Accuracy")
       plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
       plt.title("Training & Validation Accuracy")
       plt.legend()
       plt.subplot(1, 2, 2)
       plt.plot(history.history["loss"], label="Train Loss")
       plt.plot(history.history["val_loss"], label="Validation Loss")
       plt.title("Training & Validation Loss")
       plt.legend()
       plt.show()


   print("Plotting training history...")
   plot_training_history(history)