## Install module

In [None]:
!pip install pytorch-ignite

## Dataset


In [None]:
!curl -o mnist_test_seq.npy http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy

In [None]:
import numpy as np
from torch.utils.data import Dataset


class MovingMnistDataset(Dataset):
    def __init__(self, path="./mnist_test_seq.npy", phase_train=True):
        self.data = np.load(path)
        # (t, N, H, W) -> (N, t, C, H, W)
        self.data = self.data.transpose(1, 0, 2, 3)[:, :, None, ...]
        if phase_train:
            self.data = self.data[:1000]
        else:
            self.data = self.data[9000:]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return (self.data[i, :10, ...] / 255).astype(np.float32), (self.data[i, 10:, ...]/255).astype(np.float32)

## Network

In [None]:
"""
Copyright (c) 2020 Masafumi Abeta. All Rights Reserved.
Released under the MIT license
"""
import math
import uuid
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair


class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, hidden_channels,
                 kernel_size, stride=1, image_size=None):
        """ConvLSTM cell.

        Parameters
        ----------
        in_channels: int
            Number of channels of input tensor.
        hidden_channels: int
            Number of channels of hidden state.
        kernel_size: int or (int, int)
            Size of the convolutional kernel.
        stride: int or (int, int)
            Stride of the convolution.
        image_size: (int, int)
            Shape of image.
        """

        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)

        # No bias for hidden, since bias is included in observation convolution
        # Pad the hidden layer so that the input and output sizes are equal
        self.Wxi = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Whi = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)
        self.Wxf = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Whf = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)
        self.Wxg = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Whg = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)
        self.Wxo = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Who = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)

    def forward(self, x, hidden_state):
        """
        Parameters
        ----------
        x: torch.Tensor
            4-D Tensor of shape (b, c, h, w).
        hs: tuple
            Previous hidden state of shape (h_0, c_0).

        Returns
        -------
            h_next, c_next
        """

        h_prev, c_prev = hidden_state
        i = torch.sigmoid(self.Wxi(x) + self.Whi(h_prev))
        f = torch.sigmoid(self.Wxf(x) + self.Whf(h_prev))
        o = torch.sigmoid(self.Wxo(x) + self.Who(h_prev))
        g = torch.tanh(self.Wxg(x) + self.Whg(h_prev))

        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next


class ConvLSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels,
                 kernel_size, stride=1, image_size=None):
        """ConvLSTM.

        Parameters
        ----------
        in_channels: int
            Number of channels of input tensor.
        hidden_channels: int
            Number of channels of hidden state.
        kernel_size: int or (int, int)
            Size of the convolutional kernel.
        stride: int or (int, int)
            Stride of the convolution.
        image_size: (int, int)
            Shape of image.
        """

        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.image_size = image_size

        self.lstm_cell = ConvLSTMCell(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=self.image_size)

    def forward(self, xs, hidden_state=None):
        """
        Parameters
        ----------
        xs: torch.Tensor
            5-D Tensor of shape (b, t, c, h, w).
        hs: list
            Previous hidden state of shape (h_0, c_0).

        Returns
        -------
            last_state_list, layer_output
        """

        batch_size, sequence_length, _, height, width = xs.size()

        if hidden_state is None:
            hidden_state = (torch.zeros(batch_size, self.hidden_channels, height, width, device=xs.device),
                            torch.zeros(batch_size, self.hidden_channels, height, width, device=xs.device))

        output_list = []
        for t in range(sequence_length):
            hidden_state = self.lstm_cell(xs[:, t, ...], hidden_state)
            h, _ = hidden_state
            output_list.append(h)

        output = torch.stack(output_list, dim=1)

        return output, hidden_state


class Conv2dStaticSamePadding(nn.Conv2d):
    """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
       The padding mudule is calculated in construction function, then used in forward.

        # Copyright: lukemelas (github username)
        # Released under the MIT License <https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/LICENSE>
        # <https://github.com/lukemelas/EfficientNet-PyTorch/blob/4d63a1f77eb51a58d6807a384dda076808ec02c0/efficientnet_pytorch/utils.py>
    """

    # With the same calculation as Conv2dDynamicSamePadding
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
        self.stride = self.stride if len(self.stride) == 2 else [
            self.stride[0]] * 2

        # Calculate padding based on image size and save it
        assert image_size is not None
        ih, iw = (image_size, image_size) if isinstance(
            image_size, int) else image_size
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max((oh - 1) * self.stride[0] +
                    (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] +
                    (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            self.static_padding = nn.ZeroPad2d(
                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
        else:
            self.static_padding = nn.Identity()

    def forward(self, x):
        x = self.static_padding(x)
        x = F.conv2d(x, self.weight, self.bias, self.stride,
                     self.padding, self.dilation, self.groups)
        return x


In [None]:
import torch
import torch.nn as nn


class ConvLSTMEncoderPredictor(nn.Module):
    def __init__(self, image_size):
        """ConvLSTM Encoder Predictor.

        Parameters
        ----------
        image_size: (int, int)
            Shape of image.
        """

        super().__init__()

        self.encoder_1 = ConvLSTM(
            in_channels=1, hidden_channels=64, kernel_size=3, stride=1, image_size=image_size)
        self.encoder_2 = ConvLSTM(
            in_channels=64, hidden_channels=64, kernel_size=3, stride=1, image_size=image_size)
        self.encoder_3 = ConvLSTM(
            in_channels=64, hidden_channels=64, kernel_size=3, stride=1, image_size=image_size)

        self.predictor_1 = ConvLSTM(
            in_channels=64, hidden_channels=64, kernel_size=3, stride=1, image_size=image_size)
        self.predictor_2 = ConvLSTM(
            in_channels=64, hidden_channels=64, kernel_size=3, stride=1, image_size=image_size)
        self.predictor_3 = ConvLSTM(
            in_channels=64, hidden_channels=64, kernel_size=3, stride=1, image_size=image_size)

        self.conv2d = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        x, hidden_state_1 = self.encoder_1(x)
        x, hidden_state_2 = self.encoder_2(x)
        x, hidden_state_3 = self.encoder_3(x)

        x, _ = self.predictor_1(torch.zeros_like(x), hidden_state_1)
        x, _ = self.predictor_2(x, hidden_state_2)
        x, _ = self.predictor_3(x, hidden_state_3)

        seq_output = []
        for t in range(x.shape[1]):
            tmp = self.conv2d(x[:, t, :, :, :])
            seq_output.append(tmp)
        output = torch.stack(seq_output, 1)

        return output


## Train

In [None]:
import argparse
import datetime
import os
import time
import uuid
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from tqdm import tqdm as tqdm

import catalyst
from catalyst.callbacks.checkpoint import CheckpointCallback
from catalyst.callbacks.misc import EarlyStoppingCallback
from catalyst.dl import SupervisedRunner



def main():
    if args is None:
        args = argument_paser()

    # Set experiment id
    exp_id = str(uuid.uuid4())[:8] if args.exp_id is None else args.exp_id
    print(f'Experiment Id: {exp_id}', flush=True)

    # Fix seed
    torch.manual_seed(args.seed)

    # Config gpu
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Prepare data
    dataset = MovingMnistDataset()
    train_index, valid_index = train_test_split(
        range(len(dataset)), test_size=0.3)
    train_loader = DataLoader(
        Subset(dataset, train_index), batch_size=args.batch_size, shuffle=True)
    valid_loader = DataLoader(
        Subset(dataset, valid_index), batch_size=args.test_batch_size, shuffle=False)
    loaders = {"train": train_loader, "valid": valid_loader}

    model = ConvLSTMEncoderPredictor(image_size=(64, 64)).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    criterion = nn.MSELoss()

    runner = SupervisedRunner(device=catalyst.utils.get_device())
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=None,
        loaders=loaders,
        # model will be saved to {logdir}/checkpoints
        logdir=os.path.join(args.log_dir, exp_id),
        callbacks=[CheckpointCallback(save_n_best=args.n_saved),
                   EarlyStoppingCallback(patience=args.es_patience,
                                         metric="loss",
                                         minimize=True,)],
        num_epochs=args.epochs,
        main_metric="loss",
        minimize_metric=True,
        fp16=None,
        verbose=True
    )


def argument_paser():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=16, metavar='N',
                        help='input batch size for training (default: 16)')
    parser.add_argument('--test-batch-size', type=int, default=16,
                        metavar='N',
                        help='input batch size for testing (default: 16)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    # parser.add_argument('--save-model-path', type=str, default='./checkpoints',
    #                     help='For Saving the current Model (default: ./checkpoints)')
    parser.add_argument('--n-saved', type=int, default=1,
                        help='For Saving the current Model (default: 1)')
    # parser.add_argument('--log-interval', type=int, default=0,
    #                     help='logging interval (default: 0)')
    parser.add_argument('--log-dir', type=str, default='./logs',
                        help='path to snapshot file (default: ./logs)')
    parser.add_argument('--es-patience', type=int, default=10,
                        help='Early stop patience (default: 10)')
    parser.add_argument('--exp-id', type=str, default=None,
                        help='experiment id')
    args = parser.parse_args()
    return args

In [None]:
import argparse
import datetime
import os
import time
import uuid

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from tqdm import tqdm as tqdm

from ignite.contrib.handlers import ProgressBar
from ignite.engine import (Events, create_supervised_evaluator,
                           create_supervised_trainer)
from ignite.handlers import EarlyStopping, ModelCheckpoint, Timer
from ignite.metrics import Accuracy, Loss, RunningAverage



def write_metrics(metrics, writer, timer, mode: str, epoch: int):
    """print metrics & write metrics to log"""
    avg_loss = metrics['mse']
    print(f"{mode} Results - Epoch: {epoch} -- Avg loss: {avg_loss:.5f} -- Elapsed time: {timer.value():.2f}")
    if writer is not None:
        writer.add_scalar(f"{mode}/avg_loss", avg_loss, epoch)


def score_function(engine):
    val_loss = engine.state.metrics['mse']
    return - val_loss


def _epoch(engine, event_name):
    return engine.state.epoch


def run(exp_id, epochs, model, criterion, optimizer, scheduler,
        train_loader, valid_loader, device, writer, log_interval,
        n_saved, save_dir, es_patience):

    # check parameters
    assert exp_id is not None
    assert model is not None
    assert criterion is not None
    assert optimizer is not None
    assert train_loader is not None
    assert valid_loader is not None
    assert device is not None
    assert save_dir is not None

    trainer = create_supervised_trainer(
        model, optimizer, criterion, device=device)
    evaluator = create_supervised_evaluator(
        model,
        metrics={'mse': Loss(criterion)},
        device=device
    )

    # # Timer
    timer = Timer(average=False)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 pause=Events.EPOCH_COMPLETED,
                 resume=Events.EPOCH_STARTED,
                 step=Events.EPOCH_COMPLETED)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names='all')

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        if log_interval > 0:
            i = (engine.state.iteration - 1) % len(train_loader) + 1
            if i % log_interval == 0:
                print(f"Epoch[{engine.state.epoch}] -- Iteration[{i}/{len(train_loader)}] -- "
                      f"Loss: {engine.state.output:.5f} -- Elapsed time: {timer.value():.2f}")
                if writer is not None:
                    writer.add_scalar("training/loss", engine.state.output,
                                      engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        write_metrics(metrics, writer, timer, 'Training', engine.state.epoch)

        if scheduler is not None:
            scheduler.step()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(valid_loader)
        metrics = evaluator.state.metrics
        write_metrics(metrics, writer, timer, 'Validation', engine.state.epoch)

        pbar.n = pbar.last_print_n = 0

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_optimizer(engine):
        # Save optimizer
        optimizer_file_name = os.path.join(
            save_dir, exp_id, 'optimizer.pth')
        torch.save(optimizer.state_dict(), optimizer_file_name)
        print("Save optimizer :", optimizer_file_name)

        # Save scheduler
        if scheduler is not None:
            scheduler_file_name = os.path.join(
                save_dir, exp_id, 'scheduler.pth')
            torch.save(scheduler.state_dict(), scheduler_file_name)
            print("Save optimizer :", scheduler_file_name)

    # # Checkpoint setting
    # {save_dir}/{exp_id}/best_mymodel_{engine.state.epoch}
    # n_saved 個までモデルを保持する
    handler = ModelCheckpoint(dirname=f'{save_dir}/{exp_id}', filename_prefix='best',
                              n_saved=n_saved, create_dir=True, global_step_transform=_epoch)
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              handler, {'mymodel': model})

    # # Early stopping
    handler = EarlyStopping(
        patience=es_patience, score_function=score_function, trainer=trainer)
    # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset)
    evaluator.add_event_handler(Events.COMPLETED, handler)

    trainer.run(train_loader, max_epochs=epochs)


def argument_paser():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=16, metavar='N',
                        help='input batch size for training (default: 16)')
    parser.add_argument('--test-batch-size', type=int, default=16,
                        metavar='N',
                        help='input batch size for testing (default: 16)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model-path', type=str, default='./checkpoints',
                        help='For Saving the current Model (default: ./checkpoints)')
    parser.add_argument('--n-saved', type=int, default=1,
                        help='For Saving the current Model (default: 1)')
    parser.add_argument('--log-interval', type=int, default=0,
                        help='logging interval (default: 0)')
    parser.add_argument('--log-dir', type=str, default='./logs',
                        help='path to snapshot file (default: ./logs)')
    parser.add_argument('--es-patience', type=int, default=10,
                        help='Early stop patience (default: 10)')
    parser.add_argument('--exp-id', type=str, default=None,
                        help='experiment id')
    args = parser.parse_args()
    return args


def main(args=None):
    if args is None:
        args = argument_paser()

    # Set experiment id
    exp_id = str(uuid.uuid4())[:8] if args.exp_id is None else args.exp_id
    print(f'Experiment Id: {exp_id}', flush=True)

    # Fix seed
    torch.manual_seed(args.seed)

    # Set logger
    log_writer = SummaryWriter(log_dir=os.path.join(
        args.log_dir, exp_id)) if args.log_dir is not None else None

    # Prepare data
    dataset = MovingMnistDataset()
    train_index, valid_index = train_test_split(
        range(len(dataset)), test_size=0.3)
    train_loader = DataLoader(
        Subset(dataset, train_index), batch_size=args.batch_size, shuffle=True)
    valid_loader = DataLoader(
        Subset(dataset, valid_index), batch_size=args.test_batch_size, shuffle=False)

    # Prepare model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = ConvLSTMEncoderPredictor(image_size=(64, 64)).to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    criterion = nn.MSELoss()

    run(
        exp_id=exp_id,
        epochs=args.epochs,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=None,
        train_loader=train_loader,
        valid_loader=valid_loader,
        device=device,
        writer=log_writer,
        log_interval=args.log_interval,
        n_saved=args.n_saved,
        save_dir=args.save_model_path,
        es_patience=args.es_patience
    )

    log_writer.close()

    return exp_id, model

In [None]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                    help='input batch size for training (default: 16)')
parser.add_argument('--test-batch-size', type=int, default=4,
                    metavar='N',
                    help='input batch size for testing (default: 16)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--save-model-path', type=str, default='./checkpoints',
                    help='For Saving the current Model (default: ./checkpoints)')
parser.add_argument('--n-saved', type=int, default=1,
                    help='For Saving the current Model (default: 1)')
parser.add_argument('--log-interval', type=int, default=0,
                    help='logging interval (default: 0)')
parser.add_argument('--log-dir', type=str, default='./logs',
                    help='path to snapshot file (default: ./logs)')
parser.add_argument('--es-patience', type=int, default=10,
                    help='Early stop patience (default: 10)')
parser.add_argument('--exp-id', type=str, default=None,
                    help='experiment id')
args = parser.parse_args(args=['--epochs', '3'])

In [None]:
exp_id, model = main(args)

In [None]:
!cat ./logs/2528be32/log.csv

In [None]:
import gc; gc.collect()

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./logs

## Test

In [None]:
import argparse

import numpy as np
import torch
from PIL import Image



def inference(args=None):
    if args is None:
        parser = argparse.ArgumentParser()
        parser.add_argument('--model-path', '-m', type=str, default=None)
        parser.add_argument('--id', '-i', type=int, default=0)
        args = parser.parse_args()
    

    test = MovingMnistDataset(phase_train=False)

    model = ConvLSTMEncoderPredictor(image_size=(64, 64))

    if args.model_path is not None:
        print("loading model from " + args.model_path)
        model.load_state_dict(torch.load(args.model_path))

    data, target = test[args.id]

    data = np.expand_dims(data, 0)
    target = np.expand_dims(target, 0)

    data = torch.from_numpy(data.astype(np.float32)).clone()
    res = model(data).to('cpu').detach().numpy().copy()
    return res

In [None]:
!ls ./checkpoints/18d8b87a/

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', '-m', type=str, default=None)
parser.add_argument('--id', '-i', type=int, default=0)
args = parser.parse_args(args=['--model-path', './checkpoints/18d8b87a/best_mymodel_3.pt', '--id', '0'])

In [None]:
result = inference(args)

In [None]:
result = result.reshape(result.shape[1:])

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from ipywidgets import interact

def f(k):
    plt.imshow(result[k][0], 'gray')
    plt.show()

interact(f, k=(0,9,1) )