In [1]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import (
    classification_report,
    confusion_matrix
)

In [2]:
INPUT_FILE = "../../data/preprocessed/final_training_dataset.csv"

In [3]:
df = pd.read_csv(INPUT_FILE)

In [4]:
df.head()

Unnamed: 0,valid_time,latitude,longitude,u10,v10,d2m,t2m,lai_hv,lai_lv,wind_speed,...,lulc_class_2,lulc_class_3,lulc_class_5,lulc_class_15,lulc_class_19,lulc_class_9,lulc_class_17,lulc_class_13,lulc_class_11,fire_occurred
0,2016-01-01 00:00:00,31.3,77.5,-1.796256,-0.72764,-15.016418,-1.26651,2.174683,1.621314,1.938039,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2016-01-01 00:00:00,31.3,77.75,-2.54003,-1.461649,-20.002747,-6.835846,2.435181,1.207129,2.930558,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
2,2016-01-01 00:00:00,31.3,78.0,-2.340568,-1.2595,-23.637512,-11.565338,3.491089,1.080359,2.657931,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,2016-01-01 00:00:00,31.3,78.25,-1.654532,-0.617289,-29.244934,-14.932526,4.195557,0.847266,1.765934,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,2016-01-01 00:00:00,31.3,78.5,-1.472403,-0.341532,-34.576965,-19.01065,3.301758,0.702429,1.511495,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [5]:
df['valid_time'] = pd.to_datetime(df['valid_time'])
df = df.sort_values(by='valid_time').reset_index(drop=True)

In [6]:
target = 'fire_occurred'

In [7]:
# We exclude non-predictor columns
drop_cols = [
    'valid_time', 'latitude', 'longitude',  # Identifiers
    'number', 'expver',  # Metadata from ERA5
    'fire_occurred'  # The target itself
]

In [8]:
lulc_cols = [col for col in df.columns if col.startswith('lulc_class_')]

In [9]:
features = [col for col in df.columns if col not in drop_cols]

In [10]:
df[features] = df[features].fillna(0)

## Splitting data by time

In [11]:
fire_counts = df.groupby(df['valid_time'].dt.year)['fire_occurred'].sum()
print(fire_counts)

valid_time
2016    47472.0
2017    32344.0
2018    40920.0
Name: fire_occurred, dtype: float64


In [12]:
# Select all years that contain at least one fire
valid_years = fire_counts[fire_counts > 0].index.tolist()
print(f"Found {len(valid_years)} years with fire data: {valid_years}")

if len(valid_years) < 2:
    print("Error: Not enough years with fire data to split. Found:", valid_years)
    exit()

test_year = valid_years[-1]
train_years = valid_years[:-1]

Found 3 years with fire data: [2016, 2017, 2018]


In [13]:
print(f"Training on years: {train_years}")
print(f"Testing on year:   {test_year}")

Training on years: [2016, 2017]
Testing on year:   2018


In [14]:
train_df = df[df['valid_time'].dt.year.isin(train_years)]
test_df = df[df['valid_time'].dt.year == test_year]

In [15]:
X_train = train_df[features]
y_train = train_df[target]

X_test = test_df[features]
y_test = test_df[target]

In [16]:
train_counts = y_train.value_counts()
imbalance_ratio = train_counts[0] / train_counts[1]
print(f"\nImbalance Ratio (No Fire / Fire) in Training Data: {imbalance_ratio:.2f}")


Imbalance Ratio (No Fire / Fire) in Training Data: 10.49


# XGBoost

In [17]:
xgb_model = XGBClassifier(
    n_estimators=100,
    scale_pos_weight=imbalance_ratio,
    random_state=42,
    n_jobs=-1,
#    use_label_encoder=False,
    eval_metric='logloss'
)

In [18]:
xgb_model.fit(X_train, y_train)
print("XGBoost training complete.")

XGBoost training complete.


# Random Forest

In [19]:
rf_model = RandomForestClassifier(
    n_estimators=100,
    class_weight='balanced',
    random_state=42,
    n_jobs=-1
)

In [20]:
rf_model.fit(X_train, y_train)
print("Random Forest training complete.")

Random Forest training complete.


# Evaluate Models

In [21]:
def evaluate_model(model, X_test, y_test):
    # Check if there are any fire events in the test set to evaluate
    if 1 not in y_test.value_counts():
        print("--- WARNING: No fire events found in the test set. Cannot calculate metrics. ---")
        return

    # Get predictions
    y_pred = model.predict(X_test)

    print("\n--- Classification Report ---")
    # Focus on the 'Fire (1)' row for Precision, Recall, F1-Score
    print(classification_report(y_test, y_pred, target_names=['No Fire (0)', 'Fire (1)']))

    print("--- Confusion Matrix ---")
    cm = confusion_matrix(y_test, y_pred)
    print(cm)
    print(f"True Negatives (No Fire predicted as No Fire): {cm[0][0]}")
    print(f"False Positives (No Fire predicted as Fire): {cm[0][1]} (False Alarms)")
    print(f"False Negatives (Fire predicted as No Fire): {cm[1][0]} (Missed Fires)")
    print(f"True Positives (Fire predicted as Fire):    {cm[1][1]} (Correct Detections)")


## Model 1: XGBoost Results

In [22]:
evaluate_model(xgb_model, X_test, y_test)


--- Classification Report ---
              precision    recall  f1-score   support

 No Fire (0)       0.97      0.83      0.90    417120
    Fire (1)       0.31      0.75      0.44     40920

    accuracy                           0.83    458040
   macro avg       0.64      0.79      0.67    458040
weighted avg       0.91      0.83      0.86    458040

--- Confusion Matrix ---
[[347441  69679]
 [ 10156  30764]]
True Negatives (No Fire predicted as No Fire): 347441
False Positives (No Fire predicted as Fire): 69679 (False Alarms)
False Negatives (Fire predicted as No Fire): 10156 (Missed Fires)
True Positives (Fire predicted as Fire):    30764 (Correct Detections)


## Model 2: Random Forest Results

In [23]:
evaluate_model(rf_model, X_test, y_test)


--- Classification Report ---
              precision    recall  f1-score   support

 No Fire (0)       0.92      0.99      0.95    417120
    Fire (1)       0.50      0.14      0.22     40920

    accuracy                           0.91    458040
   macro avg       0.71      0.56      0.59    458040
weighted avg       0.88      0.91      0.89    458040

--- Confusion Matrix ---
[[411288   5832]
 [ 35129   5791]]
True Negatives (No Fire predicted as No Fire): 411288
False Positives (No Fire predicted as Fire): 5832 (False Alarms)
False Negatives (Fire predicted as No Fire): 35129 (Missed Fires)
True Positives (Fire predicted as Fire):    5791 (Correct Detections)
