# GCN Crop Classification Tutorial

A step-by-step walkthrough of **crop classification from Sentinel-2 imagery** using a **Graph Convolutional Network (GCN)** built with PyTorch Geometric.

This notebook covers the full pipeline: data loading, graph construction, model training, and evaluation.

## 1. Setup & Imports

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import kneighbors_graph
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

DATA_DIR = Path('data')
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'PyTorch: {torch.__version__}')
print(f'Device: {device}')

## 2. Load & Explore Data

The training dataset contains ~24,000 labeled pixels with 23 spectral/vegetation features across 5 crop classes.

In [None]:
df = pd.read_csv(DATA_DIR / 'crop_training_data_5classes_2020.csv')
drop_cols = [c for c in ['.geo', 'system:index', 'GCVI'] if c in df.columns]
df = df.drop(columns=drop_cols)
df = df.drop_duplicates().reset_index(drop=True)

print(f'Samples: {df.shape[0]}, Columns: {df.shape[1]}')
print(f'\nClass distribution:')
print(df['classname'].value_counts().sort_index())

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
colors = ['#e6194b', '#c4a35a', '#3cb44b', '#4363d8', '#f5d742']
df['classname'].value_counts().sort_index().plot(kind='bar', color=colors, ax=ax)
ax.set_title('Class Distribution', fontsize=14)
ax.set_ylabel('Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

## 3. Prepare Features & Labels

We extract 23 numeric features (10 spectral bands + 13 vegetation indices) and use the integer class labels from the CSV.

In [None]:
feature_cols = [c for c in df.columns if c not in ['class', 'classname']]
class_names = sorted(df['classname'].unique())

X = df[feature_cols].values
y = df['class'].values
num_classes = len(np.unique(y))
num_features = X.shape[1]

print(f'Features ({num_features}): {feature_cols}')
print(f'Classes ({num_classes}): {class_names}')

## 4. Train/Val/Test Split & Normalization

70/15/15 stratified split. The `StandardScaler` is fit **only on the training set** to prevent data leakage.

In [None]:
indices = np.arange(X.shape[0])
train_idx, temp_idx = train_test_split(indices, test_size=0.3, stratify=y, random_state=SEED)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, stratify=y[temp_idx], random_state=SEED)

scaler = StandardScaler()
X_scaled = X.copy()
X_scaled[train_idx] = scaler.fit_transform(X[train_idx])
X_scaled[val_idx] = scaler.transform(X[val_idx])
X_scaled[test_idx] = scaler.transform(X[test_idx])

print(f'Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}')

## 5. Build KNN Graph

GCNs require graph structure. We build a **K-nearest neighbor graph (k=8)** in feature space -- edges connect pixels with similar spectral signatures.

In [None]:
K = 8
knn_adj = kneighbors_graph(X_scaled, n_neighbors=K, mode='connectivity', include_self=False)
knn_adj = knn_adj + knn_adj.T
knn_adj[knn_adj > 1] = 1

rows, cols = knn_adj.nonzero()
edge_index = torch.tensor(np.array([rows, cols]), dtype=torch.long)
print(f'Nodes: {X_scaled.shape[0]:,}, Edges: {edge_index.shape[1]:,}')
print(f'Average degree: {edge_index.shape[1] / X_scaled.shape[0]:.1f}')

## 6. Create PyG Data Object

In [None]:
x_tensor = torch.tensor(X_scaled, dtype=torch.float)
y_tensor = torch.tensor(y, dtype=torch.long)

train_mask = torch.zeros(X_scaled.shape[0], dtype=torch.bool); train_mask[train_idx] = True
val_mask = torch.zeros(X_scaled.shape[0], dtype=torch.bool); val_mask[val_idx] = True
test_mask = torch.zeros(X_scaled.shape[0], dtype=torch.bool); test_mask[test_idx] = True

data = Data(x=x_tensor, edge_index=edge_index, y=y_tensor,
            train_mask=train_mask, val_mask=val_mask, test_mask=test_mask).to(device)
print(data)

## 7. Define the GCN Model

```
GCNConv(23 -> 128) -> BN -> ReLU -> Dropout(0.5)
GCNConv(128 -> 128) -> BN -> ReLU -> Dropout(0.5)
GCNConv(128 -> 5)   -> Output logits
```

In [None]:
class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers=3, dropout=0.5):
        super().__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(GCNConv(in_dim, hidden_dim))
        self.bns.append(nn.BatchNorm1d(hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        self.convs.append(GCNConv(hidden_dim, out_dim))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i in range(len(self.convs) - 1):
            x = self.convs[i](x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)

model = GCN(num_features, 128, num_classes, num_layers=3, dropout=0.5).to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')
print(model)

## 8. Train the Model

In [None]:
class_counts = np.bincount(y[train_idx])
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * num_classes
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights)

train_losses, val_accs = [], []
best_val_acc, patience_counter = 0.0, 0

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        pred = model(data.x, data.edge_index)[data.val_mask].argmax(1)
        val_acc = (pred == data.y[data.val_mask]).float().mean().item()

    train_losses.append(loss.item())
    val_accs.append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state = {k: v.clone() for k, v in model.state_dict().items()}
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= 30:
        print(f'Early stopping at epoch {epoch}. Best val acc: {best_val_acc:.4f}')
        break
    if epoch % 20 == 0:
        print(f'Epoch {epoch:3d}  Loss: {loss:.4f}  Val Acc: {val_acc:.4f}')

model.load_state_dict(best_state)
print(f'\nBest Validation Accuracy: {best_val_acc:.4f}')

## 9. Training Visualization

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(train_losses, color='blue')
ax1.set_title('Training Loss', fontsize=13)
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.grid(alpha=0.3)
ax2.plot(val_accs, color='green')
ax2.set_title('Validation Accuracy', fontsize=13)
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy'); ax2.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 10. Evaluate on Test Set

In [None]:
model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)
    test_pred = out[data.test_mask].argmax(1).cpu().numpy()
    test_true = data.y[data.test_mask].cpu().numpy()

idx_to_class = {0: 'Cotton', 1: 'Wheat', 2: 'Fallow', 3: 'Grass', 4: 'Water'}
target_names = [idx_to_class[i] for i in range(num_classes)]

print('Classification Report:\n')
print(classification_report(test_true, test_pred, target_names=target_names, digits=4))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
cm = confusion_matrix(test_true, test_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names, ax=axes[0])
axes[0].set_title('Confusion Matrix', fontsize=13)
axes[0].set_xlabel('Predicted'); axes[0].set_ylabel('True')

cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
sns.heatmap(cm_norm, annot=True, fmt='.3f', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names, ax=axes[1])
axes[1].set_title('Normalized Confusion Matrix', fontsize=13)
axes[1].set_xlabel('Predicted'); axes[1].set_ylabel('True')
plt.tight_layout()
plt.show()

## Next Steps

- Apply the trained model to a full Sentinel-2 raster: `python apply_gcn_to_raster.py`
- Explore results in the interactive Streamlit dashboard: `streamlit run app.py`
- Experiment with different graph construction methods (spatial vs spectral KNN)
- Try varying the number of GCN layers, hidden dimensions, or k-neighbors

---

*This tutorial is part of the [GCN Crop Classification](https://github.com/Osman-Geomatics93/GCN-Crop-Classification) project.*