In [None]:
# Install required libraries
!pip install keras scikit-learn seaborn matplotlib tensorflow

# Core libraries
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Computer vision
import cv2
from PIL import Image

# Machine learning
import keras
from keras.applications import ResNet50
from keras.layers import Dense, GlobalAveragePooling2D, Dropout
from keras.models import Model
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
# Use the ImageDataGenerator from tf.keras
# Correct the import statement from tf.keras to tensorflow.keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Directories for dataset
train_dir = "/content/drive/MyDrive/MIT URTC 2025/Experiments/experiment_F/train"
validation_dir = "/content/drive/MyDrive/MIT URTC 2025/Experiments/experiment_F/val"
test_dir = "/content/drive/MyDrive/MIT URTC 2025/Experiments/experiment_F/test"

In [None]:
# Data preprocessing parameters
image_size = (224, 224)
batch_size = 16
num_classes = 3

# Data augmentation and preprocessing
train_datagen = ImageDataGenerator(
    rescale=1./255,
    horizontal_flip=True,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=10,
    zoom_range=0.2,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(rescale=1./255)

# Load datasets
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

val_generator = test_datagen.flow_from_directory(
    validation_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False
)

print(f"Training samples: {train_generator.samples}")
print(f"Validation samples: {val_generator.samples}")
print(f"Test samples: {test_generator.samples}")
print(f"Number of classes: {train_generator.num_classes}")
print(f"Classes: {list(train_generator.class_indices.keys())}")

In [None]:
# Build ResNet50 model
# Import necessary components from Keras/TensorFlow
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.optimizers import Adam

def create_resnet50_model(input_shape, num_classes):
    # Use Keras ResNet50 without pre-trained weights (from scratch)
    base_model = ResNet50(
        include_top=False,
        weights=None,       # No pre-trained weights (train from scratch)
        input_shape=input_shape,
        pooling='avg'       # Global average pooling
    )

    # Add custom classification head
    # Use the imported 'models' and 'layers'
    model = models.Sequential([
        base_model,
        layers.Dense(num_classes, activation='softmax', name='predictions')
    ])

    return model

# Create the model
input_shape = (*image_size, 3)  # RGB images
model = create_resnet50_model(input_shape, num_classes)

# Compile the model
# Use the imported 'Adam' from tensorflow.keras.optimizers
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("Model Summary:")
model.summary()

In [None]:
from tensorflow.keras import callbacks

# Callbacks
early_stopping = callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

checkpoint = callbacks.ModelCheckpoint(
    'best_model.h5',
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

# Custom callback to track metrics
class MetricsCallback(callbacks.Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []

    def on_epoch_end(self, epoch, logs=None):
        self.train_losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.train_accuracies.append(logs.get('accuracy'))
        self.val_accuracies.append(logs.get('val_accuracy'))

        print(f"Epoch {epoch+1} | Train Loss: {logs.get('loss'):.4f} | Val Loss: {logs.get('val_loss'):.4f} | Train Acc: {logs.get('accuracy'):.4f} | Val Acc: {logs.get('val_accuracy'):.4f}")

metrics_callback = MetricsCallback()

In [None]:
# Training
epochs = 100

print("Starting training...")
history = model.fit(
    train_generator,
    epochs=epochs,
    validation_data=val_generator,
    callbacks=[early_stopping, reduce_lr, checkpoint, metrics_callback],
    verbose=0
)

print("Training Completed")

In [None]:
# Load the best model
model.load_weights('best_model.h5')

In [None]:
# Testing phase
print("Evaluating on test set...")
test_generator.reset()  # Reset generator to start from beginning

# Get predictions
test_steps = test_generator.samples // test_generator.batch_size + 1
predictions = model.predict(test_generator, steps=test_steps, verbose=1)

# Get true labels
test_generator.reset()
true_labels = []
for i in range(test_steps):
    try:
        batch_x, batch_y = next(test_generator)
        true_labels.extend(np.argmax(batch_y, axis=1))
    except StopIteration:
        break

# Trim predictions to match true labels length
predictions = predictions[:len(true_labels)]
predicted_labels = np.argmax(predictions, axis=1)

# Calculate metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')
f1 = f1_score(true_labels, predicted_labels, average='weighted')

# For ROC AUC, we need the probability scores
true_labels_categorical = keras.utils.to_categorical(true_labels, num_classes)
roc_auc = roc_auc_score(true_labels_categorical, predictions[:len(true_labels)], multi_class='ovr')
mcc = matthews_corrcoef(true_labels, predicted_labels)

# Get class names
class_names = list(train_generator.class_indices.keys())

In [None]:
print(classification_report(true_labels, predicted_labels, target_names=class_names))
print(f"Test Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, MCC: {mcc:.4f}")

In [None]:
# Confusion matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, cmap='Blues', fmt='d',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
ax1.plot(metrics_callback.train_losses, label='Training Loss')
ax1.plot(metrics_callback.val_losses, label='Validation Loss')
ax1.set_title('Model Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()

# Accuracy plot
ax2.plot(metrics_callback.train_accuracies, label='Training Accuracy')
ax2.plot(metrics_callback.val_accuracies, label='Validation Accuracy')
ax2.set_title('Model Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
print(f"Model config num_labels: {model.config.num_labels}")
print(f"Model id2label: {model.config.id2label}")

In [None]:
# file: ipython-input-15-559582950

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
# torchvision.transforms is still useful for initial image loading/resizing
import torchvision.transforms as transforms
import os

class TFGradCAM:
    """
    Grad-CAM implementation for TensorFlow Keras Sequential models
    """
    def __init__(self, model, layer_name=None):
        self.model = model
        self.target_layer = None
        self.target_layer_name = layer_name
        self.target_layer_found = False

        if layer_name is None:
            # Attempt to find a suitable default layer (e.g., last convolutional layer)
            # Iterate through layers from the end to find the last Conv or Conv-like layer
            for layer in reversed(self.model.layers):
                 # Check for convolutional layers by type name
                if 'conv' in layer.__class__.__name__.lower():
                    layer_name = layer.name
                    print(f"Using last convolutional layer found for Grad-CAM: {layer_name}")
                    break
            if layer_name is None:
                print("ERROR: Could not find any convolutional layer for Grad-CAM. Please specify a layer_name.")
                return

        # Find the specified layer
        try:
            self.target_layer = self.model.get_layer(layer_name)
            self.target_layer_found = True
            print(f"Using specified layer for Grad-CAM: {layer_name}")
        except ValueError:
            print(f"ERROR: Could not find layer '{layer_name}' in the model.")
            self.target_layer_found = False


    def __call__(self, image_tensor, class_idx=None, relu_weights=True, return_attention_map=False):
        """
        Generate Grad-CAM heatmap for a TensorFlow Keras model.

        Args:
            image_tensor: Input image tensor [1, H, W, C].
            class_idx: Class index for which to generate Grad-CAM.
                       If None, use the predicted class.
            relu_weights: Whether to apply ReLU to the weights (typically True).
            return_attention_map: If True, return the attention map only.

        Returns:
            Visualization with Grad-CAM overlay on original image
            or just the attention map if return_attention_map=True,
            along with heatmap and original image (as numpy arrays).
        """
        if not self.target_layer_found or self.target_layer is None:
            print("ERROR: Grad-CAM target layer not found or initialized correctly.")
            return None, None, None

        # Ensure image_tensor has a batch dimension
        if len(image_tensor.shape) == 3:
             image_tensor = tf.expand_dims(image_tensor, axis=0)

        # We need to compute gradients with respect to the target layer's output
        # and the final prediction for the target class.
        with tf.GradientTape() as tape:
            # Watch the input tensor
            tape.watch(image_tensor)
            # Get the output of the target layer
            activations = self.model.get_layer(self.target_layer.name).output
            # Create a sub-model that outputs the target layer's output and the final output
            grad_model = tf.keras.models.Model(inputs=self.model.input, outputs=[self.target_layer.output, self.model.output])

            # Get activations and predictions
            (activations, predictions) = grad_model(image_tensor)

            # If class_idx is None, get the predicted class
            if class_idx is None:
                class_idx = tf.argmax(predictions[0]).numpy()

            # Get the score for the target class
            target_class_score = predictions[:, class_idx]

        # Compute the gradients of the target class score with respect to the target layer's activations
        grads = tape.gradient(target_class_score, activations)

        # Ensure grads are not None (can happen if layer_name is wrong or disconnected)
        if grads is None:
            print(f"ERROR: Gradients could not be computed for layer '{self.target_layer.name}'.")
            print("This might happen if the layer is not connected to the output or is not trainable.")
            print("Consider trying a different layer.")
            return None, None, None

        # Global average pooling of gradients to get importance weights
        # In TensorFlow, mean over spatial dimensions (Height and Width)
        weights = tf.reduce_mean(grads, axis=(1, 2), keepdims=True)

        # Apply ReLU to weights if specified
        if relu_weights:
            weights = tf.nn.relu(weights)

        # Compute weighted activation map (sum of weights * activations)
        # activations shape: [1, H', W', C']
        # weights shape: [1, 1, 1, C']
        # result shape after multiplication: [1, H', W', C']
        # sum over channels to get CAM shape: [1, H', W', 1]
        cam = tf.reduce_sum(weights * activations, axis=-1, keepdims=True)

        # Apply ReLU to the CAM
        cam = tf.nn.relu(cam)

        # Normalize CAM
        # Need to handle case where CAM is all zeros
        cam_min = tf.reduce_min(cam)
        cam_max = tf.reduce_max(cam)

        # Avoid division by zero
        if tf.equal(cam_max, cam_min):
             print("WARNING: CAM has uniform values - all pixels have the same importance.")
             normalized_cam = tf.zeros_like(cam) # Or set to 0.5 or 1 depending on desired default
        else:
            normalized_cam = (cam - cam_min) / (cam_max - cam_min)

        # Remove batch dimension and channel dimension for resizing/visualization
        normalized_cam = tf.squeeze(normalized_cam, axis=[0, -1]) # Shape becomes [H', W']

        # Resize CAM to match input image size
        # Use tf.image.resize
        target_height, target_width = image_tensor.shape[1:3]
        cam_resized = tf.image.resize(normalized_cam, [target_height, target_width], method='bilinear')

        # Convert to numpy for visualization
        cam_np = cam_resized.numpy()

        if return_attention_map:
            return cam_np, None, None # Return only the heatmap

        # Convert image tensor to numpy for visualization
        # Assuming image_tensor is in [0, 1] range after preprocessing
        orig_img = tf.squeeze(image_tensor, axis=0).numpy() # Shape [H, W, C]

        # Ensure image is in RGB format if it's grayscale
        if orig_img.shape[-1] == 1:
            orig_img = np.repeat(orig_img, 3, axis=-1)

        # Apply colormap to heatmap (needs cv2)
        # Convert cam_np (float 0-1) to uint8 (0-255)
        heatmap = cv2.applyColorMap(np.uint8(255 * cam_np), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Convert from BGR to RGB
        heatmap = heatmap.astype('float32') / 255.0 # Normalize to 0-1 range

        # Overlay heatmap on original image
        # Ensure shapes are compatible for overlay
        # orig_img: [H, W, C], heatmap: [H, W, C]
        overlay = heatmap * 0.4 + orig_img * 0.6
        overlay = np.clip(overlay, 0, 1) # Clip values to be within [0, 1]

        return overlay, cam_np, orig_img


def visualize_grad_cam_for_brain_mri_tf(model, image_tensor, class_names=None, layer_name=None):
    """
    Visualize Grad-CAM for brain MRI images using a TensorFlow Keras model.

    Args:
        model: The trained TensorFlow Keras model.
        image_tensor: Pre-processed image tensor [1, H, W, C].
                      Should be scaled to the range the model was trained on (e.g., 0-1).
        class_names: List of class names (if available).
        layer_name: Name of the layer to use for Grad-CAM. If None, attempts to find a default.
    """
    # Initialize Grad-CAM with the specified layer
    grad_cam = TFGradCAM(model, layer_name=layer_name)

    if not grad_cam.target_layer_found:
        print("Skipping visualization as target layer was not found.")
        return

    # Generate Grad-CAM visualization
    overlay, heatmap, orig_img = grad_cam(image_tensor)

    if overlay is None:
        print("Failed to generate Grad-CAM visualization.")
        return

    # Get the prediction
    # Ensure image_tensor has a batch dimension for prediction
    if len(image_tensor.shape) == 3:
         image_tensor_for_pred = tf.expand_dims(image_tensor, axis=0)
    else:
        image_tensor_for_pred = image_tensor

    predictions = model.predict(image_tensor_for_pred)
    predicted_class = np.argmax(predictions[0])
    prediction_confidence = np.max(predictions[0])

    # Determine class names
    if class_names is None:
        # Try to infer from the model if possible (less common for Sequential)
        # Fallback to generic names
        num_output_classes = predictions.shape[1]
        class_names = [f"Class {i}" for i in range(num_output_classes)]

    # Set up the plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Plot original image
    axes[0].imshow(orig_img, cmap='gray' if orig_img.shape[-1] == 1 else None)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    # Plot heatmap
    im = axes[1].imshow(heatmap, cmap='jet')
    axes[1].set_title('Grad-CAM Heatmap')
    axes[1].axis('off')

    # Add colorbar
    cbar = plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    cbar.set_label('Importance')

    # Plot overlay
    axes[2].imshow(overlay)
    axes[2].set_title('Grad-CAM Overlay')
    axes[2].axis('off')

    # Create title with prediction info
    class_name = class_names[predicted_class]
    fig.suptitle(f'Class: {class_name} | Confidence: {prediction_confidence:.2f}', fontsize=16)

    plt.tight_layout()
    plt.show()

    # Return the prediction and heatmap for further analysis
    return {
        'predicted_class': predicted_class,
        'class_name': class_name,
        'confidence': prediction_confidence,
        'heatmap': heatmap
    }


# Example usage with multiple layers for TensorFlow
def compare_layers_gradcam_tf(model, image_tensor, class_names=None, layers_to_test=None):
    """
    Compare Grad-CAM visualizations from different layers for a TensorFlow model.

    Args:
        model: The trained TensorFlow Keras model.
        image_tensor: Pre-processed image tensor [1, H, W, C], range [0, 1].
        class_names: List of class names.
        layers_to_test: List of layer names (strings) to visualize.
                        If None, attempts to find a selection of convolutional layers.
    """
    # Ensure image_tensor has a batch dimension
    if len(image_tensor.shape) == 3:
         image_tensor_for_pred = tf.expand_dims(image_tensor, axis=0)
    else:
        image_tensor_for_pred = image_tensor


    if layers_to_test is None:
        # Attempt to find a selection of convolutional layers
        layers_to_test = []
        conv_layers = [layer.name for layer in model.layers if 'conv' in layer.__class__.__name__.lower()]
        if len(conv_layers) > 0:
            layers_to_test.append(conv_layers[0]) # First conv layer
            if len(conv_layers) > 1:
                 # Add layers approximately in the middle and towards the end
                 layers_to_test.append(conv_layers[len(conv_layers) // 3])
                 layers_to_test.append(conv_layers[2 * len(conv_layers) // 3])
            layers_to_test.append(conv_layers[-1]) # Last conv layer
        else:
            print("No convolutional layers found to test.")
            return

    # Get the prediction once
    predictions = model.predict(image_tensor_for_pred)
    predicted_class = np.argmax(predictions[0])
    prediction_confidence = np.max(predictions[0])

    # Determine class names
    if class_names is None:
        num_output_classes = predictions.shape[1]
        class_names = [f"Class {i}" for i in range(num_output_classes)]

    class_name = class_names[predicted_class]

    # Set up the plot
    # Add an extra row for the original image
    fig, axes = plt.subplots(len(layers_to_test) + 1, 3, figsize=(15, 5 * (len(layers_to_test) + 1)))

    # Show original image in the first row
    # Convert image_tensor back to numpy without batch dimension
    orig_img = tf.squeeze(image_tensor, axis=0).numpy()
    if orig_img.shape[-1] == 1:
        orig_img = np.repeat(orig_img, 3, axis=-1)


    axes[0, 0].imshow(orig_img, cmap='gray' if orig_img.shape[-1] == 1 else None)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')

    # Hide the other two subplots in the first row
    axes[0, 1].set_visible(False)
    axes[0, 2].set_visible(False)

    # Add title to the figure
    fig.suptitle(f'Class: {class_name} | Confidence: {prediction_confidence:.2f}', fontsize=16)

    # Generate Grad-CAM for each layer
    for i, layer_name in enumerate(layers_to_test):
        row = i + 1  # Start from second row

        print(f"\nTesting layer: {layer_name}")
        grad_cam = TFGradCAM(model, layer_name=layer_name)

        # Generate Grad-CAM visualization
        overlay, heatmap, _ = grad_cam(image_tensor) # Pass the original image_tensor (with batch dim if present)

        if overlay is None:
            print(f"Failed to generate Grad-CAM visualization for layer: {layer_name}")
            axes[row, 0].text(0.5, 0.5, f"Failed: {layer_name}",
                            horizontalalignment='center', verticalalignment='center', fontsize=10, color='red')
            axes[row, 0].axis('off')
            axes[row, 1].axis('off')
            axes[row, 2].axis('off')
            continue

        # Plot layer name (can be placed on the first subplot of the row)
        axes[row, 0].text(0.5, 0.5, f"Layer: {layer_name}",
                        horizontalalignment='center', verticalalignment='center', fontsize=10)
        axes[row, 0].axis('off')

        # Plot heatmap
        im = axes[row, 1].imshow(heatmap, cmap='jet')
        axes[row, 1].set_title('Heatmap')
        axes[row, 1].axis('off')

        # Plot overlay
        axes[row, 2].imshow(overlay)
        axes[row, 2].set_title('Overlay')
        axes[row, 2].axis('off')


    plt.tight_layout()
    plt.subplots_adjust(top=0.95) # Adjust layout to prevent title overlap
    plt.show()


# Example usage
def run_gradcam_on_image_tf(model, image_path, transform, class_names=None, specific_layer=None):
    """
    Run Grad-CAM on a single image file using a TensorFlow Keras model.

    Args:
        model: The trained TensorFlow Keras model.
        image_path: Path to the image file.
        transform: Image transformation pipeline (should output a tensor compatible with model input).
        class_names: List of class names.
        specific_layer: Optional layer name to use for single Grad-CAM visualization.
                        If None, the default is used.
    """
    # Load and preprocess the image using torchvision transforms (can be adapted)
    # Assuming torchvision transform returns [C, H, W] tensor [0, 1]
    img = Image.open(image_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')

    # Apply transform to get PyTorch tensor [C, H, W], typically [0, 1]
    input_tensor_torch = transform(img)

    # Convert PyTorch tensor [C, H, W] to TensorFlow tensor [1, H, W, C]
    # and apply the 1./255 scaling if needed (depends on your transform and model input)
    # If your transform already produces [0, 1], no additional scaling needed here.
    # Assuming transform outputs [0, 1] range for now.
    input_tensor_tf = tf.transpose(tf.convert_to_tensor(input_tensor_torch.numpy()), perm=[1, 2, 0]) # Shape [H, W, C]
    input_tensor_tf = tf.expand_dims(input_tensor_tf, axis=0) # Add batch dimension [1, H, W, C]

    # Run Grad-CAM with default or specific layer
    print("\n--- Single Image Grad-CAM ---")
    result = visualize_grad_cam_for_brain_mri_tf(
        model=model,
        image_tensor=input_tensor_tf, # Pass the TensorFlow tensor
        class_names=class_names,
        layer_name=specific_layer # Use specified layer if provided
    )

    if result:
      print(f"Predicted: {result['class_name']} with confidence {result['confidence']:.4f}")

    # Also compare with different layers
    print("\n--- Comparing Layers Grad-CAM ---")
    compare_layers_gradcam_tf(
        model=model,
        image_tensor=input_tensor_tf, # Pass the TensorFlow tensor
        class_names=class_names
    )

    return result

In [None]:
# file: ipython-input-19-3073569956 (Modified)

# Example usage with your model structure

import torch
from PIL import Image
import torchvision.transforms as transforms
# Import necessary TensorFlow/Keras modules
import tensorflow as tf
from tensorflow.keras.models import Sequential
import numpy as np # Import numpy
# Make sure to import the updated TensorFlow-compatible Grad-CAM functions
# If they are in the same notebook cell, no extra import needed.
# If you put them in a separate file (e.g., tf_gradcam.py), import them.
# from tf_gradcam import run_gradcam_on_image_tf, visualize_existing_tensor_tf, compare_layers_gradcam_tf


# Define your transformation pipeline (adjust according to your preprocessing)
# Ensure this transform matches the preprocessing applied during training.
# Your training data generator used rescale=1./255 and target_size=(224, 224).
# This transform resizes to (224, 224) and outputs a PyTorch Tensor in [0, 1].
transform = transforms.Compose([
    transforms.Resize((224, 224)), # Resize to training size
    transforms.ToTensor(),         # Convert to torch.Tensor [C, H, W], range [0, 1]
    # No normalization needed here if your model expects input in [0, 1] range
])

# Set device - This variable is no longer directly used by the TF Grad-CAM functions
# but keeping it is harmless if used elsewhere.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define class names based on your dataset
class_names = ['glioma', 'meningioma', 'pituitary_tumor']

# Option 1: Run Grad-CAM on a single image file
# Use the updated function designed for TensorFlow
visualize_single_image = run_gradcam_on_image_tf # Rename for clarity if needed

# Option 2: Apply to an existing TensorFlow image tensor (if you have one)
# The previous visualize_existing_tensor function was also PyTorch-centric.
# Let's create a new one or modify the existing logic if needed.
# The logic within run_gradcam_on_image_tf or compare_layers_gradcam_tf
# handles the tensor format [1, H, W, C] and scaling [0, 1].
def visualize_existing_tf_tensor(image_tensor_tf):
    """
    Visualize Grad-CAM on an existing TensorFlow image tensor.

    Args:
        image_tensor_tf: Pre-processed image tensor [1, H, W, C], range [0, 1].
    """
    print("\n--- Visualizing Existing TensorFlow Tensor ---")

    # Compare different layers to find which one gives the best visualization
    compare_layers_gradcam_tf(
        model=model,
        image_tensor=image_tensor_tf,
        class_names=class_names
    )

    # Visualize with specific layer that showed the best results (if needed)
    # Choose based on results from compare_layers_gradcam_tf
    # You will need the actual Keras layer name (e.g., "conv5_block3_out").
    # Inspect your model.summary() to find valid names.
    print("\n--- Visualizing with a specific layer (example) ---")
    # Replace "conv5_block3_out" with an actual layer name from your model
    best_layer_name = "conv5_block3_out" # Example
    try:
        result = visualize_grad_cam_for_brain_mri_tf(
            model=model,
            image_tensor=image_tensor_tf,
            class_names=class_names,
            layer_name=best_layer_name
        )
        if result:
            print(f"Predicted: {result['class_name']} with confidence {result['confidence']:.4f}")
    except ValueError as e:
         print(f"Could not visualize with layer '{best_layer_name}': {e}")
         print("Please check the layer name in model.summary()")



# If you have the image file path:
# Make sure the path is correct and the image exists.
image_file_path = "/content/drive/MyDrive/MIT URTC 2025/Training_Dataset/test/glioma/1841.jpg"
if os.path.exists(image_file_path):
    # Call the function directly with all required arguments
    run_gradcam_on_image_tf(
        model=model,          # Pass the trained model
        image_path=image_file_path, # Pass the image path
        transform=transform,    # Pass the transform pipeline
        class_names=class_names # Pass class names (optional, but good practice)
    )
else:
    print(f"Error: Image file not found at {image_file_path}")

# If you have an existing TensorFlow tensor (example):
# # Assume you have a tensor named 'my_tf_image_tensor' [1, 224, 224, 3] in [0, 1] range
# # Replace this with your actual tensor if you have one
# # my_tf_image_tensor = tf.random.uniform(shape=[1, 224, 224, 3]) # Example dummy tensor
# # visualize_existing_tf_tensor(my_tf_image_tensor)