<a href="https://colab.research.google.com/github/Leila828/Learning_JAX_for_deepLearning/blob/main/jax_Linear_Algebra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Vectors**

Since we are assuming you are fairly familiar with NumPy, this chapter is by no means a thorough coverage of the topic and will just serve as a quick refresher. We will practice using JAX while discovering the equivalence between normal NumPy and the JAX version of NumPy’s syntax.

Note: As a reference, we’ll use np and jnp respectively for default and JAX NumPy versions in our codes.

**Inner product#**
The inner product of two vectors can be calculated by any of the three syntaxes (dot(), inner() and @), as shown below:

In [None]:
a = jnp.arange(1,10)
b = jnp.arange(11,20)

print(jnp.dot(a,b))
print(jnp.inner(a,b))
print(a@b)
print(b@a)

Linear functions#
A linear function can be represented as:

f(x) = \alpha x_1+\beta x_2+\gamma x_3+....
f(x)=αx 
1
​
 +βx 
2
​
 +γx 
3
​
 +....

We can compose it as a lambda function too:

In [None]:
a= 1.2
b = 2.4

FuncA = lambda x : a*x[0]+b*x[1]

x = jnp.arange(1,3)
print(FuncA(x))

**Taylor approximation#**
Taylor approximation lies at the core of major optimization algorithms in machine learning. We can first-order approximate the value of a function near a point z as:

\hat{f}(x) = f(z) +\nabla f(z)^T (x-z)
f
^
​
 (x)=f(z)+∇f(z) 
T
 (x−z)

Having already seen the efficiency of grad(), we can easily implement the Taylor approximation. For example, if at z = (1,-1) we want to approximate the value of the following function:

f(x) = 2x_1+3x_2
f(x)=2x 
1
​
 +3x 
2
​


In [None]:
FuncA = lambda x : 2*x[0]+3*x[1]
Gradient = lambda x : grad(FuncA)(x)
TaylorApprox = lambda x:FuncA(z)+Gradient(z)@(x-z) 

z = jnp.array([1.0,-1.0])
print("Actual function value at z is:",FuncA(z))
print("Taylor Approximation at same value is:",TaylorApprox(z))

Norms#
The Euclidean norm of any vector is defined as:

||x|| = \sqrt {x_1^2 + x_2^2 +....+x_n^2}
∣∣x∣∣= 
x 
1
2
​
 +x 
2
2
​
 +....+x 
n
2
​
 
​
 

Norms can be calculated using the package linalg. We can also use norms to find the Euclidean distance between two vectors:

In [None]:
x = jnp.arange(1,100)
y = jnp.arange(11,110)
x_norm = jnp.linalg.norm(x)

print("Norm of x is: ",x_norm)
print("Distance between x and y vectors is:",jnp.linalg.norm(x-y))

Cosine similarity#
A common way of finding the similarity between two vectors in machine learning, especially in NLP applications, is to calculate the cosine similarity. Using the definition of the dot product:

x.y = x^Ty = ||x||.||y|| cos\theta
x.y=x 
T
 y=∣∣x∣∣.∣∣y∣∣cosθ

We can then calculate the cosine similarity:

cos \theta = \frac {x^Ty}{||x||.||y||}
cosθ= 
∣∣x∣∣.∣∣y∣∣
x 
T
 y
​


In [None]:
x = jnp.arange(1,100)
y = jnp.arange(11,110)

cosine = lambda x,y: (x@y)/(jnp.linalg.norm(x)*jnp.linalg.norm(y))
print("Cosine similarity between x and y is:",cosine(x,y))

Standard deviation#
The standard deviation of a vector is defined, where n
n
 is the size of vector, as:

std(x) = \frac{||x-\bar x||}{\sqrt{n}}
std(x)= 
n
​
 
∣∣x− 
x
ˉ
 ∣∣
​
 

We can directly calculate this using jnp.std(). The following example shows the use of both direct std() and implementation using the above formula:

In [None]:
x = jnp.arange(20,61)
y = jnp.arange(1,101)

stdv = lambda x: jnp.linalg.norm(x-jnp.average(x))/(jnp.sqrt(x.size))

print("SD of x using our function is: ",stdv(x))
print("SD of y using our function is:",stdv(y))
#Answers will be same
print("Actual SD of x is: ",jnp.std(x))
print("Actual SD of y is:",jnp.std(y))

**Matrices**

Matrices are at the core of almost any data application. Even vectors can be treated as a matrix with either a row or column dimension of 1.

Note: We will use the common notation of capital letters of linear algebra for matrices in our codes as well.

In [None]:
a = jnp.arange(5)
b = 2*jnp.arange(5)
c = -1*jnp.arange(5)

A = jnp.array((a,b,c)) #Concatenation of vectors to make a matrix

print(A)
print(A.shape)

**Slicing#**
We can make submatrices by using the slicing (:) notation.

Note: Python uses 0-based indexing, as does JAX.

For example, as shown in line 10 below, a submatrix containing the first 2 rows and first 3 columns of the above matrix will be:

B = A[0:2,0:3]

**Reshaping functions#**
Reshaping functions are commonly used in several applications, especially computer vision.

The rule behind any reshaping function is simple: If the input and output matrices have mxn and jxk dimensions respectively, then:

m\times n = j\times k
m×n=j×k

We can reshape a matrix A of dimensions (m,n) into B as:

B = A.reshape((j,k))

Note: The reshape() operator is inconsistent with linear algebra and is just a useful feature in NumPy and JAX programming.

**Stacking#**
We can directly stack a collection of vectors:

Vertically (for row vectors): Using vstack() or r_[]
Horizontally (for column vectors): Using hstack() or c_[]

In [None]:
a = jnp.arange(5)
b = 2*jnp.arange(5)
c = -1*jnp.arange(5)

A = jnp.array((a,b,c)) 

print(A)
print(A.shape)

B = A[0:2,0:3]

print(B)

C = A[1:3,1:4]
print(C)

D = jnp.vstack([B,C])
print(D)

E = jnp.r_[B,C]
print(E) #Same as D

**Multiplication#**
Matrices are multiplied by the usual linear algebra rules using either @ or matmul():

In [None]:
a = jnp.arange(5)
b = 2*jnp.arange(5)
c = -1*jnp.arange(5)

A = jnp.array((a,b,c)) 

print("A and its shape")
print(A)
print(A.shape)

B = A@A.T #Multiplying A with its transpose

print("B = A@A and its shape")
print(B)
print(B.shape)

C= jnp.matmul(A,A.T) #matmul() yields the same answer.

print("C will have same contents as B and its shape")
print(C)
print(C.shape)

**Random matrices#**
There will often be scenarios where we need to sample a matrix stochastically. This will come in handy in many real-life applications.

One of the core features of JAX is its PRNG. To sample a matrix from a given distribution, we need to specify the PRNG key first, followed by its shape. For example, we can sample a 4x4 matrix as:

In [None]:
A = jax.random.uniform(key,(4,4))

Note: The next chapter clarifies how to get detailed coverage by providing insight into both JAX’s PNRG and probability distributions.

Norm#
A matrix norm, also known as a Frobenius norm, is defined as:

||A|| = \sqrt {(\sum_{i=1}^m\sum_{j=1}^n A_{ij}^2)}
∣∣A∣∣= 
( 
i=1
∑
m
​
  
j=1
∑
n
​
 A 
ij
2
​
 )
​
 

We can use the same linalg.norm() for the Frobenius norm as well.

In [None]:
from jax import random
from jax.numpy import linalg

a = jnp.arange(5)
b = 2*jnp.arange(5)
c = -1*jnp.arange(5)

A = jnp.array((a,b,c))

print("A and its shape")
print(A)
print(A.shape)

A_norm = linalg.norm(A)
print("A's norm is: ",A_norm)

key = random.PRNGKey(0) #Don't worry about it. Next chapter covers it in depth
B = random.normal(key,(100,100)) 

B_norm = linalg.norm(B)
print("B's norm is:", B_norm)

note
While it’s correct to refer to the Frobenius norm as Euclidean norm, the L2 norm for a matrix is instead defined as:

||X||_2 = \sigma_{max}(X)=\sqrt{\lambda_{max}(X^TX)}
∣∣X∣∣ 
2
​
 =σ 
max
​
 (X)= 
λ 
max
​
 (X 
T
 X)
​
 

It is used less frequently than the Frobenius norm.

**More on Matrices**

While the vast world of matrices and the implementation is more than even a dedicated course can cover, we’ll try to do some justice to the subject by concluding with another lesson covering a few important concepts.

Determinant#
Like the norm, the determinant is a scalar value characterizing a square matrix. It is quite useful in defining some properties of the matrix.

For example, a matrix is singular if (and only if) its determinant is zero.

We can calculate it as:

In [None]:
a = linalg.det(A)

**Inverse#**
The inverse of a matrix is defined as:

A^{-1} = \frac{Adj. A}{|A|}
A 
−1
 = 
∣A∣
Adj.A
​
 

The linalg sublibrary provides the function for inverse calculation as:

A_inv = linalg.inv(A)
Note: The linalg.pinv() can be used to calculate a pseudo-inverse and doesn’t support the int members.

Power# **bold text**
Similarly, we can calculate the power of A using matrix_power() as:

A_square = linalg.matrix_power(A,2)


In [None]:
A = jnp.array([[1, 2, 3],[ 4, 5, 6],[7, 8, 9]])
print(A)

B = linalg.matrix_power(A,-1)
C = linalg.inv(A)
d = linalg.det(A)

print(B)
print(C) # will have same answer as B
print("Matrix's determinant is: ",d)

A is singular, so the outputs above should not be a surprise. Below is a non-singular matrix with its square and inverse respectively:

In [None]:
A = jnp.array([[1, -1, 3],[ 3, 5, 1],[7, 0, 9]])
print(A)

B = linalg.matrix_power(A,2)
C = linalg.inv(A)
D = linalg.inv(C) #Allowing rounding-off, D=A
e = linalg.det(A)

print("Matrix's determinant is: ",e)

print(B)
print(C)
print(D)

Eigenvectors and eigenvalues#
If we recall, an eigenvector v of a square matrix A is defined as:

Av=\lambda v
Av=λv

where λ is the corresponding eigenvalue.

lamba_values, eigen_vect = linalg.eig(A)

Compared to normal NumPy, the JAX version treats every value as complex, so don’t be surprised by the j we observe in the output values.

It may be tempting to use “lambda” as a variable identifier. Don’t do this since lambda is a keyword. Instead, we can use λ itself as an identifier.



In [None]:
A = jnp.array([[2,1,5],[3,0,4],[2,1,-2]])

λ, eigen_vectors = linalg.eig(A)

print("The λ values are:,",λ)
print("And Eigen vectors (as a matrix) are:",eigen_vectors)

**Definiteness#**
Eigenvalues of a vector introduce another related concept.

We call a square matrix a positive definite matrix if all the eigenvalues are positive while a positive semidefinite matrix is one having all its eigenvalues either zero or positive.

We can extend this concept to negative definite and semidefinite matrices as well.

If a matrix has some positive eigenvalues and other negatives, it’s called an indefinite matrix. These matrices are pretty helpful as we will shortly see in the follow-up lessons.

Note: The concept of definiteness is usually restricted to only real-value square matrices, but can be expanded to complex matrices if needed.

In [None]:
def TestDefiniteness(X):
  if(jnp.all(linalg.eigvals(X) > 0)):
    print("It's Positive Definite")
  elif(jnp.all(linalg.eigvals(X) >= 0)):
    print("It's Positive Semidefinite")
  elif(jnp.all(linalg.eigvals(X) < 0)):
    print("It's Negative Definite")
  elif(jnp.all(linalg.eigvals(X) <= 0)):
    print("It's Negative Semidefinite")
  else:
    print("Its Indefinite")

#Lets test some matrices

A = jnp.array(([2.0,3.0,4.0],[1.2,2.4,5.2],[2.8,3.6,4.9]))
B = jnp.array(([2.4,4.3],[6.1, 13.6]))
TestDefiniteness(A)
TestDefiniteness(B)

Singular Value Decomposition#
We’ll conclude the lesson by performing a singular value decomposition (SVD) on a given matrix. If you recall, an SVD of matrix A is:

U~\Sigma ~V^* = A
U Σ V 
∗
 =A

If A is m\times n
m×n
, then the dimensions of the decomposed matrices are:

U: m \times m
U:m×m

\Sigma: m \times n
Σ:m×n

V^*: n \times n
V 
∗
 :n×n

We can verify the dimensions in this example:

In [None]:
A = jnp.array([[2,1,5],[3,0,4],[2,1,-2],[4,0,7]])

U, Sigma, V = linalg.svd(A)

print("Original matrix shape:",A.shape)
print("U shape:",U.shape)
print("Σ shape:",Sigma.shape)
print("V* shape:", V.shape)

If we look closely, while the shapes of U
U
 and V^*
V 
∗
 
 are consistent with the formula, \Sigma
Σ
 is returned as a one-dimensional vector of singular values, which is consistent with NumPy.