In [31]:
import tf2onnx.convert as conv
from sbNative.debugtools import log
import tensorflow as tf
import os
from os import path


In [32]:

model_path = path.join(os.getcwd(),"src","models","model_with_frets_and_strings.h5")
model = tf.keras.models.load_model(model_path)
model.output_names=['output']
input_signature = [tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype, name='digit')]


conv.from_keras(model, input_signature=input_signature, opset=13, output_path='model_with_frets_and_strings.onnx')

ERROR:tf2onnx.tfonnx:rewriter <function rewrite_constant_fold at 0x000001E264391C60>: exception `np.cast` was removed in the NumPy 2.0 release. Use `np.asarray(arr, dtype=dtype)` instead.


(ir_version: 7
 producer_name: "tf2onnx"
 producer_version: "1.16.1 15c810"
 graph {
   node {
     input: "digit"
     input: "sequential_16_1/normalization_14_1/Sub/y:0"
     output: "sequential_16_1/normalization_14_1/Sub:0"
     name: "sequential_16_1/normalization_14_1/Sub"
     op_type: "Sub"
   }
   node {
     input: "sequential_16_1/normalization_14_1/Sub:0"
     input: "ConstantFolding/sequential_16_1/normalization_14_1/truediv_recip:0"
     output: "sequential_16_1/normalization_14_1/truediv:0"
     name: "sequential_16_1/normalization_14_1/truediv"
     op_type: "Mul"
   }
   node {
     input: "sequential_16_1/normalization_14_1/truediv:0"
     input: "sequential_16_1/dense_64_1/Cast/ReadVariableOp:0"
     output: "sequential_16_1/dense_64_1/MatMul:0"
     name: "sequential_16_1/dense_64_1/MatMul"
     op_type: "MatMul"
   }
   node {
     input: "sequential_16_1/dense_64_1/MatMul:0"
     input: "sequential_16_1/dense_64_1/BiasAdd/ReadVariableOp:0"
     output: "sequential

In [33]:
import onnx
import onnxruntime
import numpy as np
from itertools import product, chain
import time


def get_all_options(notes: list[int], max_fret: int = 24, strings: int = 6) -> list[tuple[int, int]]:
    ## note 0 is A440Hz at 5th fret, first string
    options = []
    string_offsets = [
        0,
        5,
        9,
        14,
        19,
        24
    ]
    for note in notes:
        opt = []
        for string in range(strings):
            fret = note + 5 + string_offsets[string]
            if fret < 0:
                continue
            if fret > max_fret:
                continue
            opt.append((string, fret))
        options.append(opt)
    return product(*options)

onnx_model = onnx.load('model_with_frets_and_strings.onnx')
onnx.checker.check_model(onnx_model)

def print_data_as_tab(data):
    for row in range(6):
        for column in range(len(data)):
            print("--", end="")
            if data[column][0] == row:
                num = str(data[column][1])
                if len(num) == 1:
                    print(num + "-", end="")
                else:
                    print(num, end="")
            else:
                print("--", end="")
        print()


ort_session = onnxruntime.InferenceSession('model_with_frets_and_strings.onnx')

# Check the model's input names
for input_meta in ort_session.get_inputs():
    print(input_meta.name)

begin = time.perf_counter_ns()
options = list(get_all_options([-5, -3, -2, -0, 2]))
input = [list(chain(*y)) for y in options]

# Update the input feed to match the expected input names
input_feed = {'digit': np.array(input, dtype=np.float32)}
output = ort_session.run(None, input_feed)

output = list(map(lambda x: float(list(x)[0]), output[0]))
options_outputs = zip(options, output)
options_outputs = sorted(options_outputs, key=lambda x: x[1], reverse=True)
print(f"Time taken: {(time.perf_counter_ns() - begin) / 1e6}ms to compute {len(options)} options")
for idx, (option, prob) in enumerate(options_outputs):
    print(f"top {idx+1} with probability of {prob}")
    print_data_as_tab(option)
    print()

digit
Time taken: 13.532ms to compute 3000 options
top 1 with probability of 0.9996461868286133
--------------------
--5---------------12
----------12--------
------16------------
--------------24----
--------------------

top 2 with probability of 0.999273419380188
--0-----------------
------------------12
--------------------
------16--17--------
--------------24----
--------------------

top 3 with probability of 0.9992437362670898
--0---------------7-
----------8---------
--------------------
------16------------
--------------24----
--------------------

top 4 with probability of 0.9992058873176575
--0---------------7-
------7---8---10----
--------------------
--------------------
--------------------
--------------------

top 5 with probability of 0.9991062879562378
--------------------
--5---------------12
------11--12--14----
--------------------
--------------------
--------------------

top 6 with probability of 0.9990607500076294
--------------------
--------------------
--9