# First contact with the dataset
This Notebook has as objective to replicate the baseline results from Minixhofer et al. (2021)

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from scipy.interpolate import interp1d
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import f1_score, mean_absolute_error

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter



**Note :** Just to have a normalized methodology to save and visualize the results of all the experiments trought this projet I add and configure a Tensorboard-SummaryWriter. In the same way I have changed the training cycle to put the results in the tensorboard format.

In [2]:
writer = SummaryWriter('../logs/LSTM_Baseline')

## Importing the data in a unique dictionary

In [3]:
filesList = os.listdir('../src')
print(filesList)

['soil_data.csv', 'train_timeseries', 'test_timeseries', 'validation_timeseries']


In [4]:
dataDic = {"train": pd.read_csv("../src/train_timeseries/train_timeseries.csv"),
           "test": pd.read_csv("../src/test_timeseries/test_timeseries.csv"),
           "validation": pd.read_csv("../src/validation_timeseries/validation_timeseries.csv"),
           "soil" : pd.read_csv("../src/soil_data.csv"),
           }


In [5]:
dataDic["train"].columns

Index(['fips', 'date', 'PRECTOT', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET',
       'T2M_MAX', 'T2M_MIN', 'T2M_RANGE', 'TS', 'WS10M', 'WS10M_MAX',
       'WS10M_MIN', 'WS10M_RANGE', 'WS50M', 'WS50M_MAX', 'WS50M_MIN',
       'WS50M_RANGE', 'score'],
      dtype='object')

In [6]:
class2id = {
    'None': 0,
    'D0': 1,
    'D1': 2,
    'D2': 3,
    'D3': 4,
    'D4': 5,
}
id2class = {v: k for k, v in class2id.items()}

In [7]:
dfs = {
    k: dataDic[k].set_index(['fips', 'date'])
    for k in dataDic.keys() if k != "soil"
}

dfs["soil"] = dataDic["soil"]

In [8]:
dfs["train"]

Unnamed: 0_level_0,Unnamed: 1_level_0,PRECTOT,PS,QV2M,T2M,T2MDEW,T2MWET,T2M_MAX,T2M_MIN,T2M_RANGE,TS,WS10M,WS10M_MAX,WS10M_MIN,WS10M_RANGE,WS50M,WS50M_MAX,WS50M_MIN,WS50M_RANGE,score
fips,date,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
1001,2000-01-01,0.22,100.51,9.65,14.74,13.51,13.51,20.96,11.46,9.50,14.65,2.20,2.94,1.49,1.46,4.85,6.04,3.23,2.81,
1001,2000-01-02,0.20,100.55,10.42,16.69,14.71,14.71,22.80,12.61,10.18,16.60,2.52,3.43,1.83,1.60,5.33,6.13,3.72,2.41,
1001,2000-01-03,3.65,100.15,11.76,18.49,16.52,16.52,22.73,15.32,7.41,18.41,4.03,5.33,2.66,2.67,7.53,9.52,5.87,3.66,
1001,2000-01-04,15.95,100.29,6.42,11.40,6.09,6.10,18.09,2.16,15.92,11.31,3.84,5.67,2.08,3.59,6.73,9.31,3.74,5.58,1.0
1001,2000-01-05,0.00,101.15,2.95,3.86,-3.29,-3.20,10.82,-2.66,13.48,2.65,1.60,2.50,0.52,1.98,2.94,4.85,0.65,4.19,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
56043,2016-12-27,0.16,82.88,1.63,-7.97,-13.49,-12.81,-1.39,-13.60,12.21,-9.41,5.90,7.63,3.61,4.02,8.58,10.39,5.92,4.47,0.0
56043,2016-12-28,0.02,83.33,1.41,-8.71,-14.10,-13.84,-2.49,-13.56,11.07,-10.55,6.50,11.43,4.11,7.32,9.92,14.49,7.26,7.22,
56043,2016-12-29,0.00,83.75,1.59,-7.96,-13.30,-13.03,0.42,-14.51,14.93,-10.29,4.29,6.24,2.03,4.22,6.56,10.07,3.20,6.87,
56043,2016-12-30,1.22,82.49,2.63,-2.94,-7.40,-7.33,3.76,-6.86,10.62,-4.14,4.98,7.34,1.99,5.35,7.28,10.12,3.24,6.89,


## Interpolation pour les données manquantes

In [9]:
def interpolate_nans(padata, pkind='linear'):
    """
    see: https://stackoverflow.com/a/53050216/2167159
    """
    aindexes = np.arange(padata.shape[0])
    agood_indexes, = np.where(np.isfinite(padata))
    f = interp1d(agood_indexes
               , padata[agood_indexes]
               , bounds_error=False
               , copy=False
               , fill_value="extrapolate"
               , kind=pkind)
    return f(aindexes)

## Function to encode the cycling feature: year-day, using sin/cos

In [10]:
def date_encode(date):
    if isinstance(date, str):
        date = datetime.strptime(date, "%Y-%m-%d")
    return (
        np.sin(2 * np.pi * date.timetuple().tm_yday / 366),
        np.cos(2 * np.pi * date.timetuple().tm_yday / 366),
    )

## Function to load the data

In [11]:
def loadXY(
    df,
    random_state=42,
    window_size=180, # how many days in the past (default/competition: 180)
    target_size=6, # how many weeks into the future (default/competition: 6)
    fuse_past=True, # add the past drought observations? (default: True)
    return_fips=False, # return the county identifier (do not use for predictions)
    encode_season=True, # encode the season using the function above (default: True) 
    use_prev_year=False, # add observations from 1 year prior?
):
    df = dfs[df]
    soil_df = dfs["soil"]
    time_data_cols = sorted(
        [c for c in df.columns if c not in ["fips", "date", "score"]]
    )
    static_data_cols = sorted(
        [c for c in soil_df.columns if c not in ["soil", "lat", "lon"]]
    )
    count = 0
    score_df = df.dropna(subset=["score"])
    X_static = np.empty((len(df) // window_size, len(static_data_cols)))
    X_fips_date = []
    add_dim = 0
    if use_prev_year:
        add_dim += len(time_data_cols)
    if fuse_past:
        add_dim += 1
        if use_prev_year:
            add_dim += 1
    if encode_season:
        add_dim += 2
    X_time = np.empty(
        (len(df) // window_size, window_size, len(time_data_cols) + add_dim)
    )
    y_past = np.empty((len(df) // window_size, window_size))
    y_target = np.empty((len(df) // window_size, target_size))
    if random_state is not None:
        np.random.seed(random_state)
    for fips in tqdm(score_df.index.get_level_values(0).unique()):
        if random_state is not None:
            start_i = np.random.randint(1, window_size)
        else:
            start_i = 1
        fips_df = df[(df.index.get_level_values(0) == fips)]
        X = fips_df[time_data_cols].values
        y = fips_df["score"].values
        X_s = soil_df[soil_df["fips"] == fips][static_data_cols].values[0]
        for i in range(start_i, len(y) - (window_size + target_size * 7), window_size):
            X_fips_date.append((fips, fips_df.index[i : i + window_size][-1]))
            X_time[count, :, : len(time_data_cols)] = X[i : i + window_size]
            if use_prev_year:
                if i < 365 or len(X[i - 365 : i + window_size - 365]) < window_size:
                    continue
                X_time[count, :, -len(time_data_cols) :] = X[
                    i - 365 : i + window_size - 365
                ]
            if not fuse_past:
                y_past[count] = interpolate_nans(y[i : i + window_size])
            else:
                X_time[count, :, len(time_data_cols)] = interpolate_nans(
                    y[i : i + window_size]
                )
            if encode_season:
                enc_dates = [
                    date_encode(d) for f, d in fips_df.index[i : i + window_size].values
                ]
                d_sin, d_cos = [s for s, c in enc_dates], [c for s, c in enc_dates]
                X_time[count, :, len(time_data_cols) + (add_dim - 2)] = d_sin
                X_time[count, :, len(time_data_cols) + (add_dim - 2) + 1] = d_cos
            temp_y = y[i + window_size : i + window_size + target_size * 7]
            y_target[count] = np.array(temp_y[~np.isnan(temp_y)][:target_size])
            X_static[count] = X_s
            count += 1
    print(f"loaded {count} samples")
    results = [X_static[:count], X_time[:count], y_target[:count]]
    if not fuse_past:
        results.append(y_past[:count])
    if return_fips:
        results.append(X_fips_date)
    return results

In [12]:
scaler_dict = {}
scaler_dict_static = {}
scaler_dict_past = {}


def normalize(X_static, X_time, y_past=None, fit=False):
    for index in tqdm(range(X_time.shape[-1])):
        if fit:
            scaler_dict[index] = RobustScaler().fit(X_time[:, :, index].reshape(-1, 1))
        X_time[:, :, index] = (
            scaler_dict[index]
            .transform(X_time[:, :, index].reshape(-1, 1))
            .reshape(-1, X_time.shape[-2])
        )
    for index in tqdm(range(X_static.shape[-1])):
        if fit:
            scaler_dict_static[index] = RobustScaler().fit(
                X_static[:, index].reshape(-1, 1)
            )
        X_static[:, index] = (
            scaler_dict_static[index]
            .transform(X_static[:, index].reshape(-1, 1))
            .reshape(1, -1)
        )
    index = 0
    if y_past is not None:
        if fit:
            scaler_dict_past[index] = RobustScaler().fit(y_past.reshape(-1, 1))
        y_past[:, :] = (
            scaler_dict_past[index]
            .transform(y_past.reshape(-1, 1))
            .reshape(-1, y_past.shape[-1])
        )
        return X_static, X_time, y_past
    return X_static, X_time

In [None]:
X_tabular_train, X_time_train, y_target_train = loadXY("train")
print("train shape", X_time_train.shape)
X_tabular_validation, X_time_valid, y_target_valid, valid_fips = loadXY("validation", return_fips=True)
print("validation shape", X_time_valid.shape)
X_tabular_train, X_time_train = normalize(X_tabular_train, X_time_train, fit=True)
X_tabular_validation, X_time_valid = normalize(X_tabular_validation, X_time_valid)

100%|██████████| 3108/3108 [10:32<00:00,  4.92it/s]


loaded 103390 samples
train shape (103390, 180, 21)


 73%|███████▎  | 2261/3108 [00:35<00:13, 64.56it/s]

In [None]:
batch_size = 128
output_weeks = 6
use_static = True
hidden_dim = 512
n_layers = 2
ffnn_layers = 2
dropout = 0.1
one_cycle = True
lr = 7e-5
epochs = 10
clip = 5

In [None]:
train_data = TensorDataset(
    torch.tensor(X_time_train),
    torch.tensor(X_tabular_train),
    torch.tensor(y_target_train[:, :output_weeks]),
)
train_loader = DataLoader(
    train_data, shuffle=True, batch_size=batch_size, drop_last=False
)
valid_data = TensorDataset(
    torch.tensor(X_time_valid),
    torch.tensor(X_tabular_validation),
    torch.tensor(y_target_valid[:, :output_weeks]),
)
valid_loader = DataLoader(
    valid_data, shuffle=False, batch_size=batch_size, drop_last=False
)

In [None]:
class DroughtNetLSTM(nn.Module):
    def __init__(
        self,
        output_size,
        num_input_features,
        hidden_dim,
        n_layers,
        ffnn_layers,
        drop_prob,
        static_dim=0,
    ):
        super(DroughtNetLSTM, self).__init__()
        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

        self.lstm = nn.LSTM(
            num_input_features,
            hidden_dim,
            n_layers,
            dropout=drop_prob,
            batch_first=True,
        )
        self.dropout = nn.Dropout(drop_prob)
        self.fflayers = []
        for i in range(ffnn_layers - 1):
            if i == 0:
                self.fflayers.append(nn.Linear(hidden_dim + static_dim, hidden_dim))
            else:
                self.fflayers.append(nn.Linear(hidden_dim, hidden_dim))
        self.fflayers = nn.ModuleList(self.fflayers)
        self.final = nn.Linear(hidden_dim, output_size)

    def forward(self, x, hidden, static=None):
        batch_size = x.size(0)
        x = x.to(dtype=torch.float32)
        if static is not None:
            static = static.to(dtype=torch.float32)
        lstm_out, hidden = self.lstm(x, hidden)
        lstm_out = lstm_out[:, -1, :]

        out = self.dropout(lstm_out)
        for i in range(len(self.fflayers)):
            if i == 0 and static is not None:
                out = self.fflayers[i](torch.cat((out, static), 1))
            else:
                out = self.fflayers[i](out)
        out = self.final(out)

        out = out.view(batch_size, -1)
        return out, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (
            weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device),
            weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device),
        )
        return hidden

In [None]:
is_cuda = torch.cuda.is_available()
if is_cuda:
    device = torch.device("cuda")
    print("using GPU")
else:
    device = torch.device("cpu")
    print("using CPU")
static_dim = 0
if use_static:
    static_dim = X_tabular_train.shape[-1]
model = DroughtNetLSTM(
    output_weeks,
    X_time_train.shape[-1],
    hidden_dim,
    n_layers,
    ffnn_layers,
    dropout,
    static_dim,
)
model.to(device)
loss_function = nn.MSELoss()
if one_cycle:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=epochs
    )
else:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
counter = 0
valid_loss_min = np.inf
torch.manual_seed(42)
np.random.seed(42)
for i in range(epochs):
    h = model.init_hidden(batch_size)

    for k, (inputs, static, labels) in tqdm(
        enumerate(train_loader),
        desc=f"epoch {i+1}/{epochs}",
        total=len(train_loader),
    ):
        model.train()
        counter += 1
        if len(inputs) < batch_size:
            h = model.init_hidden(len(inputs))
        h = tuple([e.data for e in h])
        inputs, labels, static = (
            inputs.to(device),
            labels.to(device),
            static.to(device),
        )
        model.zero_grad()
        if use_static:
            output, h = model(inputs, h, static)
        else:
            output, h = model(inputs, h)
        loss = loss_function(output, labels.float())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        if one_cycle:
            scheduler.step()

        with torch.no_grad():
            if k == len(train_loader) - 1 or k == (len(train_loader) - 1) // 2:
                val_h = model.init_hidden(batch_size)
                val_losses = []
                model.eval()
                labels = []
                preds = []
                raw_labels = []
                raw_preds = []
                for inp, stat, lab in valid_loader:
                    if len(inp) < batch_size:
                        val_h = model.init_hidden(len(inp))
                    val_h = tuple([each.data for each in val_h])
                    inp, lab, stat = inp.to(device), lab.to(device), stat.to(device)
                    if use_static:
                        out, val_h = model(inp, val_h, stat)
                    else:
                        out, val_h = model(inp, val_h)
                    val_loss = loss_function(out, lab.float())
                    val_losses.append(val_loss.item())
                    for labs in lab:
                        labels.append([int(l.round()) for l in labs])
                        raw_labels.append([float(l) for l in labs])
                    for pred in out:
                        preds.append([int(p.round()) for p in pred])
                        raw_preds.append([float(p) for p in pred])
                # log data
                labels = np.array(labels)
                preds = np.clip(np.array(preds), 0, 5)
                raw_preds = np.array(raw_preds)
                raw_labels = np.array(raw_labels)
                for i in range(output_weeks):
                    log_dict = {
                        "loss": float(loss),
                        "epoch": counter / len(train_loader),
                        "step": counter,
                        "lr": optimizer.param_groups[0]["lr"],
                        "week": i + 1,
                    }
                    # w = f'week_{i+1}_'
                    w = ""
                    log_dict[f"{w}validation_loss"] = np.mean(val_losses)
                    log_dict[f"{w}macro_f1"] = f1_score(
                        labels[:, i], preds[:, i], average="macro"
                    )
                    log_dict[f"{w}micro_f1"] = f1_score(
                        labels[:, i], preds[:, i], average="micro"
                    )
                    log_dict[f"{w}mae"] = mean_absolute_error(
                        raw_labels[:, i], raw_preds[:, i]
                    )
                    print(log_dict)
                    writer.add_scalars("Loss(MSE)", {'train': loss,
                                                     'validation': log_dict[f"{w}validation_loss"]},
                                                     counter)
                    writer.add_scalars("F1(MSE)", {'macro': log_dict[f"{w}macro_f1"],
                                                   'micro': log_dict[f"{w}micro_f1"]},
                                                   counter)
                    writer.add_scalar("MAE", log_dict[f"{w}mae"],
                                      counter)
                    writer.add_scalar("Learning-Rate", log_dict["lr"],
                                      counter)
                    for j, f1 in enumerate(
                        f1_score(labels[:, i], preds[:, i], average=None)
                    ):
                        log_dict[f"{w}{id2class[j]}_f1"] = f1
                    model.train()
                if np.mean(val_losses) <= valid_loss_min:
                    torch.save(model.state_dict(), "./state_dict.pt")
                    print(
                        "Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
                            valid_loss_min, np.mean(val_losses)
                        )
                    )
                    valid_loss_min = np.mean(val_losses)

using GPU


epoch 1/10:  50%|█████     | 404/808 [02:24<22:52,  3.40s/it]

{'loss': 0.40540656447410583, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 1, 'validation_loss': np.float64(0.417090016754641), 'macro_f1': np.float64(0.4940091142245857), 'micro_f1': np.float64(0.7078189300411523), 'mae': np.float64(0.4269706536039974)}
{'loss': 0.40540656447410583, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 2, 'validation_loss': np.float64(0.417090016754641), 'macro_f1': np.float64(0.43930610203028847), 'micro_f1': np.float64(0.6696387745770462), 'mae': np.float64(0.44217968909021965)}
{'loss': 0.40540656447410583, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 3, 'validation_loss': np.float64(0.417090016754641), 'macro_f1': np.float64(0.39762044439712385), 'micro_f1': np.float64(0.6274577046181985), 'mae': np.float64(0.5103915951933662)}
{'loss': 0.40540656447410583, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 4, 'validation_loss': np.float64(0.417090016754641), 'macro_f1': np.float64(0.353

epoch 1/10: 100%|██████████| 808/808 [04:52<00:00,  2.76it/s]


{'loss': 0.22620344161987305, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 1, 'validation_loss': np.float64(0.2483354107193325), 'macro_f1': np.float64(0.7564181479492601), 'micro_f1': np.float64(0.8683127572016461), 'mae': np.float64(0.22443508443987303)}
{'loss': 0.22620344161987305, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 2, 'validation_loss': np.float64(0.2483354107193325), 'macro_f1': np.float64(0.6855000665796022), 'micro_f1': np.float64(0.8176726108824874), 'mae': np.float64(0.26823645023587495)}
{'loss': 0.22620344161987305, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 3, 'validation_loss': np.float64(0.2483354107193325), 'macro_f1': np.float64(0.6347174043509513), 'micro_f1': np.float64(0.7790352080475538), 'mae': np.float64(0.3154690905890926)}
{'loss': 0.22620344161987305, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 4, 'validation_loss': np.float64(0.2483354107193325), 'macro_f1': np.float6

epoch 2/10:  50%|█████     | 404/808 [02:27<23:05,  3.43s/it]

{'loss': 0.3091033399105072, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 1, 'validation_loss': np.float64(0.23797678591116614), 'macro_f1': np.float64(0.7745300266900594), 'micro_f1': np.float64(0.880315500685871), 'mae': np.float64(0.19007316553764364)}
{'loss': 0.3091033399105072, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 2, 'validation_loss': np.float64(0.23797678591116614), 'macro_f1': np.float64(0.7058232436011548), 'micro_f1': np.float64(0.8279606767261088), 'mae': np.float64(0.24571470459828948)}
{'loss': 0.3091033399105072, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 3, 'validation_loss': np.float64(0.23797678591116614), 'macro_f1': np.float64(0.649100043056262), 'micro_f1': np.float64(0.7848651120256058), 'mae': np.float64(0.29648405554843155)}
{'loss': 0.3091033399105072, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 4, 'validation_loss': np.float64(0.23797678591116614), 'macro_f1': np.flo

epoch 2/10: 100%|██████████| 808/808 [04:54<00:00,  2.74it/s]


{'loss': 0.2376997321844101, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 1, 'validation_loss': np.float64(0.2270330902243006), 'macro_f1': np.float64(0.6335386981742628), 'micro_f1': np.float64(0.8756287151348879), 'mae': np.float64(0.17309437069325576)}
{'loss': 0.2376997321844101, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 2, 'validation_loss': np.float64(0.2270330902243006), 'macro_f1': np.float64(0.5577408352846581), 'micro_f1': np.float64(0.823045267489712), 'mae': np.float64(0.23149042904565956)}
{'loss': 0.2376997321844101, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 3, 'validation_loss': np.float64(0.2270330902243006), 'macro_f1': np.float64(0.4570009401387251), 'micro_f1': np.float64(0.7765203475080018), 'mae': np.float64(0.2888461066796681)}
{'loss': 0.2376997321844101, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 4, 'validation_loss': np.float64(0.2270330902243006), 'macro_f1': np.float64(0.4

epoch 3/10:  50%|█████     | 404/808 [02:26<23:01,  3.42s/it]

{'loss': 0.25073882937431335, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 1, 'validation_loss': np.float64(0.22737590412514797), 'macro_f1': np.float64(0.7078199458082969), 'micro_f1': np.float64(0.883973479652492), 'mae': np.float64(0.17051862244237584)}
{'loss': 0.25073882937431335, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 2, 'validation_loss': np.float64(0.22737590412514797), 'macro_f1': np.float64(0.6423269350899753), 'micro_f1': np.float64(0.8273891175125743), 'mae': np.float64(0.23382204041227642)}
{'loss': 0.25073882937431335, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 3, 'validation_loss': np.float64(0.22737590412514797), 'macro_f1': np.float64(0.5872817909270741), 'micro_f1': np.float64(0.7824645633287609), 'mae': np.float64(0.2929405957964503)}
{'loss': 0.25073882937431335, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 4, 'validation_loss': np.float64(0.22737590412514797), 'macro_f1': np.flo

epoch 3/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]


{'loss': 0.2585916221141815, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 1, 'validation_loss': np.float64(0.2271125737739646), 'macro_f1': np.float64(0.7757741172943581), 'micro_f1': np.float64(0.8808870598994056), 'mae': np.float64(0.17477316567382534)}
{'loss': 0.2585916221141815, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 2, 'validation_loss': np.float64(0.2271125737739646), 'macro_f1': np.float64(0.7116820630992358), 'micro_f1': np.float64(0.8291037951531779), 'mae': np.float64(0.23097523363249625)}
{'loss': 0.2585916221141815, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 3, 'validation_loss': np.float64(0.2271125737739646), 'macro_f1': np.float64(0.6688056528716951), 'micro_f1': np.float64(0.7865797896662095), 'mae': np.float64(0.2858168105063925)}
{'loss': 0.2585916221141815, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 4, 'validation_loss': np.float64(0.2271125737739646), 'macro_f1': np.float64(0.6008

epoch 4/10:  50%|█████     | 404/808 [02:26<22:59,  3.41s/it]

{'loss': 0.26915982365608215, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 1, 'validation_loss': np.float64(0.2121999178999576), 'macro_f1': np.float64(0.753308167107217), 'micro_f1': np.float64(0.8910608139003201), 'mae': np.float64(0.15324005312755828)}
{'loss': 0.26915982365608215, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 2, 'validation_loss': np.float64(0.2121999178999576), 'macro_f1': np.float64(0.675353898522952), 'micro_f1': np.float64(0.8345907636031092), 'mae': np.float64(0.21276099123102263)}
{'loss': 0.26915982365608215, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 3, 'validation_loss': np.float64(0.2121999178999576), 'macro_f1': np.float64(0.6318832386969778), 'micro_f1': np.float64(0.7927526291723822), 'mae': np.float64(0.26472011925474775)}
{'loss': 0.26915982365608215, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 4, 'validation_loss': np.float64(0.2121999178999576), 'macro_f1': np.float64

epoch 4/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]


{'loss': 0.17046396434307098, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 1, 'validation_loss': np.float64(0.20936721471556718), 'macro_f1': np.float64(0.8192831504050435), 'micro_f1': np.float64(0.8975765889346137), 'mae': np.float64(0.1570147752547805)}
{'loss': 0.17046396434307098, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 2, 'validation_loss': np.float64(0.20936721471556718), 'macro_f1': np.float64(0.7117567426174917), 'micro_f1': np.float64(0.8399634202103338), 'mae': np.float64(0.21640699181080214)}
{'loss': 0.17046396434307098, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 3, 'validation_loss': np.float64(0.20936721471556718), 'macro_f1': np.float64(0.617779831763933), 'micro_f1': np.float64(0.7941243712848651), 'mae': np.float64(0.26445058561223667)}
{'loss': 0.17046396434307098, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 4, 'validation_loss': np.float64(0.20936721471556718), 'macro_f1': np.flo

epoch 5/10:  50%|█████     | 404/808 [02:27<23:01,  3.42s/it]

{'loss': 0.28071296215057373, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 1, 'validation_loss': np.float64(0.20913583882476972), 'macro_f1': np.float64(0.8135486795074095), 'micro_f1': np.float64(0.897119341563786), 'mae': np.float64(0.1343047316907915)}
{'loss': 0.28071296215057373, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 2, 'validation_loss': np.float64(0.20913583882476972), 'macro_f1': np.float64(0.7362881763717551), 'micro_f1': np.float64(0.8399634202103338), 'mae': np.float64(0.19634164324250616)}
{'loss': 0.28071296215057373, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 3, 'validation_loss': np.float64(0.20913583882476972), 'macro_f1': np.float64(0.6673860132629591), 'micro_f1': np.float64(0.7982395976223137), 'mae': np.float64(0.2514857023306337)}
{'loss': 0.28071296215057373, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 4, 'validation_loss': np.float64(0.20913583882476972), 'macro_f1': np.floa

epoch 5/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]


{'loss': 0.3733613193035126, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 1, 'validation_loss': np.float64(0.21664870940688727), 'macro_f1': np.float64(0.8168460573717534), 'micro_f1': np.float64(0.9003200731595793), 'mae': np.float64(0.13867737110831477)}
{'loss': 0.3733613193035126, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 2, 'validation_loss': np.float64(0.21664870940688727), 'macro_f1': np.float64(0.720735042682071), 'micro_f1': np.float64(0.8393918609967993), 'mae': np.float64(0.20460070804762517)}
{'loss': 0.3733613193035126, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 3, 'validation_loss': np.float64(0.21664870940688727), 'macro_f1': np.float64(0.659146491980208), 'micro_f1': np.float64(0.7956104252400549), 'mae': np.float64(0.2636768748557184)}
{'loss': 0.3733613193035126, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 4, 'validation_loss': np.float64(0.21664870940688727), 'macro_f1': np.float64(

epoch 6/10:  50%|█████     | 404/808 [02:26<23:06,  3.43s/it]

{'loss': 0.22723953425884247, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 1, 'validation_loss': np.float64(0.21212584395771442), 'macro_f1': np.float64(0.8209283580973689), 'micro_f1': np.float64(0.9005486968449932), 'mae': np.float64(0.13764022572509405)}
{'loss': 0.22723953425884247, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 2, 'validation_loss': np.float64(0.21212584395771442), 'macro_f1': np.float64(0.7327254653421839), 'micro_f1': np.float64(0.8401920438957476), 'mae': np.float64(0.20198607876860714)}
{'loss': 0.22723953425884247, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 3, 'validation_loss': np.float64(0.21212584395771442), 'macro_f1': np.float64(0.6698703307900332), 'micro_f1': np.float64(0.7960676726108825), 'mae': np.float64(0.2594293405944281)}
{'loss': 0.22723953425884247, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 4, 'validation_loss': np.float64(0.21212584395771442), 'macro_f1': n

epoch 6/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]


{'loss': 0.241116002202034, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 1, 'validation_loss': np.float64(0.21238080660502115), 'macro_f1': np.float64(0.8192286794250653), 'micro_f1': np.float64(0.8956332876085963), 'mae': np.float64(0.15106211559416877)}
{'loss': 0.241116002202034, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 2, 'validation_loss': np.float64(0.21238080660502115), 'macro_f1': np.float64(0.7218114501385862), 'micro_f1': np.float64(0.8367626886145405), 'mae': np.float64(0.21324201691264955)}
{'loss': 0.241116002202034, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 3, 'validation_loss': np.float64(0.21238080660502115), 'macro_f1': np.float64(0.6576817035675956), 'micro_f1': np.float64(0.7940100594421582), 'mae': np.float64(0.26489463212439573)}
{'loss': 0.241116002202034, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 4, 'validation_loss': np.float64(0.21238080660502115), 'macro_f1': np.float64(0

epoch 7/10:  50%|█████     | 404/808 [02:26<22:55,  3.40s/it]

{'loss': 0.2058635652065277, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 1, 'validation_loss': np.float64(0.21062526754710986), 'macro_f1': np.float64(0.8207398132663268), 'micro_f1': np.float64(0.9012345679012346), 'mae': np.float64(0.13006116838171786)}
{'loss': 0.2058635652065277, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 2, 'validation_loss': np.float64(0.21062526754710986), 'macro_f1': np.float64(0.7218763120041135), 'micro_f1': np.float64(0.8406492912665752), 'mae': np.float64(0.2001784140687354)}
{'loss': 0.2058635652065277, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 3, 'validation_loss': np.float64(0.21062526754710986), 'macro_f1': np.float64(0.6602241695125078), 'micro_f1': np.float64(0.7961819844535893), 'mae': np.float64(0.2558813735814422)}
{'loss': 0.2058635652065277, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 4, 'validation_loss': np.float64(0.21062526754710986), 'macro_f1': np.float64

epoch 7/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]


{'loss': 0.34890875220298767, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 1, 'validation_loss': np.float64(0.20927275571486223), 'macro_f1': np.float64(0.8207608602019011), 'micro_f1': np.float64(0.9015775034293553), 'mae': np.float64(0.13284475238731885)}
{'loss': 0.34890875220298767, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 2, 'validation_loss': np.float64(0.20927275571486223), 'macro_f1': np.float64(0.7241111479564358), 'micro_f1': np.float64(0.8407636031092821), 'mae': np.float64(0.19834518557729142)}
{'loss': 0.34890875220298767, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 3, 'validation_loss': np.float64(0.20927275571486223), 'macro_f1': np.float64(0.6578813826150239), 'micro_f1': np.float64(0.7959533607681756), 'mae': np.float64(0.25585868423998737)}
{'loss': 0.34890875220298767, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 4, 'validation_loss': np.float64(0.20927275571486223), 'macro_f1': 

epoch 8/10:  50%|█████     | 404/808 [02:26<22:58,  3.41s/it]

{'loss': 0.22240698337554932, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 1, 'validation_loss': np.float64(0.20809780930479368), 'macro_f1': np.float64(0.8248580919619064), 'micro_f1': np.float64(0.9023776863283036), 'mae': np.float64(0.131243526444397)}
{'loss': 0.22240698337554932, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 2, 'validation_loss': np.float64(0.20809780930479368), 'macro_f1': np.float64(0.7250273294942948), 'micro_f1': np.float64(0.8403063557384545), 'mae': np.float64(0.2060660828328308)}
{'loss': 0.22240698337554932, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 3, 'validation_loss': np.float64(0.20809780930479368), 'macro_f1': np.float64(0.6600873721896237), 'micro_f1': np.float64(0.7951531778692272), 'mae': np.float64(0.2594929841328307)}
{'loss': 0.22240698337554932, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 4, 'validation_loss': np.float64(0.20809780930479368), 'macro_f1': np.f

epoch 8/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]


{'loss': 0.22848336398601532, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 1, 'validation_loss': np.float64(0.20845238794235216), 'macro_f1': np.float64(0.8232155092142336), 'micro_f1': np.float64(0.901920438957476), 'mae': np.float64(0.13034316929794837)}
{'loss': 0.22848336398601532, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 2, 'validation_loss': np.float64(0.20845238794235216), 'macro_f1': np.float64(0.7248106791282622), 'micro_f1': np.float64(0.8407636031092821), 'mae': np.float64(0.1975983227562314)}
{'loss': 0.22848336398601532, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 3, 'validation_loss': np.float64(0.20845238794235216), 'macro_f1': np.float64(0.6581804557021772), 'micro_f1': np.float64(0.7959533607681756), 'mae': np.float64(0.25811041306974936)}
{'loss': 0.22848336398601532, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 4, 'validation_loss': np.float64(0.20845238794235216), 'macro_f1': np.flo

epoch 9/10:  50%|█████     | 404/808 [02:26<22:54,  3.40s/it]

{'loss': 0.17318235337734222, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 1, 'validation_loss': np.float64(0.21004490144010904), 'macro_f1': np.float64(0.8173206014510397), 'micro_f1': np.float64(0.9014631915866483), 'mae': np.float64(0.13065018267006817)}
{'loss': 0.17318235337734222, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 2, 'validation_loss': np.float64(0.21004490144010904), 'macro_f1': np.float64(0.7311173295581028), 'micro_f1': np.float64(0.8413351623228167), 'mae': np.float64(0.19994804528761276)}
{'loss': 0.17318235337734222, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 3, 'validation_loss': np.float64(0.21004490144010904), 'macro_f1': np.float64(0.660347767872389), 'micro_f1': np.float64(0.7948102423411065), 'mae': np.float64(0.26080482703905505)}
{'loss': 0.17318235337734222, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 4, 'validation_loss': np.float64(0.21004490144010904), 'macro_f1': np.fl

epoch 9/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]


{'loss': 0.23849865794181824, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 1, 'validation_loss': np.float64(0.2083185490058816), 'macro_f1': np.float64(0.8191776634385995), 'micro_f1': np.float64(0.9032921810699589), 'mae': np.float64(0.13007749269767713)}
{'loss': 0.23849865794181824, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 2, 'validation_loss': np.float64(0.2083185490058816), 'macro_f1': np.float64(0.7273341930170633), 'micro_f1': np.float64(0.8405349794238683), 'mae': np.float64(0.1989411776995064)}
{'loss': 0.23849865794181824, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 3, 'validation_loss': np.float64(0.2083185490058816), 'macro_f1': np.float64(0.6629718413244697), 'micro_f1': np.float64(0.7958390489254686), 'mae': np.float64(0.25643559501065144)}
{'loss': 0.23849865794181824, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 4, 'validation_loss': np.float64(0.2083185490058816), 'macro_f1': np.fl

epoch 10/10:  50%|█████     | 404/808 [02:26<22:54,  3.40s/it]

{'loss': 0.27015119791030884, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 1, 'validation_loss': np.float64(0.20880755166644635), 'macro_f1': np.float64(0.8252242120088881), 'micro_f1': np.float64(0.9027206218564243), 'mae': np.float64(0.12898026911622468)}
{'loss': 0.27015119791030884, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 2, 'validation_loss': np.float64(0.20880755166644635), 'macro_f1': np.float64(0.7238188539615852), 'micro_f1': np.float64(0.840992226794696), 'mae': np.float64(0.19768315465381342)}
{'loss': 0.27015119791030884, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 3, 'validation_loss': np.float64(0.20880755166644635), 'macro_f1': np.float64(0.6643762571319954), 'micro_f1': np.float64(0.7957247370827618), 'mae': np.float64(0.2559755084648536)}
{'loss': 0.27015119791030884, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 4, 'validation_loss': np.float64(0.20880755166644635), 'macro_f1': np.flo

epoch 10/10: 100%|██████████| 808/808 [04:53<00:00,  2.75it/s]

{'loss': 0.21371784806251526, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 1, 'validation_loss': np.float64(0.20844748243689537), 'macro_f1': np.float64(0.8204156835263084), 'micro_f1': np.float64(0.9023776863283036), 'mae': np.float64(0.12869707940342906)}
{'loss': 0.21371784806251526, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 2, 'validation_loss': np.float64(0.20844748243689537), 'macro_f1': np.float64(0.7285784985633622), 'micro_f1': np.float64(0.8413351623228167), 'mae': np.float64(0.1975264953271462)}
{'loss': 0.21371784806251526, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 3, 'validation_loss': np.float64(0.20844748243689537), 'macro_f1': np.float64(0.6652830890569245), 'micro_f1': np.float64(0.7967535436671239), 'mae': np.float64(0.2559223513724822)}
{'loss': 0.21371784806251526, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 4, 'validation_loss': np.float64(0.20844748243689537), 'macro_f1': np




In [None]:
def predict(x, static=None):
    if static is None:
        out, _ = model(torch.tensor(x), val_h)
    else:
        out, _ = model(torch.tensor(x), val_h, static)
    return out

In [None]:
dict_map = {
    "y_pred": [],
    "y_pred_rounded": [],
    "fips": [],
    "date": [],
    "y_true": [],
    "week": [],
}
i = 0
for x, static, y in tqdm(
    valid_loader, # ou valid_loader
    desc="validation predictions...",
):
    val_h = tuple([each.data.to(device) for each in model.init_hidden(len(x))])
    x, static, y = x.to(device), static.to(device), y.to(device)
    with torch.no_grad():
        if use_static:
            pred = predict(x, static).clone().detach()
        else:
            pred = predict(x).clone().detach()
    for w in range(output_weeks):
        dict_map["y_pred"] += [float(p[w]) for p in pred]
        dict_map["y_pred_rounded"] += [int(p.round()[w]) for p in pred]
        dict_map["fips"] += [f[1][0] for f in valid_fips[i : i + len(x)]]
        dict_map["date"] += [f[1][1] for f in valid_fips[i : i + len(x)]]
        dict_map["y_true"] += [float(item[w]) for item in y]
        dict_map["week"] += [w] * len(x)
    i += len(x)
df = pd.DataFrame(dict_map)

  out, _ = model(torch.tensor(x), val_h, static)
validation predictions...: 100%|██████████| 69/69 [00:11<00:00,  6.22it/s]


In [None]:
for w in range(6):
    wdf = df[df['week']==w]
    mae = mean_absolute_error(wdf['y_true'], wdf['y_pred']).round(3)
    f1 = f1_score(wdf['y_true'].round(),wdf['y_pred'].round(), average='macro').round(3)
    print(f"Week {w+1}", f"MAE {mae}", f"F1 {f1}")

Week 1 MAE 0.133 F1 0.816
Week 2 MAE 0.201 F1 0.723
Week 3 MAE 0.258 F1 0.65
Week 4 MAE 0.312 F1 0.574
Week 5 MAE 0.36 F1 0.542
Week 6 MAE 0.403 F1 0.499
