# First contact with the dataset
This Notebook has as objective to replicate the results of Minixhofer et al. (2021)

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from scipy.interpolate import interp1d
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import f1_score, mean_absolute_error

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter



**Note :** Just to have a normalized methodology to save and visualize the results of all the experiments trought this projet I add and configure a Tensorboard-SummaryWriter. In the same way I have changed the training cycle to put the results in the tensorboard format.

## Importing the data in a unique dictionary

In [2]:
filesList = os.listdir('../src')
print(filesList)

['soil_data.csv', 'train_timeseries', 'counties.geojson', 'test_timeseries', 'validation_timeseries', 'counties.zip']


In [3]:
dataDic = {"train": pd.read_csv("../src/train_timeseries/train_timeseries.csv"),
           "test": pd.read_csv("../src/test_timeseries/test_timeseries.csv"),
           "validation": pd.read_csv("../src/validation_timeseries/validation_timeseries.csv"),
           "soil" : pd.read_csv("../src/soil_data.csv"),
           }


In [4]:
dataDic["train"].columns

Index(['fips', 'date', 'PRECTOT', 'PS', 'QV2M', 'T2M', 'T2MDEW', 'T2MWET',
       'T2M_MAX', 'T2M_MIN', 'T2M_RANGE', 'TS', 'WS10M', 'WS10M_MAX',
       'WS10M_MIN', 'WS10M_RANGE', 'WS50M', 'WS50M_MAX', 'WS50M_MIN',
       'WS50M_RANGE', 'score'],
      dtype='object')

In [5]:
class2id = {
    'None': 0,
    'D0': 1,
    'D1': 2,
    'D2': 3,
    'D3': 4,
    'D4': 5,
}
id2class = {v: k for k, v in class2id.items()}

In [6]:
dfs = {
    k: dataDic[k].set_index(['fips', 'date'])
    for k in dataDic.keys() if k != "soil"
}

dfs["soil"] = dataDic["soil"]

In [7]:
dfs["train"]

Unnamed: 0_level_0,Unnamed: 1_level_0,PRECTOT,PS,QV2M,T2M,T2MDEW,T2MWET,T2M_MAX,T2M_MIN,T2M_RANGE,TS,WS10M,WS10M_MAX,WS10M_MIN,WS10M_RANGE,WS50M,WS50M_MAX,WS50M_MIN,WS50M_RANGE,score
fips,date,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
1001,2000-01-01,0.22,100.51,9.65,14.74,13.51,13.51,20.96,11.46,9.50,14.65,2.20,2.94,1.49,1.46,4.85,6.04,3.23,2.81,
1001,2000-01-02,0.20,100.55,10.42,16.69,14.71,14.71,22.80,12.61,10.18,16.60,2.52,3.43,1.83,1.60,5.33,6.13,3.72,2.41,
1001,2000-01-03,3.65,100.15,11.76,18.49,16.52,16.52,22.73,15.32,7.41,18.41,4.03,5.33,2.66,2.67,7.53,9.52,5.87,3.66,
1001,2000-01-04,15.95,100.29,6.42,11.40,6.09,6.10,18.09,2.16,15.92,11.31,3.84,5.67,2.08,3.59,6.73,9.31,3.74,5.58,1.0
1001,2000-01-05,0.00,101.15,2.95,3.86,-3.29,-3.20,10.82,-2.66,13.48,2.65,1.60,2.50,0.52,1.98,2.94,4.85,0.65,4.19,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
56043,2016-12-27,0.16,82.88,1.63,-7.97,-13.49,-12.81,-1.39,-13.60,12.21,-9.41,5.90,7.63,3.61,4.02,8.58,10.39,5.92,4.47,0.0
56043,2016-12-28,0.02,83.33,1.41,-8.71,-14.10,-13.84,-2.49,-13.56,11.07,-10.55,6.50,11.43,4.11,7.32,9.92,14.49,7.26,7.22,
56043,2016-12-29,0.00,83.75,1.59,-7.96,-13.30,-13.03,0.42,-14.51,14.93,-10.29,4.29,6.24,2.03,4.22,6.56,10.07,3.20,6.87,
56043,2016-12-30,1.22,82.49,2.63,-2.94,-7.40,-7.33,3.76,-6.86,10.62,-4.14,4.98,7.34,1.99,5.35,7.28,10.12,3.24,6.89,


## Interpolation pour les données manquantes

In [8]:
def interpolate_nans(padata, pkind='linear'):
    """
    see: https://stackoverflow.com/a/53050216/2167159
    """
    aindexes = np.arange(padata.shape[0])
    agood_indexes, = np.where(np.isfinite(padata))
    f = interp1d(agood_indexes
               , padata[agood_indexes]
               , bounds_error=False
               , copy=False
               , fill_value="extrapolate"
               , kind=pkind)
    return f(aindexes)

## Function to encode the cycling feature: year-day, using sin/cos

In [9]:
def date_encode(date):
    if isinstance(date, str):
        date = datetime.strptime(date, "%Y-%m-%d")
    return (
        np.sin(2 * np.pi * date.timetuple().tm_yday / 366),
        np.cos(2 * np.pi * date.timetuple().tm_yday / 366),
    )

**Add** A dictionary to encode the categorical soil variables

## Function to load the data (Modified from Minixhofer)

In [10]:
def loadXY(
    df,
    random_state=42,
    window_size=180, # how many days in the past (default/competition: 180)
    target_size=6, # how many weeks into the future (default/competition: 6)
    fuse_past=True, # add the past drought observations? (default: True)
    return_fips=False, # return the county identifier (do not use for predictions)
    encode_season=True, # encode the season using the function above (default: True) 
    use_prev_year=False, # add observations from 1 year prior?
):
    dico_trad = {}
    df = dfs[df]
    soil_df = dfs["soil"]
    for cat in ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5", "SQ6", "SQ7"]:
        dico_trad[cat] = {j: i for i,j in enumerate(sorted(soil_df[cat].unique()))}
        soil_df[cat] = soil_df[cat].map(dico_trad[cat])
    time_data_cols = sorted(
        [c for c in df.columns if c not in ["fips", "date", "score"]]
    )
    static_data_cols = sorted(
        [c for c in soil_df.columns if c not in ["fips", "lat", "lon",
                                                 "SQ1", "SQ2", "SQ3",
                                                 "SQ4", "SQ5", "SQ6",
                                                 "SQ7"]]
    )
    static_cat_cols = sorted(
        [c for c in soil_df.columns if c in ["SQ1", "SQ2", "SQ3",
                                             "SQ4", "SQ5", "SQ6",
                                             "SQ7"]]
    )

    count = 0
    score_df = df.dropna(subset=["score"])
    X_static = np.empty((len(df) // window_size, len(static_data_cols)))
    X_static_cat = np.empty((len(df) // window_size, len(static_cat_cols)))
    X_fips_date = []
    add_dim = 0
    if use_prev_year:
        add_dim += len(time_data_cols)
    if fuse_past:
        add_dim += 1
        if use_prev_year:
            add_dim += 1
    if encode_season:
        add_dim += 2
    X_time = np.empty(
        (len(df) // window_size, window_size, len(time_data_cols) + add_dim)
    )
    y_past = np.empty((len(df) // window_size, window_size))
    y_target = np.empty((len(df) // window_size, target_size))
    if random_state is not None:
        np.random.seed(random_state)
    for fips in tqdm(score_df.index.get_level_values(0).unique()):
        if random_state is not None:
            start_i = np.random.randint(1, window_size)
        else:
            start_i = 1
        fips_df = df[(df.index.get_level_values(0) == fips)]
        X = fips_df[time_data_cols].values
        y = fips_df["score"].values
        X_s = soil_df[soil_df["fips"] == fips][static_data_cols].values[0]
        X_s_cat = soil_df[soil_df["fips"] == fips][static_cat_cols].values[0]
        for i in range(start_i, len(y) - (window_size + target_size * 7), window_size):
            X_fips_date.append((fips, fips_df.index[i : i + window_size][-1]))
            X_time[count, :, : len(time_data_cols)] = X[i : i + window_size]
            if use_prev_year:
                if i < 365 or len(X[i - 365 : i + window_size - 365]) < window_size:
                    continue
                X_time[count, :, -len(time_data_cols) :] = X[
                    i - 365 : i + window_size - 365
                ]
            if not fuse_past:
                y_past[count] = interpolate_nans(y[i : i + window_size])
            else:
                X_time[count, :, len(time_data_cols)] = interpolate_nans(
                    y[i : i + window_size]
                )
            if encode_season:
                enc_dates = [
                    date_encode(d) for f, d in fips_df.index[i : i + window_size].values
                ]
                d_sin, d_cos = [s for s, c in enc_dates], [c for s, c in enc_dates]
                X_time[count, :, len(time_data_cols) + (add_dim - 2)] = d_sin
                X_time[count, :, len(time_data_cols) + (add_dim - 2) + 1] = d_cos
            temp_y = y[i + window_size : i + window_size + target_size * 7]
            y_target[count] = np.array(temp_y[~np.isnan(temp_y)][:target_size])
            X_static[count] = X_s
            X_static_cat[count] = X_s_cat
            count += 1
    print(f"loaded {count} samples")
    results = [X_static[:count], X_static_cat[:count], X_time[:count], y_target[:count]]
    if not fuse_past:
        results.append(y_past[:count])
    if return_fips:
        results.append(X_fips_date)
    return results

In [11]:
scaler_dict = {}
scaler_dict_static = {}
scaler_dict_past = {}


def normalize(X_static, X_time, y_past=None, fit=False):
    for index in tqdm(range(X_time.shape[-1])):
        if fit:
            scaler_dict[index] = RobustScaler().fit(X_time[:, :, index].reshape(-1, 1))
        X_time[:, :, index] = (
            scaler_dict[index]
            .transform(X_time[:, :, index].reshape(-1, 1))
            .reshape(-1, X_time.shape[-2])
        )
    for index in tqdm(range(X_static.shape[-1])):
        if fit:
            scaler_dict_static[index] = RobustScaler().fit(
                X_static[:, index].reshape(-1, 1)
            )
        X_static[:, index] = (
            scaler_dict_static[index]
            .transform(X_static[:, index].reshape(-1, 1))
            .reshape(1, -1)
        )
    index = 0
    if y_past is not None:
        if fit:
            scaler_dict_past[index] = RobustScaler().fit(y_past.reshape(-1, 1))
        y_past[:, :] = (
            scaler_dict_past[index]
            .transform(y_past.reshape(-1, 1))
            .reshape(-1, y_past.shape[-1])
        )
        return X_static, X_time, y_past
    return X_static, X_time

In [12]:
X_tabular_train, X_tabular_cat_train, X_time_train, y_target_train = loadXY("train")
print("train shape", X_time_train.shape)
X_tabular_validation, X_tabular_cat_validation, X_time_valid, y_target_valid, valid_fips = loadXY("validation", return_fips=True)
print("validation shape", X_time_valid.shape)
X_tabular_train, X_time_train = normalize(X_tabular_train, X_time_train, fit=True)
X_tabular_validation, X_time_valid = normalize(X_tabular_validation, X_time_valid)

100%|██████████| 3108/3108 [09:33<00:00,  5.42it/s]


loaded 103390 samples
train shape (103390, 180, 21)


100%|██████████| 3108/3108 [00:51<00:00, 60.07it/s]


loaded 8748 samples
validation shape (8748, 180, 21)


100%|██████████| 21/21 [00:23<00:00,  1.10s/it]
100%|██████████| 22/22 [00:00<00:00, 336.59it/s]
100%|██████████| 21/21 [00:00<00:00, 30.02it/s]
100%|██████████| 22/22 [00:00<00:00, 8162.29it/s]


In [13]:
X_tabular_train[0]

array([ 5.56228848e-01,  8.17409819e-01,  6.51668564e-01, -2.83773085e-01,
       -3.14493385e-01,  2.79404984e+01, -1.97667772e-01,  9.97399986e-01,
        1.00506757e-01, -3.72835005e-01,  4.31893688e-01,  3.36440100e-02,
        1.99835526e-01, -5.95441595e-01,  2.23973455e-01,  2.08807389e-02,
       -1.77997076e-01,  5.69166279e-01,  1.13585746e+00,  3.01293900e-01,
        0.00000000e+00,  0.00000000e+00])

In [14]:
X_tabular_cat_train[10000]

array([3., 3., 1., 3., 1., 1., 1.])

In [15]:
X_time_train

array([[[ 3.55963303e+00,  5.12135922e-01,  4.82894737e-01, ...,
          8.30131398e-01,  6.75855651e-01, -1.32960775e-01],
        [ 2.28440367e+00,  4.87864078e-01,  2.34210526e-01, ...,
          8.30131398e-01,  6.73180468e-01, -1.44851493e-01],
        [-4.58715596e-02,  4.66019417e-01,  3.22368421e-01, ...,
          8.30131398e-01,  6.70303333e-01, -1.56694133e-01],
        ...,
        [-8.25688073e-02,  6.18932039e-01, -3.59210526e-01, ...,
          3.91880129e+00, -7.11799149e-01,  9.73796555e-02],
        [-8.25688073e-02,  7.18446602e-01, -5.90789474e-01, ...,
          3.91880129e+00, -7.10350132e-01,  1.09484829e-01],
        [-8.25688073e-02,  7.50000000e-01, -5.00000000e-01, ...,
          3.91880129e+00, -7.08695333e-01,  1.21563125e-01]],

       [[-8.25688073e-02,  7.03883495e-01, -4.40789474e-01, ...,
          3.91871828e+00, -7.06835240e-01,  1.33610987e-01],
        [-8.25688073e-02,  6.82038835e-01, -4.31578947e-01, ...,
          3.95185238e+00, -7.04770402e

In [16]:
# Really important line !!!
list_cat = [dfs["soil"][cat].nunique() for cat in ["SQ1", "SQ2", "SQ3", "SQ4", "SQ5", "SQ6", "SQ7"]]

In [17]:
num_categorical_features = 7
num_numerical_features = 22
num_time_series_features = 21
hidden_size = 200
num_lstm_layers = 2
embedding_dims = 50
learning_rate = 7e-5
num_fc_tabular_layers = 2
num_fc_combined_layers = 2
num_epochs_lstm = 10  # Change this
num_epochs_entire = 10 # Change this
batch_size = 128
output_weeks = 6

In [18]:

# batch_size = 128
# output_weeks = 6
# use_static = True
# hidden_dim = 512
# n_layers = 2
# ffnn_layers = 2
# dropout = 0.1
# one_cycle = True
# lr = 7e-5
# epochs = 10
# clip = 5

In [19]:
train_data = TensorDataset(
    torch.tensor(X_time_train).type(torch.FloatTensor),
    torch.tensor(X_tabular_train).type(torch.FloatTensor),
    torch.tensor(X_tabular_cat_train).type(torch.LongTensor),
    torch.tensor(y_target_train[:, :output_weeks]).type(torch.FloatTensor),
)
train_loader = DataLoader(
    train_data, shuffle=True, batch_size=batch_size, drop_last=False
)
valid_data = TensorDataset(
    torch.tensor(X_time_valid).type(torch.FloatTensor),
    torch.tensor(X_tabular_validation).type(torch.FloatTensor),
    torch.tensor(X_tabular_cat_validation).type(torch.LongTensor),
    torch.tensor(y_target_valid[:, :output_weeks]).type(torch.FloatTensor),
)
valid_loader = DataLoader(
    valid_data, shuffle=False, batch_size=batch_size, drop_last=False
)

In [20]:
X_time, X_static, X_static_cat, y_target = next(iter(train_loader))

In [21]:
X_static.shape, X_static_cat.shape, X_time.shape, y_target.shape

(torch.Size([128, 22]),
 torch.Size([128, 7]),
 torch.Size([128, 180, 21]),
 torch.Size([128, 6]))

In [None]:
y_target

In [None]:
np.round(y_target, 0)

In [None]:
cross_test = nn.CrossEntropyLoss()
print(cross_test(torch.tensor([[0.1, 0.2, 0.7]]), torch.tensor([2])))

In [24]:
class HybridModel(nn.Module):

    def __init__(
        self,
        num_categorical_features,
        list_unic_cat,
        num_numerical_features,
        num_time_series_features,
        hidden_size,
        num_lstm_layers,
        embedding_dims,
        num_fc_tabular_layers,
        num_fc_combined_layers,
        output_size,
    ):
        super(HybridModel, self).__init__()

        # Embeddings for categorical variables
        self.embeddings = nn.ModuleList(
            [
                nn.Embedding(num_embeddings=i, embedding_dim=embedding_dims)
                for i in list_unic_cat
            ]
        )

        total_embedding_dim = num_categorical_features * embedding_dims

        # Tabular part: dinamic creation of layers
        tabular_fc_layers = []
        input_size = total_embedding_dim + num_numerical_features
        for _ in range(num_fc_tabular_layers):
            tabular_fc_layers.append(nn.Linear(input_size, 128))
            tabular_fc_layers.append(nn.ReLU())
            input_size = 128
        self.tabular_fc_layers = nn.Sequential(
            *tabular_fc_layers, nn.Linear(128, 64), nn.ReLU()
        )

        # Temporal series
        self.lstm = nn.LSTM(
            input_size=num_time_series_features,
            hidden_size=hidden_size,
            num_layers=num_lstm_layers,
            batch_first=True,
        )

        # Atenttion
        self.attention = nn.Linear(hidden_size, 1)
        
        self.dropout = nn.Dropout(0.2)
        # Combined part
        self.fc_after_context = nn.Linear(hidden_size, 64)
        combined_fc_layers = []
        input_dim = (64 + 64)  # Assuming 128 from tabular output and 64 from LSTM output after attention
        for _ in range(num_fc_combined_layers):
            combined_fc_layers.append(nn.Linear(input_dim, 64))
            combined_fc_layers.append(nn.ReLU())
            input_dim = 64
        self.combined_fc_layers = nn.Sequential(
            *combined_fc_layers, nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, output_size)
        )

    def forward(self, categorical_data, numerical_data, time_series_data):
        # Embeddings for categorical data
        embeddings = [
            emb(categorical_data[:, i]) for i, emb in enumerate(self.embeddings)
        ]
        x_cat = torch.cat(embeddings, dim=1)

        # Concatenate categorical and numerical data
        x_tabular = torch.cat((x_cat, numerical_data), dim=1)

        # Pass the tabular data through FC layers
        x1 = self.tabular_fc_layers(x_tabular)

        # Pass the time series data through the LSTM
        lstm_out, (hn, cn) = self.lstm(time_series_data)
        # Pass the data through the attention mechanism
        attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
        context_vector = torch.sum(attention_weights * lstm_out, dim=1)
        droped_out = self.dropout(context_vector)

        x2 = torch.relu(self.fc_after_context(droped_out))

        # Concatenate the outputs from the tabular and the temporal data and pass it through FC layers
        x = torch.cat((x1, x2), dim=1)
        x = self.combined_fc_layers(x)

        return x

In [25]:
len(train_loader), len(valid_loader)

(808, 69)

## Here I initialize a Tensorboard Writer 

In [61]:
writer = SummaryWriter('../logs/HybridModel')

In [62]:
is_cuda = torch.cuda.is_available()
if is_cuda:
    device = torch.device("cuda")
    print("using GPU")
else:
    device = torch.device("cpu")
    print("using CPU")

model = HybridModel(
    num_categorical_features,
    list_cat,
    num_numerical_features,
    num_time_series_features,
    hidden_size,
    num_lstm_layers,
    embedding_dims,
    num_fc_tabular_layers,
    num_fc_combined_layers,
    output_size=output_weeks,
)
model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(train_loader), epochs=num_epochs_entire)
valid_loss_min = np.inf
counter = 0
# Seeds
torch.manual_seed(42)
np.random.seed(42)

model.train()
for i in range(num_epochs_entire):
    for k, batch in tqdm(enumerate(train_loader), desc=f"Epoch {i+1}/{num_epochs_entire}", total=len(train_loader)):
        X_time, X_static, X_static_cat, y_target = [data.to(device) for data in batch]
        model.train()
        counter += 1
        optimizer.zero_grad()
        output = model(X_static_cat, X_static, X_time)
        loss = criterion(output, y_target)
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            if k == len(train_loader) - 1 or k == (len(train_loader) - 1) // 2:
                model.eval()
                labels = []
                raw_labels = []
                preds = []
                raw_preds = []
                val_losses = []
                for batch in valid_loader:
                    X_time_val, X_static_val, X_static_cat_val, y_target_val = [data.to(device) for data in batch]
                    output = model(X_static_cat_val, X_static_val, X_time_val)
                    val_loss = criterion(output, y_target_val)
                    val_losses.append(val_loss.item())
                    for label in y_target_val:
                        labels.append([int(l.round()) for l in label])
                        raw_labels.append([float(l) for l in label])
                    for pred in output:
                        preds.append([int(p.round()) for p in pred])
                        raw_preds.append([float(p) for p in pred])
                    
                labels = np.array(labels)
                preds = np.clip(np.array(preds), 0, 5)
                raw_labels = np.array(raw_labels)
                raw_preds = np.array(raw_preds)

                for i in range(output_weeks):
                    log_dict = {
                        "loss": float(loss),
                        "epoch": counter / len(train_loader),
                        "step": counter,
                        "lr": optimizer.param_groups[0]["lr"],
                        "week": i + 1,
                    }
                    # w = f'week_{i+1}_'
                    w = ""
                    log_dict[f"{w}validation_loss"] = np.mean(val_losses)
                    log_dict[f"{w}macro_f1"] = f1_score(
                        labels[:, i], preds[:, i], average="macro"
                    )
                    log_dict[f"{w}micro_f1"] = f1_score(
                        labels[:, i], preds[:, i], average="micro"
                    )
                    log_dict[f"{w}mae"] = mean_absolute_error(
                        raw_labels[:, i], raw_preds[:, i]
                    )
                    print(log_dict)
                    writer.add_scalars("Loss(MSE)", {'train': loss,
                                                     'validation': log_dict[f"{w}validation_loss"]},
                                                     counter)
                    writer.add_scalars("F1(MSE)", {'macro': log_dict[f"{w}macro_f1"],
                                                   'micro': log_dict[f"{w}micro_f1"]},
                                                   counter)
                    writer.add_scalar("MAE", log_dict[f"{w}mae"],
                                      counter)
                    writer.add_scalar("Learning-Rate", log_dict["lr"],
                                      counter)
                    for j, f1 in enumerate(
                        f1_score(labels[:, i], preds[:, i], average=None)
                    ):
                        log_dict[f"{w}{id2class[j]}_f1"] = f1
                    model.train()
                if np.mean(val_losses) <= valid_loss_min:
                    torch.save(model.state_dict(), "./state_dict_HybridModel.pt")
                    print(
                        "Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
                            valid_loss_min, np.mean(val_losses)
                        )
                    )
                    valid_loss_min = np.mean(val_losses)

using GPU


Epoch 1/10:  50%|█████     | 407/808 [00:34<03:34,  1.87it/s]

{'loss': 1.5886046886444092, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 1, 'validation_loss': np.float64(1.1733881108786748), 'macro_f1': np.float64(0.13287963153914897), 'micro_f1': np.float64(0.6628943758573388), 'mae': np.float64(0.5994869498093328)}
{'loss': 1.5886046886444092, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 2, 'validation_loss': np.float64(1.1733881108786748), 'macro_f1': np.float64(0.13275553721282157), 'micro_f1': np.float64(0.6618655692729767), 'mae': np.float64(0.5980574556242103)}
{'loss': 1.5886046886444092, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 3, 'validation_loss': np.float64(1.1733881108786748), 'macro_f1': np.float64(0.13290718720645175), 'micro_f1': np.float64(0.6631229995427527), 'mae': np.float64(0.7155682356324623)}
{'loss': 1.5886046886444092, 'epoch': 0.5, 'step': 404, 'lr': 7.305177512317032e-06, 'week': 4, 'validation_loss': np.float64(1.1733881108786748), 'macro_f1': np.float64(0.133

Epoch 1/10: 100%|██████████| 808/808 [01:09<00:00, 11.57it/s]


{'loss': 1.3227063417434692, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 1, 'validation_loss': np.float64(0.8102497981078383), 'macro_f1': np.float64(0.16948519547272592), 'micro_f1': np.float64(0.5555555555555556), 'mae': np.float64(0.6751894175637752)}
{'loss': 1.3227063417434692, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 2, 'validation_loss': np.float64(0.8102497981078383), 'macro_f1': np.float64(0.10103908649617287), 'micro_f1': np.float64(0.21250571559213535), 'mae': np.float64(0.7774371078244114)}
{'loss': 1.3227063417434692, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 3, 'validation_loss': np.float64(0.8102497981078383), 'macro_f1': np.float64(0.17090317825863088), 'micro_f1': np.float64(0.5254915409236397), 'mae': np.float64(0.6692232408236752)}
{'loss': 1.3227063417434692, 'epoch': 1.0, 'step': 808, 'lr': 1.9612577643465342e-05, 'week': 4, 'validation_loss': np.float64(0.8102497981078383), 'macro_f1': np.float64(

Epoch 2/10:  50%|█████     | 407/808 [00:35<04:03,  1.65it/s]

{'loss': 0.682614266872406, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 1, 'validation_loss': np.float64(0.5924713624560315), 'macro_f1': np.float64(0.261688681993346), 'micro_f1': np.float64(0.5692729766803841), 'mae': np.float64(0.5658560298058959)}
{'loss': 0.682614266872406, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 2, 'validation_loss': np.float64(0.5924713624560315), 'macro_f1': np.float64(0.2834471956564339), 'micro_f1': np.float64(0.569387288523091), 'mae': np.float64(0.5671519964969427)}
{'loss': 0.682614266872406, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 3, 'validation_loss': np.float64(0.5924713624560315), 'macro_f1': np.float64(0.2819904124046037), 'micro_f1': np.float64(0.5775034293552812), 'mae': np.float64(0.5396382211391137)}
{'loss': 0.682614266872406, 'epoch': 1.5, 'step': 1212, 'lr': 3.6421782399043904e-05, 'week': 4, 'validation_loss': np.float64(0.5924713624560315), 'macro_f1': np.float64(0.2482

Epoch 2/10: 100%|██████████| 808/808 [01:10<00:00, 11.43it/s]


{'loss': 0.42485830187797546, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 1, 'validation_loss': np.float64(0.40237428831017535), 'macro_f1': np.float64(0.4425948179304811), 'micro_f1': np.float64(0.6966163694558757), 'mae': np.float64(0.38199450352082986)}
{'loss': 0.42485830187797546, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 2, 'validation_loss': np.float64(0.40237428831017535), 'macro_f1': np.float64(0.41247041125500034), 'micro_f1': np.float64(0.6791266575217193), 'mae': np.float64(0.40770335698466265)}
{'loss': 0.42485830187797546, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 3, 'validation_loss': np.float64(0.40237428831017535), 'macro_f1': np.float64(0.39126910417309957), 'micro_f1': np.float64(0.6558070416095108), 'mae': np.float64(0.4355482700787529)}
{'loss': 0.42485830187797546, 'epoch': 2.0, 'step': 1616, 'lr': 5.322514587043574e-05, 'week': 4, 'validation_loss': np.float64(0.40237428831017535), 'macro_f1': np.

Epoch 3/10:  50%|█████     | 407/808 [00:35<04:09,  1.61it/s]

{'loss': 0.3733770251274109, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 1, 'validation_loss': np.float64(0.3048547545744889), 'macro_f1': np.float64(0.5464634831105389), 'micro_f1': np.float64(0.7767489711934157), 'mae': np.float64(0.26841535920413584)}
{'loss': 0.3733770251274109, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 2, 'validation_loss': np.float64(0.3048547545744889), 'macro_f1': np.float64(0.49798969445748814), 'micro_f1': np.float64(0.7481710105166895), 'mae': np.float64(0.3013880558338603)}
{'loss': 0.3733770251274109, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 3, 'validation_loss': np.float64(0.3048547545744889), 'macro_f1': np.float64(0.46544074940528857), 'micro_f1': np.float64(0.7256515775034293), 'mae': np.float64(0.34565562713968484)}
{'loss': 0.3733770251274109, 'epoch': 2.5, 'step': 2020, 'lr': 6.551658857891442e-05, 'week': 4, 'validation_loss': np.float64(0.3048547545744889), 'macro_f1': np.float64(

Epoch 3/10: 100%|██████████| 808/808 [01:10<00:00, 11.39it/s]


{'loss': 0.30743539333343506, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 1, 'validation_loss': np.float64(0.26335468391577405), 'macro_f1': np.float64(0.677171212230751), 'micro_f1': np.float64(0.8305898491083676), 'mae': np.float64(0.21467656809979582)}
{'loss': 0.30743539333343506, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 2, 'validation_loss': np.float64(0.26335468391577405), 'macro_f1': np.float64(0.6016733911469613), 'micro_f1': np.float64(0.7904663923182441), 'mae': np.float64(0.2647743278932365)}
{'loss': 0.30743539333343506, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 3, 'validation_loss': np.float64(0.26335468391577405), 'macro_f1': np.float64(0.5290282101533196), 'micro_f1': np.float64(0.7525148605395519), 'mae': np.float64(0.3140586959235387)}
{'loss': 0.30743539333343506, 'epoch': 3.0, 'step': 2424, 'lr': 6.99999946009513e-05, 'week': 4, 'validation_loss': np.float64(0.26335468391577405), 'macro_f1': np.float64(

Epoch 4/10:  50%|█████     | 407/808 [00:35<04:10,  1.60it/s]

{'loss': 0.3117940425872803, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 1, 'validation_loss': np.float64(0.24230611632051674), 'macro_f1': np.float64(0.7175195230469368), 'micro_f1': np.float64(0.848079561042524), 'mae': np.float64(0.18519502068489782)}
{'loss': 0.3117940425872803, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 2, 'validation_loss': np.float64(0.24230611632051674), 'macro_f1': np.float64(0.6527466926364504), 'micro_f1': np.float64(0.80612711476909), 'mae': np.float64(0.2312402873451452)}
{'loss': 0.3117940425872803, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 3, 'validation_loss': np.float64(0.24230611632051674), 'macro_f1': np.float64(0.5937945261296119), 'micro_f1': np.float64(0.771490626428898), 'mae': np.float64(0.27802060277496854)}
{'loss': 0.3117940425872803, 'epoch': 3.5, 'step': 2828, 'lr': 6.911814926126814e-05, 'week': 4, 'validation_loss': np.float64(0.24230611632051674), 'macro_f1': np.float64(0.

Epoch 4/10: 100%|██████████| 808/808 [01:11<00:00, 11.38it/s]


{'loss': 0.19234295189380646, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 1, 'validation_loss': np.float64(0.23017254412390184), 'macro_f1': np.float64(0.7061783971928953), 'micro_f1': np.float64(0.8568815729309557), 'mae': np.float64(0.17576925486526657)}
{'loss': 0.19234295189380646, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 2, 'validation_loss': np.float64(0.23017254412390184), 'macro_f1': np.float64(0.6541712183399417), 'micro_f1': np.float64(0.8155006858710563), 'mae': np.float64(0.2241289531294612)}
{'loss': 0.19234295189380646, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 3, 'validation_loss': np.float64(0.23017254412390184), 'macro_f1': np.float64(0.5911507491941492), 'micro_f1': np.float64(0.7789208962048468), 'mae': np.float64(0.27622370832890636)}
{'loss': 0.19234295189380646, 'epoch': 4.0, 'step': 3232, 'lr': 6.652548447282524e-05, 'week': 4, 'validation_loss': np.float64(0.23017254412390184), 'macro_f1': np.fl

Epoch 5/10:  50%|█████     | 407/808 [00:35<04:09,  1.61it/s]

{'loss': 0.2893592417240143, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 1, 'validation_loss': np.float64(0.2354024387896061), 'macro_f1': np.float64(0.7246549282222013), 'micro_f1': np.float64(0.8559670781893004), 'mae': np.float64(0.1641908312147015)}
{'loss': 0.2893592417240143, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 2, 'validation_loss': np.float64(0.2354024387896061), 'macro_f1': np.float64(0.6596593924707699), 'micro_f1': np.float64(0.8161865569272977), 'mae': np.float64(0.20880640162532907)}
{'loss': 0.2893592417240143, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 3, 'validation_loss': np.float64(0.2354024387896061), 'macro_f1': np.float64(0.6135120379625463), 'micro_f1': np.float64(0.7800640146319159), 'mae': np.float64(0.25736719962217214)}
{'loss': 0.2893592417240143, 'epoch': 4.5, 'step': 3636, 'lr': 6.235200727414045e-05, 'week': 4, 'validation_loss': np.float64(0.2354024387896061), 'macro_f1': np.float64(0.

Epoch 5/10: 100%|██████████| 808/808 [01:11<00:00, 11.38it/s]


{'loss': 0.36580950021743774, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 1, 'validation_loss': np.float64(0.22500768336264984), 'macro_f1': np.float64(0.7538028809696051), 'micro_f1': np.float64(0.8691129401005944), 'mae': np.float64(0.17147523218675237)}
{'loss': 0.36580950021743774, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 2, 'validation_loss': np.float64(0.22500768336264984), 'macro_f1': np.float64(0.6861334768543887), 'micro_f1': np.float64(0.8256744398719708), 'mae': np.float64(0.21945650411844114)}
{'loss': 0.36580950021743774, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 3, 'validation_loss': np.float64(0.22500768336264984), 'macro_f1': np.float64(0.6412393546174723), 'micro_f1': np.float64(0.7862368541380887), 'mae': np.float64(0.27202264996668024)}
{'loss': 0.36580950021743774, 'epoch': 5.0, 'step': 4040, 'lr': 5.680699323887897e-05, 'week': 4, 'validation_loss': np.float64(0.22500768336264984), 'macro_f1': np.f

Epoch 6/10:  50%|█████     | 407/808 [00:35<04:03,  1.64it/s]

{'loss': 0.23386359214782715, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 1, 'validation_loss': np.float64(0.2253026554118032), 'macro_f1': np.float64(0.7200483770568632), 'micro_f1': np.float64(0.8718564243255601), 'mae': np.float64(0.16071433109676814)}
{'loss': 0.23386359214782715, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 2, 'validation_loss': np.float64(0.2253026554118032), 'macro_f1': np.float64(0.6460258278163221), 'micro_f1': np.float64(0.823045267489712), 'mae': np.float64(0.2111546583722211)}
{'loss': 0.23386359214782715, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 3, 'validation_loss': np.float64(0.2253026554118032), 'macro_f1': np.float64(0.5888120353794805), 'micro_f1': np.float64(0.7830361225422954), 'mae': np.float64(0.26464974494892923)}
{'loss': 0.23386359214782715, 'epoch': 5.5, 'step': 4444, 'lr': 5.0168492524730965e-05, 'week': 4, 'validation_loss': np.float64(0.2253026554118032), 'macro_f1': np.flo

Epoch 6/10: 100%|██████████| 808/808 [01:10<00:00, 11.39it/s]


{'loss': 0.26488155126571655, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 1, 'validation_loss': np.float64(0.22395657170293987), 'macro_f1': np.float64(0.75646557013388), 'micro_f1': np.float64(0.872085048010974), 'mae': np.float64(0.16371772425661935)}
{'loss': 0.26488155126571655, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 2, 'validation_loss': np.float64(0.22395657170293987), 'macro_f1': np.float64(0.6891215428336798), 'micro_f1': np.float64(0.8262459990855052), 'mae': np.float64(0.21556654313168105)}
{'loss': 0.26488155126571655, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 3, 'validation_loss': np.float64(0.22395657170293987), 'macro_f1': np.float64(0.6349932943926787), 'micro_f1': np.float64(0.7854366712391404), 'mae': np.float64(0.26812294587023744)}
{'loss': 0.26488155126571655, 'epoch': 6.0, 'step': 4848, 'lr': 4.276938727746874e-05, 'week': 4, 'validation_loss': np.float64(0.22395657170293987), 'macro_f1': np.floa

Epoch 7/10:  50%|█████     | 407/808 [00:35<04:03,  1.64it/s]

{'loss': 0.233161062002182, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 1, 'validation_loss': np.float64(0.21988838662703833), 'macro_f1': np.float64(0.7463314495670618), 'micro_f1': np.float64(0.8735711019661637), 'mae': np.float64(0.14974365337342643)}
{'loss': 0.233161062002182, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 2, 'validation_loss': np.float64(0.21988838662703833), 'macro_f1': np.float64(0.6899807950167435), 'micro_f1': np.float64(0.8296753543667124), 'mae': np.float64(0.20045258908590305)}
{'loss': 0.233161062002182, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 3, 'validation_loss': np.float64(0.21988838662703833), 'macro_f1': np.float64(0.6270480135858142), 'micro_f1': np.float64(0.7898948331047096), 'mae': np.float64(0.25265945930176237)}
{'loss': 0.233161062002182, 'epoch': 6.5, 'step': 5252, 'lr': 3.498069953016286e-05, 'week': 4, 'validation_loss': np.float64(0.21988838662703833), 'macro_f1': np.float64(0

Epoch 7/10: 100%|██████████| 808/808 [01:11<00:00, 11.37it/s]


{'loss': 0.35186290740966797, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 1, 'validation_loss': np.float64(0.21739354394916174), 'macro_f1': np.float64(0.7470608657059801), 'micro_f1': np.float64(0.8788294467306813), 'mae': np.float64(0.1519082824484153)}
{'loss': 0.35186290740966797, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 2, 'validation_loss': np.float64(0.21739354394916174), 'macro_f1': np.float64(0.6863746724294462), 'micro_f1': np.float64(0.8309327846364883), 'mae': np.float64(0.20275314242484335)}
{'loss': 0.35186290740966797, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 3, 'validation_loss': np.float64(0.21739354394916174), 'macro_f1': np.float64(0.629978600401459), 'micro_f1': np.float64(0.789551897576589), 'mae': np.float64(0.2570684143128312)}
{'loss': 0.35186290740966797, 'epoch': 7.0, 'step': 5656, 'lr': 2.7192986609190955e-05, 'week': 4, 'validation_loss': np.float64(0.21739354394916174), 'macro_f1': np.f

Epoch 8/10:  50%|█████     | 407/808 [00:35<04:03,  1.64it/s]

{'loss': 0.22014831006526947, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 1, 'validation_loss': np.float64(0.22259975818620212), 'macro_f1': np.float64(0.7567816721381053), 'micro_f1': np.float64(0.8728852309099223), 'mae': np.float64(0.16595080061926432)}
{'loss': 0.22014831006526947, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 2, 'validation_loss': np.float64(0.22259975818620212), 'macro_f1': np.float64(0.6933043995736622), 'micro_f1': np.float64(0.8238454503886603), 'mae': np.float64(0.22051319878830564)}
{'loss': 0.22014831006526947, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 3, 'validation_loss': np.float64(0.22259975818620212), 'macro_f1': np.float64(0.6342780527380163), 'micro_f1': np.float64(0.7852080475537265), 'mae': np.float64(0.26968920683772796)}
{'loss': 0.22014831006526947, 'epoch': 7.5, 'step': 6060, 'lr': 1.9796756959067725e-05, 'week': 4, 'validation_loss': np.float64(0.22259975818620212), 'macro_f1': 

Epoch 8/10: 100%|██████████| 808/808 [01:10<00:00, 11.38it/s]


{'loss': 0.21695873141288757, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 1, 'validation_loss': np.float64(0.21615426172164903), 'macro_f1': np.float64(0.769932711527065), 'micro_f1': np.float64(0.8835162322816644), 'mae': np.float64(0.15163859907036978)}
{'loss': 0.21695873141288757, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 2, 'validation_loss': np.float64(0.21615426172164903), 'macro_f1': np.float64(0.7024054587666405), 'micro_f1': np.float64(0.8328760859625057), 'mae': np.float64(0.20523496172161604)}
{'loss': 0.21695873141288757, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 3, 'validation_loss': np.float64(0.21615426172164903), 'macro_f1': np.float64(0.6309566296477928), 'micro_f1': np.float64(0.7898948331047096), 'mae': np.float64(0.2583488266062124)}
{'loss': 0.21695873141288757, 'epoch': 8.0, 'step': 6464, 'lr': 1.316288841841575e-05, 'week': 4, 'validation_loss': np.float64(0.21615426172164903), 'macro_f1': np.flo

Epoch 9/10:  50%|█████     | 407/808 [00:35<04:03,  1.64it/s]

{'loss': 0.17218248546123505, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 1, 'validation_loss': np.float64(0.21700713173418806), 'macro_f1': np.float64(0.7597430922629131), 'micro_f1': np.float64(0.8771147690900777), 'mae': np.float64(0.15425598572436347)}
{'loss': 0.17218248546123505, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 2, 'validation_loss': np.float64(0.21700713173418806), 'macro_f1': np.float64(0.6997003245304286), 'micro_f1': np.float64(0.8292181069958847), 'mae': np.float64(0.2066701749414078)}
{'loss': 0.17218248546123505, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 3, 'validation_loss': np.float64(0.21700713173418806), 'macro_f1': np.float64(0.6494965473074373), 'micro_f1': np.float64(0.7897805212620027), 'mae': np.float64(0.25888861616252473)}
{'loss': 0.17218248546123505, 'epoch': 8.5, 'step': 6868, 'lr': 7.624030856485954e-06, 'week': 4, 'validation_loss': np.float64(0.21700713173418806), 'macro_f1': np.fl

Epoch 9/10: 100%|██████████| 808/808 [01:10<00:00, 11.38it/s]


{'loss': 0.23346787691116333, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 1, 'validation_loss': np.float64(0.21721695169158603), 'macro_f1': np.float64(0.7703359983827184), 'micro_f1': np.float64(0.877914951989026), 'mae': np.float64(0.15779049922312352)}
{'loss': 0.23346787691116333, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 2, 'validation_loss': np.float64(0.21721695169158603), 'macro_f1': np.float64(0.7021273721883804), 'micro_f1': np.float64(0.8291037951531779), 'mae': np.float64(0.2090569390370108)}
{'loss': 0.23346787691116333, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 3, 'validation_loss': np.float64(0.21721695169158603), 'macro_f1': np.float64(0.6443951713935259), 'micro_f1': np.float64(0.7889803383630544), 'mae': np.float64(0.26093002627904016)}
{'loss': 0.23346787691116333, 'epoch': 9.0, 'step': 7272, 'lr': 3.4579257196884897e-06, 'week': 4, 'validation_loss': np.float64(0.21721695169158603), 'macro_f1': np

Epoch 10/10:  50%|█████     | 407/808 [00:35<04:04,  1.64it/s]

{'loss': 0.27507051825523376, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 1, 'validation_loss': np.float64(0.21586166937714038), 'macro_f1': np.float64(0.7695783301634155), 'micro_f1': np.float64(0.880429812528578), 'mae': np.float64(0.15447442143486148)}
{'loss': 0.27507051825523376, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 2, 'validation_loss': np.float64(0.21586166937714038), 'macro_f1': np.float64(0.7011688814589112), 'micro_f1': np.float64(0.831390032007316), 'mae': np.float64(0.2065214219692887)}
{'loss': 0.27507051825523376, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 3, 'validation_loss': np.float64(0.21586166937714038), 'macro_f1': np.float64(0.6331086862601194), 'micro_f1': np.float64(0.7893232738911752), 'mae': np.float64(0.258840336747746)}
{'loss': 0.27507051825523376, 'epoch': 9.5, 'step': 7676, 'lr': 8.734789157224429e-07, 'week': 4, 'validation_loss': np.float64(0.21586166937714038), 'macro_f1': np.float6

Epoch 10/10: 100%|██████████| 808/808 [01:11<00:00, 11.38it/s]

{'loss': 0.21028056740760803, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 1, 'validation_loss': np.float64(0.21604836246241693), 'macro_f1': np.float64(0.7676129362869686), 'micro_f1': np.float64(0.8802011888431641), 'mae': np.float64(0.15364417431386415)}
{'loss': 0.21028056740760803, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 2, 'validation_loss': np.float64(0.21604836246241693), 'macro_f1': np.float64(0.6989052064326254), 'micro_f1': np.float64(0.831275720164609), 'mae': np.float64(0.2060145811782345)}
{'loss': 0.21028056740760803, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 3, 'validation_loss': np.float64(0.21604836246241693), 'macro_f1': np.float64(0.638953548436669), 'micro_f1': np.float64(0.7897805212620027), 'mae': np.float64(0.25820790535856036)}
{'loss': 0.21028056740760803, 'epoch': 10.0, 'step': 8080, 'lr': 2.853990486928992e-10, 'week': 4, 'validation_loss': np.float64(0.21604836246241693), 'macro_f1': np.




## Copying the best model params (if needed)

In [27]:
device = torch.device("cpu")

model = HybridModel(
    num_categorical_features,
    list_cat,
    num_numerical_features,
    num_time_series_features,
    hidden_size,
    num_lstm_layers,
    embedding_dims,
    num_fc_tabular_layers,
    num_fc_combined_layers,
    output_size=output_weeks,
)

model.load_state_dict(torch.load("./state_dict_HybridModel.pt"))
model.to(device)

  model.load_state_dict(torch.load("./state_dict_HybridModel.pt"))


HybridModel(
  (embeddings): ModuleList(
    (0-4): 5 x Embedding(7, 50)
    (5): Embedding(6, 50)
    (6): Embedding(8, 50)
  )
  (tabular_fc_layers): Sequential(
    (0): Linear(in_features=372, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): ReLU()
  )
  (lstm): LSTM(21, 200, num_layers=2, batch_first=True)
  (attention): Linear(in_features=200, out_features=1, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc_after_context): Linear(in_features=200, out_features=64, bias=True)
  (combined_fc_layers): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): Linear(in_features=32, out_features=6, bias=True)
  )
)

In [28]:
def predict(x, static, static_cat):
    out = model(static_cat, static, x)
    return out

In [30]:
dict_map = {
    "y_pred": [],
    "y_pred_rounded": [],
    "fips": [],
    "date": [],
    "y_true": [],
    "week": [],
}
i = 0
for x, static, static_cat, y in tqdm(
    valid_loader,
    desc="Validation predictions...",
):
    x, static, y = x.to(device), static.to(device), y.to(device)
    with torch.no_grad():
        pred = predict(x, static, static_cat).clone().detach()
    for w in range(output_weeks):
        dict_map["y_pred"] += [float(p[w]) for p in pred]
        dict_map["y_pred_rounded"] += [int(p.round()[w]) for p in pred]
        dict_map["fips"] += [f[1][0] for f in valid_fips[i : i + len(x)]]
        dict_map["date"] += [f[1][1] for f in valid_fips[i : i + len(x)]]
        dict_map["y_true"] += [float(item[w]) for item in y]
        dict_map["week"] += [w] * len(x)
    i += len(x)
df = pd.DataFrame(dict_map)

validation predictions...: 100%|██████████| 69/69 [00:03<00:00, 22.86it/s]


In [31]:
for w in range(6):
    wdf = df[df['week']==w]
    mae = mean_absolute_error(wdf['y_true'], wdf['y_pred']).round(3)
    f1 = f1_score(wdf['y_true'].round(),wdf['y_pred'].round(), average='macro').round(3)
    print(f"Week {w+1}", f"MAE {mae}", f"F1 {f1}")

Week 1 MAE 0.16 F1 0.749
Week 2 MAE 0.213 F1 0.687
Week 3 MAE 0.264 F1 0.642
Week 4 MAE 0.316 F1 0.571
Week 5 MAE 0.363 F1 0.499
Week 6 MAE 0.405 F1 0.452
