# Train Notebook - for developing code 


In [1]:
import sys
import os 
os.environ['MPLCONFIGDIR'] = '/myhome'
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from data import BavarianCrops, BreizhCrops, SustainbenchCrops, ModisCDL
from torch.utils.data import DataLoader
from earlyrnn import EarlyRNN
import torch
from tqdm import tqdm
from loss import EarlyRewardLoss
import sklearn.metrics
import pandas as pd
import wandb
from utils.plots import plot_label_distribution_datasets, boxplot_stopping_times
from utils.doy import get_doys_dict_test, get_doy_stop, create_sorted_doys_dict_test, get_approximated_doys_dict
from utils.helpers_training import parse_args, train_epoch, test_epoch
from utils.metrics import harmonic_mean_score
import matplotlib.pyplot as plt

In [2]:
import argparse

def parse_args(args=None):
    parser = argparse.ArgumentParser(description='Run ELECTS Early Classification training on the BavarianCrops dataset.')
    parser.add_argument('--dataset', type=str, default="bavariancrops", choices=["bavariancrops","breizhcrops", "ghana", "southsudan","unitedstates"], help="dataset")
    parser.add_argument('--alpha', type=float, default=0.5, help="trade-off parameter of earliness and accuracy (eq 6): "
                                                                 "1=full weight on accuracy; 0=full weight on earliness")
    parser.add_argument('--epsilon', type=float, default=10, help="additive smoothing parameter that helps the "
                                                                  "model recover from too early classifications (eq 7)")
    parser.add_argument('--learning-rate', type=float, default=1e-3, help="Optimizer learning rate")
    parser.add_argument('--weight-decay', type=float, default=0, help="weight_decay")
    parser.add_argument('--patience', type=int, default=30, help="Early stopping patience")
    parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        choices=["cuda", "cpu"], help="'cuda' (GPU) or 'cpu' device to run the code. "
                                                     "defaults to 'cuda' if GPU is available, otherwise 'cpu'")
    parser.add_argument('--epochs', type=int, default=100, help="number of training epochs")
    parser.add_argument('--sequencelength', type=int, default=70, help="sequencelength of the time series. If samples are shorter, "
                                                                "they are zero-padded until this length; "
                                                                "if samples are longer, they will be undersampled")
    parser.add_argument('--batchsize', type=int, default=256, help="number of samples per batch")
    parser.add_argument('--dataroot', type=str, default=os.path.join(os.environ.get("HOME", os.environ.get("USERPROFILE")),"elects_data"), help="directory to download the "
                                                                                 "BavarianCrops dataset (400MB)."
                                                                                 "Defaults to home directory.")
    parser.add_argument('--snapshot', type=str, default="snapshots/model.pth",
                        help="pytorch state dict snapshot file")
    parser.add_argument('--resume', action='store_true')

    if args is not None:
        args = parser.parse_args(args)
    else:
        args = parser.parse_args()

    if args.patience < 0:
        args.patience = None

    return args


In [3]:
# Example of how to call parse_args with custom arguments in a notebook
custom_args = "--dataset breizhcrops --epochs 10".split()
args = parse_args(custom_args)
print("cuda is available: ", args.device)
print(args)

cuda is available:  cuda
Namespace(dataset='breizhcrops', alpha=0.5, epsilon=10, learning_rate=0.001, weight_decay=0, patience=30, device='cuda', epochs=10, sequencelength=70, batchsize=256, dataroot='C:\\Users\\anyam\\elects_data', snapshot='snapshots/model.pth', resume=False)


In [4]:
dir = "/mydata/studentanya/anya/wandb/"
wandb.init(
        dir=None,
        project="ELECTS",
        notes="first experimentations with ELECTS",
        tags=["ELECTS", args.dataset, "with_doys_boxplot"],
        # track hyperparameters and run metadata
        config={
        "dataset": args.dataset,
        "alpha": args.alpha,
        "epsilon": args.epsilon,
        "learning_rate": args.learning_rate,
        "weight_decay": args.weight_decay,
        "patience": args.patience,
        "device": args.device,
        "epochs": args.epochs,
        "sequencelength": args.sequencelength,
        "batchsize": args.batchsize,
        "dataroot": args.dataroot,
        "snapshot": args.snapshot,
        "resume": args.resume,
        "architecture": "EarlyRNN",
        "optimizer": "AdamW",
        "criterion": "EarlyRewardLoss",
        }
    )


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: aurenore. Use `wandb login --relogin` to force relogin


In [5]:
# ----------------------------- LOAD DATASET -----------------------------

if args.dataset == "bavariancrops":
    dataroot = os.path.join(args.dataroot,"bavariancrops")
    nclasses = 7
    input_dim = 13
    class_weights = None
    train_ds = BavarianCrops(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = BavarianCrops(root=dataroot,partition="valid", sequencelength=args.sequencelength)
    class_names = test_ds.classes
elif args.dataset == "unitedstates":
    args.dataroot = "/data/modiscdl/"
    args.sequencelength = 24
    dataroot = args.dataroot
    nclasses = 8
    input_dim = 1
    train_ds = ModisCDL(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = ModisCDL(root=dataroot,partition="valid", sequencelength=args.sequencelength)
elif args.dataset == "breizhcrops":
    dataroot = os.path.join(args.dataroot,"breizhcrops")
    nclasses = 9
    input_dim = 13
    print("get doys dict test")
    doys_dict_test = get_doys_dict_test(dataroot=os.path.join(args.dataroot,args.dataset))
    length_sorted_doy_dict_test = create_sorted_doys_dict_test(doys_dict_test)
    print("get doys dict test done")
    print("get test and validation data...")
    train_ds = BreizhCrops(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = BreizhCrops(root=dataroot,partition="valid", sequencelength=args.sequencelength)
    class_names = test_ds.ds.classname
    print("class names:", class_names)
elif args.dataset in ["ghana"]:
    use_s2_only = False
    average_pixel = False
    max_n_pixels = 50
    dataroot = args.dataroot
    nclasses = 4
    input_dim = 12 if use_s2_only else 19  # 12 sentinel 2 + 3 x sentinel 1 + 4 * planet
    args.epochs = 500
    args.sequencelength = 365
    train_ds = SustainbenchCrops(root=dataroot,partition="train", sequencelength=args.sequencelength,
                                    country="ghana",
                                    use_s2_only=use_s2_only, average_pixel=average_pixel,
                                    max_n_pixels=max_n_pixels)
    val_ds = SustainbenchCrops(root=dataroot,partition="val", sequencelength=args.sequencelength,
                                country="ghana", use_s2_only=use_s2_only, average_pixel=average_pixel,
                                max_n_pixels=max_n_pixels)

    train_ds = torch.utils.data.ConcatDataset([train_ds, val_ds])

    test_ds = SustainbenchCrops(root=dataroot,partition="test", sequencelength=args.sequencelength,
                                country="ghana", use_s2_only=use_s2_only, average_pixel=average_pixel,
                                max_n_pixels=max_n_pixels)
    class_names = test_ds.classes
elif args.dataset in ["southsudan"]:
    use_s2_only = False
    dataroot = args.dataroot
    nclasses = 4
    args.sequencelength = 365
    input_dim = 12 if use_s2_only else 19 # 12 sentinel 2 + 3 x sentinel 1 + 4 * planet
    args.epochs = 500
    train_ds = SustainbenchCrops(root=dataroot,partition="train", sequencelength=args.sequencelength, country="southsudan", use_s2_only=use_s2_only)
    val_ds = SustainbenchCrops(root=dataroot,partition="val", sequencelength=args.sequencelength, country="southsudan", use_s2_only=use_s2_only)

    train_ds = torch.utils.data.ConcatDataset([train_ds, val_ds])
    test_ds = SustainbenchCrops(root=dataroot, partition="val", sequencelength=args.sequencelength,
                                country="southsudan", use_s2_only=use_s2_only)
    class_names = test_ds.classes

else:
    raise ValueError(f"dataset {args.dataset} not recognized")

traindataloader = DataLoader(
    train_ds,
    batch_size=args.batchsize)
testdataloader = DataLoader(
    test_ds,
    batch_size=args.batchsize)


get doys dict test
get doys dict test done
get test and validation data...
2559635960 2559635960


loading data into RAM: 100%|██████████| 178613/178613 [01:13<00:00, 2433.51it/s]


2253658856 2253658856


loading data into RAM: 100%|██████████| 140645/140645 [00:56<00:00, 2484.31it/s]


2493572704 2493572704


loading data into RAM: 100%|██████████| 166391/166391 [01:05<00:00, 2522.00it/s]


class names: ['barley' 'wheat' 'rapeseed' 'corn' 'sunflower' 'orchards' 'nuts'
 'permanent meadows' 'temporary meadows']


In [6]:
# ----------------------------- VISUALIZATION: label distribution -----------------------------
datasets = [train_ds, test_ds]
sets_labels = ["Train", "Validation"]
fig, ax = plt.subplots(figsize=(15, 7))
fig, ax = plot_label_distribution_datasets(datasets, sets_labels, fig, ax, title='Label distribution', labels_names=class_names)
wandb.log({"label_distribution": wandb.Image(fig)})
plt.close(fig)  


Extracting labels from dataset Train.
Extracting labels from dataset Validation.


In [69]:
# TempCNN
import os

import torch
import torch.nn as nn
import torch.utils.data

"""
Source: https://github.com/dl4sits/BreizhCrops/blob/master/breizhcrops/models/TempCNN.py 

Pytorch re-implementation of Pelletier et al. 2019
https://github.com/charlotte-pel/temporalCNN

https://www.mdpi.com/2072-4292/11/5/523

"""

__all__ = ['TempCNN']

class TempCNN(torch.nn.Module):
    def __init__(self, input_dim=13, num_classes=9, sequencelength=45, kernel_size=7, hidden_dims=128, dropout=0.18203942949809093):
        super(TempCNN, self).__init__()
        self.modelname = f"TempCNN_input-dim={input_dim}_num-classes={num_classes}_sequencelenght={sequencelength}_" \
                         f"kernelsize={kernel_size}_hidden-dims={hidden_dims}_dropout={dropout}"

        self.hidden_dims = hidden_dims

        self.conv_bn_relu1 = Conv1D_BatchNorm_Relu_Dropout(input_dim, hidden_dims, kernel_size=kernel_size,
                                                           drop_probability=dropout)
        self.conv_bn_relu2 = Conv1D_BatchNorm_Relu_Dropout(hidden_dims, hidden_dims, kernel_size=kernel_size,
                                                           drop_probability=dropout)
        self.conv_bn_relu3 = Conv1D_BatchNorm_Relu_Dropout(hidden_dims, hidden_dims, kernel_size=kernel_size,
                                                           drop_probability=dropout)
        self.flatten = Flatten()
        self.dense = FC_BatchNorm_Relu_Dropout(hidden_dims * sequencelength, 4 * hidden_dims, drop_probability=dropout)
        self.logsoftmax = nn.Sequential(nn.Linear(4 * hidden_dims, hidden_dims), nn.Tanh())

    def forward(self, x):
        # require NxTxD
        x = x.transpose(1,2)
        x = self.conv_bn_relu1(x)
        x = self.conv_bn_relu2(x)
        x = self.conv_bn_relu3(x)
        x = self.flatten(x)
        x = self.dense(x)
        return self.logsoftmax(x)

    def save(self, path="model.pth", **kwargs):
        print("\nsaving model to " + path)
        model_state = self.state_dict()
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(dict(model_state=model_state, **kwargs), path)

    def load(self, path):
        print("loading model from " + path)
        snapshot = torch.load(path, map_location="cpu")
        model_state = snapshot.pop('model_state', snapshot)
        self.load_state_dict(model_state)
        return snapshot


class Conv1D_BatchNorm_Relu_Dropout(torch.nn.Module):
    def __init__(self, input_dim, hidden_dims, kernel_size=5, drop_probability=0.5):
        super(Conv1D_BatchNorm_Relu_Dropout, self).__init__()

        self.block = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dims, kernel_size, padding=(kernel_size // 2)),
            nn.BatchNorm1d(hidden_dims),
            nn.ReLU(),
            nn.Dropout(p=drop_probability)
        )

    def forward(self, X):
        return self.block(X)


class FC_BatchNorm_Relu_Dropout(torch.nn.Module):
    def __init__(self, input_dim, hidden_dims, drop_probability=0.5):
        super(FC_BatchNorm_Relu_Dropout, self).__init__()

        self.block = nn.Sequential(
            nn.Linear(input_dim, hidden_dims),
            nn.BatchNorm1d(hidden_dims),
            nn.ReLU(),
            nn.Dropout(p=drop_probability)
        )

    def forward(self, X):
        return self.block(X)


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [70]:
# earlyrnn.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import os
#from models.EarlyClassificationModel import EarlyClassificationModel
from torch.nn.modules.normalization import LayerNorm
import json
# from models.TempCNN import TempCNN


class EarlyRNN(nn.Module):
    def __init__(self, backbone_model:str="LSTM", input_dim:int=13, hidden_dims:int=64, nclasses:int=7, num_rnn_layers:int=2, dropout:float=0.2, sequencelength:int=70, kernel_size:int=7):
        super(EarlyRNN, self).__init__()

        # input transformations
        self.intransforms = nn.Sequential(
            nn.LayerNorm(input_dim), # normalization over D-dimension. T-dimension is untouched
            nn.Linear(input_dim, hidden_dims) # project to hidden_dims length
        )

        # backbone model 
        self.initialize_model(backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size)

        # Heads
        self.classification_head = ClassificationHead(hidden_dims, nclasses)
        self.stopping_decision_head = DecisionHead(hidden_dims)

    def forward(self, x):
        x = self.intransforms(x)
        # TO DO: CORRECT THE SHAPE OF THE OUTPUTS
        # outputs, last_state_list = self.backbone(x)
        output_tupple = self.backbone(x)
        if type(output_tupple) == tuple:
            outputs = output_tupple[0]
        else:
            outputs = output_tupple
        print("outputs shape:", outputs.shape)
        log_class_probabilities = self.classification_head(outputs)
        probabilitiy_stopping = self.stopping_decision_head(outputs)

        return log_class_probabilities, probabilitiy_stopping

    @torch.no_grad()
    def predict(self, x):
        logprobabilities, deltas = self.forward(x)

        def sample_stop_decision(delta):
            dist = torch.stack([1 - delta, delta], dim=1)
            return torch.distributions.Categorical(dist).sample().bool()

        batchsize, sequencelength, nclasses = logprobabilities.shape

        stop = list()
        for t in range(sequencelength):
            # stop if sampled true and not stopped previously
            if t < sequencelength - 1:
                stop_now = sample_stop_decision(deltas[:, t])
                stop.append(stop_now)
            else:
                # make sure to stop last
                last_stop = torch.ones(tuple(stop_now.shape)).bool()
                if torch.cuda.is_available():
                    last_stop = last_stop.cuda()
                stop.append(last_stop)

        # stack over the time dimension (multiple stops possible)
        stopped = torch.stack(stop, dim=1).bool()

        # is only true if stopped for the first time
        first_stops = (stopped.cumsum(1) == 1) & stopped

        # time of stopping
        t_stop = first_stops.long().argmax(1)

        # all predictions
        predictions = logprobabilities.argmax(-1)

        # predictions at time of stopping
        predictions_at_t_stop = torch.masked_select(predictions, first_stops)

        return logprobabilities, deltas, predictions_at_t_stop, t_stop
    
    def initialize_model(self, backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size):
        self.backbone = get_backbone_model(backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size)
        

def get_backbone_model(backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size):
    model_map = {
        "LSTM": {
            "class": nn.LSTM,
            "config": {
                "input_size": hidden_dims,
                "hidden_size": hidden_dims,
                "num_layers": num_rnn_layers,
                "bias": False,
                "batch_first": True,
                "dropout": dropout,
                "bidirectional": False
            }
        },
        "TempCNN": {
            "class": TempCNN,
            "config": {
                "input_dim": hidden_dims,
                "num_classes": nclasses,
                "sequencelength": sequencelength,
                "kernel_size": kernel_size,
                "hidden_dims": hidden_dims, 
                "dropout": dropout
            }
        }
    }
    
    if backbone_model in model_map:
        model_info = model_map[backbone_model]
        return model_info["class"](**model_info["config"])
    else:
        raise ValueError(f"Backbone model {backbone_model} is not implemented yet.")


class ClassificationHead(torch.nn.Module):

    def __init__(self, hidden_dims, nclasses):
        super(ClassificationHead, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(hidden_dims, nclasses, bias=True),
            nn.LogSoftmax(dim=2))

    def forward(self, x):
        return self.projection(x)

class DecisionHead(torch.nn.Module):

    def __init__(self, hidden_dims):
        super(DecisionHead, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(hidden_dims, 1, bias=True),
            nn.Sigmoid()
        )

        # initialize bias to predict late in first epochs
        torch.nn.init.normal_(self.projection[0].bias, mean=-2e1, std=1e-1)


    def forward(self, x):
        return self.projection(x).squeeze(2)

if __name__ == "__main__":
    model = EarlyRNN()


In [71]:
for i, (x, y) in enumerate(traindataloader):
    print("x shape:", x.shape)
    print("y shape:", y.shape)
    break

x shape: torch.Size([256, 70, 13])
y shape: torch.Size([256, 70])


In [72]:
# ----------------------------- SET UP MODEL -----------------------------
#model = EarlyRNN(nclasses=nclasses, input_dim=input_dim, sequencelength=args.sequencelength).to(args.device)
model = EarlyRNN("TempCNN", nclasses=nclasses, input_dim=input_dim, sequencelength=args.sequencelength, kernel_size=7).to(args.device)


#optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

# exclude decision head linear bias from weight decay
decay, no_decay = list(), list()
for name, param in model.named_parameters():
    if name == "stopping_decision_head.projection.0.bias":
        no_decay.append(param)
    else:
        decay.append(param)

optimizer = torch.optim.AdamW([{'params': no_decay, 'weight_decay': 0, "lr": args.learning_rate}, {'params': decay}],
                                lr=args.learning_rate, weight_decay=args.weight_decay)

criterion = EarlyRewardLoss(alpha=args.alpha, epsilon=args.epsilon)

if args.resume and os.path.exists(args.snapshot):
    model.load_state_dict(torch.load(args.snapshot, map_location=args.device))
    optimizer_snapshot = os.path.join(os.path.dirname(args.snapshot),
                                        os.path.basename(args.snapshot).replace(".pth", "_optimizer.pth")
                                        )
    optimizer.load_state_dict(torch.load(optimizer_snapshot, map_location=args.device))
    df = pd.read_csv(args.snapshot + ".csv")
    train_stats = df.to_dict("records")
    start_epoch = train_stats[-1]["epoch"]
    print(f"resuming from {args.snapshot} epoch {start_epoch}")
else:
    train_stats = []
    start_epoch = 1

not_improved = 0


In [73]:
# ----------------------------- TRAINING -----------------------------
print("starting training...")
with tqdm(range(start_epoch, args.epochs + 1)) as pbar:
    for epoch in pbar:
        trainloss = train_epoch(model, traindataloader, optimizer, criterion, device=args.device)
        testloss, stats = test_epoch(model, testdataloader, criterion, args.device)

        # statistic logging and visualization...
        precision, recall, fscore, support = sklearn.metrics.precision_recall_fscore_support(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0], average="macro",
            zero_division=0)
        accuracy = sklearn.metrics.accuracy_score(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0])
        kappa = sklearn.metrics.cohen_kappa_score(
            stats["predictions_at_t_stop"][:, 0], stats["targets"][:, 0])

        classification_loss = stats["classification_loss"].mean()
        earliness_reward = stats["earliness_reward"].mean()
        earliness = 1 - (stats["t_stop"].mean() / (args.sequencelength - 1))
        harmonic_mean = harmonic_mean_score(accuracy, stats["classification_earliness"])

        # ----------------------------- LOGGING -----------------------------
        train_stats.append(
            dict(
                epoch=epoch,
                trainloss=trainloss,
                testloss=testloss,
                accuracy=accuracy,
                precision=precision,
                recall=recall,
                fscore=fscore,
                kappa=kappa,
                elects_earliness=earliness,
                classification_loss=classification_loss,
                earliness_reward=earliness_reward,
                classification_earliness=stats["classification_earliness"],
                harmonic_mean=harmonic_mean,
            )
        )
        fig_boxplot, ax_boxplot = plt.subplots(figsize=(15, 7))
        doys_dict = get_approximated_doys_dict(stats["seqlengths"], length_sorted_doy_dict_test)
        doys_stop = get_doy_stop(stats, doys_dict)
        fig_boxplot, _ = boxplot_stopping_times(doys_stop, stats, fig_boxplot, ax_boxplot, class_names)
        wandb.log({
                "loss": {"trainloss": trainloss, "testloss": testloss},
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "fscore": fscore,
                "kappa": kappa,
                "elects_earliness": earliness,
                "classification_loss": classification_loss,
                "earliness_reward": earliness_reward,
                "classification_earliness": stats["classification_earliness"],
                "harmonic_mean": harmonic_mean,
                "boxplot": wandb.Image(fig_boxplot),
                "conf_mat" : wandb.plot.confusion_matrix(probs=None,
                        y_true=stats["targets"][:,0], preds=stats["predictions_at_t_stop"][:,0],
                        class_names=class_names, title="Confusion Matrix")
            })
        plt.close(fig_boxplot)

        df = pd.DataFrame(train_stats).set_index("epoch")

        savemsg = ""
        if len(df) > 2:
            if testloss < df.testloss[:-1].values.min():
                savemsg = f"saving model to {args.snapshot}"
                os.makedirs(os.path.dirname(args.snapshot), exist_ok=True)
                torch.save(model.state_dict(), args.snapshot)

                optimizer_snapshot = os.path.join(os.path.dirname(args.snapshot),
                                                    os.path.basename(args.snapshot).replace(".pth", "_optimizer.pth")
                                                    )
                torch.save(optimizer.state_dict(), optimizer_snapshot)
                wandb.log_artifact(args.snapshot, type="model")  

                df.to_csv(args.snapshot + ".csv")
                not_improved = 0 # reset early stopping counter
            else:
                not_improved += 1 # increment early stopping counter
                if args.patience is not None:
                    savemsg = f"early stopping in {args.patience - not_improved} epochs."
                else:
                    savemsg = ""

        pbar.set_description(f"epoch {epoch}: trainloss {trainloss:.2f}, testloss {testloss:.2f}, "
                     f"accuracy {accuracy:.2f}, earliness {earliness:.2f}. "
                     f"classification loss {classification_loss:.2f}, earliness reward {earliness_reward:.2f}, harmonic mean {harmonic_mean:.2f}. {savemsg}")
        
            
        if args.patience is not None:
            if not_improved > args.patience:
                print(f"stopping training. testloss {testloss:.2f} did not improve in {args.patience} epochs.")
                break
    
# ----------------------------- SAVE FINAL MODEL -----------------------------
wandb.log_artifact(args.snapshot, type="model")

starting training...


  0%|          | 0/10 [00:00<?, ?it/s]

outputs shape: torch.Size([256, 64])


  0%|          | 0/10 [00:00<?, ?it/s]


IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

In [11]:
wandb.finish()

VBox(children=(Label(value='0.131 MB of 0.139 MB uploaded\r'), FloatProgress(value=0.9465552404430497, max=1.0…

0,1
accuracy,▁█
classification_earliness,█▁
classification_loss,█▁
earliness_reward,▁█
elects_earliness,▁█
fscore,▁█
harmonic_mean,▁█
kappa,▁█
precision,▁█
recall,▁█

0,1
accuracy,0.76197
classification_earliness,0.37094
classification_loss,7.63235
earliness_reward,3.43585
elects_earliness,0.65988
fscore,0.53408
harmonic_mean,0.68916
kappa,0.68952
precision,0.52835
recall,0.54136
