In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import Callback # Base class for custom Keras Callbacks
from datetime import datetime # For TimeLimitCallback
from sklearn.metrics import f1_score # For DynamicClassWeightUpdater

# --- Global Constants (Ensure these are available in your main script or imported) ---
# Assuming these are consistent across your project.
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
NUM_CLASSES = len(CLASS_NAMES)

# Define `class_weights` as a tf.Variable *outside* the DynamicClassWeightUpdater class.
# This variable will be dynamically updated by the callback and used by your loss function.
class_weights = tf.Variable([1.0] * NUM_CLASSES, trainable=False, dtype=tf.float32, name="dynamic_class_weights")


# --- Custom Callbacks ---

class LRScheduleLogger(tf.keras.callbacks.Callback):
    """Callback to log the learning rate at the beginning of each epoch."""

    def on_epoch_begin(self, epoch, logs=None):
        """
        Called at the end of each epoch. Retrieves and prints the current learning rate.

        Args:
            epoch (int): The current epoch number (0-indexed).
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        lr = self.model.optimizer._decayed_lr(tf.float32).numpy()
        print(f"🔁 Epoch {epoch+1}: Learning Rate = {lr:.6e}")


class TimeLimitCallback(tf.keras.callbacks.Callback):
    """
    A Keras Callback to stop training early if a specified time limit (in minutes) is exceeded.
    This prevents runs from consuming excessive computational resources.
    """
    def __init__(self, max_minutes: float = 20):
        """
        Initializes the TimeLimitCallback.

        Args:
            max_minutes (float): The maximum number of minutes training is allowed to run.
        """
        super().__init__()
        self.max_duration = max_minutes * 60 # Convert minutes to seconds
        self.start_time = None # Will store the timestamp when training begins

    def on_train_begin(self, logs: dict = None):
        """
        Called at the beginning of training. Records the training start time using tf.timestamp().
        """
        self.start_time = tf.timestamp()
        print(f"⏰ Training started. Maximum duration set to {self.max_duration / 60:.1f} minutes.")

    def on_epoch_end(self, epoch: int, logs: dict = None):
        """
        Called at the end of each epoch. Checks if the elapsed time exceeds the maximum duration.

        Args:
            epoch (int): The current epoch number (0-indexed).
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        if self.start_time is None: # Safety check
            return

        elapsed = tf.timestamp() - self.start_time # Calculate elapsed time in seconds
        if elapsed > self.max_duration:
            print(f"\nTraining time exceeded {self.max_duration / 60:.1f} minutes ({elapsed:.2f} seconds). Stopping early at epoch {epoch + 1}.")
            self.model.stop_training = True # Signal Keras to stop training


class StepTimer(tf.keras.callbacks.Callback):
    """
    A Keras Callback to measure and report the average time taken per training step (batch).
    Provides insight into training throughput.
    """
    def on_train_begin(self, logs: dict = None):
        """
        Called at the start of training. Initializes variables to track total time and steps.
        """
        self.total_time = 0.0
        self.total_steps = 0
        self.start_time_batch = None # Will store the start time for the current batch

    def on_train_batch_begin(self, batch: int, logs: dict = None):
        """
        Called at the beginning of each training batch. Records the start time of the batch.

        Args:
            batch (int): The current batch index.
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        self.start_time_batch = tf.timestamp()

    def on_train_batch_end(self, batch: int, logs: dict = None):
        """
        Called at the end of each training batch. Calculates the elapsed time for the batch
        and accumulates it.

        Args:
            batch (int): The current batch index.
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        if self.start_time_batch is None: # Safety check
            return
            
        elapsed = tf.timestamp() - self.start_time_batch
        self.total_time += elapsed.numpy() # Convert to NumPy float for accumulation
        self.total_steps += 1

    def on_train_end(self, logs: dict = None):
        """
        Called at the end of training. Calculates and prints the average time per step.

        Args:
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        if self.total_steps > 0:
            avg_step_time = self.total_time / self.total_steps
            print(f"\nAverage training step time over {self.total_steps} steps: {avg_step_time:.4f} sec")
        else:
            print("No training steps completed for average step time calculation.")


class DynamicClassWeightUpdater(tf.keras.callbacks.Callback):
    """
    A Keras Callback to dynamically update class weights during training based on
    per-class performance (F1-score or IoU) on the validation set.
    Weights are updated every `update_every` epochs, giving higher weights to classes
    that perform poorly. It requires `class_weights` to be a `tf.Variable`
    defined in a scope accessible by this callback (e.g., globally).
    """
    def __init__(self, val_data: tf.data.Dataset, update_every: int = 5, 
                 target: str = 'f1', ignore_class: int = None):
        """
        Initializes the DynamicClassWeightUpdater callback.

        Args:
            val_data (tf.data.Dataset): The TensorFlow Dataset to use for validation metrics.
            update_every (int): How often (in epochs) to update the class weights.
            target (str): The metric to target for weighting ('f1' or 'iou').
            ignore_class (int, optional): Class ID to ignore (set its weight to 0.0) when calculating
                                          and applying new weights. Defaults to None.
        """
        super().__init__()
        self.val_data = val_data
        self.update_every = update_every
        self.target = target.lower() # Ensure target is lowercase for robust comparison
        self.ignore_class = ignore_class
        self.num_classes = NUM_CLASSES # Use global NUM_CLASSES for consistency

    def on_epoch_end(self, epoch: int, logs: dict = None):
        """
        Method called at the end of each epoch. Updates weights if `epoch + 1` is a multiple
        of `update_every`.

        Args:
            epoch (int): The current epoch number (0-indexed).
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        # Only update weights if the current epoch is a multiple of `update_every`
        if (epoch + 1) % self.update_every != 0:
            return

        print(f"\n📊 Epoch {epoch+1}: Computing per-class metrics for dynamic weight update...")
        
        y_true_all = [] # List to collect all true class labels (flattened) from validation set
        y_pred_all = [] # List to collect all predicted class labels (flattened) from validation set

        # Iterate over a limited number of validation batches to compute per-class metrics
        # Using .take(max_batches_for_metrics) can prevent long evaluation times if val_data is large.
        # For full accuracy, iterate over the entire self.val_data.
        # Example: for x_batch, y_batch in self.val_data.take(50): # limit to 50 batches
        for x_batch, y_batch in self.val_data:
            # Predict with verbose=0 to suppress per-batch output during validation prediction
            preds = self.model.predict(x_batch, verbose=0) 
            
            # Convert one-hot encoded true labels to class IDs (flattened)
            y_true = tf.argmax(y_batch, axis=-1).numpy().flatten()
            # Convert softmax predictions to class IDs (flattened)
            y_pred = tf.argmax(preds, axis=-1).numpy().flatten()

            y_true_all.extend(y_true)
            y_pred_all.extend(y_pred)

        # Convert collected lists to NumPy arrays for scikit-learn metric calculations
        y_true_all = np.array(y_true_all)
        y_pred_all = np.array(y_pred_all)

        new_weights = [] # List to store the newly calculated weights

        # Calculate weight for each class based on target metric ('f1' or 'iou')
        for i in range(self.num_classes):
            if self.ignore_class is not None and i == self.ignore_class:
                new_weights.append(0.0) # Explicitly set weight to 0.0 for the ignored class
                continue

            metric_value = 0.0 # Default value if class not present or union is zero
            if self.target == 'f1':
                # Calculate F1-score for the current class
                # zero_division=0 means F1 is 0 if no true samples for class or no predictions
                metric_value = f1_score(y_true_all == i, y_pred_all == i, zero_division=0)
            elif self.target == 'iou':
                # Manually calculate IoU for the current class
                intersection = np.logical_and(y_true_all == i, y_pred_all == i).sum()
                union = (y_true_all == i).sum() + (y_pred_all == i).sum() - intersection
                metric_value = intersection / union if union > 0 else 0.0 # Avoid division by zero
            else:
                print(f"Warning: Unknown target metric '{self.target}'. Using default weight 1.0.")
                metric_value = 1.0 # Fallback for unknown target

            # Weight is inversely proportional to the metric value.
            # Add a small epsilon to the denominator to prevent division by zero if metric_value is 0.
            weight = 1.0 / (metric_value + tf.keras.backend.epsilon()) 
            new_weights.append(weight)

        # Normalize weights to prevent extremely large values and scale them
        new_weights = np.array(new_weights, dtype=tf.float32)
        # Normalize by the maximum weight to scale values between 0 and 1 (or max_weight)
        # Add epsilon to denominator to prevent division by zero if all weights are effectively 0
        new_weights = new_weights / (tf.reduce_max(new_weights) + tf.keras.backend.epsilon()) 

        # Access the global class_weights tf.Variable and update its value
        # This is CRITICAL for the dynamic update to affect the loss function defined in your main pipeline.
        global class_weights
        class_weights.assign(new_weights)
        print(f"\nEpoch {epoch+1}: Dynamically updated class weights: {new_weights.numpy()}\n")


class DualCheckpointSaver(Callback):
    """
    A Keras Callback to save model checkpoints to two locations (local and Google Drive, if mounted)
    when the monitored metric improves.
    """
    def __init__(self, base_model: tf.keras.Model, monitor: str = 'val_iou_score', mode: str = 'max',
                 out_dir: str = "checkpoints", drive_dir: str = "/content/drive/MyDrive/checkpoints"):
        """
        Initializes the DualCheckpointSaver.

        Args:
            base_model (tf.keras.Model): The model to save. If your main model has a separate
                                        base_model (e.g., in a flexible U-Net), pass that.
                                        Otherwise, pass the main `model` instance.
            monitor (str): The metric to monitor for improvement (e.g., 'val_loss', 'val_iou_score').
            mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the
                        quantity monitored has stopped decreasing; in 'max' mode it will stop when
                        the quantity monitored has stopped increasing.
            out_dir (str): Local directory path to save checkpoints.
            drive_dir (str): Google Drive directory path to save checkpoints (assumes mounted).
        """
        super().__init__()
        self.base_model = base_model
        self.monitor = monitor
        self.mode = mode
        self.out_dir = out_dir
        self.drive_dir = drive_dir
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Unique timestamp for model filenames
        self.best_value = -float('inf') if mode == 'max' else float('inf') # Initialize best value based on mode

        # Create directories if they don't exist
        os.makedirs(self.out_dir, exist_ok=True)
        # Check if Google Drive path is accessible before creating
        if os.path.exists(os.path.dirname(self.drive_dir)): # Check parent directory exists
            os.makedirs(self.drive_dir, exist_ok=True)
        else:
            print(f"Warning: Google Drive path '{self.drive_dir}' not accessible. Skipping Drive checkpoints.")
            self.drive_dir = None # Disable saving to Drive

    def on_epoch_end(self, epoch: int, logs: dict = None):
        """
        Called at the end of each epoch. Checks if the monitored metric has improved
        and saves the model if it has.

        Args:
            epoch (int): The current epoch number (0-indexed).
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        if logs is None or self.monitor not in logs:
            print(f"Warning: Monitored metric '{self.monitor}' not found in logs for epoch {epoch+1}. Skipping checkpoint.")
            return

        current = logs[self.monitor] # Current value of the monitored metric
        # Check for improvement based on mode
        improved = (current > self.best_value) if self.mode == 'max' else (current < self.best_value)

        if improved:
            self.best_value = current # Update best value
            epoch_num = epoch + 1
            # Generate model filename with timestamp and epoch number
            model_name = f"{self.base_model.name}_{self.timestamp}_epoch{epoch_num:03d}_{self.monitor}{current:.4f}.keras"

            local_path = os.path.join(self.out_dir, model_name)
            
            # Save to local directory
            self.base_model.save(local_path, save_format='tf') # Use save_format='tf' for Keras SavedModel format
            print(f"Saved improved model locally: {local_path}")

            # Save to Google Drive if accessible
            if self.drive_dir:
                drive_path = os.path.join(self.drive_dir, model_name)
                # Use tf.keras.models.save_model for consistency if base_model is a Model
                tf.keras.models.save_model(self.base_model, drive_path, save_format='tf')
                print(f"Saved improved model to Google Drive: {drive_path}")
            
            print(f"Model improved at epoch {epoch_num} with {self.monitor}: {current:.4f}")
        else:
            print(f"⏭ Epoch {epoch+1}: {self.monitor} did not improve ({current:.4f}). Best: {self.best_value:.4f}")





# --- Additional Callbacks for Dataset Analysis, Comes with Massive Overhead (Use with Caution) ---

class DistributionLogger(tf.keras.callbacks.Callback):
    """
    A Keras Callback to log and visualize the class distribution of training (or validation) data
    at the end of each epoch. It can help monitor dataset balance and augmentation effectiveness.
    This callback adds significant overhead if max_batches is large.
    """
    def __init__(self, generator: tf.data.Dataset, name: str = "Training", max_batches: int = 16):
        """
        Initializes the DistributionLogger.

        Args:
            generator (tf.data.Dataset): The TensorFlow Dataset to analyze.
            name (str): A descriptive name for the generator (e.g., "Training", "Validation").
            max_batches (int): The maximum number of batches to process for distribution analysis
                               in each epoch (to control overhead).
        """
        super().__init__()
        self.generator = generator
        self.name = name
        self.max_batches = max_batches
        self.cumulative_class_counts = defaultdict(int) # To track counts across all epochs

    def on_epoch_end(self, epoch: int, logs: dict = None):
        """
        Called at the end of each epoch. Collects and logs the class distribution
        for a subset of the generator's data.

        Args:
            epoch (int): The current epoch number (0-indexed).
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        batch_class_counts = defaultdict(int) # To track counts for the current epoch's analyzed batches
        batches_seen = 0

        # Iterate over a limited number of batches to compute the distribution
        for _, batch_labels in self.generator.take(self.max_batches):
            # Convert one-hot labels to class IDs
            batch_preds_ids = tf.argmax(batch_labels, axis=-1).numpy()
            # Get unique class IDs and their counts in the current batch
            unique, counts = np.unique(batch_preds_ids, return_counts=True)

            for u, c in zip(unique, counts):
                batch_class_counts[u] += c
                self.cumulative_class_counts[u] += c # Accumulate for cumulative report

            batches_seen += 1

        total_pixels_in_sample = sum(batch_class_counts.values())
        # The number of images processed might be less than max_batches * batch_size if generator ends early
        # total_images = batches_seen * self.generator.element_spec[0].shape[0] # Not reliable if batch_size is None or last batch smaller

        if total_pixels_in_sample > 0:
            print(f"\n{self.name} Class Distribution Sampled After Epoch {epoch + 1}:")
            for cls_id in sorted(batch_class_counts.keys()):
                count = batch_class_counts[cls_id]
                percent = 100.0 * count / total_pixels_in_sample
                # Ensure class_id is within NUM_CLASSES bounds
                class_name = CLASS_NAMES[cls_id] if cls_id < NUM_CLASSES else f"Class {cls_id}"
                print(f"  Class {cls_id} ({class_name}): {count:,} px ({percent:.2f}%)")
        else:
            print(f"\n{self.name} Class Distribution: No pixels found in sampled batches for Epoch {epoch + 1}.")

    def on_train_end(self, logs: dict = None):
        """
        Called at the end of training. Prints the final cumulative class distribution
        across all epochs. Also plots the distribution.
        """
        total_pixels_cumulative = sum(self.cumulative_class_counts.values())

        if total_pixels_cumulative > 0:
            print("\nFinal Cumulative Training Class Distribution:")
            print(f"Total pixels analyzed: {total_pixels_cumulative:,} px")
            for cls_id in sorted(self.cumulative_class_counts.keys()):
                count = self.cumulative_class_counts[cls_id]
                percent = 100.0 * count / total_pixels_cumulative
                class_name = CLASS_NAMES[cls_id] if cls_id < NUM_CLASSES else f"Class {cls_id}"
                print(f"  Class {cls_id} ({class_name}): {count:,} px ({percent:.2f}%)")
            
            # Plot the cumulative distribution
            self._plot_cumulative_distribution()
        else:
            print("No cumulative pixels were recorded for distribution analysis.")

    def _plot_cumulative_distribution(self):
        """Helper function to plot the cumulative class distribution."""
        class_ids = sorted(self.cumulative_class_counts.keys())
        pixel_counts = [self.cumulative_class_counts[cid] for cid in class_ids]
        
        total_pixels = sum(pixel_counts)
        if total_pixels == 0:
            print("Cannot plot distribution: No pixels recorded.")
            return

        pixel_props = [count / total_pixels for count in pixel_counts]
        
        class_labels = [CLASS_NAMES[cid] if cid < NUM_CLASSES else f"Class {cid}" for cid in class_ids]
        
        plt.figure(figsize=(10, 5))
        bars = plt.bar(class_labels, pixel_props, edgecolor='black')
        plt.title(f"Cumulative Class Distribution ({self.name} Data)")
        plt.xlabel("Class")
        plt.ylabel("Proportion")
        plt.grid(True, axis='y', linestyle='--', alpha=0.5)

        for bar, prop in zip(bars, pixel_props):
            plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{prop:.2%}",
                     ha='center', va='bottom', fontsize=9)
        plt.tight_layout()
        plt.show()
        plt.close()


class ValidationPredictionLogger(tf.keras.callbacks.Callback):
    """
    A Keras Callback to visualize and save example predictions on the validation set
    at specified intervals (e.g., every N epochs). This helps monitor the model's
    qualitative performance during training.
    """
    def __init__(self, val_gen: tf.data.Dataset, user_model: tf.keras.Model, 
                 out_dir: str = "/content/figs", max_batches_to_process: int = 1,
                 num_samples_to_plot: int = 8, plot_every_n_epochs: int = 8):
        """
        Initializes the ValidationPredictionLogger.

        Args:
            val_gen (tf.data.Dataset): The validation data generator.
            user_model (tf.keras.Model): The model being trained (needed for predictions).
            out_dir (str): Directory to save the visualization plots.
            max_batches_to_process (int): Max number of validation batches to process for predictions.
            num_samples_to_plot (int): Number of (RGB, GT, Pred) samples to display in the plot.
                                       This should typically be <= (max_batches_to_process * batch_size of val_gen).
            plot_every_n_epochs (int): The interval (in epochs) at which to generate plots.
        """
        super().__init__()
        self.val_gen = val_gen
        self.user_model = user_model # Reference to the model itself
        self.out_dir = out_dir
        self.max_batches_to_process = max_batches_to_process
        self.num_samples_to_plot = num_samples_to_plot
        self.plot_every_n_epochs = plot_every_n_epochs

        # Define ignore color and class_to_color mapping (ensure consistency with main script)
        self.ignore_color = (255, 0, 255) # Magenta for ignored pixels
        self.class_to_color = {
            0: (230, 25, 75),    # Building
            1: (145, 30, 180),   # Clutter
            2: (60, 180, 75),    # Vegetation
            3: (245, 130, 48),   # Water
            4: (255, 255, 255),  # Background
            5: (0, 130, 200),    # Car
        }
        # Create output directory for plots
        os.makedirs(self.out_dir, exist_ok=True)

    def on_epoch_end(self, epoch: int, logs: dict = None):
        """
        Called at the end of each epoch. Generates and saves a plot of validation predictions
        if the current epoch is a multiple of `plot_every_n_epochs`.

        Args:
            epoch (int): The current epoch number (0-indexed).
            logs (dict, optional): Dictionary of logs. Defaults to None.
        """
        if (epoch + 1) % self.plot_every_n_epochs != 0:
            return

        print(f"\n🎨 Epoch {epoch+1}: Generating validation prediction visualization...")

        all_batch_images = []
        all_batch_labels = []

        batches_seen = 0
        # Iterate over a limited number of validation batches
        for batch_images, batch_labels in self.val_gen.take(self.max_batches_to_process):
            all_batch_images.append(batch_images.numpy())
            all_batch_labels.append(batch_labels.numpy())
            batches_seen += 1
            if batches_seen >= self.max_batches_to_process:
                break
        
        # Concatenate collected batches into single arrays
        if not all_batch_images:
            print("Warning: No validation batches found for visualization.")
            return

        full_images = np.concatenate(all_batch_images, axis=0)
        full_labels = np.concatenate(all_batch_labels, axis=0)

        # Get predictions for the collected samples
        preds = self.user_model.predict(full_images, verbose=0)
        preds_argmax = np.argmax(preds, axis=-1) # Predicted class IDs
        true_argmax = np.argmax(full_labels, axis=-1) # True class IDs

        # Determine number of samples to plot (min of requested, actual collected, or max 8 as per original code)
        num_samples_to_display = min(self.num_samples_to_plot, len(full_images))
        
        # Grid layout for plots: 4 rows, 3 columns per sample (Input, GT, Pred)
        # Total columns = 3 (for Input, GT, Pred)
        # Total rows = num_samples_to_display
        fig, axs = plt.subplots(num_samples_to_display, 3, figsize=(18, num_samples_to_display * 4))

        # Handle cases where axs might be 1D (e.g., if num_samples_to_display == 1)
        if num_samples_to_display == 1:
            axs = np.array([axs]) # Make it 2D for consistent indexing

        # Iterate through samples and plot
        for i in range(num_samples_to_display):
            # Extract RGB image (scale from [0,1] to [0,255] and cast to uint8)
            rgb = (full_images[i] * 255).astype(np.uint8)
            h, w = true_argmax[i].shape

            # Create colored versions of true and predicted masks
            true_rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)
            pred_rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)

            for cid, col_rgb in self.class_to_color.items():
                true_rgb_mask[true_argmax[i] == cid] = np.array(col_rgb, dtype=np.uint8)
                pred_rgb_mask[preds_argmax[i] == cid] = np.array(col_rgb, dtype=np.uint8)

            # Apply ignore color (magenta) to ignored regions in both masks
            # Based on the original one-hot encoded labels
            ignore_mask_bool = np.all(full_labels[i] == 0, axis=-1)
            true_rgb_mask[ignore_mask_bool] = np.array(self.ignore_color, dtype=np.uint8)
            pred_rgb_mask[ignore_mask_bool] = np.array(self.ignore_color, dtype=np.uint8)

            # Plotting on the subplots
            axs[i, 0].imshow(rgb)
            axs[i, 0].set_title("Input", fontsize=10)
            axs[i, 0].axis("off")

            axs[i, 1].imshow(true_rgb_mask)
            axs[i, 1].set_title("Ground Truth", fontsize=10)
            axs[i, 1].axis("off")

            axs[i, 2].imshow(pred_rgb_mask)
            axs[i, 2].set_title("Prediction", fontsize=10)
            axs[i, 2].axis("off")

        # Adjust layout and save the figure
        plt.tight_layout()
        plt.suptitle(f"Validation Predictions After Epoch {epoch + 1}", y=1.02, fontsize=16)
        save_path = os.path.join(self.out_dir, f"val_preds_epoch{epoch+1:03d}.png")
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        plt.show()
        plt.close(fig) # Explicitly close the figure to free memory
        print(f"Validation prediction plot saved to: {save_path}")