In [17]:
!pip install -U scikit-learn==1.3.2 imbalanced-learn==0.12.3 xgboost==2.0.3 seaborn matplotlib
import pandas as pd
import numpy as np
import os
from collections import Counter
import ast
import geopandas as gpd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, recall_score, f1_score
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.svm import LinearSVC
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
from sklearn.calibration import CalibratedClassifierCV
import seaborn as sns
import matplotlib.pyplot as plt
import joblib
import json
import xgboost as xgb
import gc




In [18]:
# Now, the imports should work if the environment is truly clean
try:
    from imblearn.over_sampling import SMOTE
    print("\nSuccessfully imported SMOTE.")
except ImportError as e:
    print(f"\nCRITICAL ERROR: Failed to import SMOTE even after aggressive reinstallation: {e}")
    print("This indicates a severe, persistent environment issue.")
    print("Please double-check that you performed a 'Factory reset runtime' (Colab) or equivalent.")
    exit()



Successfully imported SMOTE.


In [19]:
# 1. Inspect and Load GeoJSON Files (Modified for Zero Imputation)
from collections import Counter
data_dir = "/Users/anoushk/Library/CloudStorage/GoogleDrive-anoushkv@gmail.com/My Drive/TUM Munich MSC ITBE SEM 2/[ED110087]  Data Science in Earth Observation/Working/Mar-Oct"  # Replace with your folder path
all_features = []
all_labels = []
invalid_samples = []
invalid_bands = Counter()
species_counts = Counter()

# Updated bands list to include all relevant bands
bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12', 'NDVI', 'EVI', 'SAVI', 'NDWI', 'DEM']
months = ['', '_1', '_2', '_3', '_4', '_5', '_6', '_7']
band_columns = [band + month for month in months for band in bands]

# Inspect first file
first_file = os.path.join(data_dir, os.listdir(data_dir)[0]) if os.listdir(data_dir) else None
if first_file and first_file.endswith(".geojson"):
    gdf = gpd.read_file(first_file)
    print("Inspecting first 2 rows of first GeoJSON file:")
    for idx in range(min(2, len(gdf))):
        print(f"\nRow {idx}:")
        for band in ['B1', 'B2', 'B11', 'NDVI', 'DEM', 'B2_1', 'NDVI_7']:
            if band in gdf.columns:
                data = gdf[band].iloc[idx]
                try:
                    parsed_data = ast.literal_eval(data) if isinstance(data, str) else data
                    array = np.array(parsed_data, dtype=np.float32)
                    print(f"  Band {band}: shape={array.shape}, first few values={array.flatten()[:5]}")
                except (ValueError, SyntaxError, TypeError) as e:
                    print(f"  Band {band}: Error parsing/converting: {e}")
            else:
                print(f"  Band {band}: Not found in GeoJSON file")


Inspecting first 2 rows of first GeoJSON file:

Row 0:
  Band B1: shape=(5, 5), first few values=[0.122 0.122 0.122 0.122 0.122]
  Band B2: shape=(5, 5), first few values=[0.1302 0.127  0.127  0.1267 0.1267]
  Band B11: shape=(5, 5), first few values=[0.30185 0.30185 0.30185 0.27495 0.27495]
  Band NDVI: shape=(5, 5), first few values=[0.28222454 0.26450115 0.26450115 0.25972718 0.25972718]
  Band DEM: shape=(5, 5), first few values=[124. 124. 123. 123. 123.]
  Band B2_1: shape=(5, 5), first few values=[0.141  0.14   0.14   0.1393 0.1393]
  Band NDVI_7: shape=(5, 5), first few values=[0.58914065 0.58784527 0.58784527 0.55536085 0.55536085]

Row 1:
  Band B1: shape=(5, 5), first few values=[0.11365 0.11365 0.11365 0.11365 0.11365]
  Band B2: shape=(5, 5), first few values=[0.1143  0.11055 0.11055 0.1117  0.11315]
  Band B11: shape=(5, 5), first few values=[0.14995    0.14995    0.14995    0.14985001 0.14985001]
  Band NDVI: shape=(5, 5), first few values=[0.33995953 0.2707577  0.2707577

In [20]:
# Load all GeoJSON files
total_samples_attempted = 0
for file in os.listdir(data_dir):
    if file.endswith(".geojson"):
        try:
            gdf = gpd.read_file(os.path.join(data_dir, file))
            print(f"Processing file: {file}, Rows: {len(gdf)}")
            total_samples_attempted += len(gdf)
            for idx, row in gdf.iterrows():
                try:
                    patch = []
                    for col in band_columns:
                        if col not in gdf.columns:
                            invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)  # Impute with zeros
                            patch.append(array)
                            continue
                        data = row[col]
                        if data is None or (isinstance(data, str) and data.lower() == 'none'):
                            invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)  # Impute with zeros
                            patch.append(array)
                            continue
                        try:
                            parsed_data = ast.literal_eval(data) if isinstance(data, str) else data
                            array = np.array(parsed_data, dtype=np.float32).reshape(5, 5)
                        except (ValueError, SyntaxError, TypeError) as e:
                            invalid_bands[col] += 1
                            array = np.zeros((5, 5), dtype=np.float32)  # Impute for parsing errors
                            patch.append(array)
                            continue
                        patch.append(array)
                    patch = np.stack(patch, axis=-1)
                    if patch.shape != (5, 5, 136):  # Expected shape: 17 bands * 8 months
                        raise ValueError(f"Unexpected patch shape: {patch.shape}")
                    all_features.append(patch)
                    all_labels.append(row['l3_species'])
                    species_counts[row['l3_species']] += 1
                except (ValueError, SyntaxError, TypeError) as e:
                    invalid_samples.append((file, idx, str(e)))
                    continue
        except Exception as e:
            print(f"Failed to process file {file}: {e}")
            continue


Processing file: broadleaf_long-lived deciduous_sycamore maplemar-oct-2022.geojson, Rows: 2096
Processing file: needleleaf_pine_scots pinemar-oct-2022.geojson, Rows: 5389
Processing file: broadleaf_short-lived deciduous_poplarmar-oct-2022.geojson, Rows: 387
Processing file: broadleaf_long-lived deciduous_cherrymar-oct-2022.geojson, Rows: 247
Processing file: needleleaf_spruce_norway sprucemar-oct-2022.geojson, Rows: 5037
Processing file: needleleaf_larch_japanese larchmar-oct-2022.geojson, Rows: 1613
Processing file: needleleaf_fir_silver firmar-oct-2022.geojson, Rows: 811
Processing file: broadleaf_long-lived deciduous_lindenmar-oct-2022.geojson, Rows: 161
Processing file: broadleaf_oak_english oakmar-oct-2022.geojson, Rows: 2808
Processing file: broadleaf_beech_european beechmar-oct-2022.geojson, Rows: 4756
Processing file: needleleaf_pine_black pinemar-oct-2022.geojson, Rows: 412
Processing file: broadleaf_short-lived deciduous_birchmar-oct-2022.geojson, Rows: 2468
Processing file: 

In [21]:
# Log invalid samples and bands
print(f"\nTotal samples attempted: {total_samples_attempted}")
print(f"Valid samples processed: {len(all_features)}")
if invalid_samples:
    print(f"\nSkipped {len(invalid_samples)} invalid samples:")
    for file, idx, error in invalid_samples:
        print(f"File: {file}, Row: {idx}, Error: {error}")
if invalid_bands:
    print("\nBands with None or missing values:")
    for band, count in invalid_bands.most_common():
        print(f"  {band}: {count} times")
print("\nValid samples per species:")
for species, count in species_counts.most_common():
    print(f"  {species}: {count}")



Total samples attempted: 37907
Valid samples processed: 37907

Bands with None or missing values:
  B2_4: 25577 times
  B3_4: 25577 times
  B4_4: 25577 times
  B5_4: 25577 times
  B6_4: 25577 times
  B7_4: 25577 times
  B8_4: 25577 times
  B8A_4: 25577 times
  B11_4: 25577 times
  B12_4: 25577 times
  NDVI_4: 25577 times
  EVI_4: 25577 times
  SAVI_4: 25577 times
  NDWI_4: 25577 times
  B1_4: 25569 times
  B9_4: 25569 times
  DEM_4: 23721 times
  B2_7: 21786 times
  B3_7: 21786 times
  B4_7: 21786 times
  B5_7: 21786 times
  B6_7: 21786 times
  B7_7: 21786 times
  B8_7: 21786 times
  B8A_7: 21786 times
  B11_7: 21786 times
  B12_7: 21786 times
  NDVI_7: 21786 times
  EVI_7: 21786 times
  SAVI_7: 21786 times
  NDWI_7: 21786 times
  B1_7: 21781 times
  B9_7: 21781 times
  DEM_7: 20915 times
  B2_1: 5484 times
  B3_1: 5484 times
  B4_1: 5484 times
  B5_1: 5484 times
  B6_1: 5484 times
  B7_1: 5484 times
  B8_1: 5484 times
  B8A_1: 5484 times
  B11_1: 5484 times
  B12_1: 5484 times
  NDVI

In [22]:
# Convert to NumPy arrays
if not all_features:
    print("\nError: No valid samples loaded. Using Random Forest with dummy data.")
    X_dummy = np.random.rand(100, 5*5*136)  # Updated for 136 channels
    y_dummy = np.random.randint(0, 5, 100)
    rf = RandomForestClassifier(n_estimators=100, random_state=42)
    rf.fit(X_dummy, y_dummy)
    print("Random Forest dummy accuracy:", rf.score(X_dummy, y_dummy))
    print("Please re-export data with updated GEE code.")
    exit()

X = np.array(all_features, dtype=np.float32)  # Shape: (N, 5, 5, 136)
y = np.array(all_labels)

In [23]:
# 2. Preprocess Data (Corrected)

# Encode class labels to integers
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

# Flatten features for SMOTE
X_flat = X.reshape(X.shape[0], -1)

# Split the original data into training and temp (for later val/test)
X_train_raw, X_temp, y_train_raw, y_temp = train_test_split(
    X_flat, y_encoded, test_size=0.3, stratify=y_encoded, random_state=42
)

# Apply SMOTE only to the training set to avoid data leakage
smote = SMOTE(random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train_raw, y_train_raw)

# Standardize the resampled training data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_resampled)

# Apply same scaling to validation/test split
X_temp_scaled = scaler.transform(X_temp)

# Split temp into validation and test sets
X_val, X_test, y_val, y_test = train_test_split(
    X_temp_scaled, y_temp, test_size=0.5, stratify=y_temp, random_state=42
)







In [24]:
# Compute class weights from the resampled training data
class_weights = compute_class_weight('balanced', classes=np.unique(y_train_resampled), y=y_train_resampled)
class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}


# 3. Define Ensemble Model (Corrected for Proper SMOTE and Early Stopping)

# SVM with probability estimates
svm = CalibratedClassifierCV(
    LinearSVC(C=0.1, max_iter=1000, class_weight=class_weight_dict, random_state=42),
    cv=3
)

# Random Forest with class weights
rf = RandomForestClassifier(
    n_estimators=50,
    max_depth=10,
    class_weight=class_weight_dict,
    random_state=42,
    n_jobs=-1
)

# XGBoost with early stopping and probability output for mlogloss
xgb_model = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=len(np.unique(y_train_resampled)),
    max_depth=6,
    learning_rate=0.05,
    n_estimators=1000,
    subsample=0.8,
    colsample_bytree=0.8,
    device='cuda',
    random_state=42,
    eval_metric='mlogloss',
    early_stopping_rounds=10
)

# Ensemble model with soft voting
model7 = VotingClassifier(
    estimators=[
        ('svm', svm),
        ('rf', rf),
        ('xgb', xgb_model)
    ],
    voting='soft'
)



In [26]:
# 4. Train Ensemble Model (Corrected Training Flow)

# --- Use SGDClassifier for much faster SVM training on high-dimensional data ---
from sklearn.linear_model import SGDClassifier
from sklearn.calibration import CalibratedClassifierCV

svm = CalibratedClassifierCV(
    SGDClassifier(loss="hinge", class_weight=class_weight_dict, max_iter=1000, tol=1e-3, random_state=42),
    cv=3
)

# Train SVM
print("\nTraining SVM...")
svm.fit(X_train_scaled, y_train_resampled)

# Train Random Forest
print("Training Random Forest...")
rf.fit(X_train_scaled, y_train_resampled)

# Train XGBoost with early stopping using validation set
print("\nTraining XGBoost with early stopping...")
xgb_model.fit(
    X_train_scaled,
    y_train_resampled,
    eval_set=[(X_train_scaled, y_train_resampled), (X_val, y_val)],
    verbose=True
)

# Plot training and validation log loss for XGBoost
evals_result = xgb_model.evals_result()
train_loss = evals_result['validation_0']['mlogloss']
val_loss = evals_result['validation_1']['mlogloss']
epochs = range(1, len(train_loss) + 1)

plt.figure(figsize=(10, 6), dpi=100)
plt.plot(epochs, train_loss, label='Training Loss')
plt.plot(epochs, val_loss, label='Validation Loss')
plt.xlabel('Boosting Rounds')
plt.ylabel('Multi-class Log Loss')
plt.title('XGBoost Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.savefig('overfitting_plot.png', dpi=100)
plt.close()
print("Saved: overfitting_plot.png")

joblib.dump(svm,      'svm_model.pkl')
joblib.dump(rf,       'rf_model.pkl')
joblib.dump(xgb_model,'xgb_model.pkl')

print("Saved: svm_model.pkl, rf_model.pkl, xgb_model.pkl")


Training SVM...




Training Random Forest...

Training XGBoost with early stopping...
[0]	validation_0-mlogloss:2.73016	validation_1-mlogloss:2.78296
[1]	validation_0-mlogloss:2.56870	validation_1-mlogloss:2.66199
[2]	validation_0-mlogloss:2.43473	validation_1-mlogloss:2.56176
[3]	validation_0-mlogloss:2.32003	validation_1-mlogloss:2.47600
[4]	validation_0-mlogloss:2.21890	validation_1-mlogloss:2.39996
[5]	validation_0-mlogloss:2.12870	validation_1-mlogloss:2.33188
[6]	validation_0-mlogloss:2.04741	validation_1-mlogloss:2.26982
[7]	validation_0-mlogloss:1.97312	validation_1-mlogloss:2.21356
[8]	validation_0-mlogloss:1.90519	validation_1-mlogloss:2.16244
[9]	validation_0-mlogloss:1.84186	validation_1-mlogloss:2.11406
[10]	validation_0-mlogloss:1.78322	validation_1-mlogloss:2.07015
[11]	validation_0-mlogloss:1.72888	validation_1-mlogloss:2.02886
[12]	validation_0-mlogloss:1.67749	validation_1-mlogloss:1.98927
[13]	validation_0-mlogloss:1.62992	validation_1-mlogloss:1.95344
[14]	validation_0-mlogloss:1.5845

In [27]:
# After xgb_model.fit(...) with early stopping
best_n_estimators = xgb_model.best_iteration + 1  # or xgb_model.best_ntree_limit

# Redefine xgb_model for VotingClassifier
xgb_model_ensemble = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=len(np.unique(y_train_resampled)),
    max_depth=6,
    learning_rate=0.05,
    n_estimators=best_n_estimators,  # Use best number of trees
    subsample=0.8,
    colsample_bytree=0.8,
    device='cuda',
    random_state=42,
    eval_metric='mlogloss'
)

model7 = VotingClassifier(
    estimators=[
        ('svm', svm),
        ('rf', rf),
        ('xgb', xgb_model_ensemble)
    ],
    voting='soft'
)

# 5. Train Final Ensemble Model

print("\nTraining ensemble model...")
model7.fit(X_train_scaled, y_train_resampled)







Training ensemble model...




In [28]:

# 6. Evaluate Model on Validation Set

y_pred = model7.predict(X_test)
test_accuracy = (y_pred == y_test).mean()
print(f"\nValidation Set Accuracy: {test_accuracy:.4f}")

# Generate classification report
report7 = classification_report(y_test, y_pred, target_names=label_encoder.classes_, output_dict=True)
print("\nClassification Report (Validation Set):")
print(json.dumps(report7, indent=4))
print(f"Recall: {recall_score(y_test, y_pred, average='weighted'):.4f}")
print(f"F1-Score: {f1_score(y_test, y_pred, average='weighted'):.4f}")


Validation Set Accuracy: 0.7229

Classification Report (Validation Set):
{
    "alder": {
        "precision": 0.6426229508196721,
        "recall": 0.6086956521739131,
        "f1-score": 0.6251993620414673,
        "support": 322.0
    },
    "birch": {
        "precision": 0.5289256198347108,
        "recall": 0.518918918918919,
        "f1-score": 0.52387448840382,
        "support": 370.0
    },
    "black pine": {
        "precision": 0.5512820512820513,
        "recall": 0.6935483870967742,
        "f1-score": 0.6142857142857143,
        "support": 62.0
    },
    "cherry": {
        "precision": 0.5121951219512195,
        "recall": 0.5675675675675675,
        "f1-score": 0.5384615384615384,
        "support": 37.0
    },
    "douglas fir": {
        "precision": 0.891156462585034,
        "recall": 0.7987804878048781,
        "f1-score": 0.842443729903537,
        "support": 328.0
    },
    "english oak": {
        "precision": 0.6560364464692483,
        "recall": 0.6840855

In [29]:
# 7. Confusion Matrix

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 8), dpi=100)
sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Validation Set)')
plt.savefig('confusion7.png', dpi=100)
plt.close()


# 8. Save Outputs

joblib.dump(model7, 'model7.pkl')
json.dump(report7, open('report7.json', 'w'), indent=4)
np.save('confusion7.npy', cm)
joblib.dump(label_encoder, 'labelencoder7.pkl')
joblib.dump(scaler, 'scaler7.pkl')

print("Saved: model7.pkl, report7.json, confusion7.npy, confusion7.png, labelencoder7.pkl, scaler7.pkl")

Saved: model7.pkl, report7.json, confusion7.npy, confusion7.png, labelencoder7.pkl, scaler7.pkl


Test data

In [30]:
# 9. Evaluate Model on Final Test Data
test_data_dir = "/Users/anoushk/Library/CloudStorage/GoogleDrive-anoushkv@gmail.com/My Drive/TUM Munich MSC ITBE SEM 2/[ED110087]  Data Science in Earth Observation/Working/Final-Test-Data"
test_features = []
test_labels = []
test_invalid_samples = []
test_invalid_bands = Counter()
total_samples_attempted = 0  # Track total samples processed

# Check if directory exists and list files
if not os.path.exists(test_data_dir):
    print(f"\nError: Test data directory {test_data_dir} does not exist.")
    exit()
geojson_files = [f for f in os.listdir(test_data_dir) if f.endswith(".geojson")]
print(f"\nFound {len(geojson_files)} GeoJSON files in {test_data_dir}")

# Load all GeoJSON files from test data directory
for file in geojson_files:
    try:
        file_path = os.path.join(test_data_dir, file)
        gdf = gpd.read_file(file_path)
        print(f"Processing file: {file}, Rows: {len(gdf)}")
        total_samples_attempted += len(gdf)  # Count all rows in the file
        for idx, row in gdf.iterrows():
            try:
                patch = []
                for col in band_columns:
                    if col not in gdf.columns:
                        test_invalid_bands[col] += 1
                        array = np.zeros((5, 5), dtype=np.float32)  # Impute with zeros
                        patch.append(array)
                        continue
                    data = row[col]
                    if data is None or (isinstance(data, str) and data.lower() == 'none'):
                        test_invalid_bands[col] += 1
                        array = np.zeros((5, 5), dtype=np.float32)  # Impute with zeros
                        patch.append(array)
                        continue
                    try:
                        parsed_data = ast.literal_eval(data) if isinstance(data, str) else data
                        array = np.array(parsed_data, dtype=np.float32).reshape(5, 5)
                    except (ValueError, SyntaxError, TypeError) as e:
                        test_invalid_bands[col] += 1
                        array = np.zeros((5, 5), dtype=np.float32)  # Impute for parsing errors
                        patch.append(array)
                        continue
                    patch.append(array)
                patch = np.stack(patch, axis=-1)
                if patch.shape != (5, 5, 136):  # Expected shape: 17 bands * 8 months
                    raise ValueError(f"Unexpected patch shape: {patch.shape}")
                test_features.append(patch)
                test_labels.append(row['l3_species'])
            except (ValueError, SyntaxError, TypeError) as e:
                test_invalid_samples.append((file, idx, str(e)))
                continue
    except Exception as e:
        print(f"Failed to process file {file}: {e}")
        continue

# Log invalid samples and bands
print(f"\nTotal samples attempted: {total_samples_attempted}")
print(f"Valid samples processed: {len(test_features)}")
if test_invalid_samples:
    print(f"\nSkipped {len(test_invalid_samples)} invalid test samples:")
    for file, idx, error in test_invalid_samples:
        print(f"File: {file}, Row: {idx}, Error: {error}")
if test_invalid_bands:
    print("\nBands with missing/None/parsing issues in test data:")
    for band, count in test_invalid_bands.most_common():
        print(f"  {band}: {count} times")

# Convert to NumPy arrays
if not test_features:
    print("\nError: No valid test samples loaded. Cannot evaluate model.")
    exit()

X_test_final = np.array(test_features, dtype=np.float32)  # Shape: (N, 5, 5, 136)
y_test_final = np.array(test_labels)


Found 19 GeoJSON files in /Users/anoushk/Library/CloudStorage/GoogleDrive-anoushkv@gmail.com/My Drive/TUM Munich MSC ITBE SEM 2/[ED110087]  Data Science in Earth Observation/Working/Final-Test-Data
Processing file: broadleaf_long-lived deciduous_sycamore maplemar-oct-2022.geojson, Rows: 725
Processing file: needleleaf_pine_scots pinemar-oct-2022.geojson, Rows: 1202
Processing file: broadleaf_short-lived deciduous_poplarmar-oct-2022.geojson, Rows: 77
Processing file: broadleaf_long-lived deciduous_cherrymar-oct-2022.geojson, Rows: 57
Processing file: needleleaf_spruce_norway sprucemar-oct-2022.geojson, Rows: 746
Processing file: needleleaf_larch_japanese larchmar-oct-2022.geojson, Rows: 135
Processing file: needleleaf_fir_silver firmar-oct-2022.geojson, Rows: 173
Processing file: broadleaf_long-lived deciduous_lindenmar-oct-2022.geojson, Rows: 51
Processing file: broadleaf_oak_english oakmar-oct-2022.geojson, Rows: 645
Processing file: broadleaf_beech_european beechmar-oct-2022.geojson

In [32]:
# Preprocess test data
try:
    y_test_final_encoded = label_encoder.transform(y_test_final)  # Use same LabelEncoder
except ValueError as e:
    print(f"Error in label encoding: {e}")
    unknown_labels = set(y_test_final) - set(label_encoder.classes_)
    print(f"Unknown labels in test data: {unknown_labels}")
    exit()
X_test_final_flat = X_test_final.reshape(X_test_final.shape[0], -1)  # Flatten for prediction
X_test_final_scaled = scaler.transform(X_test_final_flat)  # Use same StandardScaler




In [34]:
# Diagnostic: Compare class distributions
print("\nTraining class distribution:")
print(pd.Series(label_encoder.inverse_transform(y_train_resampled)).value_counts())
print("\nTest class distribution:")
print(pd.Series(y_test_final).value_counts())
missing_classes = set(label_encoder.classes_) - set(y_test_final)
print(f"Classes missing in test data: {missing_classes}")

# Diagnostic: Compare feature distributions
print("\nTraining feature stats (after scaling):")
print(f"Mean: {X_train_scaled.mean():.4f}, Std: {X_train_scaled.std():.4f}")
print("\nTest feature stats (after scaling):")
print(f"Mean: {X_test_final_scaled.mean():.4f}, Std: {X_test_final_scaled.std():.4f}")

# Evaluate model on test data
y_pred_final = model7.predict(X_test_final_scaled)
test_accuracy_final = (y_pred_final == y_test_final_encoded).mean()
print(f"\nFinal Test Data Accuracy: {test_accuracy_final:.4f}")

# Get unique labels in test data to avoid mismatch
unique_test_labels = np.unique(y_test_final_encoded)
unique_test_label_names = label_encoder.inverse_transform(unique_test_labels)

# Additional metrics for test data
report_final = classification_report(
    y_test_final_encoded,
    y_pred_final,
    labels=unique_test_labels,
    target_names=unique_test_label_names,
    output_dict=True
)
print("\nClassification Report for Final Test Data:")
print(json.dumps(report_final, indent=4))
print(f"Recall (Final Test): {recall_score(y_test_final_encoded, y_pred_final, average='weighted'):.4f}")
print(f"F1-Score (Final Test): {f1_score(y_test_final_encoded, y_pred_final, average='weighted'):.4f}")

# Confusion matrix for test data
cm_final = confusion_matrix(y_test_final_encoded, y_pred_final, labels=unique_test_labels)
plt.figure(figsize=(10, 8), dpi=100)
sns.heatmap(cm_final, annot=True, fmt='d', xticklabels=unique_test_label_names, yticklabels=unique_test_label_names, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix for Final Test Data')
plt.savefig('confusion_final.png', dpi=100)
plt.close()

# Total number of test points
print(f"\nTotal Number of Test Points: {len(y_test_final)}")

# Save outputs for test data
json.dump(report_final, open('report_final.json', 'w'), indent=4)
np.save('confusion_final.npy', cm_final)


Training class distribution:
sycamore maple    3772
european ash      3772
poplar            3772
alder             3772
douglas fir       3772
black pine        3772
birch             3772
cherry            3772
sessile oak       3772
japanese larch    3772
red oak           3772
european larch    3772
scots pine        3772
linden            3772
norway spruce     3772
silver fir        3772
european beech    3772
english oak       3772
weymouth pine     3772
Name: count, dtype: int64

Test class distribution:
european beech    1703
scots pine        1202
norway spruce      746
sycamore maple     725
english oak        645
douglas fir        506
sessile oak        493
european ash       432
alder              420
red oak            381
birch              353
european larch     221
silver fir         173
japanese larch     135
poplar              77
cherry              57
linden              51
weymouth pine       22
black pine           9
Name: count, dtype: int64
Classes missing in