In [1]:
import torch
import torch.nn as nn
import numpy as np

## extract paddlepaddle weights

In [2]:
def load_parameter(file_name):
    with open(file_name, 'rb') as f:
        f.read(16)  # skip header.
        return np.fromfile(f, dtype=np.float32)
    

After two conv layers, the output will be transform from an image to a sequence. The dim of each frame is 1312. 
The input image size is (161, X) where X is the length of the image. 
In paddlepaddle.v2, the kernel size of the conv1 is (11,41) and the kernel size of the conv2 is (11,21).
In order to get correct output, I find that second-dim of both kernels is related to the fisrt-dim of the input image.

So I guess I should switch the the position of these two dims for kernel, stride and padding in pytorch.

In [3]:
test = load_parameter("baidu_models/baidu_en8k/params/___conv_0__.w0")

In [21]:
# v2 to fluid: reshape to [ochannel, inchannel, filer_height, filer_width], height is x-axis while width is y-axis
test = test.reshape(32, 1, 11,41)

In [22]:
conv = nn.Conv2d(1, 32, kernel_size=(11, 41), stride=(3, 2), padding=(5, 20), bias=False)

In [23]:
conv.weight.data = torch.from_numpy(conv1_weights)

In [24]:
for i in conv.parameters():
    assert torch.all(torch.eq(i, torch.from_numpy(conv1_weights)))

In [28]:
input = torch.rand((3,1,255, 161))

In [36]:
input.type()

'torch.FloatTensor'

In [29]:
result = conv(input)

In [30]:
result.size()

torch.Size([3, 32, 85, 81])

In [44]:
test.transpose(0,1,3,2)

array([[[[ 9.29311588e-02,  1.35485291e-01, -4.28381890e-01, ...,
          -2.97493190e-01,  6.46222010e-02,  9.92319882e-02],
         [-4.26919967e-01, -1.15375303e-01, -6.09744966e-01, ...,
           2.34946698e-01, -3.63475271e-02,  2.28529319e-01],
         [ 2.62146175e-01, -8.91227052e-02, -4.09002453e-01, ...,
           1.52953044e-01,  1.15739137e-01, -5.51388562e-02],
         ...,
         [ 3.39823887e-02,  1.53475493e-01, -1.54391572e-01, ...,
           1.09738052e-01, -3.44127007e-02,  3.09078366e-01],
         [ 9.09232646e-02, -3.43415588e-02, -3.24262597e-04, ...,
          -1.17909603e-01, -3.33734527e-02, -1.92636419e-02],
         [ 2.30847187e-02, -1.79902717e-01, -6.95237964e-02, ...,
          -3.12168628e-01,  1.49332166e-01, -5.31788319e-02]]],


       [[[ 1.52736083e-02,  1.70024708e-02, -4.49127220e-02, ...,
           8.97117108e-02,  3.27343680e-02, -2.51081754e-02],
         [ 9.26711969e-03,  6.20251521e-02, -5.39838634e-02, ...,
           4.5889027

# Batch mapping 

In [13]:
batch_test = nn.BatchNorm2d(32)

In [14]:
bn1_mean  = load_parameter("baidu_models/baidu_en8k/params/___batch_norm_0__.w1")
bn1_var   = load_parameter("baidu_models/baidu_en8k/params/___batch_norm_0__.w2")
bn1_gamma = load_parameter("baidu_models/baidu_en8k/params/___batch_norm_0__.w0")
bn1_beta  = load_parameter("baidu_models/baidu_en8k/params/___batch_norm_0__.wbias")

In [15]:
batch_test.bias.data         = torch.from_numpy(bn1_beta)
batch_test.weight.data       = torch.from_numpy(bn1_gamma)
batch_test.running_mean.data = torch.from_numpy(bn1_mean)
batch_test.running_var.data  = torch.from_numpy(bn1_var)

In [16]:
weight_dict = {"bias"        :torch.from_numpy(bn1_beta),
               "weight"      : torch.from_numpy(bn1_gamma),
               "running_mean": torch.from_numpy(bn1_mean),
               "running_var" : torch.from_numpy(bn1_var),
              }

In [17]:
result_batch = batch_test(result)
result_batch.size()

torch.Size([1, 32, 81, 2208])

In [18]:
for key, value in batch_test.state_dict().items():
    if key == 'num_batches_tracked':
        continue
    print(value.size())
    assert torch.all(torch.eq(value, weight_dict[key] ))

torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([32])


# Conv_bn_mask

In [19]:
class conv_bn_mask(nn.Module):
    def __init__(self, ichannel, ochannel, kernel_size, padding, stride, bias=False, track_running_stats=False):
        super(conv_bn_mask, self).__init__()
        self.conv = nn.Conv2d(ichannel, ochannel, kernel_size=kernel_size, padding=padding, stride=stride, bias=bias)
        self.bn = nn.BatchNorm2d(ochannel, track_running_stats=track_running_stats)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

In [20]:
conv_bn_mask_layer = conv_bn_mask(1, 32, kernel_size=(41,11), padding=(20, 5), stride=(2,3), track_running_stats=True)
result_conv_bn_mask = conv_bn_mask_layer.forward(input)

In [21]:
pretrained_weights = {"conv.weight"    : torch.from_numpy(conv1_weights),
                      "bn.bias"        : torch.from_numpy(bn1_beta),
                      "bn.weight"      : torch.from_numpy(bn1_gamma),
                      "bn.running_mean": torch.from_numpy(bn1_mean),
                      "bn.running_var" : torch.from_numpy(bn1_var)
                     }

In [22]:
conv_bn_mask_layer.load_state_dict(pretrained_weights)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [23]:
check_dict = conv_bn_mask_layer.state_dict()
for key, value in check_dict.items():
    if key == 'bn.num_batches_tracked':
        continue
    assert torch.all(torch.eq(value, pretrained_weights[key]))
    

In [24]:
input = torch.rand((1,1,161,6624))
result = conv_bn_mask_layer.forward(input)
result.size()

torch.Size([1, 32, 81, 2208])

In [28]:
flatten = result.view(1, -1, 2208)

In [29]:
flatten.size()

torch.Size([1, 2592, 2208])

In [36]:
mask = torch.ByteTensor([[1]*3,[1]*3, [0]*3])
hehe = torch.rand((3,3))
hehe

tensor([[0.0750, 0.6854, 0.1815],
        [0.0151, 0.4943, 0.6329],
        [0.8049, 0.5829, 0.0303]])

In [37]:
torch.masked_select(hehe,mask)

tensor([0.0750, 0.6854, 0.1815, 0.0151, 0.4943, 0.6329])

In [57]:
mask = torch.zeros([6, 1, 10,10], dtype=torch.float)
mask[:,:,3:7,3:7] = 1

In [72]:
class mask(nn.Module):
    def __init__(self):
        super(mask, self).__init__()
        
    def forward(self,x, batch_info):
        c1,c2, w1,w2, h1,h2 = batch_info
        mask = torch.zeros_like(x, dtype=torch.float)
        mask[c1:c2, w1:w2, h1:h2] = 1
        return x * mask

In [73]:
mask_layer = mask()

In [20]:
def VariableRecurrent(batch_sizes, inner):
    def forward(input, hidden, weight):
        output = []
        input_offset = 0
        last_batch_size = batch_sizes[0]
        hiddens = []
        flat_hidden = not isinstance(hidden, tuple)
        if flat_hidden:
            hidden = (hidden,)
        for batch_size in batch_sizes:
            step_input = input[input_offset:input_offset + batch_size]
            input_offset += batch_size

            dec = last_batch_size - batch_size
            if dec > 0:
                hiddens.append(tuple(h[-dec:] for h in hidden))
                hidden = tuple(h[:-dec] for h in hidden)
            last_batch_size = batch_size

            if flat_hidden:
                hidden = (inner(step_input, hidden[0], *weight),)
            else:
                hidden = inner(step_input, hidden, *weight)

            output.append(hidden[0])
        hiddens.append(hidden)
        hiddens.reverse()

        hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
        assert hidden[0].size(0) == batch_sizes[0]
        if flat_hidden:
            hidden = hidden[0]
        output = torch.cat(output, 0)

        return hidden, output
    
    return forward

In [54]:
a = torch.tensor([[11]*3,[12]*3, [13]*3])
b = torch.tensor([[21]*3,[22]*3, [23]*3, [24]*3])
c = torch.tensor([[31]*3,[32]*3, [33]*3, [34]*3, [35]*3])
seq = nn.utils.rnn.pack_sequence([c,b,a], )
input, batch_sizes, _,_ = seq

In [55]:
seq.unsorted_indices = seq.to

PackedSequence(data=tensor([[31, 31, 31],
        [21, 21, 21],
        [11, 11, 11],
        [32, 32, 32],
        [22, 22, 22],
        [12, 12, 12],
        [33, 33, 33],
        [23, 23, 23],
        [13, 13, 13],
        [34, 34, 34],
        [24, 24, 24],
        [35, 35, 35]]), batch_sizes=tensor([3, 3, 3, 2, 1]), sorted_indices=None, unsorted_indices=None)

In [52]:
seq.unsorted_indices

In [48]:
def gru_test(input, hidden, weight):
    h, c = hidden
    hy = h + 1
    cy = c + 1
    return hy, cy 

In [49]:
forward = VariableRecurrent(batch_sizes, gru_test)

In [50]:
forward(input, (torch.tensor([0]*3), torch.tensor([0]*3)), [0])

((tensor([5, 4, 3]), tensor([5, 4, 3])),
 tensor([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5]))