# Transfer Learning for Land Cover Classification

This notebook demonstrates how to apply transfer learning for land cover classification of remote sensing imagery using a pre-trained deep learning model (ResNet) with `torch` and `torchvision` in Python. Transfer learning leverages pre-trained models to improve performance on small datasets, making it ideal for remote sensing tasks with limited labeled data.

## Prerequisites
- Install required libraries: `torch`, `torchvision`, `rasterio`, `geopandas`, `numpy`, `matplotlib`, `scikit-learn` (listed in `requirements.txt`).
- A preprocessed multi-band GeoTIFF (e.g., `fused_feature_stack.tif` from `25_multisensor_fusion.ipynb`).
- A labeled vector dataset (e.g., `landcover_labels.shp`) with land cover classes.
- Replace file paths with your own data.
- GPU recommended for faster training.

## Learning Objectives
- Load and preprocess remote sensing imagery and labeled vector data.
- Fine-tune a pre-trained ResNet model for land cover classification.
- Evaluate model performance using accuracy and confusion matrix.
- Visualize predicted land cover classes.

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
import rasterio
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import os
from torch.utils.data import Dataset, DataLoader

## Step 1: Load Raster and Labeled Data

Load the preprocessed multi-band GeoTIFF and labeled vector data for training.

In [None]:
# Define file paths
raster_path = 'remote_sensing_data/fused_feature_stack.tif'  # Replace with your GeoTIFF
labels_path = 'landcover_labels.shp'                        # Replace with your labeled shapefile

# Load raster
with rasterio.open(raster_path) as src:
    raster_data = src.read(masked=True)  # Shape: (bands, height, width)
    raster_profile = src.profile
    raster_crs = src.crs

# Load labeled vector data
labels_gdf = gpd.read_file(labels_path)
if labels_gdf.crs != raster_crs:
    labels_gdf = labels_gdf.to_crs(raster_crs)

# Extract class labels (assumes 'class' column)
class_names = labels_gdf['class'].unique()
class_map = {name: idx for idx, name in enumerate(class_names)}
labels_gdf['class_id'] = labels_gdf['class'].map(class_map)

# Print basic information
print(f'Raster shape: {raster_data.shape}')
print(f'Raster CRS: {raster_crs}')
print(f'Labels CRS: {labels_gdf.crs}')
print(f'Classes: {class_names}')

## Step 2: Create Custom Dataset

Extract image patches and corresponding labels to create a dataset for training.

In [None]:
# Define custom dataset class
class RemoteSensingDataset(Dataset):
    def __init__(self, raster_data, labels_gdf, transform=None, patch_size=64):
        self.raster_data = raster_data
        self.labels_gdf = labels_gdf
        self.transform = transform
        self.patch_size = patch_size
        self.patches = []
        self.labels = []

        # Extract patches and labels
        height, width = raster_data.shape[1], raster_data.shape[2]
        for idx, row in labels_gdf.iterrows():
            centroid = row.geometry.centroid
            row_idx, col_idx = rasterio.transform.rowcol(raster_profile['transform'], centroid.x, centroid.y)
            if (row_idx - patch_size//2 >= 0 and row_idx + patch_size//2 < height and
                col_idx - patch_size//2 >= 0 and col_idx + patch_size//2 < width):
                patch = raster_data[:, row_idx-patch_size//2:row_idx+patch_size//2,
                                    col_idx-patch_size//2:col_idx+patch_size//2]
                if not np.any(np.isnan(patch)):
                    self.patches.append(patch)
                    self.labels.append(row['class_id'])

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        patch = self.patches[idx].astype(np.float32)
        label = self.labels[idx]
        if self.transform:
            patch = self.transform(patch)
        return patch, label

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5] * raster_data.shape[0], std=[0.2] * raster_data.shape[0])
])

# Create dataset
dataset = RemoteSensingDataset(raster_data, labels_gdf, transform=transform, patch_size=64)

# Split dataset into train and validation sets
train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')

## Step 3: Initialize and Fine-Tune ResNet Model

Load a pre-trained ResNet model and modify it for the number of classes in the dataset.

In [None]:
# Initialize ResNet model
model = models.resnet18(pretrained=True)

# Modify input layer to match number of input bands
num_bands = raster_data.shape[0]
model.conv1 = nn.Conv2d(num_bands, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify output layer to match number of classes
num_classes = len(class_names)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
train_losses, val_losses = [], []
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for patches, labels in train_loader:
        patches, labels = patches.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(patches)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss / len(train_loader))

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for patches, labels in val_loader:
            patches, labels = patches.to(device), labels.to(device)
            outputs = model(patches)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    val_losses.append(val_loss / len(val_loader))

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}')

# Save trained model
torch.save(model.state_dict(), 'resnet_landcover.pth')
print('Trained model saved to: resnet_landcover.pth')

## Step 4: Evaluate Model Performance

Evaluate the model on the validation set and compute accuracy and confusion matrix.

In [None]:
# Evaluate model
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for patches, labels in val_loader:
        patches, labels = patches.to(device), labels.to(device)
        outputs = model(patches)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute accuracy and confusion matrix
accuracy = accuracy_score(all_labels, all_preds)
conf_matrix = confusion_matrix(all_labels, all_preds)

print(f'Validation Accuracy: {accuracy:.4f}')
print('Confusion Matrix:')
print(conf_matrix)

# Visualize confusion matrix
plt.figure(figsize=(8, 6))
plt.imshow(conf_matrix, cmap='Blues')
plt.colorbar(label='Count')
plt.xticks(np.arange(num_classes), class_names, rotation=45)
plt.yticks(np.arange(num_classes), class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

## Step 5: Predict and Visualize Land Cover

Apply the trained model to the entire raster to generate a land cover map.

In [None]:
# Predict land cover across the entire raster
patch_size = 64
height, width = raster_data.shape[1], raster_data.shape[2]
pred_map = np.zeros((height, width), dtype=np.uint8)

model.eval()
with torch.no_grad():
    for i in range(0, height - patch_size + 1, patch_size):
        for j in range(0, width - patch_size + 1, patch_size):
            patch = raster_data[:, i:i+patch_size, j:j+patch_size].astype(np.float32)
            if not np.any(np.isnan(patch)):
                patch_tensor = transform(patch).unsqueeze(0).to(device)
                output = model(patch_tensor)
                pred = torch.argmax(output, dim=1).cpu().numpy()[0]
                pred_map[i:i+patch_size, j:j+patch_size] = pred

# Visualize predicted land cover
plt.figure(figsize=(8, 8))
plt.imshow(pred_map, cmap='tab10', interpolation='nearest')
plt.colorbar(ticks=np.arange(num_classes), label='Class')
plt.clim(-0.5, num_classes-0.5)
plt.title('Predicted Land Cover Map')
plt.xlabel('Column')
plt.ylabel('Row')
plt.show()

# Save predicted land cover as GeoTIFF
pred_profile = raster_profile.copy()
pred_profile.update({'count': 1, 'dtype': 'uint8', 'nodata': None})
pred_output_path = 'remote_sensing_data/landcover_prediction.tif'
with rasterio.open(pred_output_path, 'w', **pred_profile) as dst:
    dst.write(pred_map, 1)

print(f'Predicted land cover saved to: {pred_output_path}')

## Next Steps

- Replace `fused_feature_stack.tif` and `landcover_labels.shp` with your own GeoTIFF and labeled data (e.g., from `25_multisensor_fusion.ipynb`).
- Adjust `patch_size` or model architecture (e.g., ResNet50) based on your dataset and computational resources.
- Explore other pre-trained models (e.g., EfficientNet) or fine-tuning strategies to improve performance.
- Use the predicted land cover map in visualization notebooks like `23_kepler_gl_demo.ipynb` or `22_folium_visualization.ipynb`.
- Proceed to advanced analysis like change detection (see `18_change_detection.ipynb`) using the predicted maps.

## Notes
- Ensure the labeled shapefile contains a 'class' column with land cover categories.
- The number of input bands must match the raster data; adjust the model input layer if using different sensors.
- Transfer learning assumes sufficient labeled data; consider data augmentation if the dataset is small.
- See `docs/installation.md` for troubleshooting library installation.