### 第一步实现 

$$
\mathcal { M } _ { c , i , j } = \sum _ { m = 1 } ^ { H } \sum _ { n = 1 } ^ { W } \mathcal { I } _ { c , m , n } \max \left( 0,1 - \left| i + \mathbf { V } _ { i , j } - m \right| \right) \max \left( 0,1 - \left| j + \mathbf { U } _ { i , j } - n \right| \right)
$$

作者使用的是: http://beta.mxnet.io/r/api/mx.symbol.GridGenerator.html

对应代码:

```python
def flow_conv(data, num_filter, flows, weight, bias, name):
    assert isinstance(flows, list)
    warpped_data = []
    for i in range(len(flows)):
        flow = flows[i]
        grid = mx.sym.GridGenerator(data=-flow, transform_type="warp")
        ele_dat = mx.sym.BilinearSampler(data=data, grid=grid)
        warpped_data.append(ele_dat)
    data = mx.sym.concat(*warpped_data, dim=1)
    ret = mx.sym.Convolution(data=data,
                             num_filter=num_filter,
                             kernel=(1, 1),
                             weight=weight,
                             bias=bias,
                             name=name)
    return ret
```

1X1 卷积对应公式的三个邻近卷积，flows 为一个 list，每个元素对应的 shape 是 (batch, 2, h, w)，对应 $U_l$，$V_l$

首先 naive 实现:

In [1]:
import torch
import sys
sys.path.insert(0, '../')
%reload_ext autoreload
%autoreload 2
import logging
from nowcasting.hko.dataloader import HKOIterator
from nowcasting.config import cfg
import numpy as np

测试数据

In [2]:
channels = 3
h, w = 5, 6
batch_size = 1
U, V = torch.randn(batch_size, h, w), torch.randn(batch_size, h, w)
input = torch.randn(batch_size, channels,  h, w)

In [3]:
M = torch.zeros_like(input)

optical flow 必须在 [-1, 1], 所以这里和 mxnet 不一样

In [4]:
for c in range(channels):
    for i in range(h):
        for j in range(w):
            for m in range(h):
                for n in range(w):
                    t1 = 1 - torch.abs(i+V[:, i, j]-m)
                    t2 = 1 - torch.abs(j+U[:, i, j]-n)
                    M[:, c, i, j] += input[:, c, m, n]*torch.max(t1, torch.zeros_like(t1))*torch.max(t2, torch.zeros_like(t2))

mxnet 实现

In [5]:
import mxnet as mx
mx_flow = torch.cat([V, U]).reshape((batch_size, 2, h, w))
mx_flow = mx.nd.array(mx_flow.numpy())
mx_input = mx.nd.array(input.numpy())

In [6]:
mxnet_grid = mx.ndarray.GridGenerator(data=mx_flow, transform_type="warp")

In [7]:
mxnet_output = mx.ndarray.BilinearSampler(data=mx_input, grid=mxnet_grid)

In [8]:
import numpy as np

pytorch 实现

input: [B, C, H, W]

flow: [B, 2, H, W]

In [9]:
B, C, H, W = input.size()
# mesh grid 
xx = torch.arange(0, W).view(1,-1).repeat(H,1)
yy = torch.arange(0, H).view(-1,1).repeat(1,W)
xx = xx.view(1,1,H,W).repeat(B,1,1,1)
yy = yy.view(1,1,H,W).repeat(B,1,1,1)
grid = torch.cat((xx,yy),1).float()
vgrid = grid + torch.from_numpy(mx_flow.asnumpy())

# scale grid to [-1,1] 
vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0

测试一下 grid

In [10]:
np.sum(np.abs(vgrid.numpy()-mxnet_grid.asnumpy()))

0.0

很棒很棒！！！

In [11]:
# https://pytorch.org/docs/0.3.1/nn.html#torch.nn.functional.grid_sample
vgrid = vgrid.permute(0,2,3,1)

In [12]:
np.sum(np.abs(mxnet_output.asnumpy()-torch.nn.functional.grid_sample(input, vgrid).numpy()))

2.0936131e-06

ok，完成，参考:
* [NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py#L139](https://github.com/NVlabs/PWC-Net/blob/master/PyTorch/models/PWCNet.py#L139)
* [https://discuss.pytorch.org/t/solved-how-to-do-the-interpolating-of-optical-flow/5019/12](https://discuss.pytorch.org/t/solved-how-to-do-the-interpolating-of-optical-flow/5019/12)
* [torch.nn.functional.grid_sample](https://pytorch.org/docs/0.3.1/nn.html#torch.nn.functional.grid_sample)
* [incubator-mxnet/src/operator/bilinear_sampler.cc](https://github.com/apache/incubator-mxnet/blob/992c3c0dd90c0723de6934e826a49bad6569eeac/src/operator/bilinear_sampler.cc)

主要是第一个，非常完美！

In [13]:
from nowcasting.models.trajGRU import wrap

In [14]:
output = wrap(input.to(cfg.GLOBAL.DEVICE), torch.from_numpy(mx_flow.asnumpy()).to(cfg.GLOBAL.DEVICE))
np.sum(np.abs(mxnet_output.asnumpy()-output.cpu().numpy()))

3.1171367e-06

### 第二步

In [15]:
import torch

In [16]:
from nowcasting.models.trajGRU import TrajGRU

encoder 第一层的 RNN 输出是 8, 96, 96；hidden size 是 64.

所以输出应为 S, B, hidden_size, H, W.

H, W 是输入的 feature map 的高宽.

参考作者的 [traj_rnn.py](https://github.com/sxjscience/HKO-7/blob/b9235ca7edd4e5bd275f1abaf6f735f8059121be/nowcasting/operators/traj_rnn.py), 里面的 forward 函数对应 1, B, C, H, W 的输入，然后再看 [unroll](https://github.com/sxjscience/HKO-7/blob/b9235ca7edd4e5bd275f1abaf6f735f8059121be/nowcasting/operators/base_rnn.py#L20], 将 rnn 展开即可。

所以 trajGRU 的输出为每一 seq 的 hidden states，而最后一个 hidden states 作为 forecaster 的输入。

In [17]:
S, B, C, H, W = 5, 3, 8, 96, 96
input = torch.randn((S, B, C, H, W), dtype=torch.float).to(cfg.GLOBAL.DEVICE)

b_h_w 参数是 downsampling 或 upsampling 之后的 feature map 的大小，需要提前计算出来。

zoneout 默认为 0，其他参数参考 [encoder_forecaster.py](https://github.com/sxjscience/HKO-7/blob/b9235ca7edd4e5bd275f1abaf6f735f8059121be/nowcasting/encoder_forecaster.py#L37) 和 [trajgru_55_55_33_1_64_1_192_1_192_13_13_9_b4.yml](https://github.com/sxjscience/HKO-7/blob/b9235ca7edd4e5bd275f1abaf6f735f8059121be/experiments/hko/configurations/trajgru_55_55_33_1_64_1_192_1_192_13_13_9_b4.yml#L27)

In [18]:
trajGRU = TrajGRU(input_channel=8, num_filter=64, b_h_w=(B, H, W), zoneout=0.0, L=13,
                 i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
                 h2h_kernel=(5, 5), h2h_dilate=(1, 1),
                 act_type=cfg.MODEL.RNN_ACT_TYPE).to(cfg.GLOBAL.DEVICE)

TrajGRU 96 96


In [19]:
trajGRU

TrajGRU(
  (i2h): Conv2d(8, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (i2f_conv1): Conv2d(8, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (h2f_conv1): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (flows_conv): Conv2d(32, 26, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (ret): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))
)

In [20]:
outputs, next_h = trajGRU(input)

In [21]:
outputs.size()

torch.Size([5, 3, 64, 96, 96])

In [22]:
next_h.size()

torch.Size([3, 64, 96, 96])

### 文章的网络

In [23]:
from collections import OrderedDict
from nowcasting.models.encoder import Encoder
from nowcasting.models.model import EF
from nowcasting.models.forecaster import Forecaster
# build model
encoder_params = [
    [
        OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}),
        OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}),
        OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}),
    ],

    [
        # ConvLSTM(8, 64, 3),
        # ConvLSTM(192, 192, 3),
        # ConvLSTM(192, 192, 3)
        TrajGRU(input_channel=8, num_filter=64, b_h_w=(batch_size, 96, 96), zoneout=0.0, L=13,
                i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
                h2h_kernel=(5, 5), h2h_dilate=(1, 1),
                act_type=cfg.MODEL.RNN_ACT_TYPE),

        TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 32, 32), zoneout=0.0, L=13,
                i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
                h2h_kernel=(5, 5), h2h_dilate=(1, 1),
                act_type=cfg.MODEL.RNN_ACT_TYPE),
        TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 16, 16), zoneout=0.0, L=9,
                i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
                h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                act_type=cfg.MODEL.RNN_ACT_TYPE)
    ]
]


encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)

forecaster_params = [
    [
        OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}),
        OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}),
        OrderedDict({
            'deconv3_leaky_1': [64, 8, 7, 5, 1],
            'conv3_leaky_2': [8, 8, 3, 1, 1],
            'conv3_3': [8, 1, 1, 1, 0] # 忘了删除激活函数了，妈的
            # 忘了卷积层，分类
        }),
    ],

    [
        TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 16, 16), zoneout=0.0, L=13,
                i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
                h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                act_type=cfg.MODEL.RNN_ACT_TYPE),

        TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 32, 32), zoneout=0.0, L=13,
                i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
                h2h_kernel=(5, 5), h2h_dilate=(1, 1),
                act_type=cfg.MODEL.RNN_ACT_TYPE),
        TrajGRU(input_channel=64, num_filter=64, b_h_w=(batch_size, 96, 96), zoneout=0.0, L=9,
                i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
                h2h_kernel=(5, 5), h2h_dilate=(1, 1),
                act_type=cfg.MODEL.RNN_ACT_TYPE)
    ]
]

forecaster = Forecaster(forecaster_params[0], forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE).to(cfg.GLOBAL.DEVICE)

TrajGRU 96 96
TrajGRU 32 32
TrajGRU 16 16
TrajGRU 16 16
TrajGRU 32 32
TrajGRU 96 96


In [24]:
encoder_forecaster

EF(
  (encoder): Encoder(
    (stage1): Sequential(
      (conv1_leaky_1): Conv2d(1, 8, kernel_size=(7, 7), stride=(5, 5), padding=(1, 1))
      (leaky_conv1_leaky_1): LeakyReLU(negative_slope=0.2, inplace)
    )
    (rnn1): TrajGRU(
      (i2h): Conv2d(8, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (i2f_conv1): Conv2d(8, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (h2f_conv1): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (flows_conv): Conv2d(32, 26, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (ret): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))
    )
    (stage2): Sequential(
      (conv2_leaky_1): Conv2d(64, 192, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
      (leaky_conv2_leaky_1): LeakyReLU(negative_slope=0.2, inplace)
    )
    (rnn2): TrajGRU(
      (i2h): Conv2d(192, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (i2f_conv1): Conv2d(192, 32, kernel_size=(5, 5), stride=(1, 1),

In [25]:
S, B, C, H, W = 5, 2, 1, 480, 480
input = torch.randn((S, B, C, H, W), dtype=torch.float).to(cfg.GLOBAL.DEVICE)

In [26]:
output = encoder_forecaster(input)

In [28]:
output.size()

torch.Size([20, 2, 1, 480, 480])