In [1]:
import sys
sys.path.insert(0, '../')
%reload_ext autoreload
%autoreload 2
import logging

In [2]:
logging.disable(logging.DEBUG)

# 测试 ConvLSTM

In [3]:
from nowcasting.models.convLSTM import ConvLSTM, ConvLSTMCell
import torch

In [None]:
batch_size, input_channel, hidden_size, width, heigt = 2, 10, 10, 64, 32
c = ConvLSTMCell(input_channel, hidden_size, 3)

In [None]:
input = torch.randn(batch_size, input_channel, heigt, width)
target = torch.randn(batch_size, input_channel, heigt, width)
hidden = torch.zeros(batch_size, hidden_size, heigt, width)
cell = torch.zeros(batch_size, hidden_size, heigt, width)

In [None]:
output = c(input, (hidden, cell))[0]

In [None]:
if torch.cuda.is_available():
    output = output.cuda()
    target = target.cuda()
loss_fn = torch.nn.MSELoss()
output = output.double()
target = target.double()
assert torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True)

In [None]:
seq_num = 10
input = torch.randn(seq_num, batch_size, input_channel, heigt, width)
target = torch.randn(seq_num, batch_size, input_channel, heigt, width)

state = (hidden, cell)

In [None]:
conv_lstm = ConvLSTM(input_channel, hidden_size, 3)

In [None]:
output, stack = conv_lstm(input, state)

In [None]:
if torch.cuda.is_available():
    output = output.cuda()
    target = target.cuda()
output = output.double()
target = target.double()
assert torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True)

# 测试 Encoder 和 Forecaster

In [4]:
from nowcasting.models.forecaster import Forecaster
from nowcasting.models.encoder import Encoder
from collections import OrderedDict
from nowcasting.config import cfg

### Encoder

In [5]:
encoder_params = [
    [
        OrderedDict({'conv1_1': [1, 8, 7, 5, 1]}),
        OrderedDict({'conv2_1': [64, 192, 5, 3, 1]}),
        OrderedDict({'conv3_1': [192, 192, 3, 2, 1]}),
    ],
    
    [
        ConvLSTM(8, 64, 3),
        ConvLSTM(192, 192, 3),
        ConvLSTM(192, 192, 3)
    ]
]

In [6]:
encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)

In [7]:
encoder

Encoder(
  (stage1): Sequential(
    (conv1_1): Conv2d(1, 8, kernel_size=(7, 7), stride=(5, 5), padding=(1, 1))
    (relu_conv1_1): ReLU(inplace)
  )
  (rnn1): ConvLSTM(
    (_cell): ConvLSTMCell(
      (_conv): Conv2d(72, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (stage2): Sequential(
    (conv2_1): Conv2d(64, 192, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
    (relu_conv2_1): ReLU(inplace)
  )
  (rnn2): ConvLSTM(
    (_cell): ConvLSTMCell(
      (_conv): Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (stage3): Sequential(
    (conv3_1): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (relu_conv3_1): ReLU(inplace)
  )
  (rnn3): ConvLSTM(
    (_cell): ConvLSTMCell(
      (_conv): Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)

In [8]:
data = torch.randn((5, 2, 1, 480, 480)).to(cfg.GLOBAL.DEVICE)

In [9]:
state, size = encoder(data)

In [10]:
state[0][0].size()

torch.Size([2, 64, 96, 96])

In [11]:
size

torch.Size([5, 2, 192, 16, 16])

### Forecaster

In [12]:
forecaster_params = [
    [
        OrderedDict({'deconv1_1': [192, 192, 4, 2, 1]}),
        OrderedDict({'deconv2_1': [192, 64, 5, 3, 1]}),
        OrderedDict({
            'deconv3_1': [64, 8, 7, 5, 1],
            'deconv4_2': [8, 1, 1, 1, 0]
        }),
    ],
    
    [
        ConvLSTM(192, 192, 3),
        ConvLSTM(192, 192, 3),
        ConvLSTM(64, 64, 3)
    ]
]

In [13]:
forecaster = Forecaster(forecaster_params[0], forecaster_params[1]).to(cfg.GLOBAL.DEVICE)

In [14]:
forecaster

Forecaster(
  (rnn3): ConvLSTM(
    (_cell): ConvLSTMCell(
      (_conv): Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (stage3): Sequential(
    (deconv1_1): ConvTranspose2d(192, 192, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (relu_deconv1_1): ReLU(inplace)
  )
  (rnn2): ConvLSTM(
    (_cell): ConvLSTMCell(
      (_conv): Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (stage2): Sequential(
    (deconv2_1): ConvTranspose2d(192, 64, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
    (relu_deconv2_1): ReLU(inplace)
  )
  (rnn1): ConvLSTM(
    (_cell): ConvLSTMCell(
      (_conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (stage1): Sequential(
    (deconv3_1): ConvTranspose2d(64, 8, kernel_size=(7, 7), stride=(5, 5), padding=(1, 1))
    (relu_deconv3_1): ReLU(inplace)
    (deconv4_2): ConvTranspose2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
    (relu_deconv4_2): ReLU(i

In [15]:
output = forecaster(state, size)

In [16]:
output.size()

torch.Size([20, 2, 1, 480, 480])