# Sentinel-2 Classification Pipeline

Complete workflow from data loading to classification using index-based labeling.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

from src.sentinel2_classifier import (
    Sentinel2Classifier,
    calculate_indices_from_sentinel2,
    create_sample_labels_from_index,
    get_raster_info,
    load_sentinel2_image,
    prepare_features,
    save_classified_raster,
    visualize_classification,
)

## 1. Data Loading and Inspection

In [None]:
# Set your Sentinel-2 image path
image_path = "path/to/sentinel2_image.tif"  # Replace with actual path

# Load and inspect image
try:
    data, profile = load_sentinel2_image(image_path)
    print(f"Image shape: {data.shape}")
    print("\nRaster info:")
    info = get_raster_info(image_path)
    for key, value in info.items():
        print(f"{key}: {value}")
except FileNotFoundError:
    print("Please set a valid Sentinel-2 image path")
    # Create dummy data for demo
    data = np.random.randint(0, 4000, (4, 100, 100)).astype(np.uint16)
    profile = {"width": 100, "height": 100, "count": 4}
    print("Using dummy data for demonstration")

## 2. Calculate Indices and Visualize

In [None]:
# Calculate NDVI and NDWI
ndvi, ndwi = calculate_indices_from_sentinel2(data)

# Visualize indices
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

im1 = axes[0].imshow(ndvi, cmap="RdYlGn", vmin=-1, vmax=1)
axes[0].set_title("NDVI")
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(ndwi, cmap="Blues", vmin=-1, vmax=1)
axes[1].set_title("NDWI")
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

## 3. Generate Dataset with Index-based Labels

In [None]:
# Prepare features (sklearn format)
features = prepare_features(data)
print(f"Features shape: {features.shape}")

# Generate labels from indices
labels = create_sample_labels_from_index(data)
print(f"Labels shape: {labels.shape}")
print(f"Classes: {np.unique(labels)} (0=Water, 1=Vegetation, 2=Urban)")
print(f"Class distribution: {np.bincount(labels)}")

## 4. Train-Test Split and Model Training

In [None]:
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(
    features, labels, test_size=0.3, random_state=42, stratify=labels
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

In [None]:
# Initialize and train classifier
classifier = Sentinel2Classifier(
    RandomForestClassifier(n_estimators=100, random_state=42)
)

print("Training model...")
classifier.train(X_train, y_train)
print("Training completed!")

## 5. Model Evaluation

In [None]:
# Predict on test set
y_pred = classifier.predict(X_test)

# Classification report
class_names = ["Water", "Vegetation", "Urban"]
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
plt.xticks(range(3), class_names)
plt.yticks(range(3), class_names)
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.show()

## 6. Full Image Classification

In [None]:
# Classify entire image
print("Classifying full image...")
predictions = classifier.predict(features)

# Reshape to image format
_, height, width = data.shape
classified_image = predictions.reshape(height, width)

# Visualize results
visualize_classification(classified_image)

print(f"Classification completed! Shape: {classified_image.shape}")

## 7. Save Model and Results

In [None]:
# Save trained model
classifier.save_model("trained_model.pkl")
print("Model saved to trained_model.pkl")

# Save classified raster (if original image was loaded)
if "profile" in locals() and profile.get("crs"):
    save_classified_raster(predictions, profile, "classified_output.tif", height, width)
    print("Classified raster saved to classified_output.tif")
else:
    print("Skipping raster save (no geospatial profile available)")

## 8. Load and Test Saved Model

In [None]:
# Test loading saved model
test_classifier = Sentinel2Classifier()
test_classifier.load_model("trained_model.pkl")

# Quick prediction test
test_pred = test_classifier.predict(features[:100])  # Test on first 100 pixels
print(f"Loaded model test - predictions shape: {test_pred.shape}")
print(f"Predicted classes: {np.unique(test_pred)}")
print("Pipeline completed successfully!")