In [2]:
import pandas as pd
import numpy as np

import warnings
warnings.filterwarnings("ignore")

## Load Data

In [3]:
# Load the data
fire_data = pd.read_csv('input_data/processed/fire_data.csv')
cell_static_data = pd.read_csv('input_data/processed/cell_static.csv')
cell_dynamic_data = pd.read_csv('input_data/processed/cell_dynamic.csv')

## Preprocess Data

In [4]:
# Only keep data in the date range from 2022-02-24 to 2024-09-30, i.e., war-time data
fire_data['ACQ_DATE'] = pd.to_datetime(fire_data['ACQ_DATE'])
cell_dynamic_data['ACQ_DATE'] = pd.to_datetime(cell_dynamic_data['ACQ_DATE'])
fire_data = fire_data[(fire_data['ACQ_DATE'] >= '2022-02-24') & (fire_data['ACQ_DATE'] <= '2024-09-30')]
cell_dynamic_data = cell_dynamic_data[(cell_dynamic_data['ACQ_DATE'] >= '2022-02-24') & (cell_dynamic_data['ACQ_DATE'] <= '2024-09-30')]

In [5]:
# Copy the fire data for later use
fire_data_copy = fire_data[['FIRE_ID', 'GRID_CELL_50KM', 'ACQ_DATE', 'LONGITUDE', 'LATITUDE']].copy()

In [None]:
# Drop irrelevant columns
fire_data.drop(columns=['FIRE_ID', 'LATITUDE', 'LONGITUDE', 'GRID_CELL', 'OBLAST_ID', 
                        'LATITUDE_1KM', 'LONGITUDE_1KM', 'GRID_CELL_1KM', 'OBLAST_ID_1KM', 
                        'FIRE_COUNT_CELL_1KM',], inplace=True)
# Drop duplicates
fire_data.drop_duplicates(inplace=True)
# Reset index
fire_data.reset_index(drop=True, inplace=True)
fire_data.head()

In [7]:
def merge_static_data(data, static_data, resolution='50KM'):
    # Drop all columns in the static data that are not relevant for the specific resolution
    static_data = static_data[[col for col in static_data.columns if col.endswith(resolution)]]
    # Drop all duplicates
    static_data.drop_duplicates(inplace=True)
    # Merge the fire data with the static data
    merged_data = pd.merge(data, static_data, how='left', on=['GRID_CELL_50KM', 'OBLAST_ID_50KM', 'LATITUDE_50KM', 'LONGITUDE_50KM'])
    return merged_data

In [None]:
# Merge the fire data with the static data
fire_data_processed = merge_static_data(fire_data, cell_static_data)
fire_data_processed.head()

In [9]:
def generate_fire_time_series(data, start_date, end_date, resolution='50KM'):
    time_series_data = {}
    # Iterate over all cells in the grid_cell column
    for cell in data['GRID_CELL_{}'.format(resolution)].unique():
        # Filter the data for the specific cell
        cell_data = data[data['GRID_CELL_{}'.format(resolution)] == cell]
        # Save the static data from all columns except the ACQ_DATE, DAY_OF_YEAR, and FIRE_COUNT_CELL columns
        static_data = cell_data.iloc[0].drop(['ACQ_DATE', 'DAY_OF_YEAR', 'FIRE_COUNT_CELL_{}'.format(resolution)])
        # Set ACQ_DATE as the index and reindex with the complete date range
        cell_data.set_index('ACQ_DATE', inplace=True)
        cell_data.index = pd.to_datetime(cell_data.index)
        cell_data = cell_data.reindex(pd.date_range(start=start_date, end=end_date, freq='D'), fill_value=0)
        # Override the DAY_OF_YEAR column with the correct values
        cell_data['DAY_OF_YEAR'] = cell_data.index.dayofyear
        # Override the ACQ_DATE column with the correct values
        cell_data['ACQ_DATE'] = cell_data.index
        cell_data.reset_index(drop=True, inplace=True)
        cell_data = cell_data[['ACQ_DATE'] + [col for col in cell_data.columns if col != 'ACQ_DATE']]
        # Override the columns with the static data
        for col in static_data.index:
            cell_data[col] = static_data[col]
        # Save the data
        time_series_data[cell] = cell_data
    # Merge the time series data into a single DataFrame
    time_series_data = pd.concat(time_series_data.values())
    return time_series_data

In [None]:
# Create a date range from 2022-02-24 to 2024-09-30, i.e., war-time data
fire_data_processed = generate_fire_time_series(fire_data_processed, '2022-02-24', '2024-09-30', '50KM')
fire_data_processed.head()

In [11]:
def merge_dynamic_data(data, dynamic_data):
    # Merge the fire data with the dynamic data
    merged_data = pd.merge(data, dynamic_data, how='left', left_on=['OBLAST_ID_50KM', 'ACQ_DATE'], right_on=['OBLAST_ID', 'ACQ_DATE'])
    # Drop the OBLAST_ID column
    merged_data.drop(columns=['OBLAST_ID'], inplace=True)
    return merged_data

In [None]:
# Merge the fire data with the dynamic data
fire_data_processed = merge_dynamic_data(fire_data_processed, cell_dynamic_data)
fire_data_processed.head()

## Create Test Data

In [13]:
# Define the features and target variable
features = fire_data_processed.drop(columns=['ACQ_DATE', 'GRID_CELL_50KM', 'OBLAST_ID_50KM', 'FIRE_COUNT_CELL_50KM'])
target = fire_data_processed['FIRE_COUNT_CELL_50KM']

# Get two additional features, i.e., the ACQ_DATE and the GRID_CELL_50KM
acq_date = fire_data_processed['ACQ_DATE']
grid_cell = fire_data_processed['GRID_CELL_50KM']

# Bring the data in the correct format for the sklearn pipeline
X_test = features.values
y_test = target.values

## Load the Pipeline

In [14]:
from sklearn.base import BaseEstimator, TransformerMixin

def calculate_significance_score(value, threshold):
    # Calculate the significance score
    if value < threshold:
        return ((value - threshold) / threshold)
    else:
        return (value - threshold) / value

class ThresholdStep(BaseEstimator, TransformerMixin):
    def __init__(self, threshold):
        self.threshold = threshold

    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
        if y is None:
            raise ValueError("True values (y) are required for the threshold step.")
        
        # Calculate the error
        error = y - X
        # Set all negative values to zero
        error[error < 0] = 0
        # Compare the error with the threshold
        is_abnormal = error > self.threshold
        # Calculate the significance score
        significance_score = np.array(pd.Series(error).apply(calculate_significance_score, threshold=self.threshold))

        return np.array([is_abnormal.astype(int), significance_score])

In [15]:
from sklearn.base import BaseEstimator, TransformerMixin

class RecalculateConfidenceScores(BaseEstimator, TransformerMixin):
    def __init__(self, decay_rate, midpoint, cutoff):
        self.decay_rate = decay_rate
        self.midpoint = midpoint
        self.cutoff = cutoff

    def sigmoid_decay(self, time_diff):
        if time_diff > self.cutoff:
            return 0  # Influence reaches zero after the cut-off
        return 1 / (1 + np.exp(self.decay_rate * (time_diff - self.midpoint)))

    def fit(self, X, y=None):
        return self

    def transform(self, X, dates=None, grid_cells=None):
        if dates is None or grid_cells is None:
            raise ValueError("Dates and grid cells are required for the recalculation of confidence scores.")
        y_scores = X[1]
        y_pred = X[0]
        # Combine confidence scores, dates, and grid cells into a single list of events with indices
        indexed_events = list(enumerate(zip(y_scores, dates, grid_cells)))
        # Sort events by 'date' while preserving their original index
        indexed_events_sorted = sorted(indexed_events, key=lambda x: x[1][1])  # Sort by date (the second element of the tuple)
        # Initialize recalculated confidence scores with placeholders
        recalculated_confidences = [None] * len(y_scores)
        # Track the last war-related fire by grid cell
        last_war_events = {}
        # Loop through each event in the sorted order
        for i, (original_index, (current_conf, current_date, grid_cell)) in enumerate(indexed_events_sorted):
            # If this is a war-related fire, reset the decay process for this grid cell
            if current_conf > 0:
                last_war_events[grid_cell] = {
                    'ACQ_DATE': current_date,
                    'SIGNIFICANCE_SCORE': current_conf
                }
                recalculated_confidences[original_index] = current_conf  # No decay for the current event
            elif grid_cell in last_war_events:
                # Calculate the time difference from the last war-related fire in the same grid cell
                last_war_event = last_war_events[grid_cell]
                time_diff = (current_date - last_war_event['ACQ_DATE'])
                # Transform the time difference, which is in nanoseconds, to days
                time_diff = time_diff / np.timedelta64(1, 'D')
                # Apply the decay function to the subsequent fires in the same grid cell
                decayed_influence = self.sigmoid_decay(time_diff) * last_war_event['SIGNIFICANCE_SCORE']
                # If the decayed influence is zero or less than the original confidence, keep the original confidence
                if decayed_influence > current_conf and decayed_influence > 0:
                    new_conf = decayed_influence
                else:
                    new_conf = current_conf  # Preserve original confidence
                
                recalculated_confidences[original_index] = new_conf
            else:
                # If no war-related fire has been detected in this grid cell, keep the original confidence
                recalculated_confidences[original_index] = current_conf
        
        recalculated_confidences = np.array(recalculated_confidences)
        labels = np.where(recalculated_confidences > 0, 1, 0)
        return [labels, recalculated_confidences, y_pred, y_scores]

In [None]:
# Load the pipeline
import pickle

with open('saved_models/pipeline.pkl', 'rb') as file:
    pipeline = pickle.load(file)

print(pipeline)

## Explain Local Prediction

In [42]:
import lime.lime_tabular
import matplotlib.pyplot as plt

# Create a LIME explainer
explainer = lime.lime_tabular.LimeTabularExplainer(
    training_data=X_test, # TODO - replace with training data
    feature_names=features.columns, 
    class_names=['Normal', 'Abnormal'], 
    mode='regression'
)

# Choose an instance to explain
instance_idx = 5000
instance = X_test[instance_idx]

# Get the prediction probabilities for the instance
predict_fn = lambda x: pipeline.named_steps['regressor'].predict(
    pipeline.named_steps['pca'].transform(pipeline.named_steps['scaler'].transform(x)), quantiles=[0.95]
    )

# Explain the instance
exp = explainer.explain_instance(instance, predict_fn, num_features=X_test.shape[1])

# save to temp file
exp.save_to_file('output_data/lime_explanation.html')

# Show the explanation
# exp.show_in_notebook(show_table=True, show_all=True)

# Show the explanation as a plot
exp.as_pyplot_figure()
plt.show()

## Save Explainer

In [None]:
# Save explainer to file
import pickle

with open('output_data/lime_explainer.pkl', 'wb') as file:
    pickle.dump(explainer, file)