# First contact with the dataset
This Notebook has as objective to replicate the results of Minixhofer et al. (2021)

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from scipy.interpolate import interp1d
from sklearn.preprocessing import RobustScaler


## Importing the data in a unique dictionary

In [2]:
filesList = os.listdir('../src')
print(filesList)

['soil_data.csv', 'train_timeseries', 'test_timeseries', 'validation_timeseries']


In [3]:
dataDic = {"train": pd.read_csv("../src/train_timeseries/train_timeseries.csv"),
           "test": pd.read_csv("../src/test_timeseries/test_timeseries.csv"),
           "validation": pd.read_csv("../src/validation_timeseries/validation_timeseries.csv"),
           "soil" : pd.read_csv("../src/soil_data.csv"),
           }


In [4]:
dataDic["train"].columns

Index(['fips', 'date', 'PRECTOT', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET',
       'T2M_MAX', 'T2M_MIN', 'T2M_RANGE', 'TS', 'WS10M', 'WS10M_MAX',
       'WS10M_MIN', 'WS10M_RANGE', 'WS50M', 'WS50M_MAX', 'WS50M_MIN',
       'WS50M_RANGE', 'score'],
      dtype='object')

In [19]:
class2id = {
    'None': 0,
    'D0': 1,
    'D1': 2,
    'D2': 3,
    'D3': 4,
    'D4': 5,
}
id2class = {v: k for k, v in class2id.items()}

In [5]:
dfs = {
    k: dataDic[k].set_index(['fips', 'date'])
    for k in dataDic.keys() if k != "soil"
}

dfs["soil"] = dataDic["soil"]

In [6]:
dfs["train"]

Unnamed: 0_level_0,Unnamed: 1_level_0,PRECTOT,PS,QV2M,T2M,T2MDEW,T2MWET,T2M_MAX,T2M_MIN,T2M_RANGE,TS,WS10M,WS10M_MAX,WS10M_MIN,WS10M_RANGE,WS50M,WS50M_MAX,WS50M_MIN,WS50M_RANGE,score
fips,date,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
1001,2000-01-01,0.22,100.51,9.65,14.74,13.51,13.51,20.96,11.46,9.50,14.65,2.20,2.94,1.49,1.46,4.85,6.04,3.23,2.81,
1001,2000-01-02,0.20,100.55,10.42,16.69,14.71,14.71,22.80,12.61,10.18,16.60,2.52,3.43,1.83,1.60,5.33,6.13,3.72,2.41,
1001,2000-01-03,3.65,100.15,11.76,18.49,16.52,16.52,22.73,15.32,7.41,18.41,4.03,5.33,2.66,2.67,7.53,9.52,5.87,3.66,
1001,2000-01-04,15.95,100.29,6.42,11.40,6.09,6.10,18.09,2.16,15.92,11.31,3.84,5.67,2.08,3.59,6.73,9.31,3.74,5.58,1.0
1001,2000-01-05,0.00,101.15,2.95,3.86,-3.29,-3.20,10.82,-2.66,13.48,2.65,1.60,2.50,0.52,1.98,2.94,4.85,0.65,4.19,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
56043,2016-12-27,0.16,82.88,1.63,-7.97,-13.49,-12.81,-1.39,-13.60,12.21,-9.41,5.90,7.63,3.61,4.02,8.58,10.39,5.92,4.47,0.0
56043,2016-12-28,0.02,83.33,1.41,-8.71,-14.10,-13.84,-2.49,-13.56,11.07,-10.55,6.50,11.43,4.11,7.32,9.92,14.49,7.26,7.22,
56043,2016-12-29,0.00,83.75,1.59,-7.96,-13.30,-13.03,0.42,-14.51,14.93,-10.29,4.29,6.24,2.03,4.22,6.56,10.07,3.20,6.87,
56043,2016-12-30,1.22,82.49,2.63,-2.94,-7.40,-7.33,3.76,-6.86,10.62,-4.14,4.98,7.34,1.99,5.35,7.28,10.12,3.24,6.89,


## Interpolation pour les données manquantes

In [7]:
def interpolate_nans(padata, pkind='linear'):
    """
    see: https://stackoverflow.com/a/53050216/2167159
    """
    aindexes = np.arange(padata.shape[0])
    agood_indexes, = np.where(np.isfinite(padata))
    f = interp1d(agood_indexes
               , padata[agood_indexes]
               , bounds_error=False
               , copy=False
               , fill_value="extrapolate"
               , kind=pkind)
    return f(aindexes)

## Function to encode the cycling feature: year-day, using sin/cos

In [8]:
def date_encode(date):
    if isinstance(date, str):
        date = datetime.strptime(date, "%Y-%m-%d")
    return (
        np.sin(2 * np.pi * date.timetuple().tm_yday / 366),
        np.cos(2 * np.pi * date.timetuple().tm_yday / 366),
    )

## Function to load the data

In [9]:
def loadXY(
    df,
    random_state=42, # keep this at 42
    window_size=180, # how many days in the past (default/competition: 180)
    target_size=6, # how many weeks into the future (default/competition: 6)
    fuse_past=True, # add the past drought observations? (default: True)
    return_fips=False, # return the county identifier (do not use for predictions)
    encode_season=True, # encode the season using the function above (default: True) 
    use_prev_year=False, # add observations from 1 year prior?
):
    df = dfs[df]
    soil_df = dfs["soil"]
    time_data_cols = sorted(
        [c for c in df.columns if c not in ["fips", "date", "score"]]
    )
    static_data_cols = sorted(
        [c for c in soil_df.columns if c not in ["soil", "lat", "lon"]]
    )
    count = 0
    score_df = df.dropna(subset=["score"])
    X_static = np.empty((len(df) // window_size, len(static_data_cols)))
    X_fips_date = []
    add_dim = 0
    if use_prev_year:
        add_dim += len(time_data_cols)
    if fuse_past:
        add_dim += 1
        if use_prev_year:
            add_dim += 1
    if encode_season:
        add_dim += 2
    X_time = np.empty(
        (len(df) // window_size, window_size, len(time_data_cols) + add_dim)
    )
    y_past = np.empty((len(df) // window_size, window_size))
    y_target = np.empty((len(df) // window_size, target_size))
    if random_state is not None:
        np.random.seed(random_state)
    for fips in tqdm(score_df.index.get_level_values(0).unique()):
        if random_state is not None:
            start_i = np.random.randint(1, window_size)
        else:
            start_i = 1
        fips_df = df[(df.index.get_level_values(0) == fips)]
        X = fips_df[time_data_cols].values
        y = fips_df["score"].values
        X_s = soil_df[soil_df["fips"] == fips][static_data_cols].values[0]
        for i in range(start_i, len(y) - (window_size + target_size * 7), window_size):
            X_fips_date.append((fips, fips_df.index[i : i + window_size][-1]))
            X_time[count, :, : len(time_data_cols)] = X[i : i + window_size]
            if use_prev_year:
                if i < 365 or len(X[i - 365 : i + window_size - 365]) < window_size:
                    continue
                X_time[count, :, -len(time_data_cols) :] = X[
                    i - 365 : i + window_size - 365
                ]
            if not fuse_past:
                y_past[count] = interpolate_nans(y[i : i + window_size])
            else:
                X_time[count, :, len(time_data_cols)] = interpolate_nans(
                    y[i : i + window_size]
                )
            if encode_season:
                enc_dates = [
                    date_encode(d) for f, d in fips_df.index[i : i + window_size].values
                ]
                d_sin, d_cos = [s for s, c in enc_dates], [c for s, c in enc_dates]
                X_time[count, :, len(time_data_cols) + (add_dim - 2)] = d_sin
                X_time[count, :, len(time_data_cols) + (add_dim - 2) + 1] = d_cos
            temp_y = y[i + window_size : i + window_size + target_size * 7]
            y_target[count] = np.array(temp_y[~np.isnan(temp_y)][:target_size])
            X_static[count] = X_s
            count += 1
    print(f"loaded {count} samples")
    results = [X_static[:count], X_time[:count], y_target[:count]]
    if not fuse_past:
        results.append(y_past[:count])
    if return_fips:
        results.append(X_fips_date)
    return results

In [10]:
scaler_dict = {}
scaler_dict_static = {}
scaler_dict_past = {}


def normalize(X_static, X_time, y_past=None, fit=False):
    for index in tqdm(range(X_time.shape[-1])):
        if fit:
            scaler_dict[index] = RobustScaler().fit(X_time[:, :, index].reshape(-1, 1))
        X_time[:, :, index] = (
            scaler_dict[index]
            .transform(X_time[:, :, index].reshape(-1, 1))
            .reshape(-1, X_time.shape[-2])
        )
    for index in tqdm(range(X_static.shape[-1])):
        if fit:
            scaler_dict_static[index] = RobustScaler().fit(
                X_static[:, index].reshape(-1, 1)
            )
        X_static[:, index] = (
            scaler_dict_static[index]
            .transform(X_static[:, index].reshape(-1, 1))
            .reshape(1, -1)
        )
    index = 0
    if y_past is not None:
        if fit:
            scaler_dict_past[index] = RobustScaler().fit(y_past.reshape(-1, 1))
        y_past[:, :] = (
            scaler_dict_past[index]
            .transform(y_past.reshape(-1, 1))
            .reshape(-1, y_past.shape[-1])
        )
        return X_static, X_time, y_past
    return X_static, X_time

In [11]:
X_tabular_train, X_time_train, y_target_train = loadXY("train")
print("train shape", X_time_train.shape)
X_tabular_validation, X_time_valid, y_target_valid, valid_fips = loadXY("validation", return_fips=True)
print("validation shape", X_time_valid.shape)
X_tabular_train, X_time_train = normalize(X_tabular_train, X_time_train, fit=True)
X_tabular_validation, X_time_valid = normalize(X_tabular_validation, X_time_valid)

100%|██████████| 3108/3108 [04:25<00:00, 11.68it/s]


loaded 103390 samples
train shape (103390, 180, 21)


100%|██████████| 3108/3108 [00:25<00:00, 120.05it/s]


loaded 8748 samples
validation shape (8748, 180, 21)


100%|██████████| 21/21 [00:16<00:00,  1.27it/s]
100%|██████████| 30/30 [00:00<00:00, 466.52it/s]
100%|██████████| 21/21 [00:00<00:00, 42.65it/s]
100%|██████████| 30/30 [00:00<00:00, 13648.89it/s]


In [12]:
batch_size = 128
output_weeks = 6
use_static = True
hidden_dim = 512
n_layers = 2
ffnn_layers = 2
dropout = 0.1
one_cycle = True
lr = 7e-5
epochs = 10
clip = 5

In [15]:
import torch
from torch.utils.data import TensorDataset, DataLoader

train_data = TensorDataset(
    torch.tensor(X_time_train),
    torch.tensor(X_tabular_train),
    torch.tensor(y_target_train[:, :output_weeks]),
)
train_loader = DataLoader(
    train_data, shuffle=True, batch_size=batch_size, drop_last=False
)
valid_data = TensorDataset(
    torch.tensor(X_time_valid),
    torch.tensor(X_tabular_validation),
    torch.tensor(y_target_valid[:, :output_weeks]),
)
valid_loader = DataLoader(
    valid_data, shuffle=False, batch_size=batch_size, drop_last=False
)

In [16]:
import torch
from torch import nn
from sklearn.metrics import f1_score, mean_absolute_error

class DroughtNetLSTM(nn.Module):
    def __init__(
        self,
        output_size,
        num_input_features,
        hidden_dim,
        n_layers,
        ffnn_layers,
        drop_prob,
        static_dim=0,
    ):
        super(DroughtNetLSTM, self).__init__()
        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

        self.lstm = nn.LSTM(
            num_input_features,
            hidden_dim,
            n_layers,
            dropout=drop_prob,
            batch_first=True,
        )
        self.dropout = nn.Dropout(drop_prob)
        self.fflayers = []
        for i in range(ffnn_layers - 1):
            if i == 0:
                self.fflayers.append(nn.Linear(hidden_dim + static_dim, hidden_dim))
            else:
                self.fflayers.append(nn.Linear(hidden_dim, hidden_dim))
        self.fflayers = nn.ModuleList(self.fflayers)
        self.final = nn.Linear(hidden_dim, output_size)

    def forward(self, x, hidden, static=None):
        batch_size = x.size(0)
        x = x.to(dtype=torch.float32)
        if static is not None:
            static = static.to(dtype=torch.float32)
        lstm_out, hidden = self.lstm(x, hidden)
        lstm_out = lstm_out[:, -1, :]

        out = self.dropout(lstm_out)
        for i in range(len(self.fflayers)):
            if i == 0 and static is not None:
                out = self.fflayers[i](torch.cat((out, static), 1))
            else:
                out = self.fflayers[i](out)
        out = self.final(out)

        out = out.view(batch_size, -1)
        return out, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (
            weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device),
            weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device),
        )
        return hidden

In [20]:
is_cuda = torch.cuda.is_available()
if is_cuda:
    device = torch.device("cuda")
    print("using GPU")
else:
    device = torch.device("cpu")
    print("using CPU")
static_dim = 0
if use_static:
    static_dim = X_tabular_train.shape[-1]
model = DroughtNetLSTM(
    output_weeks,
    X_time_train.shape[-1],
    hidden_dim,
    n_layers,
    ffnn_layers,
    dropout,
    static_dim,
)
model.to(device)
loss_function = nn.MSELoss()
if one_cycle:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=lr, steps_per_epoch=len(train_loader), epochs=epochs
    )
else:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
counter = 0
valid_loss_min = np.inf
torch.manual_seed(42)
np.random.seed(42)
for i in range(epochs):
    h = model.init_hidden(batch_size)

    for k, (inputs, static, labels) in tqdm(
        enumerate(train_loader),
        desc=f"epoch {i+1}/{epochs}",
        total=len(train_loader),
    ):
        model.train()
        counter += 1
        if len(inputs) < batch_size:
            h = model.init_hidden(len(inputs))
        h = tuple([e.data for e in h])
        inputs, labels, static = (
            inputs.to(device),
            labels.to(device),
            static.to(device),
        )
        model.zero_grad()
        if use_static:
            output, h = model(inputs, h, static)
        else:
            output, h = model(inputs, h)
        loss = loss_function(output, labels.float())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        if one_cycle:
            scheduler.step()

        with torch.no_grad():
            if k == len(train_loader) - 1 or k == (len(train_loader) - 1) // 2:
                val_h = model.init_hidden(batch_size)
                val_losses = []
                model.eval()
                labels = []
                preds = []
                raw_labels = []
                raw_preds = []
                for inp, stat, lab in valid_loader:
                    if len(inp) < batch_size:
                        val_h = model.init_hidden(len(inp))
                    val_h = tuple([each.data for each in val_h])
                    inp, lab, stat = inp.to(device), lab.to(device), stat.to(device)
                    if use_static:
                        out, val_h = model(inp, val_h, stat)
                    else:
                        out, val_h = model(inp, val_h)
                    val_loss = loss_function(out, lab.float())
                    val_losses.append(val_loss.item())
                    for labs in lab:
                        labels.append([int(l.round()) for l in labs])
                        raw_labels.append([float(l) for l in labs])
                    for pred in out:
                        preds.append([int(p.round()) for p in pred])
                        raw_preds.append([float(p) for p in pred])
                # log data
                labels = np.array(labels)
                preds = np.clip(np.array(preds), 0, 5)
                raw_preds = np.array(raw_preds)
                raw_labels = np.array(raw_labels)
                for i in range(output_weeks):
                    log_dict = {
                        "loss": float(loss),
                        "epoch": counter / len(train_loader),
                        "step": counter,
                        "lr": optimizer.param_groups[0]["lr"],
                        "week": i + 1,
                    }
                    # w = f'week_{i+1}_'
                    w = ""
                    log_dict[f"{w}validation_loss"] = np.mean(val_losses)
                    log_dict[f"{w}macro_f1"] = f1_score(
                        labels[:, i], preds[:, i], average="macro"
                    )
                    log_dict[f"{w}micro_f1"] = f1_score(
                        labels[:, i], preds[:, i], average="micro"
                    )
                    log_dict[f"{w}mae"] = mean_absolute_error(
                        raw_labels[:, i], raw_preds[:, i]
                    )
                    print(log_dict)
                    for j, f1 in enumerate(
                        f1_score(labels[:, i], preds[:, i], average=None)
                    ):
                        log_dict[f"{w}{id2class[j]}_f1"] = f1
                    model.train()
                if np.mean(val_losses) <= valid_loss_min:
                    torch.save(model.state_dict(), "./state_dict.pt")
                    print(
                        "Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
                            valid_loss_min, np.mean(val_losses)
                        )
                    )
                    valid_loss_min = np.mean(val_losses)

using GPU


epoch 1/10:  50%|█████     | 404/808 [02:27<26:34,  3.95s/it]

{'loss': 0.4094846844673157, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 1, 'validation_loss': np.float64(0.41741873647855676), 'macro_f1': np.float64(0.49162178476306456), 'micro_f1': np.float64(0.7056470050297211), 'mae': np.float64(0.4319537513668917)}
{'loss': 0.4094846844673157, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 2, 'validation_loss': np.float64(0.41741873647855676), 'macro_f1': np.float64(0.42621899660617557), 'micro_f1': np.float64(0.6494055784179241), 'mae': np.float64(0.485662974291411)}
{'loss': 0.4094846844673157, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 3, 'validation_loss': np.float64(0.41741873647855676), 'macro_f1': np.float64(0.3832968291751599), 'micro_f1': np.float64(0.6251714677640604), 'mae': np.float64(0.5237166664131289)}
{'loss': 0.4094846844673157, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 4, 'validation_loss': np.float64(0.41741873647855676), 'macro_f1': np.float64(0.3

epoch 1/10: 100%|██████████| 808/808 [04:56<00:00,  2.73it/s]


{'loss': 0.22092455625534058, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 1, 'validation_loss': np.float64(0.24655242415441983), 'macro_f1': np.float64(0.7675159300585278), 'micro_f1': np.float64(0.8699131229995427), 'mae': np.float64(0.21856773082608355)}
{'loss': 0.22092455625534058, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 2, 'validation_loss': np.float64(0.24655242415441983), 'macro_f1': np.float64(0.6794433732123664), 'micro_f1': np.float64(0.8182441700960219), 'mae': np.float64(0.2667220282446425)}
{'loss': 0.22092455625534058, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 3, 'validation_loss': np.float64(0.24655242415441983), 'macro_f1': np.float64(0.6269059077376175), 'micro_f1': np.float64(0.7801783264746228), 'mae': np.float64(0.3134492456122819)}
{'loss': 0.22092455625534058, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 4, 'validation_loss': np.float64(0.24655242415441983), 'macro_f1': np.flo

epoch 2/10:  50%|█████     | 404/808 [02:28<26:48,  3.98s/it]

{'loss': 0.3058808743953705, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 1, 'validation_loss': np.float64(0.23733237315563188), 'macro_f1': np.float64(0.780181140314852), 'micro_f1': np.float64(0.8810013717421125), 'mae': np.float64(0.18999120643393813)}
{'loss': 0.3058808743953705, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 2, 'validation_loss': np.float64(0.23733237315563188), 'macro_f1': np.float64(0.7061587730609968), 'micro_f1': np.float64(0.8278463648834019), 'mae': np.float64(0.24193130400416185)}
{'loss': 0.3058808743953705, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 3, 'validation_loss': np.float64(0.23733237315563188), 'macro_f1': np.float64(0.6461879449883297), 'micro_f1': np.float64(0.7866941015089163), 'mae': np.float64(0.2951739036022269)}
{'loss': 0.3058808743953705, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 4, 'validation_loss': np.float64(0.23733237315563188), 'macro_f1': np.flo

epoch 2/10: 100%|██████████| 808/808 [04:56<00:00,  2.72it/s]


{'loss': 0.24517464637756348, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 1, 'validation_loss': np.float64(0.22640269865160403), 'macro_f1': np.float64(0.6430461783243833), 'micro_f1': np.float64(0.8776863283036123), 'mae': np.float64(0.17175334666851186)}
{'loss': 0.24517464637756348, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 2, 'validation_loss': np.float64(0.22640269865160403), 'macro_f1': np.float64(0.56380394980159), 'micro_f1': np.float64(0.8245313214449017), 'mae': np.float64(0.23194525412324735)}
{'loss': 0.24517464637756348, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 3, 'validation_loss': np.float64(0.22640269865160403), 'macro_f1': np.float64(0.47319659026853006), 'micro_f1': np.float64(0.7788065843621399), 'mae': np.float64(0.2881543252586252)}
{'loss': 0.24517464637756348, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 4, 'validation_loss': np.float64(0.22640269865160403), 'macro_f1': np.flo

epoch 3/10:  50%|█████     | 404/808 [02:28<26:32,  3.94s/it]

{'loss': 0.24621926248073578, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 1, 'validation_loss': np.float64(0.2279852118505084), 'macro_f1': np.float64(0.7076798848220333), 'micro_f1': np.float64(0.8852309099222679), 'mae': np.float64(0.17500406669844157)}
{'loss': 0.24621926248073578, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 2, 'validation_loss': np.float64(0.2279852118505084), 'macro_f1': np.float64(0.6309731875793292), 'micro_f1': np.float64(0.8270461819844536), 'mae': np.float64(0.23403148351083772)}
{'loss': 0.24621926248073578, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 3, 'validation_loss': np.float64(0.2279852118505084), 'macro_f1': np.float64(0.5986587970334974), 'micro_f1': np.float64(0.782350251486054), 'mae': np.float64(0.2949911428322209)}
{'loss': 0.24621926248073578, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 4, 'validation_loss': np.float64(0.2279852118505084), 'macro_f1': np.float64

epoch 3/10: 100%|██████████| 808/808 [04:56<00:00,  2.72it/s]


{'loss': 0.25213363766670227, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 1, 'validation_loss': np.float64(0.22551781765144804), 'macro_f1': np.float64(0.7811816570604232), 'micro_f1': np.float64(0.8830589849108368), 'mae': np.float64(0.16936440724378404)}
{'loss': 0.25213363766670227, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 2, 'validation_loss': np.float64(0.22551781765144804), 'macro_f1': np.float64(0.7099954484769864), 'micro_f1': np.float64(0.8285322359396433), 'mae': np.float64(0.23013675750972004)}
{'loss': 0.25213363766670227, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 3, 'validation_loss': np.float64(0.22551781765144804), 'macro_f1': np.float64(0.6647671499070297), 'micro_f1': np.float64(0.7874942844078646), 'mae': np.float64(0.28225138873771266)}
{'loss': 0.25213363766670227, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 4, 'validation_loss': np.float64(0.22551781765144804), 'macro_f1': np.float

epoch 4/10:  48%|████▊     | 387/808 [02:10<02:22,  2.96it/s]


KeyboardInterrupt: 

In [None]:
def predict(x, static=None):
    if static is None:
        out, _ = model(torch.tensor(x), val_h)
    else:
        out, _ = model(torch.tensor(x), val_h, static)
    return out

In [None]:
dict_map = {
    "y_pred": [],
    "y_pred_rounded": [],
    "fips": [],
    "date": [],
    "y_true": [],
    "week": [],
}
i = 0
for x, static, y in tqdm(
    valid_loader,
    desc="validation predictions...",
):
    val_h = tuple([each.data.cpu() for each in model.init_hidden(len(x))])
    with torch.no_grad():
        if use_static:
            pred = predict(x, static).clone().detach()
        else:
            pred = predict(x).clone().detach()
    for w in range(output_weeks):
        dict_map["y_pred"] += [float(p[w]) for p in pred]
        dict_map["y_pred_rounded"] += [int(p.round()[w]) for p in pred]
        dict_map["fips"] += [f[1][0] for f in valid_fips[i : i + len(x)]]
        dict_map["date"] += [f[1][1] for f in valid_fips[i : i + len(x)]]
        dict_map["y_true"] += [float(item[w]) for item in y]
        dict_map["week"] += [w] * len(x)
    i += len(x)
df = pd.DataFrame(dict_map)

In [None]:
for w in range(6):
    wdf = df[df['week']==w]
    mae = mean_absolute_error(wdf['y_true'], wdf['y_pred']).round(3)
    f1 = f1_score(wdf['y_true'].round(),wdf['y_pred'].round(), average='macro').round(3)
    print(f"Week {w+1}", f"MAE {mae}", f"F1 {f1}")