In [1]:
import os
import glob
import torch
import onnx
import numpy as np
import sys

from onnxsim import simplify

sys.path.append("../../")

In [2]:
# Get the pretrained models
print("Pre-trained models available:")

pt_models = glob.glob('../../pretrained_models/*.pt')
for model_name in pt_models:
    print(model_name)


Pre-trained models available:
../../pretrained_models/Improved_Sudormrf_U36_Bases2048_WSJ02mix.pt
../../pretrained_models/Improved_Sudormrf_U16_Bases512_WSJ02mix.pt
../../pretrained_models/Improved_Sudormrf_U16_Bases2048_WHAMRexclmark.pt
../../pretrained_models/GroupCom_Sudormrf_U8_Bases512_WSJ02mix.pt
../../pretrained_models/Improved_Sudormrf_U36_Bases4096_WHAMRexclmark.pt


In [3]:
'''utils'''


def simplify_model(ckpt, save_path, input_shapes, dynamic_input_shape=True):
    onnx_model = onnx.load(ckpt)
    model_simp, check = simplify(
        onnx_model, dynamic_input_shape=dynamic_input_shape, input_shapes=input_shapes)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_simp, save_path)
    print('finished exporting onnx')


def export_onnx(model, inputs_tensor, onnx_file_name, **kwargs):
    with torch.no_grad():
        result = model(inputs_tensor)
        torch.onnx.export(
            model,
            inputs_tensor,
            onnx_file_name,
            **kwargs,
        )
        simplify_model(onnx_file_name, onnx_file_name.replace(
            '.onnx', '.simplify.onnx'), {'x': inputs_tensor.shape})


In [4]:
'''convert groupcomm_sudormrf_v2 pt model to onnx'''

import sudo_rm_rf.dnn.models.groupcomm_sudormrf_v2 as sudormrf_gc_v2

old_model = torch.load(
    "../../pretrained_models/GroupCom_Sudormrf_U8_Bases512_WSJ02mix.pt")

print({
    'in_audio_channels': old_model.in_audio_channels,
    'out_channels': old_model.out_channels,
    'in_channels': old_model.in_channels,
    'num_blocks': old_model.num_blocks,
    'upsampling_depth': old_model.upsampling_depth,
    'enc_kernel_size': old_model.enc_kernel_size,
    'enc_num_basis': old_model.enc_num_basis,
    'num_sources': old_model.num_sources,
    'group_size': 16,
})

new_nodel = sudormrf_gc_v2.GroupCommSudoRmRfExp(
    in_audio_channels=old_model.in_audio_channels,
    out_channels=old_model.out_channels,
    in_channels=old_model.in_channels,
    num_blocks=old_model.num_blocks,
    upsampling_depth=old_model.upsampling_depth,
    enc_kernel_size=old_model.enc_kernel_size,
    enc_num_basis=old_model.enc_num_basis,
    num_sources=old_model.num_sources,
    group_size=16,
)

with torch.no_grad():
    new_nodel.load_state_dict(old_model.state_dict())

tensor = torch.randn(1, 1, 32000).float()

print(tensor.shape)

onnx_file_name = os.path.join(
    '../../pretrained_models/GroupCom_Sudormrf_U8_Bases512_WSJ02mix.onnx')

export_onnx(
    new_nodel, tensor, onnx_file_name,
    **{
        'verbose': False,
        'input_names': ['x'],
        'output_names': ['y'],
        'dynamic_axes': {
            'x': {0: 'batch_size', 2: "samples"},
            'y': {0: 'batch_size', 2: "samples"}
        },
        'opset_version': 13,
    }
)


{'in_audio_channels': 1, 'out_channels': 256, 'in_channels': 512, 'num_blocks': 8, 'upsampling_depth': 5, 'enc_kernel_size': 21, 'enc_num_basis': 512, 'num_sources': 2, 'group_size': 16}
torch.Size([1, 1, 32000])


  torch.nn.init.xavier_uniform(self.encoder.weight)
  torch.nn.init.xavier_uniform(self.decoder.weight)


finished exporting onnx


In [5]:
'''convert Improved_Sudormrf_U16_Bases512_WSJ02mix pt model to onnx'''

import sudo_rm_rf.dnn.models.improved_sudormrf as improved_sudormrf

name = 'Improved_Sudormrf_U16_Bases512_WSJ02mix'

old_model = torch.load(f"../../pretrained_models/{name}.pt")

print({
    'out_channels': old_model.out_channels,
    'in_channels': old_model.in_channels,
    'num_blocks': old_model.num_blocks,
    'upsampling_depth': old_model.upsampling_depth,
    'enc_kernel_size': old_model.enc_kernel_size,
    'enc_num_basis': old_model.enc_num_basis,
    'num_sources': old_model.num_sources,
})

new_nodel = improved_sudormrf.SuDORMRFExp(
    out_channels=old_model.out_channels,
    in_channels=old_model.in_channels,
    num_blocks=old_model.num_blocks,
    upsampling_depth=old_model.upsampling_depth,
    enc_kernel_size=old_model.enc_kernel_size,
    enc_num_basis=old_model.enc_num_basis,
    num_sources=old_model.num_sources,
)

with torch.no_grad():   
    new_nodel.load_state_dict(old_model.state_dict())

tensor = torch.randn(1, 1, 32000).float()

print(tensor.shape)

onnx_file_name = os.path.join(f'../../pretrained_models/{name}.onnx')

export_onnx(
    new_nodel, tensor, onnx_file_name,
    **{
        'verbose': False,
        'input_names': ['x'],
        'output_names': ['y'],
        'dynamic_axes': {
            'x': {0: 'batch_size', 2: "samples"},
            'y': {0: 'batch_size', 2: "samples"}
        },
        'opset_version': 13,
    }
)


{'out_channels': 256, 'in_channels': 512, 'num_blocks': 16, 'upsampling_depth': 5, 'enc_kernel_size': 21, 'enc_num_basis': 512, 'num_sources': 2}
torch.Size([1, 1, 32000])


  torch.nn.init.xavier_uniform(self.encoder.weight)
  torch.nn.init.xavier_uniform(self.decoder.weight)


finished exporting onnx


In [6]:
'''convert Improved_Sudormrf_U16_Bases2048_WHAMRexclmark pt model to onnx'''

import sudo_rm_rf.dnn.models.improved_sudormrf as improved_sudormrf

name = 'Improved_Sudormrf_U16_Bases2048_WHAMRexclmark'

old_model = torch.load(f"../../pretrained_models/{name}.pt")

print({
    'out_channels': old_model.out_channels,
    'in_channels': old_model.in_channels,
    'num_blocks': old_model.num_blocks,
    'upsampling_depth': old_model.upsampling_depth,
    'enc_kernel_size': old_model.enc_kernel_size,
    'enc_num_basis': old_model.enc_num_basis,
    'num_sources': old_model.num_sources,
})

new_nodel = improved_sudormrf.SuDORMRFExp(
    out_channels=old_model.out_channels,
    in_channels=old_model.in_channels,
    num_blocks=old_model.num_blocks,
    upsampling_depth=old_model.upsampling_depth,
    enc_kernel_size=old_model.enc_kernel_size,
    enc_num_basis=old_model.enc_num_basis,
    num_sources=old_model.num_sources,
)

with torch.no_grad():
    new_nodel.load_state_dict(old_model.state_dict())

tensor = torch.randn(1, 1, 32000).float()

print(tensor.shape)

onnx_file_name = os.path.join(f'../../pretrained_models/{name}.onnx')

export_onnx(
    new_nodel, tensor, onnx_file_name,
    **{
        'verbose': False,
        'input_names': ['x'],
        'output_names': ['y'],
        'dynamic_axes': {
            'x': {0: 'batch_size', 2: "samples"},
            'y': {0: 'batch_size', 2: "samples"}
        },
        'opset_version': 13,
    }
)


{'out_channels': 256, 'in_channels': 512, 'num_blocks': 16, 'upsampling_depth': 5, 'enc_kernel_size': 21, 'enc_num_basis': 2048, 'num_sources': 2}
torch.Size([1, 1, 32000])
finished exporting onnx
