<a href="https://colab.research.google.com/github/Gowtham-P23/new_google/blob/main/Copy_of_ShearletExitUnet_semantic_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
aletbm_global_land_cover_mapping_openearthmap_path = kagglehub.dataset_download('aletbm/global-land-cover-mapping-openearthmap')

print('Data source import complete.')


Downloading from https://www.kaggle.com/api/v1/datasets/download/aletbm/global-land-cover-mapping-openearthmap?dataset_version_number=1...


100%|██████████| 8.47G/8.47G [01:19<00:00, 114MB/s]

Extracting files...





Data source import complete.


In [None]:
!pip install -q segmentation-models-pytorch
!pip install -q albumentations
!pip install -q timm
!pip install -q rasterio

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m73.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import glob
import numpy as np
import rasterio
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler # Import lr_scheduler

import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

In [None]:
import os
import torch

class Config:
    # DATA_PATH = "/kaggle/input/global-land-cover-mapping-openearthmap/"
    DATA_PATH = aletbm_global_land_cover_mapping_openearthmap_path
    # DATA_PATH = "/root/.cache/kagglehub/datasets/aletbm/global-land-cover-mapping-openearthmap/versions/1/"

    # --- File/Directory Names ---
    IMAGE_DIR_NAME = "images"
    LABEL_DIR_NAME = "label"
    TRAIN_TXT = os.path.join(DATA_PATH, "train.txt")
    VAL_TXT = os.path.join(DATA_PATH, "val.txt")

    # --- Base Directories ---
    IMAGE_BASE_DIR = os.path.join(DATA_PATH, IMAGE_DIR_NAME)
    LABEL_BASE_DIR = os.path.join(DATA_PATH, LABEL_DIR_NAME)

    NUM_CLASSES = 9

    # --- Training Hyperparameters ---
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 4
    NUM_WORKERS = 2
    IMAGE_SIZE = 256
    LEARNING_RATE = 1e-4
    EPOCHS = 40
    PIN_MEMORY = True
    FPN_CHANNELS = 256
    # Add Weight Decay for regularization
    WEIGHT_DECAY = 1e-2
    # Parameters for Early Stopping
    EARLY_STOPPING_PATIENCE = 10 # Number of epochs to wait for improvement
    MIN_DELTA = 0.001 # Minimum change to qualify as an improvement

cfg = Config()
print(f"Using device: {cfg.DEVICE}")
print(f"Number of classes: {cfg.NUM_CLASSES}")
print(f"Weight Decay: {cfg.WEIGHT_DECAY}")
print(f"Early Stopping Patience: {cfg.EARLY_STOPPING_PATIENCE}")

Using device: cpu
Number of classes: 9
Weight Decay: 0.01
Early Stopping Patience: 10


In [None]:
# Define transformations for training and validation sets using Albumentations
train_transform = A.Compose([
    # Use RandomResizedCrop instead of just Resize for better generalization
    A.RandomResizedCrop(
        size=(cfg.IMAGE_SIZE, cfg.IMAGE_SIZE),
        scale=(0.7, 1.0),
        p=1.0
    ),

    # Geometric augmentations
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.1,
        scale_limit=0.15,
        rotate_limit=15,
        border_mode=0,
        p=0.5
    ),

    # Spatial augmentations
    A.OneOf([
        A.ElasticTransform(alpha=1, sigma=50, p=1),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=1),
        A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1),
    ], p=0.3),

    # Color augmentations
    A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.6),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.4),

    # Noise and blur
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0), p=1),
        A.GaussianBlur(blur_limit=(3, 7), p=1),
    ], p=0.3),

    # Simulate missing data/clouds
    A.CoarseDropout(
        max_holes=5,
        max_height=32,
        max_width=32,
        min_holes=1,
        min_height=8,
        min_width=8,
        fill_value=0,
        p=0.25
    ),

    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=cfg.IMAGE_SIZE, width=cfg.IMAGE_SIZE),
    A.Normalize(
        mean=[0.485, 0.456, 0.406], # ImageNet mean
        std=[0.229, 0.224, 0.225],  # ImageNet std
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

class LandCoverDataset(Dataset):
    """Custom Dataset for your Land Cover data - Enhanced version."""

    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.num_classes = cfg.NUM_CLASSES

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            # Read image
            with rasterio.open(img_path, 'r') as f:
                image = f.read([1, 2, 3]).transpose(1, 2, 0)

                # Handle different dtypes
                if image.dtype == np.uint16:
                    image = (image / 256).astype(np.uint8)
                elif image.dtype in [np.float32, np.float64]:
                    image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
                else:
                    image = image.astype(np.uint8)

            # Read mask - FIXED: proper dtype handling
            with rasterio.open(mask_path, 'r') as f:
                mask = f.read(1).astype(np.int64)  # Changed from uint8 to int64

            # Apply augmentations
            if self.transform:
                augmented = self.transform(image=image, mask=mask)
                image = augmented["image"]
                mask = augmented["mask"]

            # Ensure mask is long tensor - FIXED: proper conversion
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask).long()
            else:
                mask = mask.long()

            return image, mask

        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return dummy data to avoid crash
            return (
                torch.zeros(3, cfg.IMAGE_SIZE, cfg.IMAGE_SIZE),
                torch.zeros(cfg.IMAGE_SIZE, cfg.IMAGE_SIZE, dtype=torch.long)
            )

  original_init(self, **validated_kwargs)
  A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1),
  A.GaussNoise(var_limit=(10.0, 50.0), p=1),
  A.CoarseDropout(


In [None]:
import torch
import torch.nn as nn

import timm
import torch.nn.functional as F # Import functional

class TimmEncoder(nn.Module):
    """
    Encoder module using a pre-trained backbone from the timm library.
    """
    def __init__(self, model_name='swin_tiny_patch4_window7_224', pretrained=True, in_chans=3):
        super().__init__()
        # create_model will return a model that outputs a list of feature maps
        # at different resolutions, which is exactly what we need for skip connections.
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            features_only=True,
            in_chans=in_chans,
        )

        self.output_channels = self.backbone.feature_info.channels()

    def forward(self, x):
        # The backbone returns a list of feature maps, e.g., [64, 128, 256, 512] channels
        return self.backbone(x)

class DecoderBlock(nn.Module):
    """
    A standard U-Net decoder block.
    Upsamples, concatenates with a skip connection, and applies convolutions.
    Includes optional Dropout.
    """
    def __init__(self, in_channels, skip_channels, out_channels, dropout_prob=0.2):
        super().__init__()
        # Upsampling layer (could also use nn.Upsample + nn.Conv2d)
        self.upsample = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)

        # Convolutional layers to process the concatenated features
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels // 2 + skip_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity()


    def forward(self, x, skip_connection):
        x = self.upsample(x)
        # Concatenate the upsampled features with the skip connection
        x = torch.cat([x, skip_connection], dim=1)
        x = self.convs(x)
        x = self.dropout(x) # Apply dropout
        return x

class ShearletExitUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=9, shearlet_features=16, dropout_prob=0.2, fpn_out_channels=256, backbone_name='swin_tiny_patch4_window7_224'):
        super().__init__()

        self.lsf = LearnedShearletFrontEnd(in_channels, shearlet_features)
        self.encoder = TimmEncoder(model_name=backbone_name, pretrained=True, in_chans=in_channels)

        encoder_features_info = self.encoder.backbone.feature_info # e.g., [64, 128, 256, 512] channels
        encoder_channels = [info['num_chs'] for info in encoder_features_info]


        # --- 1. FPN Layer Initialization ---
        # Lateral layers (1x1 convs) to process encoder features
        self.lateral_convs = nn.ModuleList()
        for in_chans in encoder_channels:
            self.lateral_convs.append(
                nn.Conv2d(in_chans, fpn_out_channels, kernel_size=1)
            )

        # Output layers (3x3 convs) to create the final feature pyramid
        self.output_convs = nn.ModuleList()
        for _ in range(len(encoder_channels)):
            self.output_convs.append(
                nn.Sequential(
                    nn.Conv2d(fpn_out_channels, fpn_out_channels, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True)
                )
            )

        # --- 2. Custom Gates and Prediction Heads ---
        self.cfs_gates = nn.ModuleList()
        self.deep_supervision_heads = nn.ModuleList()

        # We need gates and heads for each level of the pyramid
        skip_chans = encoder_channels[::-1] # All encoder levels are used as skips in FPN

        # The decoder input for the gate will be the FPN feature from the level above
        gate_decoder_channels = [fpn_out_channels] * len(encoder_channels)

        for i in range(len(encoder_channels)):
            self.cfs_gates.append(
                CrossFrequencySkipGate(skip_chans[i], gate_decoder_channels[i], shearlet_features)
            )
            self.deep_supervision_heads.append(
                SegmentationHead(fpn_out_channels, num_classes)
            )

        # Note: The early-exit modules are omitted here for simplicity but could be
        # added back, attached to the final FPN feature maps.

    def forward(self, x, training=True, exit_threshold=0.8):
        f_shear = self.lsf(x)
        encoder_features = self.encoder(x) # List of features [c2, c3, c4, c5]
        # Depending on the timm backbone, features might need permuting
        # Check the output shape of your chosen backbone.
        # For Swin Transformer, we permute:
        # encoder_features = [f.permute(0, 3, 1, 2) for f in encoder_features]


        # --- 3. FPN Top-Down Pathway Implementation ---

        # Start with the deepest encoder feature (c5)
        p5 = self.lateral_convs[-1](encoder_features[-1])
        p5_out = self.output_convs[-1](p5)
        pyramid_features = [p5_out]

        # Loop from the second-to-last layer upwards
        # zip(reversed(list_[:-1]), reversed(list_[:-1])) -> (c4, lateral4), (c3, lateral3)...
        for i in range(len(encoder_features) - 2, -1, -1):
            # Get previous pyramid feature (e.g., p5) and current encoder feature (e.g., c4)
            p_prev = pyramid_features[0] # The latest feature added to the pyramid
            c_curr = encoder_features[i]

            # --- 4. Integrate Your Custom Gate ---
            # The gate uses the feature from the layer above (p_prev) as its decoder context
            gated_skip = self.cfs_gates[len(encoder_features) - 1 - i](c_curr, p_prev, f_shear)

            # Upsample previous P and add to the (gated) current C
            p_prev_upsampled = F.interpolate(p_prev, size=c_curr.shape[-2:], mode='bilinear')
            p_curr = self.lateral_convs[i](gated_skip) + p_prev_upsampled

            # Apply output conv and add to the front of the pyramid list
            p_curr_out = self.output_convs[i](p_curr)
            pyramid_features.insert(0, p_curr_out)

        # --- 5. Generate Predictions from the Pyramid ---
        all_predictions = []
        # pyramid_features is now [p2, p3, p4, p5]
        for i, p_out in enumerate(pyramid_features):
             # The heads are in order [h2, h3, h4, h5]
             # To match p2 -> h2, p3 -> h3 etc., we use deep_supervision_heads[i]
            prediction = self.deep_supervision_heads[i](p_out)
            all_predictions.append(prediction)

        if training:
            # Reverse the list so the deepest prediction is last, matching your previous loss logic
            return all_predictions[::-1]
        else:
            # For inference, typically the highest-resolution prediction (from P2) is used
            # after upsampling to the original image size.
            final_pred = F.interpolate(all_predictions[0], size=x.shape[-2:], mode='bilinear')
            return final_pred

In [None]:
class LearnedShearletFrontEnd(nn.Module):
    def __init__(self, in_channels, num_subbands=16):
        super().__init__()
        # Create a list of 1x1 convs to act as linear combinations,
        # followed by 3x3 convs to learn spatial patterns for each sub-band.
        self.subband_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels, 1, kernel_size=1, padding='same'),
                nn.Conv2d(1, 1, kernel_size=3, padding='same')
            ) for _ in range(num_subbands)
        ])

    def forward(self, x):
        # Apply each conv to the input and stack the results
        subband_features = [conv(x) for conv in self.subband_convs]
        return torch.cat(subband_features, dim=1) # Shape: (B, C_f, H, W)

In [None]:
class CrossFrequencySkipGate(nn.Module):
    def __init__(self, skip_channels, decoder_channels, shearlet_channels):
        super().__init__()
        # This layer will combine the spatial and frequency info to create the gate
        self.gate_conv = nn.Sequential(
            nn.Conv2d(skip_channels + decoder_channels + shearlet_channels,
                      skip_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(skip_channels, skip_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, s_i, d_i, f_shear):
        # S_i: Skip features from encoder
        # D_i: Decoder features from previous (deeper) stage
        # F_shear: Global frequency features from LSF

        # Resize decoder and shearlet features to match the skip connection's size
        target_size = s_i.shape[2:]
        d_i_resized = nn.functional.interpolate(d_i, size=target_size, mode='bilinear')
        f_shear_resized = nn.functional.interpolate(f_shear, size=target_size, mode='bilinear')

        # Concatenate all features
        combined_features = torch.cat([s_i, d_i_resized, f_shear_resized], dim=1)

        # Generate the attention gate
        gate = self.gate_conv(combined_features)

        # Apply the gate to the original skip connection
        return s_i * gate

In [None]:
# The PEM just needs to produce a single logit per pixel
class PerPixelExitModule(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.exit_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels // 2, 1, kernel_size=1) # Output one logit per pixel
        )
    def forward(self, x):
        return self.exit_conv(x)

# The supervision head is a standard segmentation classifier
class SegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
import matplotlib.pyplot as plt

def plot_loss_and_metric_curves(train_loss_history, val_loss_history, val_iou_history, val_accuracy_history):
    """Plots training/validation loss and validation metrics curves."""

    plt.style.use('dark_background')

    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    epochs = range(1, len(train_loss_history) + 1)

    # Plotting the training and validation loss
    axes[0].plot(epochs, train_loss_history, 'o-', color='cyan', label='Training Loss')
    axes[0].plot(epochs, val_loss_history, 'o-', color='magenta', label='Validation Loss')
    axes[0].set_title('Training and Validation Loss', fontsize=16, color='white')
    axes[0].set_xlabel('Epochs', fontsize=12, color='white')
    axes[0].set_ylabel('Loss', fontsize=12, color='white')
    axes[0].grid(True, which='both', linestyle='--', linewidth=0.5, color='gray')
    axes[0].tick_params(axis='x', colors='white')
    axes[0].tick_params(axis='y', colors='white')
    axes[0].legend(frameon=True).get_frame().set_facecolor('black')


    # Plotting the validation metrics
    axes[1].plot(epochs, val_iou_history, 'o-', color='lime', label='Validation IoU (macro)')
    axes[1].plot(epochs, val_accuracy_history, 'o-', color='yellow', label='Validation Accuracy (macro)')
    axes[1].set_title('Validation Metrics', fontsize=16, color='white')
    axes[1].set_xlabel('Epochs', fontsize=12, color='white')
    axes[1].set_ylabel('Score', fontsize=12, color='white')
    axes[1].grid(True, which='both', linestyle='--', linewidth=0.5, color='gray')
    axes[1].tick_params(axis='x', colors='white')
    axes[1].tick_params(axis='y', colors='white')
    axes[1].legend(frameon=True).get_frame().set_facecolor('black')


    # Change spine colors for both subplots
    for ax in axes:
        for spine in ax.spines.values():
            spine.set_edgecolor('gray')

    plt.tight_layout()
    plt.show()

In [None]:
import torch
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# --- Constants and Color Map ---
CLASS_NAMES = ["Building",
               "Agriculture land",
               "Water",
               "Tree",
               "Road",
               "Developed space",
               "Rangeland",
               "Bareland",
               "Background",
]
COLOR_MAP = ["#DE1F07",
             "#4BB549",
             "#0045FF",
             "#226126",
             "#FFFFFF",
             "#949494",
             "#00FF24",
             "#800000",
             "#000000",
            ]

def hex_to_rgb(hex_code):
    """Converts a hex color string to an (R, G, B) tuple."""
    hex_code = hex_code.lstrip('#')
    # This converts the hex pairs (e.g., 'DE', '1F', '07') into integers
    return tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))

color_map_rgb = [hex_to_rgb(hex_val) for hex_val in COLOR_MAP]

color_map_np = np.array(color_map_rgb, dtype=np.uint8)

# --- 1. UPDATED: Helper Functions ---

def denormalize(tensor, mean, std):
    """Denormalizes an image tensor for visualization."""
    # Clone to avoid modifying the original tensor
    tensor = tensor.clone()
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    # Clamp values to [0, 1] range and convert to numpy
    return torch.clamp(tensor, 0, 1).permute(1, 2, 0).numpy()

def mask_to_rgb(mask, color_map):
    """
    Converts a segmentation mask to a color image using a vectorized approach.
    """
    # Use the mask as indices to directly access colors from the color map
    return color_map[mask]

def calculate_percentages(mask, class_names):
    """Calculates the percentage of each class in a mask."""
    total_pixels = mask.size
    unique, counts = np.unique(mask, return_counts=True)
    legend_data = {}
    for class_id, count in zip(unique, counts):
        if class_id < len(class_names):
            percentage = (count / total_pixels) * 100
            legend_data[class_names[class_id]] = percentage
    return legend_data

# --- 2. UPDATED: Core Prediction Function ---

def predict(model, image_tensor, device):
    """Performs inference on a single image tensor."""
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(device)

        # Model in eval mode with FPN returns the final, upsampled prediction
        prediction = model(image_tensor, training=False)

        # Get the predicted class index for each pixel
        pred_mask = torch.argmax(prediction, dim=1).squeeze(0).cpu().numpy()

    return pred_mask

# --- 3. UPDATED: Main Visualization Function ---

def show_predictions(model, data_loader, device, num_samples=3):
    """
    Runs predictions on samples from the data loader and visualizes the results.
    """
    # Set plot style
    plt.style.use('default')

    # Get a single batch of data
    images, true_masks = next(iter(data_loader))

    # Define ImageNet statistics for denormalization
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    for i in range(num_samples):
        image_tensor = images[i]
        true_mask_np = true_masks[i].numpy()

        # Get model prediction
        predicted_mask_np = predict(model, image_tensor, device)

        # Convert masks to color images for visualization
        true_mask_rgb = mask_to_rgb(true_mask_np, color_map_np)
        predicted_mask_rgb = mask_to_rgb(predicted_mask_np, color_map_np)

        # Denormalize original image for correct display
        original_img_np = denormalize(image_tensor, imagenet_mean, imagenet_std)

        # Get legend data
        legend_data = calculate_percentages(predicted_mask_np, CLASS_NAMES)

        # Plotting
        fig, axes = plt.subplots(1, 3, figsize=(20, 6))

        axes[0].imshow(original_img_np)
        axes[0].set_title("Original Image")

        axes[1].imshow(true_mask_rgb)
        axes[1].set_title("Ground Truth Mask")

        axes[2].imshow(predicted_mask_rgb)
        axes[2].set_title("Predicted Mask")

        for ax in axes:
            ax.axis('off')

        # Create a legend for the predicted mask percentages
        legend_patches = [Patch(color=np.array(COLOR_MAP[CLASS_NAMES.index(name)])/255.,
                                label=f"{name}: {perc:.2f}%")
                          for name, perc in sorted(legend_data.items(), key=lambda item: item[1], reverse=True)]

        fig.legend(handles=legend_patches, bbox_to_anchor=(1.05, 0.75), loc='upper left', borderaxespad=0.)

        plt.tight_layout()
        plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    """
    Implements the Dice Loss for multi-class semantic segmentation.
    """
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        """
        Args:
            logits: A tensor of shape (B, C, H, W) from the model.
            targets: A tensor of shape (B, H, W) with ground truth indices.
        """
        probs = F.softmax(logits, dim=1)
        num_classes = logits.shape[1]
        targets_one_hot = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()

        probs_flat = probs.contiguous().view(-1)
        targets_flat = targets_one_hot.contiguous().view(-1)

        intersection = (probs_flat * targets_flat).sum()
        total_pixels = probs_flat.sum() + targets_flat.sum()

        dice_score = (2. * intersection + self.smooth) / (total_pixels + self.smooth)
        return 1 - dice_score

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    """
    A robust implementation of Focal Loss, designed to be a drop-in
    replacement for nn.CrossEntropyLoss.
    """
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        """
        :param gamma: The focusing parameter (gamma > 0). Default is 2.0.
        :param alpha: Class weights. A tensor of shape (C,) where C is num_classes.
        :param reduction: 'mean', 'sum', or 'none'.
        """
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

        # Move alpha to device if provided
        if self.alpha is not None:
            if not isinstance(self.alpha, torch.Tensor):
                self.alpha = torch.tensor(self.alpha)

    def forward(self, logits, targets):
        """
        :param logits: Model predictions (raw logits) of shape [B, C, H, W]
        :param targets: Ground truth labels of shape [B, H, W]
        """

        # Use the 'weight' parameter of cross_entropy for alpha
        # and 'reduction=none' to get the loss for each pixel
        ce_loss = F.cross_entropy(
            logits,
            targets,
            weight=self.alpha.to(logits.device) if self.alpha is not None else None,
            reduction='none'
        )

        # pt = exp(-ce_loss)
        pt = torch.exp(-ce_loss)

        # Calculate the focal loss
        # (1 - pt)^gamma * ce_loss
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        # Apply the reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [None]:
!pip install -q torchmetrics

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m389.1/983.2 kB[0m [31m11.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from torchmetrics import JaccardIndex, Accuracy

def get_file_paths(split_txt_path):
    """
    Reads a split file, constructs full paths, and verifies that both the
    image and its corresponding label file actually exist before adding them.
    """
    image_paths, label_paths = [], []
    with open(split_txt_path, 'r') as f:
        filenames = [line.strip() for line in f.readlines()]

    split_name = os.path.basename(split_txt_path).split('.')[0]

    print(f"Verifying files for split: {split_name}...")
    for fname in tqdm(filenames, desc="Checking file existence"):
        image_path = os.path.join(cfg.IMAGE_BASE_DIR, split_name, fname)
        label_path = os.path.join(cfg.LABEL_BASE_DIR, split_name, fname)

        if os.path.exists(image_path) and os.path.exists(label_path):
            image_paths.append(image_path)
            label_paths.append(label_path)

    print(f"Found {len(image_paths)} valid file pairs for the '{split_name}' set.")
    return image_paths, label_paths

def main():
    # 1. Prepare Data
    train_imgs, train_masks = get_file_paths(cfg.TRAIN_TXT)
    val_imgs, val_masks = get_file_paths(cfg.VAL_TXT)

    train_dataset = LandCoverDataset(train_imgs, train_masks, transform=train_transform)
    val_dataset = LandCoverDataset(val_imgs, val_masks, transform=val_transform)

    train_loader = DataLoader(
        train_dataset, batch_size=cfg.BATCH_SIZE, num_workers=cfg.NUM_WORKERS,
        pin_memory=cfg.PIN_MEMORY, shuffle=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=cfg.BATCH_SIZE, num_workers=cfg.NUM_WORKERS,
        pin_memory=cfg.PIN_MEMORY, shuffle=False
    )

    # 2. Initialize Model
    model = ShearletExitUNet(
        in_channels=3,
        num_classes=cfg.NUM_CLASSES,
        shearlet_features=8,
        dropout_prob=0.4, # Slightly increased dropout for more regularization
        fpn_out_channels=cfg.FPN_CHANNELS,
        backbone_name='resnet50'
    ).to(cfg.DEVICE)

    # --- 3. Define Loss, Optimizer, and Scheduler ---

    # Define weighted loss for early exits. Give the final exit the most weight.
    # The number of elements should match the number of prediction heads in your model.
    # Adjust exit_weights based on the number of features returned by the new backbone
    num_encoder_features = len(model.encoder.output_channels)
    exit_weights = [i / num_encoder_features for i in range(1, num_encoder_features + 1)]


    # Use class weights if you have class imbalance
    # class_weights = torch.tensor([...]).to(cfg.DEVICE)
    # loss_fn = nn.CrossEntropyLoss(weight=class_weights)
    loss_fn_ce = FocalLoss(gamma=2.0)
    loss_fn_dice = DiceLoss()

    loss_alpha = 0.5
    loss_beta = 0.5

    # Use a differential learning rate for the optimizer
    # This uses a smaller LR for the pretrained encoder and a larger one for the new decoder.
    encoder_params = model.encoder.parameters()
    decoder_params = [p for name, p in model.named_parameters() if not name.startswith("encoder.")]
    optimizer = optim.AdamW([
        {'params': encoder_params, 'lr': cfg.LEARNING_RATE * 0.1},
        {'params': decoder_params, 'lr': cfg.LEARNING_RATE}
    ], weight_decay=cfg.WEIGHT_DECAY)

    scaler = torch.cuda.amp.GradScaler()

    scheduler = lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=len(train_loader) * cfg.EPOCHS, # Total number of training steps
        eta_min=1e-6 # A very small minimum learning rate
    )

    # --- Initialize Metrics ---
    # Use 'macro' averaging to calculate the mean IoU across all classes
    val_iou = JaccardIndex(task='multiclass', num_classes=cfg.NUM_CLASSES, average='macro').to(cfg.DEVICE)
    # Use 'macro' average for Accuracy in multiclass
    val_accuracy = Accuracy(task='multiclass', num_classes=cfg.NUM_CLASSES, average='macro').to(cfg.DEVICE)

    # 4. Run Training Loop...
    train_loss_history = []
    val_loss_history = []
    val_iou_history = []
    val_accuracy_history = []
    best_val_iou = float('-inf')
    epochs_no_improve = 0
    early_stop = False


    for epoch in range(cfg.EPOCHS):
        if early_stop:
            print(f"Early stopping triggered at epoch {epoch}.")
            break

        model.train()
        train_loss = 0.0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.EPOCHS}")

        for data, targets in loop:
            data = data.to(cfg.DEVICE)
            targets_on_device = targets.to(cfg.DEVICE) # Move targets to device

            all_predictions = model(data, training=True)

            # --- Weighted Loss Calculation ---
            total_loss = 0
            for i, pred in enumerate(all_predictions):
                # Resize the single target tensor once for all predictions of the same size
                target_resized = F.interpolate(
                    targets_on_device.unsqueeze(1).float(),
                    size=pred.shape[2:],
                    mode='nearest'
                ).squeeze(1).long()

                ce_loss = loss_fn_ce(pred, target_resized)
                dice_loss = loss_fn_dice(pred, target_resized)

                combined_loss = (loss_alpha * ce_loss) + (loss_beta * dice_loss)

                total_loss += exit_weights[i] * combined_loss

            optimizer.zero_grad()
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step() # Step the scheduler after each batch

            train_loss += total_loss.item()
            loop.set_postfix(loss=total_loss.item())

        avg_train_loss = train_loss / len(train_loader)
        train_loss_history.append(avg_train_loss)

        # --- Validation Loop ---
        model.eval()
        val_loss = 0.0
        val_iou.reset() # Reset metrics at the beginning of each validation epoch
        val_accuracy.reset()

        with torch.no_grad():
            for data, targets in val_loader:
                data = data.to(cfg.DEVICE)
                targets_on_device = targets.to(cfg.DEVICE)

                # In evaluation mode, model returns the final prediction only
                final_prediction_logits = model(data, training=False)

                # For metric calculation, we need the predicted class index
                final_prediction_mask = torch.argmax(final_prediction_logits, dim=1)

                # Calculate loss for the final prediction only during validation logging
                val_loss += loss_fn_ce(final_prediction_logits, targets_on_device).item()
                val_loss += loss_fn_dice(final_prediction_logits, targets_on_device).item()

                # Update metrics
                val_iou.update(final_prediction_mask, targets_on_device)
                val_accuracy.update(final_prediction_mask, targets_on_device)


        avg_val_loss = val_loss / len(val_loader)
        avg_val_iou = val_iou.compute()
        avg_val_accuracy = val_accuracy.compute()

        val_loss_history.append(avg_val_loss)
        val_iou_history.append(avg_val_iou.item()) # Append scalar value
        val_accuracy_history.append(avg_val_accuracy.item()) # Append scalar value


        print(f"Epoch {epoch+1} Summary -> Avg Train Loss: {avg_train_loss:.4f} | Avg Val Loss: {avg_val_loss:.4f} | Avg Val IoU: {avg_val_iou:.4f} | Avg Val Acc: {avg_val_accuracy:.4f}")

        # --- Early Stopping Check ---
        if avg_val_iou > best_val_iou + cfg.MIN_DELTA:
            best_val_iou = avg_val_iou
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Mean IoU improved. Saving model. Best IoU: {best_val_iou:.4f}")
        else:
            epochs_no_improve += 1
            print(f"Mean IoU did not improve for {epochs_no_improve} epochs.")
            if epochs_no_improve >= cfg.EARLY_STOPPING_PATIENCE:
                early_stop = True

    print("Training finished.")
    plot_loss_and_metric_curves(train_loss_history, val_loss_history, val_iou_history, val_accuracy_history)

    # Optionally plot metrics as well
    # plot_metric_curves(val_iou_history, val_accuracy_history)


    print("Loading best model for final predictions...")
    model.load_state_dict(torch.load('best_model.pth'))
    show_predictions(model, val_loader, cfg.DEVICE)

if __name__ == "__main__":
    main()

Verifying files for split: train...


Checking file existence: 100%|██████████| 3000/3000 [00:00<00:00, 82867.91it/s]


Found 2303 valid file pairs for the 'train' set.
Verifying files for split: val...


Checking file existence: 100%|██████████| 500/500 [00:00<00:00, 64971.56it/s]

Found 384 valid file pairs for the 'val' set.



  scaler = torch.cuda.amp.GradScaler()
Epoch 1/40:   0%|          | 2/576 [01:03<5:03:02, 31.68s/it, loss=3.85]


KeyboardInterrupt: 