<a href="https://colab.research.google.com/github/afairley/ColaboratoryNotebooks/blob/main/JaxBuildingMatrixAndTensorOperationsWithVMap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
jax.print_environment_info(return_string=False)

jax:    0.4.23
jaxlib: 0.4.23
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1


In [2]:
import jax.numpy as jnp
import jax.tree_util as tree_util
from jax import vmap


vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)
mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)
mm2 = vmap(mv2, (1, 1), 0)  #  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)

In [9]:
identityMatrix = jnp.array(((1,0,0),
                  (0,1,0),
                  (0,0,1)))
permuteABMatrix = jnp.array(((0,1,0),
                  (1,0,0),
                  (0,0,1)))
permuteBCMatrix = jnp.array(((1,0,0),
                  (0,0,1),
                  (0,1,0)))
oneRowVector = jnp.array((1,1,1))
oneColumnVector = jnp.array(((1),
                             (1),
                             (1)))
columnMatrix = jnp.array(((1,0,0),
                  (2,0,0),
                  (3,0,0)))
rowMatrix = jnp.array(((1,4,9),
                  (0,0,0),
                  (0,0,0)))
columnMatrix2 = jnp.array(((0,0,1),
                  (0,0,2),
                  (0,0,3)))
rowMatrix2 = jnp.array(((0,0,0),
                  (0,0,0),
                  (1,4,9)))

#expansionMatrix = map( lambda x: map( lambda y: y *2, x) *2 , identityMatrix)
#You can no longer do the above.
def right_multiply_matrix_by_scalar(matrix, scalar):
  return tree_util.tree_map(lambda x : x * scalar, matrix)
expansionMatrix = right_multiply_matrix_by_scalar(identityMatrix, 2)
reductionMatrix = right_multiply_matrix_by_scalar(identityMatrix, 0.5)



In [10]:

print("Identity Matrix I:\n" , identityMatrix)
print("Permutation Matrix P_AB:\n" , permuteABMatrix)
print("Permutation Matrix P_BC:\n" , permuteBCMatrix)
print("RowVector v1:\n" , oneRowVector)
print("Column Vector vT1:\n" , oneColumnVector)
print("Expansion Matrix E:\n" , expansionMatrix)
print("Reduction Matrix R:\n" , reductionMatrix)
print("Column Matrix C:\n" , columnMatrix)
print("Row Matrix R:\n" , rowMatrix)
print("Column Matrix C2:\n" , columnMatrix2)
print("Row Matrix R2:\n" , rowMatrix2)




Identity Matrix I:
 [[1 0 0]
 [0 1 0]
 [0 0 1]]
Permutation Matrix P_AB:
 [[0 1 0]
 [1 0 0]
 [0 0 1]]
Permutation Matrix P_BC:
 [[1 0 0]
 [0 0 1]
 [0 1 0]]
RowVector v1:
 [1 1 1]
Column Vector vT1:
 [1 1 1]
Expansion Matrix E:
 [[2 0 0]
 [0 2 0]
 [0 0 2]]
Reduction Matrix R:
 [[0.5 0.  0. ]
 [0.  0.5 0. ]
 [0.  0.  0.5]]
Column Matrix C:
 [[1 0 0]
 [2 0 0]
 [3 0 0]]
Row Matrix R:
 [[1 4 9]
 [0 0 0]
 [0 0 0]]
Column Matrix C2:
 [[0 0 1]
 [0 0 2]
 [0 0 3]]
Row Matrix R2:
 [[0 0 0]
 [0 0 0]
 [1 4 9]]


In [11]:
#It would appear that while tree_map is perfectly happy to map over tuple trees
#vmap does not handle them unless they are upgraded to jnp.array
#mm(expansionMatrix,reductionMatrix) #appears to throw a shape error
print("mm product E * R:\n",
    mm(expansionMatrix,reductionMatrix))
print("mm product C * R:\n",
    mm(columnMatrix,rowMatrix))
print("mm product R * C:\n",
    mm(rowMatrix,columnMatrix))
print("mm product C2 * R2:\n",
    mm(columnMatrix2,rowMatrix2))
print("mm product R2 * C2:\n",
    mm(rowMatrix2,columnMatrix2))



mm product E * R:
 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
mm product C * R:
 [[ 1  4  9]
 [ 2  8 18]
 [ 3 12 27]]
mm product R * C:
 [[36  0  0]
 [ 0  0  0]
 [ 0  0  0]]
mm product C2 * R2:
 [[ 1  4  9]
 [ 2  8 18]
 [ 3 12 27]]
mm product R2 * C2:
 [[ 0  0  0]
 [ 0  0  0]
 [ 0  0 36]]
mv1 product E * E:
 [4 4 4]
mv1 product C * R:
 [1 0 0]
mv1 product R * C:
 [1 0 0]


In [None]:
print("mv1 product E * E:\n",
          (mv1(expansionMatrix,expansionMatrix)))
print("mv1 product C * R:\n",
    mv1(columnMatrix,rowMatrix))
print("mv1 product R * C:\n",
    mv1(rowMatrix,columnMatrix))

In [8]:
rank3DiagonalUnityVectorTensor = jnp.array(
                  (
                    (
                    ((1,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ),
                    (
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,1,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ),
                    (
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,1))
                    )
))
print("mm2 product sillyTensor * sillyTensor:\n",
    mm2(rank3DiagonalUnityVectorTensor,rank3DiagonalUnityVectorTensor))

mm2 product sillyTensor * sillyTensor:
 [[1 0 0]
 [0 1 0]
 [0 0 1]]


In [None]:
#Doing this isn't actually needed as colab instances
#come with jax installed
projects = [
    ("jax","https://github.com/google/jax")]
JAX_DIR = f"/content/{projects[0][0]}"
for project, repo in projects:
  !rm -rf ./{project}
  !git clone {repo}

In [None]:
#Doing this isn't actually needed as colab instances
#come with jax installed
!pip install jaxlib
!cd {JAX_DIR} && pip install -e .
