In [1]:
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt

## extract paddlepaddle weights

In [2]:
def load_parameter(file_name):
    with open(file_name, 'rb') as f:
        f.read(16)  # skip header.
        return np.fromfile(f, dtype=np.float32)
    
def compute_difference(paddle_outputs, torch_outputs):
    return np.sqrt(np.mean((paddle_outputs - torch_outputs)**2.0))

After two conv layers, the output will be transform from an image to a sequence. The dim of each frame is 1312. 
The input image size is (161, X) where X is the length of the image. 
In paddlepaddle.v2, the kernel size of the conv1 is (11,41) and the kernel size of the conv2 is (11,21).
In order to get correct output, I find that second-dim of both kernels is related to the fisrt-dim of the input image.

So I guess I should switch the the position of these two dims for kernel, stride and padding in pytorch.

In [3]:
class conv1(nn.Module):
    def __init__(self):
        super(conv1, self).__init__()
        self.conv = nn.Conv2d(1, 32, kernel_size=(11, 41), stride=(3, 2), padding=(5, 20), bias=False)
        
    def forward(self, x):
        x = self.conv(x)
        return x

In [4]:
conv1_weights = load_parameter("../models/baidu_en8k/params/___conv_0__.w0")
# v2 to fluid: reshape to [ochannel, inchannel, filer_height, filer_width], height is x-axis while width is y-axis
conv1_weights = conv1_weights.reshape(32, 1, 41, 11)
conv1_weights = np.transpose(conv1_weights, (0, 1, 3, 2))

In [5]:
a = torch.from_numpy(conv1_weights).numpy()

In [6]:
conv_layer = conv1()
conv_layer.eval()
# conv_layer.weight.data = torch.from_numpy(conv1_weights)
pretrained_weights = {"conv.weight"    : torch.from_numpy(conv1_weights),
                     }
conv_layer.load_state_dict(pretrained_weights)

<All keys matched successfully>

### paddlepaddle conv outputs shape: nchannel, feature_dim, frame_size

In [7]:
np_input = np.load("paddle/conv1_input.npy", allow_pickle=True, encoding="bytes")
np_output = np.load("paddle/conv1_output.npy", allow_pickle=True, encoding="bytes")
np_output = np_output.reshape(3, 32,81, 85)
temp = []
for i in np_input:
    temp.append(i[0])
np_input = np.array(temp)
# np_input = np.flip(np_input).copy()

torch_input = torch.from_numpy(np_input)
# torch_input = torch_input.type(torch.FloatTensor)
np.testing.assert_array_almost_equal(torch_input, np_input, decimal=10)

torch_input = torch.unsqueeze(torch_input, 1)
torch_input = torch_input.transpose(3,2)

In [8]:
torch_output = conv_layer(torch_input)
torch_output = torch_output.transpose(3,2)
torch_output = torch_output.data.numpy()
compute_difference(paddle_outputs=np_output, torch_outputs=torch_output)

1.1807945e-06

# CBM 

In [9]:
import sys
sys.path.append("../model_utils/")
from custom_layers import Conv_bn_mask

In [10]:
class cbm_layer(nn.Module):
    def __init__(self):
        super(cbm_layer, self).__init__()
        self.conv_bn_mask0 = Conv_bn_mask(ichannel=1,
                                          ochannel=32,
                                          kernel_size=(11, 41),
                                          stride=(3, 2),
                                          padding=(5, 20),
                                          bias=False,
                                          track_running_stats=True)
        
        self.conv_bn_mask1 = Conv_bn_mask(ichannel=32,
                                          ochannel=32,
                                          kernel_size=(11, 21),
                                          stride=(1, 2),
                                          padding=(5, 10),
                                          bias=False,
                                          track_running_stats=True)
    def forward(self, x, length):
        x1 = self.conv_bn_mask0(x, length)
        x1 = self.conv_bn_mask1(x1, length)
        return x1

In [11]:
cbm_test = cbm_layer()
cbm_test.eval()
temp = cbm_test.state_dict()
for i in temp:
    print(temp[i].shape)

torch.Size([32, 1, 11, 41])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([])
torch.Size([32, 32, 11, 21])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([])


In [12]:
conv1_weights = load_parameter("../models/baidu_en8k/params/___conv_0__.w0")
conv1_weights = conv1_weights.reshape(32, 1, 41, 11)
conv1_weights = np.transpose(conv1_weights, (0, 1, 3, 2))

bn1_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w1")
bn1_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w2")
bn1_gamma = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w0")
bn1_beta  = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.wbias")

cbm_test.conv_bn_mask0.conv.weight.data     = torch.from_numpy(conv1_weights)
cbm_test.conv_bn_mask0.bn.bias.data         = torch.from_numpy(bn1_beta)
cbm_test.conv_bn_mask0.bn.weight.data       = torch.from_numpy(bn1_gamma)
cbm_test.conv_bn_mask0.bn.running_mean.data = torch.from_numpy(bn1_mean)
cbm_test.conv_bn_mask0.bn.running_var.data  = torch.from_numpy(bn1_var)

conv2_weights = load_parameter("../models/baidu_en8k/params/___conv_1__.w0")
conv2_weights = conv2_weights.reshape(32, 32, 21, 11)
conv2_weights = np.transpose(conv2_weights, (0, 1, 3, 2))

bn2_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w1")
bn2_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w2")
bn2_gamma = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w0")
bn2_beta  = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.wbias")

cbm_test.conv_bn_mask1.conv.weight.data     = torch.from_numpy(conv2_weights)
cbm_test.conv_bn_mask1.bn.bias.data         = torch.from_numpy(bn2_beta)
cbm_test.conv_bn_mask1.bn.weight.data       = torch.from_numpy(bn2_gamma)
cbm_test.conv_bn_mask1.bn.running_mean.data = torch.from_numpy(bn2_mean)
cbm_test.conv_bn_mask1.bn.running_var.data  = torch.from_numpy(bn2_var)

In [13]:
np_input = np.load("paddle/cbm2_seq_input.npy", allow_pickle=True, encoding="bytes")
np_output = np.load("paddle/cbm2_seq_output.npy", allow_pickle=True, encoding="bytes")
np_output = np_output.reshape(3, 32, 41, 85)
temp_data = []
temp_length = []
for i in np_input:
    temp_data.append(i[0])
    temp_length.append(i[3])
np_input = np.array(temp_data)
np_length = np.array(temp_length).astype(int)

torch_input = torch.from_numpy(np_input)
# torch_input = torch_input.type(torch.FloatTensor)
np.testing.assert_array_almost_equal(torch_input, np_input, decimal=10)

torch_input = torch.unsqueeze(torch_input, 1)
torch_input = torch_input.transpose(3,2)

torch_output = cbm_test(torch_input, torch.from_numpy(np_length))
torch_output = torch_output.transpose(3,2)
torch_output = torch_output.data.numpy()
# np.testing.assert_array_almost_equal(np_output, torch_output,decimal=5)
compute_difference(paddle_outputs=np_output, torch_outputs=torch_output)

0.34108225

# cbmX2_bigru 

In [14]:
from custom_gru import BidirRNNLayer, GRUCell, RNNLayer

In [15]:
class cbmX2_bigru_layer(nn.Module):
    def __init__(self):
        super(cbmX2_bigru_layer, self).__init__()
        self.conv_bn_mask0 = Conv_bn_mask(ichannel=1,
                                          ochannel=32,
                                          kernel_size=(11, 41),
                                          stride=(3, 2),
                                          padding=(5, 20),
                                          bias=False,
                                          track_running_stats=True)

        self.conv_bn_mask1 = Conv_bn_mask(ichannel=32,
                                          ochannel=32,
                                          kernel_size=(11, 21),
                                          stride=(1, 2),
                                          padding=(5, 10),
                                          bias=False,
                                          track_running_stats=True)

        self.bigru0 = BidirRNNLayer(GRUCell, input_size=41 * 32, hidden_size=1024, gate_act="relu", state_act="tanh")
        # self.bigru1 = BidirRNNLayer(GRUCell, input_size=2048, hidden_size=1024, gate_act="relu", state_act="tanh")
        # self.bigru2 = BidirRNNLayer(GRUCell, input_size=2048, hidden_size=1024, gate_act="relu", state_act="tanh")

        # 28 of char + 1 blank
        # vocab list:  ["'", ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
        # self.bottleneck = nn.Linear(2048, 28 + 1)
        # self.softmax = nn.Softmax(dim=0)


    def forward(self, input, length):
        x = input
        seq_length = length 
        batch_size = x.shape[0]
        x = self.conv_bn_mask0(x, seq_length)
        x = self.conv_bn_mask1(x, seq_length)

        flattened_x = x.view(batch_size, -1, 41 * 32)
        flattened_x = nn.utils.rnn.pack_padded_sequence(flattened_x, seq_length.flatten(), batch_first=True)
        flattened_x, _ = self.bigru0(flattened_x, (torch.zeros((2, batch_size,1024)), ) )

        bottleneck_data, batch_sizes, _, _ = flattened_x
        # bottleneck_result = self.bottleneck(bottleneck_data)
        # output = self.softmax(bottleneck_result)
        output = nn.utils.rnn.PackedSequence(bottleneck_data, batch_sizes)
        # this is a special request for use CTC loss in pytorch
        output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        # include both padded data and valid length for each sample
        # data: [batch_size, length of sample, feature size]
        # valid length: [valid length for each sample]
        return output


In [16]:
cbmX2_bigru_test = cbmX2_bigru_layer()
cbmX2_bigru_test.eval()

# TODO, double check the normalization part

cbmX2_bigru_layer(
  (conv_bn_mask0): Conv_bn_mask(
    (conv): Conv2d(1, 32, kernel_size=(11, 41), stride=(3, 2), padding=(5, 20), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): BReLU()
    (mask): Mask()
  )
  (conv_bn_mask1): Conv_bn_mask(
    (conv): Conv2d(32, 32, kernel_size=(11, 21), stride=(1, 2), padding=(5, 10), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): BReLU()
    (mask): Mask()
  )
  (bigru0): BidirRNNLayer(
    original_name=BidirRNNLayer
    (directions): _ConstModuleList(
      original_name=_ConstModuleList
      (0): RNNLayer(
        original_name=RNNLayer
        (cell): GRUCell(
          original_name=GRUCell
          (bn): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ReverseRNNLayer(
        original_name=ReverseRNNLayer
        (cell): GRUCell(
    

In [17]:
cbmX2_bigru_test.state_dict().keys()
# TODO: load parameters

odict_keys(['conv_bn_mask0.conv.weight', 'conv_bn_mask0.bn.weight', 'conv_bn_mask0.bn.bias', 'conv_bn_mask0.bn.running_mean', 'conv_bn_mask0.bn.running_var', 'conv_bn_mask0.bn.num_batches_tracked', 'conv_bn_mask1.conv.weight', 'conv_bn_mask1.bn.weight', 'conv_bn_mask1.bn.bias', 'conv_bn_mask1.bn.running_mean', 'conv_bn_mask1.bn.running_var', 'conv_bn_mask1.bn.num_batches_tracked', 'bigru0.directions.0.cell.weight_i', 'bigru0.directions.0.cell.weight_h', 'bigru0.directions.0.cell.bias', 'bigru0.directions.0.cell.bn.weight', 'bigru0.directions.0.cell.bn.bias', 'bigru0.directions.0.cell.bn.running_mean', 'bigru0.directions.0.cell.bn.running_var', 'bigru0.directions.0.cell.bn.num_batches_tracked', 'bigru0.directions.1.cell.weight_i', 'bigru0.directions.1.cell.weight_h', 'bigru0.directions.1.cell.bias', 'bigru0.directions.1.cell.bn.weight', 'bigru0.directions.1.cell.bn.bias', 'bigru0.directions.1.cell.bn.running_mean', 'bigru0.directions.1.cell.bn.running_var', 'bigru0.directions.1.cell.bn.

In [18]:
conv0_weights   = load_parameter("../models/baidu_en8k/params/___conv_0__.w0")
conv0_weights   = conv0_weights.reshape(32, 1, 41, 11)
conv0_weights   = load_parameter("../models/baidu_en8k/params/___conv_0__.w0")
conv0_weights   = conv0_weights.reshape(32, 1, 41, 11)
conv0_weights   = np.transpose(conv0_weights, (0, 1, 3, 2))
conv0_bn_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w1")
conv0_bn_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w2")
conv0_bn_gamma = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w0")
conv0_bn_beta  = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.wbias")


conv1_weights = load_parameter("../models/baidu_en8k/params/___conv_1__.w0")
conv1_weights = conv1_weights.reshape(32, 32, 21, 11)
conv1_weights = np.transpose(conv1_weights, (0, 1, 3, 2))
conv1_bn_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w1")
conv1_bn_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w2")
conv1_bn_gamma = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w0")
conv1_bn_beta  = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.wbias")

# gru0
bigru0_directions_0_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_0__.w0")
bigru0_directions_0_cell_weight_i         = bigru0_directions_0_cell_weight_i.reshape(41 * 32, 1024*3)
bigru0_directions_0_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_0__.w0")
w_u_r = bigru0_directions_0_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c   = bigru0_directions_0_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru0_directions_0_cell_weight_h = np.concatenate([w_u_r,w_c], 1)
bigru0_directions_0_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_0__.wbias")
bigru0_directions_0_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.wbias")
bigru0_directions_0_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.w0")
bigru0_directions_0_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.w1")
bigru0_directions_0_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.w2")


bigru0_directions_1_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_1__.w0")
bigru0_directions_1_cell_weight_i         = bigru0_directions_1_cell_weight_i.reshape(41 * 32, 1024*3)
bigru0_directions_1_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_1__.w0")
w_u_r = bigru0_directions_1_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c = bigru0_directions_1_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru0_directions_1_cell_weight_h = np.concatenate([w_u_r,w_c], 1)
bigru0_directions_1_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_1__.wbias")
bigru0_directions_1_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.wbias")
bigru0_directions_1_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.w0")
bigru0_directions_1_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.w1")
bigru0_directions_1_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.w2")


In [19]:
pretrained_weights = {   "conv_bn_mask0.conv.weight"                 : conv0_weights,
                         "conv_bn_mask0.bn.weight"                   : conv0_bn_gamma,
                         "conv_bn_mask0.bn.bias"                     : conv0_bn_beta,
                         "conv_bn_mask0.bn.running_mean"             : conv0_bn_mean,
                         "conv_bn_mask0.bn.running_var"              : conv0_bn_var ,
                         "conv_bn_mask1.conv.weight"                 : conv1_weights,
                         "conv_bn_mask1.bn.weight"                   : conv1_bn_gamma,
                         "conv_bn_mask1.bn.bias"                     : conv1_bn_beta,
                         "conv_bn_mask1.bn.running_mean"             : conv1_bn_mean,
                         "conv_bn_mask1.bn.running_var"              : conv1_bn_var ,
                         "bigru0.directions.0.cell.weight_i"         : bigru0_directions_0_cell_weight_i,
                         "bigru0.directions.0.cell.weight_h"         : bigru0_directions_0_cell_weight_h,
                         "bigru0.directions.0.cell.bias"             : bigru0_directions_0_cell_bias,
                         "bigru0.directions.0.cell.bn.bias"          : bigru0_directions_0_cell_bn_bias        ,
                         "bigru0.directions.0.cell.bn.weight"        : bigru0_directions_0_cell_bn_weight      ,
                         "bigru0.directions.0.cell.bn.running_mean"  : bigru0_directions_0_cell_bn_running_mean,
                         "bigru0.directions.0.cell.bn.running_var"   : bigru0_directions_0_cell_bn_running_var ,
                         "bigru0.directions.1.cell.weight_i"         : bigru0_directions_1_cell_weight_i,
                         "bigru0.directions.1.cell.weight_h"         : bigru0_directions_1_cell_weight_h,
                         "bigru0.directions.1.cell.bias"             : bigru0_directions_1_cell_bias,
                         "bigru0.directions.1.cell.bn.bias"          : bigru0_directions_1_cell_bn_bias        ,
                         "bigru0.directions.1.cell.bn.weight"        : bigru0_directions_1_cell_bn_weight      ,
                         "bigru0.directions.1.cell.bn.running_mean"  : bigru0_directions_1_cell_bn_running_mean,
                         "bigru0.directions.1.cell.bn.running_var"   : bigru0_directions_1_cell_bn_running_var 
                     }

In [20]:
check_dict = cbmX2_bigru_test.state_dict()
for key in check_dict:
    if 'num_batches_tracked' in key:
        continue
    check_dict[key] = torch.from_numpy(pretrained_weights[key])
cbmX2_bigru_test.load_state_dict(check_dict)

<All keys matched successfully>

In [21]:
np_input = np.load("paddle/cbmX2_bigru_input.npy", allow_pickle=True, encoding="bytes")
np_output = np.load("paddle/cbmX2_bigru_output.npy", allow_pickle=True, encoding="bytes")
temp_data = []
temp_length = []
for i in np_input:
    temp_data.append(i[0])
    temp_length.append(i[3])
np_input = np.array(temp_data)
np_length = np.array(temp_length).astype(int)

torch_input = torch.from_numpy(np_input)
# torch_input = torch_input.type(torch.FloatTensor)
np.testing.assert_array_almost_equal(torch_input, np_input, decimal=10)

torch_input = torch.unsqueeze(torch_input, 1)
torch_input = torch_input.transpose(3,2)

torch_output = cbmX2_bigru_test(torch_input, torch.from_numpy(np_length))

# torch_output = torch_output.transpose(3,2)
# torch_output = torch_output.data.numpy()
compute_difference(paddle_outputs=np_output, torch_outputs=torch_output)
# np.testing.assert_array_almost_equal(np_output, torch_output,decimal=5)

<class 'list'>
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
torch.Size([3, 1024])
torch.Size([3, 1024])
tensor(3)
tor

NameError: name 'finished_hiddens' is not defined

In [22]:
torch_output

array([[[[6.97557926e-01, 9.48624849e-01, 1.12963164e+00, ...,
          8.86039555e-01, 1.18699563e+00, 1.44217920e+00],
         [5.79633713e-01, 1.06949759e+00, 1.22481656e+00, ...,
          1.19459677e+00, 1.12745988e+00, 1.50603473e+00],
         [6.97608709e-01, 1.55525589e+00, 1.15804243e+00, ...,
          1.32352829e+00, 1.74903929e+00, 1.68702316e+00],
         ...,
         [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
          5.78890800e-01, 7.94465423e-01, 9.24851716e-01],
         [0.00000000e+00, 1.03183210e-01, 2.57356226e-01, ...,
          3.60597074e-01, 3.74310911e-01, 7.80914187e-01],
         [2.18170881e-02, 0.00000000e+00, 0.00000000e+00, ...,
          2.18415737e-01, 3.11732471e-01, 7.04114318e-01]],

        [[4.64425862e-01, 7.25744128e-01, 6.27169371e-01, ...,
          6.85277581e-01, 5.75725138e-01, 3.38575602e-01],
         [4.18201059e-01, 4.63923573e-01, 4.44142163e-01, ...,
          4.63579386e-01, 4.70006764e-01, 3.40491116e-01],
        

In [57]:
n = 0
test = nn.utils.rnn.pad_packed_sequence(torch_output, batch_first=True)[0].data.numpy()[n]
compute_difference(np_output[n], test[:85])

AttributeError: 'tuple' object has no attribute 'batch_sizes'

In [None]:
torch_output[1]

In [None]:
temp = torch_output[0].data.numpy()

In [None]:
np.testing.assert_array_almost_equal(np_output[0], temp[0],decimal=5)

In [None]:
torch_tile(torch.tensor([[1,2,3,4]]), 0, 3)

In [None]:
flatten = result.view(1, -1, 2208)

In [None]:
flatten.size()

In [None]:
mask = torch.ByteTensor([[1]*3,[1]*3, [0]*3])
hehe = torch.rand((3,3))
hehe

In [None]:
torch.masked_select(hehe,mask)

In [None]:
mask = torch.zeros([6, 1, 10,10], dtype=torch.float)
mask[:,:,3:7,3:7] = 1

In [None]:
class mask(nn.Module):
    def __init__(self):
        super(mask, self).__init__()
        
    def forward(self,x, batch_info):
        c1,c2, w1,w2, h1,h2 = batch_info
        mask = torch.zeros_like(x, dtype=torch.float)
        mask[c1:c2, w1:w2, h1:h2] = 1
        return x * mask

In [None]:
mask_layer = mask()

In [None]:
def VariableRecurrent(batch_sizes, inner):
    def forward(input, hidden, weight):
        output = []
        input_offset = 0
        last_batch_size = batch_sizes[0]
        hiddens = []
        flat_hidden = not isinstance(hidden, tuple)
        if flat_hidden:
            hidden = (hidden,)
        for batch_size in batch_sizes:
            step_input = input[input_offset:input_offset + batch_size]
            input_offset += batch_size

            dec = last_batch_size - batch_size
            if dec > 0:
                hiddens.append(tuple(h[-dec:] for h in hidden))
                hidden = tuple(h[:-dec] for h in hidden)
            last_batch_size = batch_size

            if flat_hidden:
                hidden = (inner(step_input, hidden[0], *weight),)
            else:
                hidden = inner(step_input, hidden, *weight)

            output.append(hidden[0])
        hiddens.append(hidden)
        hiddens.reverse()

        hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
        assert hidden[0].size(0) == batch_sizes[0]
        if flat_hidden:
            hidden = hidden[0]
        output = torch.cat(output, 0)

        return hidden, output
    
    return forward

In [None]:
a = torch.tensor([[11]*3,[12]*3, [13]*3])
b = torch.tensor([[21]*3,[22]*3, [23]*3, [24]*3])
c = torch.tensor([[31]*3,[32]*3, [33]*3, [34]*3, [35]*3])
seq = nn.utils.rnn.pack_sequence([c,b,a], )
input, batch_sizes, _,_ = seq

In [None]:
seq.unsorted_indices = seq.to

In [None]:
seq.unsorted_indices

In [None]:
def gru_test(input, hidden, weight):
    h, c = hidden
    hy = h + 1
    cy = c + 1
    return hy, cy 

In [None]:
forward = VariableRecurrent(batch_sizes, gru_test)

In [None]:
forward(input, (torch.tensor([0]*3), torch.tensor([0]*3)), [0])