In [None]:
# -*- coding: utf-8 -*-
"""ducknet_v1.2_binary.py

Flood Detection with Binary Output (0=not flooded, 1=flooded)
"""

# Install necessary packages
!pip install -q segmentation-models-pytorch timm tensorboardX
!pip install -q matplotlib scikit-image tqdm

# Import required libraries
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import torch.nn.functional as F
from torchvision import transforms
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
import gc
import random
from pathlib import Path
from sklearn.metrics import accuracy_score, f1_score, jaccard_score
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Aggressive memory cleanup function
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set paths
dataset_path = "/content/drive/MyDrive/Studies/swin_unet_preprocessed/"
output_path = "/content/drive/MyDrive/Studies/duck_1.5/"
model_path = os.path.join(output_path, "models")
log_path = os.path.join(output_path, "logs")

# Create necessary directories
os.makedirs(model_path, exist_ok=True)
os.makedirs(log_path, exist_ok=True)

# Initialize TensorboardX writer
writer = SummaryWriter(log_path)

# Memory-efficient dataset class
class SARDataset(Dataset):
    def __init__(self, root_dir, split="train", transform=None, max_samples=None):
        """
        Args:
            root_dir (string): Directory with preprocessed images and masks
            split (string): 'train', 'val', or 'test'
            transform (callable, optional): Optional transform to be applied
            max_samples (int, optional): Limit number of samples for testing
        """
        self.root_dir = Path(root_dir)
        self.split = split
        self.transform = transform

        # Get list of all image files
        self.image_dir = self.root_dir / split / "images"
        self.mask_dir = self.root_dir / split / "masks"

        self.image_files = sorted(list(self.image_dir.glob("*.npy")))

        # Limit samples if needed
        if max_samples and max_samples < len(self.image_files):
            self.image_files = self.image_files[:max_samples]

        print(f"Found {len(self.image_files)} images in {split} set")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get image and mask paths
        img_path = self.image_files[idx]
        mask_path = self.mask_dir / f"mask_{img_path.name[6:]}"

        # Load image and mask
        img = np.load(img_path)
        mask = np.load(mask_path)

        # Convert to tensors
        img = torch.from_numpy(img).float()
        mask = torch.from_numpy(mask).float()

        # Apply transforms if specified
        if self.transform:
            img = self.transform(img)

        # Ensure mask is binary
        mask = (mask > 0.5).float()

        return {
            'image': img,
            'mask': mask,
            'path': str(img_path)
        }

# Simple data augmentation for training
class SimpleAugmentation:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image):
        if random.random() < self.p:
            # Random horizontal flip
            if random.random() < 0.5:
                image = torch.flip(image, dims=[1])

            # Random vertical flip
            if random.random() < 0.5:
                image = torch.flip(image, dims=[2])

        return image

# Create datasets with sample limiting for testing
transform_train = SimpleAugmentation(p=0.7)

# Use smaller datasets to avoid memory issues
train_dataset = SARDataset(root_dir=dataset_path, split="train", transform=transform_train, max_samples=100)
val_dataset = SARDataset(root_dir=dataset_path, split="val", max_samples=50)
test_dataset = SARDataset(root_dir=dataset_path, split="test", max_samples=50)

# Small batch size to reduce memory usage
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0, pin_memory=False)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0, pin_memory=False)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=0, pin_memory=False)

# Display some information about the dataset
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

# Get a batch to check shapes
sample_batch = next(iter(train_dataloader))
print(f"Input shape: {sample_batch['image'].shape}")
print(f"Mask shape: {sample_batch['mask'].shape}")

# Clear memory
clear_memory()

# Simplified convolutional block
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# Memory-efficient dense block with consistent channels
class MemEfficientDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super().__init__()
        self.layers = nn.ModuleList()
        self.num_layers = num_layers
        self.growth_rate = growth_rate
        self.initial_channels = in_channels

        # Create layers with consistent channel growth
        current_channels = in_channels
        for i in range(num_layers):
            self.layers.append(ConvBlock(current_channels, growth_rate))
            current_channels += growth_rate

    def get_out_channels(self):
        return self.initial_channels + self.num_layers * self.growth_rate

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            # Process existing features
            new_feature = layer(torch.cat(features, dim=1))
            features.append(new_feature)

        # Final concatenation
        return torch.cat(features, dim=1)

# Down and up transitions
class TransitionDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.down(x)

class TransitionUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=False)

    def forward(self, x):
        return self.up(x)

# DUCKNET with binary classification output
class FixedDUCKNET(nn.Module):
    def __init__(self, in_channels=2, num_classes=1, init_features=32, growth_rate=16):
        super().__init__()

        # Use fewer layers to save memory
        block_layers = (3, 4, 5)

        # Initial convolution with consistent channels
        self.init_conv = ConvBlock(in_channels, init_features)

        # Create encoder blocks with explicit channel tracking
        self.encoder_blocks = nn.ModuleList()
        self.transitions_down = nn.ModuleList()
        self.encoder_channels = []

        features = init_features
        for i, num_layers in enumerate(block_layers):
            # Add dense block
            dense_block = MemEfficientDenseBlock(features, growth_rate, num_layers)
            self.encoder_blocks.append(dense_block)

            # Calculate output features precisely
            output_features = dense_block.get_out_channels()
            self.encoder_channels.append(output_features)

            # Add transition down (except for last block)
            if i < len(block_layers) - 1:
                trans = TransitionDown(output_features, output_features // 2)
                self.transitions_down.append(trans)
                features = output_features // 2

        # Decoder part with precise channel control
        self.decoder_blocks = nn.ModuleList()
        self.transitions_up = nn.ModuleList()
        self.skip_connectors = nn.ModuleList()

        # Current features after bottleneck
        features = self.encoder_channels[-1]

        # Create decoder path with precise channel handling
        for i in range(len(block_layers) - 2, -1, -1):
            # Transition up (halve channels)
            up_out_channels = features // 2
            trans = TransitionUp(features, up_out_channels)
            self.transitions_up.append(trans)

            # Calculate skip connection features
            skip_features = self.encoder_channels[i]

            # Adjust skip connections to match upsampled features exactly
            skip_connector = nn.Conv2d(skip_features, up_out_channels, kernel_size=1)
            self.skip_connectors.append(skip_connector)

            # Input to dense block after concatenation
            combined_features = up_out_channels * 2  # skip_adjusted + upsampled

            # Decoder block with fewer layers for memory efficiency
            decoder_n_layers = max(block_layers[i] // 2, 2)
            dense_block = MemEfficientDenseBlock(combined_features, growth_rate, decoder_n_layers)
            self.decoder_blocks.append(dense_block)

            # Update features for next level
            features = dense_block.get_out_channels()

        # Final classification layer
        self.final_conv = nn.Conv2d(features, num_classes, kernel_size=1)

        # Print channel dimensions for debugging
        print("Encoder channels:", self.encoder_channels)

    def forward(self, x):
        # Initial features
        x = self.init_conv(x)

        # Encoder path with skip connections
        skip_connections = []
        for i, block in enumerate(self.encoder_blocks[:-1]):
            x = block(x)
            skip_connections.append(x)
            x = self.transitions_down[i](x)
            # Clear unused memory
            torch.cuda.empty_cache()

        # Bottleneck
        x = self.encoder_blocks[-1](x)

        # Decoder path with skip connections
        for i, (trans_up, skip_connector, dense_block) in enumerate(zip(
            self.transitions_up, self.skip_connectors, self.decoder_blocks)):

            # Upsample features
            x = trans_up(x)

            # Get and adjust skip connection
            skip = skip_connections[-(i + 1)]
            skip_adjusted = skip_connector(skip)

            # Ensure spatial dimensions match
            if x.shape[2:] != skip_adjusted.shape[2:]:
                x = F.interpolate(x, size=skip_adjusted.shape[2:], mode='bilinear', align_corners=False)

            # Concatenate and continue
            x = torch.cat([x, skip_adjusted], dim=1)
            x = dense_block(x)

            # Clear memory
            torch.cuda.empty_cache()

        # Final classification
        x = self.final_conv(x)

        # KEY CHANGE FOR BINARY OUTPUT:
        # During training: return logits for loss functions
        # During inference: apply sigmoid + threshold for binary output (0 or 1)
        if self.training:
            return x
        else:
            return (torch.sigmoid(x) > 0.5).float()

# Initialize model for training
model = FixedDUCKNET(in_channels=2, num_classes=1, init_features=32, growth_rate=16)
model = model.to(device)

# Print model summary
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Fixed DUCK-NET model with {count_parameters(model):,} trainable parameters")

# Clear memory
clear_memory()

# Updated loss functions for mixed precision compatibility
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, target):
        # Apply sigmoid to convert logits to probabilities
        pred = torch.sigmoid(logits)

        pred_flat = pred.view(-1)
        target_flat = target.view(-1)

        intersection = (pred_flat * target_flat).sum()

        return 1 - ((2. * intersection + self.smooth) /
                   (pred_flat.sum() + target_flat.sum() + self.smooth))

class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss()
        # BCEWithLogitsLoss applies sigmoid internally for numerical stability
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, logits, target):
        dice = self.dice_loss(logits, target)
        bce = self.bce_loss(logits, target)
        return 0.5 * dice + 0.5 * bce

# Define metrics calculation function (memory-efficient version)
def calculate_metrics(outputs, targets, threshold=0.5):
    # For training: outputs are logits, apply sigmoid
    # For inference: outputs are already binary due to model's forward pass
    if model.training:
        # Apply sigmoid to get predictions if outputs are logits
        preds = torch.sigmoid(outputs)
        preds_binary = (preds > threshold).float().cpu().numpy().reshape(-1)
    else:
        # Outputs are already binary (0 or 1)
        preds_binary = outputs.float().cpu().numpy().reshape(-1)

    targets_binary = targets.float().cpu().numpy().reshape(-1)

    # Calculate metrics
    acc = accuracy_score(targets_binary, preds_binary)
    iou = jaccard_score(targets_binary, preds_binary, average='binary', zero_division=1)
    f1 = f1_score(targets_binary, preds_binary, average='binary', zero_division=1)

    return {
        'accuracy': acc,
        'iou': iou,
        'f1': f1
    }

# Initialize loss function
criterion = CombinedLoss()

# Clear memory
clear_memory()

# Memory-efficient training configuration
num_epochs = 10
initial_lr = 1e-4
save_frequency = 5
eval_frequency = 2  # Evaluate less frequently to save memory

# Define optimizer (with smaller model, we don't need weight decay)
optimizer = optim.Adam(model.parameters(), lr=initial_lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Initialize training variables
best_val_loss = float('inf')
best_val_metrics = {'accuracy': 0, 'iou': 0, 'f1': 0}
train_history = {'loss': [], 'accuracy': [], 'iou': [], 'f1': []}
val_history = {'loss': [], 'accuracy': [], 'iou': [], 'f1': []}

# Use mixed precision for memory efficiency
scaler = amp.GradScaler()

# Create a training log file
log_file = open(os.path.join(log_path, 'training_log.txt'), 'w')
log_file.write(f"DUCK-NET Binary Flood Classification - Started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
log_file.write(f"Parameters: {count_parameters(model):,} | Learning Rate: {initial_lr}\n")
log_file.write("-" * 80 + "\n")
log_file.flush()

# Logging function
def log_metrics(epoch, phase, losses, metrics):
    msg = f"Epoch {epoch}/{num_epochs} | {phase} | "
    msg += f"Loss: {losses:.4f} | "
    msg += f"Accuracy: {metrics['accuracy']:.4f} | "
    msg += f"IoU: {metrics['iou']:.4f} | "
    msg += f"F1: {metrics['f1']:.4f}"
    print(msg)
    log_file.write(msg + "\n")
    log_file.flush()

    # Log to TensorBoard
    writer.add_scalar(f'{phase}/loss', losses, epoch)
    writer.add_scalar(f'{phase}/accuracy', metrics['accuracy'], epoch)
    writer.add_scalar(f'{phase}/iou', metrics['iou'], epoch)
    writer.add_scalar(f'{phase}/f1', metrics['f1'], epoch)

# Memory-efficient batch processing
def process_batch(batch, training=True):
    images = batch['image'].to(device)
    masks = batch['mask'].to(device)

    if training:
        optimizer.zero_grad()
        model.train()  # Set model to training mode
    else:
        model.eval()   # Set model to evaluation mode (will output binary values)

    # Forward pass with mixed precision
    with amp.autocast():
        outputs = model(images)
        loss = criterion(outputs, masks)

    if training:
        # Backward and optimize with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    # Move tensors to CPU to free GPU memory
    loss_value = loss.item()
    outputs_cpu = outputs.detach().cpu()
    masks_cpu = masks.cpu()

    # Free memory
    del images, masks, outputs, loss
    torch.cuda.empty_cache()

    return loss_value, outputs_cpu, masks_cpu

# Training loop
for epoch in range(1, num_epochs + 1):
    print(f"\nEpoch {epoch}/{num_epochs}")
    print('-' * 40)

    # Training phase
    model.train()
    train_loss = 0.0
    all_preds = []
    all_targets = []

    # Process batches with memory management
    train_progress = tqdm(train_dataloader, desc=f"Training Epoch {epoch}")
    for batch in train_progress:
        loss_value, outputs, masks = process_batch(batch, training=True)

        # Update statistics
        train_loss += loss_value * len(outputs)
        all_preds.append(outputs)
        all_targets.append(masks)

        # Update progress bar
        train_progress.set_postfix({"Loss": loss_value})

    # Calculate epoch metrics
    train_loss = train_loss / len(train_dataset)
    epoch_preds = torch.cat(all_preds, dim=0)
    epoch_targets = torch.cat(all_targets, dim=0)
    train_metrics = calculate_metrics(epoch_preds, epoch_targets)

    # Log training metrics
    log_metrics(epoch, 'Train', train_loss, train_metrics)
    train_history['loss'].append(train_loss)
    train_history['accuracy'].append(train_metrics['accuracy'])
    train_history['iou'].append(train_metrics['iou'])
    train_history['f1'].append(train_metrics['f1'])

    # Clear memory
    del all_preds, all_targets, epoch_preds, epoch_targets
    clear_memory()

    # Validation phase - run less frequently to save memory
    if epoch % eval_frequency == 0 or epoch == num_epochs:
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            val_progress = tqdm(val_dataloader, desc=f"Validation Epoch {epoch}")
            for batch in val_progress:
                loss_value, outputs, masks = process_batch(batch, training=False)

                # Update statistics
                val_loss += loss_value * len(outputs)
                all_preds.append(outputs)
                all_targets.append(masks)

                # Update progress bar
                val_progress.set_postfix({"Loss": loss_value})

        # Calculate validation metrics
        val_loss = val_loss / len(val_dataset)
        epoch_preds = torch.cat(all_preds, dim=0)
        epoch_targets = torch.cat(all_targets, dim=0)
        val_metrics = calculate_metrics(epoch_preds, epoch_targets)

        # Log validation metrics
        log_metrics(epoch, 'Validation', val_loss, val_metrics)
        val_history['loss'].append(val_loss)
        val_history['accuracy'].append(val_metrics['accuracy'])
        val_history['iou'].append(val_metrics['iou'])
        val_history['f1'].append(val_metrics['f1'])

        # Update scheduler
        scheduler.step(val_loss)

        # Save model if it has the best validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_metrics = val_metrics

            # Save in a memory-efficient way
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'metrics': val_metrics,
            }, os.path.join(model_path, 'best_model.pth'))

            print(f"Saved new best model with validation loss: {val_loss:.4f}")
            log_file.write(f"New best model saved at epoch {epoch} with validation loss: {val_loss:.4f}\n")
            log_file.flush()

        # Clear memory after validation
        del all_preds, all_targets, epoch_preds, epoch_targets
        clear_memory()

    # Save checkpoint at specified intervals
    if epoch % save_frequency == 0 or epoch == num_epochs:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, os.path.join(model_path, f'checkpoint_epoch_{epoch}.pth'))

    # Extra memory cleanup
    clear_memory()

# Close log file
log_file.write(f"\nTraining completed at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
log_file.write(f"Best validation loss: {best_val_loss:.4f} with metrics:\n")
log_file.write(f"Accuracy: {best_val_metrics['accuracy']:.4f} | ")
log_file.write(f"IoU: {best_val_metrics['iou']:.4f} | ")
log_file.write(f"F1: {best_val_metrics['f1']:.4f}\n")
log_file.close()

print("Training completed!")

# Load the best model for evaluation
clear_memory()  # Clear memory before loading

best_model_path = os.path.join(model_path, 'best_model.pth')
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']} with validation loss: {checkpoint['loss']:.4f}")

# Memory-efficient evaluation
model.eval()  # This is crucial - will use binary output mode
test_loss = 0.0
all_metrics = {'accuracy': [], 'iou': [], 'f1': []}
examples_to_visualize = []
viz_count = 0

# Process test data in smaller batches
with torch.no_grad():
    test_progress = tqdm(test_dataloader, desc="Testing")
    for batch in test_progress:
        # Process batch - note model.eval() makes outputs binary
        loss_value, outputs, masks = process_batch(batch, training=False)

        # Update statistics
        test_loss += loss_value * len(outputs)

        # Calculate metrics for this batch only
        metrics = calculate_metrics(outputs, masks)
        all_metrics['accuracy'].append(metrics['accuracy'])
        all_metrics['iou'].append(metrics['iou'])
        all_metrics['f1'].append(metrics['f1'])

        # Store up to 3 examples for visualization
        if viz_count < 3:
            for i in range(min(len(outputs), 3 - viz_count)):
                examples_to_visualize.append({
                    'image': np.load(batch['path'][i]),
                    'mask': masks[i].numpy(),
                    'pred': outputs[i].numpy(),  # This is now binary (0 or 1)
                    'path': batch['path'][i]
                })
                viz_count += 1

        # Clear memory
        del outputs, masks
        clear_memory()

# Calculate overall test metrics
test_loss = test_loss / len(test_dataset)
test_metrics = {
    'accuracy': np.mean(all_metrics['accuracy']),
    'iou': np.mean(all_metrics['iou']),
    'f1': np.mean(all_metrics['f1'])
}

# Print test results
print("\nTest Results:")
print(f"Loss: {test_loss:.4f}")
print(f"Accuracy: {test_metrics['accuracy']:.4f}")
print(f"IoU: {test_metrics['iou']:.4f}")
print(f"F1: {test_metrics['f1']:.4f}")

# Save test results to a file
with open(os.path.join(output_path, 'test_results.txt'), 'w') as f:
    f.write("DUCK-NET Binary Flood Classification Results\n")
    f.write("-" * 50 + "\n")
    f.write(f"Loss: {test_loss:.4f}\n")
    f.write(f"Accuracy: {test_metrics['accuracy']:.4f}\n")
    f.write(f"IoU: {test_metrics['iou']:.4f}\n")
    f.write(f"F1: {test_metrics['f1']:.4f}\n")

# Clear memory
clear_memory()

# Binary visualization function
def visualize_binary_prediction(example, idx):
    """
    Visualize SAR image and binary flood prediction
    """
    # Create figure
    plt.figure(figsize=(12, 4))

    # Original SAR image (composite)
    plt.subplot(1, 3, 1)
    # Create false color composite
    vv_norm = (example['image'][0] - example['image'][0].min()) / (example['image'][0].max() - example['image'][0].min())
    vh_norm = (example['image'][1] - example['image'][1].min()) / (example['image'][1].max() - example['image'][1].min())
    ratio = np.clip(vv_norm / (vh_norm + 0.01), 0, 1)
    rgb = np.stack([vv_norm, vh_norm, ratio], axis=2)
    plt.imshow(rgb)
    plt.title('SAR Image (False Color)')
    plt.axis('off')

    # Ground Truth Mask
    plt.subplot(1, 3, 2)
    plt.imshow(example['mask'][0], cmap='viridis', vmin=0, vmax=1)
    plt.title('Ground Truth (Binary)')
    plt.axis('off')

    # Binary Prediction (0 or 1)
    plt.subplot(1, 3, 3)
    plt.imshow(example['pred'][0], cmap='viridis', vmin=0, vmax=1)
    plt.title('Predicted Flood (Binary)')
    plt.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(output_path, f'binary_prediction_{idx}.png'), dpi=150)
    plt.close()

# Visualize examples (now showing binary predictions)
for i, example in enumerate(examples_to_visualize):
    visualize_binary_prediction(example, i)
    print(f"Saved binary prediction example {i+1}")

# Plot training history
plt.figure(figsize=(15, 5))

# Loss plot
plt.subplot(1, 3, 1)
plt.plot(range(1, len(train_history['loss']) + 1), train_history['loss'], label='Train Loss')
if len(val_history['loss']) > 0:
    # Create x points for validation that match the eval frequency
    val_epochs = list(range(eval_frequency, num_epochs + 1, eval_frequency))
    if len(val_epochs) > len(val_history['loss']):
        val_epochs = val_epochs[:len(val_history['loss'])]
    plt.plot(val_epochs, val_history['loss'], label='Validation Loss')
plt.title('Loss Over Time')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# IoU plot
plt.subplot(1, 3, 2)
plt.plot(range(1, len(train_history['iou']) + 1), train_history['iou'], label='Train IoU')
if len(val_history['iou']) > 0:
    plt.plot(val_epochs, val_history['iou'], label='Validation IoU')
plt.title('IoU Over Time')
plt.xlabel('Epochs')
plt.ylabel('IoU')
plt.legend()
plt.grid(True)

# F1 plot
plt.subplot(1, 3, 3)
plt.plot(range(1, len(train_history['f1']) + 1), train_history['f1'], label='Train F1')
if len(val_history['f1']) > 0:
    plt.plot(val_epochs, val_history['f1'], label='Validation F1')
plt.title('F1 Score Over Time')
plt.xlabel('Epochs')
plt.ylabel('F1 Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(output_path, 'training_history.png'), dpi=150)
plt.close()  # Close to free memory

# Close TensorBoard writer
writer.close()
print("All visualizations saved to:", output_path)

# Binary flood prediction function
def predict_flooding(img_path, model, device):
    """
    Perform flood prediction on a new SAR image - returns binary classification
    """
    # Load the image
    img = np.load(img_path)
    img = torch.from_numpy(img).float().unsqueeze(0).to(device)

    # Set model to evaluation mode - this is critical!
    model.eval()

    # Perform inference - output will be binary (0 or 1)
    with torch.no_grad():
        binary_mask = model(img)

    # Move results to CPU and convert to numpy
    binary_mask_np = binary_mask.cpu().numpy()

    # Clear memory
    del img, binary_mask
    torch.cuda.empty_cache()

    return binary_mask_np

# Example usage - run only if test data is available
if len(test_dataset) > 0:
    # Clear memory before inference
    clear_memory()

    # Run inference on a sample
    sample_path = test_dataset.image_files[0]
    print(f"Running inference on sample image: {sample_path}")

    # Get binary mask directly from model
    binary_mask = predict_flooding(sample_path, model, device)

    # Visualize the result
    plt.figure(figsize=(10, 4))

    # Load and display original image
    img = np.load(sample_path)

    plt.subplot(1, 2, 1)
    # Create false color composite
    vv_norm = (img[0] - img[0].min()) / (img[0].max() - img[0].min())
    vh_norm = (img[1] - img[1].min()) / (img[1].max() - img[1].min())
    ratio = np.clip(vv_norm / (vh_norm + 0.01), 0, 1)
    rgb = np.stack([vv_norm, vh_norm, ratio], axis=2)
    plt.imshow(rgb)
    plt.title('SAR Image (False Color)')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(binary_mask[0, 0], cmap='viridis', vmin=0, vmax=1)
    plt.title('Binary Flood Classification (0 or 1)')
    plt.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(output_path, 'binary_flood_map.png'), dpi=150)
    plt.close()

    print(f"Binary flood map saved to {output_path}")

print("DuckNet Binary Flood Classification Completed!")