In [2]:
import torch
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import MAE, SMAPE
from pytorch_forecasting.models.temporal_fusion_transformer.plot import plot_prediction
import matplotlib.pyplot as plt
import pandas as pd
import datetime
from datetime import datetime

# 1. Load the checkpoint file
best_model_path = "../models/best-checkpoint.ckpt"
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

df = pd.read_json('../data/wildfire_collection_feature_engineering_final.json')
df.drop(columns='_id', inplace=True)

def return_date(date_string):
    #date_dict = eval(date_string)  # Convert the string to a dictionary
    date_value = date_string['$date']  # Get the value of the '$date' key
    return datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S.%fZ").date()  # Convert the string to a date object

df['rep_date'] = df['rep_date'].apply(return_date)
df['cfb'] = df['cfb'].astype(float)

NameError: name 'TemporalFusionTransformer' is not defined

In [None]:
split_date = datetime.date(2023, 1, 1)
train_data = df[df['rep_date'] < split_date]
test_data = df[df['rep_date'] >= split_date]

In [None]:
# 2. Load your test dataset
# Assuming you have already created your TimeSeriesDataSet
test_dataset = TimeSeriesDataSet.from_dataset(training_dataset,
                                        test_data, 
                                        predict=True, 
                                        stop_randomization=True)

In [None]:
# Create DataLoader
test_dataloader = test_dataset.to_dataloader(train=False, batch_size=128, num_workers=4)

# 3. Evaluate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
actuals = torch.cat([y for x, (y, weight) in iter(test_dataloader)]).to(device)
predictions = best_tft.predict(test_dataloader).to(device)
raw_predictions = best_tft.predict(test_dataloader, mode="raw", return_x=True).to(device)

# 4. Calculate metrics
mae = MAE()(predictions, actuals)
smape = SMAPE()(predictions, actuals)
print(f"MAE: {mae}")
print(f"SMAPE: {smape}")

In [None]:
# 5. Plot actual vs predictions
plt.figure(figsize=(10, 6))
plt.plot(actuals.numpy(), label='Actual')
plt.plot(predictions.numpy(), label='Predicted')
plt.legend()
plt.title('Actual vs Predicted Values')
plt.xlabel('Time')
plt.ylabel('Value')
plt.show()

In [None]:
# 6. Plot feature importances
feature_importances = best_tft.interpret_output(raw_predictions, reduction="sum")
plt.figure(figsize=(10, 6))
feature_importances.plot(kind='bar')
plt.title('Feature Importances')
plt.xlabel('Features')
plt.ylabel('Importance')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# 7. Interpret the predictions for a specific sample
interpretation = best_tft.interpret_output(raw_predictions.iget(0))
best_tft.plot_interpretation(interpretation)
plt.show()