In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader


class MovingMnistDataset(Dataset):
    def __init__(self, path="./mnist_test_seq.npy"):
        self.data = np.load(path)
        # (t, N, H, W) -> (N, t, C, H, W)
        self.data = self.data.transpose(1, 0, 2, 3)[:, :, None, ...]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return (self.data[i, :10, ...]/255).astype(np.float32), (self.data[i, 10:, ...]/255).astype(np.float32)

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import train_test_split


dataset = MovingMnistDataset()

train_index, valid_index = train_test_split(range(len(dataset)), test_size=0.3)

batch_size = 16

train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset   = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

In [None]:
"""
Copyright (c) 2020 Masafumi Abeta. All Rights Reserved.
Released under the MIT license
"""
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair


class ConvLSTMCell(nn.Module):

    def __init__(self, in_channels, hidden_channels,
                 kernel_size, stride=1, image_size=None):
        """ConvLSTM cell.

        Parameters
        ----------
        in_channels: int
            Number of channels of input tensor.
        hidden_channels: int
            Number of channels of hidden state.
        kernel_size: int or (int, int)
            Size of the convolutional kernel.
        stride: int or (int, int)
            Stride of the convolution.
        image_size: (int, int)
            Shape of image.
        """

        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)

        # No bias for hidden, since bias is included in observation convolution
        # Pad the hidden layer so that the input and output sizes are equal
        self.Wxi = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Whi = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)
        self.Wxf = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Whf = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)
        self.Wxg = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Whg = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)
        self.Wxo = Conv2dStaticSamePadding(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size)
        self.Who = Conv2dStaticSamePadding(
            self.hidden_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=image_size, bias=False)

    def forward(self, x, hidden_state):
        """
        Parameters
        ----------
        x: torch.Tensor
            4-D Tensor of shape (b, c, h, w).
        hs: tuple
            Previous hidden state of shape (h_0, c_0).

        Returns
        -------
            h_next, c_next
        """

        h_prev, c_prev = hidden_state
        i = torch.sigmoid(self.Wxi(x) + self.Whi(h_prev))
        f = torch.sigmoid(self.Wxf(x) + self.Whf(h_prev))
        o = torch.sigmoid(self.Wxo(x) + self.Who(h_prev))
        g = torch.tanh(self.Wxg(x) + self.Whg(h_prev))

        c_next = f * c_prev + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next


class ConvLSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels,
                 kernel_size, stride=1, image_size=None):
        """ConvLSTM.

        Parameters
        ----------
        in_channels: int
            Number of channels of input tensor.
        hidden_channels: int
            Number of channels of hidden state.
        kernel_size: int or (int, int)
            Size of the convolutional kernel.
        stride: int or (int, int)
            Stride of the convolution.
        image_size: (int, int)
            Shape of image.
        """

        super().__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.image_size = image_size

        self.lstm_cell = ConvLSTMCell(
            self.in_channels, self.hidden_channels, self.kernel_size, self.stride, image_size=self.image_size)

    def forward(self, xs, hidden_state=None):
        """
        Parameters
        ----------
        xs: torch.Tensor
            5-D Tensor of shape (b, t, c, h, w).
        hs: list
            Previous hidden state of shape (h_0, c_0).

        Returns
        -------
            last_state_list, layer_output
        """

        batch_size, sequence_length, _, height, width = xs.size()

        if hidden_state is None:
            hidden_state = (torch.zeros(batch_size, self.hidden_channels, height, width, device=xs.device),
                            torch.zeros(batch_size, self.hidden_channels, height, width, device=xs.device))

        output_list = []
        for t in range(sequence_length):
            hidden_state = self.lstm_cell(xs[:, t, ...], hidden_state)
            h, _ = hidden_state
            output_list.append(h)

        output = torch.stack(output_list, dim=1)

        return output, hidden_state


class Conv2dStaticSamePadding(nn.Conv2d):
    """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
       The padding mudule is calculated in construction function, then used in forward.

        # Copyright: lukemelas (github username)
        # Released under the MIT License <https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/LICENSE>
        # <https://github.com/lukemelas/EfficientNet-PyTorch/blob/4d63a1f77eb51a58d6807a384dda076808ec02c0/efficientnet_pytorch/utils.py>
    """

    # With the same calculation as Conv2dDynamicSamePadding
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
        self.stride = self.stride if len(self.stride) == 2 else [
            self.stride[0]] * 2

        # Calculate padding based on image size and save it
        assert image_size is not None
        ih, iw = (image_size, image_size) if isinstance(
            image_size, int) else image_size
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max((oh - 1) * self.stride[0] +
                    (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] +
                    (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            self.static_padding = nn.ZeroPad2d(
                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
        else:
            self.static_padding = nn.Identity()

    def forward(self, x):
        x = self.static_padding(x)
        x = F.conv2d(x, self.weight, self.bias, self.stride,
                     self.padding, self.dilation, self.groups)
        return x


In [None]:
import torch
import torch.nn as nn


class ConvLSTMEncoderPredictor(nn.Module):
    def __init__(self, image_size):
        """ConvLSTM Encoder Predictor.

        Parameters
        ----------
        image_size: (int, int)
            Shape of image.
        """

        super().__init__()

        self.encoder_1 = ConvLSTM(
            in_channels=1, hidden_channels=32, kernel_size=3, stride=1, image_size=image_size)
        self.encoder_2 = ConvLSTM(
            in_channels=32, hidden_channels=32, kernel_size=3, stride=1, image_size=image_size)
        self.encoder_3 = ConvLSTM(
            in_channels=32, hidden_channels=32, kernel_size=3, stride=1, image_size=image_size)

        self.predictor_1 = ConvLSTM(
            in_channels=32, hidden_channels=32, kernel_size=3, stride=1, image_size=image_size)
        self.predictor_2 = ConvLSTM(
            in_channels=32, hidden_channels=32, kernel_size=3, stride=1, image_size=image_size)
        self.predictor_3 = ConvLSTM(
            in_channels=32, hidden_channels=32, kernel_size=3, stride=1, image_size=image_size)

        self.conv2d = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        x, hidden_state_1 = self.encoder_1(x)
        x, hidden_state_2 = self.encoder_2(x)
        x, hidden_state_3 = self.encoder_3(x)

        x, _ = self.predictor_1(torch.zeros_like(x), hidden_state_1)
        x, _ = self.predictor_2(x, hidden_state_2)
        x, _ = self.predictor_3(x, hidden_state_3)

        seq_output = []
        for t in range(x.shape[1]):
            tmp = self.conv2d(x[:, t, :, :, :])
            seq_output.append(tmp)
        output = torch.stack(seq_output, 1)

        return output

In [None]:
def train(net, train_loader, optimizer, loss_fn, device):
    net.train()
    
    running_loss = 0
    for i, (xx, yy) in tqdm(enumerate(train_loader), total=len(train_loader)):
        xx = xx.to(device)
        yy = yy.to(device).view(-1)

        y_pred = net(xx).view(-1)
        loss = loss_fn(y_pred, yy)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    return running_loss / len(train_loader)


def eval_net(net, valid_loader, loss_fn, device):
    net.eval()
    score = 0

    for i, (x, y) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
        x = x.to(device)
        y = y.to(device).view(-1)
        with torch.no_grad():
            y_pred = net(x).view(-1)
        score += loss_fn(y_pred, y).item()

    valid_score = score / (i + 1)
    return valid_score

In [None]:
def save(epoch, net, optimizer, scheduler, train_losses, valid_losses, elapsed_time, save_model_path, save_log_path):

    now = datetime.datetime.today().strftime("%Y%m%d%H%M%S")
    
    if save_model_path != None:
        os.makedirs(save_model_path, exist_ok=True)
        torch.save(net.state_dict(), os.path.join(save_model_path, f"weight_{now}.pth"))
        print( "Save model : " + os.path.join(save_model_path, f"weight_{now}.pth") )
        torch.save(optimizer.state_dict(), os.path.join(save_model_path, f"optimizer_{now}.pth"))
        torch.save(scheduler.state_dict(), os.path.join(save_model_path, f"scheduler_{now}.pth"))
        print( "Save optimizer : " + os.path.join(save_model_path, f"optimizer_{now}.pth") )

    if save_log_path != None:
        if os.path.exists(os.path.join(save_log_path, "log.csv")):
            # 過去のログ読み込み
            log_df = pd.read_csv(os.path.join(save_log_path, "log.csv"))
        else:
            log_df = pd.DataFrame([], columns=["datetime", "epoch", "train_loss", "valid_loss", "elapsed_time"])

        tmp_log = pd.DataFrame([[now, epoch, train_losses[-1], valid_losses[-1], elapsed_time]], columns=["datetime", "epoch", "train_loss", "valid_loss", "elapsed_time"])
        log_df = pd.concat([log_df, tmp_log])
        log_df.to_csv(os.path.join(save_log_path, "log.csv"), index=False)
        print( "Save log : " + os.path.join(save_log_path, "log.csv") )

In [None]:
import time
from tqdm import tqdm_notebook as tqdm


train_losses = []
valid_losses = []

n_iter = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = ConvLSTMEncoderPredictor(image_size=(64, 64)).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.005, betas=(0.9, 0.999))
loss_fn = nn.MSELoss()

for epoch in range(n_iter):
    start = time.time()

    train_score = train(net, train_dataloader, optimizer, loss_fn, device)
    train_losses.append(train_score)

    valid_score = eval_net(net, valid_dataloader, loss_fn, device)
    valid_losses.append(valid_score)

    elapsed_time = time.time() - start
    print(f"epoch:{epoch}", 
            "--", "train loss:{:.5f}".format(train_losses[-1]),
            "--","valid loss:{:.5f}".format(valid_losses[-1]),
            "--", "elapsed_time:{:.2f}".format(elapsed_time) + "[sec]", flush=True)
    
    if epoch % 1 == 0:
        save(epoch, net, optimizer, None, train_losses, valid_losses, elapsed_time, save_model_path=None, save_log_path='.')