<a href="https://colab.research.google.com/github/afairley/ColaboratoryNotebooks/blob/main/JaxBuildingMatrixAndTensorOperationsWithVMap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



# Jax Version Info



In [None]:
import jax
jax.print_environment_info(return_string=False)

jax:    0.4.23
jaxlib: 0.4.23
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1



#vmap examples


In [None]:
"""from DocString and other permutation variants
Still don't have a solid operational definition of the exact behavior
of vmap yet, but I think this is most of the basic meaningful permutations.
There are quite a few other permutations that can be done even restricting
ourselves to a "pair" tuple with no tree structure for in_axes.

It basically looks to work along the lines you'd expect if you've implemented
this sort of code before. Theoretically refactoring all the code in jax that
does something along the lines of turning user provided list and tuple pytrees
into tuple pytrees and getting all the code backing vmap into a single module
might provide some efficiency gains depending on how actual python's JIT
handles function calls across many modules and back and forth over the
jaxlib/jax boundry. It has presumably been looked at internally.

I'm not really sure if I want to proceed by writing a pretty printer for matrices and tensors and their various products or unrolling all the code.
Most likely I'm going to do both, although I actually want to read through
the haiku and XLA source at present as well. VMap appears to call all the
way down into XLA, so I can probably scratch that itch while staying
"on task" as it were.

Perhaps there is some other wasy I can stash this text here in this markdown instance, the following effort based on advice from the hoi polloi doens't seem to work.( https://stackoverflow.com/questions/4823468/comments-in-markdown )
[//]: # It would appear that while tree_map is perfectly happy to map over
[//]: # tuple trees vmap does not handle them unless they are upgraded to
[//]: # jnp.array
[//]: # #mm(expansionMatrix,reductionMatrix) #appears to throw a shape error """

In [None]:
import jax.numpy as jnp
import jax.tree_util as tree_util
from jax import vmap


vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)

#this thing is actuall some sort of tensor product and not a matrix product at
#all
tp1 = vmap(mv2, (1, 1), 0)  #  mm2 originally
                            #([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)

vm = vmap(vv, (None, 0), 0)      #  ([a], [a,b]) -> [a]
mm3 = vmap(vm, (1, None), 1)     #so this looks like the reversed matrix conventions
                                 #as far as I recall from yesterday
                                 #straight application of pattern matching
                                 #and inversion of indices would suggest that it does
                                 #  ([a,b], [b,c]) -> [c,b]

#other variants
vm2 = vmap(vv, (None, 1), 0)
mm4 = vmap(vm, (1, None), 1)

#switching the tuple in this way inverts the order of matrix multiplication
#constituent operations
vmapped_dot = vmap(vv, (1, None), 0)
vmapped_vmapped_dot = vmap(vmapped_dot, (None, 0), 1)


vmapped_dot2 = vmap(vv, (0, None), 0)
vmapped_vmapped_dot2 = vmap(vmapped_dot2, (None, 1), 1)

# Basic Matrices

In [None]:
identityMatrix = jnp.array(((1,0,0),
                  (0,1,0),
                  (0,0,1)))
permuteABMatrix = jnp.array(((0,1,0),
                  (1,0,0),
                  (0,0,1)))
permuteBCMatrix = jnp.array(((1,0,0),
                  (0,0,1),
                  (0,1,0)))
oneRowVector = jnp.array((1,1,1))
oneColumnVector = jnp.array(((1),
                             (1),
                             (1)))
columnMatrix = jnp.array(((1,0,0),
                  (2,0,0),
                  (3,0,0)))
rowMatrix = jnp.array(((1,4,9),
                  (0,0,0),
                  (0,0,0)))
columnMatrix2 = jnp.array(((0,0,1),
                  (0,0,2),
                  (0,0,3)))
rowMatrix2 = jnp.array(((0,0,0),
                  (0,0,0),
                  (1,4,9)))

#expansionMatrix = map( lambda x: map( lambda y: y *2, x) *2 , identityMatrix)
#You can no longer do the above.
def right_multiply_matrix_by_scalar(matrix, scalar):
  return tree_util.tree_map(lambda x : x * scalar, matrix)
expansionMatrix = right_multiply_matrix_by_scalar(identityMatrix, 2)
reductionMatrix = right_multiply_matrix_by_scalar(identityMatrix, 0.5)



In [None]:

print("Identity Matrix I:\n" , identityMatrix)
print("Permutation Matrix P_AB:\n" , permuteABMatrix)
print("Permutation Matrix P_BC:\n" , permuteBCMatrix)
print("RowVector v1:\n" , oneRowVector)
print("Column Vector vT1:\n" , oneColumnVector)
print("Expansion Matrix E:\n" , expansionMatrix)
print("Reduction Matrix R:\n" , reductionMatrix)
print("Column Matrix C:\n" , columnMatrix)
print("Row Matrix R:\n" , rowMatrix)
print("Column Matrix C2:\n" , columnMatrix2)
print("Row Matrix R2:\n" , rowMatrix2)




#Applications of various "Products"


In [None]:
"""Without writing a pretty printer for printing multiple matrices in shared console real estate,
 following the following is more cognitive overhead than is sensible for reasoning.
  Math notation is handy for a reason.
   I will probably write a narrowly delimited matrix pretty printer.
   Something taking 3 at most maybe 4x4 matrices ought to be reasonably straight
    forward and presumably has been written dozens of times. Maybe I can find
    one if I look around a bit."""

In [None]:

print("mm product E * R:\n",
    mm(expansionMatrix,reductionMatrix))
print("mm product C * R:\n",
    mm(columnMatrix,rowMatrix))
print("mm product R * C:\n",
    mm(rowMatrix,columnMatrix))
print("mm product C2 * R2:\n",
    mm(columnMatrix2,rowMatrix2))
print("mm product R2 * C2:\n",
    mm(rowMatrix2,columnMatrix2))



In [None]:
print("vm product rV1 * I:\n",
    vm(oneRowVector,identityMatrix))
print("vm product I * E:\n",
    vm(oneRowVector,expansionMatrix))

print("vm2 product rV1 * I:\n",
    vm2(oneRowVector,identityMatrix))
print("vm2 product I * E:\n",
    vm2(oneRowVector,expansionMatrix))

In [None]:
print("mm3 product E * R:\n",
    mm3(expansionMatrix,reductionMatrix))
print("mm3 product C * R:\n",
    mm3(columnMatrix,rowMatrix))
print("mm3 product R * C:\n",
    mm3(rowMatrix,columnMatrix))
print("mm3 product C2 * R2:\n",
    mm3(columnMatrix2,rowMatrix2))
print("mm3 product R2 * C2:\n",
    mm3(rowMatrix2,columnMatrix2))

In [None]:
print("vmapped_vmapped_dot product E * R:\n",
    vmapped_vmapped_dot(expansionMatrix,reductionMatrix))
print("vmapped_vmapped_dot product C * R:\n",
    vmapped_vmapped_dot(columnMatrix,rowMatrix))
print("vmapped_vmapped_dot product R * C:\n",
    vmapped_vmapped_dot(rowMatrix,columnMatrix))
print("vmapped_vmapped_dot product C2 * R2:\n",
    vmapped_vmapped_dot(columnMatrix2,rowMatrix2))
print("vmapped_vmapped_dot product R2 * C2:\n",
    vmapped_vmapped_dot(rowMatrix2,columnMatrix2))

In [None]:
print("mv1 product E * E:\n",
          (mv1(expansionMatrix,expansionMatrix)))
print("mv1 product C * R:\n",
    mv1(columnMatrix,rowMatrix))
print("mv1 product R * C:\n",
    mv1(rowMatrix,columnMatrix))

# Actual Proper "Tensors"


In [None]:
"""(even though there's nothing here about covariant or contravariant indices in so far as I can tell which I'm pretty sure is important for proper tensors)

Also I'm pretty sure the abuse of "rank" by tensor oriented people to mean
dimension is the origin of the abuse of "rank" to mean dimension instead of meaning the de facto dimension of a potentially higher dimension'd matrix as
modern mathematicians mean rank."""

In [None]:
rank3DiagonalUnityVectorTensor = jnp.array(
                  (
                    (
                    ((1,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ),
                    (
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,1,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ),
                    (
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,0))
                    ,
                    ((0,0,0),
                    (0,0,0),
                    (0,0,1))
                    )
))
print("tp1/\"mm2\" product sillyTensor * sillyTensor:\n",
    tp1(rank3DiagonalUnityVectorTensor,rank3DiagonalUnityVectorTensor))

tp1/"mm2" product sillyTensor * sillyTensor:
 [[1 0 0]
 [0 1 0]
 [0 0 1]]


#Irrelevance

In [None]:
## How deep does section folding work
# Not deep enough :D
"""
Vmap__Source = "https://github.com/google/jax/blob/main/jax/_src/api.py#L1061"

tree_flatten__Source = '''There are in fact numerous definitions of tree_flatten. Looking at imports it is presumably

https://github.com/google/jax/blob/main/jax/_src/tree_util.py#L71

https://github.com/google/jax/blob/main/jax/_src/lib/__init__.py#L93
Which is a layer of indirection around
jaxlib.xla_client._xla.pytree
which is somewhere in the C++ code

4.29.2024 Looking at help(jaxlib.xla_client) I see that it is from a file
installed as /usr/local/lib/python3.10/dist-packages/jaxlib/xla_client.py
which is in fact presumably some sort of file autogenerated during the build
process of jaxlib. There is some snarky commentary presumably from a googler at
https://stackoverflow.com/questions/66257662/what-exactly-is-xla-client-in-the-jax-library
which points to some outdated bazel files from the tensorflow project
(i.e. said files no longer exist in the tensorflow project in the locations
pointed to snarkily)

Looking at {JAX_DIR}/jaxlib/BUILD I see

...
symlink_files(
    name = "xla_client",
    srcs = ["@xla//xla/python:xla_client"],
    dst = ".",
    flatten = True,
)

symlink_files(
    name = "xla_extension",
    srcs = if_windows(
        ["@xla//xla/python:xla_extension.pyd"],
        ["@xla//xla/python:xla_extension.so"],
    ),
    dst = ".",
    flatten = True,
)
...
which looks like the kind of thing the stackoverflow commenter is talking about
Looking around the xla source for xla_client I first stumbled upon pytree.cc
which looks like it has the code I've been looking for at
https://github.com/openxla/xla/blob/main/xla/python/pytree.cc#L1264
https://github.com/openxla/xla/blob/main/xla/python/pytree.cc#L296
'''
lu.wrap_init__Source = "https://github.com/google/jax/blob/main/jax/_src/linear_util.py#L262"

batching.flatten_fun_for_vmap__Source = https://github.com/google/jax/blob/main/jax/_src/interpreters/batching.py#L304

flatten_axes__Source:https://github.com/google/jax/blob/main/jax/_src/api_util.py#L404

tree_unflatten__Source: Similar story to tree_flatten

https://github.com/google/jax/blob/main/jax/_src/tree_util.py#L107

cast__Source:(imported from python's typing module)
"""

In [None]:
"""Reading the docstring for vmap again now, an interesting,not necessarily
achievable exercise might be to write some sort of meaningful vector triple
product.  The cross product isn't it though I was thinking that.
There's the "scalar" triple product and the "vector" triple products.
https://en.wikipedia.org/wiki/Triple_product
I think theoretically you might be able to write the vector triple product
with vmap.
"""

'Reading the docstring for vmap again now, an interesting,not necessarily \nachievable exercise might be to write some sort of meaningful vector triple \nproduct.  The cross product isn\'t it though I was thinking that.\nThere\'s the "scalar" triple product and the "vector" triple products.\n'

#Irrelevance

In [None]:
#Doing this isn't actually needed as colab instances
#come with jax installed(unless you want to browse source in colab)
#the github browser has gotten good enough to read in some as long as you have
#something like colab to play around in.
projects = [
    ("jax","https://github.com/google/jax")]
JAX_DIR = f"/content/{projects[0][0]}"
for project, repo in projects:
  !rm -rf ./{project}
  !git clone {repo}

In [None]:
#Doing this isn't actually needed as colab instances
#come with jax installed
!pip install jaxlib
!cd {JAX_DIR} && pip install -e .
