In [None]:
#dataset preprocessing with newband1 and 2



!pip install rasterio

import os
import csv
import numpy as np
import rasterio
from PIL import Image
from tqdm import tqdm
import shutil
from skimage.transform import resize

class ResUNetPreprocessor:
    def __init__(self, base_path, output_path, img_size=256):
        self.base_path = base_path
        self.output_path = output_path
        self.img_size = img_size
        
        # Initialize normalization parameters (will be calculated from data)
        self.norm_means = None
        self.norm_stds = None
        
        self.create_output_dirs()
        
        # Define ResUNet dataset structure for Kaggle environment
        self.image_dir = os.path.join('data', 'flood_events', 'HandLabeled', 'S1Hand')
        self.mask_dir = os.path.join('data', 'flood_events', 'HandLabeled', 'LabelHand')
        self.splits_dir = os.path.join('splits', 'flood_handlabeled')
        
        # Statistics dictionary to track dataset properties
        self.stats = {
            'train': {'count': 0, 'flood_pixels': 0, 'total_pixels': 0},
            'val': {'count': 0, 'flood_pixels': 0, 'total_pixels': 0},
            'test': {'count': 0, 'flood_pixels': 0, 'total_pixels': 0}
        }

    def create_output_dirs(self):
        # In Kaggle, we can write to /kaggle/working
        if os.path.exists(self.output_path):
            shutil.rmtree(self.output_path)
        
        for split in ['train', 'val', 'test']:
            os.makedirs(os.path.join(self.output_path, split, 'images'), exist_ok=True)
            os.makedirs(os.path.join(self.output_path, split, 'masks'), exist_ok=True)
            
        print(f"Created output directories at {self.output_path}")

    def read_csv_file(self, csv_path):
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")
            
        with open(csv_path, 'r') as f:
            reader = csv.reader(f)
            next(reader, None)  # Skip header
            return [(row[0], row[1]) for row in reader]

    def compute_bands(self, vh, vv):
        """
        Compute the three bands according to the specified formulas:
        Band 1: VV
        Band 2: NewBand1 = (VH - VV) / (VH + VV)
        Band 3: NewBand2 = sqrt((VH^2 + VV^2) / 2)
        """
        eps = 1e-8
        
        # Band 1: VV
        band1 = vv
        
        # Band 2: NewBand1 = (VH - VV) / (VH + VV)
        band2 = np.divide(vh - vv, vh + vv + eps)
        
        # Band 3: NewBand2 = sqrt((VH^2 + VV^2) / 2)
        band3 = np.sqrt((vh**2 + vv**2) / 2)
        
        return band1, band2, band3

    def process_image_for_stats(self, im_path):
        """Process image to collect statistics (first pass)"""
        try:
            with rasterio.open(im_path) as src:
                # Read VH and VV bands
                vh = src.read(1)
                vv = src.read(2)
                
                # Handle NaN and infinite values
                vh = np.nan_to_num(vh)
                vv = np.nan_to_num(vv)
                
                # Compute the three bands
                band1, band2, band3 = self.compute_bands(vh, vv)
                
                # Create 3-channel image
                arr_x = np.stack([band1, band2, band3], axis=0)
                
                # Clip extreme values (common in SAR preprocessing)
                for i in range(3):
                    v_min, v_max = np.percentile(arr_x[i], [1, 99])
                    arr_x[i] = np.clip(arr_x[i], v_min, v_max)
                
                # Resize to target dimensions
                arr_x = np.stack([
                    resize(arr_x[0], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[1], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[2], (self.img_size, self.img_size), preserve_range=True)
                ], axis=0)
                
                return arr_x
                
        except Exception as e:
            print(f"Error processing image for stats {im_path}: {str(e)}")
            return None

    def calculate_normalization_params(self, train_csv_path):
        """Calculate mean and std from training data"""
        print("Calculating normalization parameters from training data...")
        
        file_pairs = self.read_csv_file(train_csv_path)
        
        # Collect all pixel values for each band
        all_pixels = [[] for _ in range(3)]
        
        for im_fname, _ in tqdm(file_pairs, desc="Collecting statistics"):
            im_path = os.path.join(self.base_path, self.image_dir, im_fname)
            
            if not os.path.exists(im_path):
                continue
                
            arr_x = self.process_image_for_stats(im_path)
            if arr_x is not None:
                for i in range(3):
                    # Flatten and append pixels
                    all_pixels[i].append(arr_x[i].flatten())
        
        # Calculate means and stds
        means = []
        stds = []
        
        for i in range(3):
            if all_pixels[i]:
                combined_pixels = np.concatenate(all_pixels[i])
                means.append(np.mean(combined_pixels))
                stds.append(np.std(combined_pixels))
            else:
                means.append(0.0)
                stds.append(1.0)
        
        self.norm_means = np.array(means)
        self.norm_stds = np.array(stds)
        
        print(f"Calculated normalization parameters:")
        print(f"Band 1 (VV): mean={self.norm_means[0]:.4f}, std={self.norm_stds[0]:.4f}")
        print(f"Band 2 (NewBand1): mean={self.norm_means[1]:.4f}, std={self.norm_stds[1]:.4f}")
        print(f"Band 3 (NewBand2): mean={self.norm_means[2]:.4f}, std={self.norm_stds[2]:.4f}")
        
        # Save normalization parameters
        norm_params_path = os.path.join(self.output_path, 'normalization_params.npy')
        np.save(norm_params_path, {'means': self.norm_means, 'stds': self.norm_stds})
        print(f"Normalization parameters saved to: {norm_params_path}")

    def process_image(self, im_path):
        """Process image with normalization (second pass)"""
        try:
            with rasterio.open(im_path) as src:
                # Read VH and VV bands
                vh = src.read(1)
                vv = src.read(2)
                
                # Handle NaN and infinite values
                vh = np.nan_to_num(vh)
                vv = np.nan_to_num(vv)
                
                # Compute the three bands
                band1, band2, band3 = self.compute_bands(vh, vv)
                
                # Create 3-channel image
                arr_x = np.stack([band1, band2, band3], axis=0)
                
                # Clip extreme values (common in SAR preprocessing)
                for i in range(3):
                    v_min, v_max = np.percentile(arr_x[i], [1, 99])
                    arr_x[i] = np.clip(arr_x[i], v_min, v_max)
                
                # Resize to target dimensions
                arr_x = np.stack([
                    resize(arr_x[0], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[1], (self.img_size, self.img_size), preserve_range=True),
                    resize(arr_x[2], (self.img_size, self.img_size), preserve_range=True)
                ], axis=0)
                
                # Normalize using calculated means and stds
                if self.norm_means is not None and self.norm_stds is not None:
                    arr_x = (arr_x - self.norm_means.reshape(3, 1, 1)) / self.norm_stds.reshape(3, 1, 1)
                
                # Convert to HWC format for saving as image
                arr_x = np.transpose(arr_x, (1, 2, 0))
                
                # Scale to 0-1 range for visualization
                eps = 1e-8
                arr_x_viz = (arr_x - arr_x.min()) / (arr_x.max() - arr_x.min() + eps)
                
                return arr_x, arr_x_viz
                
        except Exception as e:
            print(f"Error processing image {im_path}: {str(e)}")
            return None, None

    def process_mask(self, mask_path):
        try:
            with rasterio.open(mask_path) as src:
                arr_y = src.read(1)
            
            # Clean mask values (-1 to 0, ensure binary)
            arr_y[arr_y == -1] = 0
            arr_y = (arr_y > 0).astype(np.uint8)
            
            # Resize to target dimensions
            arr_y = resize(arr_y, (self.img_size, self.img_size), order=0, preserve_range=True).astype(np.uint8)
            
            return arr_y
            
        except Exception as e:
            print(f"Error processing mask {mask_path}: {str(e)}")
            return None

    def save_png(self, arr, save_path, mode='RGB'):
        """Save array as PNG image"""
        # Scale to 0-255 range for 8-bit image
        img = Image.fromarray((arr * 255).astype(np.uint8), mode=mode)
        img.save(save_path)

    def save_npy(self, arr, save_path):
        """Save raw array data as NPY file for preserving exact values"""
        np.save(save_path, arr)

    def update_stats(self, split, mask):
        """Update dataset statistics"""
        self.stats[split]['count'] += 1
        self.stats[split]['flood_pixels'] += mask.sum()
        self.stats[split]['total_pixels'] += mask.size

    def process_dataset(self, split_name, csv_path):
        print(f"Processing {split_name} dataset...")
        file_pairs = self.read_csv_file(csv_path)
        output_dir = os.path.join(self.output_path, split_name)
        
        for idx, (im_fname, mask_fname) in enumerate(tqdm(file_pairs, desc=f"Processing {split_name}")):
            im_path = os.path.join(self.base_path, self.image_dir, im_fname)
            mask_path = os.path.join(self.base_path, self.mask_dir, mask_fname)
            
            if not os.path.exists(im_path) or not os.path.exists(mask_path):
                print(f"Warning: Files not found - {im_path} or {mask_path}")
                continue
            
            # Process image (get both normalized data and visualization)
            arr_x, arr_x_viz = self.process_image(im_path)
            if arr_x is not None:
                # Save visualization as PNG
                img_save_path = os.path.join(output_dir, 'images', f'{split_name}_{idx:04d}.png')
                self.save_png(arr_x_viz, img_save_path, mode='RGB')
                
                # Optionally save raw normalized data for exact values
                raw_save_path = os.path.join(output_dir, 'images', f'{split_name}_{idx:04d}.npy')
                self.save_npy(arr_x, raw_save_path)
            
            # Process mask
            arr_y = self.process_mask(mask_path)
            if arr_y is not None:
                # Save mask as PNG
                mask_save_path = os.path.join(output_dir, 'masks', f'{split_name}_{idx:04d}.png')
                self.save_png(arr_y, mask_save_path, mode='L')
                
                # Update statistics
                self.update_stats(split_name, arr_y)

    def print_stats(self):
        """Print dataset statistics"""
        print("\nDataset Statistics:")
        print("=" * 50)
        for split, stat in self.stats.items():
            if stat['count'] > 0:
                flood_percentage = 100 * stat['flood_pixels'] / stat['total_pixels']
                print(f"{split.upper()} set: {stat['count']} samples, "
                      f"Flood pixels: {flood_percentage:.2f}%")
        print("=" * 50)

    def calculate_class_weights(self):
        """Calculate class weights to handle imbalance"""
        if self.stats['train']['total_pixels'] > 0:
            pos_ratio = self.stats['train']['flood_pixels'] / self.stats['train']['total_pixels']
            neg_ratio = 1 - pos_ratio
            
            # Class weights inversely proportional to class frequency
            weight_non_flood = 1.0
            weight_flood = neg_ratio / pos_ratio if pos_ratio > 0 else 1.0
            
            print(f"\nClass weights for handling imbalance:")
            print(f"Weight for non-flood (0): {weight_non_flood:.4f}")
            print(f"Weight for flood (1): {weight_flood:.4f}")
            
            # Save weights for model training
            return np.array([weight_non_flood, weight_flood])
        return np.array([1.0, 1.0])

# Kaggle environment setup
print("Setting up ResUNet preprocessing for Kaggle environment...")
print("Input path:", "/kaggle/input/sen1floods11-essentials/v1.2")
print("Output path:", "/kaggle/working/preprocessed")
print("\nBand Configuration:")
print("Band 1: VV")
print("Band 2: NewBand1 = (VH - VV) / (VH + VV)")
print("Band 3: NewBand2 = sqrt((VH² + VV²) / 2)")

# Kaggle paths
base_path = "/kaggle/input/sen1floods11-essentials/v1.2"
output_path = "/kaggle/working/preprocessed"

# Initialize preprocessor
preprocessor = ResUNetPreprocessor(base_path, output_path)

# Step 1: Calculate normalization parameters from training data
train_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_train_data.csv')
preprocessor.calculate_normalization_params(train_csv_path)

print("\nStarting dataset preprocessing with calculated normalization parameters...")

# Step 2: Process all datasets using calculated normalization
# Process train set
preprocessor.process_dataset('train', train_csv_path)

# Process validation set
val_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_val_data.csv')
preprocessor.process_dataset('val', val_csv_path)

# Process test set
test_csv_path = os.path.join(base_path, 'splits', 'flood_handlabeled', 'flood_test_data.csv')
preprocessor.process_dataset('test', test_csv_path)

# Print statistics and calculate class weights
preprocessor.print_stats()
weights = preprocessor.calculate_class_weights()

# Save class weights for later use
weights_path = os.path.join(output_path, 'class_weights.npy')
np.save(weights_path, weights)
print(f"\nClass weights saved to: {weights_path}")

print(f"\nPreprocessing complete! Processed data saved to: {output_path}")
print("\nOutput structure:")
print("preprocessed/")
print("├── train/")
print("│   ├── images/ (PNG and NPY files)")
print("│   └── masks/ (PNG files)")
print("├── val/")
print("│   ├── images/ (PNG and NPY files)")
print("│   └── masks/ (PNG files)")
print("├── test/")
print("│   ├── images/ (PNG and NPY files)")
print("│   └── masks/ (PNG files)")
print("├── class_weights.npy")
print("└── normalization_params.npy")

In [None]:
# train GAC-Unet
!pip install torch_geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import cv2
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
import torch_geometric
from torch_geometric.nn import GATConv, ChebConv
from torch_geometric.data import Data
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import warnings
warnings.filterwarnings('ignore')

class DiceLoss(nn.Module):
    """Dice Loss implementation for segmentation"""
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)
        
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        
        return 1 - dice

class CenterOfMassLayer(nn.Module):
    """Center of Mass layer for spatial feature distribution"""
    def __init__(self, in_channels):
        super(CenterOfMassLayer, self).__init__()
        self.in_channels = in_channels
        
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # Create coordinate grids
        y_coords = torch.arange(height, dtype=torch.float32, device=x.device).view(1, 1, height, 1)
        x_coords = torch.arange(width, dtype=torch.float32, device=x.device).view(1, 1, 1, width)
        
        # Calculate center of mass for each channel
        total_mass = x.sum(dim=(2, 3), keepdim=True) + 1e-8
        
        # Center of mass coordinates
        y_com = (x * y_coords).sum(dim=(2, 3), keepdim=True) / total_mass
        x_com = (x * x_coords).sum(dim=(2, 3), keepdim=True) / total_mass
        
        # Normalize coordinates
        y_com = y_com / height
        x_com = x_com / width
        
        # Create centered features
        y_grid = (y_coords / height) - y_com
        x_grid = (x_coords / width) - x_com
        
        # Apply spatial attention based on distance from center of mass
        distance = torch.sqrt(y_grid**2 + x_grid**2)
        attention = torch.exp(-distance)
        
        return x * attention

class GraphConvBlock(nn.Module):
    """Graph convolutional block with GAT and Chebyshev convolutions"""
    def __init__(self, in_channels, out_channels, k=3):
        super(GraphConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k
        
        # Reduce spatial dimensions for graph processing
        self.spatial_reduce = nn.Conv2d(in_channels, out_channels, 1)
        self.spatial_restore = nn.Conv2d(out_channels, out_channels, 1)
        
        # Graph attention and Chebyshev convolutions
        self.gat_conv = GATConv(out_channels, out_channels, heads=4, concat=False, dropout=0.1)
        self.cheb_conv = ChebConv(out_channels, out_channels, K=k)
        
        self.norm1 = nn.BatchNorm1d(out_channels)
        self.norm2 = nn.BatchNorm1d(out_channels)
        self.dropout = nn.Dropout(0.1)
        
    def create_graph_from_feature_map(self, x, k_neighbors=8):
        """Convert feature map to graph representation"""
        batch_size, channels, height, width = x.size()
        
        # Reshape to nodes
        x_nodes = x.view(batch_size, channels, -1).permute(0, 2, 1)  # [B, H*W, C]
        num_nodes = height * width
        
        # Create coordinate grid for spatial connections
        coords = torch.stack(torch.meshgrid(
            torch.arange(height, device=x.device),
            torch.arange(width, device=x.device),
            indexing='ij'
        ), dim=-1).float()
        coords = coords.view(-1, 2)  # [H*W, 2]
        
        # Create edges based on spatial proximity (simplified version)
        edge_indices = []
        for i in range(num_nodes):
            y, x_coord = coords[i]
            
            # Connect to spatial neighbors
            neighbors = []
            for dy in [-1, 0, 1]:
                for dx in [-1, 0, 1]:
                    if dy == 0 and dx == 0:
                        continue
                    ny, nx = y + dy, x_coord + dx
                    if 0 <= ny < height and 0 <= nx < width:
                        neighbor_idx = int(ny * width + nx)
                        neighbors.append([i, neighbor_idx])
            
            edge_indices.extend(neighbors)
        
        if edge_indices:
            edge_index = torch.tensor(edge_indices, device=x.device).t().contiguous()
        else:
            # Fallback: create self-loops
            edge_index = torch.arange(num_nodes, device=x.device).unsqueeze(0).repeat(2, 1)
        
        return x_nodes, edge_index
    
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # Reduce channels for graph processing
        x_reduced = self.spatial_reduce(x)
        
        # Process each item in batch separately
        graph_outputs = []
        for b in range(batch_size):
            x_batch = x_reduced[b:b+1]
            
            # Convert to graph
            x_nodes, edge_index = self.create_graph_from_feature_map(x_batch)
            x_nodes = x_nodes.squeeze(0)  # Remove batch dimension
            
            # Apply graph attention
            x_gat = self.gat_conv(x_nodes, edge_index)
            x_gat = F.relu(self.norm1(x_gat))
            x_gat = self.dropout(x_gat)
            
            # Apply Chebyshev convolution
            x_cheb = self.cheb_conv(x_gat, edge_index)
            x_cheb = F.relu(self.norm2(x_cheb))
            
            # Reshape back to spatial format
            x_output = x_cheb.view(1, self.out_channels, height, width)
            graph_outputs.append(x_output)
        
        # Concatenate batch results
        graph_output = torch.cat(graph_outputs, dim=0)
        
        # Restore spatial dimensions
        output = self.spatial_restore(graph_output)
        
        return output + x_reduced  # Residual connection

class DoubleConv(nn.Module):
    """Double convolution block used in U-Net"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Handle dimension mismatch
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class GACUNET(nn.Module):
    """Graph Attention Convolutional U-NET"""
    def __init__(self, n_channels=3, n_classes=1, bilinear=True):
        super(GACUNET, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Graph layers (between encoder and decoder)
        self.graph_conv = GraphConvBlock(1024 // factor, 1024 // factor)
        self.center_of_mass = CenterOfMassLayer(1024 // factor)
        
        # Decoder
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        # Encoder path
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Graph processing
        x5 = self.graph_conv(x5)
        x5 = self.center_of_mass(x5)
        
        # Decoder path with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        return logits

class FloodDataset(Dataset):
    """Dataset class for flood segmentation"""
    def __init__(self, data_dir, split='train', transform=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        
        self.images_dir = os.path.join(data_dir, split, 'images')
        self.masks_dir = os.path.join(data_dir, split, 'masks')
        
        # Get all .npy files (containing exact normalized values)
        self.image_files = [f for f in os.listdir(self.images_dir) if f.endswith('.npy')]
        self.image_files.sort()
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load normalized image data (NPY file for exact values)
        img_file = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_file)
        
        # Load mask (PNG file)
        mask_file = img_file.replace('.npy', '.png')
        mask_path = os.path.join(self.masks_dir, mask_file)
        
        # Load image as numpy array and convert to tensor
        image = np.load(img_path).astype(np.float32)  # Shape: (H, W, 3)
        image = torch.from_numpy(image).permute(2, 0, 1)  # Convert to (3, H, W)
        
        # Load mask
        mask = np.array(Image.open(mask_path).convert('L')).astype(np.float32)
        mask = mask / 255.0  # Normalize to [0, 1]
        mask = torch.from_numpy(mask).unsqueeze(0)  # Add channel dimension
        
        if self.transform:
            # Apply same transform to both image and mask
            seed = torch.randint(0, 2**32, (1,)).item()
            torch.manual_seed(seed)
            image = self.transform(image)
            torch.manual_seed(seed)
            mask = self.transform(mask)
        
        return image, mask

def calculate_metrics(pred, target, threshold=0.5):
    """Calculate segmentation metrics"""
    pred_binary = (torch.sigmoid(pred) > threshold).float()
    target_binary = (target > threshold).float()
    
    # Convert to numpy for sklearn metrics
    pred_np = pred_binary.cpu().numpy().flatten()
    target_np = target_binary.cpu().numpy().flatten()
    
    # IoU (Jaccard Index)
    iou = jaccard_score(target_np, pred_np, average='binary', zero_division=1)
    
    # Dice Score
    intersection = (pred_binary * target_binary).sum().item()
    dice = (2. * intersection) / (pred_binary.sum().item() + target_binary.sum().item() + 1e-8)
    
    return iou, dice

def train_model(model, train_loader, val_loader, num_epochs=100, learning_rate=1e-4, device='cuda'):
    """Training function for GAC-UNET"""
    
    # Loss functions
    criterion_bce = nn.BCEWithLogitsLoss()
    criterion_dice = DiceLoss()
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5)
    
    # Tensorboard logging
    writer = SummaryWriter('runs/gac_unet_flood')
    
    best_val_dice = 0.0
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_iou = 0.0
        train_dice = 0.0
        
        # Training loop
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for batch_idx, (images, masks) in enumerate(progress_bar):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            
            # Combined loss (BCE + Dice)
            loss_bce = criterion_bce(outputs, masks)
            loss_dice = criterion_dice(outputs, masks)
            loss = 0.3 * loss_bce + 0.7 * loss_dice  # Weight dice loss more heavily
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Calculate metrics
            iou, dice = calculate_metrics(outputs, masks)
            
            train_loss += loss.item()
            train_iou += iou
            train_dice += dice
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Dice': f'{dice:.4f}',
                'IoU': f'{iou:.4f}'
            })
        
        # Average training metrics
        avg_train_loss = train_loss / len(train_loader)
        avg_train_iou = train_iou / len(train_loader)
        avg_train_dice = train_dice / len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_iou = 0.0
        val_dice = 0.0
        
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                
                outputs = model(images)
                
                # Calculate loss
                loss_bce = criterion_bce(outputs, masks)
                loss_dice = criterion_dice(outputs, masks)
                loss = 0.3 * loss_bce + 0.7 * loss_dice
                
                # Calculate metrics
                iou, dice = calculate_metrics(outputs, masks)
                
                val_loss += loss.item()
                val_iou += iou
                val_dice += dice
        
        # Average validation metrics
        avg_val_loss = val_loss / len(val_loader)
        avg_val_iou = val_iou / len(val_loader)
        avg_val_dice = val_dice / len(val_loader)
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        # Log metrics
        writer.add_scalar('Loss/Train', avg_train_loss, epoch)
        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
        writer.add_scalar('Dice/Train', avg_train_dice, epoch)
        writer.add_scalar('Dice/Validation', avg_val_dice, epoch)
        writer.add_scalar('IoU/Train', avg_train_iou, epoch)
        writer.add_scalar('IoU/Validation', avg_val_iou, epoch)
        
        # Save best model
        if avg_val_dice > best_val_dice:
            best_val_dice = avg_val_dice
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_dice': best_val_dice,
            }, 'best_gac_unet_model.pth')
        
        # Print epoch results
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}, Train Dice: {avg_train_dice:.4f}, Train IoU: {avg_train_iou:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}, Val Dice: {avg_val_dice:.4f}, Val IoU: {avg_val_iou:.4f}')
        print(f'  Best Val Dice: {best_val_dice:.4f}')
        print('-' * 60)
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
    
    writer.close()
    return train_losses, val_losses

def visualize_predictions(model, dataset, device, num_samples=4):
    """Visualize model predictions"""
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    
    with torch.no_grad():
        for i in range(num_samples):
            # Get random sample
            idx = np.random.randint(0, len(dataset))
            image, mask = dataset[idx]
            
            # Add batch dimension
            image_batch = image.unsqueeze(0).to(device)
            
            # Get prediction
            pred = model(image_batch)
            pred = torch.sigmoid(pred).squeeze().cpu().numpy()
            
            # Convert tensors to numpy for visualization
            image_np = image.permute(1, 2, 0).cpu().numpy()
            mask_np = mask.squeeze().cpu().numpy()
            
            # Normalize image for display
            image_display = (image_np - image_np.min()) / (image_np.max() - image_np.min())
            
            # Plot
            axes[i, 0].imshow(image_display)
            axes[i, 0].set_title('Input Image')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(mask_np, cmap='gray')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(pred, cmap='gray')
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')
            
            # Overlay
            overlay = image_display.copy()
            overlay[:, :, 0] = np.where(pred > 0.5, 1, overlay[:, :, 0])
            axes[i, 3].imshow(overlay)
            axes[i, 3].set_title('Overlay')
            axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig('predictions_visualization.png', dpi=300, bbox_inches='tight')
    plt.show()

# Main training script
def main():
    # Configuration
    BATCH_SIZE = 8
    NUM_EPOCHS = 100
    LEARNING_RATE = 1e-4
    DATA_DIR = "/kaggle/working/preprocessed"  # Path to preprocessed data
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"Using device: {DEVICE}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"PyTorch Geometric version: {torch_geometric.__version__}")
    
    # Data transforms for augmentation
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
    ])
    
    # Create datasets
    print("Loading datasets...")
    train_dataset = FloodDataset(DATA_DIR, split='train', transform=train_transform)
    val_dataset = FloodDataset(DATA_DIR, split='val', transform=None)
    test_dataset = FloodDataset(DATA_DIR, split='test', transform=None)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    # Create model
    print("Creating GAC-UNET model...")
    model = GACUNET(n_channels=3, n_classes=1, bilinear=True)
    model = model.to(DEVICE)
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Train model
    print("Starting training...")
    train_losses, val_losses = train_model(
        model, train_loader, val_loader, 
        num_epochs=NUM_EPOCHS, 
        learning_rate=LEARNING_RATE, 
        device=DEVICE
    )
    
    # Load best model for evaluation
    print("Loading best model for evaluation...")
    checkpoint = torch.load('best_gac_unet_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Test evaluation
    print("Evaluating on test set...")
    model.eval()
    test_iou = 0.0
    test_dice = 0.0
    
    with torch.no_grad():
        for images, masks in tqdm(test_loader, desc="Testing"):
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)
            
            iou, dice = calculate_metrics(outputs, masks)
            test_iou += iou
            test_dice += dice
    
    avg_test_iou = test_iou / len(test_loader)
    avg_test_dice = test_dice / len(test_loader)
    
    print(f"\nTest Results:")
    print(f"  Test IoU: {avg_test_iou:.4f}")
    print(f"  Test Dice: {avg_test_dice:.4f}")
    
    # Visualize some predictions
    print("Generating prediction visualizations...")
    visualize_predictions(model, test_dataset, DEVICE)
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(range(len(train_losses)), [avg_test_dice] * len(train_losses), 
             label=f'Test Dice: {avg_test_dice:.4f}', linestyle='--')
    plt.title('Final Test Performance')
    plt.xlabel('Epoch')
    plt.ylabel('Dice Score')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

if __name__ == "__main__":
    main()