# Train Notebook - for developing code 


In [1]:
import sys
import os 
os.environ['MPLCONFIGDIR'] = '/myhome'
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from data import BavarianCrops, BreizhCrops, SustainbenchCrops, ModisCDL
from torch.utils.data import DataLoader
from earlyrnn import EarlyRNN
import torch
from tqdm import tqdm
from loss import EarlyRewardLoss
import sklearn.metrics
import pandas as pd
import wandb
from utils.plots import plot_label_distribution_datasets, boxplot_stopping_times
from utils.doy import get_doys_dict_test, get_doy_stop, create_sorted_doys_dict_test, get_approximated_doys_dict
from utils.helpers_training import parse_args, train_epoch
from utils.helpers_testing import test_epoch
from utils.metrics import harmonic_mean_score
import matplotlib.pyplot as plt
from models.model_helpers import count_parameters

In [2]:
import argparse

def parse_args(args=None):
    def int_list(value):
        # This function will split the string by spaces and convert each to an integer
        return [int(i) for i in value.split()]
    
    parser = argparse.ArgumentParser(description='Run ELECTS Early Classification training on the BavarianCrops dataset.')
    parser.add_argument('--backbonemodel', type=str, default="LSTM", choices=["LSTM", "TempCNN", "Transformer"], help="backbone model")
    parser.add_argument('--dataset', type=str, default="bavariancrops", choices=["bavariancrops","breizhcrops", "ghana", "southsudan","unitedstates"], help="dataset")
    parser.add_argument('--alpha', type=float, default=0.5, help="trade-off parameter of earliness and accuracy (eq 6): "
                                                                 "1=full weight on accuracy; 0=full weight on earliness")
    parser.add_argument('--epsilon', type=float, default=10, help="additive smoothing parameter that helps the "
                                                                  "model recover from too early classifications (eq 7)")
    parser.add_argument('--learning-rate', type=float, default=1e-3, help="Optimizer learning rate")
    parser.add_argument('--weight-decay', type=float, default=0, help="weight_decay")
    parser.add_argument('--patience', type=int, default=30, help="Early stopping patience")
    parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        choices=["cuda", "cpu"], help="'cuda' (GPU) or 'cpu' device to run the code. "
                                                     "defaults to 'cuda' if GPU is available, otherwise 'cpu'")
    parser.add_argument('--epochs', type=int, default=100, help="number of training epochs")
    parser.add_argument('--sequencelength', type=int, default=70, help="sequencelength of the time series. If samples are shorter, "
                                                                "they are zero-padded until this length; "
                                                                "if samples are longer, they will be undersampled")
    parser.add_argument('--extra-padding-list', type=int_list, default=[[0]], nargs='+', help="extra padding for the TempCNN model")
    parser.add_argument('--hidden-dims', type=int, default=64, help="number of hidden dimensions in the backbone model")
    parser.add_argument('--batchsize', type=int, default=256, help="number of samples per batch")
    parser.add_argument('--dataroot', type=str, default=os.path.join(os.environ.get("HOME", os.environ.get("USERPROFILE")),"elects_data"), help="directory to download the "
                                                                                 "BavarianCrops dataset (400MB)."
                                                                                 "Defaults to home directory.")
    parser.add_argument('--snapshot', type=str, default="snapshots/model.pth",
                        help="pytorch state dict snapshot file")
    parser.add_argument('--left_padding', type=bool, default=False, help="left padding for the TempCNN model")
    parser.add_argument('--resume', action='store_true')

    if args is not None:
        args = parser.parse_args(args)
    else:
        args = parser.parse_args()

    if args.patience < 0:
        args.patience = None
    args.extra_padding_list = [item for sublist in args.extra_padding_list for item in sublist]

    return args


In [3]:
# Example of how to call parse_args with custom arguments in a notebook
custom_args = "--backbonemodel TempCNN --dataset breizhcrops --epochs 10 --hidden-dims 16 --sequencelength 70 --extra-padding-list 35 0".split()
args = parse_args(custom_args)
print("cuda is available: ", args.device)
print(args)

cuda is available:  cuda
Namespace(backbonemodel='TempCNN', dataset='breizhcrops', alpha=0.5, epsilon=10, learning_rate=0.001, weight_decay=0, patience=30, device='cuda', epochs=10, sequencelength=70, extra_padding_list=[35, 0], hidden_dims=16, batchsize=256, dataroot='C:\\Users\\anyam\\elects_data', snapshot='snapshots/model.pth', left_padding=False, resume=False)


In [4]:
dir = "/mydata/studentanya/anya/wandb/"
tag_padding = "left padding" if args.left_padding else "right padding"
wandb.init(
        dir=None,
        project="ELECTS",
        notes="first experimentations with ELECTS",
        tags=["ELECTS", args.dataset, "with_doys_boxplot", args.backbonemodel, tag_padding, "in notebook_train"],
        config={
        "backbonemodel": args.backbonemodel,
        "dataset": args.dataset,
        "alpha": args.alpha,
        "epsilon": args.epsilon,
        "learning_rate": args.learning_rate,
        "weight_decay": args.weight_decay,
        "patience": args.patience,
        "device": args.device,
        "epochs": args.epochs,
        "sequencelength": args.sequencelength,
        "extra_padding_list": args.extra_padding_list,
        "hidden_dims": args.hidden_dims,
        "batchsize": args.batchsize,
        "dataroot": args.dataroot,
        "snapshot": args.snapshot,
        "resume": args.resume,
        "left_padding": args.left_padding,
        "architecture": "EarlyRNN",
        "optimizer": "AdamW",
        "criterion": "EarlyRewardLoss",
        }
    )


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: aurenore. Use `wandb login --relogin` to force relogin


In [5]:
# ----------------------------- LOAD DATASET -----------------------------

if args.dataset == "bavariancrops":
    dataroot = os.path.join(args.dataroot,"bavariancrops")
    nclasses = 7
    input_dim = 13
    class_weights = None
    train_ds = BavarianCrops(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = BavarianCrops(root=dataroot,partition="valid", sequencelength=args.sequencelength)
    class_names = test_ds.classes
elif args.dataset == "unitedstates":
    args.dataroot = "/data/modiscdl/"
    args.sequencelength = 24
    dataroot = args.dataroot
    nclasses = 8
    input_dim = 1
    train_ds = ModisCDL(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = ModisCDL(root=dataroot,partition="valid", sequencelength=args.sequencelength)
elif args.dataset == "breizhcrops":
    dataroot = os.path.join(args.dataroot,"breizhcrops")
    nclasses = 9
    input_dim = 13
    print("get doys dict test")
    doys_dict_test = get_doys_dict_test(dataroot=os.path.join(args.dataroot,args.dataset))
    length_sorted_doy_dict_test = create_sorted_doys_dict_test(doys_dict_test)
    print("get doys dict test done")
    print("get train and validation data...")
    train_ds = BreizhCrops(root=dataroot,partition="train", sequencelength=args.sequencelength)
    test_ds = BreizhCrops(root=dataroot,partition="valid", sequencelength=args.sequencelength)
    class_names = test_ds.ds.classname
    print("class names:", class_names)
elif args.dataset in ["ghana"]:
    use_s2_only = False
    average_pixel = False
    max_n_pixels = 50
    dataroot = args.dataroot
    nclasses = 4
    input_dim = 12 if use_s2_only else 19  # 12 sentinel 2 + 3 x sentinel 1 + 4 * planet
    args.epochs = 500
    args.sequencelength = 365
    train_ds = SustainbenchCrops(root=dataroot,partition="train", sequencelength=args.sequencelength,
                                    country="ghana",
                                    use_s2_only=use_s2_only, average_pixel=average_pixel,
                                    max_n_pixels=max_n_pixels)
    val_ds = SustainbenchCrops(root=dataroot,partition="val", sequencelength=args.sequencelength,
                                country="ghana", use_s2_only=use_s2_only, average_pixel=average_pixel,
                                max_n_pixels=max_n_pixels)

    train_ds = torch.utils.data.ConcatDataset([train_ds, val_ds])

    test_ds = SustainbenchCrops(root=dataroot,partition="test", sequencelength=args.sequencelength,
                                country="ghana", use_s2_only=use_s2_only, average_pixel=average_pixel,
                                max_n_pixels=max_n_pixels)
    class_names = test_ds.classes
elif args.dataset in ["southsudan"]:
    use_s2_only = False
    dataroot = args.dataroot
    nclasses = 4
    args.sequencelength = 365
    input_dim = 12 if use_s2_only else 19 # 12 sentinel 2 + 3 x sentinel 1 + 4 * planet
    args.epochs = 500
    train_ds = SustainbenchCrops(root=dataroot,partition="train", sequencelength=args.sequencelength, country="southsudan", use_s2_only=use_s2_only)
    val_ds = SustainbenchCrops(root=dataroot,partition="val", sequencelength=args.sequencelength, country="southsudan", use_s2_only=use_s2_only)

    train_ds = torch.utils.data.ConcatDataset([train_ds, val_ds])
    test_ds = SustainbenchCrops(root=dataroot, partition="val", sequencelength=args.sequencelength,
                                country="southsudan", use_s2_only=use_s2_only)
    class_names = test_ds.classes

else:
    raise ValueError(f"dataset {args.dataset} not recognized")

traindataloader = DataLoader(
    train_ds,
    batch_size=args.batchsize)
testdataloader = DataLoader(
    test_ds,
    batch_size=args.batchsize)


get doys dict test
get doys dict test done
get train and validation data...
2559635960 2559635960


loading data into RAM: 100%|██████████| 178613/178613 [01:17<00:00, 2316.30it/s]


2253658856 2253658856


loading data into RAM: 100%|██████████| 140645/140645 [00:59<00:00, 2379.38it/s]


2493572704 2493572704


loading data into RAM: 100%|██████████| 166391/166391 [01:11<00:00, 2339.74it/s]


class names: ['barley' 'wheat' 'rapeseed' 'corn' 'sunflower' 'orchards' 'nuts'
 'permanent meadows' 'temporary meadows']


In [6]:
# ----------------------------- VISUALIZATION: label distribution -----------------------------
# datasets = [train_ds, test_ds]
# sets_labels = ["Train", "Validation"]
# fig, ax = plt.subplots(figsize=(15, 7))
# fig, ax = plot_label_distribution_datasets(datasets, sets_labels, fig, ax, title='Label distribution', labels_names=class_names)
# wandb.log({"label_distribution": wandb.Image(fig)})
# plt.close(fig)  


In [6]:
step_timestamp_padding = 25
sequence_lengths_train = [step_timestamp_padding*i for i in range(1, args.sequencelength//step_timestamp_padding+1)]
print("sequence_lengths_train", sequence_lengths_train)
extra_padding_list = [args.sequencelength - i for i in sequence_lengths_train]
print("extra_padding_list", extra_padding_list)

sequence_lengths_train [25, 50]
extra_padding_list [45, 20]


In [7]:
# ----------------------------- SET UP MODEL -----------------------------
#model = EarlyRNN(nclasses=nclasses, input_dim=input_dim, hidden_dims=64, sequencelength=args.sequencelength).to(args.device)
# nclasses=9
# input_dim=13
model = EarlyRNN(args.backbonemodel, nclasses=nclasses, input_dim=input_dim, sequencelength=args.sequencelength, kernel_size=7, hidden_dims=args.hidden_dims, left_padding=args.left_padding).to(args.device)
print(f"The model has {count_parameters(model):,} trainable parameters.")

#optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

# exclude decision head linear bias from weight decay
decay, no_decay = list(), list()
for name, param in model.named_parameters():
    if name == "stopping_decision_head.projection.0.bias":
        no_decay.append(param)
    else:
        decay.append(param)

optimizer = torch.optim.AdamW([{'params': no_decay, 'weight_decay': 0, "lr": args.learning_rate}, {'params': decay}],
                                lr=args.learning_rate, weight_decay=args.weight_decay)

criterion = EarlyRewardLoss(alpha=args.alpha, epsilon=args.epsilon)

if args.resume and os.path.exists(args.snapshot):
    model.load_state_dict(torch.load(args.snapshot, map_location=args.device))
    optimizer_snapshot = os.path.join(os.path.dirname(args.snapshot),
                                        os.path.basename(args.snapshot).replace(".pth", "_optimizer.pth")
                                        )
    optimizer.load_state_dict(torch.load(optimizer_snapshot, map_location=args.device))
    df = pd.read_csv(args.snapshot + ".csv")
    train_stats = df.to_dict("records")
    start_epoch = train_stats[-1]["epoch"]
    print(f"resuming from {args.snapshot} epoch {start_epoch}")
else:
    train_stats = []
    start_epoch = 1

not_improved = 0


The model has 8,196 trainable parameters.


In [8]:
# ----------------------------- TRAINING -----------------------------
print("starting training...")
with tqdm(range(start_epoch, args.epochs + 1)) as pbar:
    for epoch in pbar:
        trainloss = train_epoch(model, traindataloader, optimizer, criterion, device=args.device, extra_padding_list=extra_padding_list)
        print("finished training for epoch ", epoch)
        testloss, stats = test_epoch(model, testdataloader, criterion, args.device, extra_padding_list=extra_padding_list, return_id=test_ds.return_id)

        # statistic logging and visualization...
        precision, recall, fscore, support = sklearn.metrics.precision_recall_fscore_support(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0], average="macro",
            zero_division=0)
        accuracy = sklearn.metrics.accuracy_score(
            y_pred=stats["predictions_at_t_stop"][:, 0], y_true=stats["targets"][:, 0])
        kappa = sklearn.metrics.cohen_kappa_score(
            stats["predictions_at_t_stop"][:, 0], stats["targets"][:, 0])

        classification_loss = stats["classification_loss"].mean()
        earliness_reward = stats["earliness_reward"].mean()
        earliness = 1 - (stats["t_stop"].mean() / (args.sequencelength - 1))
        harmonic_mean = harmonic_mean_score(accuracy, stats["classification_earliness"])

        # ----------------------------- LOGGING -----------------------------
        train_stats.append(
            dict(
                epoch=epoch,
                trainloss=trainloss,
                testloss=testloss,
                accuracy=accuracy,
                precision=precision,
                recall=recall,
                fscore=fscore,
                kappa=kappa,
                elects_earliness=earliness,
                classification_loss=classification_loss,
                earliness_reward=earliness_reward,
                classification_earliness=stats["classification_earliness"],
                harmonic_mean=harmonic_mean,
            )
        )
        fig_boxplot, ax_boxplot = plt.subplots(figsize=(15, 7))
        doys_dict = get_approximated_doys_dict(stats["seqlengths"], length_sorted_doy_dict_test)
        doys_stop = get_doy_stop(stats, doys_dict)
        fig_boxplot, _ = boxplot_stopping_times(doys_stop, stats, fig_boxplot, ax_boxplot, class_names)
        wandb.log({
                "loss": {"trainloss": trainloss, "testloss": testloss},
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "fscore": fscore,
                "kappa": kappa,
                "elects_earliness": earliness,
                "classification_loss": classification_loss,
                "earliness_reward": earliness_reward,
                "classification_earliness": stats["classification_earliness"],
                "harmonic_mean": harmonic_mean,
                "boxplot": wandb.Image(fig_boxplot),
                "conf_mat" : wandb.plot.confusion_matrix(probs=None,
                        y_true=stats["targets"][:,0], preds=stats["predictions_at_t_stop"][:,0],
                        class_names=class_names, title="Confusion Matrix")
            })
        plt.close(fig_boxplot)

        df = pd.DataFrame(train_stats).set_index("epoch")

        savemsg = ""
        if len(df) > 2:
            if testloss < df.testloss[:-1].values.min():
                savemsg = f"saving model to {args.snapshot}"
                os.makedirs(os.path.dirname(args.snapshot), exist_ok=True)
                torch.save(model.state_dict(), args.snapshot)

                optimizer_snapshot = os.path.join(os.path.dirname(args.snapshot),
                                                    os.path.basename(args.snapshot).replace(".pth", "_optimizer.pth")
                                                    )
                torch.save(optimizer.state_dict(), optimizer_snapshot)
                wandb.log_artifact(args.snapshot, type="model")  

                df.to_csv(args.snapshot + ".csv")
                not_improved = 0 # reset early stopping counter
            else:
                not_improved += 1 # increment early stopping counter
                if args.patience is not None:
                    savemsg = f"early stopping in {args.patience - not_improved} epochs."
                else:
                    savemsg = ""

        pbar.set_description(f"epoch {epoch}: trainloss {trainloss:.2f}, testloss {testloss:.2f}, "
                     f"accuracy {accuracy:.2f}, earliness {earliness:.2f}. "
                     f"classification loss {classification_loss:.2f}, earliness reward {earliness_reward:.2f}, harmonic mean {harmonic_mean:.2f}. {savemsg}")
        
            
        if args.patience is not None:
            if not_improved > args.patience:
                print(f"stopping training. testloss {testloss:.2f} did not improve in {args.patience} epochs.")
                break
    
# ----------------------------- SAVE FINAL MODEL -----------------------------
wandb.log_artifact(args.snapshot, type="model")

starting training...


  0%|          | 0/10 [00:00<?, ?it/s]

finished training for epoch  1


  0%|          | 0/10 [04:31<?, ?it/s]


KeyboardInterrupt: 

In [9]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))