# EuroSAT Dataset Exploration

This notebook explores the EuroSAT dataset used for land use classification.

## Objectives:
1. Load and visualize sample images
2. Analyze class distribution
3. Examine image statistics
4. Preview data augmentation

In [None]:
import sys
sys.path.append('../src')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import datasets

from dataset import (
    BinaryEuroSATDataset, get_dataloaders, get_transforms,
    EUROSAT_CLASSES, BINARY_CLASS_NAMES, CLASS_TO_BINARY
)

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

## 1. Load Dataset

In [None]:
# Load EuroSAT dataset
print("Loading EuroSAT dataset...")
transform = get_transforms(is_training=False)
eurosat_dataset = datasets.EuroSAT(root='../data', download=True, transform=transform)

# Create binary dataset wrapper
binary_dataset = BinaryEuroSATDataset(eurosat_dataset)

print(f"Total images: {len(eurosat_dataset)}")
print(f"Binary dataset size: {len(binary_dataset)}")
print(f"Excluded images: {len(eurosat_dataset) - len(binary_dataset)}")

## 2. Class Distribution

### Original 10-Class Distribution

In [None]:
# Count samples in each original class
original_class_counts = {cls: 0 for cls in EUROSAT_CLASSES}

for idx in range(len(eurosat_dataset)):
    _, label = eurosat_dataset[idx]
    class_name = EUROSAT_CLASSES[label]
    original_class_counts[class_name] += 1

# Plot
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
classes = list(original_class_counts.keys())
counts = list(original_class_counts.values())
colors = ['green' if c in CLASS_TO_BINARY else 'gray' for c in classes]

ax.barh(classes, counts, color=colors, alpha=0.7)
ax.set_xlabel('Number of Samples', fontsize=12)
ax.set_title('EuroSAT Original Class Distribution\n(Green = Used, Gray = Excluded)', 
             fontsize=14, fontweight='bold')
ax.invert_yaxis()

# Add value labels
for i, (cls, count) in enumerate(zip(classes, counts)):
    ax.text(count + 50, i, str(count), va='center', fontsize=10)

plt.tight_layout()
plt.show()

print("\nSample counts by original class:")
for cls, count in original_class_counts.items():
    status = "[USED]" if cls in CLASS_TO_BINARY else "[EXCLUDED]"
    print(f"  {status} {cls:<25} {count:>5}")

### Binary Class Distribution

In [None]:
# Get binary distribution
binary_distribution = binary_dataset.get_class_distribution()

# Plot
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
classes = list(binary_distribution.keys())
counts = list(binary_distribution.values())
colors = ['#3498db', '#2ecc71']

bars = ax.bar(classes, counts, color=colors, alpha=0.7, edgecolor='black')
ax.set_ylabel('Number of Samples', fontsize=12)
ax.set_title('Binary Class Distribution', fontsize=14, fontweight='bold')

# Add value labels and percentages
total = sum(counts)
for bar, count in zip(bars, counts):
    height = bar.get_height()
    percentage = (count / total) * 100
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{count}\n({percentage:.1f}%)',
            ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

print("\nBinary class distribution:")
for cls, count in binary_distribution.items():
    percentage = (count / total) * 100
    print(f"  {cls:<10} {count:>5} ({percentage:.1f}%)")
print(f"\nTotal: {total}")

## 3. Sample Images

### City Samples

In [None]:
# Find city samples
city_indices = [i for i in range(len(binary_dataset)) 
                if binary_dataset[i][1] == 0][:12]

# Denormalization parameters
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Plot
fig, axes = plt.subplots(3, 4, figsize=(12, 9))
fig.suptitle('City Samples', fontsize=16, fontweight='bold')

for idx, ax in zip(city_indices, axes.flat):
    img, label = binary_dataset[idx]
    
    # Denormalize
    img_np = img.numpy().transpose(1, 2, 0)
    img_np = std * img_np + mean
    img_np = np.clip(img_np, 0, 1)
    
    ax.imshow(img_np)
    ax.set_title(f'Sample {idx}', fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.show()

### Farmland Samples

In [None]:
# Find farmland samples
farmland_indices = [i for i in range(len(binary_dataset)) 
                    if binary_dataset[i][1] == 1][:12]

# Plot
fig, axes = plt.subplots(3, 4, figsize=(12, 9))
fig.suptitle('Farmland Samples', fontsize=16, fontweight='bold')

for idx, ax in zip(farmland_indices, axes.flat):
    img, label = binary_dataset[idx]
    
    # Denormalize
    img_np = img.numpy().transpose(1, 2, 0)
    img_np = std * img_np + mean
    img_np = np.clip(img_np, 0, 1)
    
    ax.imshow(img_np)
    ax.set_title(f'Sample {idx}', fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.show()

## 4. Image Statistics

In [None]:
# Compute pixel statistics (on a sample for speed)
sample_size = min(1000, len(binary_dataset))
pixel_values = {'R': [], 'G': [], 'B': []}

print(f"Computing statistics on {sample_size} samples...")
for i in range(sample_size):
    img, _ = binary_dataset[i]
    img_np = img.numpy()
    
    pixel_values['R'].append(img_np[0].flatten())
    pixel_values['G'].append(img_np[1].flatten())
    pixel_values['B'].append(img_np[2].flatten())

# Flatten and compute stats
for channel in ['R', 'G', 'B']:
    pixel_values[channel] = np.concatenate(pixel_values[channel])

# Plot distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
colors = ['red', 'green', 'blue']

for ax, channel, color in zip(axes, ['R', 'G', 'B'], colors):
    ax.hist(pixel_values[channel], bins=50, color=color, alpha=0.7, edgecolor='black')
    ax.set_title(f'{channel} Channel Distribution', fontsize=12, fontweight='bold')
    ax.set_xlabel('Normalized Pixel Value', fontsize=10)
    ax.set_ylabel('Frequency', fontsize=10)
    
    # Stats
    mean_val = pixel_values[channel].mean()
    std_val = pixel_values[channel].std()
    ax.axvline(mean_val, color='black', linestyle='--', linewidth=2, 
               label=f'Mean: {mean_val:.3f}')
    ax.legend()

plt.tight_layout()
plt.show()

print("\nChannel statistics (normalized):")
for channel in ['R', 'G', 'B']:
    print(f"  {channel}: Mean={pixel_values[channel].mean():.4f}, "
          f"Std={pixel_values[channel].std():.4f}")

## 5. Data Augmentation Preview

In [None]:
# Load dataset with augmentation
transform_aug = get_transforms(is_training=True)
eurosat_aug = datasets.EuroSAT(root='../data', download=False, transform=transform_aug)
binary_aug = BinaryEuroSATDataset(eurosat_aug)

# Get one sample and apply multiple augmentations
sample_idx = 0
original_img, label = binary_dataset[sample_idx]

# Generate augmented versions
aug_images = [binary_aug[sample_idx][0] for _ in range(8)]

# Plot
fig, axes = plt.subplots(3, 3, figsize=(10, 10))
fig.suptitle(f'Data Augmentation Examples - {BINARY_CLASS_NAMES[label]}', 
             fontsize=14, fontweight='bold')

# Original (center)
img_np = original_img.numpy().transpose(1, 2, 0)
img_np = std * img_np + mean
img_np = np.clip(img_np, 0, 1)
axes[1, 1].imshow(img_np)
axes[1, 1].set_title('Original', fontsize=12, fontweight='bold', color='green')
axes[1, 1].axis('off')

# Augmented (surrounding)
positions = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (2, 2)]
for img, (row, col) in zip(aug_images, positions):
    img_np = img.numpy().transpose(1, 2, 0)
    img_np = std * img_np + mean
    img_np = np.clip(img_np, 0, 1)
    
    axes[row, col].imshow(img_np)
    axes[row, col].set_title('Augmented', fontsize=10)
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

## 6. DataLoader Test

In [None]:
# Test dataloaders
print("Creating dataloaders...")
train_loader, val_loader, test_loader = get_dataloaders(batch_size=32, num_workers=0)

# Get a batch
images, labels = next(iter(train_loader))

print(f"\nBatch information:")
print(f"  Images shape: {images.shape}")
print(f"  Labels shape: {labels.shape}")
print(f"  Labels in batch: {labels.tolist()}")
print(f"  City count: {(labels == 0).sum().item()}")
print(f"  Farmland count: {(labels == 1).sum().item()}")

print("\nDataLoader test successful!")

## Summary

This notebook demonstrated:
- ✅ EuroSAT dataset loading and binary mapping
- ✅ Class distribution analysis (balanced dataset)
- ✅ Visual inspection of City vs Farmland samples
- ✅ Image statistics and channel distributions
- ✅ Data augmentation techniques
- ✅ DataLoader functionality

The dataset is ready for model training!