In [None]:
import torch
import numpy as np

from tqdm import tqdm
import torchvision
import torch.nn as nn

import time
import copy
import gc
import json
import warnings

import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier 
from sklearn.metrics import confusion_matrix, plot_confusion_matrix, classification_report, plot_precision_recall_curve, roc_auc_score

from wildfire_forecasting.datamodules.datasets import FireDataset_npy

In [None]:
# !IMPORTANT fill the path with path of the dataset you have downloaded
dataset_root = Path(None)

dataloaders = {
    'train' : torch.utils.data.DataLoader(FireDataset_npy(dataset_root = dataset_root, train_val_test='train', access_mode = 'temporal', clc = 'vec'), batch_size=1, shuffle=True, num_workers=16),
    'val' : torch.utils.data.DataLoader(FireDataset_npy(dataset_root = dataset_root, train_val_test='val', access_mode = 'temporal', clc = 'vec'), batch_size=1, num_workers=16),
    'test': torch.utils.data.DataLoader(FireDataset_npy(dataset_root = dataset_root, train_val_test='test', access_mode = 'temporal', clc = 'vec'), batch_size=1, num_workers=16),
}

In [None]:
#Create the training, val and test datasets
X_train = []
X_val = []
X_test = []
y_train = []
y_val = []
y_test = []

for i, (dynamic, static, clc, label) in enumerate(dataloaders['train']):
    dynamic_avg = dynamic.nanmean(dim=1)
    input_ = torch.cat([dynamic_avg.squeeze(), dynamic[:,-1,:].squeeze(), static.squeeze(), clc.squeeze()], dim = 0)
    input_ = input_.numpy()
    X_train.append(input_)
    y_train.append(label.numpy())

for i, (dynamic, static, clc, label) in enumerate(dataloaders['val']):
    dynamic_avg = dynamic.nanmean(dim=1)
    input_ = torch.cat([dynamic_avg.squeeze(), dynamic[:,-1,:].squeeze(), static.squeeze(), clc.squeeze()], dim = 0)
    input_ = input_.numpy()
    X_val.append(input_)
    y_val.append(label.numpy())
    
for i, (dynamic, static, clc, label) in enumerate(dataloaders['test']):
    dynamic_avg = dynamic.nanmean(dim=1)
    input_ = torch.cat([dynamic_avg.squeeze(), dynamic[:,-1,:].squeeze(), static.squeeze(), clc.squeeze()], dim = 0)
    input_ = input_.numpy()
    X_test.append(input_)
    y_test.append(label.numpy())

X_train = np.stack(X_train, axis=0)
y_train = np.stack(y_train, axis=0)
X_val = np.stack(X_val, axis=0)
y_val = np.stack(y_val, axis=0)
X_test = np.stack(X_test, axis=0)
y_test = np.stack(y_test, axis=0)

In [None]:
n_est = 100
max_depth = 10
min_samples_split = 2
min_samples_leaf = 1

In [None]:
clf = RandomForestClassifier(n_estimators=n_est, max_depth = max_depth, min_samples_split=min_samples_split, 
                             min_samples_leaf = min_samples_leaf, random_state=1234)
clf.fit(X_train, y_train.ravel())

In [None]:
y_pred=clf.predict(X_test)

probs_pred = clf.predict_proba(X_test)[:,1]
X_test = np.stack(X_test, axis=0)
y_test = np.stack(y_test, axis=0)
auc = roc_auc_score(y_test, probs_pred)
aucpr = average_precision_score(y_test, probs_pred)

print(auc)
print(aucpr)
print(classification_report(y_test, y_pred, digits=3))

In [None]:
# import pickle
# filename = 'rf.sav'
# pickle.dump(clf, open(filename, 'wb'))