# End2end ConvGRU

> Autoencoder + forecaster in the same training loop. Based on (https://github.com/tcapelle/moving_mnist/blob/master/01_train_example.ipynb) and (https://github.com/tcapelle/moving_mnist/blob/master/02_train_cross_entropy_loss-Copy1.ipynb)

In [None]:
import sys
sys.path.append('..')
from fastai.vision.all import *
from mocatml.utils import *
convert_uuids_to_indices()
from mocatml.data import *
from mocatml.models.conv_rnn import *
from mygrad import sliding_window_view
from tsai.imports import my_setup
from tsai.utils import yaml2dict, dict2attrdict
from fastai.callback.schedule import valley, steep
from fastai.callback.wandb import WandbCallback
import wandb

In [None]:
from fastai.callback.schedule import LRFinder

@patch_to(LRFinder)
def after_fit(self):
    self.learn.opt.zero_grad() # Needed before detaching the optimizer for future fits
    tmp_f = self.path/self.model_dir/self.tmp_p/'_tmp.pth'
    if tmp_f.exists():
        self.learn.load(f'{self.tmp_p}/_tmp', with_opt=True, device='cpu')
        self.tmp_d.cleanup()

In [None]:
my_setup()

In [None]:
config_base = yaml2dict('./config/base.yaml', attrdict=True)
config_base.convgru = yaml2dict('./config/convgru/convgru.yaml', attrdict=True)
#config = AttrDict({**config_base, **config_e2e})
config = AttrDict(config_base)
config

In [None]:
# Set device
default_device(0 if config.device == 'cpu' else config.device)

In [None]:
run = wandb.init(dir=ifnone(config.wandb.dir, '../'),
                 project=config.wandb.project, 
                 config=config,
                 group=config.wandb.group,
                 mode=config.wandb.mode, 
                 anonymous='never') if config.wandb.enabled else None
config = dict2attrdict(run.config) if config.wandb.enabled else config
print(config)

In [None]:
data = np.load(Path(config.data.path).expanduser(), 
               mmap_mode='c' if config.mmap else None)
data = data[:, :config.sel_steps]
data.shape

In [None]:
data_sw = sliding_window_view(data, (data.shape[0], config.lookback + config.horizon, 
                                 data.shape[-2], data.shape[-1]), 
                                 (data.shape[0], config.stride, 
                                  data.shape[-2], data.shape[-1]))
samples_per_simulation = data_sw.shape[1]
data_sw = data_sw.squeeze().transpose([1,0,2,3,4])
data_sw = data_sw.reshape(-1, *data_sw.shape[2:])
data_sw.shape

Split and get Normalization statistics from training set (mean and standard deviation)

In [None]:
# Split by simulation
splits = RandomSplitter()(data)
splits

In [None]:
ds = DensityData(data_sw, lbk=config.lookback, h=config.horizon)
train_idxs = calculate_sample_idxs(splits[0], samples_per_simulation)
valid_idxs = calculate_sample_idxs(splits[1], samples_per_simulation)
len(train_idxs), len(valid_idxs)

In [None]:
mocat_stats = (np.mean(data[splits[0]]), np.std(data[splits[0]]))
mocat_stats

In [None]:
# Create dataloaders
train_tl = TfmdLists(train_idxs, DensityTupleTransform(ds))
valid_tl = TfmdLists(valid_idxs, DensityTupleTransform(ds))
dls = DataLoaders.from_dsets(train_tl, valid_tl, bs=config.bs, device=default_device(),
                            after_batch=[Normalize.from_stats(*mocat_stats)] if \
                             config.normalize else None,
                            num_workers=config.num_workers)
dls.show_batch()
foo, bar = dls.one_batch()
len(foo), len(bar)

In [None]:
loss_func = StackLoss(MSELossFlat())

In [None]:
config.convgru.norm = NormType.Batch if config.convgru.norm == 'batch' else None
model = StackUnstack(SimpleModel(**config.convgru)).to(default_device())
wandbc = WandbCallback(log_preds=False, log_model=False) if config.wandb.enabled else None
cbs = L() + wandbc
learn = Learner(dls, model, loss_func=loss_func, cbs=cbs).to_fp16()
lr_max = config.lr_max if config.lr_max is not None else learn.lr_find()

In [None]:
learn.fit_one_cycle(config.n_epoch, lr_max=lr_max)

In [None]:
p,t = learn.get_preds()
len(p), p[0].shape

In [None]:
def show_res(t, idx, figsize=(8,4)):
    density_seq = DensitySeq.create([t[i][idx] for i in range(len(t))])
    density_seq.show(figsize=figsize);

In [None]:
k = random.randint(0, dls.valid.n)
figsize=(12,8)
print(k)
show_res(t,k, figsize=figsize)
show_res(p,k, figsize=figsize)

In [None]:
#|hide
# Print the validation loss and save it in case other notebooks (optuna) wants to
# use it for hyperparameter optimization
valid_loss = learn.validate()[0] 
print(valid_loss)
%store valid_loss

In [None]:
# Remove the wandb callback to avoid errors when downloading the learner
if config.wandb.enabled:
    learn.remove_cb(wandbc)

# Save locally and in wandb if online and enabled
learn.model_dir = config.tmp_folder
learn.save('model', with_opt=True)
learn.export(f'{config.tmp_folder}/learner.pkl')
if run is not None and config.wandb.log_learner:
    # Save the learner (all tmp/dls, tmp/model.pth, and tmp/learner.pkl). 
    run.log_artifact(config.tmp_folder, type='learner', name='density-forecaster')

In [None]:
if run is not None:
    run.finish()