# RNN-T Load weights


In [None]:
log_dir = "/home/julian/exp/speech/myrtlespeech/rnnt/debug/2/"

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnn_t import RNNTEncoder, RNNT
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

In [None]:
from myrtlespeech.run.train import run_stage
from myrtlespeech.run.callbacks.callback import CallbackHandler

In [None]:
torch.backends.cudnn.benchmark = False # since variable size inputs
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
weights_fp = "/home/julian/models/rnnt-53.pt"

model = torch.load(weights_fp)
for key, val in model["network"].items():
    print(key, val.shape)



In [None]:
torch.load(weights_fp)['network']["joint_net.3.bias"]

Build the RNNT model defined in the config file:

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en_ds_int.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())


In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)


In [None]:
seq_to_seq

In [None]:
                    #ds_int: ms
dict_map_partial = {"encoder.0": "encode.fc1.fully_connected.0",
               "encoder.3": "encode.fc1.fully_connected.3",
               "encoder.6.layers.0": "encode.rnn1",
               "encoder.6.layers.2.rnn.weight_ih_l0": "encode.rnn1.rnn.weight_ih_l1",
               "encoder.6.layers.2.rnn.weight_hh_l0": "encode.rnn1.rnn.weight_hh_l1",
               "encoder.6.layers.2.rnn.bias_ih_l0": "encode.rnn1.rnn.bias_ih_l1",
               "encoder.6.layers.2.rnn.bias_hh_l0": "encode.rnn1.rnn.bias_hh_l1",
                "encoder.8": "encode.fc2.fully_connected.0",
                "encoder.11": "encode.fc2.fully_connected.3",
                "prediction.dec_rnn.layers.0": "predict_net.dec_rnn",
                "prediction.dec_rnn.layers.2.rnn.weight_ih_l0": "predict_net.dec_rnn.rnn.weight_ih_l1",
                "prediction.dec_rnn.layers.2.rnn.weight_hh_l0": "predict_net.dec_rnn.rnn.weight_hh_l1",
                "prediction.dec_rnn.layers.2.rnn.bias_ih_l0": "predict_net.dec_rnn.rnn.bias_ih_l1",
                "prediction.dec_rnn.layers.2.rnn.bias_hh_l0": "predict_net.dec_rnn.rnn.bias_hh_l1",
                "prediction.embed": "predict_net.embed",
            "joint_net.0": "joint_net.fully_connected.fully_connected.0",
            "joint_net.3": "joint_net.fully_connected.fully_connected.3"}



In [None]:
def get_keys(model_):
    keys = []
    for k, _ in model_.named_parameters():
        keys.append(k)
    return keys

dict_map = {}
ms_keys = get_keys(seq_to_seq.model) 
dsi_keys = model["network"].keys()
for mskey in ms_keys:
    found_key = False
    for p_dsikey, p_mskey in dict_map_partial.items():
        
        if p_mskey in mskey:
            dsikey = mskey.replace(p_mskey, p_dsikey)
            dict_map[dsikey] = mskey
            found_key = True
    assert found_key == True, f"Did not find key={mskey}"
dict_map

In [None]:
## pad input fc with zeros:
add_zeros = False
if add_zeros:
    val = model["network"]["encoder.0.weight"]
    val_saved = val
    print(val.shape, type(val), val.dtype, val.type())
    zeros = torch.zeros(1152, 80).to(val.device).type(val.dtype)
    val = torch.cat([zeros, val], dim=-1).contiguous()
    assert val.shape ==  (1152, 400)
    model["network"]["encoder.0.weight"] = val

In [None]:
## update to new params

state_dict = seq_to_seq.model.state_dict()

for dsikey, param in model["network"].items():
    mskey = dict_map[dsikey]
    state_dict[mskey] = param


seq_to_seq.model.load_state_dict(state_dict)


In [None]:
seq_to_seq.model.state_dict()["joint_net.fully_connected.fully_connected.3.bias"]

In [None]:
## assert it has worked
for key, param in seq_to_seq.model.named_parameters():
    found = False
    for dsparam in model["network"].values():
        if param.shape == dsparam.shape and torch.allclose(param.half(), dsparam):
            found = True
    assert found == True, f"Did not find param = {key}"
    print(f"found param {key}")

In [None]:
# save weights
model_dir = "/home/julian/models/"

save_model = True
if add_zeros:
    fp_out = model_dir + "dsint_imported_w_zeros.pt"
else:
    fp_out = model_dir + "dsint_imported_no_zeros.pt"
if save_model:
    torch.save(seq_to_seq.model.state_dict(), fp_out)

In [None]:
#check number of params
check = False

if check:
    old_model_dict = torch.load(weights_fp)['network']

    new_model_dict = torch.load(fp_out)

    for model in [old_model_dict, new_model_dict]:
        print(sum([p.numel() for _, p in model.items()]))

    old = sum([p.numel() for _, p in old_model_dict.items()])
    new = sum([p.numel() for _, p in new_model_dict.items()])

    if add_zeros:
        assert new - old == 1152 * 80, "Failed"
    else:
        assert new == old
    print(f"passed, new={new}, old={old}")

del model["network"]

### Maybe change decoder:

In [None]:
use_beam = False
no_max = False

from myrtlespeech.post_process.rnn_t_beam_decoder import RNNTBeamDecoder
from myrtlespeech.post_process.rnn_t_greedy_decoder import RNNTGreedyDecoder

if use_beam:
    decoder = RNNTBeamDecoder(blank_index=28,
                                beam_width=4,
                                 length_norm=False,
                                 max_symbols_per_step = 4,
                             model=seq_to_seq.model)
else:
    decoder = RNNTGreedyDecoder(blank_index=28,
                                 max_symbols_per_step = 4,
                               model=seq_to_seq.model) 

seq_to_seq.post_process = decoder

if no_max:
    seq_to_seq.post_process.max_symbols_per_step = 100
else:
    seq_to_seq.post_process.max_symbols_per_step = 4

## Callbacks
* Use callbacks to inject features into training loop. 
* It is necessary (for now) to use the `RNNTTraining()` callback but the others are optional


In [None]:
#custom callback to monitor training and print results
class PrintCB(Callback):
    def __init__(self):
        super().__init__()
    
    def on_batch_end(self, **kwargs):
        
        if self.training:
            #print("training batch ended")
            return
        epoch = kwargs["epoch"]
#         if kwargs["epoch_batches"] % 100 == 0 and kwargs["epoch_batches"] != 0:
#             print(f"{kwargs['epoch_batches']} batches completed")
#             try:
#                 wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
#                 wer = wer_reports["wer"]
#                 if len(wer_reports["transcripts"]) > 0:
#                     transcripts = wer_reports["transcripts"][0] #take first element
#                     pred, exp = transcripts
#                     pred = "".join(pred)
#                     exp = "".join(exp)
#                     loss = kwargs["reports"]["ReportMeanBatchLoss"]
#                     print("batch end, pred: {}, exp: {}, wer: {:.4f}".format(pred, exp, wer, ))

#             except KeyError:
#                 print("no wer - using new decoder?")
        
        
            
    def on_epoch_end(self, **kwargs):
        if self.training:
            return
        epoch = kwargs["epoch"]
        
        try:
            
            loss = kwargs["reports"]["ReportMeanBatchLoss"]
            
            wer_reports = kwargs["reports"][seq_to_seq.post_process.__class__.__name__]
            wer = wer_reports["wer"]
            
            out_str = "{}, loss: {:.8f}".format(epoch, loss)
            
            if len(wer_reports["transcripts"]) > 0:
                transcripts = wer_reports["transcripts"][0] #take first element
                pred, exp = transcripts
                pred = "".join(pred)
                exp = "".join(exp)
                
                out_str += ", wer: {:.4f}, pred: {}, exp: {},".format(wer, pred, exp)
            print(out_str)
        except KeyError:
            
            print("no wer - using new decoder?")        

keys_to_log = ["epoch", 
        f"reports/{seq_to_seq.post_process.__class__.__name__}/wer",
        "reports/ReportMeanBatchLoss"]


In [None]:
run_eval = True


eval_cbs = None
if run_eval:
    eval_cbs = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet),
            CSVLogger(log_dir + "log_eval.csv", keys=keys_to_log),
            PrintCB(),] 
    cb_handler = CallbackHandler(eval_cbs, False)
    cb_handler.on_train_begin(epochs=2)
    
    run_stage(seq_to_seq, cb_handler, eval_loader, is_training=False)

In [None]:


rnnt_decoder_cb  = ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet, eval_every=1, 
                                         skip_first_epoch=True)


keys_to_log = ["epoch", 
        f"reports/{seq_to_seq.post_process.__class__.__name__}/wer",
        "reports/ReportMeanBatchLoss"]



callbacks = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            TensorBoardLogger(log_dir, seq_to_seq.model, histograms=True),
            MixedPrecision(seq_to_seq),
            ClipGradNorm(seq_to_seq, 200),
            rnnt_decoder_cb,
             
            #stop prematurely (useful for debug). Ensure following line is commented out to perform full training
            #StopEpochAfter(epoch_batches=2),
            
            # logging
            CSVLogger(log_dir + "log.csv", keys=keys_to_log),
            
            PrintCB(),
            Saver(log_dir, seq_to_seq.model)] 


### Maybe Change lr

In [None]:
change_lr = True
new_lr = 0.0005



for param_group in seq_to_seq.optim.param_groups:
    print("current lr: ", param_group['lr'])

if change_lr:
    for param_group in seq_to_seq.optim.param_groups:
        param_group['lr'] = new_lr

for param_group in seq_to_seq.optim.param_groups:
    print("new lr: ", param_group['lr'])

In [None]:
fit(
    seq_to_seq, 
    epochs=40,
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=callbacks,
)



In [None]:
import torch
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except: pass

### Maybe eval

In [None]:
run_eval = True


eval_cbs = None
if run_eval:
    eval_cbs = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet),
            CSVLogger(log_dir + "log_eval.csv", keys=keys_to_log),
            PrintCB(),] 
    cb_handler = CallbackHandler(eval_cbs, False)
    cb_handler.on_train_begin(epochs=2)
    
    run_stage(seq_to_seq, cb_handler, eval_loader, is_training=False)