### Comparing `deepspeech_internal` and `myrtlespeech` rnnt implementations
Conclusion - there were differences that made the results not equivalent:
1. The packed sequence in the `myrtlespeech` implementation requires the length of the sequence. This was off by one because the SOS token was added. This means the final element in the `myrtlespeech` target sequence was not being considered during training. When this is fixed, the WER = 24.6058% on dev-clean. 
2. There is no subsampling in the preprocessing of `myrtlespeech` so the network recieves audio inputs at twice the resolution that the weights were trained at in `deepspeech_internal`.
3. Different preprocessing `python_speech_features` vs `torchaudio` this may be simply because the values are not exactly the same OR it may suggest the `myrtlespeech` pp is not good at extracting features.
4. Activations - x2 extra hardtanhs in encoder. 
3. Left context vs symmetric context - this may be relevant at the start of the sequence - i.e. the network learns it is at the SOS when there are zeros in the leftmost frames??

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ..

In [None]:
from dsi.rnnt_dsi import Network as RNNT_DSI
from myrtlespeech.builders.rnn_t import build as build_rnn_t
from myrtlespeech.protos import task_config_pb2


import torch
from google.protobuf import text_format

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
input_features = 80
ncontext = 4
vocab_size=28
rnnt_dsi = RNNT_DSI(input_features * ncontext, vocab_size)
rnnt_dsi.is_half = False
rnnt_dsi.eval()


In [None]:
#build rnnt_ms
with open("src/myrtlespeech/configs/rnn_t_en_ds_int.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

rnnt_ms = build_rnn_t(task_config.speech_to_text.rnn_t, 
                      input_features=input_features,
                      input_channels=ncontext + 1,
                      vocab_size = vocab_size)
rnnt_ms.eval()

In [None]:
ms_params = sum([p.numel() for _, p in rnnt_ms.named_parameters()])
print(ms_params)
dsi_params = sum([p.numel() for _, p in rnnt_dsi.named_parameters()])
print(dsi_params)


### Make weights the same

In [None]:
                   #ds_int: ms
dict_map_partial = {"encoder.0": "encode.fc1.fully_connected.0",
               "encoder.3": "encode.fc1.fully_connected.3",
               "encoder.6.layers.0": "encode.rnn1",
               "encoder.6.layers.2.rnn.weight_ih_l0": "encode.rnn1.rnn.weight_ih_l1",
               "encoder.6.layers.2.rnn.weight_hh_l0": "encode.rnn1.rnn.weight_hh_l1",
               "encoder.6.layers.2.rnn.bias_ih_l0": "encode.rnn1.rnn.bias_ih_l1",
               "encoder.6.layers.2.rnn.bias_hh_l0": "encode.rnn1.rnn.bias_hh_l1",
                "encoder.8": "encode.fc2.fully_connected.0",
                "encoder.11": "encode.fc2.fully_connected.3",
                "prediction.dec_rnn.layers.0": "predict_net.dec_rnn",
                "prediction.dec_rnn.layers.2.rnn.weight_ih_l0": "predict_net.dec_rnn.rnn.weight_ih_l1",
                "prediction.dec_rnn.layers.2.rnn.weight_hh_l0": "predict_net.dec_rnn.rnn.weight_hh_l1",
                "prediction.dec_rnn.layers.2.rnn.bias_ih_l0": "predict_net.dec_rnn.rnn.bias_ih_l1",
                "prediction.dec_rnn.layers.2.rnn.bias_hh_l0": "predict_net.dec_rnn.rnn.bias_hh_l1",
                "prediction.embed": "predict_net.embed",
            "joint_net.0": "joint_net.fully_connected.fully_connected.0",
            "joint_net.3": "joint_net.fully_connected.fully_connected.3"}

def get_keys(model_):
    keys = []
    for k, _ in model_.named_parameters():
        keys.append(k)
    return keys

dict_map = {}
ms_keys = get_keys(rnnt_ms) 
dsi_keys = get_keys(rnnt_dsi)
for mskey in ms_keys:
    found_key = False
    for p_dsikey, p_mskey in dict_map_partial.items():
        
        if p_mskey in mskey:
            dsikey = mskey.replace(p_mskey, p_dsikey)
            dict_map[dsikey] = mskey
            found_key = True
    assert found_key == True, f"Did not find key={mskey}"
dict_map

In [None]:
## update to same params (put dsi weights into ms network)
state_dict_ms = rnnt_ms.state_dict()

for dsikey, param in rnnt_dsi.named_parameters():
    mskey = dict_map[dsikey]
    state_dict_ms[mskey] = param


rnnt_ms.load_state_dict(state_dict_ms)


In [None]:
rnnt_dsi.cpu()
rnnt_ms.cpu()

### Create input data

In [None]:
batch = 2
input_channels = ncontext + 1
label_seq_len = 3
seq_len = 4


x = torch.empty((batch, input_channels, input_features, seq_len)).normal_()
seq_lens = torch.randint(
    low=1, high=seq_len, size=(batch,), dtype=torch.long
)
y = torch.randint(
    low=0,
    high=vocab_size - 1,
    size=(batch, label_seq_len),
    dtype=torch.long,
)
label_seq_lens = torch.randint(
    low=1, high=label_seq_len, size=(batch,), dtype=torch.long
)
input_ms = ((x, y), (seq_lens, label_seq_lens))


## now create for dsi
x_dsi = x[:,1:] #

_, C, _, _ = x_dsi.shape # get new channel size

assert C == ncontext
x_dsi = x_dsi.view(batch, C * input_features, seq_len) #B, F, T
x_dsi = x_dsi.permute(2, 0, 1).contiguous()

input_dsi = (x_dsi, y)


In [None]:
fixed_len = True
if fixed_len:
    x = torch.empty((batch, input_channels, input_features, seq_len)).normal_()

    seq_lens = torch.randint(
        low=1, high=seq_len, size=(batch,), dtype=torch.long
    )
    seq_lens = torch.IntTensor([seq_len] * batch)
    y = torch.randint(
        low=0,
        high=vocab_size - 1,
        size=(batch, label_seq_len),
        dtype=torch.long,
    )

    label_seq_lens = torch.randint(
        low=1, high=label_seq_len, size=(batch,), dtype=torch.long
    )
    label_seq_lens = torch.IntTensor([label_seq_len] * batch)

    input_ms = ((x, y), (seq_lens, label_seq_lens))


    ## now create for dsi
    x_dsi = x[:,1:] #

    _, C, _, _ = x_dsi.shape # get new channel size

    assert C == ncontext, f"{C} != {ncontext}"
    x_dsi = x_dsi.view(batch, C * input_features, seq_len) #B, F, T
    x_dsi = x_dsi.permute(2, 0, 1).contiguous()

    input_dsi = (x_dsi, y)

In [None]:
ms_out, _ = rnnt_ms(input_ms)
ms_out = ms_out.cpu()
ms_out.shape

dsi_out = rnnt_dsi(input_dsi)
dsi_out.shape

In [None]:
assert torch.allclose(dsi_out, ms_out)

### check individual elements of network

In [None]:
#encoder 

enc_dsi_out = rnnt_dsi.encode(input_dsi[0])
enc_ms_out = rnnt_ms.encode((input_ms[0][0], input_ms[1][0]))
enc_dsi_out = enc_dsi_out.cpu()
enc_ms_out = enc_ms_out[0].cpu(), enc_ms_out[1]

assert torch.allclose(enc_dsi_out, enc_ms_out[0].transpose(0,1))

In [None]:
# prediction
pred_dsi_out, _ = rnnt_dsi.predict(input_dsi[1])
pred_ms_out = rnnt_ms.prediction((input_ms[0][1], input_ms[1][1]))

print(pred_dsi_out.shape)
print(pred_ms_out[0].shape)
assert torch.allclose(pred_dsi_out, pred_ms_out[0])

In [None]:
# dec_rnn
y = torch.empty((batch, label_seq_len, 256)).normal_()
y_dsi = y.transpose(1, 0)
y_lens = input_ms[1][1]
y_lens = torch.IntTensor([label_seq_len, label_seq_len])
state = None

dec_rnn_dsi_out, dec_rnn_dsi_hid = rnnt_dsi.prediction["dec_rnn"](y_dsi, state)
(dec_rnn_ms_out, dec_rnn_ms_hid), lengths = rnnt_ms.dec_rnn(((y, state), y_lens))
lengths, label_seq_len

In [None]:

dec_rnn_dsi_hid = dec_rnn_dsi_hid[0].squeeze(1), dec_rnn_dsi_hid[1].squeeze(1)
assert torch.allclose(dec_rnn_dsi_out, dec_rnn_ms_out.transpose(1, 0))
assert torch.allclose(dec_rnn_ms_hid[0], dec_rnn_dsi_hid[0])
assert torch.allclose(dec_rnn_ms_hid[1], dec_rnn_dsi_hid[1])


In [None]:


for idx, (dsi_name, dsi_param) in enumerate(rnnt_dsi.prediction["dec_rnn"].named_parameters()):
    for idx_2, (rnnt_name, rnnt_param) in enumerate(rnnt_ms.dec_rnn.named_parameters()):
        if idx == idx_2:
            print(dsi_name, rnnt_param.type(), dsi_param.type())
            assert torch.allclose(rnnt_param, dsi_param)
            break

In [None]:
assert torch.allclose(dec_rnn_dsi_hid[0], dec_rnn_ms_hid[0])
assert torch.allclose(dec_rnn_ms_out[0], dec_rnn_dsi_out.transpose(1, 0))

In [None]:
for name, param in rnnt_ms.named_parameters():
    print(param.type())