In [106]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import *
import onnx
import onnxruntime
import numpy as np

def make_mapping_node(amap, input_name, output_name, keys='strings', values='int64s'):
    valid_key_values = ['strings', 'int64s', 'floats']
    assert keys in valid_key_values and values in valid_key_values, f'Keys or Values not in valid set of {valid_key_values}'
    other_inps = {f'keys_{keys}': amap.keys(), f'values_{values}': amap.values()}
    return onnx.helper.make_node('LabelEncoder', inputs=[input_name], outputs=[output_name], domain='ai.onnx.ml', **other_inps)
    
def replace_node_value(input_output, new_node_value, replace_node_name):
    existing_node = list(filter(lambda x: x.name == replace_node_name, input_output))[0]
    input_output.remove(existing_node)
    input_output.append(new_node_value)
    

class MLTModel(nn.Module):
    def __init__(self, it2ind_mapping: Dict[str, int], vectors):
        super(MLTModel, self).__init__()
        n_items = len(vectors)
        n_factors = len(vectors[0])
        self.emb = nn.Embedding(n_items, n_factors, _weight=torch.from_numpy(vectors))
        self.it2ind = it2ind_mapping
        self.ind2it = {v:k for k,v in self.it2ind.items()}
    
    # model forward pass
    def forward(self, ind: torch.Tensor, size: int = 5)-> Tuple[torch.Tensor, torch.Tensor]:
        u = self.emb(ind)
        scores = u @ self.emb.weight.t()
        s, i = scores.topk(size)
        return s.squeeze(), i.squeeze()
    
    # most of the export code is here to map strings in and out of the network
    # which is not supported by default when exporting a PyTorch graph --> ONNX
    def add_onnx_mappings(self, path, op_set):
        import onnx
        om = onnx.load(path)
        
        #TODO obviously simplify this, and possibly make it not neccessary to think about for the modeller.
        # Map input
        in_node_value = onnx.helper.make_tensor_value_info('contentId', onnx.TensorProto.STRING, [None])
        replace_node_name = 'contentId_ind'
        n = make_mapping_node(self.it2ind, in_node_value.name, replace_node_name)
        om.graph.node.insert(0, n)
        replace_node_value(om.graph.input, in_node_value, replace_node_name)
        
        # Map output
        out_node_value = onnx.helper.make_tensor_value_info('contentIdd', onnx.TensorProto.STRING, [None])
        replace_node_name = 'indices'
        n = make_mapping_node(self.ind2it, replace_node_name, out_node_value.name, keys='int64s', values='strings')
        om.graph.node.append(n)
        replace_node_value(om.graph.output, out_node_value, replace_node_name)
        
        # finalize model
        model = onnx.helper.make_model(om.graph, opset_imports=[onnx.helper.make_opsetid('ai.onnx.ml', 2), onnx.helper.make_opsetid('', op_set)])
        onnx.checker.check_model(model)
        onnx.save(model, path)
        
    
    def export(self, path='model.onnx', onnx_op_set=16):
        input_names = ['contentId_ind', 'size']
        output_names = ['scores', 'indices'] # this should be a convention? 
        # Dynamic axes does nothing since we re-export the model atm.
        dynamic_axes = {name: [0] for name in output_names}
        jit_model = torch.jit.script(self)
        dummy_input = (torch.ones(1).long(), 3)
        dummy_input = tuple(1 for _ in input_names)
        torch.onnx.export(jit_model, dummy_input, path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, verbose=True, opset_version=onnx_op_set)
        self.add_onnx_mappings(path, onnx_op_set)

In [107]:
#TODO actually load real weights from a generated model
#TODO add user --> item
#TODO add dynamic_axes to final output model, to stop warning
#TODO meta level error handling. etc, having size>num_items_in_model

In [68]:
om = onnx.load('model.onnx')

In [72]:
list(map(lambda x: {x.name, x.op_type}, om.graph.node))

[{'', 'LabelEncoder'},
 {'Gather', 'Gather_0'},
 {'MatMul', 'MatMul_1'},
 {'Constant', 'Constant_2'},
 {'Reshape', 'Reshape_3'},
 {'TopK', 'TopK_4'},
 {'Squeeze', 'Squeeze_5'},
 {'Squeeze', 'Squeeze_6'},
 {'', 'LabelEncoder'}]

# Use some real data
Pull down a model, and import the newline-delimited json file.

In [125]:
import json

with open('data/model.json', 'r') as f:
    rows = [json.loads(line) for line in f]

In [126]:
real_model_data = {x['contentId']: x['factors'] for x in rows}

vectors = np.array([x for x in real_model_data.values()])
name2ind = {n:i for i, n in enumerate(real_model_data.keys())}

In [108]:
model = MLTModel(name2ind, vectors)
model.export()

In [124]:
# just make sure all or ducks are in a row
np.isclose(vectors[0], np.array(rows[0]['factors'])).all() , np.isclose(model.emb.weight[0].detach().numpy(), np.array(rows[0]['factors'])).all()

(True, True)

In [116]:
list(name2ind.keys())[110:115]

['lover', 'lunsj', 'mamma', 'marit', 'match']

In [118]:
import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession('model.onnx')
inp = np.array(["ski-vm-junior-og-u23"])
ort_inputs = {'contentId': inp, 'size': np.ones(1, dtype=np.int64) * 10}
res = ort_session.run(None, ort_inputs)
{n:s for n, s in zip(res[1], res[0])}

{'ski-vm-junior-og-u23': 1.0000000000000002,
 'dama-til': 0.999999999638621,
 'KOIF43007811': 0.9999999995384318,
 'KMNO10008822': 0.9999999992626512,
 'KOIF75000417': 0.9999999976467383,
 'friidrett-nm': 0.9999999963539263,
 'verdens-beste-landslag': 0.9999999922260039,
 'kriger': 0.9999999904107656,
 'lunsj': 0.9999999903992889,
 'der-ingen-skulle-tru-at-nokon-kunne-bu': 0.9999999876724911}

2022-10-27 13:28:00.150713884 [W:onnxruntime:, execution_frame.cc:812 VerifyOutputSizes] Expected shape from model of {} does not match actual shape of {10} for output scores
