In [None]:
import os
import onnx
import torch
import torchaudio
import hydra
import IPython.display as ipd

from hydra import compose, initialize
from omegaconf import OmegaConf
from torch import nn
from dyn_experiments.utils import register_resolvers, pretty_configs



In [6]:
# reset hydra (just in case)
hydra.core.global_hydra.GlobalHydra.instance().clear()

# initialize hydra
initialize(version_base=None, config_path="../dyn_experiments/config/")

# tell hydra to parse the configs (use `overrides` to pass custom arguments)
config = compose(config_name="train.yaml", overrides=[
    "model=nsnet2_baseline",
    # "checkpoint_path=/path/to/checkpoint.ckpt",
])

# tell hydra to use custom configuration "resolvers"
register_resolvers()

# print the model config
print(f"Model configs:\n{pretty_conf
sdskcigs(config.model)}")

# instantiate the model
model = hydra.utils.instantiate(config.model)

# load checkpoint
model_weights = '/home/rmiccini/checkpoints/nsnet2/baseline/pytorch_model.bin'
state_dict = torch.load(model_weights, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")

# show the model
model.eval()
print(model)

Model configs:
_target_: dyn_experiments.models.NsNet2
hidden_1: 400
hidden_2: 400
hidden_3: 600
loss:
  _target_: dyn_experiments.models.losses.DynCompMSE
  normalize: true
  normalize_framelen: 512
  normalize_threshold: 0.025
n_features: 257
n_fft: 512
postproc:
  _target_: torchaudio.transforms.InverseSpectrogram
  n_fft: 512
preproc:
  _target_: torchaudio.transforms.Spectrogram
  n_fft: 512
  power: null

Missing keys: []
Unexpected keys: []
NsNet2(
  (loss_function): DynCompMSE()
  (preproc): Spectrogram()
  (postproc): InverseSpectrogram()
  (fc1): Linear(in_features=257, out_features=400, bias=True)
  (rnn1): GRU(400, 400, batch_first=True)
  (rnn2): GRU(400, 400, batch_first=True)
  (fc2): Linear(in_features=400, out_features=600, bias=True)
  (fc3): Linear(in_features=600, out_features=600, bias=True)
  (fc4): Linear(in_features=600, out_features=257, bias=True)
)


# convert to ONNX (stateless)

In [None]:
# monkey-patch model forward to take an additional hidden state
def forward(self, log_stft_noisy, h1, h2):
    x = self.fc1(log_stft_noisy)
    x, h1n = self.rnn1(x, h1)
    x, h2n = self.rnn2(x, h2)
    x = self.fc2(x)
    x = nn.functional.relu(x)
    x = self.fc3(x)
    x = nn.functional.relu(x)
    x = self.fc4(x)
    x = torch.sigmoid(x)
    return x, h1n, h2n


def inference_forward(self, stft_noisy, h1=None, h2=None):
    # log power
    log_stft_noisy = torch.log(stft_noisy.abs() ** 2 + self.eps)
    # sort shape
    log_stft_noisy = log_stft_noisy.squeeze(1).permute(0, 2, 1)
    # run neural network layers 
    x, h1n, h2n = self.forward(log_stft_noisy, h1, h2)
    # sort shape
    mask_pred = x.permute(0, 2, 1).unsqueeze(1)
    # apply mask
    stft_pred = stft_noisy * 
sdskcmask_pred
    return stft_pred, h1n, h2n


model.forward = forward.__get__(model, type(model))
model.inference_forward = inference_forward.__get__(model, type(model))

In [None]:
# convert to onnx
path = "./nsnet2_simple.onnx"
sdskc
in1 = torch.randn(1, 1, 257) # input shape: [batch_size, time_frames, freqs]
h1 = torch.randn(1, 1, 400) # input shape: [batch_size, time_frames, hidden_size]
h2 = torch.randn(1, 1, 400) # input shape: [batch_size, time_frames, hidden_size]
dummy_input = (in1, h1, h2)


torch.onnx.export(
    model, 
    dummy_input,
    path,
    output_names=['mask_pred', 'h1n', 'h2n'],
    input_names = ['in_noisy', 'h1', 'h2'],
    verbose=False)

sdskc
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


# convert to ONNX (cell, stateless)

In [11]:
# replace gru with gru cell
model.rnn1 = nn.GRUCell(model.rnn1.input_size, model.rnn1.hidden_size)
model.rnn2 = nn.GRUCell(model.rnn2.input_size, model.rnn2.hidden_size)

# monkey-patch model forward to take an additional hidden state
def forward(self, log_stft_noisy, h1, h2):
    x = self.fc1(log_stft_noisy)
    h1n = self.rnn1(x, h1)
    h2n = self.rnn2(h1n, h2)
    x = self.fc2(h2n)
    x = nn.functional.relu(x)
    x = self.fc3(x)
    x = nn.functional.relu(x)
    x = self.fc4(x)
    x = torch.sigmoid(x)
    return x, h1n, h2n


def inference_forward(self, stft_noisy, h1=None, h2=None):
    # log power
    log_stft_noisy = torch.log(stft_noisy.abs() ** 2 + self.eps)
    # sort shape
    log_stft_noisy = log_stft_noisy.squeeze(1).permute(0, 2, 1)
    # run neural network layers 
    x, h1n, h2n = self.forward(log_stft_noisy, h1, h2)
    # sort shape
    mask_pred = x.permute(0, 2, 1).unsqueeze(1)
    # apply mask
    stft_pred = stft_noisy * mask_pred
    return stft_pred, h1n, h2n


model.forward = forward.__get__(model, type(model))
model.inference_forward = inference_forward.__get__(model, type(model))


In [12]:
# convert to onnx
path = "./nsnet2_simple_cell.onnx"
in1 = torch.randn(1, 257) # input shape: [batch_size, freqs]
h1 = torch.randn(1, 400) # input shape: [batch_size, hidden_size]
h2 = torch.randn(1, 400) # input shape: [batch_size, hidden_size]
dummy_input = (in1, h1, h2)


torch.onnx.export(
    model, 
    dummy_input,
    path,
    output_names=['mask_pred', 'h1n', 'h2n'],
    input_names = ['in_noisy', 'h1', 'h2'],
    verbose=False)

onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
