# Land Cover Classification using DeepLabV3+ (ResNet50 Backbone)

## Kaggle Environment Setup

Follow these steps to run this notebook on Kaggle:

1.  **Dataset:** Add the `deepglobe-land-cover-classification-dataset` via \
. The notebook expects it at `/kaggle/input/deepglobe-land-cover-classification-dataset/`.
2.  **Internet:** Turn **ON** Internet access in the Notebook settings (right panel -> Settings) to allow package installation.
3.  **Accelerator:** Select a **GPU** (e.g., T4 x2, P100) for faster training.
4.  **Resource Limits:** If you hit Kaggle's time or memory limits:
    *   Reduce `NUM_EPOCHS` (e.g., to 5-10).
    *   Reduce `BATCH_SIZE` (e.g., to 4 or 2).
    *   Consider reducing `IMAGE_SIZE` (e.g., to 256 or 384), though this might affect accuracy.

## 1. Setup Environment

Install and import necessary libraries.

In [None]:
# Install required libraries
!pip install --upgrade pip -q
# Install core libraries + torch_xla for TPU support + torchmetrics
!pip install -q segmentation-models-pytorch albumentations Pillow torch torchvision torchaudio numpy matplotlib seaborn scikit-learn torchmetrics torchinfo torch_xla
# Note: evaluate & datasets might be used for metrics/data handling, keeping them for now.

**Note on Dependency Conflicts:**

You might see `ERROR: pip's dependency resolver...` messages after the installation. These often occur in pre-configured environments like Kaggle or Colab where existing packages (e.g., `gcsfs`, `bigframes`, `rich`, RAPIDS libraries) have version constraints that conflict with newly installed ones.

For this specific notebook, these conflicts are generally **safe to ignore** as they typically involve packages not directly used for the core DeepLabV3+ training and prediction tasks (which rely on `torch`, `segmentation-models-pytorch`, `albumentations`, `torchmetrics` etc.). Attempting to fix them might break other functionalities of the environment.

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import torchmetrics # Use torchmetrics
import random
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm # Progress bars
import time
from torchinfo import summary # For model summary
import torch.nn.functional as F # For loss functions
import gc # Garbage collection

# --- TPU Setup (if applicable) ---
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    print("torch_xla found.")
    _xla_available = True
except ImportError:
    print("torch_xla not found, TPU support disabled.")
    _xla_available = False

print(f"PyTorch version: {torch.__version__}")
print(f"Segmentation Models Pytorch version: {smp.__version__}")
print(f"TorchMetrics version: {torchmetrics.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# --- Device Selection (Prioritize TPU > CUDA > CPU) ---
if _xla_available:
    DEVICE = xm.xla_device()
    print(f"Using TPU: {xm.xla_real_devices([str(DEVICE)])[0]}")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU device")

## 2. Load Dataset Metadata & Initial Analysis

Load the metadata file, perform initial checks, handle missing mask paths, and analyze the data split.

**Important:** Assumes the dataset is at `/kaggle/input/deepglobe-land-cover-classification-dataset/`.

In [None]:
# Define dataset paths
dataset_base_dir = '/kaggle/input/deepglobe-land-cover-classification-dataset' # Original Kaggle path

# Corrected: dataset_root_dir should be the base directory itself, as metadata.csv is directly inside it.
dataset_root_dir = dataset_base_dir
metadata_path = os.path.join(dataset_root_dir, 'metadata.csv')

# Load metadata
metadata_df = None
if not os.path.exists(dataset_root_dir):
    print(f"Error: Dataset directory not found at {dataset_root_dir}")
    print("Please ensure the DeepGlobe dataset is added via '+ Add Data' and the path is correct.")
elif not os.path.exists(metadata_path):
    print(f"Error: metadata.csv not found at {metadata_path}")
else:
    print(f"Dataset directory found. Loading metadata from: {metadata_path}")
    try:
        metadata_df = pd.read_csv(metadata_path)
        print(f"Metadata loaded successfully. Initial shape: {metadata_df.shape}")
        
        # --- Data Cleaning and Path Handling ---
        print("\n--- Initial Data Stats ---")
        print("Value counts for 'split' column:")
        print(metadata_df['split'].value_counts())
        print("\nNull values per column:")
        print(metadata_df.isnull().sum())
        
        # Check for nulls specifically in path columns
        null_sat_paths = metadata_df['sat_image_path'].isnull().sum()
        null_mask_paths = metadata_df['mask_path'].isnull().sum()
        print(f"\nNumber of null 'sat_image_path': {null_sat_paths}")
        print(f"Number of null 'mask_path': {null_mask_paths}")
        
        # Drop rows where mask_path is null, as they are unusable for supervised learning
        if null_mask_paths > 0:
            print(f"\nDropping {null_mask_paths} rows with missing 'mask_path'...")
            metadata_df.dropna(subset=['mask_path'], inplace=True)
            print(f"Shape after dropping null mask paths: {metadata_df.shape}")
        
        # Drop rows where sat_image_path is null (if any)
        if null_sat_paths > 0 and 'sat_image_path' in metadata_df.columns: # Check if column still exists
             initial_rows = metadata_df.shape[0]
             metadata_df.dropna(subset=['sat_image_path'], inplace=True)
             rows_dropped = initial_rows - metadata_df.shape[0]
             if rows_dropped > 0:
                 print(f"Dropped {rows_dropped} rows with missing 'sat_image_path'.")
                 print(f"Shape after dropping null sat image paths: {metadata_df.shape}")
        
        # Prepend the root directory to the relative paths
        metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda x: os.path.join(dataset_root_dir, x) if isinstance(x, str) else x)
        metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda x: os.path.join(dataset_root_dir, x) if isinstance(x, str) else x)
        
        print("\n--- Data Stats After Cleaning --- ")
        print("Value counts for 'split' column:")
        print(metadata_df['split'].value_counts())
        print("\nNull values per column (should be 0 for paths):")
        print(metadata_df.isnull().sum())
        
    except Exception as e:
        print(f"Error loading or processing metadata.csv: {e}")
        metadata_df = None # Ensure df is None if error occurred

# Define class labels and color mapping (assuming this remains the same)
id2label = {
    0: 'urban_land', 1: 'agriculture_land', 2: 'rangeland',
    3: 'forest_land', 4: 'water', 5: 'barren_land',
    6: 'unknown' # Often used for background or ignored areas
}
label2id = {v: k for k, v in id2label.items()}
NUM_CLASSES = len(id2label) # Define NUM_CLASSES based on the dataset
class_names = list(id2label.values())
IGNORE_INDEX = 6 # Class ID to ignore during training and evaluation

# RGB color to Class ID mapping (Verify these with dataset documentation)
rgb_to_id = {
    (0, 255, 255): 0,  # Urban (Cyan)
    (255, 255, 0): 1,  # Agriculture (Yellow)
    (255, 0, 255): 2,  # Rangeland (Magenta)
    (0, 255, 0): 3,    # Forest (Green)
    (0, 0, 255): 4,    # Water (Blue)
    (255, 255, 255): 5,# Barren (White)
    (0, 0, 0): 6       # Unknown (Black)
}
id_to_rgb = {v: k for k, v in rgb_to_id.items()}

def rgb_mask_to_class_id_mask(mask_img):
    """Converts an RGB mask image (PIL Image) to a 2D numpy array of class IDs."""
    mask_arr = np.array(mask_img.convert('RGB'))
    class_mask = np.full(mask_arr.shape[:2], IGNORE_INDEX, dtype=np.uint8) # Default to ignore index
    for rgb, class_id in rgb_to_id.items():
        matches = np.all(mask_arr == np.array(rgb).reshape(1, 1, 3), axis=2)
        class_mask[matches] = class_id
    return class_mask

# Function to load and verify data paths for a given split
# Updated: Assumes df has already been cleaned (null paths dropped)
def load_data_paths(df, split):
    if df is None:
        print(f"Metadata DataFrame not loaded. Cannot load paths for split '{split}'.")
        return [], []

    print(f"\nLoading paths for split: '{split}'")
    split_df = df[df['split'] == split].copy()

    if 'sat_image_path' not in split_df.columns or 'mask_path' not in split_df.columns:
        print(f"Error: 'sat_image_path' or 'mask_path' column not found in metadata for split '{split}'.")
        return [], []

    if split_df.empty:
        print(f"No entries found for split '{split}' in metadata.")
        return [], []

    valid_image_paths = []
    valid_mask_paths = []
    skipped_missing_file = 0

    # Iterate through rows, checking file existence (paths assumed valid strings now)
    for index, row in split_df.iterrows():
        img_path = row['sat_image_path']
        mask_path = row['mask_path']

        # Check if paths are strings before checking existence
        if isinstance(img_path, str) and isinstance(mask_path, str):
            if os.path.exists(img_path) and os.path.exists(mask_path):
                valid_image_paths.append(img_path)
                valid_mask_paths.append(mask_path)
            else:
                skipped_missing_file += 1
                # Optional: Print missing paths for debugging
                # if not os.path.exists(img_path): print(f"Missing image: {img_path}")
                # if not os.path.exists(mask_path): print(f"Missing mask: {mask_path}")
        else:
             # This case should ideally not happen after dropping NaNs, but good to handle
             skipped_missing_file += 1 
             # print(f"Skipping row {index} due to non-string path: Img={img_path}, Mask={mask_path}")

    if skipped_missing_file > 0:
         print(f"Warning: Skipped {skipped_missing_file} pairs due to missing files or invalid paths for split '{split}'.")

    if not valid_image_paths:
        print(f"No valid image-mask pairs found for split '{split}' after checking existence.")
        return [], []

    print(f"Returning {len(valid_image_paths)} valid pairs for split '{split}'.")
    return valid_image_paths, valid_mask_paths

# Load paths for each split using the cleaned DataFrame
train_image_paths, train_mask_paths = load_data_paths(metadata_df, 'train')
val_image_paths, val_mask_paths = load_data_paths(metadata_df, 'valid')
test_image_paths, test_mask_paths = load_data_paths(metadata_df, 'test')

# --- Create Validation Split from Training Data if Validation is Empty ---
if not val_image_paths and train_image_paths:
    print("\nValidation set from metadata is empty. Creating validation split from training data (20%).")
    try:
        from sklearn.model_selection import train_test_split
        # Split the training data into new train and validation sets
        train_image_paths, val_image_paths, train_mask_paths, val_mask_paths = train_test_split(
            train_image_paths, 
            train_mask_paths, 
            test_size=0.2, # Use 20% for validation
            random_state=42 # For reproducibility
        )
        print(f"  New training set size: {len(train_image_paths)}")
        print(f"  New validation set size: {len(val_image_paths)}")
    except ImportError:
        print("  Warning: sklearn not found. Cannot create validation split automatically.")
    except Exception as e:
        print(f"  Error creating validation split: {e}")
elif val_image_paths:
    print("\nUsing validation set defined in metadata.")
else:
    print("\nNo training data available to create a validation split.")

### Clarification on Missing Validation/Test Paths (Original Issue)

Previously, warnings like `Skipped ... pairs due to invalid path types (e.g., NaN)` occurred because the `metadata.csv` file had **empty values** (read as `NaN` by pandas) in the `mask_path` column for rows marked with `split='valid'` or `split='test'`.

**Resolution:** The code above now explicitly checks for and drops rows with null `mask_path` values *before* attempting to load paths for each split. This ensures that only rows with valid, non-null mask paths are considered. If the 'valid' and 'test' splits still show 0 pairs after this cleaning, it means the original `metadata.csv` truly lacked mask paths for those entries.

### 2.1 Visualize Data Split Distribution

Plot the number of samples in each data split (train, validation, test) after cleaning.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_split_distribution(df):
    if df is None or 'split' not in df.columns:
        print("Cannot plot distribution: DataFrame is None or 'split' column missing.")
        return
        
    plt.figure(figsize=(8, 5))
    split_counts = df['split'].value_counts()
    sns.barplot(x=split_counts.index, y=split_counts.values, palette='viridis')
    plt.title('Data Split Distribution (After Cleaning)')
    plt.xlabel('Split')
    plt.ylabel('Number of Samples')
    plt.xticks(rotation=0)
    # Add counts on top of bars
    for index, value in enumerate(split_counts):
        plt.text(index, value + 0.1, str(value), ha='center', va='bottom')
    plt.show()

# Plot the distribution using the cleaned metadata_df
plot_split_distribution(metadata_df)

### 2.2 Data Loading Status Note

Based on the cleaning and path loading steps above:
*   If **training data paths** were found, training can proceed.
*   If **validation data paths** were found (`val_image_paths` is not empty), validation during training and saving the best model based on validation loss/mIoU will occur.
*   If **validation data paths** were *not* found, validation steps will be skipped, and the model will be saved after the final training epoch.
*   If **test data paths** were found (`test_image_paths` is not empty), evaluation on the test set can be performed after training.
*   If **test data paths** were *not* found, evaluation on the test set will be skipped.
*   Visualization will prioritize the test set, then validation, then training set, depending on availability.

### 2.3 Visualize Sample Data

Display a few image-mask pairs from the training set (if available).

In [None]:
import random
import matplotlib.pyplot as plt
from PIL import Image

def show_samples_from_paths(image_paths, mask_paths, num_samples=3):
    if not image_paths or len(image_paths) < num_samples:
        print(f"Cannot show samples: Need at least {num_samples} valid image paths, found {len(image_paths)}.")
        return
    
    indices = random.sample(range(len(image_paths)), num_samples)
    
    plt.figure(figsize=(15, 5 * num_samples))
    for i, idx in enumerate(indices):
        img_path = image_paths[idx]
        mask_path = mask_paths[idx]
        
        try:
            img = Image.open(img_path).convert("RGB")
            mask = Image.open(mask_path).convert("RGB") # Keep mask as RGB for visualization
        except Exception as e:
             print(f"Error loading image/mask at index {idx} ({img_path}): {e}")
             # Optionally add placeholder plots or skip
             plt.subplot(num_samples, 2, 2*i + 1).set_title(f"Error loading image {idx}").axis('off')
             plt.subplot(num_samples, 2, 2*i + 2).set_title(f"Error loading mask {idx}").axis('off')
             continue
                 
        plt.subplot(num_samples, 2, 2*i + 1)
        plt.imshow(img)
        plt.title(f"Sample {idx} - Image")
        plt.axis('off')
        
        plt.subplot(num_samples, 2, 2*i + 2)
        plt.imshow(mask)
        plt.title(f"Sample {idx} - Mask (RGB)")
        plt.axis('off')
        
    plt.tight_layout()
    plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title("Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title("Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title("Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title("Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"```json
Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning```json
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for```json
,
,
,
,
,
,
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

In [None]:
IMAGE_SIZE = 512 # Input size for the model. Reduce if memory errors occur.
ENCODER = 'resnet50' # Changed to ResNet50 for DeepLabV3+
ENCODER_WEIGHTS = 'imagenet'

# Get the preprocessing function specific to the ResNet50 encoder
try:
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
except KeyError:
    print(f"Warning: Preprocessing function not found for {ENCODER} with {ENCODER_WEIGHTS}. Using standard ImageNet normalization.")
    # Define a fallback or standard normalization if needed
    preprocessing_fn = lambda x: (x / 255.0 - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])

# Define Albumentations Transforms
def get_transforms(phase, image_size, preprocessing_fn):
    common_transforms = [A.Resize(image_size, image_size)]
    if phase == 'train':
        # Augmentations: Add more as needed (e.g., Rotate, ShiftScaleRotate)
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Add more augmentations here if desired
            # A.RandomBrightnessContrast(p=0.2),
            # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        ]
    else:
        aug_transforms = [] # No augmentation for validation/test

    # Preprocessing (normalization) and tensor conversion
    # Note: Apply preprocessing_fn *before* ToTensorV2 if it expects a numpy array
    # If preprocessing_fn expects a tensor, apply it after ToTensorV2
    final_transforms = [
        A.Lambda(image=preprocessing_fn), # Apply model-specific preprocessing
        ToTensorV2(), # Convert image and mask to PyTorch tensors (C, H, W)
    ]

    return A.Compose(common_transforms + aug_transforms + final_transforms)

# Create transforms for each phase
train_transforms = get_transforms('train', IMAGE_SIZE, preprocessing_fn)
val_transforms = get_transforms('val', IMAGE_SIZE, preprocessing_fn)
test_transforms = get_transforms('test', IMAGE_SIZE, preprocessing_fn)

# Custom Dataset Class (Remains the same)
class LandCoverDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None, rgb_to_id_func=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        self.rgb_to_id_func = rgb_to_id_func

        if len(self.image_paths) != len(self.mask_paths):
             raise ValueError("Number of images and masks must be the same.")
        if not self.image_paths:
             print("Warning: Initializing dataset with zero image paths.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths):
            raise IndexError("Index out of range")
            
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        try:
            image = np.array(Image.open(img_path).convert("RGB"))
            mask_rgb = Image.open(mask_path) # Load mask
        except Exception as e:
            print(f"Error loading image/mask at index {idx} ({img_path} / {mask_path}): {e}")
            # Return dummy data or raise error, depending on desired behavior
            # For simplicity, returning None here, handle appropriately in DataLoader/training loop
            # A better approach might be to filter out bad data beforehand
            return None, None 

        # Convert RGB mask to class ID mask
        if self.rgb_to_id_func:
            mask = self.rgb_to_id_func(mask_rgb)
        else:
            # Fallback or error if function not provided
            mask = np.array(mask_rgb) # Assuming mask is already single channel if no func

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is LongTensor for CrossEntropyLoss/FocalLoss
        # Mask shape should be (H, W), not (1, H, W)
        mask = mask.squeeze().long() 

        return image, mask

# Create Datasets
# Check if paths are loaded before creating datasets
train_dataset = None
val_dataset = None
test_dataset = None

if train_image_paths:
    train_dataset = LandCoverDataset(train_image_paths, train_mask_paths, transforms=train_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created training dataset with {len(train_dataset)} samples.")
else:
    print("Skipping training dataset creation: No valid training paths found.")

if val_image_paths:
    val_dataset = LandCoverDataset(val_image_paths, val_mask_paths, transforms=val_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created validation dataset with {len(val_dataset)} samples.")
else:
    print("Skipping validation dataset creation: No valid validation paths found.")

if test_image_paths:
    test_dataset = LandCoverDataset(test_image_paths, test_mask_paths, transforms=test_transforms, rgb_to_id_func=rgb_mask_to_class_id_mask)
    print(f"Created test dataset with {len(test_dataset)} samples.")
else:
    print("Skipping test dataset creation: No valid test paths found.")

# Create DataLoaders
BATCH_SIZE = 8 # Adjust based on GPU/TPU memory (e.g., 4, 8, 16)
NUM_WORKERS = 2 # Adjust based on system capabilities

train_loader = None
val_loader = None
test_loader = None

if train_dataset:
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created train DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping train DataLoader creation.")

if val_dataset:
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created validation DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping validation DataLoader creation.")

if test_dataset:
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f"Created test DataLoader with batch size {BATCH_SIZE}.")
else:
    print("Skipping test DataLoader creation.")

# --- Verification Step: Check a batch ---
if train_loader:
    print("\nVerifying a batch from train_loader...")
    try:
        images, masks = next(iter(train_loader))
        print(f"Image batch shape: {images.shape}, dtype: {images.dtype}")
        print(f"Mask batch shape: {masks.shape}, dtype: {masks.dtype}")
        print(f"Mask unique values: {torch.unique(masks)}")
        
        # Visualize one sample from the batch
        img_sample = images[0].permute(1, 2, 0).cpu().numpy() # C, H, W -> H, W, C
        mask_sample = masks[0].cpu().numpy()
        
        # Need to denormalize image for visualization if normalized
        # This depends on the exact preprocessing_fn. Assuming standard ImageNet normalization:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_sample = std * img_sample + mean
        img_sample = np.clip(img_sample, 0, 1)
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(img_sample)
        plt.title(f"Sample Image (from DataLoader)")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mask_sample, cmap='viridis') # Use a colormap suitable for class IDs
        plt.title(f"Sample Mask (from DataLoader)")
        plt.axis('off')
        plt.show()

# Visualize samples from the training dataset paths
print("Displaying samples from the training set paths:")
show_samples_from_paths(train_image_paths, train_mask_paths)

## 3. Preprocessing & Data Loading

Define augmentations using Albumentations, preprocessing steps suitable for the ResNet backbone, and create a custom PyTorch Dataset and DataLoaders.

**Note:** `IMAGE_SIZE` impacts memory usage and training time. Adjust if needed.

## 4. Define Model (DeepLabV3+ with ResNet50 Backbone)

Using the `segmentation-models-pytorch` library to create a DeepLabV3+ model with a pre-trained ResNet50 encoder.

In [None]:
# Define the model
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    in_channels=3,             # Input channels (RGB)
    classes=NUM_CLASSES,       # Output classes
    activation=None            # Output logits for combined loss
)

# Move model to the selected device
model.to(DEVICE)

# Print model summary (optional)
print(f"Model: DeepLabV3+ with {ENCODER} backbone")
# Use torchinfo for a detailed summary
# Check if train_loader exists to get BATCH_SIZE for summary
if train_loader:
    try:
        # Provide input size including batch dimension (B, C, H, W)
        summary(model, input_size=(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE))
    except Exception as e:
        print(f"Could not generate model summary with torchinfo: {e}")
        # Fallback to basic print
        # print(model)
else:
    print("Skipping model summary: train_loader not available.")

## 5. Define Loss Function, Optimizer, Scheduler, and Metrics

*   **Loss:** Combined Focal Loss + Dice Loss. We'll ignore the 'unknown' class (`IGNORE_INDEX`).
*   **Optimizer:** AdamW.
*   **Scheduler:** ReduceLROnPlateau monitors validation loss.
*   **Metrics:** Using `torchmetrics` for Accuracy and Mean Intersection over Union (mIoU), ignoring the specified class index.

In [None]:
LEARNING_RATE = 1e-4 # Adjusted learning rate
PATIENCE = 5 # For early stopping and scheduler

# --- Loss Functions ---
# Optional: Define class weights (adjust based on dataset analysis if needed)
# class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.1], device=DEVICE) # Example: Downweight 'unknown'
class_weights = None # Start without weights

def focal_loss(outputs, targets, alpha=None, gamma=2.0, ignore_index=None):
    """ Focal Loss, adapted to ignore an index. """
    ce_loss = F.cross_entropy(outputs, targets, reduction='none', weight=alpha, ignore_index=ignore_index if ignore_index is not None else -100)
    pt = torch.exp(-ce_loss)
    # Only apply focal loss calculation to valid (non-ignored) pixels
    if ignore_index is not None:
        valid_mask = (targets != ignore_index)
        focal = ((1 - pt[valid_mask]) ** gamma) * ce_loss[valid_mask]
        return focal.mean()
    else:
        focal = ((1 - pt) ** gamma) * ce_loss
        return focal.mean()

def dice_loss(outputs, targets, smooth=1e-6, ignore_index=None):
    """ Dice Loss, adapted to ignore an index. """
    num_classes = outputs.shape[1]
    outputs = F.softmax(outputs, dim=1)
    
    # Create one-hot targets, considering ignore_index
    if ignore_index is not None:
        valid_mask = (targets != ignore_index).unsqueeze(1) # (B, 1, H, W)
        targets_masked = torch.where(targets == ignore_index, num_classes, targets) # Temp replace ignore_index
        targets_one_hot = F.one_hot(targets_masked, num_classes + 1).permute(0, 3, 1, 2).float()
        targets_one_hot = targets_one_hot[:, :num_classes, :, :] # Remove the extra class channel
        targets_one_hot = targets_one_hot * valid_mask # Zero out ignored pixels
        outputs = outputs * valid_mask # Zero out predictions for ignored pixels
    else:
        targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()

    # Calculate intersection and union per class, averaged over batch
    intersection = (outputs * targets_one_hot).sum(dim=(0, 2, 3)) # Sum over batch and spatial dims
    union = outputs.sum(dim=(0, 2, 3)) + targets_one_hot.sum(dim=(0, 2, 3))

    # Calculate Dice coefficient per class, then average
    # Exclude the ignore_index class if specified (it won't be in the one-hot encoding anyway)
    dice_per_class = (2 * intersection + smooth) / (union + smooth)
    
    # Average Dice score across relevant classes
    # If ignore_index is used, we might want to average only over non-ignored classes
    # For simplicity here, we average over all output classes.
    # A more refined approach could mask the average based on class presence.
    dice = dice_per_class.mean()
    
    return 1 - dice

def combined_loss(outputs, targets, ignore_index=None):
    """ Combination of Focal Loss and Dice Loss. """
    fl = focal_loss(outputs, targets, alpha=class_weights, gamma=2.0, ignore_index=ignore_index)
    dl = dice_loss(outputs, targets, ignore_index=ignore_index)
    return fl + dl

# --- Optimizer ---
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# --- Scheduler ---
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=PATIENCE // 2, verbose=True)

# --- Metrics (using torchmetrics) ---
# Instantiate metrics - specify task, num_classes, and ignore_index
# Note: Metrics are often computed on CPU to save GPU memory, especially during aggregation.
metrics_device = 'cpu' # Compute metrics on CPU

train_accuracy = torchmetrics.classification.MulticlassAccuracy(
    num_classes=NUM_CLASSES,
    ignore_index=IGNORE_INDEX,
    average='micro' # Pixel accuracy across all classes
).to(metrics_device)

train_miou = torchmetrics.classification.MulticlassJaccardIndex(
    num_classes=NUM_CLASSES,
    ignore_index=IGNORE_INDEX,
    average='macro' # Mean IoU across classes
).to(metrics_device)

val_accuracy = torchmetrics.classification.MulticlassAccuracy(
    num_classes=NUM_CLASSES,
    ignore_index=IGNORE_INDEX,
    average='micro'
).to(metrics_device)

val_miou = torchmetrics.classification.MulticlassJaccardIndex(
    num_classes=NUM_CLASSES,
    ignore_index=IGNORE_INDEX,
    average='macro'
).to(metrics_device)

## 6. Training and Validation Loops

Includes early stopping based on validation loss and LR scheduling.

In [None]:
def train_epoch(model, loader, criterion, optimizer, device, acc_metric, miou_metric, ignore_index):
    model.train()
    epoch_loss = 0.0
    acc_metric.reset()
    miou_metric.reset()

    # Wrap loader with xm.ParallelLoader for TPU training if applicable
    if _xla_available and isinstance(device, torch.device) and device.type == 'xla':
        para_loader = xm.ParallelLoader(loader, [device])
        progress_bar = tqdm(para_loader.per_device_loader(device), desc="Training", leave=False, total=len(loader))
    else:
        progress_bar = tqdm(loader, desc="Training", leave=False)
        
    for batch_idx, (images, masks) in enumerate(progress_bar):
        images = images.to(device)
        masks = masks.to(device) # Shape: (B, H, W)

        optimizer.zero_grad()
        outputs = model(images) # Shape: (B, C, H, W) - Logits
        # Handle potential dict output from some models
        if isinstance(outputs, dict):
            outputs = outputs.get('out', outputs) # Default to 'out' key

        loss = criterion(outputs, masks, ignore_index=ignore_index)
        loss.backward()
        
        # Use xm.optimizer_step for TPU
        if _xla_available and isinstance(device, torch.device) and device.type == 'xla':
            xm.optimizer_step(optimizer)
        else:
            optimizer.step()

        epoch_loss += loss.item()

        # Update metrics (move preds/masks to metrics_device - CPU)
        with torch.no_grad():
            preds = torch.argmax(outputs, dim=1)
            acc_metric.update(preds.to(acc_metric.device), masks.to(acc_metric.device))
            miou_metric.update(preds.to(miou_metric.device), masks.to(miou_metric.device))

        progress_bar.set_postfix(loss=loss.item())

    # Reduce loss across cores if using TPU
    if _xla_available and isinstance(device, torch.device) and device.type == 'xla':
         epoch_loss = xm.mesh_reduce('train_loss_reduce', epoch_loss, np.mean)
         
    avg_loss = epoch_loss / len(loader) 
    
    # Compute final epoch metrics
    # For TPU, ensure compute is called on all devices or gather results
    # torchmetrics handles synchronization with `sync_on_compute=True` (default)
    final_acc = acc_metric.compute().item()
    final_miou = miou_metric.compute().item()
    
    # Mark step for TPU execution graph
    if _xla_available and isinstance(device, torch.device) and device.type == 'xla':
        xm.mark_step()
        
    return avg_loss, {'Accuracy': final_acc, 'mIoU': final_miou}

def validate_epoch(model, loader, criterion, device, acc_metric, miou_metric, ignore_index):
    model.eval()
    epoch_loss = 0.0
    acc_metric.reset()
    miou_metric.reset()

    # Wrap loader for TPU if applicable
    if _xla_available and isinstance(device, torch.device) and device.type == 'xla':
        para_loader = xm.ParallelLoader(loader, [device])
        progress_bar = tqdm(para_loader.per_device_loader(device), desc="Validation", leave=False, total=len(loader))
    else:
        progress_bar = tqdm(loader, desc="Validation", leave=False)

    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(progress_bar):
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            if isinstance(outputs, dict):
                outputs = outputs.get('out', outputs)
                
            loss = criterion(outputs, masks, ignore_index=ignore_index)
            epoch_loss += loss.item()

            # Update metrics (move preds/masks to metrics_device - CPU)
            preds = torch.argmax(outputs, dim=1)
            acc_metric.update(preds.to(acc_metric.device), masks.to(acc_metric.device))
            miou_metric.update(preds.to(miou_metric.device), masks.to(miou_metric.device))

            progress_bar.set_postfix(loss=loss.item())

    # Reduce loss across cores if using TPU
    if _xla_available and isinstance(device, torch.device) and device.type == 'xla':
         epoch_loss = xm.mesh_reduce('val_loss_reduce', epoch_loss, np.mean)
         
    avg_loss = epoch_loss / len(loader)
    
    # Compute final epoch metrics
    final_acc = acc_metric.compute().item()
    final_miou = miou_metric.compute().item()

    # Mark step for TPU execution graph
    if _xla_available and isinstance(device, torch.device) and device.type == 'xla':
        xm.mark_step()
        
    return avg_loss, {'Accuracy': final_acc, 'mIoU': final_miou}

## 7. Run Training

Execute the training and validation loops for a specified number of epochs, with early stopping and model saving.

**Note:** Training can take a significant amount of time. Adjust `NUM_EPOCHS` based on available time and resources.

In [None]:
NUM_EPOCHS = 25 # Adjust as needed (e.g., 15, 25, 50)
best_val_loss = float('inf')
patience_counter = 0
model_save_path = f"deeplabv3plus_{ENCODER}_best_model.pth" # Save best based on validation loss
final_model_save_path = f"deeplabv3plus_{ENCODER}_final_epoch_model.pth" # Always save the final epoch model

# Initialize history dictionary to store metrics
history = {
    'train_loss': [], 'val_loss': [],
    'train_mIoU': [], 'val_mIoU': [],
    'train_Accuracy': [], 'val_Accuracy': [],
    'lr': [] # Track learning rate
}

# Check if train_loader exists before starting training
if train_loader:
    print(f"Starting training for {NUM_EPOCHS} epochs...")
    print(f"Using device: {DEVICE}")
    if not val_loader:
        print("Warning: val_loader not available. Validation, LR scheduling, and early stopping will be skipped. Saving final model only.")
        # If no validation, 'best' path is the same as 'final'
        model_save_path = final_model_save_path 

    start_time = time.time()

    for epoch in range(NUM_EPOCHS):
        epoch_start_time = time.time()
        
        # Clear CUDA cache (optional, might help with memory fragmentation)
        if DEVICE.type == 'cuda':
            torch.cuda.empty_cache()
            gc.collect()

        # --- Training ---
        train_loss, train_metrics = train_epoch(model, train_loader, combined_loss, optimizer, DEVICE, train_accuracy, train_miou, IGNORE_INDEX)
        history['train_loss'].append(train_loss)
        history['train_mIoU'].append(train_metrics.get('mIoU', 0.0))
        history['train_Accuracy'].append(train_metrics.get('Accuracy', 0.0))

        # --- Validation (Optional) ---
        val_loss = float('nan')
        val_metrics = {'Accuracy': 0.0, 'mIoU': 0.0} # Initialize with defaults
        if val_loader:
            val_loss, val_metrics = validate_epoch(model, val_loader, combined_loss, DEVICE, val_accuracy, val_miou, IGNORE_INDEX)
            history['val_loss'].append(val_loss)
            history['val_mIoU'].append(val_metrics.get('mIoU', 0.0))
            history['val_Accuracy'].append(val_metrics.get('Accuracy', 0.0))
            
            # --- Scheduler Step ---
            # ReduceLROnPlateau steps based on validation loss
            scheduler.step(val_loss)
            history['lr'].append(optimizer.param_groups[0]['lr'])
            
        else:
            # Append placeholders if validation is skipped
            history['val_loss'].append(float('nan'))
            history['val_mIoU'].append(float('nan'))
            history['val_Accuracy'].append(float('nan'))
            history['lr'].append(optimizer.param_groups[0]['lr']) # Still record LR

        epoch_duration = time.time() - epoch_start_time

        # --- Logging (using xm.master_print on TPU) ---
        log_fn = xm.master_print if _xla_available and isinstance(DEVICE, torch.device) and DEVICE.type == 'xla' else print
        
        log_fn(f"Epoch {epoch+1}/{NUM_EPOCHS} | Duration: {epoch_duration:.2f}s | LR: {optimizer.param_groups[0]['lr']:.1e}")
        log_fn(f"  Train Loss: {train_loss:.4f} | mIoU: {train_metrics.get('mIoU', 0.0):.4f} | Acc: {train_metrics.get('Accuracy', 0.0):.4f}")
        if val_loader:
            log_fn(f"  Val   Loss: {val_loss:.4f} | mIoU: {val_metrics.get('mIoU', 0.0):.4f} | Acc: {val_metrics.get('Accuracy', 0.0):.4f}")
        else:
            log_fn("  Validation skipped.")

        # --- Save Best Model (based on validation loss) & Early Stopping ---
        if val_loader:
            current_val_loss = val_loss
            # Use xm.save for TPU, ensuring it runs only on the master process
            if current_val_loss < best_val_loss:
                best_val_loss = current_val_loss
                patience_counter = 0 # Reset patience
                if _xla_available and isinstance(DEVICE, torch.device) and DEVICE.type == 'xla':
                    # Save on master process
                    xm.save(model.state_dict(), model_save_path, master_only=True)
                    log_fn(f"  -> Best model saved to {model_save_path} (Val Loss: {best_val_loss:.4f})")
                else:
                    torch.save(model.state_dict(), model_save_path)
                    log_fn(f"  -> Best model saved to {model_save_path} (Val Loss: {best_val_loss:.4f})")
            else:
                patience_counter += 1
                log_fn(f"  -> Patience: {patience_counter}/{PATIENCE}")
                if patience_counter >= PATIENCE:
                    log_fn(f"\nEarly stopping triggered after {epoch + 1} epochs.")
                    break # Exit training loop

        # --- Save Final Epoch Model ---
        if epoch == NUM_EPOCHS - 1:
             log_fn(f"\nReached final epoch {NUM_EPOCHS}.")
             if _xla_available and isinstance(DEVICE, torch.device) and DEVICE.type == 'xla':
                 xm.save(model.state_dict(), final_model_save_path, master_only=True)
                 log_fn(f"  -> Final epoch model saved to {final_model_save_path}")
             else:
                 torch.save(model.state_dict(), final_model_save_path)
                 log_fn(f"  -> Final epoch model saved to {final_model_save_path}")

    total_training_time = time.time() - start_time
    log_fn = xm.master_print if _xla_available and isinstance(DEVICE, torch.device) and DEVICE.type == 'xla' else print
    log_fn(f"\nTraining finished in {total_training_time // 60:.0f}m {total_training_time % 60:.0f}s")
    if val_loader:
        log_fn(f"Best Validation Loss achieved: {best_val_loss:.4f} (saved to {model_save_path})")
    log_fn(f"Model after final epoch saved to {final_model_save_path}")
else:
    print("Cannot start training: train_loader is not available.")

## 8. Plot Training History

In [None]:
def plot_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    has_val_data = 'val_loss' in history and any(not np.isnan(x) for x in history['val_loss'])

    # Determine number of plots needed (Loss, mIoU, Accuracy, LR)
    metrics_to_plot = ['Loss', 'mIoU', 'Accuracy']
    num_plots = len(metrics_to_plot) + 1 # Add one for LR
    num_cols = 2
    num_rows = (num_plots + num_cols - 1) // num_cols # Calculate rows needed

    plt.figure(figsize=(14, 5 * num_rows))

    # Plot Loss, mIoU, Accuracy
    for i, metric_name in enumerate(metrics_to_plot):
        plt.subplot(num_rows, num_cols, i + 1)
        
        # Construct history keys (e.g., 'train_loss', 'val_mIoU')
        train_key = f"train_{metric_name.lower().replace('miou', 'mIoU')}"
        val_key = f"val_{metric_name.lower().replace('miou', 'mIoU')}"
        
        if train_key in history and history[train_key]:
            plt.plot(epochs, history[train_key], 'bo-', label=f'Training {metric_name}')
        if has_val_data and val_key in history and history[val_key]:
            # Ensure validation data aligns with training epochs plotted
            val_epochs = range(1, len(history[val_key]) + 1)
            plt.plot(val_epochs, history[val_key], 'ro-', label=f'Validation {metric_name}')
            
        plt.title(f'Training and Validation {metric_name}')
        plt.xlabel('Epochs')
        plt.ylabel(metric_name)
        # Only show legend if there's data to plot
        handles, labels = plt.gca().get_legend_handles_labels()
        if handles:
            plt.legend()
        plt.grid(True)

    # Plot Learning Rate
    plt.subplot(num_rows, num_cols, len(metrics_to_plot) + 1)
    if 'lr' in history and history['lr']:
        lr_epochs = range(1, len(history['lr']) + 1)
        plt.plot(lr_epochs, history['lr'], 'go-', label='Learning Rate')
        plt.title('Learning Rate over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Learning Rate')
        plt.legend()
        plt.grid(True)
    else:
        plt.title('Learning Rate (No Data)')
        plt.text(0.5, 0.5, 'No LR data recorded', ha='center', va='center')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Check if history has data before plotting
if history['train_loss']:
    plot_history(history)
else:
    print("No training history to plot.")

## 9. Evaluation on Test Set

Load the best saved model (based on validation loss) and evaluate its performance on the unseen test set using Accuracy and mIoU.

In [None]:
# Determine which model to load for evaluation (prefer best based on val loss)
eval_model_path = None
log_fn = xm.master_print if _xla_available and isinstance(DEVICE, torch.device) and DEVICE.type == 'xla' else print

# Check if the 'best' model file exists (saved based on val loss)
if os.path.exists(model_save_path):
    eval_model_path = model_save_path
    log_fn(f"\nAttempting to load best model (lowest val loss) from: {eval_model_path}")
elif os.path.exists(final_model_save_path): # Fallback to the final epoch model
    eval_model_path = final_model_save_path
    log_fn(f"\nBest model not found. Attempting to load final epoch model from: {eval_model_path}")
else:
    log_fn(f"\nNo model file found at {model_save_path} or {final_model_save_path}. Cannot perform evaluation.")

model_loaded = False
if eval_model_path:
    # Re-initialize model architecture
    model_eval = smp.DeepLabV3Plus( # Use a different variable name
        encoder_name=ENCODER,
        encoder_weights=None, # No need to load pretrained weights again
        in_channels=3,
        classes=NUM_CLASSES,
        activation=None
    )
    try:
        # Load state dict - use map_location to load onto the correct device (CPU first if loading TPU model on non-TPU)
        map_loc = 'cpu' if _xla_available else DEVICE # Load to CPU if XLA exists, otherwise current device
        model_eval.load_state_dict(torch.load(eval_model_path, map_location=map_loc))
        model_eval.to(DEVICE) # Move model to the target device (TPU/CUDA/CPU)
        model_eval.eval() # Set to evaluation mode
        log_fn("Model loaded successfully for evaluation.")
        model_loaded = True
    except Exception as e:
        log_fn(f"Error loading model state dict from {eval_model_path}: {e}")
        model_loaded = False

# Evaluate on the test set (if available and model loaded)
if model_loaded and test_loader:
    log_fn("\nEvaluating on the test set...")
    # Instantiate test metrics
    test_accuracy = torchmetrics.classification.MulticlassAccuracy(
        num_classes=NUM_CLASSES, ignore_index=IGNORE_INDEX, average='micro'
    ).to(metrics_device)
    test_miou = torchmetrics.classification.MulticlassJaccardIndex(
        num_classes=NUM_CLASSES, ignore_index=IGNORE_INDEX, average='macro'
    ).to(metrics_device)
    
    # Use validate_epoch function for evaluation logic
    test_loss, test_metrics = validate_epoch(model_eval, test_loader, combined_loss, DEVICE, test_accuracy, test_miou, IGNORE_INDEX)
    
    log_fn(f"\n--- Test Set Performance ---")
    log_fn(f"  Test Loss:      {test_loss:.4f}")
    log_fn(f"  Test Accuracy:  {test_metrics.get('Accuracy', 0.0):.4f}")
    log_fn(f"  Test mIoU:      {test_metrics.get('mIoU', 0.0):.4f}")
    
    # --- Detailed Confusion Matrix (Run on master/single device) ---
    # Ensure this part runs only once, e.g., using xm.is_master_ordinal() for TPU
    run_cm = not (_xla_available and isinstance(DEVICE, torch.device) and DEVICE.type == 'xla') or xm.is_master_ordinal()
    if run_cm:
        log_fn("\nCalculating Confusion Matrix on Test Set (master/single device)...")
        all_preds = []
        all_masks = []
        # Use the standard test_loader for CM calculation on CPU/single GPU
        cm_loader = test_loader 
        cm_device = DEVICE # Use the main device unless forcing CPU
        # If on TPU, might need to run CM calculation on CPU with a CPU model copy
        # cm_device = torch.device('cpu')
        # model_eval_cpu = model_eval.to(cm_device) # Create CPU copy if needed
        
        model_eval.eval() # Ensure model is in eval mode
        with torch.no_grad():
            for images, masks in tqdm(cm_loader, desc="Test Prediction for CM"):
                images = images.to(cm_device)
                masks = masks.to(cm_device) # (B, H, W)
                outputs = model_eval(images) # Use model_eval (on cm_device)
                if isinstance(outputs, dict):
                    outputs = outputs.get('out', outputs)
                preds = torch.argmax(outputs, dim=1) # (B, H, W)
                
                # Flatten masks and predictions, ignore IGNORE_INDEX
                mask_flat = masks.view(-1).cpu().numpy()
                pred_flat = preds.view(-1).cpu().numpy()
                
                # Filter out ignored index pixels
                valid_indices = mask_flat != IGNORE_INDEX
                all_masks.extend(mask_flat[valid_indices])
                all_preds.extend(pred_flat[valid_indices])
        
        if all_masks:
            # Calculate confusion matrix
            cm = confusion_matrix(all_masks, all_preds, labels=list(range(NUM_CLASSES - 1))) # Exclude ignore index from labels
            
            # Plot confusion matrix
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                        xticklabels=class_names[:-1], yticklabels=class_names[:-1]) # Exclude ignore index name
            plt.xlabel('Predicted Labels')
            plt.ylabel('True Labels')
            plt.title('Confusion Matrix (Test Set - Excluding Unknown Class)')
            plt.show()
        else:
            log_fn("Could not calculate confusion matrix: No valid pixels found after filtering ignore index.")
    else:
         log_fn("Skipping Confusion Matrix calculation on non-master TPU core.")
        
elif model_loaded and not test_loader:
    log_fn("\nSkipping test set evaluation: test_loader not available (likely no test data found).")
elif not model_loaded:
    log_fn("\nSkipping test set evaluation: Model could not be loaded.")

## 10. Visualize Predictions

Display some predictions from the test set alongside the original images and ground truth masks, using the loaded evaluation model.

In [None]:
import matplotlib.colors as mcolors

def id_to_rgb_mask(mask_id, id_to_rgb_map):
    """Converts a class ID mask (H, W) to an RGB mask (H, W, 3)."""
    rgb_mask = np.zeros((*mask_id.shape, 3), dtype=np.uint8)
    for class_id, color in id_to_rgb_map.items():
        rgb_mask[mask_id == class_id] = color
    return rgb_mask

def visualize_predictions(model, loader, device, num_samples=5, id_to_rgb_map=None):
    if not loader:
        print("Cannot visualize predictions: Loader not available.")
        return
        
    if not hasattr(model, 'eval'): # Basic check if model is a PyTorch model
         print("Cannot visualize predictions: Invalid model object provided.")
         return
         
    model.eval()
    samples_shown = 0
    
    # Create a colormap and norm for the ID masks if id_to_rgb_map is not provided
    cmap = None
    norm = None
    if id_to_rgb_map is None:
        # Check if matplotlib version supports get_cmap
        try:
            cmap = plt.colormaps.get_cmap('viridis', num_labels)
        except AttributeError:
            cmap = plt.cm.get_cmap('viridis', num_labels) # Fallback for older versions
        norm = mcolors.Normalize(vmin=0, vmax=num_labels-1)
    
    plt.figure(figsize=(18, 6 * num_samples)) # Adjusted figure size
    
    with torch.no_grad():
        try:
            data_iterator = iter(loader)
            while samples_shown < num_samples:
                try:
                    images, masks_true_id = next(data_iterator)
                except StopIteration:
                    print("\nReached end of loader before showing desired number of samples.")
                    break # Exit outer loop if loader is exhausted
                    
                images = images.to(device)
                # masks_true_id are already (B, H, W) LongTensor

                outputs = model(images) # (B, C, H, W)
                masks_pred_id = torch.argmax(outputs, dim=1) # (B, H, W)

                for i in range(images.shape[0]):
                    if samples_shown >= num_samples:
                        break # Exit inner loop

                    img = images[i].permute(1, 2, 0).cpu().numpy()
                    true_mask_id = masks_true_id[i].cpu().numpy()
                    pred_mask_id = masks_pred_id[i].cpu().numpy()

                    # Denormalize image for display (assuming standard ImageNet normalization)
                    mean = np.array([0.485, 0.456, 0.406])
                    std = np.array([0.229, 0.224, 0.225])
                    img = std * img + mean
                    img = np.clip(img, 0, 1)

                    # Convert ID masks to RGB for visualization
                    if id_to_rgb_map:
                        true_mask_rgb = id_to_rgb_mask(true_mask_id, id_to_rgb_map)
                        pred_mask_rgb = id_to_rgb_mask(pred_mask_id, id_to_rgb_map)
                    elif cmap and norm:
                        # Use colormap if no RGB map provided
                        true_mask_rgb = cmap(norm(true_mask_id))[:, :, :3] # Get RGB from colormap
                        pred_mask_rgb = cmap(norm(pred_mask_id))[:, :, :3]
                        # Convert to uint8 if needed, though imshow handles float [0,1]
                        # true_mask_rgb = (true_mask_rgb * 255).astype(np.uint8)
                        # pred_mask_rgb = (pred_mask_rgb * 255).astype(np.uint8)
                    else: # Fallback if no map and cmap failed
                         true_mask_rgb = true_mask_id # Show raw IDs
                         pred_mask_rgb = pred_mask_id

                    plt.subplot(num_samples, 3, samples_shown * 3 + 1)
                    plt.imshow(img)
                    plt.title(f"Sample {samples_shown+1} - Image")
                    plt.axis('off')

                    plt.subplot(num_samples, 3, samples_shown * 3 + 2)
                    plt.imshow(true_mask_rgb)
                    plt.title(f"Sample {samples_shown+1} - True Mask")
                    plt.axis('off')

                    plt.subplot(num_samples, 3, samples_shown * 3 + 3)
                    plt.imshow(pred_mask_rgb)
                    plt.title(f"Sample {samples_shown+1} - Predicted Mask")
                    plt.axis('off')

                    samples_shown += 1
                # End of inner loop (batch processing)
            # End of outer loop (sample count check)
        except Exception as e:
             print(f"An error occurred during visualization: {e}")
             import traceback
             traceback.print_exc() # Print detailed traceback
            
    if samples_shown > 0:
        plt.tight_layout(pad=2.0)
        plt.show()
    elif loader```json

# Determine which model and loader to use for visualization
vis_model = None
vis_loader = None
vis_loader_name = ""

# Prioritize using the evaluated model ('model_eval') if it was loaded
if 'model_eval' in locals() and model_loaded:
    vis_model = model_eval
    # Prioritize loader: Test > Validation > Train
    if test_loader:
        vis_loader = test_loader
        vis_loader_name = "test set"
    elif val_loader:
        vis_loader = val_loader
        vis_loader_name = "validation set"
    elif train_loader:
        vis_loader = train_loader
        vis_loader_name = "training set"
# Fallback to the model currently in memory ('model') if eval failed/skipped
elif 'model' in locals():
    vis_model = model
    print("\nWarning: Using model from training state for visualization (evaluation model not loaded/available).")
    # Prioritize loader: Test > Validation > Train
    if test_loader:
        vis_loader = test_loader
        vis_loader_name = "test set"
    elif val_loader:
        vis_loader = val_loader
        vis_loader_name = "validation set"
    elif train_loader:
        vis_loader = train_loader
        vis_loader_name = "training set"

# Visualize predictions if a model and loader were selected
if vis_model and vis_loader:
    print(f"\nVisualizing predictions on {vis_loader_name}...")
    visualize_predictions(vis_model, vis_loader, DEVICE, num_samples=5, id_to_rgb_map=id_to_rgb)
else:
    print("\nCannot visualize predictions: Suitable model or data loader not available.")