In [None]:
# Core imports
import sys
import os
sys.path.append('../src')

import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML

# Workshop modules
from data_utils import EuroSATDataset, HydroFloodDataset, EUROSAT_CLASSES
from models import SatelliteCNN, SatelliteViT, FloodCNN, FloodViT, ModelComparator, create_models
from visualization import CNNVisualizer, TransformerVisualizer, ModelComparator as VisComparator
from training import WorkshopTrainer, evaluate_model, compare_models, create_dummy_history

# Styling
plt.style.use('default')
sns.set_palette("husl")

print("🎯 Workshop setup complete!")
print(f"📱 PyTorch version: {torch.__version__}")
print(f"🖥️  Device available: {'GPU' if torch.cuda.is_available() else 'CPU'}")


In [None]:
# Dataset choice - flip to True for hydrology-centric demo
USE_FLOOD = False  # Set to True for flood mapping demo

if USE_FLOOD:
    print("🌊 Loading Flood Dataset...")
    dataset = HydroFloodDataset(root_dir="../data", split="train")
    NUM_CLASSES = 2
    CLASS_NAMES = HydroFloodDataset.LABELS
    dataset_type = "Flood Detection (GFZ: Flood Risk & Climate Adaptation)"
else:
    print("🛰️ Loading EuroSAT Dataset...")
    dataset = EuroSATDataset(root_dir="../data")
    NUM_CLASSES = len(EUROSAT_CLASSES)
    CLASS_NAMES = EUROSAT_CLASSES
    dataset_type = "Land-Cover (GFZ: Landscape Hydrology)"

# Get data loaders
train_loader, val_loader = dataset.get_dataloaders(batch_size=32)

print(f"\n📊 Dataset Information:")
print(f"   🎯 Type: {dataset_type}")
print(f"   🏷️  Classes: {NUM_CLASSES}")
print(f"   📸 Training batches: {len(train_loader)}")
print(f"   📸 Validation batches: {len(val_loader)}")
print(f"\n🏷️  Class names:")
for i, class_name in enumerate(CLASS_NAMES):
    print(f"   {i}: {class_name}")


In [None]:
# Visualize sample images
print("🖼️ Sample satellite images:")
dataset.visualize_samples(train_loader, num_samples=8)


In [None]:
# Create appropriate models based on dataset choice
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🎯 Using device: {device}")

# Determine task type and create optimized models
task_type = "flood" if USE_FLOOD else "landcover"
cnn_model, vit_model = create_models(num_classes=NUM_CLASSES, task_type=task_type)

# Move models to device
cnn_model.to(device)
vit_model.to(device)

print(f"\n🔍 Model Architecture Comparison:")
print(f"   📊 CNN: {cnn_model.__class__.__name__} ({sum(p.numel() for p in cnn_model.parameters()):,} params)")
print(f"   🤖 ViT: {vit_model.__class__.__name__} ({sum(p.numel() for p in vit_model.parameters()):,} params)")
print(f"   ⚖️  Parameter Ratio: {sum(p.numel() for p in vit_model.parameters()) / sum(p.numel() for p in cnn_model.parameters()):.1f}x")

if USE_FLOOD:
    print(f"\n💡 Flood Detection Models:")
    print(f"   🌊 FloodCNN: Lightweight architecture optimized for binary water detection")
    print(f"   🌊 FloodViT: Efficient transformer for rapid flood assessment")
else:
    print(f"\n💡 Land-Cover Models:")
    print(f"   🌍 SatelliteCNN: Full-scale CNN for complex multi-class land-cover")
    print(f"   🌍 SatelliteViT: Full transformer for global land-use patterns")


In [None]:
# Quick demo inference to show CNN feature maps and ViT attention
sample_data, sample_labels = next(iter(val_loader))
sample_data = sample_data.to(device)

print("🔍 Analyzing one satellite image...")

# CNN Feature Maps
cnn_model.eval()
with torch.no_grad():
    cnn_output = cnn_model(sample_data[:1])
    feature_maps = cnn_model.get_feature_maps()

print(f"CNN captured {len(feature_maps)} feature map layers")

# ViT Attention
vit_model.eval()
with torch.no_grad():
    vit_output = vit_model(sample_data[:1])
    attention_weights = vit_model.get_attention_weights()

if attention_weights:
    print(f"ViT captured attention from {len(attention_weights)} layers")
else:
    print("ViT attention visualization not available")


In [None]:
# Visualize CNN Feature Maps
if feature_maps:
    print("🧠 CNN Feature Maps - What does the CNN see?")
    cnn_visualizer = CNNVisualizer(cnn_model)
    cnn_visualizer.plot_feature_maps(feature_maps, num_channels=6)
    
    # Show statistics
    cnn_visualizer.plot_activation_statistics(feature_maps)
else:
    print("⚠️ No CNN feature maps to display")


In [None]:
# Visualize ViT Attention
if attention_weights:
    print("🤖 Vision Transformer Attention - Where does ViT focus?")
    vit_visualizer = TransformerVisualizer(vit_model)
    
    # Show attention maps from different layers
    print("🎯 Early layer attention:")
    vit_visualizer.plot_attention_maps(attention_weights, layer_idx=2)
    
    print("🎯 Final layer attention:")
    vit_visualizer.plot_attention_maps(attention_weights, layer_idx=-1)
    
    # Attention rollout
    print("🔍 Attention rollout - spatial attention map:")
    rollout_map = vit_visualizer.plot_attention_rollout(attention_weights)
else:
    print("⚠️ No ViT attention weights to display")


In [None]:
# Physics Check Exercise: Do attention patterns make hydrological sense?
print("🔬 Physics Interpretability Check")
print("   💭 Questions to explore:")
print("   • Do CNN feature maps focus on water bodies, vegetation edges?")
print("   • Does ViT attention correlate with drainage networks?")
print("   • Are high-attention regions in topographically relevant areas?")
print("   • For flood detection: Does the model focus on low-lying areas?")
print()
print("💡 Next step: Overlay predictions on DEM/terrain data")
print("   This helps validate that AI decisions align with physical processes!")

# Workshop participants can discuss:
# - How would you validate model focus against known hydrology?
# - What additional data layers would improve model physics-awareness?
# - How might you incorporate process-based constraints?


In [None]:
# Quick HuggingFace demo
print("🤗 HuggingFace Integration Demo")
try:
    from transformers import pipeline
    
    # Create image classification pipeline
    classifier = pipeline("image-classification", 
                         model="google/vit-base-patch16-224")
    
    print("✅ HuggingFace ViT pipeline ready!")
    print("💡 You can now classify any image with just a few lines of code!")
    print("Example: predictions = classifier(your_image)")
    
except Exception as e:
    print(f"⚠️ HuggingFace pipeline demo not available: {e}")
    print("💡 In practice, you'd have access to thousands of pre-trained models!")


In [None]:
# 🏃‍♀️ Quick Challenges (If you finish early!)

print("🏆 Challenge 1: Architecture-Task Matching")
print("   Try switching between FloodCNN/FloodViT and SatelliteCNN/SatelliteViT")
print("   - Set USE_FLOOD=True vs USE_FLOOD=False and compare model sizes")
print("   - Notice how simpler tasks use simpler architectures!")

print("\n🏆 Challenge 2: Attention Analysis")
print("   Compare attention patterns from different ViT layers")
print("   - Use vit_visualizer.plot_attention_maps() with different layer_idx")

print("\n🏆 Challenge 3: Transfer Learning")
print("   Implement fine-tuning with frozen base layers")
print("   - Freeze ViT base, only train classification head")

print("\n🏆 Challenge 4: Ensemble Methods")
print("   Combine CNN and ViT predictions")
print("   - Average their outputs for better performance")

print("\n🏆 Challenge 5: Real Dataset")
print("   Replace CIFAR-10 with actual EuroSAT dataset")
print("   - Download from: https://github.com/phelber/EuroSAT")

print("\n🏆 Challenge 6: GFZ Data Integration")  
print("   Download satellite data from GFZ's Geoportal")
print("   - Use ogr2ogr to extract your study area")
print("   - Test models on actual German catchments")

print("\n🏆 Challenge 7: Hybrid CNN-Physics Models")
print("   Combine AI outputs with process-based models")
print("   - Use CNN land-cover as input to rainfall-runoff models")
print("   - Constrain flood predictions with DEM flow directions")

print("\n💡 Most importantly: Experiment and have fun! 🎉")
