In [1]:
# ********************************************************************************
# IMPORTANT: BEFORE RUNNING THIS CELL, PERFORM A "FACTORY RESET RUNTIME" (Colab)
# OR THE EQUIVALENT DEEPEST RESTART IN YOUR ENVIRONMENT (e.g., Kaggle Session Restart).
# THEN, RUN THIS CELL AS THE VERY FIRST CODE IN YOUR NOTEBOOK.
# ********************************************************************************
!pip install -U scikit-learn==1.3.2 imbalanced-learn==0.12.3
import pandas as pd
import numpy as np
import tensorflow as tf
import os
from collections import Counter
import ast
import geopandas as gpd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras import layers, models
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import joblib

import gc


# Now, the imports should work if the environment is truly clean
try:
    from imblearn.over_sampling import SMOTE
    print("\nSuccessfully imported SMOTE.")
except ImportError as e:
    print(f"\nCRITICAL ERROR: Failed to import SMOTE even after aggressive reinstallation: {e}")
    print("This indicates a severe, persistent environment issue.")
    print("Please double-check that you performed a 'Factory reset runtime' (Colab) or equivalent.")
    exit()

# 1. Inspect and Load GeoJSON Files (Modified for Zero Imputation)
data_dir = "/kaggle/input/mar-oct"  # Replace with your folder path
all_features = []
all_labels = []
invalid_samples = []
invalid_bands = Counter()
species_counts = Counter()

# Updated bands list to include all relevant bands
bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12', 'NDVI', 'EVI', 'SAVI', 'NDWI', 'DEM']
months = ['', '_1', '_2', '_3', '_4', '_5', '_6', '_7']
band_columns = [band + month for month in months for band in bands]

# Inspect first file
first_file = os.path.join(data_dir, os.listdir(data_dir)[0]) if os.listdir(data_dir) else None
if first_file and first_file.endswith(".geojson"):
    gdf = gpd.read_file(first_file)
    print("Inspecting first 2 rows of first GeoJSON file:")
    for idx in range(min(2, len(gdf))):
        print(f"\nRow {idx}:")
        for band in ['B1', 'B2', 'B11', 'NDVI', 'DEM', 'B2_1', 'NDVI_7']:
            if band in gdf.columns:
                data = gdf[band].iloc[idx]
                try:
                    parsed_data = ast.literal_eval(data) if isinstance(data, str) else data
                    array = np.array(parsed_data, dtype=np.float32)
                    print(f"  Band {band}: shape={array.shape}, first few values={array.flatten()[:5]}")
                except (ValueError, SyntaxError, TypeError) as e:
                    print(f"  Band {band}: Error parsing/converting: {e}")
            else:
                print(f"  Band {band}: Not found in GeoJSON file")

# Load all GeoJSON files
total_samples_attempted = 0
for file in os.listdir(data_dir):
    if file.endswith(".geojson"):
        try:
            gdf = gpd.read_file(os.path.join(data_dir, file))
            print(f"Processing file: {file}, Rows: {len(gdf)}")
            total_samples_attempted += len(gdf)
            for idx, row in gdf.iterrows():
                try:
                    patch = []
                    for col in band_columns:
                        if col not in gdf.columns:
                            invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)  # Impute with zeros
                            patch.append(array)
                            continue
                        data = row[col]
                        if data is None or (isinstance(data, str) and data.lower() == 'none'):
                            invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)  # Impute with zeros
                            patch.append(array)
                            continue
                        try:
                            parsed_data = ast.literal_eval(data) if isinstance(data, str) else data
                            array = np.array(parsed_data, dtype=np.float32).reshape(5, 5)
                        except (ValueError, SyntaxError, TypeError) as e:
                            invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)  # Impute for parsing errors
                            patch.append(array)
                            continue
                        patch.append(array)
                    patch = np.stack(patch, axis=-1)
                    if patch.shape != (5, 5, 136):  # Expected shape: 17 bands * 8 months
                        raise ValueError(f"Unexpected patch shape: {patch.shape}")
                    all_features.append(patch)
                    all_labels.append(row['l3_species'])
                    species_counts[row['l3_species']] += 1
                except (ValueError, SyntaxError, TypeError) as e:
                    invalid_samples.append((file, idx, str(e)))
                    continue
        except Exception as e:
            print(f"Failed to process file {file}: {e}")
            continue

# Log invalid samples and bands
print(f"\nTotal samples attempted: {total_samples_attempted}")
print(f"Valid samples processed: {len(all_features)}")
if invalid_samples:
    print(f"\nSkipped {len(invalid_samples)} invalid samples:")
    for file, idx, error in invalid_samples:
        print(f"File: {file}, Row: {idx}, Error: {error}")
if invalid_bands:
    print("\nBands with None or missing values:")
    for band, count in invalid_bands.most_common():
        print(f"  {band}: {count} times")
print("\nValid samples per species:")
for species, count in species_counts.most_common():
    print(f"  {species}: {count}")

# Convert to NumPy arrays
if not all_features:
    print("\nError: No valid samples loaded. Using Random Forest with dummy data.")
    X_dummy = np.random.rand(100, 5*5*136)  # Updated for 136 channels
    y_dummy = np.random.randint(0, 5, 100)
    rf = RandomForestClassifier(n_estimators=100, random_state=42)
    rf.fit(X_dummy, y_dummy)
    print("Random Forest dummy accuracy:", rf.score(X_dummy, y_dummy))
    print("Please re-export data with updated GEE code.")
    exit()

X = np.array(all_features, dtype=np.float32)  # Shape: (N, 5, 5, 136)
y = np.array(all_labels)


Collecting scikit-learn==1.3.2
  Downloading scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting imbalanced-learn==0.12.3
  Downloading imbalanced_learn-0.12.3-py3-none-any.whl.metadata (8.3 kB)
Downloading scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading imbalanced_learn-0.12.3-py3-none-any.whl (258 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.3/258.3 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: scikit-learn, imbalanced-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.2.2
    Uninstalling scikit-learn-1.2.2:
      Successfully uninstalled scikit-learn-1.2.2
  Attempting uninstall: imbalanced-learn
    Found existing installation: imbalanced-

2025-07-19 22:39:08.176419: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752964748.364226      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752964748.417591      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered



Successfully imported SMOTE.
Inspecting first 2 rows of first GeoJSON file:

Row 0:
  Band B1: shape=(5, 5), first few values=[0.1118 0.1158 0.1158 0.1158 0.1158]
  Band B2: shape=(5, 5), first few values=[0.11225 0.11335 0.11335 0.11955 0.11955]
  Band B11: shape=(5, 5), first few values=[0.17015 0.1891  0.1891  0.1891  0.1891 ]
  Band NDVI: shape=(5, 5), first few values=[0.27416557 0.2273243  0.2273243  0.28395182 0.28395182]
  Band DEM: shape=(5, 5), first few values=[88. 88. 88. 88. 88.]
  Band B2_1: shape=(), first few values=[nan]
  Band NDVI_7: shape=(), first few values=[nan]

Row 1:
  Band B1: shape=(5, 5), first few values=[0.125 0.124 0.124 0.124 0.124]
  Band B2: shape=(5, 5), first few values=[0.1239 0.1176 0.1176 0.1152 0.117 ]
  Band B11: shape=(5, 5), first few values=[0.1976 0.1797 0.1797 0.1797 0.1779]
  Band NDVI: shape=(5, 5), first few values=[0.40535372 0.36756483 0.36756483 0.36180538 0.34991294]
  Band DEM: shape=(5, 5), first few values=[94. 94. 93. 93. 93.]


In [2]:
import os
import ast
import numpy as np
import pandas as pd
import geopandas as gpd
from collections import Counter
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import json
import joblib
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
import gc

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# 1. Enhanced Preprocessing Pipeline
def preprocess_data(X, y, test_size=0.3, validation_size=0.5):
    """Enhanced preprocessing with better memory management."""
    # Encode class labels
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    
    # Flatten features for initial split
    X_flat = X.reshape(X.shape[0], -1)
    
    # Initial train/temp split
    X_train_raw, X_temp, y_train_raw, y_temp = train_test_split(
        X_flat, y_encoded, test_size=test_size, stratify=y_encoded, random_state=42
    )
    
    # Apply SMOTE only to training set
    print("Applying SMOTE to balance training data...")
    smote = SMOTE(random_state=42, k_neighbors=3)
    X_train_resampled, y_train_resampled = smote.fit_resample(X_train_raw, y_train_raw)
    
    # Scale data
    print("Scaling features...")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_resampled)
    X_temp_scaled = scaler.transform(X_temp)
    
    # Split temp into val and test sets
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp_scaled, y_temp, test_size=validation_size, stratify=y_temp, random_state=42
    )
    
    return (X_train_scaled, X_val, X_test, 
            y_train_resampled, y_val, y_test, 
            label_encoder, scaler)

# 2. Advanced Transformer Architecture
class PositionalEncoding(layers.Layer):
    """Positional encoding for transformer."""
    def __init__(self, max_len, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_encoding = self.positional_encoding(max_len, embed_dim)
    
    def positional_encoding(self, max_len, embed_dim):
        pos = np.arange(max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, embed_dim, 2) * -(np.log(10000.0) / embed_dim))
        
        pos_encoding = np.zeros((max_len, embed_dim))
        pos_encoding[:, 0::2] = np.sin(pos * div_term)
        pos_encoding[:, 1::2] = np.cos(pos * div_term)
        
        return tf.constant(pos_encoding, dtype=tf.float32)
    
    def call(self, x):
        return x + self.pos_encoding[:tf.shape(x)[1], :]

class MultiHeadSelfAttention(layers.Layer):
    """Enhanced multi-head self-attention with residual connections."""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        
        self.projection_dim = embed_dim // num_heads
        self.query_dense = layers.Dense(embed_dim)
        self.key_dense = layers.Dense(embed_dim)
        self.value_dense = layers.Dense(embed_dim)
        self.combine_heads = layers.Dense(embed_dim)
        self.dropout = layers.Dropout(dropout)
        
    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        weights = self.dropout(weights)
        output = tf.matmul(weights, value)
        return output
    
    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        
        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)
        
        attention = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        
        concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim))
        output = self.combine_heads(concat_attention)
        return output

class TransformerBlock(layers.Layer):
    """Enhanced transformer block with improved normalization and residual connections."""
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation='gelu'),
            layers.Dropout(dropout),
            layers.Dense(embed_dim),
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)
    
    def call(self, inputs, training=None):
        attn_output = self.att(inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

def build_advanced_transformer(input_shape, num_classes, embed_dim=64, num_heads=8, 
                             ff_dim=128, num_blocks=3, dropout=0.15):
    """Build an advanced transformer model with multiple blocks."""
    inputs = layers.Input(shape=input_shape)
    
    x = layers.Dense(embed_dim, kernel_regularizer=regularizers.l2(0.001))(inputs)
    
    pos_encoding = PositionalEncoding(input_shape[0], embed_dim)
    x = pos_encoding(x)
    x = layers.Dropout(dropout)(x)
    
    for i in range(num_blocks):
        x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout)(x)
    
    attention_weights = layers.Dense(1, activation='tanh')(x)
    attention_weights = layers.Softmax(axis=1)(attention_weights)
    x = layers.Multiply()([x, attention_weights])
    x = layers.GlobalAveragePooling1D()(x)
    
    x = layers.Dense(128, activation='gelu', kernel_regularizer=regularizers.l2(0.001))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout)(x)
    
    x = layers.Dense(64, activation='gelu', kernel_regularizer=regularizers.l2(0.001))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout)(x)
    
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = models.Model(inputs, outputs, name='advanced_transformer')
    return model

# 3. Enhanced Training Pipeline
def train_advanced_transformer(X_train, y_train, X_val, y_val, num_classes, 
                             sequence_length, class_weight_dict):
    """Train the advanced transformer with sophisticated callbacks."""
    
    stride = max(1, X_train.shape[1] // sequence_length)
    X_train_seq = X_train[:, ::stride][:, :sequence_length].reshape(-1, sequence_length, 1)
    X_val_seq = X_val[:, ::stride][:, :sequence_length].reshape(-1, sequence_length, 1)
    
    y_train_onehot = tf.keras.utils.to_categorical(y_train, num_classes)
    y_val_onehot = tf.keras.utils.to_categorical(y_val, num_classes)
    
    model = build_advanced_transformer(
        input_shape=(sequence_length, 1),
        num_classes=num_classes,
        embed_dim=64,
        num_heads=8,
        ff_dim=128,
        num_blocks=3,
        dropout=0.15
    )
    
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=0.001,
        weight_decay=0.01,
        beta_1=0.9,
        beta_2=0.999

    )
    
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'top_k_categorical_accuracy']
    )
    
    callbacks = [
        EarlyStopping(
            monitor='val_accuracy',
            patience=15,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            'best_advanced_transformer.keras',
            monitor='val_accuracy',
            save_best_only=True,
            save_weights_only=False,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=7,
            min_lr=1e-7,
            verbose=1
        ),
        tf.keras.callbacks.TerminateOnNaN()
    ]
    
    print(f"Training model with {model.count_params():,} parameters...")
    model.summary()
    
    history = model.fit(
        X_train_seq, y_train_onehot,
        validation_data=(X_val_seq, y_val_onehot),
        epochs=100,
        batch_size=8,
        class_weight=class_weight_dict,
        callbacks=callbacks,
        verbose=1
    )
    
    return model, history, X_val_seq, y_val_onehot

# 4. Enhanced Evaluation Function
def evaluate_model(model, X_test, y_test, label_encoder, sequence_length):
    """Comprehensive model evaluation with multiple metrics."""
    
    stride = max(1, X_test.shape[1] // sequence_length)
    X_test_seq = X_test[:, ::stride][:, :sequence_length].reshape(-1, sequence_length, 1)
    y_test_onehot = tf.keras.utils.to_categorical(y_test, len(label_encoder.classes_))
    
    test_loss, test_accuracy, test_top_k = model.evaluate(X_test_seq, y_test_onehot, verbose=0)
    
    y_pred_probs = model.predict(X_test_seq, verbose=0)
    y_pred = np.argmax(y_pred_probs, axis=1)
    
    report = classification_report(y_test, y_pred, target_names=label_encoder.classes_, 
                                 output_dict=True, zero_division=0)
    
    precision = precision_score(y_test, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_test, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
    
    print(f"\nAdvanced Transformer Results:")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Test Top-K Accuracy: {test_top_k:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    
    # Generate and save confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_encoder.classes_, 
                yticklabels=label_encoder.classes_, cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Advanced Transformer - Confusion Matrix')
    plt.tight_layout()
    plt.savefig('advanced_transformer_confusion.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    # Save confusion matrix as NumPy array
    np.save('advanced_transformer_confusion.npy', cm)
    
    # Save classification report
    with open('advanced_transformer_report.json', 'w') as f:
        json.dump(report, f, indent=4)
    
    print("Saved: advanced_transformer_confusion.png, advanced_transformer_confusion.npy, advanced_transformer_report.json")
    
    return report, y_pred_probs, y_pred

# 5. Main Training Pipeline
def main_training_pipeline(X, y):
    """Complete training pipeline with the advanced transformer."""
    
    print("Starting advanced transformer training pipeline...")
    
    (X_train_scaled, X_val, X_test, 
     y_train_resampled, y_val, y_test, 
     label_encoder, scaler) = preprocess_data(X, y)
    
    num_classes = len(label_encoder.classes_)
    sequence_length = min(512, X_train_scaled.shape[1] // 4)
    
    print(f"Data shapes - Train: {X_train_scaled.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
    print(f"Number of classes: {num_classes}")
    print(f"Sequence length: {sequence_length}")
    
    class_weights = compute_class_weight('balanced', 
                                       classes=np.unique(y_train_resampled), 
                                       y=y_train_resampled)
    class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}
    
    model, history, X_val_seq, y_val_onehot = train_advanced_transformer(
        X_train_scaled, y_train_resampled, X_val, y_val, 
        num_classes, sequence_length, class_weight_dict
    )
    
    report, y_pred_probs, y_pred = evaluate_model(
        model, X_test, y_test, label_encoder, sequence_length
    )
    
    model.save('advanced_transformer_model.keras')
    joblib.dump(label_encoder, 'advanced_label_encoder.pkl')
    joblib.dump(scaler, 'advanced_scaler.pkl')
    
    with open('advanced_transformer_report.json', 'w') as f:
        json.dump(report, f, indent=4)
    
    print("\nTraining completed successfully!")
    print("Saved: advanced_transformer_model.keras, advanced_label_encoder.pkl, advanced_scaler.pkl")
    
    return model, history, label_encoder, scaler

# 6. Plot Training History
def plot_training_history(history):
    """Plot training and validation metrics."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)
    
    if 'learning_rate' in history.history:
        ax3.plot(history.history['learning_rate'])
        ax3.set_title('Learning Rate')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Learning Rate')
        ax3.set_yscale('log')
        ax3.grid(True)
    
    if 'top_k_categorical_accuracy' in history.history:
        ax4.plot(history.history['top_k_categorical_accuracy'], label='Training Top-K')
        ax4.plot(history.history['val_top_k_categorical_accuracy'], label='Validation Top-K')
        ax4.set_title('Top-K Accuracy')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Top-K Accuracy')
        ax4.legend()
        ax4.grid(True)
    
    plt.tight_layout()
    plt.savefig('advanced_transformer_training_history.png', dpi=150, bbox_inches='tight')
    plt.close()

# 7. Updated Enhanced Test Data Evaluation
def evaluate_on_final_test_data(model, label_encoder, scaler, sequence_length, test_data_dir="/kaggle/input/final-test-data"):
    """Evaluate the trained transformer model on final test data with proper preprocessing."""
    
    # Define band columns to match training data
    bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12', 'NDVI', 'EVI', 'SAVI', 'NDWI', 'DEM']
    months = ['', '_1', '_2', '_3', '_4', '_5', '_6', '_7']
    band_columns = [band + month for month in months for band in bands]
    
    test_features = []
    test_labels = []
    test_invalid_samples = []
    test_invalid_bands = Counter()
    total_samples_attempted = 0
    
    # Check if directory exists
    if not os.path.exists(test_data_dir):
        print(f"\nError: Test data directory {test_data_dir} does not exist.")
        return None, None, None
    
    geojson_files = [f for f in os.listdir(test_data_dir) if f.endswith(".geojson")]
    if not geojson_files:
        print(f"\nError: No GeoJSON files found in {test_data_dir}")
        return None, None, None
    
    print(f"\nFound {len(geojson_files)} GeoJSON files in {test_data_dir}")
    
    # Load all GeoJSON files
    for file in geojson_files:
        try:
            file_path = os.path.join(test_data_dir, file)
            gdf = gpd.read_file(file_path)
            print(f"Processing file: {file}, Rows: {len(gdf)}")
            total_samples_attempted += len(gdf)
            
            for idx, row in gdf.iterrows():
                try:
                    patch = []
                    for col in band_columns:
                        if col not in gdf.columns:
                            test_invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)
                            patch.append(array)
                            continue
                        
                        data = row[col]
                        if data is None or (isinstance(data, str) and data.lower() == 'none'):
                            test_invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)
                            patch.append(array)
                            continue
                        
                        try:
                            if isinstance(data, str):
                                parsed_data = ast.literal_eval(data)
                            else:
                                parsed_data = data
                            array = np.array(parsed_data, dtype=np.float32).reshape(5, 5)
                            if array.shape != (5, 5):
                                raise ValueError(f"Unexpected array shape for {col}: {array.shape}")
                        except (ValueError, SyntaxError, TypeError) as e:
                            test_invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)
                            patch.append(array)
                        
                        patch.append(array)
                    
                    patch = np.stack(patch, axis=-1)
                    if patch.shape != (5, 5, 136):
                        raise ValueError(f"Unexpected patch shape: {patch.shape}")
                    
                    test_features.append(patch)
                    test_labels.append(row['l3_species'])
                    
                except Exception as e:
                    test_invalid_samples.append((file, idx, str(e)))
                    continue
                    
        except Exception as e:
            print(f"Failed to process file {file}: {e}")
            continue
    
    print(f"\nTotal samples attempted: {total_samples_attempted}")
    print(f"Valid samples processed: {len(test_features)}")
    
    if not test_features:
        print("\nError: No valid test samples loaded. Cannot evaluate model.")
        return None, None, None
    
    if test_invalid_samples:
        print(f"\nSkipped {len(test_invalid_samples)} invalid test samples:")
        for file, idx, error in test_invalid_samples:
            print(f"File: {file}, Row: {idx}, Error: {error}")
    
    if test_invalid_bands:
        print("\nBands with missing/None/parsing issues in test data:")
        for band, count in test_invalid_bands.most_common():
            print(f"  {band}: {count} times")
    
    # Convert to NumPy arrays
    X_test_final = np.array(test_features, dtype=np.float32)  # Shape: (N, 5, 5, 136)
    y_test_final = np.array(test_labels)
    
    # Handle unknown labels
    try:
        known_classes = set(label_encoder.classes_)
        test_classes = set(y_test_final)
        unknown_labels = test_classes - known_classes
        
        if unknown_labels:
            print(f"\nWarning: Unknown labels in test data: {unknown_labels}")
            valid_mask = np.isin(y_test_final, list(known_classes))
            X_test_final = X_test_final[valid_mask]
            y_test_final = y_test_final[valid_mask]
            print(f"Filtered dataset size: {len(y_test_final)} samples")
        
        if len(y_test_final) == 0:
            print("Error: No valid labels found in test data after filtering.")
            return None, None, None
        
        y_test_final_encoded = label_encoder.transform(y_test_final)
        
    except Exception as e:
        print(f"Error in label encoding: {e}")
        return None, None, None
    
    # Preprocess test data
    try:
        # Flatten for scaling
        X_test_final_flat = X_test_final.reshape(X_test_final.shape[0], -1)  # Shape: (N, 5*5*136)
        X_test_final_scaled = scaler.transform(X_test_final_flat)  # Apply same scaler as training
        
        # Reshape for transformer input
        stride = max(1, X_test_final_scaled.shape[1] // sequence_length)
        X_test_final_seq = X_test_final_scaled[:, ::stride][:, :sequence_length].reshape(-1, sequence_length, 1)
        
    except Exception as e:
        print(f"Error in data preprocessing: {e}")
        return None, None, None
    
    # Evaluate model
    try:
        print("\nEvaluating model on final test data...")
        y_pred_final_probs = model.predict(X_test_final_seq, verbose=1)
        y_pred_final_classes = np.argmax(y_pred_final_probs, axis=1)
        
        # Calculate accuracy
        test_accuracy_final = (y_pred_final_classes == y_test_final_encoded).mean()
        print(f"\nFinal Test Data Accuracy: {test_accuracy_final:.4f}")
        
        # Calculate metrics
        precision_final = precision_score(y_test_final_encoded, y_pred_final_classes, 
                                        average='weighted', zero_division=0)
        recall_final = recall_score(y_test_final_encoded, y_pred_final_classes, 
                                  average='weighted', zero_division=0)
        f1_final = f1_score(y_test_final_encoded, y_pred_final_classes, 
                           average='weighted', zero_division=0)
        
        print(f"Precision (Final Test): {precision_final:.4f}")
        print(f"Recall (Final Test): {recall_final:.4f}")
        print(f"F1-Score (Final Test): {f1_final:.4f}")
        
        # Get unique labels in test data
        unique_test_labels = np.unique(y_test_final_encoded)
        unique_test_label_names = label_encoder.inverse_transform(unique_test_labels)
        
        # Classification report
        report_final = classification_report(
            y_test_final_encoded,
            y_pred_final_classes,
            labels=unique_test_labels,
            target_names=unique_test_label_names,
            output_dict=True,
            zero_division=0
        )
        
        print("\nClassification Report for Final Test Data:")
        print(json.dumps(report_final, indent=4))
        
        # Confusion matrix
        cm_final = confusion_matrix(y_test_final_encoded, y_pred_final_classes, 
                                   labels=unique_test_labels)
        
        plt.figure(figsize=(max(10, len(unique_test_labels)), max(8, len(unique_test_labels))))
        sns.heatmap(cm_final, annot=True, fmt='d', 
                   xticklabels=unique_test_label_names, 
                   yticklabels=unique_test_label_names, 
                   cmap='Blues', square=True)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Advanced Transformer - Final Test Data Confusion Matrix')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.savefig('advanced_transformer_final_test_confusion.png', 
                   dpi=150, bbox_inches='tight')
        plt.close()
        
        # Save outputs
        with open('advanced_transformer_final_test_report.json', 'w') as f:
            json.dump(report_final, f, indent=4)
        np.save('advanced_transformer_final_test_confusion.npy', cm_final)
        
        print(f"\nTotal Number of Test Points Evaluated: {len(y_test_final)}")
        print("Saved: advanced_transformer_final_test_report.json, advanced_transformer_final_test_confusion.npy, advanced_transformer_final_test_confusion.png")
        
        return report_final, y_pred_final_probs, y_pred_final_classes
        
    except Exception as e:
        print(f"Error during model evaluation: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None

# 8. Complete Training and Evaluation Pipeline
def complete_pipeline(X, y, test_data_dir="/kaggle/input/final-test-data"):
    """Complete pipeline with better error handling."""
    
    print("=" * 80)
    print("ADVANCED TRANSFORMER TRAINING AND EVALUATION PIPELINE")
    print("=" * 80)
    
    try:
        # Step 1: Train the model
        print("\n1. Training Advanced Transformer Model...")
        model, history, label_encoder, scaler = main_training_pipeline(X, y)
        
        # Step 2: Plot training history
        print("\n2. Plotting Training History...")
        plot_training_history(history)
        
        # Step 3: Evaluate on final test data
        print("\n3. Evaluating on Final Test Data...")
        sequence_length = min(512, X.reshape(X.shape[0], -1).shape[1] // 4)
        
        final_report, final_probs, final_preds = evaluate_on_final_test_data(
            model, label_encoder, scaler, sequence_length, test_data_dir
        )
        
        # Step 4: Summary
        print("\n" + "=" * 80)
        print("PIPELINE COMPLETED!")
        print("=" * 80)
        print("\nGenerated Files:")
        print("- advanced_transformer_model.keras")
        print("- advanced_label_encoder.pkl")
        print("- advanced_scaler.pkl")
        print("- advanced_transformer_report.json")
        print("- advanced_transformer_confusion.png")
        print("- advanced_transformer_confusion.npy")
        print("- advanced_transformer_training_history.png")
        
        if final_report is not None:
            print("- advanced_transformer_final_test_report.json")
            print("- advanced_transformer_final_test_confusion.png")
            print("- advanced_transformer_final_test_confusion.npy")
            print("\nFinal test evaluation: SUCCESS")
        else:
            print("\nFinal test evaluation: SKIPPED (no test data or errors occurred)")
        
        return model, history, label_encoder, scaler, final_report
        
    except Exception as e:
        print(f"\nError in pipeline: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None, None, None

model, history, label_encoder, scaler, final_report = complete_pipeline(X, y)

ADVANCED TRANSFORMER TRAINING AND EVALUATION PIPELINE

1. Training Advanced Transformer Model...
Starting advanced transformer training pipeline...
Applying SMOTE to balance training data...
Scaling features...
Data shapes - Train: (71668, 3400), Val: (5686, 3400), Test: (5687, 3400)
Number of classes: 19
Sequence length: 512


I0000 00:00:1752965126.070526      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Training model with 119,188 parameters...


Epoch 1/100


I0000 00:00:1752965145.875352      69 service.cc:148] XLA service 0x7fc9280043d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1752965145.875988      69 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1752965147.481250      69 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m   7/8959[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3:04[0m 21ms/step - accuracy: 0.0161 - loss: 3.1155 - top_k_categorical_accuracy: 0.2290

I0000 00:00:1752965157.413758      69 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m8959/8959[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.2469 - loss: 2.3856 - top_k_categorical_accuracy: 0.6806
Epoch 1: val_accuracy improved from -inf to 0.35878, saving model to best_advanced_transformer.keras
[1m8959/8959[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 22ms/step - accuracy: 0.2469 - loss: 2.3855 - top_k_categorical_accuracy: 0.6806 - val_accuracy: 0.3588 - val_loss: 2.0655 - val_top_k_categorical_accuracy: 0.8336 - learning_rate: 0.0010
Epoch 2/100
[1m8957/8959[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 20ms/step - accuracy: 0.4369 - loss: 1.7991 - top_k_categorical_accuracy: 0.8457
Epoch 2: val_accuracy did not improve from 0.35878
[1m8959/8959[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m182s[0m 20ms/step - accuracy: 0.4369 - loss: 1.7991 - top_k_categorical_accuracy: 0.8457 - val_accuracy: 0.3303 - val_loss: 2.4789 - val_top_k_categorical_accuracy: 0.7689 - learning_rate: 0.0010
Epoch 3/100
[1m8959/8