<a href="https://colab.research.google.com/github/Krankile/ensemble_forecasting/blob/main/notebooks/autoencoder/1_fit_lstm_ae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%capture
!pip install wandb --upgrade

In [None]:
%%capture
!git clone https://github.com/Krankile/ensemble_forecasting.git
!mv ensemble_forecasting ef

In [None]:
%%capture
!cd ef && git pull

In [None]:
import wandb as wb
wb.login()

[34m[1mwandb[0m: Currently logged in as: [33mkrankile[0m (use `wandb login --relogin` to force relogin)


True

In [None]:
import os
import copy
import random
from datetime import datetime as dt
import psutil
from collections import defaultdict
from pathlib import Path
import json
from importlib import reload
from functools import partial

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from tqdm.notebook import tqdm

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torch.optim import lr_scheduler

from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence, pad_sequence, PackedSequence

from ef.models import lstm_autoencoders
from ef.plotting.ae_plot import plot_examples
from ef.utils import normalizers, schedulers, optimizers
from ef.data import autoencoder_loaders

plt.rcParams["figure.figsize"] = (16, 8)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Run training

### Setup and start training loop

In [None]:
def train_model(model, train_loader, val_loader, conf):

    optimizer = optimizers[conf.optimizer](model.parameters(), lr=conf.learning_rate, weight_decay=conf.weight_decay)
    criterion = nn.L1Loss(reduction="mean").to(device)
    scheduler = schedulers[conf.scheduler["name"] if conf.scheduler else None](optimizer, **conf.scheduler["kwargs"] if conf.scheduler else {})
    scaler = torch.cuda.amp.GradScaler()

    best_loss = float("inf")
    b_size = conf.batch_size
    step = 0
    example_data, example_lens, _ = next(iter(val_loader))
    example_packed = pack_padded_sequence(example_data, example_lens, batch_first=True, enforce_sorted=False)
    history = dict(train=[], val=[])

    it = tqdm(range(1, conf.epochs + 1))
    for epoch in it:
        
        # Training part of epoch
        model = model.train()
        train_losses = []
        for seq_true, lens, _ in train_loader:
            optimizer.zero_grad()
            packed_true = pack_padded_sequence(seq_true, lens, batch_first=True, enforce_sorted=False).to(device)
            seq_true = seq_true.to(device)

            with torch.cuda.amp.autocast():
                seq_pred = model(packed_true, lens)
                loss = criterion(seq_pred.data, packed_true.data)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss.item())

            step += 1

        # Validation part of epoch
        val_losses = []
        model = model.eval()
        with torch.no_grad():
            for seq_true, lens, _ in val_loader:
                seq_true = seq_true.to(device)
                packed_true = pack_padded_sequence(seq_true, lens, batch_first=True, enforce_sorted=False).to(device)
                seq_pred = model(packed_true, lens)
                
                loss = criterion(seq_pred.data, packed_true.data)
                val_losses.append(loss.item())

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)

        scheduler.step()

        history["train"].append(train_loss)
        history["val"].append(val_loss)

        wb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch, "examples": step*b_size, "lr": optimizer.param_groups[0]["lr"]}, step=step)

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())

            filepath = "model.torch"
            torch.save(best_model_wts, filepath)
            wb.save(filepath)

            figurepath = "best_val_plot.png"
            plot_examples(
                figurepath,
                example_data,
                example_packed.to(device),
                model,
                lens=example_lens,
                conf=conf,
            )
            wb.log({"best_val_expl": wb.Image(figurepath), "epoch": epoch}, step=step)

        if (epoch + 1) % 10 == 0:
            figurepath = "periodic_val_plot.png"
            plot_examples(
                figurepath,
                example_data,
                example_packed.to(device),
                model,
                lens=example_lens,
                conf=conf,
            )
            wb.log({"periodic_val_expl": wb.Image(figurepath), "epoch": epoch}, step=step)

        it.set_postfix(
            train_loss=train_loss,
            val_loss=val_loss,
            lr=f"{optimizer.param_groups[0]['lr']:.2e}",
        )

        # Code for early stopping
        if conf.get("early_stop") is not None and early_stop(history, ma1=5, ma2=20, em=1.05):
            wb.log({"early_stop": True, "epoch": epoch}, step=step)
            break

    model.load_state_dict(best_model_wts)
    return model.eval(), filepath

In [None]:
def df_from_art(run, artname, *, root="krankile/data-processing/"):
        art = run.use_artifact(root + artname); art.download()
        df = pd.read_feather(art.file()).set_index("m4id")
        return df


In [None]:
def train(config=None, project=None, entity=None, enablewb=True):
    mode = "online" if enablewb else "disabled"
    with wb.init(config=config, project=project, entity=entity, job_type="training", mode=mode) as run:
        conf = run.config

        series, info, split = map(partial(df_from_art, run), ("series_traval:latest", "info_traval:latest", "traval_split_80_20:latest"))
        traidx, validx = split[split.val == False].index, split[split.val == True].index
        tra_data = dict(
            series=series.loc[traidx],
            info=info.loc[traidx],
        )
        val_data = dict(
            series=series.loc[validx],
            info=info.loc[validx],
        )

        (train_loader,
         val_loader,
         seq_len,
         n_features) = autoencoder_loaders(run, tra_data, val_data, cpus=None)

        model = lstm_autoencoders[conf.architecture](
            seq_len=seq_len,
            n_features=n_features,
            embedding_dim=conf.embedding_dim,
            hidden_dim=conf.hidden_dim,
            dropout=conf.dropout,
            num_layers=conf.num_layers,
        )

        print(f"Moving model {conf.architecture} to device: {device}")
        model = model.to(device)

        model, savepath = train_model(
            model,
            train_loader,
            val_loader,
            conf=conf,
        )

        artifact = wb.Artifact(conf.architecture, type='lstm-ae-model', metadata={"config": json.dumps(dict(conf))})
        # Add a file to the artifact's contents
        artifact.add_file(savepath)
        # Save the artifact version to W&B and mark it as the output of this run
        run.log_artifact(artifact)
    return model
    

### Standalone training

#### Config

In [None]:
config = dict(
    epochs=500,
    maxlen=250,
    embedding_dim=32,
    hidden_dim=128,
    learning_rate=0.002,
    architecture="RecurrentAutoencoderV4",
    num_layers=2,
    batch_size=256*2,
    optimizer="adamw",
    dropout=0.2,
    normalize_data="normal",
    weight_decay=0.005,
    scheduler=None  # {"name": "MultiStepLR", "kwargs": {"milestones": [100, 200, 400, 800], "gamma": 0.5}},
)

#### Start

In [None]:
enablewb = True
sweepid = None

if sweepid is not None:
    count = 100 # number of runs to execute
    wb.agent(sweepid, function=partial(train, config=config), count=count)
else:
    model = train(config=config, project="lstm-ae-tmp", entity="krankile", enablewb=enablewb)