# Test daily_earlyrnn model
Test the new model to check if it is working correctly.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os 
#os.environ['MPLCONFIGDIR'] = "$HOME"
#os.envir
# on["WANDB_DIR"] = os.path.join(os.path.dirname(__file__), "..", "wandb")
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
# sys append 
sys.path.append("..")
from data import BavarianCrops, BreizhCrops, SustainbenchCrops, ModisCDL
from torch.utils.data import DataLoader
from models.earlyrnn import EarlyRNN
from models.daily_earlyrnn import DailyEarlyRNN
import torch
from tqdm import tqdm
from utils.losses.early_reward_loss import EarlyRewardLoss
from utils.losses.stopping_time_proximity_loss import StoppingTimeProximityLoss, sample_three_uniform_numbers
import sklearn.metrics
import pandas as pd
import wandb
from utils.plots import plot_label_distribution_datasets, boxplot_stopping_times
from utils.doy import get_doys_dict_test, get_doy_stop, create_sorted_doys_dict_test, get_approximated_doys_dict
from utils.helpers_training import parse_args_sweep, train_epoch
from utils.helpers_testing import test_epoch
from utils.metrics import harmonic_mean_score
from models.model_helpers import count_parameters
import matplotlib.pyplot as plt


In [12]:
# config 
class Config():
    def __init__(self):
        self.alpha = 0.6
        self.backbonemodel = "LSTM"
        self.batchsize = 256
        self.corrected = True
        self.dataroot = os.path.join(os.environ.get("HOME", os.environ.get("USERPROFILE")),"elects_data")
        self.dataset = "breizhcrops"
        self.device = "cuda"
        self.epochs = 100
        self.epsilon = 10
        self.extra_padding_list = [0]
        self.hidden_dims = 64
        self.learning_rate = 0.001
        self.loss_weight = "balanced"
        self.patience = 30
        self.resume = False
        self.sequencelength = 365
        self.validation_set = "valid"
        self.weight_decay = 0
        self.daily_timestamps = True
        self.original_time_serie_lengths = [102]
        self.loss = "stopping_time_proximity"
        
config = Config()

In [7]:
dataroot = os.path.join(config.dataroot,"breizhcrops")
nclasses = 9
input_dim = 13
test_ds = BreizhCrops(root=dataroot,partition=config.validation_set, sequencelength=config.sequencelength, corrected=config.corrected, daily_timestamps=config.daily_timestamps, original_time_serie_lengths=config.original_time_serie_lengths)
train_ds = BreizhCrops(root=dataroot,partition="train", sequencelength=config.sequencelength, corrected=config.corrected, daily_timestamps=config.daily_timestamps, original_time_serie_lengths=config.original_time_serie_lengths)
traindataloader = DataLoader(train_ds,batch_size=config.batchsize)
testdataloader = DataLoader(test_ds, batch_size=config.batchsize)

2559635960 2559635960


loading data into RAM: 100%|██████████| 67523/67523 [00:28<00:00, 2330.86it/s]


2253658856 2253658856


loading data into RAM: 100%|██████████| 85310/85310 [00:34<00:00, 2485.60it/s]


cost function: 

In [8]:
alpha1, alpha2, alpha3 = sample_three_uniform_numbers()
print(f"alphas: {alpha1}, {alpha2}, {alpha3}")

alphas: 0.1693766564130783, 0.12460131198167801, 0.7060220837593079


model: 

In [9]:
model = DailyEarlyRNN(config.backbonemodel, nclasses=nclasses, input_dim=input_dim, sequencelength=config.sequencelength, hidden_dims=config.hidden_dims).to(config.device)


optimizer

In [10]:
# exclude decision head linear bias from weight decay
decay, no_decay = list(), list()
for name, param in model.named_parameters():
    if name == "stopping_decision_head.projection.0.bias":
        no_decay.append(param)
    else:
        decay.append(param)

optimizer = torch.optim.AdamW([{'params': no_decay, 'weight_decay': 0, "lr": config.learning_rate}, {'params': decay}],
                                lr=config.learning_rate, weight_decay=config.weight_decay)



loss: 

In [13]:
if config.loss_weight == "balanced":
    class_weights = train_ds.get_class_weights().to(config.device)
else: 
    class_weights = None

if config.loss == "early_reward":
    criterion = EarlyRewardLoss(alpha=config.alpha, epsilon=config.epsilon, weight=class_weights)
elif config.loss == "stopping_time_proximity":
    criterion = StoppingTimeProximityLoss(alphas=config.alpha, weight=class_weights)

Train example

In [16]:
# ----------------------------- TRAINING -----------------------------
start_epoch = 1
print("starting training...")
with tqdm(range(start_epoch, config.epochs + 1)) as pbar:
    for epoch in pbar:
        trainloss = train_epoch(model, traindataloader, optimizer, criterion, device=config.device)
        testloss, stats = test_epoch(model, testdataloader, criterion, config.device, return_id=test_ds.return_id, daily_timestamps=config.daily_timestamps)

        # statistic logging and visualization...
        precision, recall, fscore, support = sklearn.metrics.precision_recall_fscore_support(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0], average="macro",
            zero_division=0)
        accuracy = sklearn.metrics.accuracy_score(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0])
        kappa = sklearn.metrics.cohen_kappa_score(
            stats["predictions_at_t_stop"][:, 0], stats["targets"][:, 0])

        classification_loss = stats["classification_loss"].mean()
        earliness_reward = stats["earliness_reward"].mean()
        earliness = 1 - (stats["t_stop"].mean() / (config.sequencelength - 1))
        harmonic_mean = harmonic_mean_score(accuracy, stats["classification_earliness"])

starting training...


  0%|          | 0/100 [00:00<?, ?it/s]

cross_entropy.shape:  torch.Size([256, 365])
classification_loss.shape:  torch.Size([365])


  0%|          | 0/100 [00:00<?, ?it/s]

earliness_reward.shape:  torch.Size([256, 365])
earliness_reward.shape:  torch.Size([])





IndexError: too many indices for tensor of dimension 1

In [16]:
log_class_probs

tensor([[[-2.0982, -2.2887, -2.2046,  ..., -2.1571, -2.2635, -2.3104],
         [-2.0976, -2.2901, -2.2015,  ..., -2.1580, -2.2649, -2.3124],
         [-2.0978, -2.2876, -2.1999,  ..., -2.1596, -2.2658, -2.3163],
         ...,
         [-2.0849, -2.2865, -2.1784,  ..., -2.1551, -2.2731, -2.3009],
         [-2.0886, -2.2877, -2.1736,  ..., -2.1570, -2.2730, -2.3015],
         [-2.0836, -2.2922, -2.1846,  ..., -2.1593, -2.2680, -2.3033]],

        [[-2.0980, -2.2894, -2.2048,  ..., -2.1548, -2.2627, -2.3100],
         [-2.0965, -2.2894, -2.2042,  ..., -2.1555, -2.2657, -2.3114],
         [-2.0957, -2.2876, -2.1982,  ..., -2.1533, -2.2757, -2.3100],
         ...,
         [-2.0932, -2.2766, -2.1814,  ..., -2.1685, -2.2705, -2.3067],
         [-2.0946, -2.2819, -2.1834,  ..., -2.1703, -2.2645, -2.3107],
         [-2.0944, -2.2828, -2.1855,  ..., -2.1670, -2.2623, -2.3126]],

        [[-2.0982, -2.2886, -2.2041,  ..., -2.1561, -2.2637, -2.3101],
         [-2.0961, -2.2885, -2.2031,  ..., -2

In [17]:
timestamps_left

tensor([[364., 364., 364.,  ..., 364., 364., 364.],
        [364., 364., 364.,  ..., 364., 364., 364.],
        [364., 364., 364.,  ..., 364., 364., 364.],
        ...,
        [364., 364., 364.,  ..., 364., 364., 364.],
        [364., 364., 364.,  ..., 364., 364., 364.],
        [364., 364., 364.,  ..., 364., 364., 364.]], device='cuda:0',
       grad_fn=<MulBackward0>)

In [27]:
logprobabilities, timestamps_left, predictions_at_t_stop, t_stop = model.predict(X.to(config.device))
print("shapes")
print("logprobabilities", logprobabilities.shape)
print("timestamps_left", timestamps_left.shape)
print("predictions_at_t_stop", predictions_at_t_stop.shape)
print("t_stop", t_stop.shape)


shapes
logprobabilities torch.Size([256, 365, 9])
timestamps_left torch.Size([256, 365])
predictions_at_t_stop torch.Size([256])
t_stop torch.Size([256])


In [28]:
print("logprobabilities", logprobabilities)
print("timestamps_left", timestamps_left)
print("predictions_at_t_stop", predictions_at_t_stop)
print("t_stop", t_stop)

logprobabilities tensor([[[-2.0979, -2.2886, -2.2047,  ..., -2.1557, -2.2622, -2.3116],
         [-2.0968, -2.2899, -2.2028,  ..., -2.1575, -2.2649, -2.3119],
         [-2.0901, -2.2942, -2.1983,  ..., -2.1591, -2.2760, -2.3093],
         ...,
         [-2.0942, -2.2873, -2.1760,  ..., -2.1637, -2.2672, -2.3101],
         [-2.0943, -2.2934, -2.1836,  ..., -2.1627, -2.2637, -2.3133],
         [-2.0884, -2.2939, -2.1864,  ..., -2.1599, -2.2589, -2.3139]],

        [[-2.0975, -2.2887, -2.2056,  ..., -2.1547, -2.2636, -2.3114],
         [-2.0956, -2.2872, -2.2062,  ..., -2.1556, -2.2662, -2.3116],
         [-2.0879, -2.2856, -2.2042,  ..., -2.1556, -2.2781, -2.3077],
         ...,
         [-2.0989, -2.2763, -2.1828,  ..., -2.1632, -2.2744, -2.3030],
         [-2.0996, -2.2766, -2.1820,  ..., -2.1661, -2.2648, -2.3058],
         [-2.0968, -2.2773, -2.1851,  ..., -2.1677, -2.2617, -2.3111]],

        [[-2.0971, -2.2883, -2.2053,  ..., -2.1545, -2.2617, -2.3121],
         [-2.0945, -2.2890, 