## Example script: training and predicting with ConvLSTM
Author          : SSI project team Wadden Sea <br>
First Built     : 2021.08.01 <br>
Last Update     : 2021.08.12 <br>
Description     : This notebook serves as an example of training and predicting with
                  Convolutional Long-Short Term Memeory Neural Network (ConvLSTM). <br>
Dependency      : os, numpy, pytorch <br>
Return Values   : time series / array <br>
Caveat!         : This module performs many-to-one prediction! It supports CUDA. <br>


In [1]:
import os
import numpy as np
import torch
import torch.nn.functional
from torch.autograd import Variable

# import convlstm
sys.path.append("../src")
import convlstm

### Path

In [2]:
# please specify output path for the model
output_path = './model'
if not os.path.exists(output_path):
   os.makedirs(output_path, exist_ok = True)

### Hyper-parameter of neural network

In [3]:
input_channels = 6 # number of input channels e.g. concentration heatmap, current, wind curl, etc.
hidden_channels = [6, 3, 1] # the last digit is the output channel
kernel_size = 3
batch_size = 1
learning_rate = 0.01
num_epochs = 20

### Hardware info and version of pytorch

In [4]:
print ("Pytorch version {}".format(torch.__version__))
# check if CUDA is available
use_cuda = torch.cuda.is_available()
# use GPU if possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device to be used for computation: {}".format(device))

Pytorch version 1.8.1
Device to be used for computation: cpu


### Initialize model

In [5]:
# initialize our model
model = convlstm.ConvLSTM(input_channels, hidden_channels, kernel_size).to(device)
# choose loss function
loss_fn = torch.nn.MSELoss()
# choose optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# check the model / loss function and optimizer
print(model)
print(loss_fn)
print(optimizer)

ConvLSTM(
  (cell0): ConvLSTMCell(
    (Wxi): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Whi): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (Wxf): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Whf): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (Wxc): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Whc): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (Wxo): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Who): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (cell1): ConvLSTMCell(
    (Wxi): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Whi): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (Wxf): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Whf): Conv2d(3, 3, kernel_size=(3, 3), stride=

### Create dummy data for testing

In [6]:
# training data
train_X_torch = Variable(torch.randn(10, 6, 5, 5)).to(device) # shape [timestep, channel, height, width]
train_y_torch = Variable(torch.randn(10, 6, 5, 5)).float().to(device)
# testing data
test_X_torch = Variable(torch.randn(2, 6, 5, 5)).to(device) # shape [timestep, channel, height, width]
test_y_torch = Variable(torch.randn(2, 6, 5, 5)).float().to(device)
# get length of training set
train_steps, channels, height, width = train_X_torch.shape

### Training

In [7]:
hist = np.zeros(num_epochs) # save the loss for every epoch
for epoch in range(num_epochs):
    # Clear stored gradient
    model.zero_grad()
    # loop through all timesteps
    for t in range(train_steps):
        var_X = torch.autograd.Variable(train_X_torch[t,:,:,:].view(-1,channels,height,width)).to(device) # record gradient
        var_y = torch.autograd.Variable(train_y_torch[t,:,:,:].view(-1,1,height,width)).to(device) # record gradient
        # Forward process
        pred_y, _ = model(var_X, t)
        # compute loss
        if t == 0:
            loss = loss_fn(pred_y, var_y)
        else:
            loss += loss_fn(pred_y, var_y)
    if epoch % 2 == 0:
        print("Epoch ", epoch, "MSE: ", loss.item())
    hist[epoch] = loss.item()

    # Zero out gradient, else they will accumulate between epochs
    optimizer.zero_grad()
    
    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()

# save the general checkpoint
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item()
            }, os.path.join(output_path,'Lagrangian_ML_training_checkpoint.pt'))
print("The checkpoint of the model and training status is saved.")

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch  0 MSE:  10.254456520080566
Epoch  2 MSE:  10.23924732208252
Epoch  4 MSE:  10.231311798095703
Epoch  6 MSE:  10.17760181427002
Epoch  8 MSE:  10.135177612304688
Epoch  10 MSE:  10.110885620117188
Epoch  12 MSE:  10.123470306396484
Epoch  14 MSE:  9.997820854187012
Epoch  16 MSE:  9.983427047729492
Epoch  18 MSE:  9.888792037963867
The checkpoint of the model and training status is saved.
