# 🎨 Function 4: Visualize Raster Data

## Building the `visualize_raster_data` Function

**Learning Objectives:**
- Master comprehensive raster visualization techniques
- Understand color mapping, stretching, and enhancement methods
- Learn to create publication-quality raster maps and plots
- Implement multi-band and composite visualizations
- Handle different data types and value ranges effectively
- Create both static and interactive visualization workflows

**Professional Context:**
Effective raster visualization is crucial for geospatial communication and analysis. Professionals use these techniques for:
- Creating compelling maps for stakeholders and publications
- Exploring data patterns and identifying anomalies
- Quality assessment and data validation
- Scientific communication and reporting
- Decision support and public engagement

## 🎯 Function Overview

**Function Signature:**
```python
def visualize_raster_data(raster_path, bands=None, visualization_type='single', 
                         color_map='viridis', stretch_type='linear', 
                         percentile_range=(2, 98), title=None, 
                         save_path=None, show_statistics=True, 
                         figure_size=(12, 8)):
    """
    Create comprehensive visualizations of raster data with multiple display options.
    
    Parameters:
    -----------
    raster_path : str
        Path to the input raster file
    bands : int, list, or None
        Band number(s) to visualize. If None, uses first band
    visualization_type : str
        Type of visualization ('single', 'multi_band', 'rgb', 'comparison')
    color_map : str
        Matplotlib colormap name for single-band visualization
    stretch_type : str
        Data stretching method ('linear', 'percentile', 'histogram_eq')
    percentile_range : tuple
        Percentile range for stretching (min_percentile, max_percentile)
    title : str, optional
        Custom title for the visualization
    save_path : str, optional
        Path to save the visualization
    show_statistics : bool
        Whether to display data statistics on the plot
    figure_size : tuple
        Figure size in inches (width, height)
    
    Returns:
    --------
    dict
        Dictionary containing visualization results and metadata
    """
```

**Key Capabilities:**
- 🎨 Multiple visualization types (single, multi-band, RGB, comparison)
- 🌈 Comprehensive color mapping and data stretching options
- 📊 Integrated statistics display and data distribution analysis
- 📐 Automatic scaling and enhancement techniques
- 💾 High-quality output generation for publications and reports
- 🔧 Customizable layouts, styling, and display options

## 🚀 Hands-On Example: Building the Function

Let's build the complete raster visualization function step by step:

In [None]:
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from skimage import exposure
import os
import warnings
warnings.filterwarnings('ignore')

def visualize_raster_data(raster_path, bands=None, visualization_type='single', 
                         color_map='viridis', stretch_type='linear', 
                         percentile_range=(2, 98), title=None, 
                         save_path=None, show_statistics=True, 
                         figure_size=(12, 8)):
    """
    Create comprehensive visualizations of raster data with multiple display options.
    """
    
    print(f"🎨 Starting raster visualization...")
    print(f"   📁 Input raster: {os.path.basename(raster_path)}")
    print(f"   🖼️ Visualization type: {visualization_type}")
    
    # Step 1: Input validation and data loading
    if not os.path.exists(raster_path):
        raise FileNotFoundError(f"Raster file not found: {raster_path}")
    
    with rasterio.open(raster_path) as src:
        # Handle band selection
        if bands is None:
            bands = [1]  # Default to first band
        elif isinstance(bands, int):
            bands = [bands]
        
        # Validate band numbers
        for band in bands:
            if band < 1 or band > src.count:
                raise ValueError(f"Band {band} not found. Raster has {src.count} bands.")
        
        # Read data
        if len(bands) == 1:
            data = src.read(bands[0])
        else:
            data = src.read(bands)
        
        # Get metadata
        profile = src.profile.copy()
        bounds = src.bounds
        crs = src.crs
        nodata = src.nodata
        
        print(f"   📊 Data shape: {data.shape}")
        print(f"   📏 Bounds: {bounds}")
        print(f"   🌍 CRS: {crs}")
    
    # Step 2: Data preprocessing and masking
    if nodata is not None:
        data_masked = np.ma.masked_equal(data, nodata)
    else:
        data_masked = np.ma.masked_invalid(data)
    
    # Calculate statistics for single band
    if len(bands) == 1:
        valid_data = data_masked.compressed()
        if len(valid_data) > 0:
            data_stats = {
                'min': float(np.min(valid_data)),
                'max': float(np.max(valid_data)),
                'mean': float(np.mean(valid_data)),
                'std': float(np.std(valid_data)),
                'median': float(np.median(valid_data)),
                'count': len(valid_data),
                'total_pixels': int(data.size)
            }
            print(f"   📈 Data range: {data_stats['min']:.2f} - {data_stats['max']:.2f}")
            print(f"   📈 Mean: {data_stats['mean']:.2f} ± {data_stats['std']:.2f}")
        else:
            raise ValueError("No valid data found")
    
    # Step 3: Apply data stretching
    def apply_stretch(input_data, method):
        """Apply stretching to data array."""
        if method == 'linear':
            return input_data, data_stats['min'], data_stats['max']
        elif method == 'percentile':
            vmin, vmax = np.percentile(valid_data, percentile_range)
            return np.clip(input_data, vmin, vmax), vmin, vmax
        elif method == 'histogram_eq':
            if hasattr(input_data, 'filled'):
                filled_data = input_data.filled(np.nan)
            else:
                filled_data = input_data.copy()
            stretched = exposure.equalize_hist(filled_data, mask=~np.isnan(filled_data))
            return np.ma.masked_invalid(stretched), np.nanmin(stretched), np.nanmax(stretched)
        else:
            return input_data, data_stats['min'], data_stats['max']
    
    # Step 4: Create visualizations based on type
    extent = [bounds.left, bounds.right, bounds.bottom, bounds.top]
    
    if visualization_type == 'single':
        # Single band visualization
        fig, axes = plt.subplots(1, 2 if show_statistics else 1, figsize=figure_size)
        if show_statistics and hasattr(axes, '__len__'):
            ax_main, ax_hist = axes
        else:
            ax_main = axes if not hasattr(axes, '__len__') else axes
            ax_hist = None
        
        # Apply stretch and display
        stretched_data, vmin, vmax = apply_stretch(data_masked, stretch_type)
        
        im = ax_main.imshow(stretched_data, cmap=color_map, vmin=vmin, vmax=vmax,
                           extent=extent, aspect='auto')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax_main, shrink=0.8)
        cbar.set_label('Value', rotation=270, labelpad=20)
        
        # Set title and labels
        plot_title = title or f"Raster Visualization - Band {bands[0]} ({stretch_type.title()} Stretch)"
        ax_main.set_title(plot_title, fontsize=14, fontweight='bold')
        ax_main.set_xlabel('Longitude' if crs and crs.is_geographic else 'X')
        ax_main.set_ylabel('Latitude' if crs and crs.is_geographic else 'Y')
        
        # Add statistics text
        if show_statistics:
            stats_text = f"Min: {data_stats['min']:.2f}\nMax: {data_stats['max']:.2f}\nMean: {data_stats['mean']:.2f}\nStd: {data_stats['std']:.2f}"
            ax_main.text(0.02, 0.98, stats_text, transform=ax_main.transAxes,
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                        verticalalignment='top', fontsize=10)
        
        # Create histogram if space available
        if ax_hist is not None:
            ax_hist.hist(valid_data, bins=50, density=True, alpha=0.7, color='skyblue')
            ax_hist.axvline(data_stats['mean'], color='red', linestyle='--', label=f"Mean: {data_stats['mean']:.2f}")
            ax_hist.axvline(data_stats['median'], color='orange', linestyle='--', label=f"Median: {data_stats['median']:.2f}")
            ax_hist.set_xlabel('Value')
            ax_hist.set_ylabel('Density')
            ax_hist.set_title('Data Distribution')
            ax_hist.legend()
            ax_hist.grid(True, alpha=0.3)
    
    elif visualization_type == 'multi_band':
        # Multi-band visualization
        n_bands = len(bands)
        n_cols = min(3, n_bands)
        n_rows = (n_bands + n_cols - 1) // n_cols
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figure_size)
        
        if n_bands == 1:
            axes = [axes]
        elif hasattr(axes, 'flatten'):
            axes = axes.flatten()
        else:
            axes = [axes]
        
        for i, band_num in enumerate(bands):
            if i >= len(axes):
                break
            
            if data.ndim == 3:
                band_data = data[i]
            else:
                band_data = data
            
            # Mask nodata
            if nodata is not None:
                band_masked = np.ma.masked_equal(band_data, nodata)
            else:
                band_masked = np.ma.masked_invalid(band_data)
            
            # Apply percentile stretch for multi-band
            band_valid = band_masked.compressed()
            if len(band_valid) > 0:
                vmin, vmax = np.percentile(band_valid, [2, 98])
            else:
                vmin, vmax = 0, 1
            
            im = axes[i].imshow(band_masked, cmap=color_map, vmin=vmin, vmax=vmax,
                               extent=extent, aspect='auto')
            axes[i].set_title(f'Band {band_num}')
            axes[i].set_xlabel('X')
            axes[i].set_ylabel('Y')
            plt.colorbar(im, ax=axes[i], shrink=0.8)
        
        # Hide unused subplots
        for i in range(len(bands), len(axes)):
            axes[i].set_visible(False)
    
    elif visualization_type == 'rgb':
        # RGB composite visualization
        if len(bands) < 3:
            raise ValueError("RGB visualization requires at least 3 bands")
        
        fig, ax = plt.subplots(1, 1, figsize=figure_size)
        
        # Create RGB array
        if data.ndim == 3:
            rgb_data = np.dstack([data[i] for i in range(min(3, len(bands)))])
        else:
            raise ValueError("RGB requires multi-band data")
        
        # Normalize each band
        rgb_normalized = np.zeros_like(rgb_data, dtype=np.float32)
        for i in range(3):
            band = rgb_data[:, :, i]
            if nodata is not None:
                band_valid = band[band != nodata]
            else:
                band_valid = band[~np.isnan(band)]
            
            if len(band_valid) > 0:
                vmin, vmax = np.percentile(band_valid, [2, 98])
                rgb_normalized[:, :, i] = np.clip((band - vmin) / (vmax - vmin), 0, 1)
        
        # Handle nodata for display
        if nodata is not None:
            mask = (rgb_data[:, :, 0] == nodata) | (rgb_data[:, :, 1] == nodata) | (rgb_data[:, :, 2] == nodata)
            rgb_normalized[mask] = 0
        
        ax.imshow(rgb_normalized, extent=extent, aspect='auto')
        ax.set_title(title or f'RGB Composite (Bands {bands[0]}-{bands[1]}-{bands[2]})')
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
    
    elif visualization_type == 'comparison':
        # Comparison of different stretch methods
        fig, axes = plt.subplots(2, 2, figsize=figure_size)
        axes = axes.flatten()
        
        methods = ['linear', 'percentile', 'histogram_eq']
        
        for i, method in enumerate(methods):
            if i >= len(axes) - 1:  # Leave last subplot for histogram
                break
                
            stretched, vmin, vmax = apply_stretch(data_masked, method)
            
            im = axes[i].imshow(stretched, cmap=color_map, vmin=vmin, vmax=vmax,
                               extent=extent, aspect='auto')
            axes[i].set_title(f'{method.replace("_", " ").title()} Stretch')
            plt.colorbar(im, ax=axes[i], shrink=0.8)
        
        # Add histogram in last subplot
        if len(bands) == 1:
            axes[-1].hist(valid_data, bins=50, density=True, alpha=0.7, color='skyblue')
            axes[-1].axvline(data_stats['mean'], color='red', linestyle='--')
            axes[-1].set_title('Data Distribution')
            axes[-1].set_xlabel('Value')
            axes[-1].set_ylabel('Density')
    
    # Step 5: Final formatting and output
    plt.tight_layout()
    
    # Save if requested
    if save_path:
        print(f"   💾 Saving to: {save_path}")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Compile results
    results = {
        'visualization_successful': True,
        'data_statistics': data_stats if len(bands) == 1 else {'multi_band': True},
        'visualization_parameters': {
            'bands': bands,
            'visualization_type': visualization_type,
            'color_map': color_map,
            'stretch_type': stretch_type
        },
        'raster_info': {
            'shape': data.shape,
            'bounds': bounds,
            'crs': str(crs),
            'nodata': nodata
        }
    }
    
    if save_path:
        results['output_path'] = save_path
    
    print(f"\n🎉 Visualization complete!")
    return results

## 🧪 Test the Function

Let's test our visualization function with sample data:

In [None]:
# Create sample test raster data
print("🏗️ Creating sample test data...\n")

# Create synthetic elevation data
height, width = 200, 300
x = np.linspace(-120, -119, width)
y = np.linspace(37, 38, height)
X, Y = np.meshgrid(x, y)

# Generate terrain-like elevation data
elevation = (
    1000 + 
    500 * np.sin(X * 10) * np.cos(Y * 8) +
    300 * np.exp(-((X + 119.5)**2 + (Y - 37.5)**2) * 100) +
    np.random.normal(0, 50, (height, width))
)

# Clip to reasonable elevation range
elevation = np.clip(elevation, 0, 3000)

# Add some nodata areas
mask = np.random.random((height, width)) < 0.05
elevation[mask] = -9999

# Define geospatial properties
transform = rasterio.transform.from_bounds(-120, 37, -119, 38, width, height)

# Save test raster
test_raster_path = '/tmp/test_elevation.tif'
profile = {
    'driver': 'GTiff',
    'height': height,
    'width': width,
    'count': 1,
    'dtype': 'float32',
    'crs': 'EPSG:4326',
    'transform': transform,
    'nodata': -9999
}

with rasterio.open(test_raster_path, 'w', **profile) as dst:
    dst.write(elevation.astype(np.float32), 1)

print(f"✅ Created test raster: {test_raster_path}")
print(f"   📊 Shape: {height} x {width} pixels")
print(f"   🏔️ Elevation range: {elevation[elevation != -9999].min():.1f} - {elevation[elevation != -9999].max():.1f} m")
print(f"   ☁️ NoData pixels: {mask.sum()}")

In [None]:
# Test 1: Single band visualization with statistics
print("🎨 Test 1: Single Band with Statistics\n")

results = visualize_raster_data(
    raster_path=test_raster_path,
    bands=1,
    visualization_type='single',
    color_map='terrain',
    stretch_type='percentile',
    percentile_range=(5, 95),
    title='Elevation Data with Statistics',
    show_statistics=True,
    figure_size=(14, 6)
)

print(f"\n📊 Results:")
print(f"   Data range: {results['data_statistics']['min']:.1f} - {results['data_statistics']['max']:.1f} m")
print(f"   Mean elevation: {results['data_statistics']['mean']:.1f} m")
print(f"   Valid pixels: {results['data_statistics']['count']:,}")

In [None]:
# Test 2: Comparison of stretch methods
print("\n🔄 Test 2: Stretch Methods Comparison\n")

comparison_results = visualize_raster_data(
    raster_path=test_raster_path,
    visualization_type='comparison',
    color_map='viridis',
    title='Elevation Data - Stretch Comparison',
    figure_size=(14, 10)
)

print("Comparison visualization created showing different enhancement methods")

## 💡 Understanding Raster Visualization

### Key Concepts:

**Data Stretching Methods:**
- **Linear**: Uses actual min/max values - preserves data relationships
- **Percentile**: Uses specified percentile range - reduces outlier effects
- **Histogram Equalization**: Redistributes values for optimal contrast

**Color Maps:**
- **Sequential**: viridis, plasma, inferno - for continuous data
- **Diverging**: coolwarm, RdBu - for data with meaningful center
- **Specialized**: terrain, ocean - for specific data types

**Visualization Types:**
- **Single**: Individual band with optional statistics
- **Multi-band**: Side-by-side comparison of bands
- **RGB**: Color composite from three bands
- **Comparison**: Different enhancement methods compared

### Best Practices:
- Choose appropriate color maps for your data type
- Use percentile stretching to handle outliers
- Include statistics for quantitative analysis
- Save high-resolution outputs for publications

## 🎯 Your Task: Implement and Test

**Requirements:**
1. **Implement the function** exactly as shown above
2. **Support multiple visualization types** for different analysis needs
3. **Apply appropriate data stretching** for optimal visual enhancement
4. **Handle nodata values** and edge cases gracefully
5. **Generate publication-quality outputs** with proper formatting
6. **Provide comprehensive statistics** and metadata

**Testing Strategy:**
```python
# Test different scenarios:
# 1. Single band elevation data
# 2. Multi-band satellite imagery
# 3. RGB composite creation
# 4. Data with extreme outliers
# 5. Various color map applications
```

## 🔧 Testing Your Implementation

Run the official tests to verify your function works correctly:

```bash
cd /workspaces/your-repo
python -m pytest tests/test_rasterio_functions.py::test_visualize_raster_data -v
```

### Additional Testing Ideas:
```python
# Test with real satellite data
results = visualize_raster_data(
    raster_path='landsat_scene.tif',
    bands=[4, 3, 2],  # NIR, Red, Green
    visualization_type='rgb',
    title='Landsat False Color Composite'
)

# Test with DEM data
results = visualize_raster_data(
    raster_path='elevation.tif',
    visualization_type='single',
    color_map='terrain',
    stretch_type='percentile',
    show_statistics=True
)
```

## 📚 Professional Applications

### Real-World Use Cases:

**Environmental Monitoring:**
- Creating compelling visualizations of satellite imagery for reports
- Analyzing vegetation health through false color composites
- Monitoring land use change with before/after comparisons

**Scientific Research:**
- Publication-quality figures for research papers
- Data exploration and pattern identification
- Quality assessment of remote sensing data

**Urban Planning:**
- Visualizing elevation models for development planning
- Creating maps for public engagement and stakeholder meetings
- Analyzing urban heat islands and temperature patterns

**Natural Resource Management:**
- Forest health assessment visualizations
- Water resource monitoring and reporting
- Agricultural productivity analysis

### Industry Standards:
- **Resolution**: Match visualization resolution to intended use
- **Color Schemes**: Use colorblind-friendly palettes when possible
- **Metadata**: Include proper legends, scale bars, and coordinate information
- **File Formats**: PNG for web, PDF for print, TIFF for archival

## 🚀 Next Steps

**Congratulations! You've completed all 4 basic rasterio functions:**

1. ✅ **Load and Explore Raster** - Data loading and basic exploration
2. ✅ **Calculate Raster Statistics** - Statistical analysis and data summaries
3. ✅ **Extract Raster Subset** - Spatial subsetting and area-of-interest extraction
4. ✅ **Visualize Raster Data** - Comprehensive visualization and display techniques

### Core Skills Achieved:
- 🗂️ **Data Management**: Loading, exploring, and validating raster datasets
- 📊 **Statistical Analysis**: Computing comprehensive metrics and summaries
- ✂️ **Spatial Operations**: Extracting subsets and areas of interest
- 🎨 **Visualization**: Creating publication-quality maps and analyses

### Next Module Options:
Move on to advanced rasterio-analysis functions for specialized workflows:
- **Topographic Metrics**: Terrain analysis from elevation data
- **Vegetation Indices**: Multispectral analysis and vegetation monitoring
- **Spatial Sampling**: Point and zonal statistics extraction
- **Cloud Optimized GeoTIFF**: Modern raster data formats
- **STAC Integration**: Cloud-native geospatial data access

**You now have solid foundational skills in raster data processing! 🎉**

## 🎓 Real-World Applications

The raster visualization techniques you've mastered are used for:
- **Scientific Communication**: Creating figures for research publications
- **Data Exploration**: Identifying patterns and anomalies in large datasets
- **Quality Control**: Validating data processing and analysis results
- **Stakeholder Engagement**: Creating compelling maps for decision-makers
- **Education and Training**: Teaching spatial concepts through visual examples

**Excellent work on mastering raster visualization fundamentals! 🍀**