In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import *

#TODO actually load real weights from a generated model
class ALSModel(nn.Module):
    def __init__(self, n_items: int, n_factors: int):
        super(ALSModel, self).__init__()
        #TODO add user --> item
        self.emb = nn.Embedding(n_items, n_factors)
        self.it2ind = {str(i): i for i in range(n_items)}# TODO use real values
        self.ind2it = {v:k for k,v in self.it2ind.items()}
        
    def map_out(self, scores: torch.Tensor, indices: torch.Tensor) -> Dict[str, float]:
        inds: List[int] = indices.squeeze().tolist()
        scorez: List[float] = scores.squeeze().tolist()
        return {self.ind2it[i]: s for i, s in zip(inds, scorez)}
    
    #def forward(self, ind: torch.Tensor, size: int = 30)-> Dict[str, float]:
    #def forward(self, ind: torch.Tensor)-> torch.Tensor:
    def forward(self, ind: torch.Tensor, size: int = 5)-> torch.Tensor:
        #ind = torch.tensor([self.it2ind[inp]]).long()
        u = self.emb(ind)
        scores = u @ self.emb.weight.t()
        s, i = scores.topk(size)
        #return self.map_out(s, i)
        return s.squeeze()
        
        
n_items, n_factors = (10000, 128)     
m = ALSModel(n_items, n_factors)
m2 = torch.jit.script(m)
    
#TODO can we actually map strings and stuff in ONNX? It seems like we cant???
#m("5", size=5), m2("5", size=5)


In [42]:
o = m2(torch.ones(1).long())
o, o.sort(descending=True)

(tensor([118.0973,  37.7835,  36.9943,  35.9663,  35.9311],
        grad_fn=<SqueezeBackward0>),
 torch.return_types.sort(
 values=tensor([118.0973,  37.7835,  36.9943,  35.9663,  35.9311],
        grad_fn=<SortBackward0>),
 indices=tensor([0, 1, 2, 3, 4])))

In [43]:
#x = (torch.ones(1).long())
x = (torch.ones(1).long(), 3)
file_path = "data/model.onnx"
torch.onnx.export(m2, x, file_path, verbose=True, opset_version=15)

graph(%ind.1 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %size.1 : Long(device=cpu),
      %emb.weight : Float(10000, 128, strides=[128, 1], requires_grad=0, device=cpu),
      %11 : Float(128, 10000, strides=[1, 128], requires_grad=0, device=cpu)):
  %u : Float(1, 128, strides=[128, 1], device=cpu) = onnx::Gather(%emb.weight, %ind.1) # /home/n651042/micromamba/envs/attention/lib/python3.10/site-packages/torch/nn/functional.py:2183:11
  %scores : Float(1, 10000, strides=[10000, 1], device=cpu) = onnx::MatMul(%u, %11) # /tmp/ipykernel_12857/2553407528.py:25:17
  %6 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}]() # /tmp/ipykernel_12857/2553407528.py:26:15
  %7 : Long(1, strides=[1], device=cpu) = onnx::Reshape[allowzero=0](%size.1, %6) # /tmp/ipykernel_12857/2553407528.py:26:15
  %s : Float(*, *, device=cpu), %i : Long(*, *, device=cpu) = onnx::TopK[axis=-1, largest=1, sorted=1](%scores, %7) # /tmp/ipykernel_12857/2553407528.py:26:15
  %10 : Float(3, strid

In [44]:
import onnx
# ensure a valid onnx file
onnx.checker.check_model(onnx.load(file_path))

In [45]:
import onnxruntime
import numpy as np

# sanity test outputs
ort_session = onnxruntime.InferenceSession(file_path)
inp = np.ones(1, dtype=np.int64) * 99
ort_inputs = {ort_session.get_inputs()[0].name: inp, ort_session.get_inputs()[1].name: inp}
ort_outs = ort_session.run(None, ort_inputs)

print(ort_outs)

[array([132.13054 ,  40.181793,  40.153183,  39.523598,  38.097862,
        37.84709 ,  37.330074,  36.924046,  36.715164,  36.378918,
        35.684757,  35.634724,  35.04893 ,  34.54045 ,  34.072304,
        33.694126,  33.0558  ,  32.82387 ,  32.82374 ,  32.617374,
        32.533703,  32.441418,  32.32119 ,  31.737667,  31.354448,
        30.98542 ,  30.5975  ,  30.574152,  30.384832,  30.371317,
        30.342276,  30.330914,  30.131212,  30.08917 ,  30.084301,
        30.064453,  29.63688 ,  29.489979,  29.363363,  29.354372,
        29.288212,  29.18877 ,  29.18855 ,  29.039358,  29.004469,
        28.882124,  28.814892,  28.809755,  28.714895,  28.598053,
        28.595556,  28.589615,  28.504133,  28.389652,  28.378061,
        28.31546 ,  28.30276 ,  28.286125,  28.201939,  28.05688 ,
        27.745224,  27.590948,  27.58031 ,  27.515045,  27.487463,
        27.479776,  27.441105,  27.385202,  27.330118,  27.30194 ,
        27.231031,  27.118498,  27.115707,  27.100737,  27.08

2022-10-26 11:22:12.099381850 [W:onnxruntime:, execution_frame.cc:812 VerifyOutputSizes] Expected shape from model of {3} does not match actual shape of {99} for output 10


In [33]:
ort_session.get_inputs()[1].name

'size.1'