## Installing JAX  
#### CPU Version:  
`pip install --upgrade pip`  
`pip install --upgrade jax jaxlib  # CPU-only version`  
#### GPU Version:  
`pip install --upgrade pip`  
`pip install --upgrade jax jaxlib==0.1.55+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html`

In [17]:
import jax.numpy as jnp
from jax import grad, jacfwd, jacrev, random

In [20]:
key = random.PRNGKey(42)
A = random.normal(key, (4,3))

Lets make some objects to work with
and some basic operations like *matrix product*, *Hadamard product*, *vec* operation, and *Kronecker product* $\bigotimes$. 

In [10]:
A = random.normal(key, (2,3))
B = random.normal(key, (3,2))

print('A:')
print(A)
print('\nB:')
print(B)

A:
[[ 0.61226517  1.1225882   1.1373315 ]
 [-0.8127326  -0.8904051   0.12623137]]

B:
[[ 0.61226517  1.1225882 ]
 [ 1.1373315  -0.8127326 ]
 [-0.8904051   0.12623137]]


In [11]:
# Matrix product
print('\nAB:')
print(jnp.dot(A,B))
print('\nBA:')
print(jnp.dot(B,A))

print("\nA' (transpose):")
print(A.T)

print('\nvec operation (note we have to transpose before calling flatten (row major order)):')
vec = lambda A: A.T.flatten()
print(vec(A))


AB:
[[ 0.6389377  -0.08147542]
 [-1.6226907  -0.17276834]]

BA:
[[-0.5374953  -0.31223667  0.8380543 ]
 [ 1.3568827   2.000416    1.1909306 ]
 [-0.6477564  -1.1119553  -0.9967514 ]]

A' (transpose):
[[ 0.61226517 -0.8127326 ]
 [ 1.1225882  -0.8904051 ]
 [ 1.1373315   0.12623137]]

vec operation (note we have to transpose before calling flatten (row major order)):
[ 0.61226517 -0.8127326   1.1225882  -0.8904051   1.1373315   0.12623137]


In [12]:
# Hadamard product
print("\nHadamard product (A*B')")
print(A*B.T)

# Kronecker product
print(u'\nKronecker product A \u2297 B')
print(jnp.kron(A,B))


Hadamard product (A*B')
[[ 0.37486863  1.2767549  -1.0126858 ]
 [-0.91236395  0.72366124  0.01593436]]

Kronecker product A ⊗ B
[[ 0.37486863  0.6873216   0.6873216   1.2602042   0.6963484   1.2767549 ]
 [ 0.6963484  -0.49760786  1.2767549  -0.91236395  1.293523   -0.9243463 ]
 [-0.54516405  0.07728707 -0.99955827  0.14170584 -1.0126858   0.14356692]
 [-0.49760786 -0.91236395 -0.54516405 -0.99955827  0.07728707  0.14170584]
 [-0.9243463   0.66053426 -1.0126858   0.72366124  0.14356692 -0.10259235]
 [ 0.72366124 -0.10259235  0.7928213  -0.11239706 -0.11239706  0.01593436]]


Calculating $Df(x)$ for $f: \mathbb{R} \to \mathbb{R}$

In [13]:
f = lambda x: x**2

x = jnp.float32(3.0)

print(f'f(3) = {f(x)}')
print(f'Df(3) = {grad(f)(x)}')

f(3) = 9.0
Df(3) = 6.0


Calculating $Df(X)$ for $f: \mathbb{R}^n \to \mathbb{R}$

In [14]:
def f(X):
    col_sum = jnp.sum(X, axis=0)
    squared = col_sum**2
    return jnp.sum(squared)

X = jnp.array([[2,3,1], [-2,0,6]], dtype=jnp.float32)

print(f'f(X) = {f(X)}')
print(f'Df(X) = \n{grad(f)(X)}')

f(X) = 58.0
Df(X) = 
[[ 0.  6. 14.]
 [ 0.  6. 14.]]


Notation for taking derivative with respect to different variables.  
For:  
$U: \mathbb{R}^{10} \to \mathbb{R}^3$  
$W: \mathbb{R}^{15} \to \mathbb{R}^{10}$  
$V: \mathbb{R}^4 \to \mathbb{R}^{15}$  
let $A = UWV$ making $A: \mathbb{R}^4 \to \mathbb{R}^3$  
and let $f: \mathbb{R}^4 \to \mathbb{R}$ with $f(A) = \sum_{i,j} a_{ij}$ .  
We can calculate
$$
\frac{\partial f}{\partial U}, \frac{\partial f}{\partial W}, \frac{\partial f}{\partial V}
$$
by:

In [15]:
U = random.normal(key, (3, 10), dtype=jnp.float32)
W = random.normal(key, (10, 15), dtype=jnp.float32)
V = random.normal(key, (15, 4), dtype=jnp.float32)

A = lambda a, b, c: jnp.dot(jnp.dot(a, b), c)
f = lambda a, b, c: jnp.sum(A(a,b,c))

dU = grad(f)(U,W,V)
print('shape of dU:', dU.shape)
dW = grad(f, argnums=1)(U,W,V)
print('shape of dW:', dW.shape)
dV = grad(f, argnums=2)(U,W,V)
print('shape of dV:', dV.shape)
dU, dW, dV = grad(f, argnums=[0,1,2])(U,W,V)
print('shape of dU:', dU.shape)
print('shape of dW:', dW.shape)
print('shape of dV:', dV.shape)

shape of dU: (3, 10)
shape of dW: (10, 15)
shape of dV: (15, 4)
shape of dU: (3, 10)
shape of dW: (10, 15)
shape of dV: (15, 4)


Failing to calculate $DF(x)$ for $F: \mathbb{R} \to \mathbb{R}^n$

In [16]:
print('F(x) = [x^0, x^1, x^2, x^3, x^4, x^5]')
print('DF(x) = [0, 1, 2x, 3x^2, 4x^3, 5x^4]')

def F(x):
    powers = jnp.arange(6)
    return x**powers

x = jnp.float32(3.0)

print(f'F(3) =\n {F(x)}')

try:
    print(f'DF(x) = \n{grad(F)(x)}')
except Exception as e:
    print('\n... Oh noes ...')
    print('Exception raised:', e)

F(x) = [x^0, x^1, x^2, x^3, x^4, x^5]
DF(x) = [0, 1, 2x, 3x^2, 4x^3, 5x^4]
F(3) =
 [  1.   3.   9.  27.  81. 243.]

... Oh noes ...
Exception raised: Gradient only defined for scalar-output functions. Output had shape: (6,).


Calculating $DF(x)$ for $F: \mathbb{R} \to \mathbb{R}^n$ using the Jacobian

In [21]:
print('F(x) = [x^0, x^1, x^2, x^3, x^4, x^5]')
print('DF(x) = [0, 1, 2x, 3x^2, 4x^3, 5x^4]')

def F(x):
    powers = jnp.arange(6)
    return x**powers

x = jnp.float32(3.0)

print(f'F(3) =\n {F(x)}')

print(f'DF(3) = \n{jacfwd(F)(x)}')

print(f'DDF(3) = \n{jacfwd(jacfwd(F))(x)}')

print(f'D^3F(3) = \n{jacfwd(jacfwd(jacfwd(F)))(x)}')

print(f'D^4F(3) = \n{jacfwd(jacfwd(jacfwd(jacfwd(F))))(x)}')

def D(f, dtimes=[0]):
    for i, n in enumerate(dtimes):
        while n > 0:
            f = jacfwd(f, argnums=i)
            n -= 1
    return f

print(f'D^4F(3) = \n{D(F,[4])(x)}')

F(x) = [x^0, x^1, x^2, x^3, x^4, x^5]
DF(x) = [0, 1, 2x, 3x^2, 4x^3, 5x^4]
F(3) =
 [  1.   3.   9.  27.  81. 243.]
DF(3) = 
[  0.   1.   6.  27. 108. 405.]
DDF(3) = 
[  0.   0.   2.  18. 108. 540.]
D^3F(3) = 
[  0.   0.   0.   6.  72. 540.]
D^4F(3) = 
[  0.   0.   0.   0.  24. 360.]
D^4F(3) = 
[  0.   0.   0.   0.  24. 360.]


Calculating $DF(X)$ for $F: \mathbb{R}^n \to \mathbb{R}^m$

In [None]:
print('function:')
print('[x_4sin(x_1) + x_2sin(x_3),')
print(' x_3sin(x_2) + x_1sin(x_4)]')
print()

def f(x):
    y = x[::-1]*jnp.sin(x)
    return y[:2] + y[2:]

v = random.normal(key, (4, ), dtype=jnp.float32)

print('Forward Jacobian ( use when for mxn matrix, when m >= n ):')
print(jacfwd(f)(v)); print()

print('Reverse Jacobian ( use when for mxn matrix, when m < n ):')
print(jacrev(f)(v)); print()

print('Norm difference:', jnp.linalg.norm(jacfwd(f)(v) - jacrev(f)(v))); print()

def hessian(f):
    return jacfwd(jacrev(f))

print('Hessian:')
print(hessian(f)(v)); print()

Recreating the result of Magnus and Neudecker  
  
For constant matrices $A, B$ and $F(X) = AXB$, Calculate $DF(X)$ and verify $DF(X) = B' \bigotimes A$.

In [27]:
#A = random.normal(key, (3, 10), dtype=jnp.float32)
#X = random.normal(key, (10, 15), dtype=jnp.float32)
#B = random.normal(key, (15, 4), dtype=jnp.float32)

A = random.normal(key, (2, 3), dtype=jnp.float32)
X = random.normal(key, (3, 4), dtype=jnp.float32)
B = random.normal(key, (4, 2), dtype=jnp.float32)

F1 = lambda X, A, B: jnp.dot(jnp.dot(A, X), B)

R1 = jacfwd(F1)(X, A, B)
R2 = jnp.kron(vec(B), A)

print('shapes:', R1.shape, R2.shape)
print('size:', R1.size, R2.size)

Xshape = X.shape

vecinv = lambda x, shape: x.reshape(shape[::-1]).T

F2 = lambda X, A, B: vec(jnp.dot(jnp.dot(A, vecinv(X, Xshape)), B))

R1 = jacfwd(F2)(vec(X), A, B)
R2 = jnp.kron(vec(B.T), A)

print('shapes:', R1.shape, R2.shape)
print('size:', R1.size, R2.size)

shapes: (2, 2, 3, 4) (2, 24)
size: 48 48
shapes: (4, 12) (2, 24)
size: 48 48


In [None]:
import matplotlib.pyplot as plt
def getA(x,y,z):
    py2 = x**2 + y**2
    py3 = x**2 + y**2 + z**2
    sq2 = jnp.sqrt(py2)
    sq3 = jnp.sqrt(py3)
    sq23 = jnp.sqrt(py2*py3)
    A = jnp.array([[x/sq3,   x*z/sq23, -y/sq2],
                   [y/sq3,   y*z/sq23,  x/sq2],
                   [z/sq3, -sq2/sq3,      0]])
    return A

fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')

'''for x in jnp.linspace(-10,10,5):
    for y in jnp.linspace(-10,10,5):
        for z in jnp.linspace(-10,10,5):
            A = getA(x,y,z)

            e = jnp.array([[1,0,0],[0,1,0],[0,0,1]])
            v = jnp.dot(A, e)
            for i in range(3):
                #ax.plot([x, x+v[0,i]],[y, y+v[1,i]],[z, z+v[2,i]], color='black')
                ax.plot([x, x+v[0,i]],[y, y+v[1,i]],[z, z+v[2,i]], color='black')'''

#plt.show()
r=10
for theta in jnp.linspace(0, jnp.pi, 10)[1:-1]:
    for phi in jnp.linspace(0, 2*jnp.pi, 10):
        x = 10*jnp.sin(theta)*jnp.cos(phi)
        y = 10*jnp.sin(theta)*jnp.sin(phi)
        z = 10*jnp.cos(theta)
        
        A = getA(x,y,z)

        e = jnp.array([[1,0,0],[0,1,0],[0,0,1]])
        v = jnp.dot(A, e)
        for i in range(3):
            #ax.plot([x, x+v[0,i]],[y, y+v[1,i]],[z, z+v[2,i]], color='black')
            ax.plot([x, x+v[0,i]],[y, y+v[1,i]],[z, z+v[2,i]], color='black')

plt.show()