In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras import layers, models
import geopandas as gpd
import ast
import os
import json
import joblib
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

# 1. Define Data Paths and Parameters
time_periods = ['july-2022', 'august-2022', 'mar-oct-2022', 'all-2022']
patch_size = 5  # Fixed to 5x5 based on GEE code
data_folders = [
    '/kaggle/input/july2022',  # Update with actual path
    '/kaggle/input/august',  # Update with actual path
    '/kaggle/input/mar-oct',  # Update with actual path
    '/kaggle/input/all2022'  # Update with actual path
]

# Define bands
base_bands = [
    'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12',
    'NDVI', 'EVI', 'SAVI', 'NDWI', 'DEM'
]
time_period_configs = {
    'july-2022': {'months': [''], 'num_bands': 17},
    'august-2022': {'months': [''], 'num_bands': 17},
    'mar-oct-2022': {'months': ['', '_1', '_2', '_3', '_4', '_5', '_6', '_7'], 'num_bands': 136},
    'all-2022': {'months': ['', '_1', '_2', '_3', '_4', '_5', '_6', '_7', '_8', '_9', '_10', '_11'], 'num_bands': 204}
}

# 2. Function to Load Data for a Given Time Period
def load_data_for_time_period(time_period, folder_path):
    all_features = []
    all_labels = []
    invalid_samples = []
    # Get band columns for the time period
    band_columns = [band + month for month in time_period_configs[time_period]['months'] for band in base_bands]
    # Verify available columns
    available_columns = None
    for file in os.listdir(folder_path):
        if file.endswith(".geojson"):
            gdf = gpd.read_file(os.path.join(folder_path, file))
            available_columns = [col for col in band_columns if col in gdf.columns]
            # Debugging: Print sample data for first few rows
            print(f"\nInspecting first 2 rows of {file}:")
            for idx in range(min(2, len(gdf))):
                print(f"Row {idx}:")
                for col in available_columns[:5]:  # Print first 5 bands for brevity
                    print(f"  {col}: {gdf[col].iloc[idx]}")
                print(f"  l3_species: {gdf['l3_species'].iloc[idx]}")
            break
    if not available_columns:
        raise ValueError(f"No valid GeoJSON files found in {folder_path}")
    print(f"Time period {time_period}: Available columns: {len(available_columns)}/{len(band_columns)}")
    expected_shape = (patch_size, patch_size, len(available_columns))

    # Load all GeoJSON files
    for file in os.listdir(folder_path):
        if file.endswith(".geojson"):
            gdf = gpd.read_file(os.path.join(folder_path, file))
            for idx, row in gdf.iterrows():
                try:
                    if row['l3_species'] is None or not isinstance(row['l3_species'], str):
                        raise ValueError(f"Invalid label at row {idx}: {row['l3_species']}")
                    patches = []
                    for col in available_columns:
                        val = row[col]
                        if val is None:
                            raise ValueError(f"Null value for band {col}")
                        if isinstance(val, (list, np.ndarray)):
                            arr = np.array(val, dtype=np.float32).reshape(patch_size, patch_size)
                        elif isinstance(val, str):
                            try:
                                arr = np.array(json.loads(val), dtype=np.float32).reshape(patch_size, patch_size)
                            except json.JSONDecodeError:
                                arr = np.array(ast.literal_eval(val), dtype=np.float32).reshape(patch_size, patch_size)
                        elif isinstance(val, (float, np.float32, np.float64)):
                            arr = np.full((patch_size, patch_size), val, dtype=np.float32)
                        else:
                            raise ValueError(f"Unexpected data type for band {col}: {type(val)}")
                        patches.append(arr)

                    patch = np.stack(patches, axis=-1)
                    if patch.shape != expected_shape:
                        raise ValueError(f"Unexpected patch shape: {patch.shape}, expected {expected_shape}")
                    all_features.append(patch)
                    all_labels.append(row['l3_species'])
                except Exception as e:
                    invalid_samples.append((file, idx, str(e)))
                    continue

    if invalid_samples:
        print(f"\nTime period {time_period}: Skipped {len(invalid_samples)} invalid samples")
        for file, idx, error in invalid_samples[:5]:
            print(f"File: {file}, Row: {idx}, Error: {error}")
        with open(f'invalid_samples_{time_period}.json', 'w') as f:
            json.dump(invalid_samples, f, indent=4)

    if not all_features:
        raise ValueError(f"No valid samples loaded for time period {time_period}. Check GeoJSON files.")
    X = np.array(all_features, dtype=np.float32)
    y = np.array(all_labels)
    print(f"Time period {time_period}: Loaded {len(all_features)} valid samples with shape {X.shape}")
    print(f"Time period {time_period}: Unique labels: {np.unique(y).tolist()}")
    return X, y, invalid_samples, available_columns

# 3. Define CNN Model
def build_cnn(input_shape, num_classes):
    model = models.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

# 4. Data Augmentation
def get_data_augmentation():
    return tf.keras.Sequential([
        layers.RandomFlip("horizontal_and_vertical"),
        layers.RandomRotation(0.2),
    ])

# 5. Train and Evaluate for Each Time Period
results = {}
label_encoders = {}  # Store separate LabelEncoder for each time period

for time_period, folder_path in zip(time_periods, data_folders):
    print(f"\nProcessing time period: {time_period}")
    if not os.path.exists(folder_path):
        print(f"Folder {folder_path} does not exist.")
        continue

    # Load data
    try:
        X, y, invalid_samples, available_columns = load_data_for_time_period(time_period, folder_path)
    except Exception as e:
        print(f"Error loading data for time period {time_period}: {e}")
        continue

    # Handle NaN values
    print(f"\nTime period {time_period}: Checking for NaN values...")
    nan_mask = np.any(np.isnan(X), axis=(1, 2, 3))
    nan_count = np.sum(nan_mask)
    if nan_count > 0:
        print(f"Time period {time_period}: Removing {nan_count} samples with NaN values")
        valid_mask = ~nan_mask
        X = X[valid_mask]
        y = y[valid_mask]
        print(f"Time period {time_period}: New data shape after removing NaN: {X.shape}")

    # Normalize data
    print(f"Time period {time_period}: Normalizing data...")
    X_min = np.nanmin(X, axis=(0, 1, 2), keepdims=True)
    X_max = np.nanmax(X, axis=(0, 1, 2), keepdims=True)
    X = (X - X_min) / (X_max - X_min + 1e-6)
    print(f"Time period {time_period}: Data range after normalization: min={np.nanmin(X):.4f}, max={np.nanmax(X):.4f}")

    # Encode labels
    try:
        label_encoder = LabelEncoder()
        y_encoded = label_encoder.fit_transform(y)
        y_onehot = tf.keras.utils.to_categorical(y_encoded)
        num_classes = len(label_encoder.classes_)
        label_encoders[time_period] = label_encoder  # Store encoder
        print(f"Time period {time_period}: Data shape: {X.shape}, Number of classes: {num_classes}")
        print(f"Time period {time_period}: Class names: {label_encoder.classes_.tolist()}")
    except Exception as e:
        print(f"Error encoding labels for time period {time_period}: {e}")
        continue

    # Train-test-validation split
    try:
        X_train, X_test, y_train, y_test = train_test_split(X, y_onehot, test_size=0.15, random_state=42, stratify=y_onehot)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1765, random_state=42, stratify=y_train)
        print(f"Time period {time_period}: Train shape: {X_train.shape}, Validation shape: {X_val.shape}, Test shape: {X_test.shape}")
    except Exception as e:
        print(f"Error splitting data for time period {time_period}: {e}")
        continue

    # Build and compile model
    try:
        model = build_cnn(input_shape=(patch_size, patch_size, len(available_columns)), num_classes=num_classes)
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    except Exception as e:
        print(f"Error building model for time period {time_period}: {e}")
        continue

    # Train model
    try:
        history = model.fit(
            get_data_augmentation()(X_train), y_train,
            validation_data=(X_val, y_val),
            epochs=50,
            batch_size=16,
            callbacks=[
                tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
                tf.keras.callbacks.ModelCheckpoint(f'best_model_{time_period}.keras', save_best_only=True)
            ],
            verbose=1
        )
    except Exception as e:
        print(f"Error training model for time period {time_period}: {e}")
        continue

    # Evaluate model
    try:
        test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
        print(f"Time period {time_period}: Test Accuracy: {test_accuracy:.4f}, Test Loss: {test_loss:.4f}")

        # Compute additional metrics
        y_pred = model.predict(X_test, verbose=0)
        y_pred_classes = np.argmax(y_pred, axis=1)
        y_test_classes = np.argmax(y_test, axis=1)

        # Classification report
        class_report = classification_report(y_test_classes, y_pred_classes, target_names=label_encoder.classes_, output_dict=True)
        print(f"\nTime period {time_period}: Classification Report:")
        print(classification_report(y_test_classes, y_pred_classes, target_names=label_encoder.classes_))

        # Confusion matrix
        cm = confusion_matrix(y_test_classes, y_pred_classes)
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Confusion Matrix: Time Period {time_period}')
        plt.savefig(f'confusion_matrix_{time_period}.png')
        plt.close()

        # Store results
        results[time_period] = {
            'time_period': time_period,
            'test_accuracy': test_accuracy,
            'test_loss': test_loss,
            'classification_report': class_report,
            'history': history.history,
            'num_classes': num_classes,
            'class_names': label_encoder.classes_.tolist()
        }
    except Exception as e:
        print(f"Error evaluating model for time period {time_period}: {e}")
        continue

# 6. Compare Results
print("\nSummary of Results:")
for key, result in results.items():
    print(f"\nTime Period {result['time_period']}:")
    print(f"  Test Accuracy: {result['test_accuracy']:.4f}")
    print(f"  Test Loss: {result['test_loss']:.4f}")
    print(f"  Weighted Precision: {result['classification_report']['weighted avg']['precision']:.4f}")
    print(f"  Weighted Recall: {result['classification_report']['weighted avg']['recall']:.4f}")
    print(f"  Weighted F1-Score: {result['classification_report']['weighted avg']['f1-score']:.4f}")
    print(f"  Macro Precision: {result['classification_report']['macro avg']['precision']:.4f}")
    print(f"  Macro Recall: {result['classification_report']['macro avg']['recall']:.4f}")
    print(f"  Macro F1-Score: {result['classification_report']['macro avg']['f1-score']:.4f}")
    print(f"  Number of Classes: {result['num_classes']}")
    print(f"  Class Names: {result['class_names']}")

# Plot accuracy comparison
if results:
    plt.figure(figsize=(10, 6))
    completed_periods = [key for key in results]
    accuracies = [results[key]['test_accuracy'] for key in results]
    plt.bar(completed_periods, accuracies)
    plt.xlabel('Time Period')
    plt.ylabel('Test Accuracy')
    plt.title('Test Accuracy by Time Period')
    plt.tight_layout()
    plt.savefig('accuracy_comparison_time_periods.png')
    plt.close()

    # Plot weighted F1-score comparison
    plt.figure(figsize=(10, 6))
    f1_scores = [results[key]['classification_report']['weighted avg']['f1-score'] for key in results]
    plt.bar(completed_periods, f1_scores)
    plt.xlabel('Time Period')
    plt.ylabel('Weighted F1-Score')
    plt.title('Weighted F1-Score by Time Period')
    plt.tight_layout()
    plt.savefig('f1_score_comparison_time_periods.png')
    plt.close()

# 7. Save Label Encoders and Results
for time_period, encoder in label_encoders.items():
    joblib.dump(encoder, f'label_encoder_{time_period}.pkl')
with open('results_summary_time_periods.json', 'w') as f:
    json.dump({k: {kk: vv for kk, vv in v.items() if kk != 'history'} for k, v in results.items()}, f, indent=4)





2025-06-29 00:32:55.528389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751157175.722394      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751157175.774939      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered



Processing time period: july-2022

Inspecting first 2 rows of needleleaf_pine_weymouth pinejuly-2022.geojson:
Row 0:
  B1: None
  B2: None
  B3: None
  B4: None
  B5: None
  l3_species: weymouth pine
Row 1:
  B1: None
  B2: None
  B3: None
  B4: None
  B5: None
  l3_species: weymouth pine
Time period july-2022: Available columns: 17/17

Time period july-2022: Skipped 25577 invalid samples
File: needleleaf_pine_weymouth pinejuly-2022.geojson, Row: 0, Error: Null value for band B1
File: needleleaf_pine_weymouth pinejuly-2022.geojson, Row: 1, Error: Null value for band B1
File: needleleaf_pine_weymouth pinejuly-2022.geojson, Row: 2, Error: Null value for band B1
File: needleleaf_pine_weymouth pinejuly-2022.geojson, Row: 3, Error: Null value for band B1
File: needleleaf_pine_weymouth pinejuly-2022.geojson, Row: 4, Error: Null value for band B1
Time period july-2022: Loaded 12330 valid samples with shape (12330, 5, 5, 17)
Time period july-2022: Unique labels: ['alder', 'birch', 'black pine

I0000 00:00:1751157214.457117      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Epoch 1/50


I0000 00:00:1751157219.843108      60 service.cc:148] XLA service 0x7f93a80044f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1751157219.843623      60 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1751157220.255995      60 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m 61/540[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m1s[0m 3ms/step - accuracy: 0.1765 - loss: 3.0117

I0000 00:00:1751157222.965892      60 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 10ms/step - accuracy: 0.2376 - loss: 2.5532 - val_accuracy: 0.1730 - val_loss: 2.4300
Epoch 2/50
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.2900 - loss: 2.2572 - val_accuracy: 0.1876 - val_loss: 3.6922
Epoch 3/50
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.3143 - loss: 2.1534 - val_accuracy: 0.1265 - val_loss: 4.3923
Epoch 4/50
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.3248 - loss: 2.0954 - val_accuracy: 0.1816 - val_loss: 2.8325
Epoch 5/50
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.3403 - loss: 2.0692 - val_accuracy: 0.3259 - val_loss: 2.1615
Epoch 6/50
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.3479 - loss: 2.0463 - val_accuracy: 0.3070 - val_loss: 2.1784
Epoch 7/50
[1m540/540[0m [32m━━━━━

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Time period july-2022: Classification Report:
                precision    recall  f1-score   support

         alder       0.00      0.00      0.00        77
         birch       0.15      0.12      0.14       131
    black pine       0.12      0.62      0.20        34
        cherry       0.00      0.00      0.00         8
   douglas fir       0.88      0.34      0.49       188
   english oak       0.30      0.62      0.40       177
  european ash       0.21      0.07      0.11        57
european beech       0.29      0.73      0.42       120
european larch       0.22      0.06      0.09        35
japanese larch       0.23      0.11      0.15       128
        linden       0.00      0.00      0.00         7
 norway spruce       0.25      0.83      0.38       101
        poplar       0.00      0.00      0.00        11
       red oak       0.75      0.05      0.09       124
    scots pine       0.61      0.43      0.51       423
   sessile oak       1.00      0.01      0.02        87


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Time period august-2022: Classification Report:
                precision    recall  f1-score   support

         alder       0.24      0.10      0.14       322
         birch       0.31      0.35      0.33       370
    black pine       0.50      0.08      0.14        62
        cherry       0.00      0.00      0.00        37
   douglas fir       0.60      0.57      0.58       328
   english oak       0.39      0.66      0.49       421
  european ash       0.51      0.25      0.34       330
european beech       0.47      0.85      0.61       714
european larch       0.47      0.19      0.27       171
japanese larch       0.93      0.33      0.49       242
        linden       0.00      0.00      0.00        24
 norway spruce       0.65      0.84      0.73       756
        poplar       1.00      0.02      0.03        58
       red oak       0.53      0.43      0.47       219
    scots pine       0.81      0.75      0.78       808
   sessile oak       0.42      0.39      0.41       31

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Time period mar-oct-2022: Classification Report:
                precision    recall  f1-score   support

         alder       0.42      0.77      0.55        53
         birch       0.47      0.38      0.42        55
    black pine       0.76      0.90      0.82        31
        cherry       0.00      0.00      0.00         1
   douglas fir       0.37      0.97      0.54        70
   english oak       0.53      0.53      0.53       106
  european ash       0.65      0.31      0.42        36
european beech       0.65      0.55      0.59        44
european larch       1.00      0.33      0.50        18
japanese larch       0.88      0.65      0.75        55
        linden       1.00      0.33      0.50         6
 norway spruce       0.75      0.67      0.70        57
        poplar       1.00      0.14      0.25         7
       red oak       0.44      0.68      0.53        41
    scots pine       0.83      0.63      0.71       252
   sessile oak       0.45      0.15      0.22        

In [2]:
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras import layers, models
import geopandas as gpd
import ast
import os
import json
import joblib
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter

# 1. Define Parameters
time_period = 'all-2022'
patch_size = 5  # Fixed to 5x5 based on GEE code
data_folder = '/kaggle/input/all2022'  # Update with actual path

# Define bands
base_bands = [
    'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12',
    'NDVI', 'EVI', 'SAVI', 'NDWI', 'DEM'
]
time_period_config = {
    'all-2022': {
        'months': ['', '_1', '_2', '_3', '_4', '_5', '_6', '_7', '_8', '_9', '_10', '_11'],
        'num_bands': 204  # 17 bands × 12 months
    }
}

# 2. Function to Load Data
def load_data(folder_path):
    all_features = []
    all_labels = []
    invalid_samples = []
    # Get band columns
    band_columns = [band + month for month in time_period_config[time_period]['months'] for band in base_bands]
    # Verify available columns
    available_columns = None
    for file in os.listdir(folder_path):
        if file.endswith(".geojson"):
            gdf = gpd.read_file(os.path.join(folder_path, file))
            available_columns = [col for col in band_columns if col in gdf.columns]
            # Debugging: Print sample data for first few rows
            print(f"\nInspecting first 2 rows of {file}:")
            for idx in range(min(2, len(gdf))):
                print(f"Row {idx}:")
                for col in available_columns[:5]:  # Print first 5 bands
                    print(f"  {col}: {gdf[col].iloc[idx]}")
                print(f"  l3_species: {gdf['l3_species'].iloc[idx]}")
            break
    if not available_columns:
        raise ValueError(f"No valid GeoJSON files found in {folder_path}")
    print(f"Available columns: {len(available_columns)}/{len(band_columns)}")
    expected_shape = (patch_size, patch_size, len(available_columns))

    # Load all GeoJSON files
    for file in os.listdir(folder_path):
        if file.endswith(".geojson"):
            gdf = gpd.read_file(os.path.join(folder_path, file))
            for idx, row in gdf.iterrows():
                try:
                    if row['l3_species'] is None or not isinstance(row['l3_species'], str):
                        raise ValueError(f"Invalid label at row {idx}: {row['l3_species']}")
                    patches = []
                    for col in available_columns:
                        val = row[col]
                        if val is None:
                            raise ValueError(f"Null value for band {col}")
                        if isinstance(val, (list, np.ndarray)):
                            arr = np.array(val, dtype=np.float32).reshape(patch_size, patch_size)
                        elif isinstance(val, str):
                            try:
                                arr = np.array(json.loads(val), dtype=np.float32).reshape(patch_size, patch_size)
                            except json.JSONDecodeError:
                                arr = np.array(ast.literal_eval(val), dtype=np.float32).reshape(patch_size, patch_size)
                        elif isinstance(val, (float, np.float32, np.float64)):
                            arr = np.full((patch_size, patch_size), val, dtype=np.float32)
                        else:
                            raise ValueError(f"Unexpected data type for band {col}: {type(val)}")
                        patches.append(arr)

                    patch = np.stack(patches, axis=-1)
                    if patch.shape != expected_shape:
                        raise ValueError(f"Unexpected patch shape: {patch.shape}, expected {expected_shape}")
                    all_features.append(patch)
                    all_labels.append(row['l3_species'])
                except Exception as e:
                    invalid_samples.append((file, idx, str(e)))
                    continue

    if invalid_samples:
        print(f"\nSkipped {len(invalid_samples)} invalid samples")
        for file, idx, error in invalid_samples[:5]:
            print(f"File: {file}, Row: {idx}, Error: {error}")
        with open('invalid_samples_all-2022.json', 'w') as f:
            json.dump(invalid_samples, f, indent=4)

    if not all_features:
        raise ValueError("No valid samples loaded. Check GeoJSON files.")
    X = np.array(all_features, dtype=np.float32)
    y = np.array(all_labels)
    print(f"Loaded {len(all_features)} valid samples with shape {X.shape}")
    print(f"Unique labels: {np.unique(y).tolist()}")
    print(f"Label distribution: {dict(Counter(y))}")
    return X, y, invalid_samples, available_columns

# 3. Define CNN Model
def build_cnn(input_shape, num_classes):
    model = models.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

# 4. Data Augmentation
def get_data_augmentation():
    return tf.keras.Sequential([
        layers.RandomFlip("horizontal_and_vertical"),
        layers.RandomRotation(0.2),
    ])

# 5. Main Processing
print(f"\nProcessing time period: {time_period}")
if not os.path.exists(data_folder):
    raise ValueError(f"Folder {data_folder} does not exist.")

# Load data
try:
    X, y, invalid_samples, available_columns = load_data(data_folder)
except Exception as e:
    raise ValueError(f"Error loading data: {e}")

# Handle NaN values
print("\nChecking for NaN values...")
nan_mask = np.any(np.isnan(X), axis=(1, 2, 3))
nan_count = np.sum(nan_mask)
if nan_count > 0:
    print(f"Removing {nan_count} samples with NaN values")
    valid_mask = ~nan_mask
    X = X[valid_mask]
    y = y[valid_mask]
    print(f"New data shape after removing NaN: {X.shape}")

# Normalize data
print("Normalizing data...")
X_min = np.nanmin(X, axis=(0, 1, 2), keepdims=True)
X_max = np.nanmax(X, axis=(0, 1, 2), keepdims=True)
X = (X - X_min) / (X_max - X_min + 1e-6)
print(f"Data range after normalization: min={np.nanmin(X):.4f}, max={np.nanmax(X):.4f}")

# Encode labels
try:
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    y_onehot = tf.keras.utils.to_categorical(y_encoded)
    num_classes = len(label_encoder.classes_)
    print(f"Data shape: {X.shape}, Number of classes: {num_classes}")
    print(f"Class names: {label_encoder.classes_.tolist()}")
except Exception as e:
    raise ValueError(f"Error encoding labels: {e}")

# Train-test-validation split
try:
    X_train, X_test, y_train, y_test = train_test_split(X, y_onehot, test_size=0.15, random_state=42, stratify=y_onehot)
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1765, random_state=42, stratify=y_train)
    print(f"Train shape: {X_train.shape}, Validation shape: {X_val.shape}, Test shape: {X_test.shape}")
except Exception as e:
    raise ValueError(f"Error splitting data: {e}")

# Build and compile model
try:
    model = build_cnn(input_shape=(patch_size, patch_size, len(available_columns)), num_classes=num_classes)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
except Exception as e:
    raise ValueError(f"Error building model: {e}")

# Train model
try:
    history = model.fit(
        get_data_augmentation()(X_train), y_train,
        validation_data=(X_val, y_val),
        epochs=50,
        batch_size=16,
        callbacks=[
            tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
            tf.keras.callbacks.ModelCheckpoint('best_model_all-2022.keras', save_best_only=True)
        ],
        verbose=1
    )
except Exception as e:
    raise ValueError(f"Error training model: {e}")

# Evaluate model
try:
    test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
    print(f"Test Accuracy: {test_accuracy:.4f}, Test Loss: {test_loss:.4f}")

    # Compute additional metrics
    y_pred = model.predict(X_test, verbose=0)
    y_pred_classes = np.argmax(y_pred, axis=1)
    y_test_classes = np.argmax(y_test, axis=1)

    # Classification report with explicit labels
    labels = np.arange(len(label_encoder.classes_))
    class_report = classification_report(y_test_classes, y_pred_classes, labels=labels, target_names=label_encoder.classes_, output_dict=True)
    print(f"\nClassification Report:")
    print(classification_report(y_test_classes, y_pred_classes, labels=labels, target_names=label_encoder.classes_))

    # Confusion matrix
    cm = confusion_matrix(y_test_classes, y_pred_classes, labels=labels)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix: all-2022')
    plt.savefig('confusion_matrix_all-2022.png')
    plt.close()

    # Store results
    results = {
        'time_period': time_period,
        'test_accuracy': test_accuracy,
        'test_loss': test_loss,
        'classification_report': class_report,
        'history': history.history,
        'num_classes': num_classes,
        'class_names': label_encoder.classes_.tolist()
    }
except Exception as e:
    raise ValueError(f"Error evaluating model: {e}")

# 6. Save Results
print("\nSummary of Results:")
print(f"Test Accuracy: {results['test_accuracy']:.4f}")
print(f"Test Loss: {results['test_loss']:.4f}")
print(f"Weighted Precision: {results['classification_report']['weighted avg']['precision']:.4f}")
print(f"Weighted Recall: {results['classification_report']['weighted avg']['recall']:.4f}")
print(f"Weighted F1-Score: {results['classification_report']['weighted avg']['f1-score']:.4f}")
print(f"Macro Precision: {results['classification_report']['macro avg']['precision']:.4f}")
print(f"Macro Recall: {results['classification_report']['macro avg']['recall']:.4f}")
print(f"Macro F1-Score: {results['classification_report']['macro avg']['f1-score']:.4f}")
print(f"Number of Classes: {results['num_classes']}")
print(f"Class Names: {results['class_names']}")

joblib.dump(label_encoder, 'label_encoder_all-2022.pkl')
with open('results_summary_all-2022.json', 'w') as f:
    json.dump({k: v for k, v in results.items() if k != 'history'}, f, indent=4)


Processing time period: all-2022

Inspecting first 2 rows of needleleaf_pine_weymouth pineall-2022.geojson:
Row 0:
  B1: None
  B2: None
  B3: None
  B4: None
  B5: None
  l3_species: weymouth pine
Row 1:
  B1: None
  B2: None
  B3: None
  B4: None
  B5: None
  l3_species: weymouth pine
Available columns: 204/204

Skipped 31805 invalid samples
File: needleleaf_pine_weymouth pineall-2022.geojson, Row: 0, Error: Null value for band B1
File: needleleaf_pine_weymouth pineall-2022.geojson, Row: 1, Error: Null value for band B1
File: needleleaf_pine_weymouth pineall-2022.geojson, Row: 2, Error: Null value for band B1
File: needleleaf_pine_weymouth pineall-2022.geojson, Row: 3, Error: Null value for band B1
File: needleleaf_pine_weymouth pineall-2022.geojson, Row: 4, Error: Null value for band B1
Loaded 2676 valid samples with shape (2676, 5, 5, 204)
Unique labels: ['alder', 'birch', 'black pine', 'cherry', 'douglas fir', 'english oak', 'european ash', 'european beech', 'european larch', 'ja

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))



Classification Report:
                precision    recall  f1-score   support

         alder       0.77      0.87      0.82        31
         birch       0.12      0.06      0.08        17
    black pine       0.00      0.00      0.00         0
        cherry       0.00      0.00      0.00         1
   douglas fir       0.94      0.94      0.94        33
   english oak       0.74      0.80      0.77        54
  european ash       0.61      0.94      0.74        18
european beech       0.59      0.77      0.67        22
european larch       0.44      0.62      0.52        13
japanese larch       0.67      0.53      0.59        15
        linden       0.60      0.60      0.60         5
 norway spruce       0.78      0.50      0.61        36
        poplar       0.00      0.00      0.00         2
       red oak       1.00      0.08      0.14        13
    scots pine       0.85      0.96      0.90       116
   sessile oak       0.53      0.47      0.50        17
    silver fir       0.