In [1]:
import os
os.environ['OMP_NUM_THREADS'] = '1'
import time
import copy
import numpy as np
import jax
import jax.numpy as jnp
from jax_explicit_inv import *
from jax.config import config

# This line is critical for enabling 64-bit floats.
config.update("jax_enable_x64", True)



In [22]:
mats = np.random.rand(int(1e6), 4, 4)

In [30]:
%%time
_ = np.linalg.inv(mats)

CPU times: user 313 ms, sys: 8.41 ms, total: 322 ms
Wall time: 321 ms


In [29]:
%%time
_ = jnp.linalg.inv(mats)

CPU times: user 372 ms, sys: 224 ms, total: 596 ms
Wall time: 340 ms


In [2]:
def bench(f, N, d, iters=5):
    np.random.seed(10)
    mats = np.random.rand(int(N), d, d)
    jax_mats = jnp.array(mats)
    for i in range(iters):
        start = time.time()
        correct = jnp.linalg.inv(jax_mats).block_until_ready()
        end = time.time()
    jli_time = end - start
    for i in range(iters):
        start = time.time()
        fout = f(jax_mats).block_until_ready()
        end = time.time()
    f_time = end - start
    # np.testing.assert_allclose(fout, correct, rtol=1e-4)
    return jli_time / N * 1e6, f_time / N * 1e6

In [3]:
mats = np.random.rand(int(1e4), 15, 15)

In [5]:
%%timeit -n 10
vmap_inv_recurse(mats)

8.23 ms ± 848 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
bench(vmap_inv44, 1e5, 4), bench(vmap_inv33, 3e5, 3), bench(vmap_inv22, 1e6, 2)

((0.3070688247680664, 0.014290809631347656),
 (0.15472332636515299, 0.004496574401855469),
 (0.05975699424743652, 0.0023920536041259766))

In [7]:
mats.shape

(10000, 15, 15)

In [8]:
for d in range(1, 12):
    n = 6e5 / (d ** 2)
    print(d, n)
    print(bench(vmap_inv_recurse, n, d, iters=2))

1 600000.0
(0.03903826077779134, 0.000621477762858073)
2 150000.0
(0.07524649302164713, 0.0025113423665364585)
3 66666.66666666667
(0.16719102859497068, 0.005682706832885742)
4 37500.0
(0.3244781494140625, 0.020745595296223957)
5 24000.0
(0.5118846893310547, 0.04878640174865723)
6 16666.666666666668
(0.8373498916625977, 0.09006500244140625)
7 12244.897959183674
(1.2759443124135335, 0.10444164276123045)
8 9375.0
(1.6009521484375, 0.1348114013671875)
9 7407.407407407408
(2.3558914661407466, 0.2498960494995117)
10 6000.0
(3.130833307902018, 0.39931138356526696)
11 4958.677685950413
(4.201853672663371, 0.44994274775187176)


## Code generation?!

In [1]:
import os
os.environ['OMP_NUM_THREADS'] = '1'
import time
import copy
import numpy as np
import jax
import jax.numpy as jnp
from jax_explicit_inv import *
from jax.config import config

# This line is critical for enabling 64-bit floats.
config.update("jax_enable_x64", True)



In [2]:
orig = np.random.rand(3,3)
U = orig.copy()
L = np.diag(np.ones(U.shape[0]))
for k in range(U.shape[0] - 1):
    invkk = 1.0 / U[k,k]
    L[(k+1):,k] = U[(k + 1):,k].copy() * invkk
    U[(k+1):,k:] -= U[k:(k+1),k:] * U[(k+1):,k:(k+1)] * invkk
LU = U.copy()
LL = L.copy()
np.fill_diagonal(LL, 0)
soln = LU + LL
np.testing.assert_allclose(L.dot(U), orig)

In [3]:
from dataclasses import dataclass
class CodeGenCtx:
    def __init__(self):
        self.assignments = []
        self.definitions = dict()

    def assign(self, name, definition):
        self.assignments.append(name)
        self.definitions[name] = definition
        return self.assignments[-1]
        
    def lines(self):
        return [f'{a} = {self.definitions[a]}' for a in self.assignments]

In [4]:
def gen_lu(ctx, M, d):
    U = copy.deepcopy(M)
    L = [[None] * d for i in range(d)]
    for k in range(d - 1):
        inv_k = ctx.assign(f'inv_{k}', f'1.0 / {U[k][k]}')
        for j in range(k + 1, d):
            L[j][k] = ctx.assign(f'L_{j}{k}', f'{U[j][k]} * {inv_k}')
        for i in range(k + 1, d):
            for j in range(k + 1, d):
                if i == k + 1:
                    name = f'U_{i}{j}'
                else:
                    name = f'U_{k}_{i}{j}'
                U[i][j] = ctx.assign(name, f'{U[i][j]} - {U[k][j]} * {U[i][k]} * {inv_k}')
    LU = [[U[i][j] if i <= j else L[i][j] for j in range(d)] for i in range(d)]
    return LU



def LU_decomp(m):
    inv_0 = 1.0 / m[0, 0]
    L_10 = m[1, 0] * inv_0
    L_20 = m[2, 0] * inv_0
    U_11 = m[1, 1] - m[0, 1] * m[1, 0] * inv_0
    U_12 = m[1, 2] - m[0, 2] * m[1, 0] * inv_0
    U_0_21 = m[2, 1] - m[0, 1] * m[2, 0] * inv_0
    U_0_22 = m[2, 2] - m[0, 2] * m[2, 0] * inv_0
    inv_1 = 1.0 / U_11
    L_21 = U_0_21 * inv_1
    U_22 = U_0_22 - U_12 * U_0_21 * inv_1
    return jnp.array([[m[0, 0], m[0, 1], m[0, 2]], [L_10, U_11, U_12], [L_20, L_21, U_22]])


In [None]:
def build_linalg(name, generator, d, print_code=True):
    ctx = CodeGenCtx()
    M = [[f'm[{i}, {j}]' for j in range(d)] for i in range(d)]
    LU = generator(ctx, M, d)
    lines = ctx.lines()
    lines.append('return jnp.array([' + ', '.join([
        '[' + ', '.join(LU[i]) + ']'
        for i in range(d)
    ]) + '])')
    lines = [f'def {name}(m):'] + ['    ' + l for l in lines] 
    code = '\n'.join(lines)
    if print_code:
        print(code)
    return code

exec(build_linalg('LU_decomp', gen_lu, 3))
np.testing.assert_allclose(LU_decomp(orig), soln)

In [5]:
def gen_upper_tri_inv(ctx, U, d):
    invU = copy.deepcopy(U)
    for k in range(d)[::-1]:
        invU[k][k] = ctx.assign(f'invU_{k}{k}', f'1.0 / {invU[k][k]}')
        for j in range(k + 1, d):
            invU[k][j] = ctx.assign(f'invU_{k}{j}', f'{invU[k][j]} * {invU[k][k]}')
        for i in range(k):
            mult = f'-{invU[i][k]}'
            invU[i][k] = ctx.assign(f'invU_{k}_{i}{k}', f'{mult} * {invU[k][k]}')
            for j in range(k + 1, d):
                invU[i][j] = ctx.assign(f'invU_{k}_{i}{j}', f'{invU[i][j]} + {mult} * {invU[k][j]}')
    return invU
exec(build_linalg('upper_tri_inv', gen_upper_tri_inv, 3))
np.testing.assert_allclose(np.triu(upper_tri_inv(soln)), np.linalg.inv(np.triu(soln)))

def upper_tri_inv(m):
    invU_22 = 1.0 / m[2, 2]
    invU_2_02 = -m[0, 2] * invU_22
    invU_2_12 = -m[1, 2] * invU_22
    invU_11 = 1.0 / m[1, 1]
    invU_12 = invU_2_12 * invU_11
    invU_1_01 = -m[0, 1] * invU_11
    invU_1_02 = invU_2_02 + -m[0, 1] * invU_12
    invU_00 = 1.0 / m[0, 0]
    invU_01 = invU_1_01 * invU_00
    invU_02 = invU_1_02 * invU_00
    return jnp.array([[invU_00, invU_01, invU_02], [m[1, 0], invU_11, invU_12], [m[2, 0], m[2, 1], invU_22]])


In [6]:
def transpose(A):
    d = len(A)
    return [[A[j][i] for j in range(d)] for i in range(d)]

def gen_lu_inv(ctx, LU, d):
    invU = copy.deepcopy(LU)
    for k in range(d)[::-1]:
        invU[k][k] = ctx.assign(f'invU_{k}{k}', f'1.0 / {invU[k][k]}')
        for j in range(k + 1, d):
            invU[k][j] = ctx.assign(f'invU_{k}{j}', f'{invU[k][j]} * {invU[k][k]}')
        for i in range(k):
            mult = f'-{invU[i][k]}'
            invU[i][k] = ctx.assign(f'invU_{k}_{i}{k}', f'{mult} * {invU[k][k]}')
            for j in range(k + 1, d):
                invU[i][j] = ctx.assign(f'invU_{k}_{i}{j}', f'{invU[i][j]} + {mult} * {invU[k][j]}')
                
    invLU_T = transpose(invU)
    for i in range(d - 1):
        for j in range(i + 1, d):
            invLU_T[i][j] = '0'
    for k in range(d)[::-1]:
        for i in range(k):
            mult = f'-{LU[k][i]}'
            for j in range(d):
                name = f'invLU_T_{k}_{i}{j}'
                invLU_T[i][j] = ctx.assign(name, f'{invLU_T[i][j]} + {mult} * {invLU_T[k][j]}')
    return transpose(invLU_T)

exec(build_linalg('lu_inv', gen_lu_inv, 3))
np.testing.assert_allclose(lu_inv(LU_decomp(orig)), np.linalg.inv(orig), rtol=1e-5)

In [67]:
def gen_lu_solve(ctx, LU, B, d):
    Y = [None] * d
    for i in range(d):
        terms_i = [f'-{LU[i][j]}*{Y[j]}' for j in range(i)]
        Y[i] = ctx.assign(f'Y_{i}', f'{B[i]}' + ''.join(terms_i))
    X = [None] * d
    for i in range(d)[::-1]:
        invkk = ctx.assign(f'inv_{i}', f'1.0 / {LU[i][i]}')
        terms_i = [f'-{LU[i][j]}*{X[j]}*{invkk}' for j in range(i + 1, d)]
        X[i] = ctx.assign(f'X_{i}', f'{Y[i]}*{invkk}' + ''.join(terms_i))
    return X

In [77]:
def gen_solve(ctx, M, Y, d):
    LU = gen_lu(ctx, M, d)
    return gen_lu_solve(ctx, LU, Y, d)
    
def build_linalg_solve(name, generator, d, print_code=True):
    ctx = CodeGenCtx()
    M = [[f'm[{i}, {j}]' for j in range(d)] for i in range(d)]
    Y = [f'y[{i}]' for i in range(d)]
    X = gen_solve(ctx, M, Y, d)
    lines = ctx.lines()
    lines.append('return jnp.array([' + ', '.join(X) + '])')
    lines = [f'def {name}(m, y):'] + ['    ' + l for l in lines] 
    code = '\n'.join(lines)
    if print_code:
        print(code)
    return code
exec(build_linalg_solve('solve3', gen_solve, 3))

def solve3(m, y):
    inv_0 = 1.0 / m[0, 0]
    L_10 = m[1, 0] * inv_0
    L_20 = m[2, 0] * inv_0
    U_11 = m[1, 1] - m[0, 1] * m[1, 0] * inv_0
    U_12 = m[1, 2] - m[0, 2] * m[1, 0] * inv_0
    U_0_21 = m[2, 1] - m[0, 1] * m[2, 0] * inv_0
    U_0_22 = m[2, 2] - m[0, 2] * m[2, 0] * inv_0
    inv_1 = 1.0 / U_11
    L_21 = U_0_21 * inv_1
    U_22 = U_0_22 - U_12 * U_0_21 * inv_1
    Y_0 = y[0]
    Y_1 = y[1]-L_10*Y_0
    Y_2 = y[2]-L_20*Y_0-L_21*Y_1
    inv_2 = 1.0 / U_22
    X_2 = Y_2*inv_2
    inv_1 = 1.0 / U_11
    X_1 = Y_1*inv_1-U_12*X_2*inv_1
    inv_0 = 1.0 / m[0, 0]
    X_0 = Y_0*inv_0-m[0, 1]*X_1*inv_0-m[0, 2]*X_2*inv_0
    return jnp.array([X_0, X_1, X_2])


In [78]:
np.random.seed(0)
A = np.random.rand(3,3)
y = np.random.rand(3)
np.testing.assert_allclose(solve(A, y), np.linalg.solve(A,y))

In [79]:
def gen_inv(ctx, M, d):
    LU = gen_lu(ctx, M, d)
    out = gen_lu_inv(ctx, LU, d)
    return out

In [81]:
def bench_solve(f, N, d, iters=5):
    np.random.seed(10)
    mats = np.random.rand(int(N), d, d)
    bs = np.random.rand(int(N), d)
    jax_mats = jnp.array(mats)
    for i in range(iters):
        start = time.time()
        correct = jnp.linalg.solve(jax_mats, bs).block_until_ready()
        end = time.time()
    jli_time = end - start
    for i in range(iters):
        start = time.time()
        fout = f(jax_mats, bs).block_until_ready()
        end = time.time()
    f_time = end - start
    # np.testing.assert_allclose(fout, correct, rtol=1e-4)
    return jli_time / N * 1e6, f_time / N * 1e6

In [82]:
for d in range(1, 12):
    exec(build_linalg(f'inv{d}', gen_inv, d, print_code=False))
    exec(build_linalg(f'lu{d}', gen_lu, d, print_code=False))
    exec(build_linalg_solve(f'solve{d}', gen_solve, d, print_code=False))
    f = globals()[f'inv{d}']
    f_lu = globals()[f'lu{d}']
    f_solve = globals()[f'solve{d}']
    mat = np.random.rand(d, d)
    b = np.random.rand(d)
    np.testing.assert_allclose(f(mat), np.linalg.inv(mat), rtol=1e-5)
    np.testing.assert_allclose(f_solve(mat, b), np.linalg.solve(mat, b), rtol=1e-5)
    vmap = jax.jit(jax.vmap(f))
    vmap_lu = jax.jit(jax.vmap(f_lu))
    vmap_solve = jax.jit(jax.vmap(f_solve))
    globals()[f'vmap_inv{d}'] = vmap
    print('\n', d)
    n = 1e5 / (d ** 2)
    print('recursive + cramer', bench(vmap_inv_recurse, n, d))
    print('code gen', bench(vmap, n, d))
    print('lu gen', bench(vmap_lu, n, d))
    print('solve', bench_solve(vmap_solve, n, d))


 1
recursive + cramer (0.038678646087646484, 0.0005769729614257812)
code gen (0.03645896911621094, 0.0003814697265625)
lu gen (0.037589073181152344, 0.00016927719116210938)
solve (0.03943920135498047, 0.0008916854858398438)

 2
recursive + cramer (0.07440567016601562, 0.002841949462890625)
code gen (0.07260322570800781, 0.0023174285888671875)
lu gen (0.08527755737304688, 0.0017547607421875)
solve (0.06092071533203126, 0.001430511474609375)

 3
recursive + cramer (0.17460107803344727, 0.009183883666992188)
code gen (0.17236948013305664, 0.01574993133544922)
lu gen (0.18990039825439453, 0.005300045013427734)
solve (0.1529073715209961, 0.004141330718994141)

 4
recursive + cramer (0.37250518798828125, 0.026092529296875)
code gen (0.3539276123046875, 0.03520965576171875)
lu gen (0.333404541015625, 0.01102447509765625)
solve (0.22563934326171875, 0.02544403076171875)

 5
recursive + cramer (0.6052255630493164, 0.06300210952758789)
code gen (0.6139278411865234, 0.0749826431274414)
lu gen (0

In [10]:
def inv_jax_lax(m):
    lu, pivot, perm = jax.lax.linalg.lu(m)
    U_inv = jax.lax.linalg.triangular_solve(lu, jnp.diag(np.array([1.0,1,1])), lower=False, unit_diagonal=False)
    full_inv = jax.lax.linalg.triangular_solve(lu, U_inv, lower=True, unit_diagonal=True)
    return full_inv[:, perm]
vmap_inv_jax_lax = jax.jit(jax.vmap(inv_jax_lax))
bench(vmap_inv_jax_lax, 1e4, 3)

NameError: name 'bench' is not defined

In [338]:
%%time
_ = vmap_inv_jax(mats)

CPU times: user 153 ms, sys: 17.9 ms, total: 171 ms
Wall time: 133 ms


## JAX is fast for large matrices

In [2]:
A = np.random.rand(5000, 5000)

In [9]:
flops = 5000 ** 3 * 2 / 3.
flops * 0.3 / 1e9 / 10 / 2

1.2499999999999998

In [6]:
%%time
np.linalg.inv(A)

CPU times: user 9.08 s, sys: 579 ms, total: 9.66 s
Wall time: 1.47 s


array([[-0.54415102, -0.45408855, -0.41497187, ..., -0.33757099,
         0.85470587, -0.45968236],
       [-0.34614283, -0.33987818, -0.26964381, ..., -0.2076328 ,
         0.60106392, -0.33870489],
       [-0.44630539, -0.43907991, -0.30957567, ..., -0.27006815,
         0.6991083 , -0.43488036],
       ...,
       [ 0.05952298,  0.0160277 ,  0.07382422, ...,  0.01481646,
        -0.09952914,  0.03751881],
       [ 0.51272287,  0.49819567,  0.33228203, ...,  0.21269653,
        -0.84855951,  0.53084048],
       [-0.21913774, -0.19052291, -0.13904184, ..., -0.07937356,
         0.39681116, -0.21506747]])

In [7]:
%%time
jnp.linalg.inv(jnp.array(A, dtype=jnp.float64))

CPU times: user 9.33 s, sys: 589 ms, total: 9.92 s
Wall time: 1.33 s


DeviceArray([[-0.54415102, -0.45408855, -0.41497187, ..., -0.33757099,
               0.85470587, -0.45968236],
             [-0.34614283, -0.33987818, -0.26964381, ..., -0.2076328 ,
               0.60106392, -0.33870489],
             [-0.44630539, -0.43907991, -0.30957567, ..., -0.27006815,
               0.6991083 , -0.43488036],
             ...,
             [ 0.05952298,  0.0160277 ,  0.07382422, ...,  0.01481646,
              -0.09952914,  0.03751881],
             [ 0.51272287,  0.49819567,  0.33228203, ...,  0.21269653,
              -0.84855951,  0.53084048],
             [-0.21913774, -0.19052291, -0.13904184, ..., -0.07937356,
               0.39681116, -0.21506747]], dtype=float64)