In [1]:
import torch


In [27]:
## load model
from myrtlespeech.protos import task_config_pb2
from google.protobuf import text_format
from myrtlespeech.builders.task_config import build
from myrtlespeech.builders.speech_to_text import build as build_stt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # set this before importing torch
import torch
from myrtlespeech.model.deep_speech_1 import DeepSpeech1
from pathlib import Path
import copy

# parse example config file
with open("../src/myrtlespeech/configs/deep_speech_1_2048_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

stt = build_stt(task_config.speech_to_text)
ds1 = stt.model

for key, p in ds1.state_dict().items():
    print(key, p.shape)

fc1.0.weight torch.Size([2048, 494])
fc1.0.bias torch.Size([2048])
fc2.0.weight torch.Size([2048, 2048])
fc2.0.bias torch.Size([2048])
fc3.0.weight torch.Size([4096, 2048])
fc3.0.bias torch.Size([4096])
bi_lstm.rnn.weight_ih_l0 torch.Size([8192, 4096])
bi_lstm.rnn.weight_hh_l0 torch.Size([8192, 2048])
bi_lstm.rnn.bias_ih_l0 torch.Size([8192])
bi_lstm.rnn.bias_hh_l0 torch.Size([8192])
bi_lstm.rnn.weight_ih_l0_reverse torch.Size([8192, 4096])
bi_lstm.rnn.weight_hh_l0_reverse torch.Size([8192, 2048])
bi_lstm.rnn.bias_ih_l0_reverse torch.Size([8192])
bi_lstm.rnn.bias_hh_l0_reverse torch.Size([8192])
fc4.0.weight torch.Size([2048, 4096])
fc4.0.bias torch.Size([2048])
out.weight torch.Size([29, 2048])
out.bias torch.Size([29])


In [28]:
state_fp = '/home/julian/models/ds1/96_sparsity.pt'
state_dict = torch.load(fp)
for key, p in state_dict.items():
    print(key, p.shape)

bi_lstm.weight_ih_l0_reverse torch.Size([8192, 4096])
fc3.module.0.weight torch.Size([4096, 2048])
bi_lstm.bias_ih_l0 torch.Size([8192])
fc4.module.0.weight torch.Size([2048, 4096])
fc2.module.0.weight torch.Size([2048, 2048])
fc1.module.0.weight torch.Size([2048, 494])
bi_lstm.bias_hh_l0 torch.Size([8192])
bi_lstm.bias_ih_l0_reverse torch.Size([8192])
fc4.module.0.bias torch.Size([2048])
fc1.module.0.bias torch.Size([2048])
bi_lstm.weight_hh_l0 torch.Size([8192, 2048])
bi_lstm.weight_ih_l0 torch.Size([8192, 4096])
out.module.0.weight torch.Size([29, 2048])
bi_lstm.weight_hh_l0_reverse torch.Size([8192, 2048])
out.module.0.bias torch.Size([29])
fc3.module.0.bias torch.Size([4096])
fc2.module.0.bias torch.Size([2048])
bi_lstm.bias_hh_l0_reverse torch.Size([8192])


In [37]:
#fc: fc1.module.0.weight -> fc1.0.weight 
#lstm: bi_lstm.weight_ih_l0 -> bi_lstm.rnn.weight_ih_l0
state_dict_ = {}
for k, v in state_dict.items():
    if 'fc' in k:
        new_key = k.replace('.module', '')
        state_dict_[new_key] = v
    elif 'lstm' in k:
        new_key = k.replace('lstm.', 'lstm.rnn.')
        state_dict_[new_key] = v
    elif 'out' in k:
        new_key = k.replace('module.0.', '')
        state_dict_[new_key] = v
    else:
        state_dict_[k] = v

ds1.load_state_dict(state_dict_, strict=True)

<All keys matched successfully>

In [None]:
# Save state_dict
ds1.eval()

torch.save(ds1.state_dict(), '/home/julian/models/ds1/96_sparsity_myrtle.pt')

## Onnx helper functions

In [42]:
from pathlib import Path
import onnx
import onnxruntime as ort
log_dir = '/home/julian/exp/onnx/tmp/'

def export_and_check(model, args, fname, input_names, output_names, example_outputs=None, 
                     dynamic_axes=None, verbose=False, opset_version=11):
    fp = Path(log_dir) / fname
    model.eval()
    
    # run model in torch to get expected outputs
    exp_outputs = model(*args)
    
    # export onnx model
    torch.onnx.export(model, args, fp, export_params=True, verbose=False,  example_outputs=example_outputs, 
                      dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                      opset_version=opset_version)
    
    
    # Load the ONNX model
    model_onnx = onnx.load(fp)
    
    # Print a human readable representation of the graph
    if verbose:
        print("Printing graph...")
        print(onnx.helper.printable_graph(model_onnx.graph))
    
    
    # Check that the IR is well formed
    #onnx.checker.check_model(model_onnx)
    
    
    # onnx runtime
    ort_session = ort.InferenceSession(str(fp))
    
    compare_outputs(ort_session, model, args, output_names, input_names)
    
    return ort_session

def compare_outputs(ort_session, model, args, output_names, input_names):
    # convert input args to numpy
    exp_outputs = model(*args)
    args = [x.numpy() if isinstance(x, torch.Tensor) else x for x in args]
    
    outputs = ort_session.run(output_names, {k: args[idx] for idx, k in enumerate(input_names)})
    
    check_outputs_as_expected(outputs, exp_outputs)
    
    print('model correct!')
    

    
def check_outputs_as_expected(outputs, exp_outputs):
    if isinstance(exp_outputs, torch.Tensor):
        assert torch.allclose(torch.tensor(outputs), exp_outputs.cpu(), atol=1e-4, rtol=1e-2)
    elif isinstance(exp_outputs, tuple) and isinstance(outputs, (tuple, list)):
        assert len(exp_outputs) == len(outputs), f"{len(exp_outputs)} != {len(outputs)}"
        for idx, x in enumerate(outputs):
            check_outputs_as_expected(x, exp_outputs[idx])
    else:
        raise ValueError(f'Unexpected output type(outputs)={type(outputs)} '
                         f'with type(exp_outputs)={type(exp_outputs)} ')
        



### Run and export onnx

In [43]:
def create_ds1_data(seq_len, batch):
    # inputs: Tuple: [(batch, channels, in_features, seq_len), (batch,)]
    # outputs: Tuple: [(seq_len, batch, out_feat), (batch,)]
    seq_len = 2
    batch = 5
    in_features = 26 * 19
    channels = 1

    # inputs
    inp = torch.randn(batch, channels, in_features, seq_len)
    in_lens = torch.randint(low=1, high=seq_len, size=(batch,)).type(torch.int64)
    in_lens = in_lens.sort(descending=True)[0]

    return(inp, in_lens)

In [45]:
seq_len = 2
batch = 6

args = create_ds1_data(seq_len=seq_len, batch=batch)
# init onnx params
onnx_fname = f'ds1_1024_traced.onnx'
input_names = ['input', 'in_lens']
output_names = ['output', 'out_lens']
dynamic_axes = {'input': {0: 'batch', 3: 'seq_len'}, 
                     'in_lens': {0: 'batch'},
                     'output': {0: 'seq_len', 1: 'batch'},  
                     'out_lens': {0: 'batch'}}
opset_version = 11
log_dir = '/home/julian/exp/onnx/tmp/'
fp = Path(log_dir) / onnx_fname
    
class CollapseArgs(torch.nn.Module):
    def __init__(self, submodel):
        super().__init__()
        self.submodel = submodel
        
    def forward(self, x: torch.Tensor, lens: torch.Tensor):
        return self.submodel((x, lens))

# get model
model = ds1
model.eval()


# Collapse input args
model = CollapseArgs(ds1)

# trace model
traced = torch.jit.trace(model, args)

# script model
#scripted = torch.jit.script(model)


# expected outputs
example_outputs = model(*args)
trace_outputs = traced(*args)
#script_outputs = scripted(*args)


assert torch.allclose(example_outputs[0], trace_outputs[0])
assert torch.allclose(example_outputs[1], trace_outputs[1])

print("trace and Module versions equivalent")

# assert torch.allclose(example_outputs[0], script_outputs[0])
# assert torch.allclose(example_outputs[1], script_outputs[1])

# print("script and Module versions equivalent")

# check they are equivalent with new values
seq_len *= 2
batch *= 2
args = create_ds1_data(seq_len=seq_len, batch=batch)

# expected outputs
example_outputs = model(*args)
trace_outputs = traced(*args)
#script_outputs = scripted(*args)


assert torch.allclose(example_outputs[0], trace_outputs[0], rtol=1e-2, atol=1e-4)
assert torch.allclose(example_outputs[1], trace_outputs[1], rtol=1e-2, atol=1e-4)

print("trace and Module versions equivalent for variable batch")

# assert torch.allclose(example_outputs[0], script_outputs[0], rtol=1e-2, atol=1e-4)
# assert torch.allclose(example_outputs[1], script_outputs[1], rtol=1e-2, atol=1e-4)

# print("script and Module versions equivalent for variable batch")


model = traced

# export onnx model
torch.onnx.export(model, args, fp, export_params=True, verbose=False,  example_outputs=example_outputs, 
                  dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                  opset_version=opset_version)

    
ort_session = export_and_check(model=traced, 
                 args = args,
                 fname = onnx_fname,
                 input_names = input_names,
                 output_names = output_names,
                 dynamic_axes = dynamic_axes,
                 opset_version=11,
                 example_outputs=example_outputs,
    )


batch = 10
seq_len = 40
args = create_ds1_data(seq_len=seq_len, batch=batch)

compare_outputs(ort_session, model, args, output_names=output_names, input_names=input_names)
print("done")


trace and Module versions equivalent
trace and Module versions equivalent for variable batch
model correct!
model correct!

