In [1]:
import torch
import csv
import xarray as xr
import fsspec
import zarr

import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

from tqdm import tqdm
import torchvision
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
from torchvision import models, transforms

torch.multiprocessing.set_sharing_strategy('file_system')

from pytorch_lightning import LightningModule
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics import AUC, ConfusionMatrix, AUROC, AveragePrecision
from wildfire_forecasting.models.greece_fire_models import LSTM_fire_model, ConvLSTM_fire_model
# from wildfire_forecasting.models.modules.greece_fire_models import LSTM_fire_model, ConvLSTM_fire_model 

import pickle
import json
from pathlib import Path
import random
import time
import warnings
import gc

random.seed(16)



In [2]:
dimensions = ['time',
 'x',
 'y',
]

all_dynamic_features = [
 '1 km 16 days NDVI',
 'LST_Day_1km',
 'LST_Night_1km',
 'era5_max_d2m',
 'era5_max_t2m',
 'era5_max_sp',
 'era5_max_tp',
 'sminx',
 'era5_max_wind_speed',
 'era5_min_rh']

all_dynamic_features_bolam = [
 '1 km 16 days NDVI',
 'LST_Day_1km',
 'LST_Night_1km',
 'era5_max_d2m',
 'era5_max_t2m',
 'era5_max_tp',
 'sminx',
 'era5_max_wind_speed',
 'era5_min_rh']

all_static_features = [
 'dem_mean',
 'slope_mean',
 'roads_distance',
 'waterway_distance',
 'population_density'
]

all_categorical_features = 'clc_vec'
len_clc = 10

In [3]:
best_settings = {
    'lstm' : {'dynamic_features':all_dynamic_features, 'static_features':all_static_features, 'hidden_size':64, 'lstm_layers':1, 'dropout':0.5}
}

best_settings_bolam = {
    'lstm_bolam' : {'dynamic_features':all_dynamic_features_bolam, 'static_features':all_static_features, 'hidden_size':64, 'lstm_layers':1, 'dropout':0.5}
}

In [4]:
#Follow readme for downloading the models and complete the models path here
models_path = Path.home() / 'hdd1/iprapas/uc3/models'
models_path_bolam = Path.home() / 'hdd1/diogenis/observatory/wildfire_forecasting/logs/runs/2022-09-20/13-44-33/checkpoints'
model = {}

In [5]:
model['lstm'] = LSTM_fire_model(**best_settings['lstm']).load_from_checkpoint(models_path / 'lstm.ckpt')
model['lstm'].eval()



LSTM_fire_model(
  (model): SimpleLSTM(
    (ln1): LayerNorm((25,), eps=1e-05, elementwise_affine=True)
    (lstm): LSTM(25, 64, batch_first=True)
    (fc1): Linear(in_features=64, out_features=64, bias=True)
    (drop1): Dropout(p=0.5, inplace=False)
    (relu): ReLU()
    (fc2): Linear(in_features=64, out_features=32, bias=True)
    (drop2): Dropout(p=0.5, inplace=False)
    (fc3): Linear(in_features=32, out_features=2, bias=True)
    (fc_nn): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=64, out_features=32, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=32, out_features=2, bias=True)
    )
  )
  (criterion): NLLLoss()
  (train_accuracy): Accuracy()
  (train_auc): AUROC()
  (train_auprc): AveragePrecision()
  (val_accuracy): Accuracy()
  (val_auc): AUROC()
  (val_auprc): AveragePrecision()
  (test_accuracy): Accurac

In [6]:
model['lstm_bolam'] = LSTM_fire_model(**best_settings_bolam['lstm_bolam']).load_from_checkpoint(models_path_bolam / 'last.ckpt')
model['lstm_bolam'].eval()

LSTM_fire_model(
  (model): SimpleLSTM(
    (ln1): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
    (lstm): LSTM(24, 64, batch_first=True)
    (fc1): Linear(in_features=64, out_features=64, bias=True)
    (drop1): Dropout(p=0.5, inplace=False)
    (relu): ReLU()
    (fc2): Linear(in_features=64, out_features=32, bias=True)
    (drop2): Dropout(p=0.5, inplace=False)
    (fc3): Linear(in_features=32, out_features=2, bias=True)
    (fc_nn): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=64, out_features=32, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=32, out_features=2, bias=True)
    )
  )
  (criterion): NLLLoss()
  (train_accuracy): Accuracy()
  (train_auc): AUROC()
  (train_auprc): AveragePrecision()
  (val_accuracy): Accuracy()
  (val_auc): AUROC()
  (val_auprc): AveragePrecision()
  (test_accuracy): Accurac

In [7]:
dataset_root = Path.home() / 'hdd1/diogenis/observatory'

minmax_dataset_root = Path.home() / 'jh-shared/skondylatos/datasets'

variable_dict_path = dataset_root / 'variable_dict.json' 

In [8]:
with open(variable_dict_path) as f:
    variable_dict = json.load(f)

In [9]:
with open(minmax_dataset_root / 'minmax_clc_v3.json') as f:
    min_max_dict = json.load(f)

In [10]:
class FireDataset_npy(Dataset):
    def __init__(self,  src, access_mode: str = 'temporal',
                 problem_class: str = 'classification',
                 train_val_test: str = 'test', dynamic_features: list = None, static_features: list = None,
                 categorical_features: list = None, nan_fill: float = -1., neg_pos_ratio: int = 2, clc: str = None):
        """
        @param access_mode: spatial, temporal or spatiotemporal
        @param problem_class: classification or segmentation
        @param train_val_test:
                'train' gets samples from [2009-2018].
                'val' gets samples from 2019.
                test' get samples from 2020
        @param dynamic_features: selects the dynamic features to return
        @param static_features: selects the static features to return
        @param categorical_features: selects the categorical features
        @param nan_fill: Fills nan with the value specified here
        """
        if static_features is None:
            static_features = all_static_features
        if dynamic_features is None:
            dynamic_features = all_dynamic_features
            
        self.static_features = static_features
        self.dynamic_features = dynamic_features
        self.categorical_features = categorical_features
        self.access_mode = access_mode
        self.problem_class = problem_class
        self.nan_fill = nan_fill
        self.clc = clc
        self.src = src
        
        assert problem_class in ['classification', 'segmentation']
        if problem_class == 'classification':
            self.target = 'burned'
        else:
            self.target = 'burned_areas'
            
        assert self.access_mode in ['spatial', 'temporal', 'spatiotemporal']
        
        dataset_path = dataset_root
        if self.src == 'bolam':
            self.positives_list = list((dataset_path / 'positives_bolam').glob('*dynamic.npy'))
            self.negatives_list = list((dataset_path / 'negatives_bolam').glob('*dynamic.npy'))
        else:
            self.positives_list = list((dataset_path / 'positives_era5_bolam').glob('*dynamic.npy'))
            self.negatives_list = list((dataset_path / 'negatives_era5_bolam').glob('*dynamic.npy'))
        
        self.positives_list = list(zip(self.positives_list, [1] * (len(self.positives_list))))
        self.negatives_list = list(zip(self.negatives_list, [0] * (len(self.negatives_list))))
        
        val_year = 2020
        test_year = min(val_year + 1, 2021)

        self.test_positive_list = [(x, y) for (x, y) in self.positives_list if int(x.stem[:4]) == test_year]
        self.test_negative_list = random.sample(
            [(x, y) for (x, y) in self.negatives_list if int(x.stem[:4]) == test_year],
            8638)

        self.dynamic_idxfeat = [(i, feat) for i, feat in enumerate(variable_dict['dynamic']) if
                                feat in self.dynamic_features]
        self.static_idxfeat = [(i, feat) for i, feat in enumerate(variable_dict['static']) if
                               feat in self.static_features]
        
        self.dynamic_idx = [x for (x, _) in self.dynamic_idxfeat]
        self.static_idx = [x for (x, _) in self.static_idxfeat]

        if train_val_test == 'train':
            print(f'Positives: {len(self.train_positive_list)} / Negatives: {len(self.train_negative_list)}')
            self.path_list = self.train_positive_list + self.train_negative_list
        elif train_val_test == 'val':
            print(f'Positives: {len(self.val_positive_list)} / Negatives: {len(self.val_negative_list)}')
            self.path_list = self.val_positive_list + self.val_negative_list
        elif train_val_test == 'test':
            print(f'Positives: {len(self.test_positive_list)} / Negatives: {len(self.test_negative_list)}')
            self.path_list = self.test_positive_list + self.test_negative_list
            
        print("Dataset length", len(self.path_list))
        
        random.shuffle(self.path_list)
        
        self.mm_dict = self._min_max_vec()

    def _min_max_vec(self):
        mm_dict = {'min': {}, 'max': {}}
        for agg in ['min', 'max']:
            if self.access_mode == 'spatial':
                mm_dict[agg]['dynamic'] = np.ones((len(self.dynamic_features), 1, 1))
                mm_dict[agg]['static'] = np.ones((len(self.static_features), 1, 1))
                for i, (_, feat) in enumerate(self.dynamic_idxfeat):
                    mm_dict[agg]['dynamic'][i, :, :] = min_max_dict[agg][self.access_mode][feat]
                for i, (_, feat) in enumerate(self.static_idxfeat):
                    mm_dict[agg]['static'][i, :, :] = min_max_dict[agg][self.access_mode][feat]

            if self.access_mode == 'temporal':
                mm_dict[agg]['dynamic'] = np.ones((1, len(self.dynamic_features)))
                mm_dict[agg]['static'] = np.ones((len(self.static_features)))
                for i, (_, feat) in enumerate(self.dynamic_idxfeat):
                    mm_dict[agg]['dynamic'][:, i] = min_max_dict[agg][self.access_mode][feat]
                for i, (_, feat) in enumerate(self.static_idxfeat):
                    mm_dict[agg]['static'][i] = min_max_dict[agg][self.access_mode][feat]

            if self.access_mode == 'spatiotemporal':
                mm_dict[agg]['dynamic'] = np.ones((1, len(self.dynamic_features), 1, 1))
                mm_dict[agg]['static'] = np.ones((len(self.static_features), 1, 1))
                for i, (_, feat) in enumerate(self.dynamic_idxfeat):
                    mm_dict[agg]['dynamic'][:, i, :, :] = min_max_dict[agg][self.access_mode][feat]
                for i, (_, feat) in enumerate(self.static_idxfeat):
                    mm_dict[agg]['static'][i, :, :] = min_max_dict[agg][self.access_mode][feat]
        return mm_dict

    def __len__(self):
        return len(self.path_list)

    def __getitem__(self, idx):
        path, labels = self.path_list[idx]
        dynamic = np.load(path)
        static = np.load(str(path).replace('dynamic', 'static'))
        
        if self.access_mode == 'spatial':
            dynamic = dynamic[self.dynamic_idx]
            static = static[self.static_idx]
        elif self.access_mode == 'temporal':
            dynamic = dynamic[:, self.dynamic_idx, ...]
            static = static[self.static_idx]
        else:
            dynamic = dynamic[:, self.dynamic_idx, ...]
            static = static[self.static_idx]

        def _min_max_scaling(in_vec, max_vec, min_vec):
            return (in_vec - min_vec) / (max_vec - min_vec)

        dynamic = _min_max_scaling(dynamic, self.mm_dict['max']['dynamic'], self.mm_dict['min']['dynamic'])
        static = _min_max_scaling(static, self.mm_dict['max']['static'], self.mm_dict['min']['static'])

        if self.access_mode == 'temporal':
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                feat_mean = np.nanmean(dynamic, axis=0)
                # Find indices that you need to replace
                inds = np.where(np.isnan(dynamic))
                # Place column means in the indices. Align the arrays using take
                dynamic[inds] = np.take(feat_mean, inds[1])

        elif self.access_mode == 'spatiotemporal':
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                feat_mean = np.nanmean(dynamic, axis=(2, 3))
                feat_mean = feat_mean[..., np.newaxis, np.newaxis]
                feat_mean = np.repeat(feat_mean, dynamic.shape[2], axis=2)
                feat_mean = np.repeat(feat_mean, dynamic.shape[3], axis=3)
                dynamic = np.where(np.isnan(dynamic), feat_mean, dynamic)
        if self.nan_fill:
            dynamic = np.nan_to_num(dynamic, nan=self.nan_fill)
            static = np.nan_to_num(static, nan=self.nan_fill)

        if self.clc == 'mode':
            clc = np.load(str(path).replace('dynamic', 'clc_mode'))
        elif self.clc == 'vec':
            clc = np.load(str(path).replace('dynamic', 'clc_vec'))
            clc = np.nan_to_num(clc, nan=0)
        else:
            clc = 0
        return dynamic, static, clc, labels

In [11]:
cuda_device = 1
positive_weight = 0.5
device = torch.device("cuda:" + str(cuda_device) if torch.cuda.is_available() else "cpu")

weights = [1 - positive_weight, positive_weight]
class_weights = torch.FloatTensor(weights)
criterion = nn.NLLLoss(weight=class_weights)
num_epochs=40

dataloaders = {}

dataloaders['lstm'] ={'bolam' : torch.utils.data.DataLoader(FireDataset_npy('bolam', train_val_test='test', access_mode = 'temporal', clc = 'vec'), batch_size=256, num_workers=16),
                      'era5' : torch.utils.data.DataLoader(FireDataset_npy('era5', train_val_test='test', access_mode = 'temporal', clc = 'vec'), batch_size=256, num_workers=16)}

dataloaders['lstm_bolam'] = {'bolam_2' : torch.utils.data.DataLoader(FireDataset_npy('bolam', train_val_test='test', access_mode = 'temporal', clc = 'vec', dynamic_features = all_dynamic_features_bolam), batch_size=256, num_workers=16),
                           'era5_2': torch.utils.data.DataLoader(FireDataset_npy('era5', train_val_test='test', access_mode = 'temporal', clc = 'vec', dynamic_features = all_dynamic_features_bolam), batch_size=256, num_workers=16)}

Positives: 4319 / Negatives: 8638
Dataset length 12957
Positives: 4319 / Negatives: 8638
Dataset length 12957
Positives: 4319 / Negatives: 8638
Dataset length 12957
Positives: 4319 / Negatives: 8638
Dataset length 12957


In [12]:
def get_precision_and_recall(output, labels, running_true_positives_fire, running_false_positives_fire, running_false_negatives_fire, running_true_positives_non_fire, running_false_positives_non_fire, running_false_negatives_non_fire):
    for j in range(output.size()[0]):
        if output[j] == 1 and labels[j] == 1:
            running_true_positives_fire +=1
        if output[j] == 1 and labels[j] == 0:
            running_false_positives_fire +=1
        if output[j] == 0 and labels[j] == 1:
            running_false_negatives_fire +=1
        if output[j] == 0 and labels[j] == 0:
            running_true_positives_non_fire +=1
        if output[j] == 0 and labels[j] == 1:
            running_false_positives_non_fire +=1
        if output[j] == 1 and labels[j] == 0:
            running_false_negatives_non_fire +=1
    return running_true_positives_fire, running_false_positives_fire, running_false_negatives_fire, running_true_positives_non_fire, running_false_positives_non_fire, running_false_negatives_non_fire

In [13]:
since = time.time()

for data in ['bolam', 'era5']:
    preds = []
    true_labels = []
    running_loss = 0.0
    running_corrects = 0
    running_true_positives_fire = 0
    running_false_positives_fire = 0
    running_false_negatives_fire = 0
    running_true_positives_non_fire = 0
    running_false_positives_non_fire = 0
    running_false_negatives_non_fire = 0
    # Iterate over data.
    for i, (dynamic, static, clc, labels) in enumerate(tqdm(dataloaders['lstm'][data])):
        static = static.unsqueeze(1).repeat(1, dynamic.shape[1], 1)
        clc = clc.unsqueeze(1).repeat(1, dynamic.shape[1], 1)
        input_ = torch.cat([dynamic, static, clc], dim = 2).float()
        with torch.set_grad_enabled(data == 'train'):
            outputs_list = []
            outputs_list.append(model['lstm'](input_))
            outputs = torch.stack(outputs_list, dim=1)
            mean = outputs.mean(1)
            loss = criterion(mean, labels)
        # statistics
        running_loss += loss.item() * dynamic.size(0)
        output = mean

        preds.append(output[:,1])
        true_labels.append(labels)
        output = torch.argmax(output, dim=1)

        correct = (output == labels).float().sum()
        running_corrects += correct

        running_true_positives_fire, running_false_positives_fire, running_false_negatives_fire, running_true_positives_non_fire, running_false_positives_non_fire, running_false_negatives_non_fire = get_precision_and_recall(output, labels, running_true_positives_fire, running_false_positives_fire, running_false_negatives_fire, running_true_positives_non_fire, running_false_positives_non_fire, running_false_negatives_non_fire)

    preds = torch.cat(preds, dim=0).detach().cpu().numpy()
    true_labels = torch.cat(true_labels, dim=0).detach().cpu().numpy()
    
    auc = roc_auc_score(true_labels, preds)
    aucpr = average_precision_score(true_labels, preds)

    epoch_loss = running_loss / len(dataloaders['lstm'][data].dataset)
    time_elapsed = time.time() - since

    print('{} Loss: {:.4f}, Accuracy: {:.4f}, PrecisionFire: {:.4f}, RecallFire: {:.4f}, F1Fire: {:4f}, PrecisionNonFire: {:.4f}, RecallNonFire: {:.4f}, F1NonFire: {:4f}, AUC: {:4f}, AUPRC: {:4f} in {:.4f}m'.format(data, 
                                                                                  epoch_loss, running_corrects/len(dataloaders['lstm'][data].dataset),
                                                                                  running_true_positives_fire/(running_true_positives_fire + running_false_positives_fire),
                                                                                  running_true_positives_fire/(running_true_positives_fire + running_false_negatives_fire),
                                                                                  running_true_positives_fire/(running_true_positives_fire + (1/2)*(running_false_positives_fire + running_false_negatives_fire)),
                                                                                  running_true_positives_non_fire/(running_true_positives_non_fire + running_false_positives_non_fire),
                                                                                  running_true_positives_non_fire/(running_true_positives_non_fire + running_false_negatives_non_fire),
                                                                                  running_true_positives_non_fire/(running_true_positives_non_fire + (1/2)*(running_false_positives_non_fire + running_false_negatives_non_fire)),
                                                                                  auc, aucpr, time_elapsed/60))

    print('Confusion Matrix')
    print(running_true_positives_non_fire, running_false_negatives_non_fire)
    print(running_false_negatives_fire, running_true_positives_fire)

gc.collect()

100%|██████████| 51/51 [00:03<00:00, 13.32it/s]


bolam Loss: 0.4146, Accuracy: 0.8321, PrecisionFire: 0.7147, RecallFire: 0.8263, F1Fire: 0.766455, PrecisionNonFire: 0.9058, RecallNonFire: 0.8350, F1NonFire: 0.868984, AUC: 0.912188, AUPRC: 0.826506 in 0.0645m
Confusion Matrix
7213 1425
750 3569


100%|██████████| 51/51 [00:02<00:00, 17.83it/s]


era5 Loss: 0.2635, Accuracy: 0.8934, PrecisionFire: 0.8233, RecallFire: 0.8662, F1Fire: 0.844184, PrecisionNonFire: 0.9313, RecallNonFire: 0.9070, F1NonFire: 0.919008, AUC: 0.955644, AUPRC: 0.915685 in 0.1130m
Confusion Matrix
7835 803
578 3741


0

In [14]:
since = time.time()

for data in ['bolam_2', 'era5_2']:
    preds = []
    true_labels = []
    running_loss = 0.0
    running_corrects = 0
    running_true_positives_fire = 0
    running_false_positives_fire = 0
    running_false_negatives_fire = 0
    running_true_positives_non_fire = 0
    running_false_positives_non_fire = 0
    running_false_negatives_non_fire = 0
    # Iterate over data.
    for i, (dynamic, static, clc, labels) in enumerate(tqdm(dataloaders['lstm_bolam'][data])):
        static = static.unsqueeze(1).repeat(1, dynamic.shape[1], 1)
        clc = clc.unsqueeze(1).repeat(1, dynamic.shape[1], 1)
        input_ = torch.cat([dynamic, static, clc], dim = 2).float()
        with torch.set_grad_enabled(data == 'train'):
            outputs_list = []
            outputs_list.append(model['lstm_bolam'](input_))
            outputs = torch.stack(outputs_list, dim=1)
            mean = outputs.mean(1)
            loss = criterion(mean, labels)
        # statistics
        running_loss += loss.item() * dynamic.size(0)
        output = mean

        preds.append(output[:,1])
        true_labels.append(labels)
        output = torch.argmax(output, dim=1)

        correct = (output == labels).float().sum()
        running_corrects += correct

        running_true_positives_fire, running_false_positives_fire, running_false_negatives_fire, running_true_positives_non_fire, running_false_positives_non_fire, running_false_negatives_non_fire = get_precision_and_recall(output, labels, running_true_positives_fire, running_false_positives_fire, running_false_negatives_fire, running_true_positives_non_fire, running_false_positives_non_fire, running_false_negatives_non_fire)

    preds = torch.cat(preds, dim=0).detach().cpu().numpy()
    true_labels = torch.cat(true_labels, dim=0).detach().cpu().numpy()
    
    auc = roc_auc_score(true_labels, preds)
    aucpr = average_precision_score(true_labels, preds)

    epoch_loss = running_loss / len(dataloaders['lstm_bolam'][data].dataset)
    time_elapsed = time.time() - since

    print('{} Loss: {:.4f}, Accuracy: {:.4f}, PrecisionFire: {:.4f}, RecallFire: {:.4f}, F1Fire: {:4f}, PrecisionNonFire: {:.4f}, RecallNonFire: {:.4f}, F1NonFire: {:4f}, AUC: {:4f}, AUPRC: {:4f} in {:.4f}m'.format(data, 
                                                                                  epoch_loss, running_corrects/len(dataloaders['lstm_bolam'][data].dataset),
                                                                                  running_true_positives_fire/(running_true_positives_fire + running_false_positives_fire),
                                                                                  running_true_positives_fire/(running_true_positives_fire + running_false_negatives_fire),
                                                                                  running_true_positives_fire/(running_true_positives_fire + (1/2)*(running_false_positives_fire + running_false_negatives_fire)),
                                                                                  running_true_positives_non_fire/(running_true_positives_non_fire + running_false_positives_non_fire),
                                                                                  running_true_positives_non_fire/(running_true_positives_non_fire + running_false_negatives_non_fire),
                                                                                  running_true_positives_non_fire/(running_true_positives_non_fire + (1/2)*(running_false_positives_non_fire + running_false_negatives_non_fire)),
                                                                                  auc, aucpr, time_elapsed/60))

    print('Confusion Matrix')
    print(running_true_positives_non_fire, running_false_negatives_non_fire)
    print(running_false_negatives_fire, running_true_positives_fire)

gc.collect()

100%|██████████| 51/51 [00:02<00:00, 21.79it/s]


bolam_2 Loss: 0.4084, Accuracy: 0.8353, PrecisionFire: 0.7072, RecallFire: 0.8634, F1Fire: 0.777523, PrecisionNonFire: 0.9232, RecallNonFire: 0.8213, F1NonFire: 0.869256, AUC: 0.917479, AUPRC: 0.824759 in 0.0391m
Confusion Matrix
7094 1544
590 3729


100%|██████████| 51/51 [00:02<00:00, 20.46it/s]


era5_2 Loss: 0.2740, Accuracy: 0.8844, PrecisionFire: 0.7903, RecallFire: 0.8891, F1Fire: 0.836784, PrecisionNonFire: 0.9408, RecallNonFire: 0.8820, F1NonFire: 0.910492, AUC: 0.955511, AUPRC: 0.917751 in 0.0809m
Confusion Matrix
7619 1019
479 3840


0