# Train Notebook - for developing code 
## test epoch with extra padding 
test new functions with extra padding. 


In [1]:
import sys
import os 
os.environ['MPLCONFIGDIR'] = '/myhome'
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from data import BavarianCrops, BreizhCrops, SustainbenchCrops, ModisCDL
from torch.utils.data import DataLoader
from earlyrnn import EarlyRNN
import torch
from tqdm import tqdm
from loss import EarlyRewardLoss
import sklearn.metrics
import pandas as pd
import wandb
from utils.plots import plot_label_distribution_datasets, boxplot_stopping_times
from utils.doy import get_doys_dict_test, get_doy_stop, create_sorted_doys_dict_test, get_approximated_doys_dict
from utils.helpers_training import parse_args, train_epoch, test_epoch
from utils.metrics import harmonic_mean_score
import matplotlib.pyplot as plt
from models.model_helpers import count_parameters

In [2]:
import argparse

def parse_args(args=None):
    parser = argparse.ArgumentParser(description='Run ELECTS Early Classification training on the BavarianCrops dataset.')
    parser.add_argument('--backbonemodel', type=str, default="LSTM", choices=["LSTM", "TempCNN", "Transformer"], help="backbone model")
    parser.add_argument('--dataset', type=str, default="bavariancrops", choices=["bavariancrops","breizhcrops", "ghana", "southsudan","unitedstates"], help="dataset")
    parser.add_argument('--alpha', type=float, default=0.5, help="trade-off parameter of earliness and accuracy (eq 6): "
                                                                 "1=full weight on accuracy; 0=full weight on earliness")
    parser.add_argument('--epsilon', type=float, default=10, help="additive smoothing parameter that helps the "
                                                                  "model recover from too early classifications (eq 7)")
    parser.add_argument('--learning-rate', type=float, default=1e-3, help="Optimizer learning rate")
    parser.add_argument('--weight-decay', type=float, default=0, help="weight_decay")
    parser.add_argument('--patience', type=int, default=30, help="Early stopping patience")
    parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        choices=["cuda", "cpu"], help="'cuda' (GPU) or 'cpu' device to run the code. "
                                                     "defaults to 'cuda' if GPU is available, otherwise 'cpu'")
    parser.add_argument('--epochs', type=int, default=100, help="number of training epochs")
    parser.add_argument('--sequencelength', type=int, default=70, help="sequencelength of the time series. If samples are shorter, "
                                                                "they are zero-padded until this length; "
                                                                "if samples are longer, they will be undersampled")
    parser.add_argument('--hidden-dims', type=int, default=64, help="number of hidden dimensions in the backbone model")
    parser.add_argument('--batchsize', type=int, default=256, help="number of samples per batch")
    parser.add_argument('--dataroot', type=str, default=os.path.join(os.environ.get("HOME", os.environ.get("USERPROFILE")),"elects_data"), help="directory to download the "
                                                                                 "BavarianCrops dataset (400MB)."
                                                                                 "Defaults to home directory.")
    parser.add_argument('--snapshot', type=str, default="snapshots/model.pth",
                        help="pytorch state dict snapshot file")
    parser.add_argument('--resume', action='store_true')

    if args is not None:
        args = parser.parse_args(args)
    else:
        args = parser.parse_args()

    if args.patience < 0:
        args.patience = None

    return args


In [3]:
# Example of how to call parse_args with custom arguments in a notebook
custom_args = "--backbonemodel TempCNN --dataset breizhcrops --epochs 10 --hidden-dims 16".split()
args = parse_args(custom_args)
print("cuda is available: ", args.device)
print(args)

cuda is available:  cuda
Namespace(backbonemodel='TempCNN', dataset='breizhcrops', alpha=0.5, epsilon=10, learning_rate=0.001, weight_decay=0, patience=30, device='cuda', epochs=10, sequencelength=70, hidden_dims=16, batchsize=256, dataroot='C:\\Users\\anyam\\elects_data', snapshot='snapshots/model.pth', resume=False)


In [4]:
dir = "/mydata/studentanya/anya/wandb/"
wandb.init(
        dir=None,
        project="ELECTS",
        notes="first experimentations with ELECTS",
        tags=["ELECTS", args.dataset, "with_doys_boxplot", args.backbonemodel],
        config={
        "backbonemodel": args.backbonemodel,
        "dataset": args.dataset,
        "alpha": args.alpha,
        "epsilon": args.epsilon,
        "learning_rate": args.learning_rate,
        "weight_decay": args.weight_decay,
        "patience": args.patience,
        "device": args.device,
        "epochs": args.epochs,
        "sequencelength": args.sequencelength,
        "hidden_dims": args.hidden_dims,
        "batchsize": args.batchsize,
        "dataroot": args.dataroot,
        "snapshot": args.snapshot,
        "resume": args.resume,
        "architecture": "EarlyRNN",
        "optimizer": "AdamW",
        "criterion": "EarlyRewardLoss",
        }
    )


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: aurenore. Use `wandb login --relogin` to force relogin


In [5]:
# ----------------------------- LOAD DATASET -----------------------------

if args.dataset == "bavariancrops":
    dataroot = os.path.join(args.dataroot,"bavariancrops")
    nclasses = 7
    input_dim = 13
    class_weights = None
    train_ds = BavarianCrops(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = BavarianCrops(root=dataroot,partition="valid", sequencelength=args.sequencelength)
    class_names = test_ds.classes
elif args.dataset == "unitedstates":
    args.dataroot = "/data/modiscdl/"
    args.sequencelength = 24
    dataroot = args.dataroot
    nclasses = 8
    input_dim = 1
    train_ds = ModisCDL(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = ModisCDL(root=dataroot,partition="valid", sequencelength=args.sequencelength)
elif args.dataset == "breizhcrops":
    dataroot = os.path.join(args.dataroot,"breizhcrops")
    nclasses = 9
    input_dim = 13
    print("get doys dict test")
    doys_dict_test = get_doys_dict_test(dataroot=os.path.join(args.dataroot,args.dataset))
    length_sorted_doy_dict_test = create_sorted_doys_dict_test(doys_dict_test)
    print("get doys dict test done")
    print("get train and validation data...")
    train_ds = BreizhCrops(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = BreizhCrops(root=dataroot,partition="valid", sequencelength=args.sequencelength)
    class_names = test_ds.ds.classname
    print("class names:", class_names)
elif args.dataset in ["ghana"]:
    use_s2_only = False
    average_pixel = False
    max_n_pixels = 50
    dataroot = args.dataroot
    nclasses = 4
    input_dim = 12 if use_s2_only else 19  # 12 sentinel 2 + 3 x sentinel 1 + 4 * planet
    args.epochs = 500
    args.sequencelength = 365
    train_ds = SustainbenchCrops(root=dataroot,partition="train", sequencelength=args.sequencelength,
                                    country="ghana",
                                    use_s2_only=use_s2_only, average_pixel=average_pixel,
                                    max_n_pixels=max_n_pixels)
    val_ds = SustainbenchCrops(root=dataroot,partition="val", sequencelength=args.sequencelength,
                                country="ghana", use_s2_only=use_s2_only, average_pixel=average_pixel,
                                max_n_pixels=max_n_pixels)

    train_ds = torch.utils.data.ConcatDataset([train_ds, val_ds])

    test_ds = SustainbenchCrops(root=dataroot,partition="test", sequencelength=args.sequencelength,
                                country="ghana", use_s2_only=use_s2_only, average_pixel=average_pixel,
                                max_n_pixels=max_n_pixels)
    class_names = test_ds.classes
elif args.dataset in ["southsudan"]:
    use_s2_only = False
    dataroot = args.dataroot
    nclasses = 4
    args.sequencelength = 365
    input_dim = 12 if use_s2_only else 19 # 12 sentinel 2 + 3 x sentinel 1 + 4 * planet
    args.epochs = 500
    train_ds = SustainbenchCrops(root=dataroot,partition="train", sequencelength=args.sequencelength, country="southsudan", use_s2_only=use_s2_only)
    val_ds = SustainbenchCrops(root=dataroot,partition="val", sequencelength=args.sequencelength, country="southsudan", use_s2_only=use_s2_only)

    train_ds = torch.utils.data.ConcatDataset([train_ds, val_ds])
    test_ds = SustainbenchCrops(root=dataroot, partition="val", sequencelength=args.sequencelength,
                                country="southsudan", use_s2_only=use_s2_only)
    class_names = test_ds.classes

else:
    raise ValueError(f"dataset {args.dataset} not recognized")

traindataloader = DataLoader(
    train_ds,
    batch_size=args.batchsize)
testdataloader = DataLoader(
    test_ds,
    batch_size=args.batchsize)


get doys dict test
get doys dict test done
get train and validation data...
2559635960 2559635960


loading data into RAM: 100%|██████████| 178613/178613 [01:11<00:00, 2487.49it/s]


2253658856 2253658856


loading data into RAM: 100%|██████████| 140645/140645 [00:57<00:00, 2436.71it/s]


2493572704 2493572704


loading data into RAM: 100%|██████████| 166391/166391 [01:03<00:00, 2611.02it/s]


class names: ['barley' 'wheat' 'rapeseed' 'corn' 'sunflower' 'orchards' 'nuts'
 'permanent meadows' 'temporary meadows']


In [6]:
# ----------------------------- VISUALIZATION: label distribution -----------------------------
# datasets = [train_ds, test_ds]
# sets_labels = ["Train", "Validation"]
# fig, ax = plt.subplots(figsize=(15, 7))
# fig, ax = plot_label_distribution_datasets(datasets, sets_labels, fig, ax, title='Label distribution', labels_names=class_names)
# wandb.log({"label_distribution": wandb.Image(fig)})
# plt.close(fig)  


In [7]:
step_timestamp_padding = 25
sequence_lengths_train = [step_timestamp_padding*i for i in range(1, args.sequencelength//step_timestamp_padding+1)]
print("sequence_lengths_train", sequence_lengths_train)
extra_padding_list = [args.sequencelength - i for i in sequence_lengths_train]
print("extra_padding_list", extra_padding_list)

sequence_lengths_train [25, 50]
extra_padding_list [45, 20]


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import os
#from models.EarlyClassificationModel import EarlyClassificationModel
from torch.nn.modules.normalization import LayerNorm
from models.TempCNN import TempCNN


class EarlyRNN(nn.Module):
    def __init__(self, backbone_model:str="LSTM", input_dim:int=13, hidden_dims:int=64, nclasses:int=7, num_rnn_layers:int=2, dropout:float=0.2, sequencelength:int=70, kernel_size:int=7):
        super(EarlyRNN, self).__init__()
        # padding layer, for certain models that require padding
        self.padding = Padding()

        # input transformations
        self.intransforms = nn.Sequential(
            nn.LayerNorm(input_dim), # normalization over D-dimension. T-dimension is untouched
            nn.Linear(input_dim, hidden_dims) # project to hidden_dims length
        )

        # backbone model 
        self.initialize_model(backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size)

        # Heads
        self.classification_head = ClassificationHead(hidden_dims, nclasses)
        self.stopping_decision_head = DecisionHead(hidden_dims)

    def forward(self, x, **kwargs):
        x = self.padding(x, **kwargs)
        x = self.intransforms(x)
        output_tupple = self.backbone(x)
        if type(output_tupple) == tuple:
            outputs = output_tupple[0]
        else:
            outputs = output_tupple # shape : (batch_size, sequencelength, hidden_dims)
        log_class_probabilities = self.classification_head(outputs)
        probabilitiy_stopping = self.stopping_decision_head(outputs)

        return log_class_probabilities, probabilitiy_stopping

    @torch.no_grad()
    def predict(self, x, **kwargs):
        logprobabilities, deltas = self.forward(x, **kwargs)

        def sample_stop_decision(delta):
            dist = torch.stack([1 - delta, delta], dim=1)
            return torch.distributions.Categorical(dist).sample().bool()

        batchsize, sequencelength, nclasses = logprobabilities.shape

        stop = list()
        for t in range(sequencelength):
            # stop if sampled true and not stopped previously
            if t < sequencelength - 1:
                stop_now = sample_stop_decision(deltas[:, t]) # stop decision is true with probability delta
                stop.append(stop_now)
            else:
                # make sure to stop last
                last_stop = torch.ones(tuple(stop_now.shape)).bool()
                if torch.cuda.is_available():
                    last_stop = last_stop.cuda()
                stop.append(last_stop)

        # stack over the time dimension (multiple stops possible)
        stopped = torch.stack(stop, dim=1).bool() # has a true if the model decides to stop at time t

        # is only true if stopped for the first time
        first_stops = (stopped.cumsum(1) == 1) & stopped

        # time of stopping
        t_stop = first_stops.long().argmax(1) # get the index of the first stop
        
        # if there is an extra padding, the stopping time is shifted: 
        if len(kwargs) > 0:
            extra_padding = kwargs.get("extra_padding", 0)
            t_stop -= extra_padding

        # all predictions
        predictions = logprobabilities.argmax(-1)

        # predictions at time of stopping
        predictions_at_t_stop = torch.masked_select(predictions, first_stops)

        return logprobabilities, deltas, predictions_at_t_stop, t_stop
    
    def initialize_model(self, backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size):
        self.backbone = get_backbone_model(backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size)
        

def get_backbone_model(backbone_model, input_dim, hidden_dims, nclasses, num_rnn_layers, dropout, sequencelength, kernel_size):
    model_map = {
        "LSTM": {
            "class": nn.LSTM,
            "config": {
                "input_size": hidden_dims,
                "hidden_size": hidden_dims,
                "num_layers": num_rnn_layers,
                "bias": False,
                "batch_first": True,
                "dropout": dropout,
                "bidirectional": False
            }
        },
        "TempCNN": {
            "class": TempCNN,
            "config": {
                "input_dim": hidden_dims,
                "sequencelength": sequencelength,
                "kernel_size": kernel_size,
                "hidden_dims": hidden_dims, 
                "dropout": dropout
            }
        }
    }
    
    if backbone_model in model_map:
        model_info = model_map[backbone_model]
        return model_info["class"](**model_info["config"])
    else:
        raise ValueError(f"Backbone model {backbone_model} is not implemented yet.")


class Padding(nn.Module):
    def __init__(self): 
        super(Padding, self).__init__()

    def forward(self, x, **kwargs):
        if len(kwargs) > 0:
            extra_padding = kwargs.get("extra_padding", 0)
            if extra_padding > 0:
                # add extra padding to the input, on the left side
                x = torch.cat([torch.zeros(x.shape[0], extra_padding, x.shape[2], device=x.device), x], dim=1)
                # remove the extra padding from the output, on the right side
                x = x[:, :-extra_padding, :]
            return x
        else:
            return x
        

class ClassificationHead(torch.nn.Module):

    def __init__(self, hidden_dims, nclasses):
        super(ClassificationHead, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(hidden_dims, nclasses, bias=True),
            nn.LogSoftmax(dim=2))

    def forward(self, x):
        return self.projection(x)

class DecisionHead(torch.nn.Module):

    def __init__(self, hidden_dims):
        super(DecisionHead, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(hidden_dims, 1, bias=True),
            nn.Sigmoid()
        )

        # initialize bias to predict late in first epochs
        torch.nn.init.normal_(self.projection[0].bias, mean=-2e1, std=1e-1)

    def forward(self, x):
        return self.projection(x).squeeze(2)


if __name__ == "__main__":
    model = EarlyRNN()

In [26]:
# ----------------------------- SET UP MODEL -----------------------------
#model = EarlyRNN(nclasses=nclasses, input_dim=input_dim, hidden_dims=64, sequencelength=args.sequencelength).to(args.device)
# nclasses=9
# input_dim=13
model = EarlyRNN(args.backbonemodel, nclasses=nclasses, input_dim=input_dim, sequencelength=args.sequencelength, kernel_size=7, hidden_dims=args.hidden_dims).to(args.device)
print(f"The model has {count_parameters(model):,} trainable parameters.")

#optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

# exclude decision head linear bias from weight decay
decay, no_decay = list(), list()
for name, param in model.named_parameters():
    if name == "stopping_decision_head.projection.0.bias":
        no_decay.append(param)
    else:
        decay.append(param)

optimizer = torch.optim.AdamW([{'params': no_decay, 'weight_decay': 0, "lr": args.learning_rate}, {'params': decay}],
                                lr=args.learning_rate, weight_decay=args.weight_decay)

criterion = EarlyRewardLoss(alpha=args.alpha, epsilon=args.epsilon)

if args.resume and os.path.exists(args.snapshot):
    model.load_state_dict(torch.load(args.snapshot, map_location=args.device))
    optimizer_snapshot = os.path.join(os.path.dirname(args.snapshot),
                                        os.path.basename(args.snapshot).replace(".pth", "_optimizer.pth")
                                        )
    optimizer.load_state_dict(torch.load(optimizer_snapshot, map_location=args.device))
    df = pd.read_csv(args.snapshot + ".csv")
    train_stats = df.to_dict("records")
    start_epoch = train_stats[-1]["epoch"]
    print(f"resuming from {args.snapshot} epoch {start_epoch}")
else:
    train_stats = []
    start_epoch = 1

not_improved = 0


The model has 8,196 trainable parameters.


In [27]:
# helpers_training.py

import numpy as np 

def train_epoch(model, dataloader, optimizer, criterion, device, extra_padding_list=[0]):
    losses = []
    model.train()
    for batch in dataloader:
        for extra_padding in extra_padding_list:
            optimizer.zero_grad()
            X, y_true = batch
            X, y_true = X.to(device), y_true.to(device)
            dict_padding = {"extra_padding": extra_padding}
            log_class_probabilities, probability_stopping = model(X, **dict_padding)

            loss = criterion(log_class_probabilities, probability_stopping, y_true)

            if not loss.isnan().any():
                loss.backward()
                optimizer.step()

                losses.append(loss.cpu().detach().numpy())

    return np.stack(losses).mean()



In [28]:
trainloss = train_epoch(model, traindataloader, optimizer, criterion, device=args.device, extra_padding_list=extra_padding_list)


In [30]:
def test_epoch(model, dataloader, criterion, device, extra_padding_list=[0]):
    model.eval()

    stats = []
    losses = []
    slengths = []

    # sort the padding in descending order
    extra_padding_list = sorted(extra_padding_list, reverse=True)

    for ids, batch in enumerate(dataloader):
        X, y_true = batch
        X, y_true = X.to(device), y_true.to(device)

        seqlengths = (X[:,:,0] != 0).sum(1)
        slengths.append(seqlengths.cpu().detach())

        unpredicted_seq_mask = torch.ones(X.shape[0], dtype=bool).to(device) # mask for sequences that are not predicted yet
        # by default, we predict the sequence with the smallest padding
        extra_padding = extra_padding_list[-1]
        dict_padding = {"extra_padding": extra_padding}
        log_class_probabilities, probability_stopping, predictions_at_t_stop, t_stop = model.predict(X, **dict_padding)
        loss, stat = criterion(log_class_probabilities, probability_stopping, y_true, return_stats=True)
            
        if len(extra_padding_list) > 1:
            i=0 # index for the extra_padding_list
            while unpredicted_seq_mask.any() and i < len(extra_padding_list)-1:
                extra_padding = extra_padding_list[i]
                dict_padding = {"extra_padding": extra_padding}
                log_class_probabilities_temp, probability_stopping_temp, predictions_at_t_stop_temp, t_stop_temp = model.predict(X, **dict_padding)
                loss_temp, stat_temp = criterion(log_class_probabilities, probability_stopping, y_true, return_stats=True)
                # update the mask if t_stop is different from the length of the sequence (i.e. the sequence is predicted before its end)
                unpredicted_seq_mask = unpredicted_seq_mask*(t_stop >= seqlengths-extra_padding)
            
                # update the metrics data with the mask of predicted sequences
                log_class_probabilities = torch.where(~unpredicted_seq_mask.unsqueeze(1).unsqueeze(-1), log_class_probabilities_temp, log_class_probabilities)
                probability_stopping = torch.where(~unpredicted_seq_mask.unsqueeze(1), probability_stopping_temp, probability_stopping)
                predictions_at_t_stop = torch.where(~unpredicted_seq_mask, predictions_at_t_stop_temp, predictions_at_t_stop)
                t_stop = torch.where(~unpredicted_seq_mask, t_stop_temp, t_stop)
                loss = torch.where(~unpredicted_seq_mask, loss_temp, loss)
                stat = {k: np.where(~np.expand_dims(unpredicted_seq_mask.cpu().numpy(), axis=1), v, stat[k]) for k, v in stat_temp.items()}
                i+=1
            # mean: so not exact measure, as it takes into account predictions that are not kept at the end. Only indicative. 
            stat["classification_loss"] = stat["classification_loss"].mean() 
            stat["earliness_reward"] = stat["earliness_reward"].mean()
            loss = loss.mean()
        stat["loss"] = loss.cpu().detach().numpy()
        stat["probability_stopping"] = probability_stopping.cpu().detach().numpy()
        stat["class_probabilities"] = log_class_probabilities.exp().cpu().detach().numpy()
        stat["predictions_at_t_stop"] = predictions_at_t_stop.unsqueeze(-1).cpu().detach().numpy()
        stat["t_stop"] = t_stop.unsqueeze(-1).cpu().detach().numpy()
        stat["targets"] = y_true.cpu().detach().numpy()
        stat["ids"] = ids

        stats.append(stat)
        losses.append(loss.mean().cpu().detach().numpy())

    # list of dicts to dict of lists
    stats = {k: np.vstack([dic[k] for dic in stats]) for k in stats[0]}
    stats["seqlengths"] = torch.cat(slengths).numpy()
    stats["classification_earliness"] = np.mean(stats["t_stop"].flatten()/stats["seqlengths"])

    return np.stack(losses).mean(), stats

In [16]:
testloss, stats = test_epoch(model, testdataloader, criterion, args.device, extra_padding_list=extra_padding_list)


In [32]:
def test_epoch_old(model, dataloader, criterion, device):
    model.eval()

    stats = []
    losses = []
    slengths = []
    for ids, batch in enumerate(dataloader):
        X, y_true = batch
        X, y_true = X.to(device), y_true.to(device)

        seqlengths = (X[:,:,0] != 0).sum(1)
        slengths.append(seqlengths.cpu().detach())

        log_class_probabilities, probability_stopping, predictions_at_t_stop, t_stop = model.predict(X)
        loss, stat = criterion(log_class_probabilities, probability_stopping, y_true, return_stats=True)

        stat["loss"] = loss.cpu().detach().numpy()
        stat["probability_stopping"] = probability_stopping.cpu().detach().numpy()
        stat["class_probabilities"] = log_class_probabilities.exp().cpu().detach().numpy()
        stat["predictions_at_t_stop"] = predictions_at_t_stop.unsqueeze(-1).cpu().detach().numpy()
        stat["t_stop"] = t_stop.unsqueeze(-1).cpu().detach().numpy()
        stat["targets"] = y_true.cpu().detach().numpy()
        stat["ids"] = ids

        stats.append(stat)

        losses.append(loss.cpu().detach().numpy())

    # list of dicts to dict of lists
    stats = {k: np.vstack([dic[k] for dic in stats]) for k in stats[0]}
    stats["seqlengths"] = torch.cat(slengths).numpy()
    stats["classification_earliness"] = np.mean(stats["t_stop"].flatten()/stats["seqlengths"])

    return np.stack(losses).mean(), stats


In [53]:
# measure the time of the test_epoch function
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)

import time
start = time.time()
testloss, stats = test_epoch(model, testdataloader, criterion, args.device, extra_padding_list=[0])
end = time.time()
print("time test_epoch", end-start)
start = time.time()
testloss_old, stats_old = test_epoch_old(model, testdataloader, criterion, args.device)
end = time.time()
print("time test_epoch_old", end-start)
start = time.time()
testloss2, stats_old2 = test_epoch_old(model, testdataloader, criterion, args.device)
end = time.time()
print("time test_epoch_old2", end-start)


time test_epoch 63.539100646972656
time test_epoch_old 61.57906174659729
time test_epoch_old2 61.03752112388611


In [57]:
#both should give the same result: 
print("testloss", testloss)
print("testloss_old", testloss_old)
print("testloss2", testloss2)
print("difference:", np.abs(testloss-testloss_old))
print("difference2:", np.abs(testloss_old-testloss2))


testloss 4.576465
testloss_old 4.574134
testloss2 4.577822
difference: 0.0023312569
difference2: 0.0036883354


In [56]:
key = "targets"
# check that the results are the same for key 
print("allclose:", np.allclose(stats[key], stats_old[key]))
print("allclose2:", np.allclose(stats[key], stats_old2[key]))


allclose: True
allclose2: True


In [29]:
# ----------------------------- TRAINING -----------------------------
print("starting training...")
with tqdm(range(start_epoch, args.epochs + 1)) as pbar:
    for epoch in pbar:
        trainloss = train_epoch(model, traindataloader, optimizer, criterion, device=args.device, extra_padding_list=extra_padding_list)
        print("finished training for epoch ", epoch)
        testloss, stats = test_epoch(model, testdataloader, criterion, args.device, thresh_stop=0.8, extra_padding_list=extra_padding_list)

        # statistic logging and visualization...
        precision, recall, fscore, support = sklearn.metrics.precision_recall_fscore_support(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0], average="macro",
            zero_division=0)
        accuracy = sklearn.metrics.accuracy_score(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0])
        kappa = sklearn.metrics.cohen_kappa_score(
            stats["predictions_at_t_stop"][:, 0], stats["targets"][:, 0])

        classification_loss = stats["classification_loss"].mean()
        earliness_reward = stats["earliness_reward"].mean()
        earliness = 1 - (stats["t_stop"].mean() / (args.sequencelength - 1))
        harmonic_mean = harmonic_mean_score(accuracy, stats["classification_earliness"])

        # ----------------------------- LOGGING -----------------------------
        train_stats.append(
            dict(
                epoch=epoch,
                trainloss=trainloss,
                testloss=testloss,
                accuracy=accuracy,
                precision=precision,
                recall=recall,
                fscore=fscore,
                kappa=kappa,
                elects_earliness=earliness,
                classification_loss=classification_loss,
                earliness_reward=earliness_reward,
                classification_earliness=stats["classification_earliness"],
                harmonic_mean=harmonic_mean,
            )
        )
        fig_boxplot, ax_boxplot = plt.subplots(figsize=(15, 7))
        doys_dict = get_approximated_doys_dict(stats["seqlengths"], length_sorted_doy_dict_test)
        doys_stop = get_doy_stop(stats, doys_dict)
        fig_boxplot, _ = boxplot_stopping_times(doys_stop, stats, fig_boxplot, ax_boxplot, class_names)
        wandb.log({
                "loss": {"trainloss": trainloss, "testloss": testloss},
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "fscore": fscore,
                "kappa": kappa,
                "elects_earliness": earliness,
                "classification_loss": classification_loss,
                "earliness_reward": earliness_reward,
                "classification_earliness": stats["classification_earliness"],
                "harmonic_mean": harmonic_mean,
                "boxplot": wandb.Image(fig_boxplot),
                "conf_mat" : wandb.plot.confusion_matrix(probs=None,
                        y_true=stats["targets"][:,0], preds=stats["predictions_at_t_stop"][:,0],
                        class_names=class_names, title="Confusion Matrix")
            })
        plt.close(fig_boxplot)

        df = pd.DataFrame(train_stats).set_index("epoch")

        savemsg = ""
        if len(df) > 2:
            if testloss < df.testloss[:-1].values.min():
                savemsg = f"saving model to {args.snapshot}"
                os.makedirs(os.path.dirname(args.snapshot), exist_ok=True)
                torch.save(model.state_dict(), args.snapshot)

                optimizer_snapshot = os.path.join(os.path.dirname(args.snapshot),
                                                    os.path.basename(args.snapshot).replace(".pth", "_optimizer.pth")
                                                    )
                torch.save(optimizer.state_dict(), optimizer_snapshot)
                wandb.log_artifact(args.snapshot, type="model")  

                df.to_csv(args.snapshot + ".csv")
                not_improved = 0 # reset early stopping counter
            else:
                not_improved += 1 # increment early stopping counter
                if args.patience is not None:
                    savemsg = f"early stopping in {args.patience - not_improved} epochs."
                else:
                    savemsg = ""

        pbar.set_description(f"epoch {epoch}: trainloss {trainloss:.2f}, testloss {testloss:.2f}, "
                     f"accuracy {accuracy:.2f}, earliness {earliness:.2f}. "
                     f"classification loss {classification_loss:.2f}, earliness reward {earliness_reward:.2f}, harmonic mean {harmonic_mean:.2f}. {savemsg}")
        
            
        if args.patience is not None:
            if not_improved > args.patience:
                print(f"stopping training. testloss {testloss:.2f} did not improve in {args.patience} epochs.")
                break
    
# ----------------------------- SAVE FINAL MODEL -----------------------------
wandb.log_artifact(args.snapshot, type="model")

starting training...


  0%|          | 0/10 [00:00<?, ?it/s]

finished training for epoch  1


epoch 1: trainloss 5.31, testloss 6.09, accuracy 0.75, earliness 0.75. classification loss 15.43, earliness reward 3.26, harmonic mean 0.74. :  10%|█         | 1/10 [03:53<34:57, 233.07s/it]


KeyboardInterrupt: 

In [None]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

wandb: ERROR Control-C detected -- Run data was not synced
