# Land Cover Classification with CNN

This notebook demonstrates how to perform land cover classification using a Convolutional Neural Network (CNN) with `torch` in Python. The CNN classifies image patches extracted from a multi-band raster, suitable for remote sensing tasks like land cover mapping.

## Prerequisites
- Install required libraries: `rasterio`, `geopandas`, `torch`, `numpy`, `matplotlib` (listed in `requirements.txt`).
- A multi-band GeoTIFF file (e.g., `sample.tif`) and a shapefile with labeled data (e.g., `labels.shp`). Replace file paths with your own data.
- GPU recommended for faster training.

## Learning Objectives
- Extract labeled patches from a raster for CNN training.
- Train a CNN for land cover classification.
- Predict and visualize classification results.

In [None]:
# Import required libraries
import rasterio
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from rasterio.features import geometry_mask

## Step 1: Create Custom Dataset

Define a custom dataset to extract image patches centered on labeled points.

In [None]:
class RasterPatchDataset(Dataset):
    def __init__(self, image_path, shapefile_path, patch_size=64):
        self.image_path = image_path
        self.shapefile_path = shapefile_path
        self.patch_size = patch_size
        
        # Load shapefile
        gdf = gpd.read_file(shapefile_path)
        
        # Load raster
        with rasterio.open(image_path) as src:
            self.image = src.read().astype(np.float32)
            self.transform = src.transform
            self.crs = src.crs
            self.profile = src.profile
        
        # Reproject shapefile to match raster CRS
        if gdf.crs != self.crs:
            gdf = gdf.to_crs(self.crs)
        
        # Extract point coordinates and labels (assumes points with 'class' column)
        self.points = gdf.geometry.centroid
        self.labels = gdf['class'].values
        
        # Normalize image
        self.image = self.image / np.max(self.image, axis=(1, 2), keepdims=True)
    
    def __len__(self):
        return len(self.points)
    
    def __getitem__(self, idx):
        # Get point coordinates
        point = self.points.iloc[idx]
        x, y = point.x, point.y
        
        # Convert to pixel coordinates
        row, col = ~self.transform * (x, y)
        row, col = int(row), int(col)
        
        # Extract patch
        half_patch = self.patch_size // 2
        patch = self.image[:, row-half_patch:row+half_patch, col-half_patch:col+half_patch]
        
        # Ensure patch is correct size
        if patch.shape[1:] != (self.patch_size, self.patch_size):
            patch = np.zeros((self.image.shape[0], self.patch_size, self.patch_size), dtype=np.float32)
        
        # Get label
        label = self.labels[idx]
        
        return torch.from_numpy(patch), torch.tensor(label, dtype=torch.long)

## Step 2: Define CNN Model

Create a simple CNN architecture for patch classification.

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, in_channels, n_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 16 * 16, 128)  # Adjust based on patch_size=64
        self.fc2 = nn.Linear(128, n_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Step 3: Load Data and Prepare Dataloaders

Load the dataset, split into training/validation sets, and create dataloaders.

In [None]:
# Define file paths
image_path = 'sample.tif'
shapefile_path = 'labels.shp'

# Create dataset
dataset = RasterPatchDataset(image_path, shapefile_path, patch_size=64)

# Split into training and validation
train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Print dataset information
print(f'Total samples: {len(dataset)}')
print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')

## Step 4: Initialize and Train CNN

Initialize the CNN and train it on the dataset.

In [None]:
# Initialize model
n_classes = len(np.unique(dataset.labels))
n_channels = dataset.image.shape[0]
model = SimpleCNN(in_channels=n_channels, n_classes=n_classes)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
n_epochs = 10
for epoch in range(n_epochs):
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
    
    train_loss /= len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{n_epochs}, Training Loss: {train_loss:.4f}')

## Step 5: Predict and Visualize Classification

Predict classifications across the raster using a sliding window approach.

In [None]:
# Load full raster for prediction
with rasterio.open(image_path) as src:
    full_image = src.read().astype(np.float32)
    profile = src.profile
full_image = full_image / np.max(full_image, axis=(1, 2), keepdims=True)

# Initialize output array
height, width = full_image.shape[1], full_image.shape[2]
predictions = np.zeros((height, width), dtype=np.int64)

# Predict in patches
model.eval()
with torch.no_grad():
    for i in range(0, height, 64):
        for j in range(0, width, 64):
            patch = full_image[:, i:i+64, j:j+64]
            if patch.shape[1:] != (64, 64):
                continue  # Skip incomplete patches
            patch = torch.from_numpy(patch).unsqueeze(0).to(device)
            output = model(patch)
            pred = torch.argmax(output, dim=1).cpu().numpy()[0]
            predictions[i:i+64, j:j+64] = pred

# Visualize predictions
plt.figure(figsize=(8, 8))
plt.imshow(predictions, cmap='tab10')
plt.colorbar(label='Class')
plt.title('CNN Land Cover Classification')
plt.xlabel('Column')
plt.ylabel('Row')
plt.show()

## Step 6: Save Classification Result

Save the classification result as a single-band GeoTIFF.

In [None]:
# Update profile for single-band output
output_profile = profile.copy()
output_profile.update(count=1, dtype=rasterio.int64)

# Save predictions
with rasterio.open('cnn_classification.tif', 'w', **output_profile) as dst:
    dst.write(predictions, 1)

print('Classification result saved to: cnn_classification.tif')

## Next Steps

- Replace `sample.tif` and `labels.shp` with your own image and labeled shapefile.
- Adjust patch size, number of epochs, or CNN architecture (e.g., add more layers).
- Add validation metrics (e.g., accuracy, F1-score) for evaluation.
- Proceed to the next notebook (`17_time_series_analysis.ipynb`) for time series analysis.

## Notes
- Ensure the shapefile contains points with a 'class' column for labels.
- Normalize input data to improve training stability.
- See `docs/installation.md` for troubleshooting library installation.