# Conv-TasNet

In [2]:
import onnx
import onnxruntime as ort
import numpy as np
from onnx import shape_inference
import sys
import torch
import numpy as np
import os

```markdown
You need to clone the Conv-TasNet repo here and rename the folder to network:
`git@github.com:naplab/Conv-TasNet.git`
```

Define HyperParameters of Conv-TasNet

In [3]:
L =32
stride = L // 2

num_spk = 1 # number of speakers , Currently only 1 is supported in the streaming mode
casual = True # Due to the gLN layer(Therefore, it uses the statistics from all sequence from past to future), the non-causal model cannot run in the streaming mode.

Create the Conv-TasNet with random weights and input

In [None]:
from network.conv_tasnet import TasNet

nnet = TasNet(enc_dim=128, feature_dim=64, layer=2, stack=3, 
                kernel=3, num_spk=num_spk, causal=casual) # Currently supports only 1 speaker

receptive_field = nnet.receptive_field

onnx_filename = "convtasnet_orig.onnx"
current_directory = os.getcwd()
onnx_path = os.path.join(current_directory, "", onnx_filename)

x = torch.rand(1, 32000) # a dummy input
torch.onnx.export(nnet, x, onnx_path, input_names=["input"], output_names=["output"])
print(f" Model exported to {onnx_path}")

T = int((receptive_field - L) / (L // 2)) + 1 # Number of timestamps of the receptive field

print(f"receptive field is { nnet.receptive_field} samples and {T} timesteps ")

Run the non-streaming model

In [None]:
ns_model_path = "convtasnet_orig.onnx"
ort_sess = ort.SessionOptions()
ort_sess.enable_profiling = True
ort_sess  = ort.InferenceSession(ns_model_path)
ort_sess.enable_profiling = True

# input_data = np.random.rand(1, 32000).astype(np.float32)
input_data = x.numpy().astype(np.float32)

onnx_model = onnx.load(ns_model_path)
# Run the model
input_name = onnx_model.graph.input[0].name
final_output_model = ort_sess.run(None, {input_name: input_data})

# 'output' contains the model's output (replace 'output_name' with the actual output name in your model)
print("Model output:", final_output_model[0].shape)
print(final_output_model[0])

Convert the non-streaming model to streaming one

In [None]:
from streamease.onnx_streamer.streamer import StreamingConverter

streaming_name = "torch_streaming_model.onnx"
time_steps = 1
model = onnx.load(ns_model_path)
streaming =  StreamingConverter(model, time_steps=time_steps)

streaming.run()
streaming.print_info()
streaming.save_streaming_onnx('.', onnx_filename=streaming_name)

In this step, we need to prepare the correct input for the streaming. For instance, in Conv-TasNet, the following padding function is used before the first layer. 

In [7]:
from torch.autograd import Variable
def pad_signal( input, stride, win):

    # input is the waveforms: (B, T) or (B, 1, T)
    # reshape and padding
    if input.dim() not in [2, 3]:
        raise RuntimeError("Input can only be 2 or 3 dimensional.")
    
    if input.dim() == 2:
        input = input.unsqueeze(1)
    batch_size = input.size(0)
    nsample = input.size(2)
    rest = win - (stride + nsample % win) % win
    if rest > 0:
        pad = Variable(torch.zeros(batch_size, 1, rest)).type(input.type())
        input = torch.cat([input, pad], 2)
    
    pad_aux = Variable(torch.zeros(batch_size, 1, stride)).type(input.type())
    input = torch.cat([pad_aux, input, pad_aux], 2)
    return input, rest

In [None]:
from utils.onnx_inference.streaming_inference import Inference

s_test = Inference(streaming_name, receptive_field=T, causal=casual, time_steps=1)
s_test.init_buffers()


input_data_pad, rest = pad_signal(torch.tensor(input_data), stride, L)

input_data_re = np.reshape(input_data_pad, (input_data_pad.shape[-1]))
# input_data_re = np.pad(input_data_re, (receptive_field, 0), 'constant', constant_values=(0, 0))
str_output = s_test.run_audio(input_data_re, frame_length= L, stride= stride, transpose=False, dim=2 )
out1 = str_output[0][stride:-(rest + stride)]
print(out1)


Explain the cLN code here. 

```markdown
You might notice some differences in the outputs, particularly for timesteps greater than 1. This is because the training model uses cLN on the entire dataset (not just the receptive field). If you set the receptive field in the following code to match that of the training model, the results will be identical.
```

In [None]:
from utils.onnx_inference.streaming_inference import Inference

input_data_pad, rest = pad_signal(torch.tensor(input_data), stride, L)
T = (input_data_pad.shape[-1] - L) // stride + 1 
print(f"receptive field is { nnet.receptive_field} samples and {T} timesteps ")

input_data_re = np.reshape(input_data_pad, (input_data_pad.shape[-1]))
# input_data_re = np.pad(input_data_re, (receptive_field, 0), 'constant', constant_values=(0, 0))
s_test = Inference(streaming_name, receptive_field=T, causal=casual, time_steps=1)
s_test.init_buffers()
str_output = s_test.run_audio(input_data_re, frame_length= L, stride= stride, transpose=False, dim=2 )
out1 = str_output[0][stride:-(rest + stride)]
print(out1)

## NNTOOL

In [10]:
from nntool.api import NNGraph
from nntool.api.utils import model_settings, quantization_options, tensor_plot
import logging
# nntool_log = logging.getLogger('nntool')
# nntool_log.setLevel(logging.ERROR)

In [None]:
streaming_model = "torch_streaming_model.onnx" 
s_model = NNGraph.load_graph(streaming_model, use_onnx_names=True)
# s_model.draw()
s_model.adjust_order()
# s_model.draw()
# The equivalent of the fusions --scale8 command. The fusions method can be given a series of fusions to apply
# fusions('name1', 'name2', etc)
s_model.fusions('scaled_match_group')


In [None]:
from utils.nntool_inference.streaming_inference import Inference
print(f"Receptive field is {T}")
s_test = Inference(s_model, streaming_model, time_steps=1, receptive_field=T, causal=True)

s_test.init_buffers()
input_data_pad, rest = pad_signal(torch.tensor(input_data), stride, L)
input_data_re = np.reshape(input_data_pad, (input_data_pad.shape[-1]))
# input_data = np.reshape(input_data, (1,32,28))

# input_tra = np.transpose(input_data, (0,2,1))
print(input_data_re.shape)
nn_str_output = s_test.run_audio(input_data_re, frame_length= L, stride= stride, transpose=True, dim=2 )
print(nn_str_output[0][stride:-(rest + stride)])