In [None]:
# --- Callbacks ---
from sklearn.metrics import f1_score
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.callbacks import Callback
from datetime import datetime




CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']


class LearningRateLogger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.learning_rate.numpy()
        print(f"Learning Rate at epoch {epoch + 1}: {lr:.6f}")

class TimeLimitCallback(tf.keras.callbacks.Callback):
    def __init__(self, max_minutes=20):
        super().__init__()
        self.max_duration = max_minutes * 60
    def on_train_begin(self, logs=None):
        self.start_time = tf.timestamp()
    def on_epoch_end(self, epoch, logs=None):
        elapsed = tf.timestamp() - self.start_time
        if elapsed > self.max_duration:
            print(f"\n Training time exceeded {self.max_duration // 60} minutes. Stopping early.")
            self.model.stop_training = True


class StepTimer(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.total_time = 0.0
        self.total_steps = 0

    def on_train_batch_begin(self, batch, logs=None):
        self.start_time = tf.timestamp()

    def on_train_batch_end(self, batch, logs=None):
        elapsed = tf.timestamp() - self.start_time
        self.total_time += elapsed
        self.total_steps += 1

    def on_train_end(self, logs=None):
        avg_step_time = self.total_time / self.total_steps
        print(f"🕒 Average training step time: {avg_step_time:.4f} sec")




class DistributionLogger(tf.keras.callbacks.Callback):
    def __init__(self, generator, name="Training", max_batches=16, visualise_samples=2):
        super().__init__()
        self.generator = generator
        self.name = name
        self.max_batches = max_batches
        self.visualise_samples = visualise_samples
        self.cumulative_class_counts = defaultdict(int)

    def on_epoch_end(self, epoch, logs=None):
        batch_class_counts = defaultdict(int)
        all_samples = []
        batches_seen = 0

        for batch_images, batch_labels in self.generator:
            if batches_seen >= self.max_batches:
                break

            batch_preds = np.argmax(batch_labels, axis=-1)
            unique, counts = np.unique(batch_preds, return_counts=True)

            for u, c in zip(unique, counts):
                batch_class_counts[u] += c
                self.cumulative_class_counts[u] += c

            for img, label in zip(batch_images, batch_preds):
                all_samples.append((img, label))

            batches_seen += 1

        total_pixels = sum(batch_class_counts.values())
        total_images = batches_seen * self.generator.batch_size

        print(f"{self.name} Distribution After Epoch {epoch + 1} ({total_images:,} images):")
        for cls in sorted(batch_class_counts):
            count = batch_class_counts[cls]
            percent = 100.0 * count / total_pixels
            print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

    def on_train_end(self, logs=None):
        total_pixels = sum(self.cumulative_class_counts.values())
        print("Final Cumulative Training Class Distribution:")
        print(f"Total pixels: {total_pixels:,} px")
        for cls in sorted(self.cumulative_class_counts):
            count = self.cumulative_class_counts[cls]
            percent = 100.0 * count / total_pixels
            print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

        plot_class_distribution(self.cumulative_class_counts)



class ValidationPredictionLogger(tf.keras.callbacks.Callback):
    def __init__(self, val_gen, user_model, out_dir="/content/figs", max_batches=1):
        super().__init__()
        self.val_gen = val_gen
        self.user_model = user_model
        self.max_batches = max_batches
        self.out_dir = out_dir
        self.ignore_color = (255, 0, 255)
        self.class_to_color = {
            0: (230, 25, 75),    # Building
            1: (145, 30, 180),   # Clutter
            2: (60, 180, 75),    # Vegetation
            3: (245, 130, 48),   # Water
            4: (255, 255, 255),  # Background
            5: (0, 130, 200),    # Car
        }

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % 8 != 0:
            return

        batches_seen = 0
        for batch_images, batch_labels in self.val_gen:
            if batches_seen >= self.max_batches:
                break

            preds = self.user_model.predict(batch_images)
            preds_argmax = np.argmax(preds, axis=-1)
            true_argmax = np.argmax(batch_labels.numpy(), axis=-1)

            num_samples = min(8, len(batch_images))
            fig, axs = plt.subplots(4, 6, figsize=(18, 12))

            for i in range(num_samples):
                row = i // 2
                col = (i % 2) * 3

                rgb = (batch_images[i].numpy() * 255).astype(np.uint8)
                h, w = true_argmax[i].shape

                true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
                pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

                for cid, col_rgb in self.class_to_color.items():
                    true_rgb[true_argmax[i] == cid] = col_rgb
                    pred_rgb[preds_argmax[i] == cid] = col_rgb

                # Convert batch_labels[i] to numpy if not already
                mask = batch_labels[i].numpy()
                ignore_mask = np.all(mask == 0, axis=-1)
                true_rgb[ignore_mask] = self.ignore_color
                pred_rgb[ignore_mask] = self.ignore_color

                axs[row, col].imshow(rgb)
                axs[row, col].set_title("Input")
                axs[row, col].axis("off")

                axs[row, col + 1].imshow(true_rgb)
                axs[row, col + 1].set_title("Ground Truth")
                axs[row, col + 1].axis("off")

                axs[row, col + 2].imshow(pred_rgb)
                axs[row, col + 2].set_title("Prediction")
                axs[row, col + 2].axis("off")

            os.makedirs(self.out_dir, exist_ok=True)
            save_path = os.path.join(self.out_dir, f"val_preds_epoch{epoch+1:03d}.png")

            plt.tight_layout()
            plt.suptitle(f"Validation Predictions After Epoch {epoch + 1}", y=1.02)
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            plt.show()
            plt.close(fig)

            batches_seen += 1


class DualCheckpointSaver(Callback):
    def __init__(self, base_model, monitor='val_iou_score', mode='max',
                 out_dir="checkpoints", drive_dir="/content/drive/MyDrive/checkpoints"):
        super().__init__()
        self.base_model = base_model
        self.monitor = monitor
        self.mode = mode
        self.out_dir = out_dir
        self.drive_dir = drive_dir
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.best_value = -float('inf') if mode == 'max' else float('inf')

        os.makedirs(self.out_dir, exist_ok=True)
        os.makedirs(self.drive_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        if logs is None or self.monitor not in logs:
            print(f"{self.monitor} not found in logs. Skipping checkpoint.")
            return

        current = logs[self.monitor]
        improved = (current > self.best_value) if self.mode == 'max' else (current < self.best_value)

        if improved:
            self.best_value = current
            epoch_num = epoch + 1
            model_name = f"{self.base_model.name}_{self.timestamp}_epoch{epoch_num:03d}.keras"

            local_path = os.path.join(self.out_dir, model_name)
            drive_path = os.path.join(self.drive_dir, model_name)

            self.base_model.save(local_path)
            tf.keras.models.save_model(self.base_model, drive_path)
            print(f"✅ Saved improved model at epoch {epoch_num} with {self.monitor}: {current:.4f}")
        else:
            print(f"⏭Epoch {epoch+1}: {self.monitor} did not improve ({current:.4f})")




