In [2]:
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras import layers, models
import geopandas as gpd
import ast
import os
import json
import joblib
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

# 1. Define Data Paths and Parameters
patch_sizes = [1, 5, 9, 13]
patch_size_folders = [f"/kaggle/input/treesatai-patch{p}x{p}" for p in patch_sizes]

# Define bands (as provided, no monthly suffixes based on GeoJSON sample)
bands = [
    'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12',
    'NDVI', 'EVI', 'SAVI', 'NDWI', 'DEM'
]
band_columns = bands  # 17 bands, no monthly data

# 2. Function to Load Data for a Given Patch Size
def load_data_for_patch_size(patch_size, folder_path):
    all_features = []
    all_labels = []
    invalid_samples = []
    # Verify available columns
    available_columns = None
    for file in os.listdir(folder_path):
        if file.endswith(".geojson"):
            gdf = gpd.read_file(os.path.join(folder_path, file))
            available_columns = [col for col in band_columns if col in gdf.columns]
            # Debugging: Print sample data for first few rows
            print(f"\nInspecting first 2 rows of {file}:")
            for idx in range(min(2, len(gdf))):
                print(f"Row {idx}:")
                for col in available_columns:
                    print(f"  {col}: {gdf[col].iloc[idx]}")
                print(f"  l3_species: {gdf['l3_species'].iloc[idx]}")
            break
    if not available_columns:
        raise ValueError(f"No valid GeoJSON files found in {folder_path}")
    print(f"Patch size {patch_size}: Available columns: {available_columns}")
    expected_shape = (patch_size, patch_size, len(available_columns))

    # Load all GeoJSON files
    for file in os.listdir(folder_path):
        if file.endswith(".geojson"):
            gdf = gpd.read_file(os.path.join(folder_path, file))
            for idx, row in gdf.iterrows():
                try:
                    if row['l3_species'] is None or not isinstance(row['l3_species'], str):
                        raise ValueError(f"Invalid label at row {idx}: {row['l3_species']}")
                    patches = []
                    for col in available_columns:
                        val = row[col]
                        if val is None:
                            raise ValueError(f"Null value for band {col}")
                        if isinstance(val, (list, np.ndarray)):
                            arr = np.array(val, dtype=np.float32).reshape(patch_size, patch_size)
                        elif isinstance(val, str):
                            try:
                                arr = np.array(json.loads(val), dtype=np.float32).reshape(patch_size, patch_size)
                            except json.JSONDecodeError:
                                arr = np.array(ast.literal_eval(val), dtype=np.float32).reshape(patch_size, patch_size)
                        elif isinstance(val, (float, np.float32, np.float64)):
                            arr = np.full((patch_size, patch_size), val, dtype=np.float32)
                        else:
                            raise ValueError(f"Unexpected data type for band {col}: {type(val)}")
                        patches.append(arr)

                    patch = np.stack(patches, axis=-1)
                    if patch.shape != expected_shape:
                        raise ValueError(f"Unexpected patch shape: {patch.shape}, expected {expected_shape}")
                    all_features.append(patch)
                    all_labels.append(row['l3_species'])
                except Exception as e:
                    invalid_samples.append((file, idx, str(e)))
                    continue

    if invalid_samples:
        print(f"\nPatch size {patch_size}: Skipped {len(invalid_samples)} invalid samples")
        for file, idx, error in invalid_samples[:5]:
            print(f"File: {file}, Row: {idx}, Error: {error}")
        # Save invalid samples to a file for inspection
        with open(f'invalid_samples_patch_size_{patch_size}.json', 'w') as f:
            json.dump(invalid_samples, f, indent=4)

    if not all_features:
        raise ValueError(f"No valid samples loaded for patch size {patch_size}. Check GeoJSON files.")
    X = np.array(all_features, dtype=np.float32)
    y = np.array(all_labels)
    print(f"Patch size {patch_size}: Loaded {len(all_features)} valid samples with shape {X.shape}")
    return X, y, invalid_samples, available_columns

# 3. Define CNN Model
def build_cnn(input_shape, num_classes):
    model = models.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)) if input_shape[0] >= 5 else layers.Lambda(lambda x: x),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

# 4. Data Augmentation
def get_data_augmentation(patch_size):
    if patch_size > 1:
        return tf.keras.Sequential([
            layers.RandomFlip("horizontal_and_vertical"),
            layers.RandomRotation(0.2),
        ])
    return lambda x: x  # Identity function for 1x1 patches (no augmentation)

# 5. Train and Evaluate for Each Patch Size
results = {}
label_encoder = LabelEncoder()
available_columns_dict = {}

for patch_size, folder_path in zip(patch_sizes, patch_size_folders):
    print(f"\nProcessing patch size: {patch_size}x{patch_size}")
    if not os.path.exists(folder_path):
        print(f"Folder {folder_path} does not exist.")
        continue

    # Load data
    try:
        X, y, invalid_samples, available_columns = load_data_for_patch_size(patch_size, folder_path)
        available_columns_dict[patch_size] = available_columns
    except Exception as e:
        print(f"Error loading data for patch size {patch_size}: {e}")
        continue

    # Handle NaN values
    print(f"\nPatch size {patch_size}: Checking for NaN values...")
    nan_mask = np.any(np.isnan(X), axis=(1, 2, 3))
    nan_count = np.sum(nan_mask)
    if nan_count > 0:
        print(f"Patch size {patch_size}: Removing {nan_count} samples with NaN values")
        valid_mask = ~nan_mask
        X = X[valid_mask]
        y = y[valid_mask]
        print(f"Patch size {patch_size}: New data shape after removing NaN: {X.shape}")

    # Normalize data
    print(f"Patch size {patch_size}: Normalizing data...")
    X_min = np.nanmin(X, axis=(0, 1, 2), keepdims=True)
    X_max = np.nanmax(X, axis=(0, 1, 2), keepdims=True)
    X = (X - X_min) / (X_max - X_min + 1e-6)
    print(f"Patch size {patch_size}: Data range after normalization: min={np.nanmin(X):.4f}, max={np.nanmax(X):.4f}")

    # Encode labels
    try:
        if not results:
            y_encoded = label_encoder.fit_transform(y)
        else:
            y_encoded = label_encoder.transform(y)
        y_onehot = tf.keras.utils.to_categorical(y_encoded)
        num_classes = len(label_encoder.classes_)
        print(f"Patch size {patch_size}: Data shape: {X.shape}, Number of classes: {num_classes}")
    except Exception as e:
        print(f"Error encoding labels for patch size {patch_size}: {e}")
        continue

    # Train-test-validation split
    try:
        X_train, X_test, y_train, y_test = train_test_split(X, y_onehot, test_size=0.15, random_state=42, stratify=y_onehot)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1765, random_state=42, stratify=y_train)
        print(f"Patch size {patch_size}: Train shape: {X_train.shape}, Validation shape: {X_val.shape}, Test shape: {X_test.shape}")
    except Exception as e:
        print(f"Error splitting data for patch size {patch_size}: {e}")
        continue

    # Build and compile model
    try:
        model = build_cnn(input_shape=(patch_size, patch_size, len(available_columns)), num_classes=num_classes)
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    except Exception as e:
        print(f"Error building model for patch size {patch_size}: {e}")
        continue

    # Train model
    try:
        history = model.fit(
            get_data_augmentation(patch_size)(X_train), y_train,
            validation_data=(X_val, y_val),
            epochs=50,
            batch_size=16,  # Reduced to mitigate memory issues
            callbacks=[
                tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
                tf.keras.callbacks.ModelCheckpoint(f'best_model_patch_size_{patch_size}.keras', save_best_only=True)
            ],
            verbose=1
        )
    except Exception as e:
        print(f"Error training model for patch size {patch_size}: {e}")
        continue

    # Evaluate model
    try:
        test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
        print(f"Patch size {patch_size}: Test Accuracy: {test_accuracy:.4f}, Test Loss: {test_loss:.4f}")

        # Compute additional metrics
        y_pred = model.predict(X_test, verbose=0)
        y_pred_classes = np.argmax(y_pred, axis=1)
        y_test_classes = np.argmax(y_test, axis=1)

        # Classification report (precision, recall, F1-score)
        class_report = classification_report(y_test_classes, y_pred_classes, target_names=label_encoder.classes_, output_dict=True)
        print(f"\nPatch size {patch_size}: Classification Report:")
        print(classification_report(y_test_classes, y_pred_classes, target_names=label_encoder.classes_))

        # Confusion matrix
        cm = confusion_matrix(y_test_classes, y_pred_classes)
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Confusion Matrix: Patch Size {patch_size}x{patch_size}')
        plt.savefig(f'confusion_matrix_patch_size_{patch_size}.png')
        plt.close()

        # Store results
        results[f'patch_size_{patch_size}'] = {
            'patch_size': patch_size,
            'test_accuracy': test_accuracy,
            'test_loss': test_loss,
            'classification_report': class_report,
            'history': history.history
        }
    except Exception as e:
        print(f"Error evaluating model for patch size {patch_size}: {e}")
        continue

# 6. Compare Results
print("\nSummary of Results:")
for key, result in results.items():
    print(f"\nPatch Size {result['patch_size']}x{result['patch_size']}:")
    print(f"  Test Accuracy: {result['test_accuracy']:.4f}")
    print(f"  Test Loss: {result['test_loss']:.4f}")
    print(f"  Weighted Precision: {result['classification_report']['weighted avg']['precision']:.4f}")
    print(f"  Weighted Recall: {result['classification_report']['weighted avg']['recall']:.4f}")
    print(f"  Weighted F1-Score: {result['classification_report']['weighted avg']['f1-score']:.4f}")
    print(f"  Macro Precision: {result['classification_report']['macro avg']['precision']:.4f}")
    print(f"  Macro Recall: {result['classification_report']['macro avg']['recall']:.4f}")
    print(f"  Macro F1-Score: {result['classification_report']['macro avg']['f1-score']:.4f}")

# Plot accuracy comparison
plt.figure(figsize=(10, 6))
accuracies = [results[key]['test_accuracy'] for key in results]
patch_size_labels = [f"{results[key]['patch_size']}x{results[key]['patch_size']}" for key in results]
plt.bar(patch_size_labels, accuracies)
plt.xlabel('Patch Size')
plt.ylabel('Test Accuracy')
plt.title('Test Accuracy by Patch Size')
plt.tight_layout()
plt.savefig('accuracy_comparison_patch_sizes.png')
plt.close()

# Plot weighted F1-score comparison
plt.figure(figsize=(10, 6))
f1_scores = [results[key]['classification_report']['weighted avg']['f1-score'] for key in results]
plt.bar(patch_size_labels, f1_scores)
plt.xlabel('Patch Size')
plt.ylabel('Weighted F1-Score')
plt.title('Weighted F1-Score by Patch Size')
plt.tight_layout()
plt.savefig('f1_score_comparison_patch_sizes.png')
plt.close()

# 7. Save Label Encoder and Results
joblib.dump(label_encoder, 'label_encoder.pkl')
with open('results_summary_patch_sizes.json', 'w') as f:
    json.dump({k: {kk: vv for kk, vv in v.items() if kk != 'history'} for k, v in results.items()}, f, indent=4)


Processing patch size: 1x1

Inspecting first 2 rows of needleleaf_douglas fir_douglas fir_Patch1x1_2022-07.geojson:
Row 0:
  B1: None
  B2: None
  B3: None
  B4: None
  B5: None
  B6: None
  B7: None
  B8: None
  B8A: None
  B9: None
  B11: None
  B12: None
  NDVI: None
  EVI: None
  SAVI: None
  NDWI: None
  DEM: None
  l3_species: douglas fir
Row 1:
  B1: [[0.1298000067472458]]
  B2: [[0.12809999287128448]]
  B3: [[0.1462000012397766]]
  B4: [[0.13349999487400055]]
  B5: [[0.17399999499320984]]
  B6: [[0.3084999918937683]]
  B7: [[0.3635999858379364]]
  B8: [[0.38960000872612]]
  B8A: [[0.3939000070095062]]
  B9: [[0.3813999891281128]]
  B11: [[0.24650000035762787]]
  B12: [[0.16859999299049377]]
  NDVI: [[0.4895813763141632]]
  EVI: [[0.5205919569583571]]
  SAVI: [[0.3754765119992361]]
  NDWI: [[0.22496463358402252]]
  DEM: [[96.0]]
  l3_species: douglas fir
Patch size 1: Available columns: ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12', 'NDVI', 'EVI', '

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Patch size 1: Classification Report:
                precision    recall  f1-score   support

         alder       0.23      0.21      0.22        77
         birch       0.21      0.27      0.23       131
    black pine       0.58      0.62      0.60        34
        cherry       0.00      0.00      0.00         8
   douglas fir       0.68      0.57      0.62       188
   english oak       0.55      0.31      0.39       177
  european ash       0.33      0.39      0.36        57
european beech       0.39      0.62      0.48       120
european larch       0.17      0.09      0.11        35
japanese larch       0.35      0.49      0.41       128
        linden       0.00      0.00      0.00         7
 norway spruce       0.67      0.40      0.50       101
        poplar       0.00      0.00      0.00        11
       red oak       0.35      0.63      0.45       124
    scots pine       0.65      0.76      0.70       423
   sessile oak       0.60      0.03      0.07        87
    silve

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Patch size 5: Classification Report:
                precision    recall  f1-score   support

         alder       0.40      0.03      0.05        77
         birch       0.21      0.11      0.14       131
    black pine       0.63      0.56      0.59        34
        cherry       0.00      0.00      0.00         8
   douglas fir       0.77      0.59      0.67       188
   english oak       0.34      0.56      0.43       177
  european ash       0.57      0.07      0.12        57
european beech       0.38      0.72      0.50       120
european larch       0.44      0.11      0.18        35
japanese larch       0.39      0.20      0.26       128
        linden       0.00      0.00      0.00         7
 norway spruce       0.69      0.62      0.66       101
        poplar       0.00      0.00      0.00        11
       red oak       0.38      0.59      0.46       124
    scots pine       0.61      0.89      0.73       423
   sessile oak       0.54      0.43      0.47        87
    silve

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Patch size 9: Classification Report:
                precision    recall  f1-score   support

         alder       0.00      0.00      0.00        77
         birch       0.12      0.22      0.15       131
    black pine       0.26      0.50      0.34        34
        cherry       0.60      0.38      0.46         8
   douglas fir       0.68      0.63      0.65       188
   english oak       0.40      0.18      0.25       177
  european ash       0.00      0.00      0.00        57
european beech       0.44      0.51      0.47       120
european larch       0.50      0.03      0.05        35
japanese larch       0.26      0.05      0.09       128
        linden       0.00      0.00      0.00         7
 norway spruce       0.70      0.56      0.63       101
        poplar       0.00      0.00      0.00        11
       red oak       0.26      0.77      0.39       124
    scots pine       0.63      0.84      0.72       423
   sessile oak       0.69      0.10      0.18        87
    silve

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Patch size 13: Classification Report:
                precision    recall  f1-score   support

         alder       0.00      0.00      0.00        77
         birch       0.00      0.00      0.00       131
    black pine       0.00      0.00      0.00        34
        cherry       0.00      0.00      0.00         8
   douglas fir       0.71      0.41      0.52       188
   english oak       0.17      0.83      0.29       177
  european ash       0.00      0.00      0.00        57
european beech       0.00      0.00      0.00       120
european larch       0.00      0.00      0.00        35
japanese larch       0.00      0.00      0.00       128
        linden       0.00      0.00      0.00         7
 norway spruce       0.00      0.00      0.00       101
        poplar       0.00      0.00      0.00        11
       red oak       1.00      0.01      0.02       124
    scots pine       0.44      0.93      0.60       423
   sessile oak       0.00      0.00      0.00        87
    silv