# Satellite Image Feature Extraction Using Deep Learning

## Introduction

this notebook extracts meaningful features from satellite imagery using state-of-the-art deep learning techniques. These features are important for various downstream tasks such as land cover classification, object detection, and change detection.

### Mathematical Foundations

In deep learning, features are extracted through convolutional layers, which perform the following operation:

$$F(i,j) = \sum_m \sum_n I(i+m, j+n) K(m,n)$$

where $I$ is the input image, $K$ is the kernel/filter, and $F$ is the resulting feature map.

### Objectives

1. Load and preprocess satellite imagery
2. Extract features using pre-trained CNNs (Transfer Learning)
3. Visualize extracted features and feature maps
4. Dimensionality reduction for feature analysis (PCA, t-SNE)
5. Save features for downstream tasks

In [None]:
# Basic imports
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import random
import seaborn as sns

# Image processing
import cv2
from PIL import Image
import rasterio
from skimage import io, transform
from skimage.feature import hog

# Deep learning imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights, vgg16, VGG16_Weights

# Feature visualization and dimensionality reduction
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set plotting style
plt.style.use('seaborn-whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

# Load Satellite Imagery Dataset

Satellite images have unique characteristics that require specialized handling:

1. **Multi-spectral data**: Satellite images often contain multiple bands beyond RGB
2. **Geospatial metadata**: Images are georeferenced with coordinate systems
3. **Large file sizes**: Satellite images can be extremely large
4. **Various formats**: Common formats include GeoTIFF, JP2, and NITF

Let's create a utility class to handle satellite image loading efficiently.

In [None]:
class SatelliteImageLoader:
    """
    A class to load and preprocess satellite images.
    
    Attributes:
        data_dir (str): Directory containing the satellite images.
        transform (callable): Transformations to apply to the images.
    """
    
    def __init__(self, data_dir, transform=None):
        """
        Initialize the SatelliteImageLoader.
        
        Args:
            data_dir (str): Directory containing the satellite images.
            transform (callable, optional): Transformations to apply to the images.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.image_files = self._get_image_files()
        
    def _get_image_files(self):
        """
        Get all image files in the data directory.
        
        Returns:
            list: List of image file paths.
        """
        image_extensions = ('.tif', '.jpg', '.png')
        if os.path.exists(self.data_dir):
            return [os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir) 
                    if f.lower().endswith(image_extensions)]
        else:
            print(f"Directory {self.data_dir} not found")
            return []
    
    def load_image(self, idx):
        """
        Load a specific image by index.
        
        Args:
            idx (int): Index of the image to load.
            
        Returns:
            PIL.Image: Loaded image.
        """
        image_path = self.image_files[idx]
        
        # Check if it's a GeoTIFF file
        if image_path.lower().endswith('.tif'):
            try:
                with rasterio.open(image_path) as src:
                    # Read all bands and convert to RGB if needed
                    img = src.read()
                    # If more than 3 bands, take first 3 (assuming RGB)
                    if img.shape[0] > 3:
                        img = img[:3]
                    # Transpose to (H, W, C) format for display
                    img = np.transpose(img, (1, 2, 0))
                    # Normalize to 0-255 range if needed
                    if img.max() > 0:
                        img = (img / img.max() * 255).astype(np.uint8)
                    return Image.fromarray(img)
            except:
                print(f"Error opening {image_path} with rasterio. Trying with PIL...")
                return Image.open(image_path).convert('RGB')
        else:
            # For regular image formats
            return Image.open(image_path).convert('RGB')
    
    def get_batch(self, batch_size=4):
        """
        Get a batch of random images.
        
        Args:
            batch_size (int): Number of images to return.
            
        Returns:
            torch.Tensor: Batch of images.
        """
        if len(self.image_files) == 0:
            print("No image files found")
            return None
            
        indices = np.random.choice(len(self.image_files), min(batch_size, len(self.image_files)), replace=False)
        images = [self.load_image(idx) for idx in indices]
        
        if self.transform:
            images = [self.transform(img) for img in images]
            return torch.stack(images)
        return images
    
    def visualize_samples(self, num_samples=4, figsize=(12, 12)):
        """
        Visualize a random sample of images.
        
        Args:
            num_samples (int): Number of images to visualize.
            figsize (tuple): Figure size (width, height).
        """
        if len(self.image_files) == 0:
            print("No image files found")
            return
            
        indices = np.random.choice(len(self.image_files), min(num_samples, len(self.image_files)), replace=False)
        
        plt.figure(figsize=figsize)
        for i, idx in enumerate(indices):
            img = self.load_image(idx)
            plt.subplot(int(np.ceil(num_samples / 2)), 2, i+1)
            plt.imshow(img)
            plt.title(f"Sample {i+1}: {os.path.basename(self.image_files[idx])}")
            plt.axis('off')
        plt.tight_layout()
        plt.show()

# Set the path to your satellite image directory
DATA_DIR = "../data/sentinelsat"  # Update this path as needed

# Check if the directory exists
if os.path.exists(DATA_DIR):
    # Define image transformations for preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to model input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
    ])
    
    # Create the image loader
    loader = SatelliteImageLoader(DATA_DIR)
    
    # Visualize some sample images
    loader.visualize_samples(num_samples=4)
else:
    print(f"Data directory {DATA_DIR} not found. Please update the path.")

## Transfer Learning

We'll leverage a pre-trained ResNet50 model:
- Trained on millions of images (ImageNet)
- Strong feature representations
- Adaptable to satellite imagery through fine-tuning

The mathematical formulation for feature extraction via CNN is:

$$f(x) = \phi_L(...\phi_2(\phi_1(x;w_1);w_2)...;w_L)$$

Where $\phi_l$ represents the function of layer $l$ with parameters $w_l$.

In [None]:
class DeepFeatureExtractor:
    """
    A class for extracting deep features from images using pre-trained models.
    
    Attributes:
        model (torch.nn.Module): The pre-trained model for feature extraction.
        layer_name (str): The name of the layer from which to extract features.
        transform (callable): Transformations to apply to the images.
        device (torch.device): The device to use for computation.
    """
    
    def __init__(self, model_name='resnet50', layer_name='avgpool', device=device):
        """
        Initialize the feature extractor.
        
        Args:
            model_name (str): Name of the pre-trained model to use ('resnet50' or 'vgg16').
            layer_name (str): Name of the layer from which to extract features.
            device (torch.device): Device to use for computation.
        """
        self.device = device
        
        # Load the pre-trained model
        if model_name == 'resnet50':
            self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        elif model_name == 'vgg16':
            self.model = vgg16(weights=VGG16_Weights.DEFAULT)
        else:
            raise ValueError(f"Model {model_name} not supported.")
            
        self.model = self.model.to(device)
        self.model.eval()  # Set to evaluation mode
        
        self.layer_name = layer_name
        self.features = None
        
        # Register hook to capture features
        self._register_hook()
        
        # Default transform for preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def _register_hook(self):
        """
        Register a forward hook to capture features from the specified layer.
        """
        def hook_fn(module, input, output):
            self.features = output.detach()
            
        # Find the target layer and register hook
        if hasattr(self.model, self.layer_name):
            getattr(self.model, self.layer_name).register_forward_hook(hook_fn)
        else:
            # For other layers, need to find them dynamically
            for name, module in self.model.named_modules():
                if name == self.layer_name:
                    module.register_forward_hook(hook_fn)
                    break
    
    def extract_features(self, images, flatten=True):
        """
        Extract features from input images.
        
        Args:
            images (torch.Tensor or PIL.Image or list): Input images.
            flatten (bool): Whether to flatten the features.
            
        Returns:
            torch.Tensor: Extracted features.
        """
        # Handle different input types
        if isinstance(images, list) and isinstance(images[0], Image.Image):
            # List of PIL images
            tensor_images = torch.stack([self.transform(img) for img in images])
        elif isinstance(images, Image.Image):
            # Single PIL image
            tensor_images = self.transform(images).unsqueeze(0)
        else:
            # Already a tensor
            tensor_images = images
        
        tensor_images = tensor_images.to(self.device)
        
        # Forward pass through the model
        with torch.no_grad():
            _ = self.model(tensor_images)
            
        # Get the features
        features = self.features
        
        # Flatten if requested
        if flatten:
            features = features.view(features.size(0), -1)
            
        return features

# Create a feature extractor
feature_extractor = DeepFeatureExtractor(model_name='resnet50', layer_name='avgpool')

# Check if data directory exists
if os.path.exists(DATA_DIR) and len(os.listdir(DATA_DIR)) > 0:
    # Create the image loader with no transform (will be applied in extractor)
    loader = SatelliteImageLoader(DATA_DIR, transform=None)
    
    # Load some sample images
    num_samples = min(4, len(loader.image_files))
    if num_samples > 0:
        indices = np.random.choice(len(loader.image_files), num_samples, replace=False)
        images = [loader.load_image(idx) for idx in indices]
        
        # Extract features
        features = feature_extractor.extract_features(images)
        
        print(f"Extracted features shape: {features.shape}")
        print(f"First few feature values for first image: {features[0, :5]}")
else:
    print("No images found in the data directory.")

In [None]:
class FeatureMapVisualizer:
    """
    A class for visualizing CNN feature maps.
    
    Attributes:
        model (torch.nn.Module): The model for feature extraction.
        device (torch.device): The device to use for computation.
    """
    
    def __init__(self, model_name='resnet50', device=device):
        """
        Initialize the feature map visualizer.
        
        Args:
            model_name (str): Name of the pre-trained model.
            device (torch.device): Device to use for computation.
        """
        self.device = device
        
        # Load the model
        if model_name == 'resnet50':
            self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        elif model_name == 'vgg16':
            self.model = vgg16(weights=VGG16_Weights.DEFAULT)
        else:
            raise ValueError(f"Model {model_name} not supported.")
            
        self.model = self.model.to(device)
        self.model.eval()
        
        # Store feature maps
        self.feature_maps = {}
        
        # Register hooks
        self._register_hooks()
        
        # Default transform
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def _register_hooks(self):
        """
        Register forward hooks to capture feature maps from various layers.
        """
        def hook_fn(name):
            def hook(module, input, output):
                self.feature_maps[name] = output.detach()
            return hook
        
        # Register hooks for interesting layers in ResNet50
        if hasattr(self.model, 'layer1'):
            # ResNet architecture
            self.model.layer1[0].conv1.register_forward_hook(hook_fn('layer1.0.conv1'))
            self.model.layer2[0].conv1.register_forward_hook(hook_fn('layer2.0.conv1'))
            self.model.layer3[0].conv1.register_forward_hook(hook_fn('layer3.0.conv1'))
            self.model.layer4[0].conv1.register_forward_hook(hook_fn('layer4.0.conv1'))
        elif hasattr(self.model, 'features'):
            # VGG architecture
            self.model.features[0].register_forward_hook(hook_fn('features.0'))  # First conv
            self.model.features[5].register_forward_hook(hook_fn('features.5'))  # After first maxpool
            self.model.features[10].register_forward_hook(hook_fn('features.10'))  # Middle
            self.model.features[20].register_forward_hook(hook_fn('features.20'))  # Later
    
    def get_feature_maps(self, image):
        """
        Get feature maps for an input image.
        
        Args:
            image (PIL.Image or torch.Tensor): Input image.
            
        Returns:
            dict: Feature maps from different layers.
        """
        if isinstance(image, Image.Image):
            tensor_image = self.transform(image).unsqueeze(0)
        else:
            tensor_image = image
            
        tensor_image = tensor_image.to(self.device)
        
        with torch.no_grad():
            _ = self.model(tensor_image)
            
        return self.feature_maps
    
    def visualize(self, image, num_features=16, figsize=(15, 10)):
        """
        Visualize feature maps for an image.
        
        Args:
            image (PIL.Image): Input image.
            num_features (int): Number of feature maps to display per layer.
            figsize (tuple): Figure size.
        """
        # Get feature maps
        feature_maps = self.get_feature_maps(image)
        
        # Display the original image
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        plt.title("Original Image")
        plt.axis('off')
        plt.show()
        
        # Display feature maps from each layer
        for layer_name, feature_map in feature_maps.items():
            # Move to CPU and convert to numpy
            feature_map = feature_map[0].cpu().numpy()
            
            # Determine grid dimensions
            n_features = min(num_features, feature_map.shape[0])
            grid_size = int(np.ceil(np.sqrt(n_features)))
            
            plt.figure(figsize=figsize)
            plt.suptitle(f"Feature Maps: {layer_name}")
            
            for i in range(n_features):
                plt.subplot(grid_size, grid_size, i + 1)
                
                # Normalize for better visualization
                feat = feature_map[i]
                if feat.max() > feat.min():
                    feat = (feat - feat.min()) / (feat.max() - feat.min())
                
                plt.imshow(feat, cmap='viridis')
                plt.axis('off')
                
            plt.tight_layout()
            plt.show()

# Create a feature map visualizer
visualizer = FeatureMapVisualizer()

# Check if data directory exists
if os.path.exists(DATA_DIR) and len(os.listdir(DATA_DIR)) > 0:
    # Load a sample image
    loader = SatelliteImageLoader(DATA_DIR)
    if len(loader.image_files) > 0:
        sample_image = loader.load_image(0)  # Load the first image
        
        # Visualize feature maps
        visualizer.visualize(sample_image, num_features=8)
    else:
        print("No images found in the data directory.")
else:
    print("Data directory not found.")

# Dimensionality Reduction and Feature Visualization

High-dimensional features extracted from CNNs (e.g., 2048D from ResNet50) are difficult to visualize directly. We can use dimensionality reduction techniques to visualize these features in lower dimensions.

## Principal Component Analysis (PCA)

PCA projects data along directions of maximum variance:

$$Z = X W$$

Where:
- $X$ is the mean-centered data matrix
- $W$ is the matrix of eigenvectors of the covariance matrix $X^TX$
- $Z$ is the projected data

## t-Distributed Stochastic Neighbor Embedding (t-SNE)

t-SNE preserves local neighborhood structure by modeling similarities as conditional probabilities:

$$p_{j|i} = \frac{\exp(-\|x_i-x_j\|^2/2\sigma_i^2)}{\sum_{k \neq i}\exp(-\|x_i-x_k\|^2/2\sigma_i^2)}$$

These visualizations help identify clusters, outliers, and patterns in the feature space.

In [None]:
def reduce_dimensions(features, method='pca', n_components=2):
    """
    Reduce dimensionality of features.
    
    Args:
        features (torch.Tensor or np.ndarray): Features to reduce.
        method (str): Reduction method ('pca' or 'tsne').
        n_components (int): Number of components to keep.
        
    Returns:
        np.ndarray: Reduced features.
    """
    # Convert to numpy if needed
    if isinstance(features, torch.Tensor):
        features = features.cpu().numpy()
    
    # Apply dimensionality reduction
    if method.lower() == 'pca':
        reducer = PCA(n_components=n_components)
    elif method.lower() == 'tsne':
        reducer = TSNE(n_components=n_components, random_state=42)
    else:
        raise ValueError(f"Method {method} not supported. Use 'pca' or 'tsne'.")
    
    reduced_features = reducer.fit_transform(features)
    return reduced_features

def visualize_features_2d(features, labels=None, method='pca', figsize=(10, 8)):
    """
    Visualize features in 2D.
    
    Args:
        features (torch.Tensor or np.ndarray): Features to visualize.
        labels (np.ndarray, optional): Labels for coloring points.
        method (str): Reduction method ('pca' or 'tsne').
        figsize (tuple): Figure size.
    """
    # Reduce to 2D
    features_2d = reduce_dimensions(features, method=method, n_components=2)
    
    # Create plot
    plt.figure(figsize=figsize)
    
    if labels is not None:
        # Color by labels if provided
        scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='viridis', alpha=0.8)
        plt.colorbar(scatter, label="Label")
    else:
        plt.scatter(features_2d[:, 0], features_2d[:, 1], alpha=0.8)
        
    plt.title(f"2D Feature Visualization using {method.upper()}")
    plt.xlabel(f"{method.upper()} Component 1")
    plt.ylabel(f"{method.upper()} Component 2")
    plt.tight_layout()
    plt.show()

def visualize_features_3d(features, labels=None, method='pca', figsize=(12, 10)):
    """
    Visualize features in 3D.
    
    Args:
        features (torch.Tensor or np.ndarray): Features to visualize.
        labels (np.ndarray, optional): Labels for coloring points.
        method (str): Reduction method ('pca' or 'tsne').
        figsize (tuple): Figure size.
    """
    # Reduce to 3D
    features_3d = reduce_dimensions(features, method=method, n_components=3)
    
    # Create 3D plot
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    
    if labels is not None:
        scatter = ax.scatter(features_3d[:, 0], features_3d[:, 1], features_3d[:, 2], 
                             c=labels, cmap='viridis', alpha=0.8)
        fig.colorbar(scatter, ax=ax, label="Label")
    else:
        ax.scatter(features_3d[:, 0], features_3d[:, 1], features_3d[:, 2], alpha=0.8)
    
    ax.set_title(f"3D Feature Visualization using {method.upper()}")
    ax.set_xlabel(f"{method.upper()} Component 1")
    ax.set_ylabel(f"{method.upper()} Component 2")
    ax.set_zlabel(f"{method.upper()} Component 3")
    plt.tight_layout()
    plt.show()

# Check if data directory exists
if os.path.exists(DATA_DIR) and len(os.listdir(DATA_DIR)) > 0:
    # Create loader and extractor
    loader = SatelliteImageLoader(DATA_DIR)
    extractor = DeepFeatureExtractor(model_name='resnet50', layer_name='avgpool')
    
    # Load multiple images (use all available images, up to a reasonable number)
    num_samples = min(20, len(loader.image_files))  # Limit to 20 images for visualization
    
    if num_samples > 0:
        indices = np.random.choice(len(loader.image_files), num_samples, replace=False)
        images = [loader.load_image(idx) for idx in indices]
        
        # Extract features
        features = extractor.extract_features(images)
        print(f"Extracted features from {num_samples} images, shape: {features.shape}")
        
        # Generate some dummy labels for visualization purposes
        # In a real scenario, these would be your class labels or clusters
        dummy_labels = np.random.randint(0, 3, size=num_samples)
        
        # Visualize features using PCA
        print("Visualizing features using PCA...")
        visualize_features_2d(features, labels=dummy_labels, method='pca')
        
        # Visualize features using t-SNE
        if num_samples >= 5:  # t-SNE works better with more samples
            print("Visualizing features using t-SNE...")
            visualize_features_2d(features, labels=dummy_labels, method='tsne')
    else:
        print("No images found in the data directory.")
else:
    print("Data directory not found or empty.")

# Store Features for Downstream Tasks

After extracting features, we need to save them for later use in downstream tasks like classification, regression, or clustering.

## Feature Storage Considerations

1. **Format**: Store as NumPy arrays or PyTorch tensors
2. **Metadata**: Include information about the source images and extraction parameters
3. **Compression**: Consider compression for large feature sets
4. **Indexing**: Enable efficient retrieval for specific images

## Conclusion and Next Steps

In this notebook, we have:

1. **Loaded** satellite imagery data
2. **Extracted** deep features using pre-trained CNNs
3. **Visualized** feature maps to understand what the network detects
4. **Reduced dimensions** to visualize the feature space
5. **Saved features** for downstream applications

These features can now be used for:
- Land cover classification
- Object detection
- Change detection
- Anomaly detection
- And many more remote sensing applications

In [None]:
def save_features(features, file_paths, output_path):
    """
    Save extracted features to disk.
    
    Args:
        features (torch.Tensor or np.ndarray): Extracted features.
        file_paths (list): List of image file paths.
        output_path (str): Path to save the features.
    """
    # Convert to numpy if needed
    if isinstance(features, torch.Tensor):
        features = features.cpu().numpy()
    
    # Create a dictionary with file paths and features
    feature_dict = {
        'file_paths': [os.path.basename(fp) for fp in file_paths],
        'features': features,
        'extraction_info': {
            'model': 'ResNet50',
            'layer': 'avgpool',
            'date_extracted': pd.Timestamp.now().isoformat(),
            'feature_dim': features.shape[1]
        }
    }
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Save using numpy
    np.save(output_path, feature_dict)
    print(f"Features saved to {output_path}")
    
def load_features(input_path):
    """
    Load saved features.
    
    Args:
        input_path (str): Path to the saved features.
        
    Returns:
        tuple: (file_paths, features, extraction_info)
    """
    feature_dict = np.load(input_path, allow_pickle=True).item()
    return feature_dict['file_paths'], feature_dict['features'], feature_dict.get('extraction_info', None)

# Check if data directory exists and we've extracted features
if os.path.exists(DATA_DIR) and 'features' in locals() and 'indices' in locals() and len(loader.image_files) > 0:
    # Create output directory if it doesn't exist
    output_dir = "../output/features"
    os.makedirs(output_dir, exist_ok=True)
    
    # Save features extracted earlier
    file_paths = [loader.image_files[idx] for idx in indices]
    output_path = os.path.join(output_dir, "resnet50_features.npy")
    save_features(features, file_paths, output_path)
    
    # Test loading
    loaded_paths, loaded_features, extraction_info = load_features(output_path)
    print(f"Loaded {len(loaded_paths)} file paths and features with shape {loaded_features.shape}")
    print(f"Extraction info: {extraction_info}")
    
    # Display a summary of what we've accomplished
    print("\n=== Feature Extraction Summary ===")
    print(f"• Images processed: {num_samples}")
    print(f"• Feature dimensions: {features.shape[1]}")
    print(f"• Model used: ResNet50")
    print(f"• Features saved to: {output_path}")
    
    # Show what can be done with these features
    print("\n=== Next Steps ===")
    print("1. Use these features for classification tasks")
    print("2. Apply clustering to find patterns in the data")
    print("3. Train models for object detection or segmentation")
    print("4. Develop change detection algorithms")
else:
    print("Unable to save features: either no data directory found or no features extracted.")