In [1]:
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Parameter
from torch.nn.utils.rnn import PackedSequence
import numpy as np
from typing import List, Tuple
import sys
sys.path.append("../model_utils/")
from custom_layers import Conv_bn_mask
from custom_gru import BidirRNNLayer, GRUCell, GRU_hiddenCell,RNNLayer

## extract paddlepaddle weights

In [2]:
def load_parameter(file_name):
    with open(file_name, 'rb') as f:
        f.read(16)  # skip header.
        return np.fromfile(f, dtype=np.float32)
    
    
def compute_difference(paddle_outputs, torch_outputs):
    return np.sqrt(np.mean((paddle_outputs - torch_outputs)**2.0))

In [3]:
class GRUlayer(nn.Module):
    def __init__(self, cell, input_size, hidden_size, gate_act, state_act):
        super(GRUlayer, self).__init__()
        
        self.f_weight_i = Parameter(torch.randn(input_size,  3 * hidden_size))
        self.b_weight_i = Parameter(torch.randn(input_size,  3 * hidden_size))
        self.f_bn = nn.BatchNorm1d(3 * hidden_size)
        self.b_bn = nn.BatchNorm1d(3 * hidden_size)
        self.rnn = BidirRNNLayer(cell, hidden_size=hidden_size, gate_act=gate_act, state_act=state_act)
        
    def forward(self, input: PackedSequence, hidden: List[Tensor]) -> Tuple[PackedSequence, List[Tensor]]:
        assert isinstance(input, PackedSequence)
        x, batch_sizes, sorted_indices, unsorted_indices = input

        f_gates_input = torch.mm(x, self.f_weight_i)
        f_gates_input = self.f_bn(f_gates_input) # deepspeech sequence-wise normalization of input part 
        
        b_gates_input = torch.mm(x, self.b_weight_i)
        b_gates_input = self.b_bn(b_gates_input) 
        
        f_input = PackedSequence(f_gates_input, batch_sizes, sorted_indices, unsorted_indices)
        b_input = PackedSequence(b_gates_input, batch_sizes, sorted_indices, unsorted_indices)
        
        return self.rnn([f_input, b_input], hidden)

In [4]:
class cbmX2_bigru_layer(nn.Module):
    def __init__(self):
        super(cbmX2_bigru_layer, self).__init__()
        self.conv_bn_mask0 = Conv_bn_mask(ichannel=1,
                                          ochannel=32,
                                          kernel_size=(11, 41),
                                          stride=(3, 2),
                                          padding=(5, 20),
                                          bias=False,
                                          track_running_stats=True)

        self.conv_bn_mask1 = Conv_bn_mask(ichannel=32,
                                          ochannel=32,
                                          kernel_size=(11, 21),
                                          stride=(1, 2),
                                          padding=(5, 10),
                                          bias=False,
                                          track_running_stats=True)

        self.bigru0 = GRUlayer(GRU_hiddenCell, input_size=41 * 32, hidden_size=1024, gate_act="sigmoid", state_act="relu")
        self.bigru1 = GRUlayer(GRU_hiddenCell, input_size=2048, hidden_size=1024,gate_act="sigmoid", state_act="relu")
        self.bigru2 = GRUlayer(GRU_hiddenCell, input_size=2048, hidden_size=1024,gate_act="sigmoid", state_act="relu")

        # 28 of char + 1 blank
        # vocab list:  ["'", ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
        self.bottleneck = nn.Linear(2048, 28 + 1)
        self.softmax = nn.Softmax(dim=1)


    def forward(self, input, length):
        x = input
        seq_length = length 
        batch_size = x.shape[0]
        x = self.conv_bn_mask0(x, seq_length)
        x = self.conv_bn_mask1(x, seq_length)
        x = x.transpose(2, 1).contiguous()
        flattened_x = x.reshape(batch_size, -1, 41 * 32)
        
        # flattened_x = torch.transpose(flattened_x, 2,1)
        flattened_x = nn.utils.rnn.pack_padded_sequence(flattened_x, seq_length.flatten(), batch_first=True)
        flattened_x, _ = self.bigru0(flattened_x, [torch.zeros((2, batch_size, 1024)),] )
        flattened_x, _ = self.bigru1(flattened_x, [torch.zeros((2, batch_size, 1024)),] )
        flattened_x, _ = self.bigru2(flattened_x, [torch.zeros((2, batch_size, 1024)),] )

        
        data_x, batch_sizes, sorted_indices, unsorted_indices = flattened_x
        data_x = self.bottleneck(data_x)
        data_x = self.softmax(data_x)
        flattened_x = nn.utils.rnn.PackedSequence(data_x, batch_sizes, sorted_indices, unsorted_indices)
        
        return flattened_x


In [5]:
cbmX2_bigru_test = cbmX2_bigru_layer()
cbmX2_bigru_test.eval()

# TODO, double check the normalization part

cbmX2_bigru_layer(
  (conv_bn_mask0): Conv_bn_mask(
    (conv): Conv2d(1, 32, kernel_size=(11, 41), stride=(3, 2), padding=(5, 20), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): BReLU()
    (mask): Mask()
  )
  (conv_bn_mask1): Conv_bn_mask(
    (conv): Conv2d(32, 32, kernel_size=(11, 21), stride=(1, 2), padding=(5, 10), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): BReLU()
    (mask): Mask()
  )
  (bigru0): GRUlayer(
    (f_bn): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (b_bn): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rnn): BidirRNNLayer(
      (directions): ModuleList(
        (0): RNNLayer(
          (cell): GRU_hiddenCell()
        )
        (1): ReverseRNNLayer(
          original_name=ReverseRNNLayer
          (cell): GRU_hiddenCell()
        )
 

In [6]:
cbmX2_bigru_test.state_dict().keys()
# TODO: load parameters

odict_keys(['conv_bn_mask0.conv.weight', 'conv_bn_mask0.bn.weight', 'conv_bn_mask0.bn.bias', 'conv_bn_mask0.bn.running_mean', 'conv_bn_mask0.bn.running_var', 'conv_bn_mask0.bn.num_batches_tracked', 'conv_bn_mask1.conv.weight', 'conv_bn_mask1.bn.weight', 'conv_bn_mask1.bn.bias', 'conv_bn_mask1.bn.running_mean', 'conv_bn_mask1.bn.running_var', 'conv_bn_mask1.bn.num_batches_tracked', 'bigru0.f_weight_i', 'bigru0.b_weight_i', 'bigru0.f_bn.weight', 'bigru0.f_bn.bias', 'bigru0.f_bn.running_mean', 'bigru0.f_bn.running_var', 'bigru0.f_bn.num_batches_tracked', 'bigru0.b_bn.weight', 'bigru0.b_bn.bias', 'bigru0.b_bn.running_mean', 'bigru0.b_bn.running_var', 'bigru0.b_bn.num_batches_tracked', 'bigru0.rnn.directions.0.cell.weight_h', 'bigru0.rnn.directions.0.cell.bias', 'bigru0.rnn.directions.1.cell.weight_h', 'bigru0.rnn.directions.1.cell.bias', 'bigru1.f_weight_i', 'bigru1.b_weight_i', 'bigru1.f_bn.weight', 'bigru1.f_bn.bias', 'bigru1.f_bn.running_mean', 'bigru1.f_bn.running_var', 'bigru1.f_bn.nu

In [7]:
conv0_weights   = load_parameter("../models/baidu_en8k/params/___conv_0__.w0")
conv0_weights   = conv0_weights.reshape(32, 1, 41, 11)
conv0_weights   = np.transpose(conv0_weights, (0, 1, 3, 2))
conv0_bn_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w1")
conv0_bn_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w2")
conv0_bn_gamma = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.w0")
conv0_bn_beta  = load_parameter("../models/baidu_en8k/params/___batch_norm_0__.wbias")


conv1_weights = load_parameter("../models/baidu_en8k/params/___conv_1__.w0")
conv1_weights = conv1_weights.reshape(32, 32, 21, 11)
conv1_weights = np.transpose(conv1_weights, (0, 1, 3, 2))
conv1_bn_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w1")
conv1_bn_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w2")
conv1_bn_gamma = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.w0")
conv1_bn_beta  = load_parameter("../models/baidu_en8k/params/___batch_norm_1__.wbias")

# gru0
bigru0_directions_0_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_0__.w0")
bigru0_directions_0_cell_weight_i         = bigru0_directions_0_cell_weight_i.reshape(41 * 32, 1024*3)
bigru0_directions_0_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_0__.w0")
w_u_r = bigru0_directions_0_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c   = bigru0_directions_0_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru0_directions_0_cell_weight_h = np.concatenate([w_u_r,w_c], 1)
bigru0_directions_0_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_0__.wbias")
bigru0_directions_0_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.wbias")
bigru0_directions_0_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.w0")
bigru0_directions_0_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.w1")
bigru0_directions_0_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_2__.w2")


bigru0_directions_1_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_1__.w0")
bigru0_directions_1_cell_weight_i         = bigru0_directions_1_cell_weight_i.reshape(41 * 32, 1024*3)
bigru0_directions_1_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_1__.w0")
w_u_r = bigru0_directions_1_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c = bigru0_directions_1_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru0_directions_1_cell_weight_h = np.concatenate([w_u_r,w_c], 1)
bigru0_directions_1_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_1__.wbias")
bigru0_directions_1_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.wbias")
bigru0_directions_1_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.w0")
bigru0_directions_1_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.w1")
bigru0_directions_1_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_3__.w2")


# gru1
bigru1_directions_0_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_2__.w0")
bigru1_directions_0_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_2__.w0")
bigru1_directions_0_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_2__.wbias")
bigru1_directions_0_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_4__.wbias")
bigru1_directions_0_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_4__.w0")
bigru1_directions_0_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_4__.w1")
bigru1_directions_0_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_4__.w2")
bigru1_directions_0_cell_weight_i         = bigru1_directions_0_cell_weight_i.reshape(2048, 1024*3)
w_u_r = bigru1_directions_0_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c   = bigru1_directions_0_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru1_directions_0_cell_weight_h = np.concatenate([w_u_r,w_c], 1)


bigru1_directions_1_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_3__.w0")
bigru1_directions_1_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_3__.w0")
bigru1_directions_1_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_3__.wbias")
bigru1_directions_1_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_5__.wbias")
bigru1_directions_1_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_5__.w0")
bigru1_directions_1_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_5__.w1")
bigru1_directions_1_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_5__.w2")
bigru1_directions_1_cell_weight_i         = bigru1_directions_1_cell_weight_i.reshape(2048, 1024*3)
w_u_r = bigru1_directions_1_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c   = bigru1_directions_1_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru1_directions_1_cell_weight_h = np.concatenate([w_u_r,w_c], 1)

# gru2
bigru2_directions_0_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_4__.w0")
bigru2_directions_0_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_4__.w0")
bigru2_directions_0_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_4__.wbias")
bigru2_directions_0_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_6__.wbias")
bigru2_directions_0_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_6__.w0")
bigru2_directions_0_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_6__.w1")
bigru2_directions_0_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_6__.w2")
bigru2_directions_0_cell_weight_i         = bigru2_directions_0_cell_weight_i.reshape(2048, 1024*3)
w_u_r = bigru2_directions_0_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c   = bigru2_directions_0_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru2_directions_0_cell_weight_h = np.concatenate([w_u_r,w_c], 1)


bigru2_directions_1_cell_weight_i         = load_parameter("../models/baidu_en8k/params/___fc_layer_5__.w0")
bigru2_directions_1_cell_weight_h         = load_parameter("../models/baidu_en8k/params/___gru_5__.w0")
bigru2_directions_1_cell_bias             = load_parameter("../models/baidu_en8k/params/___gru_5__.wbias")
bigru2_directions_1_cell_bn_bias          = load_parameter("../models/baidu_en8k/params/___batch_norm_7__.wbias")
bigru2_directions_1_cell_bn_weight        = load_parameter("../models/baidu_en8k/params/___batch_norm_7__.w0")
bigru2_directions_1_cell_bn_running_mean  = load_parameter("../models/baidu_en8k/params/___batch_norm_7__.w1")
bigru2_directions_1_cell_bn_running_var   = load_parameter("../models/baidu_en8k/params/___batch_norm_7__.w2")
bigru2_directions_1_cell_weight_i         = bigru2_directions_1_cell_weight_i.reshape(2048, 1024*3)
w_u_r = bigru2_directions_1_cell_weight_h.flatten()[:1024*1024*2].reshape(1024,1024*2)
w_c   = bigru2_directions_1_cell_weight_h.flatten()[1024*1024*2:].reshape(1024,1024)
bigru2_directions_1_cell_weight_h = np.concatenate([w_u_r,w_c], 1)


battleneck_weight                          = load_parameter("../models/baidu_en8k/params/___fc_layer_6__.w0")
battleneck_weight                          = battleneck_weight.reshape(2048, 29).transpose(1,0)
battleneck_bias                            = load_parameter("../models/baidu_en8k/params/___fc_layer_6__.wbias")

In [8]:
bigru2_directions_1_cell_weight_h.shape

(1024, 3072)

In [9]:
pretrained_weights = {   "conv_bn_mask0.conv.weight"                 : conv0_weights,
                         "conv_bn_mask0.bn.weight"                   : conv0_bn_gamma,
                         "conv_bn_mask0.bn.bias"                     : conv0_bn_beta,
                         "conv_bn_mask0.bn.running_mean"             : conv0_bn_mean,
                         "conv_bn_mask0.bn.running_var"              : conv0_bn_var ,
                         "conv_bn_mask1.conv.weight"                 : conv1_weights,
                         "conv_bn_mask1.bn.weight"                   : conv1_bn_gamma,
                         "conv_bn_mask1.bn.bias"                     : conv1_bn_beta,
                         "conv_bn_mask1.bn.running_mean"             : conv1_bn_mean,
                         "conv_bn_mask1.bn.running_var"              : conv1_bn_var ,
                      
 
                         "bigru0.f_weight_i"                         : bigru0_directions_0_cell_weight_i,
                         "bigru0.rnn.directions.0.cell.weight_h"     : bigru0_directions_0_cell_weight_h,
                         "bigru0.rnn.directions.0.cell.bias"         : bigru0_directions_0_cell_bias,
                         "bigru0.f_bn.bias"                          : bigru0_directions_0_cell_bn_bias        ,
                         "bigru0.f_bn.weight"                        : bigru0_directions_0_cell_bn_weight      ,
                         "bigru0.f_bn.running_mean"                  : bigru0_directions_0_cell_bn_running_mean,
                         "bigru0.f_bn.running_var"                   : bigru0_directions_0_cell_bn_running_var ,
                         "bigru0.b_weight_i"                         : bigru0_directions_1_cell_weight_i,
                         "bigru0.rnn.directions.1.cell.weight_h"     : bigru0_directions_1_cell_weight_h,
                         "bigru0.rnn.directions.1.cell.bias"         : bigru0_directions_1_cell_bias,
                         "bigru0.b_bn.bias"                          : bigru0_directions_1_cell_bn_bias        ,
                         "bigru0.b_bn.weight"                        : bigru0_directions_1_cell_bn_weight      ,
                         "bigru0.b_bn.running_mean"                  : bigru0_directions_1_cell_bn_running_mean,
                         "bigru0.b_bn.running_var"                   : bigru0_directions_1_cell_bn_running_var ,
                      
                         "bigru1.f_weight_i"                         : bigru1_directions_0_cell_weight_i,
                         "bigru1.rnn.directions.0.cell.weight_h"     : bigru1_directions_0_cell_weight_h,
                         "bigru1.rnn.directions.0.cell.bias"         : bigru1_directions_0_cell_bias,
                         "bigru1.f_bn.bias"                          : bigru1_directions_0_cell_bn_bias        ,
                         "bigru1.f_bn.weight"                        : bigru1_directions_0_cell_bn_weight      ,
                         "bigru1.f_bn.running_mean"                  : bigru1_directions_0_cell_bn_running_mean,
                         "bigru1.f_bn.running_var"                   : bigru1_directions_0_cell_bn_running_var ,
                         "bigru1.b_weight_i"                         : bigru1_directions_1_cell_weight_i,
                         "bigru1.rnn.directions.1.cell.weight_h"     : bigru1_directions_1_cell_weight_h,
                         "bigru1.rnn.directions.1.cell.bias"         : bigru1_directions_1_cell_bias,
                         "bigru1.b_bn.bias"                          : bigru1_directions_1_cell_bn_bias        ,
                         "bigru1.b_bn.weight"                        : bigru1_directions_1_cell_bn_weight      ,
                         "bigru1.b_bn.running_mean"                  : bigru1_directions_1_cell_bn_running_mean,
                         "bigru1.b_bn.running_var"                   : bigru1_directions_1_cell_bn_running_var ,
                      
                         "bigru2.f_weight_i"                         : bigru2_directions_0_cell_weight_i,
                         "bigru2.rnn.directions.0.cell.weight_h"     : bigru2_directions_0_cell_weight_h,
                         "bigru2.rnn.directions.0.cell.bias"         : bigru2_directions_0_cell_bias,
                         "bigru2.f_bn.bias"                          : bigru2_directions_0_cell_bn_bias        ,
                         "bigru2.f_bn.weight"                        : bigru2_directions_0_cell_bn_weight      ,
                         "bigru2.f_bn.running_mean"                  : bigru2_directions_0_cell_bn_running_mean,
                         "bigru2.f_bn.running_var"                   : bigru2_directions_0_cell_bn_running_var ,
                         "bigru2.b_weight_i"                         : bigru2_directions_1_cell_weight_i,
                         "bigru2.rnn.directions.1.cell.weight_h"     : bigru2_directions_1_cell_weight_h,
                         "bigru2.rnn.directions.1.cell.bias"         : bigru2_directions_1_cell_bias,
                         "bigru2.b_bn.bias"                          : bigru2_directions_1_cell_bn_bias        ,
                         "bigru2.b_bn.weight"                        : bigru2_directions_1_cell_bn_weight      ,
                         "bigru2.b_bn.running_mean"                  : bigru2_directions_1_cell_bn_running_mean,
                         "bigru2.b_bn.running_var"                   : bigru2_directions_1_cell_bn_running_var ,
                      
                         "bottleneck.weight"                         : battleneck_weight, 
                         "bottleneck.bias"                           : battleneck_bias,
                     }

In [10]:
check_dict = cbmX2_bigru_test.state_dict()
for key in check_dict:
    if 'num_batches_tracked' in key:
        continue
        
    assert key in pretrained_weights
    check_dict[key] = torch.from_numpy(pretrained_weights[key])
cbmX2_bigru_test.load_state_dict(check_dict)

<All keys matched successfully>

In [11]:
np_input = np.load("paddle/cbmX2_bigru_f_input.npy", allow_pickle=True, encoding="bytes")
np_output = np.load("paddle/cbmX2_bigru_f_output.npy", allow_pickle=True, encoding="bytes")
temp_data = []
temp_length = []
for i in np_input:
    temp_data.append(i[0])
    temp_length.append(i[3])
np_input = np.array(temp_data)
np_length = np.array(temp_length).astype(int)

torch_input = torch.from_numpy(np_input)
# torch_input = torch_input.type(torch.FloatTensor)
np.testing.assert_array_almost_equal(torch_input, np_input, decimal=10)

torch_input = torch.unsqueeze(torch_input, 1)
torch_input = torch_input.transpose(3,2)

torch_output = cbmX2_bigru_test(torch_input, torch.from_numpy(np_length))

decimal = 5

In [12]:
n = 0
test = nn.utils.rnn.pad_packed_sequence(torch_output, batch_first=True)[0].data.numpy()[n]
compute_difference(np_output[n], test[:85])

5.3051828e-08

In [13]:
n = 1
test = nn.utils.rnn.pad_packed_sequence(torch_output, batch_first=True)[0].data.numpy()[n]
compute_difference(np_output[n], test[:50])

3.2889425e-08

In [14]:
n = 2
test = nn.utils.rnn.pad_packed_sequence(torch_output, batch_first=True)[0].data.numpy()[n]
compute_difference(np_output[n], test[:34])

5.0018734e-08

In [15]:
np_output[0].shape

(85, 29)

In [16]:
cbmX2_bigru_test

cbmX2_bigru_layer(
  (conv_bn_mask0): Conv_bn_mask(
    (conv): Conv2d(1, 32, kernel_size=(11, 41), stride=(3, 2), padding=(5, 20), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): BReLU()
    (mask): Mask()
  )
  (conv_bn_mask1): Conv_bn_mask(
    (conv): Conv2d(32, 32, kernel_size=(11, 21), stride=(1, 2), padding=(5, 10), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): BReLU()
    (mask): Mask()
  )
  (bigru0): GRUlayer(
    (f_bn): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (b_bn): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rnn): BidirRNNLayer(
      (directions): ModuleList(
        (0): RNNLayer(
          (cell): GRU_hiddenCell()
        )
        (1): ReverseRNNLayer(
          original_name=ReverseRNNLayer
          (cell): GRU_hiddenCell()
        )
 

In [17]:
torch_input.shape

torch.Size([3, 1, 255, 161])