In [1]:
from pathlib import Path
import onnx
import onnxruntime as ort
log_dir = '/home/julian/exp/onnx/tmp/'

def export_and_check(model, args, fname, input_names, output_names, example_outputs=None, 
                     dynamic_axes=None, verbose=False, opset_version=11):
    fp = Path(log_dir) / fname
    model.eval()
    
    # run model in torch to get expected outputs
    exp_outputs = model(*args)
    
    # export onnx model
    torch.onnx.export(model, args, fp, export_params=True, verbose=False,  example_outputs=example_outputs, 
                      dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                      opset_version=opset_version)
    
    
    # Load the ONNX model
    model_onnx = onnx.load(fp)
    
    # Print a human readable representation of the graph
    if verbose:
        print("Printing graph...")
        print(onnx.helper.printable_graph(model_onnx.graph))
    
    
    # Check that the IR is well formed
    #onnx.checker.check_model(model_onnx)
    
    
    # onnx runtime
    ort_session = ort.InferenceSession(str(fp))
    
    compare_outputs(ort_session, model, args, output_names, input_names)
    
    return ort_session

def compare_outputs(ort_session, model, args, output_names, input_names):
    # convert input args to numpy
    exp_outputs = model(*args)
    args = [x.numpy() if isinstance(x, torch.Tensor) else x for x in args]
    
    outputs = ort_session.run(output_names, {k: args[idx] for idx, k in enumerate(input_names)})
    
    check_outputs_as_expected(outputs, exp_outputs)
    
    print('model correct!')
    

    
def check_outputs_as_expected(outputs, exp_outputs):
    if isinstance(exp_outputs, torch.Tensor):
        assert torch.allclose(torch.tensor(outputs), exp_outputs.cpu(), atol=1e-4, rtol=1e-2)
    elif isinstance(exp_outputs, tuple) and isinstance(outputs, (tuple, list)):
        assert len(exp_outputs) == len(outputs), f"{len(exp_outputs)} != {len(outputs)}"
        for idx, x in enumerate(outputs):
            check_outputs_as_expected(x, exp_outputs[idx])
    else:
        raise ValueError(f'Unexpected output type(outputs)={type(outputs)} '
                         f'with type(exp_outputs)={type(exp_outputs)} ')
        



In [2]:
from myrtlespeech.protos import task_config_pb2
from google.protobuf import text_format
from myrtlespeech.builders.task_config import build
from myrtlespeech.builders.speech_to_text import build as build_stt
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # set this before importing torch
import torch
from myrtlespeech.model.deep_speech_1 import DeepSpeech1
from pathlib import Path

# parse example config file
with open("../src/myrtlespeech/configs/deep_speech_1_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

stt = build_stt(task_config.speech_to_text)
ds1 = stt.model

# inputs: Tuple: [(batch, channels, in_features, seq_len), (batch,)]
# outputs: Tuple: [(seq_len, batch, out_feat), (batch,)]
seq_len = 2
batch = 5
in_features = 26 * 19
channels = 1



# inputs
inp = torch.randn(batch, channels, in_features, seq_len)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,)).type(torch.int64)
in_lens = in_lens.sort(descending=True)[0]

args = (inp, in_lens)


# init onnx params
onnx_fname = f'ds1_1024_traced.onnx'
input_names = ['input', 'in_lens']
output_names = ['output', 'out_lens']
dynamic_axes = {'input': {0: 'batch', 3: 'seq_len'}, 
                     'in_lens': {0: 'batch'},
                     'output': {0: 'seq_len', 1: 'batch'},  
                     'out_lens': {0: 'batch'}}
opset_version = 11
log_dir = '/home/julian/exp/onnx/tmp/'
fp = Path(log_dir) / onnx_fname
    
class CollapseArgs(torch.nn.Module):
    def __init__(self, submodel):
        super().__init__()
        self.submodel = submodel
        
    def forward(self, x: torch.Tensor, lens: torch.Tensor):
        return self.submodel((x, lens))

# get model
model = ds1
model.eval()


# Collapse input args
model = CollapseArgs(ds1)

# trace model
traced = torch.jit.trace(model, args)

# script model
#scripted = torch.jit.script(model)


# expected outputs
example_outputs = model(*args)
trace_outputs = traced(*args)
#script_outputs = scripted(*args)


assert torch.allclose(example_outputs[0], trace_outputs[0])
assert torch.allclose(example_outputs[1], trace_outputs[1])

print("trace and Module versions equivalent")

# assert torch.allclose(example_outputs[0], script_outputs[0])
# assert torch.allclose(example_outputs[1], script_outputs[1])

# print("script and Module versions equivalent")

# check they are equivalent with new values
seq_len *= 2
batch *= 2
inp = torch.randn(batch, channels, in_features, seq_len)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,)).type(torch.int64)
in_lens = in_lens.sort(descending=True)[0]

args = (inp, in_lens)

# expected outputs
example_outputs = model(*args)
trace_outputs = traced(*args)
#script_outputs = scripted(*args)


assert torch.allclose(example_outputs[0], trace_outputs[0], rtol=1e-2, atol=1e-4)
assert torch.allclose(example_outputs[1], trace_outputs[1], rtol=1e-2, atol=1e-4)

print("trace and Module versions equivalent for variable batch")

# assert torch.allclose(example_outputs[0], script_outputs[0], rtol=1e-2, atol=1e-4)
# assert torch.allclose(example_outputs[1], script_outputs[1], rtol=1e-2, atol=1e-4)

# print("script and Module versions equivalent for variable batch")


model = traced

# export onnx model
torch.onnx.export(model, args, fp, export_params=True, verbose=False,  example_outputs=example_outputs, 
                  dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                  opset_version=opset_version)

    
ort_session = export_and_check(model=traced, 
                 args = args,
                 fname = onnx_fname,
                 input_names = input_names,
                 output_names = output_names,
                 dynamic_axes = dynamic_axes,
                 opset_version=11,
                 example_outputs=example_outputs,
    )


batch = 10
seq_len = 40

inp = torch.randn(batch, channels, in_features, seq_len)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,))
in_lens = in_lens.sort(descending=True)[0]

compare_outputs(ort_session, model, args, output_names=output_names, input_names=input_names)



trace and Module versions equivalent
trace and Module versions equivalent for variable batch


  "or define the initial states (h0/c0) as inputs of the model. ")


model correct!
model correct!


### Exporting to ONNX notes

* No `Union` allowed in types
* MAKE SURE cuda is disabled with `os.environ["CUDA_VISIBLE_DEVICES"] = ""` before exporting. 
* Make sure you place in .eval() mode before exporting. 

## Everything below is an offcut

In [8]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "" # set this before importing torch
import torch
import onnx


from myrtlespeech.protos import task_config_pb2
from google.protobuf import text_format
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.run import Saver
from myrtlespeech.model.fully_connected import FullyConnected
from myrtlespeech.model.rnn import RNN

from myrtlespeech.run.train import fit
from pathlib import Path

import torch
import onnx

import onnxruntime as ort
import numpy as np

In [3]:
log_dir = '/home/julian/exp/onnx/tmp/'

## Create test to check identity

In [24]:

def export_and_check(model, args, fname, input_names, output_names, example_outputs=None, 
                     dynamic_axes=None, verbose=False, opset_version=11):
    fp = Path(log_dir) / fname
    model.eval()
    
    # run model in torch to get expected outputs
    exp_outputs = model(*args)
    
    # export onnx model
    torch.onnx.export(model, args, fp, export_params=True, verbose=False,  example_outputs=example_outputs, 
                      dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                      opset_version=opset_version)
    
    
    # Load the ONNX model
    model_onnx = onnx.load(fp)
    
    # Print a human readable representation of the graph
    if verbose:
        print("Printing graph...")
        print(onnx.helper.printable_graph(model_onnx.graph))
    
    
    # Check that the IR is well formed
    #onnx.checker.check_model(model_onnx)
    
    
    # onnx runtime
    ort_session = ort.InferenceSession(str(fp))
    
    compare_outputs(ort_session, model, args, output_names, input_names)
    
    return ort_session

def compare_outputs(ort_session, model, args, output_names, input_names):
    # convert input args to numpy
    exp_outputs = model(*args)
    args = [x.numpy() if isinstance(x, torch.Tensor) else x for x in args]
    
    outputs = ort_session.run(output_names, {k: args[idx] for idx, k in enumerate(input_names)})
    
    check_outputs_as_expected(outputs, exp_outputs)
    
    print('model correct!')
    

    
def check_outputs_as_expected(outputs, exp_outputs):
    if isinstance(exp_outputs, torch.Tensor):
        assert torch.allclose(torch.tensor(outputs), exp_outputs.cpu())
    elif isinstance(exp_outputs, tuple) and isinstance(outputs, (tuple, list)):
        assert len(exp_outputs) == len(outputs), f"{len(exp_outputs)} != {len(outputs)}"
        for idx, x in enumerate(outputs):
            check_outputs_as_expected(x, exp_outputs[idx])
    else:
        raise ValueError(f'Unexpected output type(outputs)={type(outputs)} '
                         f'with type(exp_outputs)={type(exp_outputs)} ')
        

In [27]:
# create wrapper to unwrap tuple args
class CollapseTupleArgs(torch.nn.Module):
    def __init__(self, submodel):
        super().__init__()
        self.submodel = submodel
    def forward(self, *args):
        return self.submodel(args)

class FlattenTupleArgs(torch.nn.Module):
    """Flatten Tuple Args before returning."""
    def __init__(self, submodel):
        super().__init__()
        self.submodel = submodel
    def forward(self, *args):
        res = self.submodel(*args)
        ret = []
        if isinstance(res, tuple):
            for x in res:
                if isinstance(x, tuple):
                    for y in x:
                        ret.append(y)
                else:
                    ret.append(x)
        return tuple(ret)

## Linear

In [21]:
model=torch.nn.Linear(2, 3)
batch = 5
args = (torch.randn(batch, 2),)
exp_out = model(*args)
model = torch.jit.trace(model, args)
#model = torch.jit.script(model)
ort_session = export_and_check(model=model, 
                 args = args,
                 fname = 'linear.onnx',
                 input_names = ['input'],
                 output_names = ['output'],
                 dynamic_axes = {'input': {0: 'batch'}, 'output': {0: 'batch'}},
                 opset_version=11,
                 example_outputs = exp_out,
    )
batch = 900
compare_outputs(ort_session, model, (torch.randn(batch, 2),), ['output'], ['input'])

model correct!
model correct!


## RNN

In [12]:
class _Wrapper(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, input, hx):
        return self.module(input, hx)


### RNN with trace


In [22]:
import torch 

from pathlib import Path


log_dir = Path('/home/julian/exp/onnx/tmp/')

class _Wrapper(torch.nn.Module):
    """Wrapper for to fix pytorch 1.4 bug."""
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, input, hx):
        return self.module(input, hx)

# args
seq_len = 1
batch = 1
input_size = 100
num_layers = 8
hidden_size = 300

# args: (inp, h_n)
# inp: (seq_len, batch, input_size)
# h_n: (num_layers * num_directions, batch, hidden_size)

args = torch.randn(seq_len, batch, input_size), torch.randn((num_layers, batch, hidden_size))

#onnx args
input_names = ['input', 'hx']
output_names = ['output', 'hy']
dynamic_axes = {'input':  {0: 'seq_len', 1: 'batch'}, 'hx': {1: 'batch'},
                'output': {0: 'seq_len', 1: 'batch'}, 'hy': {1: 'batch'},}

model = torch.nn.GRU(input_size, hidden_size, num_layers)
model.eval()



rnn_trace= torch.jit.trace(_Wrapper(model), args)

example_outputs = model(*args) #check it runs
trace_output = rnn_trace(*args)

assert torch.allclose(example_outputs[0], trace_output[0])
assert torch.allclose(example_outputs[1], trace_output[1])

print("trace_output and nn.Module give same values")

torch.onnx.export(rnn_trace, args, log_dir / 'gru.onnx', export_params=True, verbose=False,  
                  example_outputs=example_outputs, 
                  dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                  opset_version=11)


## check for new args

# args
seq_len = 2000
batch = 20


# args: (inp, h_n)
# inp: (seq_len, batch, input_size)
# h_n: (num_layers * num_directions, batch, hidden_size)

args = torch.randn(seq_len, batch, input_size), torch.randn((num_layers, batch, hidden_size))

example_outputs = model(*args) #check it runs
trace_output = rnn_trace(*args)

assert torch.allclose(example_outputs[0], trace_output[0])
assert torch.allclose(example_outputs[1], trace_output[1])

print("trace_output and nn.Module give same values for new dimensions")

trace_output and nn.Module give same values
trace_output and nn.Module give same values for new dimensions


### RNN with script

In [25]:
import torch 

from pathlib import Path


log_dir = Path('/home/julian/exp/onnx/tmp/')

class _Wrapper(torch.nn.Module):
    """Wrapper for to fix pytorch 1.4 bug."""
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, input, hx):
        return self.module(input, hx)

# args
seq_len = 2
batch = 1
input_size = 2
num_layers = 1
hidden_size = 3

# args: (inp, h_n)
# inp: (seq_len, batch, input_size)
# h_n: (num_layers * num_directions, batch, hidden_size)

args = torch.randn(seq_len, batch, input_size), torch.randn((num_layers, batch, hidden_size))

#onnx args
input_names = ['input', 'hx']
output_names = ['output', 'hy']
dynamic_axes = {'input':  {0: 'seq_len', 1: 'batch'}, 'hx': {1: 'batch'},
                'output': {0: 'seq_len', 1: 'batch'}, 'hy': {1: 'batch'},}

model = torch.nn.GRU(input_size, hidden_size, num_layers)
model.eval()



rnn_script= torch.jit.trace(_Wrapper(model), args)

example_outputs = model(*args) #check it runs
script_output = rnn_script(*args)

assert torch.allclose(example_outputs[0], script_output[0])
assert torch.allclose(example_outputs[1], script_output[1])

print("script and nn.Module give same values")

torch.onnx.export(rnn_script, args, log_dir / 'gru.onnx', export_params=True, verbose=True,  
                  example_outputs=example_outputs, 
                  dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                  opset_version=11)



script and nn.Module give same values
graph(%input : Float(2, 1, 2),
      %hx : Float(1, 1, 3),
      %2 : Float(9, 2),
      %3 : Float(9, 3),
      %4 : Float(9),
      %5 : Float(9)):
  %6 : Tensor? = prim::Constant(), scope: __module.module
  %7 : Tensor = onnx::Constant[value={0}](), scope: __module.module
  %8 : Tensor = onnx::Constant[value={3}](), scope: __module.module
  %9 : Tensor = onnx::Constant[value={6}](), scope: __module.module
  %10 : Tensor = onnx::Slice(%2, %8, %9, %7), scope: __module.module
  %11 : Tensor = onnx::Constant[value={0}](), scope: __module.module
  %12 : Tensor = onnx::Constant[value={0}](), scope: __module.module
  %13 : Tensor = onnx::Constant[value={3}](), scope: __module.module
  %14 : Tensor = onnx::Slice(%2, %12, %13, %11), scope: __module.module
  %15 : Tensor = onnx::Constant[value={0}](), scope: __module.module
  %16 : Tensor = onnx::Constant[value={6}](), scope: __module.module
  %17 : Tensor = onnx::Constant[value={9}](), scope: __module.mo

In [20]:
model = torch.nn.GRU(input_size, hidden_size, num_layers)

example_outputs = model(*args)

rnn_script= torch.jit.script(_Wrapper(model))
rnn_script_= torch.jit.script(_Wrapper(torch.jit.script(model)))

print(rnn_script)
print(rnn_script_)
assert torch.allclose(rnn_script(*args)[0], rnn_script_(*args)[0])
assert torch.allclose(rnn_script(*args)[1], rnn_script_(*args)[1])

RecursiveScriptModule(
  original_name=_Wrapper
  (module): RecursiveScriptModule(original_name=GRU)
)
RecursiveScriptModule(
  original_name=_Wrapper
  (module): RecursiveScriptModule(original_name=GRU)
)


In [9]:
input_size = 3
num_layers = 2
seq_len = 3
batch = 3
args = (torch.randn(seq_len, batch, input_size),)


scripted_lstm = torch.jit.script(_Wrapper(torch.nn.LSTM(input_size, 3, num_layers)))

scripted_lstm(*args)
#model = torch.jit.trace(FlattenTupleArgs(scripted_lstm), args)


(tensor([[[-0.0915, -0.0938,  0.0228],
          [-0.1540, -0.1174,  0.0207],
          [-0.0293, -0.0392,  0.0219]],
 
         [[-0.1900, -0.1471,  0.0298],
          [-0.1172, -0.1263,  0.0386],
          [-0.1067, -0.1139,  0.0077]],
 
         [[-0.2705, -0.1657,  0.0360],
          [-0.1083, -0.1431,  0.0272],
          [-0.1427, -0.1523,  0.0046]]], grad_fn=<StackBackward>),
 (tensor([[[-0.1653,  0.4739,  0.2383],
           [ 0.2439,  0.1348,  0.1578],
           [ 0.1438,  0.3261,  0.1405]],
  
          [[-0.2705, -0.1657,  0.0360],
           [-0.1083, -0.1431,  0.0272],
           [-0.1427, -0.1523,  0.0046]]], grad_fn=<StackBackward>),
  tensor([[[-0.2369,  0.6427,  0.5650],
           [ 0.4649,  0.1760,  0.3083],
           [ 0.3006,  0.4258,  0.2456]],
  
          [[-0.3988, -0.4197,  0.1099],
           [-0.1745, -0.2874,  0.0790],
           [-0.2203, -0.3272,  0.0133]]], grad_fn=<StackBackward>)))

## LSTM

## myrtlespeech submodules

In [1]:
# test myrtlespeech fully_connected
# args: ([batch, seq_len, in_features], (batch,))

seq_len = 2
batch = 1
in_features = 5
model = CollapseTupleArgs(FullyConnected(in_features, out_features=3, 
                                    num_hidden_layers=2,  hidden_size = 2, hidden_activation_fn=torch.nn.ReLU()),)

inp = torch.randn(batch, seq_len, in_features)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,))
args = (inp, in_lens)

model  = torch.jit.trace(model, args)

output_names= ['output', 'out_lens']
input_names=['input', 'in_lens']
ort_session = export_and_check(model=model, 
                 args = args,
                 fname = 'fc_myrtlespeech.onnx',
                 input_names = input_names,
                 output_names = output_names,
                 dynamic_axes = {'input': {0: 'batch', 1: 'seq_len'}, 
                                 'in_lens': {0: 'batch'},
                                 'output': {0: 'batch', 1: 'seq_len'}, 
                                 'out_lens': {0: 'batch'}},                        
    )

batch = 9
seq_len = 4

inp = torch.randn(batch, seq_len, in_features)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,))
args = (inp, in_lens)
compare_outputs(ort_session, model, args, output_names=output_names, input_names=input_names)

NameError: name 'CollapseTupleArgs' is not defined

In [33]:

from typing import Optional
import torch 
from myrtlespeech.model.rnn import RNN
import onnx
import onnxruntime as ort
class FlattenRNNArgs(torch.nn.Module):
    """Flatten Tuple Args before returning and 'unflattens' at input."""
    def __init__(self, submodel):
        super().__init__()
        self.submodel = submodel
    

    def forward(self, *args):
        # (inp, in_lens, hx) -> ((inp, in_lens), hx)
        args_in = (args[0], args[1]), args[2]
        res = self.submodel(*args_in)
        ret = []
        if isinstance(res, tuple):
            for x in res:
                if isinstance(x, tuple):
                    for y in x:
                        ret.append(y)
                else:
                    ret.append(x)
        return tuple(ret)
    
# test myrtlespeech RNN
# args: ([seq_len, batch, in_features], (batch,))


seq_len = 2
batch = 2
input_size = 2
num_layers = 1
hidden_size = 3
rnn_type = 1 # GRU
# h_n = num_layers * num_directions, batch, hidden_size

# inputs
inp = torch.randn(seq_len, batch, input_size)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,)).type(torch.int32)
in_lens = in_lens.sort(descending=True)[0]

hx: Optional[torch.Tensor] = torch.randn(num_layers, batch, hidden_size)
args = (inp, in_lens, hx)

# init onnx params
onnx_fname = 'RNN_myrtlespeech.onnx'
input_names = ['input', 'in_lens', 'hx']
output_names = ['output', 'out_lens', 'hx']
dynamic_axes = {'input':  {0: 'seq_len', 1: 'batch'}, 'in_lens': {0: 'batch'}, 'hx': {1: 'batch'},
                'output': {0: 'seq_len', 1: 'batch'}, 'out_lens': {0: 'batch'}, 'hx': {1: 'batch'},}

# init model
model = RNN(rnn_type=1, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)

# script model
scripted = torch.jit.script(model)

# flatten tuple args
model = FlattenRNNArgs(model)
#scripted = torch.jit.trace(FlattenTupleArgs(scripted, unflatten_fn=unflatten_fn), args)
scripted = FlattenRNNArgs(scripted)


# expected outputs
example_outputs = model(*args)
script_outputs = scripted(*args)


assert torch.allclose(example_outputs[0], script_outputs[0])
assert torch.allclose(example_outputs[1], script_outputs[1])
assert torch.allclose(example_outputs[2], script_outputs[2])

print("script and Module versions equivalent")

traced = torch.jit.trace(model, args)

print("sucessfully traced module")


ort_session = export_and_check(model=traced, 
                 args = args,
                 fname = onnx_fname,
                 input_names = input_names,
                 output_names = output_names,
                 dynamic_axes = dynamic_axes,
                 opset_version=11,
                 example_outputs=example_outputs,
    )

batch = 9
seq_len = 4

inp = torch.randn(batch, seq_len, in_features)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,))
args = (inp, in_lens)
compare_outputs(ort_session, model, args, output_names=output_names, input_names=input_names)




script and Module versions equivalent
sucessfully traced module


  "or define the initial states (h0/c0) as inputs of the model. ")


Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (tensor(int64)) of output arg (out_lens) of node () does not match expected type (tensor(int32)).

### RNN notes
* empty_like - not avavilable in torch 1.2. Hence upgrade to 1.4
* TopK - not avavilable in most recent release. 

## ds1

In [None]:
from myrtlespeech.builders.speech_to_text import build as build_stt

In [None]:
ds1 = CollapseTupleArgs(stt.model)
ds1

In [13]:
ds1

DeepSpeech1(
  (fc1): Sequential(
    (0): Linear(in_features=494, out_features=8, bias=True)
    (1): Hardtanh(min_val=0.0, max_val=20.0, inplace=True)
    (2): Dropout(p=0.3, inplace=False)
  )
  (fc2): Sequential(
    (0): Linear(in_features=8, out_features=8, bias=True)
    (1): Hardtanh(min_val=0.0, max_val=20.0, inplace=True)
    (2): Dropout(p=0.3, inplace=False)
  )
  (fc3): Sequential(
    (0): Linear(in_features=8, out_features=16, bias=True)
    (1): Hardtanh(min_val=0.0, max_val=20.0, inplace=True)
    (2): Dropout(p=0.3, inplace=False)
  )
  (bi_lstm): LSTM(
    (rnn): LSTM(16, 8, batch_first=True, bidirectional=True)
  )
  (fc4): Sequential(
    (0): Linear(in_features=16, out_features=8, bias=True)
    (1): Hardtanh(min_val=0.0, max_val=20.0, inplace=True)
    (2): Dropout(p=0.3, inplace=False)
  )
  (out): Linear(in_features=8, out_features=3, bias=True)
)

In [10]:
scripted.submodel.fc1

RecursiveScriptModule(
  original_name=Sequential
  (0): RecursiveScriptModule(original_name=Linear)
  (1): RecursiveScriptModule(original_name=Hardtanh)
  (2): RecursiveScriptModule(original_name=Dropout)
)

In [9]:
inp = torch.randn(batch, channels, in_features, seq_len)
in_lens = torch.randint(low=1, high=seq_len, size=(batch,)).type(torch.int64)
in_lens = in_lens.sort(descending=True)[0]
log_dir = Path(log_dir)
args = (inp, in_lens)

input_names = ['input',]
output_names = ['output',]
dynamic_axes = {'input': {0: 'batch', 1: 'seq_len'}, 
                     'output': {0: 'batch', 1: 'seq_len'}}
opset_version = 11

###
h, seq_lens = args
batch, channels, features, seq_len = h.size()
h = h.view(batch, channels * features, seq_len).permute(0, 2, 1)

assert torch.allclose(model.submodel.fc1(h), scripted.submodel.fc1(h))
print("fc1 correct")
torch.onnx.export(scripted.submodel.fc1, args, log_dir / 'ds1_fc1', export_params=True, verbose=True,  example_outputs=model.submodel.fc1(h), 
                  dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, 
                  opset_version=opset_version)


h = model.submodel.fc1(h)
assert torch.allclose(model.submodel.fc2(h), scripted.submodel.fc2(h))


print("fc2 correct")
h = model.submodel.fc2(h)
assert torch.allclose(model.submodel.fc3(h), scripted.submodel.fc3(h))

print("fc3 correct")

fc1 correct


RuntimeError: isTensor() INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579040055865/work/aten/src/ATen/core/ivalue_inl.h:90, please report a bug to PyTorch. Expected Tensor but got Bool (toTensor at /opt/conda/conda-bld/pytorch_1579040055865/work/aten/src/ATen/core/ivalue_inl.h:90)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7f82941a2627 in /home/julian/miniconda3/envs/myrtlespeech/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: torch::jit::LowerGraph(torch::jit::Graph&, c10::intrusive_ptr<c10::ivalue::Object, c10::detail::intrusive_target_default_null_type<c10::ivalue::Object> > const&) + 0x674 (0x7f82592b6794 in /home/julian/miniconda3/envs/myrtlespeech/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: <unknown function> + 0x6b6479 (0x7f8284b71479 in /home/julian/miniconda3/envs/myrtlespeech/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #3: <unknown function> + 0x28ba06 (0x7f8284746a06 in /home/julian/miniconda3/envs/myrtlespeech/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #4: _PyMethodDef_RawFastCallKeywords + 0x264 (0x5635de6446e4 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #5: _PyCFunction_FastCallKeywords + 0x21 (0x5635de644801 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #6: _PyEval_EvalFrameDefault + 0x4e8c (0x5635de6a02bc in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #7: _PyEval_EvalCodeWithName + 0x2f9 (0x5635de5e14f9 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #8: _PyFunction_FastCallKeywords + 0x387 (0x5635de643a27 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #9: _PyEval_EvalFrameDefault + 0x14ce (0x5635de69c8fe in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #10: _PyEval_EvalCodeWithName + 0x2f9 (0x5635de5e14f9 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #11: _PyFunction_FastCallKeywords + 0x387 (0x5635de643a27 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #12: _PyEval_EvalFrameDefault + 0x14ce (0x5635de69c8fe in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #13: _PyEval_EvalCodeWithName + 0x2f9 (0x5635de5e14f9 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #14: _PyFunction_FastCallKeywords + 0x325 (0x5635de6439c5 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #15: _PyEval_EvalFrameDefault + 0x4aa9 (0x5635de69fed9 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #16: _PyEval_EvalCodeWithName + 0x2f9 (0x5635de5e14f9 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #17: _PyFunction_FastCallKeywords + 0x387 (0x5635de643a27 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #18: _PyEval_EvalFrameDefault + 0x14ce (0x5635de69c8fe in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #19: _PyEval_EvalCodeWithName + 0x2f9 (0x5635de5e14f9 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #20: PyEval_EvalCodeEx + 0x44 (0x5635de5e23c4 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #21: PyEval_EvalCode + 0x1c (0x5635de5e23ec in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #22: <unknown function> + 0x1e004d (0x5635de6ab04d in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #23: _PyMethodDef_RawFastCallKeywords + 0xe9 (0x5635de644569 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #24: _PyCFunction_FastCallKeywords + 0x21 (0x5635de644801 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #25: _PyEval_EvalFrameDefault + 0x4755 (0x5635de69fb85 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #26: _PyGen_Send + 0x2a2 (0x5635de63d672 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0x1a6d (0x5635de69ce9d in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #28: _PyGen_Send + 0x2a2 (0x5635de63d672 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #29: _PyEval_EvalFrameDefault + 0x1a6d (0x5635de69ce9d in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #30: _PyGen_Send + 0x2a2 (0x5635de63d672 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #31: _PyMethodDef_RawFastCallKeywords + 0x8c (0x5635de64450c in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #32: _PyMethodDescr_FastCallKeywords + 0x4f (0x5635de64486f in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #33: _PyEval_EvalFrameDefault + 0x4c4c (0x5635de6a007c in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #34: _PyFunction_FastCallKeywords + 0xfb (0x5635de64379b in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #35: _PyEval_EvalFrameDefault + 0x416 (0x5635de69b846 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #36: _PyFunction_FastCallKeywords + 0xfb (0x5635de64379b in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #37: _PyEval_EvalFrameDefault + 0x6a0 (0x5635de69bad0 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #38: _PyEval_EvalCodeWithName + 0x2f9 (0x5635de5e14f9 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #39: _PyFunction_FastCallDict + 0x400 (0x5635de5e2800 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #40: _PyObject_Call_Prepend + 0x63 (0x5635de5f9c43 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #41: PyObject_Call + 0x6e (0x5635de5ee95e in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #42: _PyEval_EvalFrameDefault + 0x1e20 (0x5635de69d250 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #43: _PyEval_EvalCodeWithName + 0x5da (0x5635de5e17da in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #44: _PyFunction_FastCallKeywords + 0x387 (0x5635de643a27 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #45: _PyEval_EvalFrameDefault + 0x14ce (0x5635de69c8fe in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #46: <unknown function> + 0x171cc6 (0x5635de63ccc6 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #47: <unknown function> + 0x171ecb (0x5635de63cecb in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #48: _PyMethodDef_RawFastCallKeywords + 0xe9 (0x5635de644569 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #49: _PyCFunction_FastCallKeywords + 0x21 (0x5635de644801 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #50: _PyEval_EvalFrameDefault + 0x4755 (0x5635de69fb85 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #51: _PyEval_EvalCodeWithName + 0x5da (0x5635de5e17da in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #52: _PyFunction_FastCallKeywords + 0x387 (0x5635de643a27 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #53: _PyEval_EvalFrameDefault + 0x6a0 (0x5635de69bad0 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #54: <unknown function> + 0x171cc6 (0x5635de63ccc6 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #55: <unknown function> + 0x171ecb (0x5635de63cecb in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #56: _PyMethodDef_RawFastCallKeywords + 0xe9 (0x5635de644569 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #57: _PyCFunction_FastCallKeywords + 0x21 (0x5635de644801 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #58: _PyEval_EvalFrameDefault + 0x4755 (0x5635de69fb85 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #59: _PyEval_EvalCodeWithName + 0x5da (0x5635de5e17da in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #60: _PyFunction_FastCallKeywords + 0x387 (0x5635de643a27 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #61: _PyEval_EvalFrameDefault + 0x416 (0x5635de69b846 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #62: <unknown function> + 0x171cc6 (0x5635de63ccc6 in /home/julian/miniconda3/envs/myrtlespeech/bin/python)
frame #63: <unknown function> + 0x171ecb (0x5635de63cecb in /home/julian/miniconda3/envs/myrtlespeech/bin/python)


tensor([[[0.0000, 0.0000, 0.1624, 0.0000, 0.2714, 0.8505, 0.0000, 1.6649],
         [0.1109, 0.0000, 0.0000, 0.8201, 0.3121, 0.0000, 0.0000, 0.0000]]],
       grad_fn=<DifferentiableGraphBackward>)

## rnnt

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en_2L.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())


In [None]:
log_dir = '/home/julian/exp/rnnt/wer_down/2L/2/'


In [None]:
# create all components for config
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config, accumulation_steps=2)
seq_to_seq

In [None]:
load_model = True
epoch = 68
training_state = {}
if load_model:
    fp = log_dir + f'state_dict_{epoch}.pt'
    #fp = '/home/julian/exp/rnnt/wer_down/2D/1/model_saved.pt'
    training_state = load_seq_to_seq(seq_to_seq, fp)
    
seq_to_seq.model.eval()