In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import *
import onnx
import onnxruntime
import numpy as np

def make_mapping_node(amap, input_name, output_name, keys='strings', values='int64s'):
    valid_key_values = ['strings', 'int64s', 'floats']
    assert keys in valid_key_values and values in valid_key_values, f'Keys or Values not in valid set of {valid_key_values}'
    other_inps = {f'keys_{keys}': amap.keys(), f'values_{values}': amap.values()}
    n = onnx.helper.make_node('LabelEncoder', inputs=[input_name], outputs=[output_name], domain='ai.onnx.ml', **other_inps)
    return n
    

class MLTModel(nn.Module):
    def __init__(self, n_items: int, n_factors: int):
        super(MLTModel, self).__init__()
        self.emb = nn.Embedding(n_items, n_factors)
        self.it2ind = {str(i): i for i in range(n_items)}
        self.ind2it = {v:k for k,v in self.it2ind.items()}
    
    def forward(self, ind: torch.Tensor, size: int = 5)-> Tuple[torch.Tensor, torch.Tensor]:
        u = self.emb(ind)
        scores = u @ self.emb.weight.t()
        s, i = scores.topk(size)
        return s.squeeze(), i.squeeze()
    
    # most of the export code is here to map strings in and out of the network
    # which is not supported by default when exporting a PyTorch graph --> ONNX
    def add_onnx_mappings(self, path, op_set):
        import onnx
        om = onnx.load(path)
        
        #TODO obviously simplify this, and possibly make it not neccessary to think about for the modeller.
        # Map input
        replace_node_name = 'contentId_ind'
        in_node_value = onnx.helper.make_tensor_value_info('contentId', onnx.TensorProto.STRING, [None])
        n = make_mapping_node(self.it2ind, in_node_value.name, replace_node_name)

        existing_node = list(filter(lambda x: x.name == replace_node_name, om.graph.input))[0]
        om.graph.input.remove(existing_node)
        om.graph.input.append(in_node_value)
        # TODO something that sorts the graph nodes? Since they need to be placed in the correct position.
        om.graph.node.insert(0, n)
        
        # Map output
        out_node_value = onnx.helper.make_tensor_value_info('contentId', onnx.TensorProto.STRING, [None])
        replace_node_name = 'indices'
        n = make_mapping_node(self.ind2it, replace_node_name, out_node_value.name, keys='int64s', values='strings')
        
        existing_node = list(filter(lambda x: x.name == replace_node_name, om.graph.output))[0]
        om.graph.output.remove(existing_node)
        om.graph.output.append(out_node_value)
        om.graph.node.append(n)
        
        # finalize model
        model = onnx.helper.make_model(om.graph, opset_imports=[onnx.helper.make_opsetid('ai.onnx.ml', 2), onnx.helper.make_opsetid('', op_set)])
        onnx.checker.check_model(model)
        onnx.save(model, path)
        
    
    def export(self, path='model.onnx', onnx_op_set=16):
        input_names = ['contentId_ind', 'size']
        output_names = ['scores', 'indices'] # this should be a convention? 
        # Dynamic axes does nothing since we re-export the model atm.
        dynamic_axes = {name: [0] for name in output_names}
        jit_model = torch.jit.script(self)
        dummy_input = (torch.ones(1).long(), 3)
        dummy_input = tuple(1 for _ in input_names)
        torch.onnx.export(jit_model, dummy_input, path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, verbose=True, opset_version=onnx_op_set)
        self.add_onnx_mappings(path, onnx_op_set)

In [61]:
a = MLTModel(10000, 128)

In [62]:
a.export()

Exported graph: graph(%contentId_ind : Long(requires_grad=0, device=cpu),
      %size : Long(device=cpu),
      %emb.weight : Float(10000, 128, strides=[128, 1], requires_grad=0, device=cpu),
      %onnx::MatMul_12 : Float(128, 10000, strides=[1, 128], requires_grad=0, device=cpu)):
  %u : Float(128, strides=[1], device=cpu) = onnx::Gather[onnx_name="Gather_0"](%emb.weight, %contentId_ind) # /home/n651042/micromamba/envs/onnx/lib/python3.10/site-packages/torch/nn/functional.py:2199:11
  %scores.1 : Float(10000, strides=[1], device=cpu) = onnx::MatMul[onnx_name="MatMul_1"](%u, %onnx::MatMul_12) # /tmp/ipykernel_5861/1164184637.py:26:17
  %onnx::Reshape_6 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="Constant_2"]() # /tmp/ipykernel_5861/1164184637.py:27:15
  %onnx::TopK_7 : Long(1, strides=[1], device=cpu) = onnx::Reshape[allowzero=0, onnx_name="Reshape_3"](%size, %onnx::Reshape_6) # /tmp/ipykernel_5861/1164184637.py:27:15
  %s : Float(*, device=cpu), %i : Lon

ValidationError: Graph must be in single static assignment (SSA) form, however 'contentId' has been used as output names multiple times.

In [59]:
import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession('model.onnx')
inp = np.array(["5"])
ort_inputs = {'contentId': inp, 'size': np.ones(1, dtype=np.int64) * 5}
ort_session.run(None, ort_inputs)

2022-10-27 12:26:59.361970707 [W:onnxruntime:, execution_frame.cc:812 VerifyOutputSizes] Expected shape from model of {} does not match actual shape of {5} for output scores


[array([154.96382 ,  46.32555 ,  42.35973 ,  41.462833,  41.073437],
       dtype=float32),
 array(['5', '9418', '9878', '7925', '5728'], dtype=object)]

In [None]:
#TODO actually load real weights from a generated model
#TODO add user --> item
#TODO add dynamic_axes to final output model, to stop warning