In [9]:
import numba
import numpy as np
import cProfile
%load_ext memory_profiler

The memory_profiler extension is already loaded. To reload it, use:
  %reload_ext memory_profiler


In [10]:
import llvmlite.binding as llvm
llvm.set_option('', '--debug-only=loop-vectorize')

In [11]:
# vectorized outer product
einsum_outer = lambda a, b: np.einsum('...i,...j->...ij', a, b)

# shape
N = 1000
M = 500

test_a = np.random.rand(N, M)
test_b = np.random.rand(N, M)

In [12]:
cProfile.run('einsum_outer(test_a, test_b)')
%memit einsum_outer(test_a, test_b)

         105 function calls in 0.739 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.485    0.485 2922171395.py:2(<lambda>)
        1    0.085    0.085    0.569    0.569 <string>:1(<module>)
        2    0.000    0.000    0.000    0.000 base_events.py:731(time)
        5    0.000    0.000    0.000    0.000 einsumfunc.py:1049(_einsum_dispatcher)
        1    0.000    0.000    0.485    0.485 einsumfunc.py:1057(einsum)
        1    0.000    0.000    0.000    0.000 events.py:82(_run)
        1    0.000    0.000    0.000    0.000 history.py:839(_writeout_output_cache)
        1    0.000    0.000    0.000    0.000 ioloop.py:742(_run_callback)
        1    0.000    0.000    0.000    0.000 iostream.py:616(_flush)
        1    0.000    0.000    0.000    0.000 iostream.py:710(_flush_buffers)
        1    0.000    0.000    0.000    0.000 iostream.py:718(_rotate_buffers)
        1    0.000    0.000    0

In [13]:
@numba.jit(numba.float64[:,:,:](numba.float64[:,:], numba.float64[:,:]))
def numba_outer(a, b):
    n = a.shape[0]
    result = np.empty((n, M, M), dtype=np.float64)
    for i in numba.prange(n):
        for j in range(M):
            for k in range(M):
                result[i, j, k] = a[i, j] * b[i, k]
    return result

x = numba_outer(test_a, test_b)

In [14]:
cProfile.run('numba_outer(test_a, test_b)')
%memit numba_outer(test_a, test_b)

         92 function calls (90 primitive calls) in 0.087 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <frozen abc>:121(__subclasscheck__)
        2    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:1390(_handle_fromlist)
        1    0.086    0.086    0.086    0.086 <string>:1(<module>)
        2    0.000    0.000    0.000    0.000 attrsettr.py:43(__getattr__)
        2    0.000    0.000    0.000    0.000 attrsettr.py:66(_get_attr_opt)
        1    0.000    0.000    0.086    0.086 base_events.py:1894(_run_once)
        2    0.000    0.000    0.000    0.000 base_events.py:731(time)
        7    0.000    0.000    0.000    0.000 enum.py:1116(__new__)
        1    0.000    0.000    0.000    0.000 enum.py:1531(__or__)
        3    0.000    0.000    0.000    0.000 enum.py:1541(__and__)
        7    0.000    0.000    0.000    0.000 enum.py:713(__call__)
        1    

In [15]:
import torch

class OuterProductModule(torch.nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x, y):
        return torch.einsum ('ij, ik -> ijk', x, y)
    
a = OuterProductModule()
scripted_model = torch.jit.script(a)

In [16]:
cProfile.run('scripted_model(torch.from_numpy(test_a), torch.from_numpy(test_b))')
%memit scripted_model(torch.from_numpy(test_a), torch.from_numpy(test_b))

         104 function calls in 0.269 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.102    0.102    0.102    0.102 <string>:1(<module>)
        1    0.000    0.000    0.102    0.102 base_events.py:1894(_run_once)
        2    0.000    0.000    0.000    0.000 base_events.py:731(time)
        1    0.000    0.000    0.000    0.000 events.py:82(_run)
        1    0.000    0.000    0.000    0.000 history.py:839(_writeout_output_cache)
        1    0.000    0.000    0.000    0.000 ioloop.py:742(_run_callback)
        1    0.000    0.000    0.000    0.000 iostream.py:616(_flush)
        1    0.000    0.000    0.000    0.000 iostream.py:710(_flush_buffers)
        1    0.000    0.000    0.000    0.000 iostream.py:718(_rotate_buffers)
        1    0.000    0.000    0.106    0.106 module.py:1747(_wrapped_call_impl)
        1    0.106    0.106    0.106    0.106 module.py:1755(_call_impl)
        1    0.000    0.000    