# Off-Road Semantic Scene Segmentation - Data Preprocessing
## Track 2: Desert Environment Segmentation with Falcon Dataset

This notebook handles data loading, preprocessing, and augmentation for the desert semantic segmentation challenge.

**Dataset Classes (10 total):**
- Trees
- Lush Bushes
- Dry Grass
- Dry Bushes
- Ground Clutter
- Flowers
- Logs
- Rocks
- Landscape
- Sky

## 1. Setup and Imports

In [None]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import os
import glob
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import json

import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
from pathlib import Path

print(f"PyTorch Version: {torch.__version__}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

## 2. Configuration and Hyperparameters

In [None]:
# Dataset paths - UPDATE THESE AFTER DOWNLOADING FROM FALCON
DATA_ROOT = './falcon_desert_dataset'  # Root directory containing train, val, test
TRAIN_IMAGES = os.path.join(DATA_ROOT, 'train/images')
TRAIN_MASKS = os.path.join(DATA_ROOT, 'train/masks')
VAL_IMAGES = os.path.join(DATA_ROOT, 'val/images')
VAL_MASKS = os.path.join(DATA_ROOT, 'val/masks')
TEST_IMAGES = os.path.join(DATA_ROOT, 'test/images')

# Image preprocessing parameters
HEIGHT = 512  # Can adjust based on GPU memory
WIDTH = 512
NUM_CLASSES = 10  # Desert scene classes

# Class names and colors for visualization
CLASS_NAMES = [
    'Trees',
    'Lush Bushes',
    'Dry Grass',
    'Dry Bushes',
    'Ground Clutter',
    'Flowers',
    'Logs',
    'Rocks',
    'Landscape',
    'Sky'
]

# Color palette for visualization (RGB)
CLASS_COLORS = [
    [0, 128, 0],      # Trees - Green
    [34, 139, 34],    # Lush Bushes - Forest Green
    [189, 183, 107],  # Dry Grass - Dark Khaki
    [160, 82, 45],    # Dry Bushes - Sienna
    [139, 69, 19],    # Ground Clutter - Saddle Brown
    [255, 182, 193],  # Flowers - Light Pink
    [101, 67, 33],    # Logs - Dark Brown
    [128, 128, 128],  # Rocks - Gray
    [210, 180, 140],  # Landscape - Tan
    [135, 206, 235]   # Sky - Sky Blue
]

print(f"Dataset Configuration:")
print(f"  Image Size: {HEIGHT}x{WIDTH}")
print(f"  Number of Classes: {NUM_CLASSES}")
print(f"  Classes: {', '.join(CLASS_NAMES)}")

## 3. Dataset Exploration

In [None]:
def explore_dataset(data_root):
    """Explore the dataset structure and count images"""
    
    splits = ['train', 'val', 'test']
    stats = {}
    
    for split in splits:
        img_dir = os.path.join(data_root, split, 'images')
        mask_dir = os.path.join(data_root, split, 'masks')
        
        if os.path.exists(img_dir):
            images = sorted(glob.glob(os.path.join(img_dir, '*')))
            masks = sorted(glob.glob(os.path.join(mask_dir, '*'))) if os.path.exists(mask_dir) else []
            
            stats[split] = {
                'num_images': len(images),
                'num_masks': len(masks),
                'has_masks': len(masks) > 0
            }
            
            print(f"{split.upper()} Set:")
            print(f"  Images: {len(images)}")
            print(f"  Masks: {len(masks)}")
            
            # Check first image dimensions
            if len(images) > 0:
                sample_img = Image.open(images[0])
                print(f"  Sample image size: {sample_img.size}")
            print()
    
    return stats

# Explore the dataset
if os.path.exists(DATA_ROOT):
    dataset_stats = explore_dataset(DATA_ROOT)
else:
    print(f"⚠️ Dataset directory not found: {DATA_ROOT}")
    print("Please download the dataset from Falcon and update DATA_ROOT path.")

## 4. Custom Dataset Class

In [None]:
class DesertSegmentationDataset(Dataset):
    """
    Custom Dataset for Desert Semantic Segmentation
    
    Args:
        image_dir: Directory containing RGB images
        mask_dir: Directory containing segmentation masks
        transform: Transformations to apply
        is_train: Whether this is training set (for augmentation)
    """
    
    def __init__(self, image_dir, mask_dir=None, transform=None, is_train=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.is_train = is_train
        
        # Get all image paths
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, '*')))
        
        # Get mask paths if available
        if mask_dir is not None and os.path.exists(mask_dir):
            self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, '*')))
        else:
            self.mask_paths = None
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Load mask if available
        if self.mask_paths is not None:
            mask_path = self.mask_paths[idx]
            mask = Image.open(mask_path)
            
            # Convert mask to numpy for easier processing
            mask = np.array(mask)
            
            # If mask is RGB, convert to class indices
            if len(mask.shape) == 3:
                mask = self._rgb_to_class_idx(mask)
        else:
            mask = None
        
        # Apply transformations
        if self.transform:
            image, mask = self.transform(image, mask)
        
        if mask is not None:
            return image, mask
        else:
            return image, os.path.basename(img_path)
    
    def _rgb_to_class_idx(self, mask_rgb):
        """Convert RGB mask to class indices"""
        h, w = mask_rgb.shape[:2]
        mask_idx = np.zeros((h, w), dtype=np.int64)
        
        for class_idx, color in enumerate(CLASS_COLORS):
            matches = np.all(mask_rgb == color, axis=-1)
            mask_idx[matches] = class_idx
        
        return mask_idx

## 5. Data Augmentation and Transformations

In [None]:
class JointTransform:
    """
    Apply transformations to both image and mask
    Important for segmentation tasks where image and mask must be transformed together
    """
    
    def __init__(self, height=512, width=512, is_train=True):
        self.height = height
        self.width = width
        self.is_train = is_train
        
        # Normalization values (ImageNet stats)
        self.normalize = T.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    
    def __call__(self, image, mask=None):
        # Resize
        image = TF.resize(image, (self.height, self.width), 
                         interpolation=T.InterpolationMode.BILINEAR)
        if mask is not None:
            mask = cv2.resize(mask, (self.width, self.height), 
                            interpolation=cv2.INTER_NEAREST)
        
        # Training augmentations
        if self.is_train:
            # Random horizontal flip
            if np.random.rand() > 0.5:
                image = TF.hflip(image)
                if mask is not None:
                    mask = np.fliplr(mask).copy()
            
            # Random vertical flip
            if np.random.rand() > 0.5:
                image = TF.vflip(image)
                if mask is not None:
                    mask = np.flipud(mask).copy()
            
            # Random rotation (in 90-degree increments)
            if np.random.rand() > 0.75:
                angle = np.random.choice([90, 180, 270])
                image = TF.rotate(image, angle, interpolation=T.InterpolationMode.BILINEAR)
                if mask is not None:
                    k = angle // 90
                    mask = np.rot90(mask, k).copy()
            
            # Color jitter (only for image, not mask)
            if np.random.rand() > 0.5:
                brightness = np.random.uniform(0.8, 1.2)
                contrast = np.random.uniform(0.8, 1.2)
                saturation = np.random.uniform(0.8, 1.2)
                hue = np.random.uniform(-0.1, 0.1)
                
                image = TF.adjust_brightness(image, brightness)
                image = TF.adjust_contrast(image, contrast)
                image = TF.adjust_saturation(image, saturation)
                image = TF.adjust_hue(image, hue)
        
        # Convert to tensor
        image = TF.to_tensor(image)
        image = self.normalize(image)
        
        if mask is not None:
            mask = torch.from_numpy(mask).long()
            return image, mask
        else:
            return image, None


# Create transform instances
train_transform = JointTransform(height=HEIGHT, width=WIDTH, is_train=True)
val_transform = JointTransform(height=HEIGHT, width=WIDTH, is_train=False)

## 6. Visualization Functions

In [None]:
def visualize_sample(dataset, idx, figsize=(15, 5)):
    """Visualize a sample from the dataset"""
    
    image, mask = dataset[idx]
    
    # Denormalize image for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image_vis = image * std + mean
    image_vis = torch.clamp(image_vis, 0, 1)
    
    # Convert to numpy for plotting
    image_np = image_vis.permute(1, 2, 0).numpy()
    
    if isinstance(mask, torch.Tensor):
        mask_np = mask.numpy()
        
        # Create colored mask
        mask_color = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)
        for class_idx in range(NUM_CLASSES):
            mask_color[mask_np == class_idx] = CLASS_COLORS[class_idx]
        
        # Plot
        fig, axes = plt.subplots(1, 3, figsize=figsize)
        
        axes[0].imshow(image_np)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        axes[1].imshow(mask_color)
        axes[1].set_title('Ground Truth Mask')
        axes[1].axis('off')
        
        # Overlay
        overlay = (image_np * 0.6 + mask_color / 255 * 0.4)
        axes[2].imshow(overlay)
        axes[2].set_title('Overlay')
        axes[2].axis('off')
        
    else:
        # Test set without masks
        plt.figure(figsize=(8, 8))
        plt.imshow(image_np)
        plt.title(f'Test Image: {mask}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()


def plot_class_distribution(dataset, title="Class Distribution"):
    """Plot the distribution of classes in the dataset"""
    
    class_counts = np.zeros(NUM_CLASSES)
    
    print("Analyzing class distribution...")
    for idx in tqdm(range(len(dataset))):
        _, mask = dataset[idx]
        if isinstance(mask, torch.Tensor):
            unique, counts = torch.unique(mask, return_counts=True)
            for u, c in zip(unique, counts):
                if u < NUM_CLASSES:
                    class_counts[u] += c.item()
    
    # Plot
    plt.figure(figsize=(12, 6))
    bars = plt.bar(range(NUM_CLASSES), class_counts, color=['#{:02x}{:02x}{:02x}'.format(*c) for c in CLASS_COLORS])
    plt.xlabel('Class')
    plt.ylabel('Pixel Count')
    plt.title(title)
    plt.xticks(range(NUM_CLASSES), CLASS_NAMES, rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nClass Distribution:")
    for i, (name, count) in enumerate(zip(CLASS_NAMES, class_counts)):
        percentage = (count / class_counts.sum()) * 100
        print(f"  {name:15s}: {count:12.0f} pixels ({percentage:5.2f}%)")

## 7. Create Datasets and DataLoaders

In [None]:
if os.path.exists(DATA_ROOT):
    # Create datasets
    train_dataset = DesertSegmentationDataset(
        image_dir=TRAIN_IMAGES,
        mask_dir=TRAIN_MASKS,
        transform=train_transform,
        is_train=True
    )
    
    val_dataset = DesertSegmentationDataset(
        image_dir=VAL_IMAGES,
        mask_dir=VAL_MASKS,
        transform=val_transform,
        is_train=False
    )
    
    test_dataset = DesertSegmentationDataset(
        image_dir=TEST_IMAGES,
        mask_dir=None,  # No masks for test set
        transform=val_transform,
        is_train=False
    )
    
    print(f"Dataset sizes:")
    print(f"  Training: {len(train_dataset)} samples")
    print(f"  Validation: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")
else:
    print("⚠️ Please download and set up the Falcon dataset first.")

## 8. Visualize Sample Data

In [None]:
if os.path.exists(DATA_ROOT):
    # Visualize random training samples
    print("Training Set Samples:")
    for i in range(3):
        idx = np.random.randint(0, len(train_dataset))
        visualize_sample(train_dataset, idx)
    
    # Visualize random validation samples
    print("\nValidation Set Samples:")
    for i in range(2):
        idx = np.random.randint(0, len(val_dataset))
        visualize_sample(val_dataset, idx)

## 9. Analyze Class Distribution

In [None]:
if os.path.exists(DATA_ROOT):
    # Analyze training set distribution (sample subset to save time)
    sample_size = min(100, len(train_dataset))
    sample_indices = np.random.choice(len(train_dataset), sample_size, replace=False)
    
    class SubsetDataset(Dataset):
        def __init__(self, dataset, indices):
            self.dataset = dataset
            self.indices = indices
        
        def __len__(self):
            return len(self.indices)
        
        def __getitem__(self, idx):
            return self.dataset[self.indices[idx]]
    
    sample_dataset = SubsetDataset(train_dataset, sample_indices)
    plot_class_distribution(sample_dataset, f"Training Set Class Distribution (Sample of {sample_size})")

## 10. Save Preprocessed Data Information

In [None]:
if os.path.exists(DATA_ROOT):
    # Save dataset configuration
    config = {
        'num_classes': NUM_CLASSES,
        'class_names': CLASS_NAMES,
        'class_colors': CLASS_COLORS,
        'image_size': [HEIGHT, WIDTH],
        'train_size': len(train_dataset),
        'val_size': len(val_dataset),
        'test_size': len(test_dataset),
    }
    
    with open('dataset_config.json', 'w') as f:
        json.dump(config, f, indent=2)
    
    print("Dataset configuration saved to 'dataset_config.json'")
    print("\nPreprocessing complete! Ready for model training.")

## Summary

This notebook completed:
1. ✅ Dataset loading and exploration
2. ✅ Custom Dataset class for desert segmentation
3. ✅ Data augmentation pipeline
4. ✅ Visualization utilities
5. ✅ Class distribution analysis
6. ✅ Configuration saving

**Next Steps:**
- Proceed to `02_model_train_desert_segmentation.ipynb` for model training
- Experiment with different augmentation strategies
- Consider class imbalance for loss function weighting