In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping


# ------------------------------
# DATA LOADING FUNCTIONS
# ------------------------------
def load_images_and_labels(root, image_size=(224, 224)):
   """
   Loads labeled images and their numeric labels from the given folder.
   Assumes subdirectories "1", "2", "3" corresponding to tumor classes.


   Returns:
       - X: Array of images (normalized)
       - y: Array of corresponding labels (0, 1, 2)
   """
   from PIL import Image
   images_list, labels_list = [], []
  
   # Mapping: Convert folder names "1", "2", "3" to class labels {0,1,2}
   class_mapping = {"1": 0, "2": 1, "3": 2}


   for class_name, label in class_mapping.items():
       label_dir = os.path.join(root, class_name)
       if not os.path.isdir(label_dir):
           print(f"⚠️ Warning: Directory not found -> {label_dir}")
           continue


       for fn in os.listdir(label_dir):
           if fn.lower().endswith((".jpg", ".jpeg", ".png")):
               img_path = os.path.join(label_dir, fn)
               try:
                   image = Image.open(img_path).convert("RGB")
                   image = image.resize(image_size)
                   image = np.array(image, dtype=np.float32) / 255.0  # Normalize to [0,1]
                   images_list.append(image)
                   labels_list.append(label)  # Store correct class label
               except Exception as e:
                   print(f"⚠️ Error loading {img_path}: {e}")


   if len(images_list) == 0:
       raise ValueError("No images were loaded! Check dataset structure.")


   return np.array(images_list), np.array(labels_list)


def load_teacher_soft_labels(soft_labels_path):
   """
   Loads precomputed teacher soft labels from an NPZ file.
   The NPZ file should contain an array under the key "soft_labels" of shape (num_images, num_classes).
   """
   data = np.load(soft_labels_path)
   return data["soft_labels"]


# ------------------------------
# TINYVIT STUDENT MODEL ARCHITECTURE
# ------------------------------
def build_tinyvit_student(input_shape, num_classes):
   """
   Builds a TinyViT student model with self-attention layers.
   """
   def tiny_vit_encoder(image_size, patch_size, num_layers, embed_dim, num_heads, mlp_units, dropout_rate):
       inputs = Input(shape=(image_size, image_size, 3))
       patch_embedding = tf.keras.layers.Conv2D(filters=embed_dim, kernel_size=patch_size,
                                                strides=patch_size, padding="valid")(inputs)
       patches = tf.keras.layers.Reshape((-1, embed_dim))(patch_embedding)
      
       x = patches
       for _ in range(num_layers):
           x1 = tf.keras.layers.LayerNormalization()(x)
           attention_output = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)(x1, x1)
           x2 = tf.keras.layers.Add()([x1, attention_output])
           x3 = tf.keras.layers.LayerNormalization()(x2)
           feed_forward = tf.keras.layers.Dense(mlp_units, activation="gelu")(x3)
           feed_forward = tf.keras.layers.Dense(embed_dim)(feed_forward)
           x = tf.keras.layers.Add()([x2, feed_forward])
      
       x = tf.keras.layers.LayerNormalization()(x)
       x = tf.keras.layers.GlobalAveragePooling1D()(x)
       return inputs, x


   inputs, enc_outputs = tiny_vit_encoder(224, 16, 2, 64, 4, 128, 0.1)
   outputs = Dense(num_classes, activation="softmax")(enc_outputs)
   return Model(inputs, outputs)


# ------------------------------
# TRAINING FUNCTION WITH EARLY STOPPING
# ------------------------------
def train_student(student_model, X_train, y_train, teacher_soft_train,
                 X_val, y_val, teacher_soft_val,
                 batch_size=16, epochs=45, alpha=0.25, temperature=3.0):
  
   num_classes = student_model.output_shape[-1]
   y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)
   y_val_cat = tf.keras.utils.to_categorical(y_val, num_classes=num_classes)
  
   # ✅ Compile Model (Fixes the error)
   student_model.compile(
       optimizer=Adam(learning_rate=0.001),
       loss="categorical_crossentropy",
       metrics=["accuracy"]
   )
  
   early_stopping = EarlyStopping(
       monitor="val_loss",  
       patience=5,          
       restore_best_weights=True 
   )


   print("🚀 Training started with Early Stopping (patience=5)...")
  
   history = student_model.fit(
       X_train, y_train_cat,
       validation_data=(X_val, y_val_cat),
       epochs=epochs,
       batch_size=batch_size,
       callbacks=[early_stopping]
   )


   return student_model, history.history




# ------------------------------
# PLOTTING FUNCTIONS
# ------------------------------
def plot_learning_curve(history):
   """
   Plots training & validation loss and accuracy curves.
   """
   epochs = len(history["loss"])
   plt.figure(figsize=(12, 5))


   plt.subplot(1, 2, 1)
   plt.plot(range(1, epochs+1), history["loss"], label="Train Loss")
   plt.plot(range(1, epochs+1), history["val_loss"], label="Validation Loss")
   plt.xlabel("Epoch")
   plt.ylabel("Loss")
   plt.title("Loss over Epochs")
   plt.legend()
   plt.grid(True)


   plt.subplot(1, 2, 2)
   plt.plot(range(1, epochs+1), history["accuracy"], label="Train Accuracy")
   plt.plot(range(1, epochs+1), history["val_accuracy"], label="Validation Accuracy")
   plt.xlabel("Epoch")
   plt.ylabel("Accuracy")
   plt.title("Accuracy over Epochs")
   plt.legend()
   plt.grid(True)


   plt.show()


def plot_roc_curve(y_true, y_pred_proba, num_classes):
   """
   Plots the ROC curve for multi-class classification.
   """
   plt.figure(figsize=(8, 6))


   for i in range(num_classes):
       fpr, tpr, _ = roc_curve(y_true[:, i], y_pred_proba[:, i])
       roc_auc = auc(fpr, tpr)
       plt.plot(fpr, tpr, label=f"Class {i} (AUC = {roc_auc:.2f})")


   plt.plot([0, 1], [0, 1], "k--") 
   plt.xlabel("False Positive Rate")
   plt.ylabel("True Positive Rate")
   plt.title("Multi-Class ROC Curve")
   plt.legend()
   plt.grid(True)
   plt.show()


# ------------------------------
# MAIN TRAINING SCRIPT
# ------------------------------
if __name__ == "__main__":
   labeled_data_path = "/Users/morgan/Desktop/BT-ClassificationDissCode/LabelledFigshare"
   soft_labels_path = "soft_labels.npz"
   input_shape = (224, 224, 3)
   num_classes = 3


   X, y = load_images_and_labels(labeled_data_path, image_size=(224, 224))
   print(f"Loaded {X.shape[0]} images with labels: {np.unique(y)}")
  
   X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
  
   teacher_soft = load_teacher_soft_labels(soft_labels_path)
   teacher_soft_train, teacher_soft_val = train_test_split(teacher_soft, test_size=0.2, random_state=42)
  
   student_model = build_tinyvit_student(input_shape, num_classes)
   student_model.summary()
  
   trained_student, history = train_student(student_model, X_train, y_train, teacher_soft_train,
                                            X_val, y_val, teacher_soft_val,
                                            batch_size=16, epochs=45, alpha=0.25, temperature=3.0)
  
   trained_student.save("tinyvit_student_model_best.h5")
   print("✅ Model saved as 'tinyvit_student_model_best.h5'")


   plot_learning_curve(history)


   y_val_cat = tf.keras.utils.to_categorical(y_val, num_classes=num_classes)
   y_pred_proba = trained_student.predict(X_val)


   plot_roc_curve(y_val_cat, y_pred_proba, num_classes)