In [None]:
#dataset preprocessing with newband1 and newband2

!pip install rasterio

import os
import csv
import numpy as np
import rasterio
from PIL import Image
from tqdm import tqdm
import shutil
from skimage.transform import resize

class ResUNetPreprocessor:
    def __init__(self, base_path, output_path, img_size=256):
        self.base_path = base_path
        self.output_path = output_path
        self.img_size = img_size
        
        # Initialize normalization parameters (will be calculated from data)
        self.norm_means = None
        self.norm_stds = None
        
        self.create_output_dirs()
        
        # Define ResUNet dataset structure for Kaggle environment
        self.image_dir = os.path.join('data', 'flood_events', 'HandLabeled', 'S1Hand')
        self.mask_dir = os.path.join('data', 'flood_events', 'HandLabeled', 'LabelHand')
        self.splits_dir = os.path.join('splits', 'flood_handlabeled')
        
        # Statistics dictionary to track dataset properties (only valid pixels)
        self.stats = {
            'train': {'count': 0, 'flood_pixels': 0, 'total_valid_pixels': 0, 'invalid_pixels': 0},
            'val': {'count': 0, 'flood_pixels': 0, 'total_valid_pixels': 0, 'invalid_pixels': 0},
            'test': {'count': 0, 'flood_pixels': 0, 'total_valid_pixels': 0, 'invalid_pixels': 0}
        }

    def create_output_dirs(self):
        # In Kaggle, we can write to /kaggle/working
        if os.path.exists(self.output_path):
            shutil.rmtree(self.output_path)
        
        for split in ['train', 'val', 'test']:
            os.makedirs(os.path.join(self.output_path, split, 'images'), exist_ok=True)
            os.makedirs(os.path.join(self.output_path, split, 'masks'), exist_ok=True)
            os.makedirs(os.path.join(self.output_path, split, 'validity_masks'), exist_ok=True)
            
        print(f"Created output directories at {self.output_path}")

    def read_csv_file(self, csv_path):
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")
            
        with open(csv_path, 'r') as f:
            reader = csv.reader(f)
            next(reader, None)  # Skip header
            return [(row[0], row[1]) for row in reader]

    def compute_bands(self, vh, vv):
        """
        Compute the three bands according to the specified formulas:
        Band 1: VV
        Band 2: NewBand1 = (VH - VV) / (VH + VV)
        Band 3: NewBand2 = sqrt((VH^2 + VV^2) / 2)
        """
        eps = 1e-8
        
        # Band 1: VV
        band1 = vv
        
        # Band 2: NewBand1 = (VH - VV) / (VH + VV)
        band2 = np.divide(vh - vv, vh + vv + eps)
        
        # Band 3: NewBand2 = sqrt((VH^2 + VV^2) / 2)
        band3 = np.sqrt((vh**2 + vv**2) / 2)
        
        return band1, band2, band3

    def process_image_for_stats(self, im_path):
        """Process image to collect statistics (first pass)"""
        try:
            with rasterio.open(im_path) as src:
                # Read VH and VV bands
                vh = src.read(1)
                vv = src.read(2)
                
                # Handle NaN and infinite values
                vh = np.nan_to_num(vh)
                vv = np.nan_to_num(vv)
                
                # Compute the three bands
                band1, band2, band3 = self.compute_bands(vh, vv)
                
                # Create 3-channel image
                arr_x = np.stack([band1, band2, band3], axis=0)
                
                # Clip extreme values (common in SAR preprocessing)
                for i in range(3):
                    v_min, v_max = np.percentile(arr_x[i], [1, 99])
                    arr_x[i] = np.clip(arr_x[i], v_min, v_max)
                
                # Resize to target dimensions
                arr_x = np.stack([
                    resize(arr_x[0], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[1], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[2], (self.img_size, self.img_size), preserve_range=True)
                ], axis=0)
                
                return arr_x
                
        except Exception as e:
            print(f"Error processing image for stats {im_path}: {str(e)}")
            return None

    def calculate_normalization_params(self, train_csv_path):
        """Calculate mean and std from training data"""
        print("Calculating normalization parameters from training data...")
        
        file_pairs = self.read_csv_file(train_csv_path)
        
        # Collect all pixel values for each band
        all_pixels = [[] for _ in range(3)]
        
        for im_fname, _ in tqdm(file_pairs, desc="Collecting statistics"):
            im_path = os.path.join(self.base_path, self.image_dir, im_fname)
            
            if not os.path.exists(im_path):
                continue
                
            arr_x = self.process_image_for_stats(im_path)
            if arr_x is not None:
                for i in range(3):
                    # Flatten and append pixels
                    all_pixels[i].append(arr_x[i].flatten())
        
        # Calculate means and stds
        means = []
        stds = []
        
        for i in range(3):
            if all_pixels[i]:
                combined_pixels = np.concatenate(all_pixels[i])
                means.append(np.mean(combined_pixels))
                stds.append(np.std(combined_pixels))
            else:
                means.append(0.0)
                stds.append(1.0)
        
        self.norm_means = np.array(means)
        self.norm_stds = np.array(stds)
        
        print(f"Calculated normalization parameters:")
        print(f"Band 1 (VV): mean={self.norm_means[0]:.4f}, std={self.norm_stds[0]:.4f}")
        print(f"Band 2 (NewBand1): mean={self.norm_means[1]:.4f}, std={self.norm_stds[1]:.4f}")
        print(f"Band 3 (NewBand2): mean={self.norm_means[2]:.4f}, std={self.norm_stds[2]:.4f}")
        
        # Save normalization parameters
        norm_params_path = os.path.join(self.output_path, 'normalization_params.npy')
        np.save(norm_params_path, {'means': self.norm_means, 'stds': self.norm_stds})
        print(f"Normalization parameters saved to: {norm_params_path}")

    def process_image(self, im_path):
        """Process image with normalization (second pass)"""
        try:
            with rasterio.open(im_path) as src:
                # Read VH and VV bands
                vh = src.read(1)
                vv = src.read(2)
                
                # Handle NaN and infinite values
                vh = np.nan_to_num(vh)
                vv = np.nan_to_num(vv)
                
                # Compute the three bands
                band1, band2, band3 = self.compute_bands(vh, vv)
                
                # Create 3-channel image
                arr_x = np.stack([band1, band2, band3], axis=0)
                
                # Clip extreme values (common in SAR preprocessing)
                for i in range(3):
                    v_min, v_max = np.percentile(arr_x[i], [1, 99])
                    arr_x[i] = np.clip(arr_x[i], v_min, v_max)
                
                # Resize to target dimensions
                arr_x = np.stack([
                    resize(arr_x[0], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[1], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[2], (self.img_size, self.img_size), preserve_range=True)
                ], axis=0)
                
                # Normalize using calculated means and stds
                if self.norm_means is not None and self.norm_stds is not None:
                    arr_x = (arr_x - self.norm_means.reshape(3, 1, 1)) / self.norm_stds.reshape(3, 1, 1)
                
                # Convert to HWC format for saving as image
                arr_x = np.transpose(arr_x, (1, 2, 0))
                
                # Scale to 0-1 range for visualization
                eps = 1e-8
                arr_x_viz = (arr_x - arr_x.min()) / (arr_x.max() - arr_x.min() + eps)
                
                return arr_x, arr_x_viz
                
        except Exception as e:
            print(f"Error processing image {im_path}: {str(e)}")
            return None, None

    def process_mask(self, mask_path):
        """
        Process mask preserving -1 values and creating validity mask
        Returns: (ground_truth_mask, validity_mask)
        """
        try:
            with rasterio.open(mask_path) as src:
                arr_y = src.read(1)
            
            # Resize to target dimensions using nearest neighbor to preserve labels
            arr_y = resize(arr_y, (self.img_size, self.img_size), order=0, preserve_range=True)
            
            # Create validity mask: True for valid pixels (0 or 1), False for invalid (-1)
            validity_mask = (arr_y != -1).astype(np.uint8)
            
            # Create ground truth mask: convert -1 to 0 for network compatibility,
            # but we'll use the validity mask to ignore these during training
            ground_truth_mask = arr_y.copy()
            ground_truth_mask[arr_y == -1] = 0  # Temporary conversion for network
            ground_truth_mask = (ground_truth_mask > 0).astype(np.uint8)
            
            return ground_truth_mask, validity_mask
            
        except Exception as e:
            print(f"Error processing mask {mask_path}: {str(e)}")
            return None, None

    def save_png(self, arr, save_path, mode='RGB'):
        """Save array as PNG image"""
        # Scale to 0-255 range for 8-bit image
        img = Image.fromarray((arr * 255).astype(np.uint8), mode=mode)
        img.save(save_path)

    def save_npy(self, arr, save_path):
        """Save raw array data as NPY file for preserving exact values"""
        np.save(save_path, arr)

    def update_stats(self, split, ground_truth_mask, validity_mask):
        """Update dataset statistics - only count valid pixels"""
        valid_pixels = validity_mask.astype(bool)
        
        self.stats[split]['count'] += 1
        self.stats[split]['flood_pixels'] += ground_truth_mask[valid_pixels].sum()
        self.stats[split]['total_valid_pixels'] += valid_pixels.sum()
        self.stats[split]['invalid_pixels'] += (~valid_pixels).sum()

    def process_dataset(self, split_name, csv_path):
        print(f"Processing {split_name} dataset...")
        file_pairs = self.read_csv_file(csv_path)
        output_dir = os.path.join(self.output_path, split_name)
        
        for idx, (im_fname, mask_fname) in enumerate(tqdm(file_pairs, desc=f"Processing {split_name}")):
            im_path = os.path.join(self.base_path, self.image_dir, im_fname)
            mask_path = os.path.join(self.base_path, self.mask_dir, mask_fname)
            
            if not os.path.exists(im_path) or not os.path.exists(mask_path):
                print(f"Warning: Files not found - {im_path} or {mask_path}")
                continue
            
            # Process image (get both normalized data and visualization)
            arr_x, arr_x_viz = self.process_image(im_path)
            if arr_x is not None:
                # Save visualization as PNG
                img_save_path = os.path.join(output_dir, 'images', f'{split_name}_{idx:04d}.png')
                self.save_png(arr_x_viz, img_save_path, mode='RGB')
                
                # Save raw normalized data for exact values
                raw_save_path = os.path.join(output_dir, 'images', f'{split_name}_{idx:04d}.npy')
                self.save_npy(arr_x, raw_save_path)
            
            # Process mask - now returns both ground truth and validity masks
            result = self.process_mask(mask_path)
            if result is not None:
                ground_truth_mask, validity_mask = result
                
                # Save ground truth mask as PNG
                mask_save_path = os.path.join(output_dir, 'masks', f'{split_name}_{idx:04d}.png')
                self.save_png(ground_truth_mask, mask_save_path, mode='L')
                
                # Save validity mask as PNG
                validity_save_path = os.path.join(output_dir, 'validity_masks', f'{split_name}_{idx:04d}.png')
                self.save_png(validity_mask, validity_save_path, mode='L')
                
                # Also save as NPY for exact values
                mask_npy_path = os.path.join(output_dir, 'masks', f'{split_name}_{idx:04d}.npy')
                validity_npy_path = os.path.join(output_dir, 'validity_masks', f'{split_name}_{idx:04d}.npy')
                self.save_npy(ground_truth_mask, mask_npy_path)
                self.save_npy(validity_mask, validity_npy_path)
                
                # Update statistics
                self.update_stats(split_name, ground_truth_mask, validity_mask)

    def print_stats(self):
        """Print dataset statistics"""
        print("\nDataset Statistics (Only Valid Pixels):")
        print("=" * 70)
        for split, stat in self.stats.items():
            if stat['count'] > 0:
                flood_percentage = 100 * stat['flood_pixels'] / stat['total_valid_pixels'] if stat['total_valid_pixels'] > 0 else 0
                invalid_percentage = 100 * stat['invalid_pixels'] / (stat['total_valid_pixels'] + stat['invalid_pixels'])
                print(f"{split.upper()} set: {stat['count']} samples")
                print(f"  Valid pixels: {stat['total_valid_pixels']:,} ({100-invalid_percentage:.1f}%)")
                print(f"  Invalid pixels: {stat['invalid_pixels']:,} ({invalid_percentage:.1f}%)")
                print(f"  Flood pixels (of valid): {stat['flood_pixels']:,} ({flood_percentage:.2f}%)")
                print()
        print("=" * 70)

    def calculate_class_weights(self):
        """Calculate class weights to handle imbalance (only for valid pixels)"""
        if self.stats['train']['total_valid_pixels'] > 0:
            pos_ratio = self.stats['train']['flood_pixels'] / self.stats['train']['total_valid_pixels']
            neg_ratio = 1 - pos_ratio
            
            # Class weights inversely proportional to class frequency
            weight_non_flood = 1.0
            weight_flood = neg_ratio / pos_ratio if pos_ratio > 0 else 1.0
            
            print(f"\nClass weights for handling imbalance (valid pixels only):")
            print(f"Weight for non-flood (0): {weight_non_flood:.4f}")
            print(f"Weight for flood (1): {weight_flood:.4f}")
            
            # Save weights for model training
            return np.array([weight_non_flood, weight_flood])
        return np.array([1.0, 1.0])

# Kaggle environment setup
print("Setting up ResUNet preprocessing for Kaggle environment...")
print("Input path:", "/kaggle/input/sen1floods11-essentials/v1.2")
print("Output path:", "/kaggle/working/preprocessed")
print("\nBand Configuration:")
print("Band 1: VV")
print("Band 2: NewBand1 = (VH - VV) / (VH + VV)")
print("Band 3: NewBand2 = sqrt((VHÂ² + VVÂ²) / 2)")
print("\nðŸ”§ GROUND TRUTH HANDLING:")
print("âœ… Preserving -1 (invalid) labels")
print("âœ… Creating validity masks to exclude invalid pixels from training")
print("âœ… Statistics calculated only on valid pixels (0 and 1)")

# Kaggle paths
base_path = "/kaggle/input/sen1floods11-essentials/v1.2"
output_path = "/kaggle/working/preprocessed"

# Initialize preprocessor
preprocessor = ResUNetPreprocessor(base_path, output_path)

# Step 1: Calculate normalization parameters from training data
train_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_train_data.csv')
preprocessor.calculate_normalization_params(train_csv_path)

print("\nStarting dataset preprocessing with calculated normalization parameters...")

# Step 2: Process all datasets using calculated normalization
# Process train set
preprocessor.process_dataset('train', train_csv_path)

# Process validation set
val_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_val_data.csv')
preprocessor.process_dataset('val', val_csv_path)

# Process test set
test_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_test_data.csv')
preprocessor.process_dataset('test', test_csv_path)

# Print statistics and calculate class weights
preprocessor.print_stats()
weights = preprocessor.calculate_class_weights()

# Save class weights for later use
weights_path = os.path.join(output_path, 'class_weights.npy')
np.save(weights_path, weights)
print(f"\nClass weights saved to: {weights_path}")

print(f"\nPreprocessing complete! Processed data saved to: {output_path}")
print("\nOutput structure:")
print("preprocessed/")
print("â”œâ”€â”€ train/")
print("â”‚   â”œâ”€â”€ images/ (PNG and NPY files)")
print("â”‚   â”œâ”€â”€ masks/ (PNG and NPY files - ground truth)")
print("â”‚   â””â”€â”€ validity_masks/ (PNG and NPY files - valid pixel indicators)")
print("â”œâ”€â”€ val/")
print("â”‚   â”œâ”€â”€ images/ (PNG and NPY files)")
print("â”‚   â”œâ”€â”€ masks/ (PNG and NPY files - ground truth)")
print("â”‚   â””â”€â”€ validity_masks/ (PNG and NPY files - valid pixel indicators)")
print("â”œâ”€â”€ test/")
print("â”‚   â”œâ”€â”€ images/ (PNG and NPY files)")
print("â”‚   â”œâ”€â”€ masks/ (PNG and NPY files - ground truth)")
print("â”‚   â””â”€â”€ validity_masks/ (PNG and NPY files - valid pixel indicators)")
print("â”œâ”€â”€ class_weights.npy")
print("â””â”€â”€ normalization_params.npy")



In [None]:
# Resunet+SPGR
!pip install torch_geometric
!pip install scikit-image


import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import random
import glob
from tensorflow.keras import backend as K
import json
import pandas as pd
from PIL import Image
from skimage.transform import resize
import math

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

print("TensorFlow version:", tf.__version__)
print("GPU Available:", len(tf.config.list_physical_devices('GPU')) > 0)
print("GPU Devices:", tf.config.list_physical_devices('GPU'))

# Define paths
BASE_PATH = "/kaggle/working/preprocessed"
OUTPUT_PATH = "/kaggle/working/flood_resunet_spgr"
os.makedirs(OUTPUT_PATH, exist_ok=True)

# ==================== SPGR IMPLEMENTATION ====================

class GraphConvLayer(layers.Layer):
    """Graph Convolution Layer using Chebyshev polynomials"""
    
    def __init__(self, output_channels, chebyshev_order=3, **kwargs):
        super(GraphConvLayer, self).__init__(**kwargs)
        self.output_channels = output_channels
        self.chebyshev_order = chebyshev_order
        
    def build(self, input_shape):
        # Create learnable weights for each Chebyshev polynomial order
        self.kernels = []
        for k in range(self.chebyshev_order):
            self.kernels.append(
                self.add_weight(
                    name=f'cheb_kernel_{k}',
                    shape=(input_shape[-1], self.output_channels),
                    initializer='glorot_uniform',
                    trainable=True
                )
            )
        super(GraphConvLayer, self).build(input_shape)
        
    def call(self, inputs, adjacency_matrix):
        """
        inputs: (batch_size, num_nodes, input_channels)
        adjacency_matrix: (num_nodes, num_nodes) - normalized adjacency
        """
        # Compute Chebyshev polynomials of the normalized Laplacian
        cheb_polynomials = self._compute_chebyshev_polynomials(adjacency_matrix)
        
        outputs = []
        for k in range(self.chebyshev_order):
            # Apply k-th order Chebyshev polynomial: T_k(L) * X
            filtered = tf.matmul(cheb_polynomials[k], inputs)
            
            # Apply learnable transformation: (T_k(L) * X) * W_k
            transformed = tf.matmul(filtered, self.kernels[k])
            outputs.append(transformed)
        
        # Sum all polynomial orders
        result = tf.add_n(outputs)
        return result
    
    def _compute_chebyshev_polynomials(self, normalized_laplacian):
        """Compute Chebyshev polynomials T_0, T_1, ..., T_{K-1}"""
        num_nodes = tf.shape(normalized_laplacian)[0]
        
        # T_0(L) = I
        polynomials = [tf.eye(num_nodes, dtype=tf.float32)]
        
        if self.chebyshev_order > 1:
            # T_1(L) = L
            polynomials.append(normalized_laplacian)
        
        # T_k(L) = 2 * L * T_{k-1}(L) - T_{k-2}(L)
        for k in range(2, self.chebyshev_order):
            T_k = 2.0 * tf.matmul(normalized_laplacian, polynomials[k-1]) - polynomials[k-2]
            polynomials.append(T_k)
        
        return polynomials
    
    def get_config(self):
        """Return the config of the layer for serialization"""
        config = super(GraphConvLayer, self).get_config()
        config.update({
            'output_channels': self.output_channels,
            'chebyshev_order': self.chebyshev_order
        })
        return config
    
    @classmethod
    def from_config(cls, config):
        """Create layer from config"""
        return cls(**config)


class SPGR(layers.Layer):
    """
    Spatial Pyramid Graph Reasoning Layer
    
    Implements a three-level pyramid (1/4, 1/8, 1/16 scales) with graph reasoning
    at each level and progressive feature fusion.
    """
    
    def __init__(self, output_channels, **kwargs):
        super(SPGR, self).__init__(**kwargs)
        self.output_channels = output_channels
        self.scale_factors = [4, 8, 16]  # 1/4, 1/8, 1/16 scales
        
    def build(self, input_shape):
        # Channel projection layers for each scale
        self.projections = []
        self.graph_convs = []
        self.batch_norms = []
        self.activations = []
        
        for i, scale in enumerate(self.scale_factors):
            # Project input channels to output channels
            projection_layer = layers.Conv2D(
                self.output_channels,
                kernel_size=1,
                padding='same',
                name=f'projection_scale_{scale}'
            )
            self.projections.append(projection_layer)
            
            # Graph convolution layer for this scale
            graph_conv_layer = GraphConvLayer(
                self.output_channels,
                chebyshev_order=3,
                name=f'graph_conv_scale_{scale}'
            )
            self.graph_convs.append(graph_conv_layer)
            
            # Batch normalization and activation
            bn_layer = layers.BatchNormalization(name=f'bn_scale_{scale}')
            self.batch_norms.append(bn_layer)
            
            relu_layer = layers.ReLU(name=f'relu_scale_{scale}')
            self.activations.append(relu_layer)
        
        # Final projection to ensure output channels
        self.final_projection = layers.Conv2D(
            self.output_channels,
            kernel_size=1,
            padding='same',
            name='final_projection'
        )
        
        super(SPGR, self).build(input_shape)
        
        # Build all sublayers with realistic dummy inputs to ensure proper serialization
        try:
            # Get realistic dimensions for building
            if input_shape[1] is not None and input_shape[2] is not None:
                height, width = input_shape[1], input_shape[2]
            else:
                height, width = 32, 32  # Default fallback
                
            channels = input_shape[3] if input_shape[3] is not None else 512
            batch_size = 1  # Use dummy batch size for building
            
            # Build projection and graph conv layers for each scale
            for i, scale_factor in enumerate(self.scale_factors):
                target_h = max(1, height // scale_factor)
                target_w = max(1, width // scale_factor)
                
                # Build projection layer
                dummy_input = tf.zeros((batch_size, target_h, target_w, channels))
                projected = self.projections[i](dummy_input)
                
                # Build graph conv layer
                dummy_adjacency = tf.eye(target_h * target_w, dtype=tf.float32)
                dummy_graph_input = tf.zeros((batch_size, target_h * target_w, self.output_channels))
                graph_output = self.graph_convs[i](dummy_graph_input, dummy_adjacency)
                
                # Build batch norm and activation layers
                dummy_spatial = tf.zeros((batch_size, target_h, target_w, self.output_channels))
                bn_output = self.batch_norms[i](dummy_spatial, training=False)
                _ = self.activations[i](bn_output)
            
            # Build final projection layer
            dummy_final_input = tf.zeros((batch_size, height, width, self.output_channels))
            _ = self.final_projection(dummy_final_input)
            
        except Exception as e:
            # If building fails, just continue - the layers will be built on first use
            print(f"Warning: Could not pre-build SPGR sublayers: {e}")
            pass
    
    def compute_output_shape(self, input_shape):
        """Compute the output shape of the layer"""
        return (input_shape[0], input_shape[1], input_shape[2], self.output_channels)
    
    def call(self, inputs):
        """
        inputs: (B, H, W, C) feature map
        returns: (B, H, W, output_channels) refined features
        """
        batch_size = tf.shape(inputs)[0]
        orig_height = tf.shape(inputs)[1]
        orig_width = tf.shape(inputs)[2]
        
        scale_features = []
        
        # Process each scale level
        for i, scale_factor in enumerate(self.scale_factors):
            # Calculate target dimensions
            target_h = orig_height // scale_factor
            target_w = orig_width // scale_factor
            
            # Downsample to current scale
            downsampled = tf.image.resize(
                inputs,
                [target_h, target_w],
                method='bilinear'
            )
            
            # Project to output channels
            projected = self.projections[i](downsampled)
            
            # Create grid graph adjacency matrix for this scale
            adjacency = self._create_grid_adjacency_tf(target_h, target_w)
            
            # Reshape for graph convolution: (B, H*W, C)
            num_nodes = target_h * target_w
            graph_input = tf.reshape(projected, [batch_size, num_nodes, self.output_channels])
            
            # Apply graph convolution
            graph_output = self.graph_convs[i](graph_input, adjacency)
            
            # Reshape back to spatial format: (B, H, W, C)
            spatial_output = tf.reshape(graph_output, [batch_size, target_h, target_w, self.output_channels])
            
            # Apply batch normalization and activation
            spatial_output = self.batch_norms[i](spatial_output)
            spatial_output = self.activations[i](spatial_output)
            
            scale_features.append(spatial_output)
        
        # Progressive feature fusion (coarse to fine)
        # Start with the coarsest scale (1/16)
        fused_features = scale_features[-1]  # 1/16 scale
        
        # Progressively add finer scales with skip connections
        for i in range(len(scale_features) - 2, -1, -1):
            # Upsample to match current scale
            current_h = tf.shape(scale_features[i])[1]
            current_w = tf.shape(scale_features[i])[2]
            
            upsampled = tf.image.resize(
                fused_features,
                [current_h, current_w],
                method='bilinear'
            )
            
            # Add skip connection (U-Net style)
            fused_features = upsampled + scale_features[i]
        
        # Upsample to original resolution
        final_features = tf.image.resize(
            fused_features,
            [orig_height, orig_width],
            method='bilinear'
        )
        
        # Final projection to ensure correct output channels
        output = self.final_projection(final_features)
        
        return output
    
    def _create_grid_adjacency_tf(self, height, width):
        """Create normalized adjacency matrix for 2D grid graph using XLA-compatible dense ops"""
        num_nodes = height * width
        
        # Create adjacency matrix using dense operations (XLA-compatible)
        adjacency = tf.zeros([num_nodes, num_nodes], dtype=tf.float32)
        
        # Create coordinate mappings
        indices = tf.range(num_nodes, dtype=tf.int32)
        i_coords = indices // width  # row coordinates
        j_coords = indices % width   # column coordinates
        
        # Stack coordinates for vectorized operations
        coords = tf.stack([i_coords, j_coords], axis=1)  # [num_nodes, 2]
        
        # Create all pairwise differences
        coord_diff = tf.expand_dims(coords, 1) - tf.expand_dims(coords, 0)  # [num_nodes, num_nodes, 2]
        
        # Calculate Manhattan distance between all pairs
        manhattan_dist = tf.reduce_sum(tf.abs(coord_diff), axis=2)  # [num_nodes, num_nodes]
        
        # Create adjacency matrix: connect nodes with Manhattan distance = 1 (4-connectivity)
        adjacency = tf.cast(tf.equal(manhattan_dist, 1), tf.float32)
        
        # Add self-loops (distance = 0)
        self_loops = tf.cast(tf.equal(manhattan_dist, 0), tf.float32)
        adjacency = adjacency + self_loops
        
        # Symmetric normalization: D^(-1/2) * A * D^(-1/2)
        degree = tf.reduce_sum(adjacency, axis=1)
        degree_inv_sqrt = tf.pow(degree + 1e-8, -0.5)  # Add epsilon for numerical stability
        
        # Apply normalization using broadcasting
        degree_inv_sqrt_expanded = tf.expand_dims(degree_inv_sqrt, 1)  # [num_nodes, 1]
        degree_inv_sqrt_transposed = tf.expand_dims(degree_inv_sqrt, 0)  # [1, num_nodes]
        
        normalized_adj = adjacency * degree_inv_sqrt_expanded * degree_inv_sqrt_transposed
        
        return normalized_adj
    
    def get_config(self):
        """Return the config of the layer for serialization"""
        config = super(SPGR, self).get_config()
        config.update({
            'output_channels': self.output_channels
        })
        return config
    
    @classmethod
    def from_config(cls, config):
        """Create layer from config"""
        return cls(**config)


# ==================== MODIFIED RESUNET WITH SPGR ====================

def conv_block(inputs, filters, kernel_size=3, strides=1, padding='same'):
    """Convolutional block with batch normalization and activation"""
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

def channel_attention(inputs, ratio=16):
    """Squeeze and Excitation Block for channel attention"""
    channels = int(inputs.shape[-1])
    reduced_channels = max(channels // ratio, 8)
    
    x = layers.GlobalAveragePooling2D()(inputs)
    x = layers.Reshape((1, 1, channels))(x)
    x = layers.Conv2D(reduced_channels, kernel_size=1, activation='relu', padding='same')(x)
    x = layers.Conv2D(channels, kernel_size=1, activation='sigmoid', padding='same')(x)
    
    output = layers.Multiply()([inputs, x])
    return output

def spatial_attention(inputs):
    """Spatial attention module"""
    avg_pool = layers.Conv2D(1, kernel_size=1, padding='same', use_bias=False, 
                            kernel_initializer='ones')(inputs)
    max_features = layers.Conv2D(1, kernel_size=7, padding='same', activation='relu')(inputs)
    concat = layers.Concatenate(axis=-1)([avg_pool, max_features])
    attention_map = layers.Conv2D(1, kernel_size=7, padding='same', activation='sigmoid')(concat)
    output = layers.Multiply()([inputs, attention_map])
    return output

def attention_residual_block(inputs, filters, kernel_size=3, strides=1):
    """Residual block with channel and spatial attention"""
    x = conv_block(inputs, filters, kernel_size, strides)
    x = conv_block(x, filters, kernel_size, 1)
    x = channel_attention(x)
    x = spatial_attention(x)
    
    if strides > 1 or inputs.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, kernel_size=1, strides=strides, padding='same')(inputs)
        shortcut = layers.BatchNormalization()(shortcut)
    else:
        shortcut = inputs
    
    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    
    return x

def build_spgr_attention_resunet(input_shape=(256, 256, 3), num_classes=1):
    """Build ResUNet model with SPGR and attention mechanisms"""
    inputs = layers.Input(input_shape)

    x = conv_block(inputs, 64, kernel_size=7, strides=1)

    # Encoder blocks
    skip1 = attention_residual_block(x, 64)
    x = layers.MaxPooling2D(2)(skip1)

    skip2 = attention_residual_block(x, 128)
    x = layers.MaxPooling2D(2)(skip2)

    skip3 = attention_residual_block(x, 256)
    x = layers.MaxPooling2D(2)(skip3)

    # Bridge with SPGR
    bridge = attention_residual_block(x, 512)
    
    #  SPGR INTEGRATION: Apply SPGR between encoder and decoder
    print(" Adding SPGR (Spatial Pyramid Graph Reasoning) layer...")
    spgr_features = SPGR(output_channels=512, name='spgr_bridge')(bridge)
    
    # Combine original bridge features with SPGR features
    enhanced_bridge = layers.Add(name='spgr_fusion')([bridge, spgr_features])
    enhanced_bridge = layers.BatchNormalization()(enhanced_bridge)
    enhanced_bridge = layers.ReLU()(enhanced_bridge)

    # Decoder blocks (using enhanced bridge features)
    x = layers.UpSampling2D(2)(enhanced_bridge)
    x = conv_block(x, 256)
    x = layers.Concatenate()([x, skip3])
    x = attention_residual_block(x, 256)

    x = layers.UpSampling2D(2)(x)
    x = conv_block(x, 128)
    x = layers.Concatenate()([x, skip2])
    x = attention_residual_block(x, 128)

    x = layers.UpSampling2D(2)(x)
    x = conv_block(x, 64)
    x = layers.Concatenate()([x, skip1])
    x = attention_residual_block(x, 64)

    # Output layer
    outputs = layers.Conv2D(num_classes, kernel_size=1, activation='sigmoid')(x)

    model = models.Model(inputs=inputs, outputs=outputs, name='SPGR_ResUNet')
    return model

# ==================== DATA LOADING FUNCTIONS ====================

def load_image(image_path):
    """Load normalized image data from .npy file"""
    if isinstance(image_path, tf.Tensor):
        image_path = image_path.numpy().decode('utf-8')
    return np.load(image_path).astype(np.float32)

def load_mask(mask_path):
    """Load mask from NPY file (preserves exact values)"""
    if isinstance(mask_path, tf.Tensor):
        mask_path = mask_path.numpy().decode('utf-8')
    mask = np.load(mask_path)
    return mask.astype(np.float32)

def load_validity_mask(validity_path):
    """Load validity mask from NPY file"""
    if isinstance(validity_path, tf.Tensor):
        validity_path = validity_path.numpy().decode('utf-8')
    validity_mask = np.load(validity_path)
    return validity_mask.astype(np.float32)

def augment_data(image, mask, validity_mask):
    """Apply data augmentation to image, mask, and validity mask"""
    # Random horizontal flip
    if tf.random.uniform([]) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
        validity_mask = tf.image.flip_left_right(validity_mask)
    
    # Random vertical flip
    if tf.random.uniform([]) > 0.5:
        image = tf.image.flip_up_down(image)
        mask = tf.image.flip_up_down(mask)
        validity_mask = tf.image.flip_up_down(validity_mask)
    
    # Random rotation (90, 180, 270 degrees)
    if tf.random.uniform([]) > 0.5:
        k = tf.random.uniform([], minval=1, maxval=4, dtype=tf.int32)
        image = tf.image.rot90(image, k=k)
        mask = tf.image.rot90(mask, k=k)
        validity_mask = tf.image.rot90(validity_mask, k=k)
    
    # Random brightness adjustment (only for image)
    if tf.random.uniform([]) > 0.5:
        image = tf.image.random_brightness(image, max_delta=0.1)
    
    # Random contrast adjustment (only for image)
    if tf.random.uniform([]) > 0.5:
        image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    
    return image, mask, validity_mask

def create_dataset(base_path, split, batch_size=16, shuffle=True, augment=False):
    """Create a TensorFlow dataset for the specified split with validity masks"""
    img_paths = sorted(glob.glob(os.path.join(base_path, split, 'images', '*.npy')))
    mask_paths = sorted(glob.glob(os.path.join(base_path, split, 'masks', '*.npy')))
    validity_paths = sorted(glob.glob(os.path.join(base_path, split, 'validity_masks', '*.npy')))

    if len(img_paths) == 0 or len(mask_paths) == 0 or len(validity_paths) == 0:
        raise ValueError(f"No images, masks, or validity masks found in {base_path}/{split}")

    print(f"Found {len(img_paths)} images, {len(mask_paths)} masks, and {len(validity_paths)} validity masks for {split}")

    # Create datasets of paths
    img_dataset = tf.data.Dataset.from_tensor_slices(img_paths)
    mask_dataset = tf.data.Dataset.from_tensor_slices(mask_paths)
    validity_dataset = tf.data.Dataset.from_tensor_slices(validity_paths)

    # Combine all paths
    dataset = tf.data.Dataset.zip((img_dataset, mask_dataset, validity_dataset))

    # Shuffle if needed
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(img_paths), seed=42)

    # Map loading function to the dataset
    dataset = dataset.map(
        lambda img_path, mask_path, validity_path: (
            tf.py_function(
                func=load_image,
                inp=[img_path],
                Tout=tf.float32
            ),
            tf.py_function(
                func=load_mask,
                inp=[mask_path],
                Tout=tf.float32
            ),
            tf.py_function(
                func=load_validity_mask,
                inp=[validity_path],
                Tout=tf.float32
            )
        ),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    # Set shapes
    dataset = dataset.map(
        lambda x, y, v: (
            tf.ensure_shape(x, [256, 256, 3]),
            tf.ensure_shape(y, [256, 256]),
            tf.ensure_shape(v, [256, 256])
        )
    )

    # Add channel dimension to masks
    dataset = dataset.map(lambda x, y, v: (x, tf.expand_dims(y, axis=-1), tf.expand_dims(v, axis=-1)))

    # Apply data augmentation for training set
    if augment and split == 'train':
        print(f"Applying data augmentation to {split} dataset")
        dataset = dataset.map(augment_data, num_parallel_calls=tf.data.AUTOTUNE)

    # Important: Add repeat to prevent dataset exhaustion
    dataset = dataset.repeat()
    
    # Batch and prefetch
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset, len(img_paths)

# ==================== MASKED LOSS AND METRICS ====================

def masked_dice_coefficient(y_true, y_pred, validity_mask, smooth=1e-6):
    """Calculate Dice coefficient only for valid pixels"""
    # Apply validity mask
    y_true_masked = y_true * validity_mask
    y_pred_masked = y_pred * validity_mask
    
    y_true_f = K.flatten(y_true_masked)
    y_pred_f = K.flatten(y_pred_masked)
    validity_f = K.flatten(validity_mask)
    
    # Only consider valid pixels
    intersection = K.sum(y_true_f * y_pred_f * validity_f)
    union = K.sum(y_true_f * validity_f) + K.sum(y_pred_f * validity_f)
    
    return (2. * intersection + smooth) / (union + smooth)

def masked_dice_loss(y_true_and_validity, y_pred):
    """Masked Dice loss function that ignores invalid pixels"""
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    
    return 1 - masked_dice_coefficient(y_true, y_pred, validity_mask)

def masked_focal_tversky_loss(y_true_and_validity, y_pred, alpha=0.7, beta=0.3, gamma=1.5, smooth=1e-6):
    """Masked Focal Tversky Loss that ignores invalid pixels"""
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    
    # Apply validity mask
    y_true_masked = y_true * validity_mask
    y_pred_masked = y_pred * validity_mask
    
    # Flatten the inputs
    y_true_f = K.flatten(y_true_masked)
    y_pred_f = K.flatten(y_pred_masked)
    validity_f = K.flatten(validity_mask)
    
    # Calculate true positives, false negatives, and false positives (only for valid pixels)
    true_pos = K.sum(y_true_f * y_pred_f * validity_f)
    false_neg = K.sum(y_true_f * (1 - y_pred_f) * validity_f)
    false_pos = K.sum((1 - y_true_f) * y_pred_f * validity_f)
    
    # Calculate Tversky index
    tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
    
    # Apply focal parameter to focus on hard examples
    focal_tversky = K.pow((1 - tversky), gamma)
    
    return focal_tversky

def masked_iou_score(y_true_and_validity, y_pred, smooth=1e-6):
    """Calculate IoU score only for valid pixels"""
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    
    # Apply validity mask
    y_true_masked = y_true * validity_mask
    y_pred_masked = y_pred * validity_mask
    
    y_true_f = K.flatten(y_true_masked)
    y_pred_f = K.flatten(y_pred_masked)
    validity_f = K.flatten(validity_mask)
    
    intersection = K.sum(y_true_f * y_pred_f * validity_f)
    union = K.sum(y_true_f * validity_f) + K.sum(y_pred_f * validity_f) - intersection
    
    return (intersection + smooth) / (union + smooth)

def masked_binary_accuracy(y_true_and_validity, y_pred):
    """Calculate binary accuracy only for valid pixels"""
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    
    # Apply validity mask
    y_true_masked = y_true * validity_mask
    y_pred_masked = y_pred * validity_mask
    
    # Threshold predictions
    y_pred_binary = K.cast(y_pred_masked > 0.5, K.floatx())
    
    # Calculate accuracy only for valid pixels
    correct = K.cast(K.equal(y_true_masked, y_pred_binary), K.floatx()) * validity_mask
    total_valid = K.sum(validity_mask)
    
    return K.sum(correct) / (total_valid + K.epsilon())

def masked_f1_score_metric(y_true_and_validity, y_pred, smooth=1e-6):
    """Calculate F1 score only for valid pixels"""
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    
    # Apply validity mask
    y_true_masked = y_true * validity_mask
    y_pred_masked = y_pred * validity_mask
    
    y_true_f = K.flatten(y_true_masked)
    y_pred_f = K.flatten(y_pred_masked)
    validity_f = K.flatten(validity_mask)

    # Calculate precision and recall for valid pixels only
    true_positives = K.sum(y_true_f * y_pred_f * validity_f)
    predicted_positives = K.sum(y_pred_f * validity_f)
    actual_positives = K.sum(y_true_f * validity_f)

    precision = (true_positives + smooth) / (predicted_positives + smooth)
    recall = (true_positives + smooth) / (actual_positives + smooth)

    # Calculate F1 score
    f1 = 2 * (precision * recall) / (precision + recall + smooth)
    return f1

def masked_precision_metric(y_true_and_validity, y_pred, smooth=1e-6):
    """Calculate precision only for valid pixels"""
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    
    # Apply validity mask
    y_true_masked = y_true * validity_mask
    y_pred_masked = y_pred * validity_mask
    
    y_true_f = K.flatten(y_true_masked)
    y_pred_f = K.flatten(y_pred_masked)
    validity_f = K.flatten(validity_mask)

    true_positives = K.sum(y_true_f * y_pred_f * validity_f)
    predicted_positives = K.sum(y_pred_f * validity_f)

    precision = (true_positives + smooth) / (predicted_positives + smooth)
    return precision

def masked_recall_metric(y_true_and_validity, y_pred, smooth=1e-6):
    """Calculate recall only for valid pixels"""
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    
    # Apply validity mask
    y_true_masked = y_true * validity_mask
    y_pred_masked = y_pred * validity_mask
    
    y_true_f = K.flatten(y_true_masked)
    y_pred_f = K.flatten(y_pred_masked)
    validity_f = K.flatten(validity_mask)

    true_positives = K.sum(y_true_f * y_pred_f * validity_f)
    actual_positives = K.sum(y_true_f * validity_f)

    recall = (true_positives + smooth) / (actual_positives + smooth)
    return recall

# Wrapper functions to use with compile (need specific names)
def masked_dice_coefficient_metric(y_true_and_validity, y_pred):
    y_true = y_true_and_validity[..., 0:1]
    validity_mask = y_true_and_validity[..., 1:2]
    return masked_dice_coefficient(y_true, y_pred, validity_mask)

# ==================== LEARNING RATE SCHEDULING ====================

def cosine_annealing_warm_restarts(epoch, initial_lr=0.001, min_lr=1e-6, T_0=50, T_mult=2):
    """Cosine annealing with warm restarts function for LearningRateScheduler"""
    T_cur = epoch
    T_i = T_0
    
    while T_cur >= T_i:
        T_cur -= T_i
        T_i *= T_mult
    
    lr = min_lr + (initial_lr - min_lr) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
    return lr

def prepare_labels_for_training(dataset):
    """Prepare dataset to combine ground truth and validity masks for masked loss"""
    def combine_masks(image, gt_mask, validity_mask):
        # Combine ground truth and validity masks into a single tensor
        combined_mask = tf.concat([gt_mask, validity_mask], axis=-1)
        return image, combined_mask
    
    return dataset.map(combine_masks)

def load_spgr_model(weights_path, input_shape=(256, 256, 3)):
    """
    Helper function to load SPGR model from weights
    
    Args:
        weights_path: Path to the saved weights file (.weights.h5)
        input_shape: Input shape for the model
    
    Returns:
        Loaded SPGR model
    """
    # Build the model architecture
    model = build_spgr_attention_resunet(input_shape)
    
    # Compile with dummy optimizer (you can recompile later with proper settings)
    model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    
    # Build the model by running a forward pass
    dummy_input = tf.zeros((1, *input_shape))
    _ = model(dummy_input)
    
    # Load the weights
    model.load_weights(weights_path)
    
    return model

# ==================== MAIN EXECUTION ====================

# Configuration
COSINE_ANNEALING_CONFIG = {
    "method": "warm_restarts_scheduler",
    "initial_lr": 0.0005,
    "min_lr": 5e-7,
    "T_max": 200,
    "T_0": 50,
    "T_mult": 2,
}

print(" TRAINING SPGR-ENHANCED RESUNET WITH MASKED LOSS")
print("=" * 60)
print(" Spatial Pyramid Graph Reasoning (SPGR) integrated at bridge")
print(" Three-level pyramid: 1/4, 1/8, 1/16 scales")
print(" Graph convolution with Chebyshev polynomials (order=3)")
print(" Progressive feature fusion with skip connections")
print(" Masked loss excluding invalid ground truth pixels")
print("=" * 60)

# Use a batch size that fits in GPU memory
BATCH_SIZE = 16  
# Create datasets with validity masks
print("Creating datasets with validity mask support...")
train_dataset, train_size = create_dataset(BASE_PATH, 'train', batch_size=BATCH_SIZE, augment=True)
val_dataset, val_size = create_dataset(BASE_PATH, 'val', batch_size=BATCH_SIZE, augment=False)
test_dataset, test_size = create_dataset(BASE_PATH, 'test', batch_size=BATCH_SIZE, augment=False)

print(f"Training dataset size: {train_size} images")
print(f"Validation dataset size: {val_size} images")
print(f"Test dataset size: {test_size} images")

# Build SPGR-enhanced model
print("\n Building SPGR-Enhanced ResUNet...")
input_shape = (256, 256, 3)
model = build_spgr_attention_resunet(input_shape)

# Calculate steps per epoch
steps_per_epoch = max(1, train_size // BATCH_SIZE)
validation_steps = max(1, val_size // BATCH_SIZE)

print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")

# Setup optimizer and learning rate scheduling
initial_lr = COSINE_ANNEALING_CONFIG["initial_lr"]
optimizer = optimizers.Adam(learning_rate=initial_lr)

lr_schedule_func = lambda epoch: cosine_annealing_warm_restarts(
    epoch,
    initial_lr=initial_lr,
    min_lr=COSINE_ANNEALING_CONFIG["min_lr"],
    T_0=COSINE_ANNEALING_CONFIG["T_0"],
    T_mult=COSINE_ANNEALING_CONFIG["T_mult"]
)
lr_callback = callbacks.LearningRateScheduler(lr_schedule_func, verbose=1)

# Prepare datasets for training with combined masks
print("Preparing datasets for masked training...")
train_dataset_prepared = prepare_labels_for_training(train_dataset)
val_dataset_prepared = prepare_labels_for_training(val_dataset)
test_dataset_prepared = prepare_labels_for_training(test_dataset)

# Compile model with masked loss and metrics
print("Compiling SPGR model with MASKED loss and metrics...")

# Disable XLA compilation to avoid sparse operation issues
compile_kwargs = {
    'optimizer': optimizer,
    'loss': masked_focal_tversky_loss,
    'metrics': [
        masked_dice_coefficient_metric,
        masked_iou_score,
        masked_binary_accuracy,
        masked_f1_score_metric,
        masked_precision_metric,
        masked_recall_metric
    ]
}

# Try to disable XLA compilation if available
try:
    # For newer TensorFlow versions
    compile_kwargs['jit_compile'] = False
    print(" Disabled XLA/JIT compilation for sparse operations compatibility")
except:
    print(" Could not disable XLA compilation (older TF version)")

model.compile(**compile_kwargs)

# Display model summary
print("\nSPGR-Enhanced ResUNet Architecture:")
model.summary()

# Define callbacks
checkpoint_path = os.path.join(OUTPUT_PATH, "best_spgr_model_masked.keras")
weights_path = os.path.join(OUTPUT_PATH, "best_spgr_weights_masked.weights.h5")  # Fixed: correct extension
log_dir = os.path.join(OUTPUT_PATH, "logs_spgr_masked")
os.makedirs(log_dir, exist_ok=True)

# Custom callback to save both model and weights
class ModelAndWeightsSaver(callbacks.Callback):
    def __init__(self, model_path, weights_path, monitor='val_masked_iou_score', mode='max'):
        super().__init__()
        self.model_path = model_path
        self.weights_path = weights_path
        self.monitor = monitor
        self.mode = mode
        self.best_value = -np.inf if mode == 'max' else np.inf
        
    def on_epoch_end(self, epoch, logs=None):
        current_value = logs.get(self.monitor)
        if current_value is None:
            return
            
        if (self.mode == 'max' and current_value > self.best_value) or \
           (self.mode == 'min' and current_value < self.best_value):
            self.best_value = current_value
            
            # Save both model and weights
            try:
                self.model.save(self.model_path)
                self.model.save_weights(self.weights_path)
                print(f"\n Saved best model and weights (epoch {epoch+1}, {self.monitor}: {current_value:.4f})")
            except Exception as e:
                print(f"\n Warning: Could not save model: {e}")
                # At least save the weights
                try:
                    self.model.save_weights(self.weights_path)
                    print(f" Saved weights only (epoch {epoch+1})")
                except Exception as we:
                    print(f" Failed to save weights: {we}")

callbacks_list = [
    ModelAndWeightsSaver(checkpoint_path, weights_path),
    callbacks.EarlyStopping(
        monitor='val_masked_iou_score',
        patience=70,
        restore_best_weights=True,
        mode='max'
    ),
    callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=1,
        update_freq='epoch',
        write_graph=True,
        write_images=True,
        profile_batch=0
    ),
    callbacks.CSVLogger(
        os.path.join(OUTPUT_PATH, 'training_log_spgr_masked.csv'),
        separator=',',
        append=False
    ),
    lr_callback
]

print(f"\n Training SPGR-Enhanced ResUNet...")
print(f"Initial LR: {initial_lr}, Min LR: {COSINE_ANNEALING_CONFIG['min_lr']}")

# Train model
epochs = 20  # Reduced due to increased model complexity
history = model.fit(
    train_dataset_prepared,
    validation_data=val_dataset_prepared,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=callbacks_list
)

# Save the trained model
print("Saving trained model...")
try:
    model.save(os.path.join(OUTPUT_PATH, 'spgr_resunet_model_masked.keras'))
    print(" Model saved successfully")
except Exception as e:
    print(f" Warning: Could not save full model: {e}")

# Save weights separately
try:
    model.save_weights(os.path.join(OUTPUT_PATH, 'spgr_training_weights.weights.h5'))  # Fixed: correct extension
    print(" Training weights saved successfully")
except Exception as e:
    print(f" Warning: Could not save training weights: {e}")

print("\n" + "="*80)
print(" SPGR-ENHANCED RESUNET TRAINING COMPLETE!")
print("="*80)
print(" Spatial Pyramid Graph Reasoning successfully integrated")
print(" Multi-scale graph convolution with Chebyshev polynomials")
print("Progressive feature fusion across pyramid levels")
print(" Masked training excluding invalid ground truth pixels")
print(f" Model saved to: {OUTPUT_PATH}")
print("="*80)

# Load and evaluate best model
print("Loading best SPGR model...")
try:
    # Try to load the full model with custom objects
    best_model = models.load_model(checkpoint_path, custom_objects={
        'masked_dice_coefficient_metric': masked_dice_coefficient_metric,
        'masked_focal_tversky_loss': masked_focal_tversky_loss,
        'masked_iou_score': masked_iou_score,
        'masked_binary_accuracy': masked_binary_accuracy,
        'masked_f1_score_metric': masked_f1_score_metric,
        'masked_precision_metric': masked_precision_metric,
        'masked_recall_metric': masked_recall_metric,
        'SPGR': SPGR,
        'GraphConvLayer': GraphConvLayer
    })
    print(" Successfully loaded model from checkpoint")
except Exception as e:
    print(f" Failed to load from checkpoint: {e}")
    print(" Rebuilding model and loading weights...")
    
    # Rebuild the model architecture
    best_model = build_spgr_attention_resunet(input_shape)
    best_model.compile(
        optimizer=optimizer,
        loss=masked_focal_tversky_loss,
        metrics=[
            masked_dice_coefficient_metric,
            masked_iou_score,
            masked_binary_accuracy,
            masked_f1_score_metric,
            masked_precision_metric,
            masked_recall_metric
        ]
    )
    
    # Load weights from the saved model
    try:
        # First, we need to build the model by running a forward pass
        dummy_input = tf.zeros((1, 256, 256, 3))
        _ = best_model(dummy_input)
        
        # Now load the weights
        best_model.load_weights(weights_path)  # Using the weights_path which has correct extension
        print(" Successfully loaded weights")
    except Exception as weight_error:
        print(f" Failed to load weights: {weight_error}")
        print("Using the last trained model instead...")
        best_model = model

# Evaluate on test set
print("Evaluating SPGR model on test set...")
test_steps = max(1, test_size // BATCH_SIZE)
test_results = best_model.evaluate(test_dataset_prepared, steps=test_steps)

print("\nSPGR Test Results (Invalid pixels excluded):")
for metric_name, value in zip(best_model.metrics_names, test_results):
    print(f"{metric_name}: {value:.4f}")

# Save test results
test_metrics = {metric_name: float(value) for metric_name, value in zip(best_model.metrics_names, test_results)}
test_metrics['spgr_enhanced'] = True
test_metrics['cosine_annealing_config'] = COSINE_ANNEALING_CONFIG
test_metrics['masked_training'] = True

with open(os.path.join(OUTPUT_PATH, 'spgr_test_metrics_masked.json'), 'w') as f:
    json.dump(test_metrics, f, indent=4)

# Also save model weights separately for easier loading
print("Saving final model weights separately...")
try:
    best_model.save_weights(os.path.join(OUTPUT_PATH, 'spgr_final_model_weights.weights.h5'))  # Fixed: correct extension
    print(" Final model weights saved separately")
except Exception as e:
    print(f" Warning: Could not save final model weights: {e}")