<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 2: Data Operations (Exercises 11-20) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.

In [645]:
#!python3 -m pip install jax

In [646]:
## import packages
import jax
import jax.numpy as jnp
import os
import requests

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[GpuDevice(id=0, process_index=0)]

**Exercise 11: Create a matrix with values [[10, 1, 24], [20, 15, 14]] and assign it to `data`**

In [647]:
data = jnp.array([[10,1,24], [20,15,14]])
data

DeviceArray([[10,  1, 24],
             [20, 15, 14]], dtype=int32)

**Exercise 12: Assign the transpose of `data` to `dataT`**

In [648]:
dataT = data.T
dataT

DeviceArray([[10, 20],
             [ 1, 15],
             [24, 14]], dtype=int32)

**Exercise 13: Assign the element of `data` at index [0, 2] to `value`**

In [649]:
value = data[0][2]
value1 = data[0, 2]
if value1 == value:
    print('value1 and value are equal')
    print(value)
    print(value1)

value1 and value are equal
24
24


**Exercise 14: Update the value of `data` at index [1, 1] to `100`**

In [650]:
data = data.at[1,1].set(100)
data

DeviceArray([[ 10,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 15: Add `41` to the value of `data` at index [0, 0]**

In [651]:
data = data.at[0, 0].add(41)
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 16: Calculate the minimum values over axis=1 and assign it to `mins`**

In [652]:
mins = data.min(axis=1)
mins

DeviceArray([ 1, 14], dtype=int32)

**Exercise 17: Select the first row of values of `data` and assign it to `data_select`**

In [653]:
print(data)
# data[begin:end]
# data[:everything until this number] -> data[:7] would be values from index from 0 to 6
# data[start from this number:] -> data[7:] would be values from index 7 to the end

# data[first dimension, second dimension]
# data[,1] would be the values from second dimension but only from index 1
# data[,0] would be the values from second dimension but only from index 0
# data[,2] would be the values from second dimension but only from index 2

# data[,:] would be the values from second dimension but from index 0 to the end
# data[,7:] would be the values from second dimension starting from index 7
# data[,:7] would be the values from second dimension but only from index 0 to 6

# data[:,] would be the values from first dimension from index 0 to the end
# data[7:,] would be the values from first dimension starting from index 7
# data[:7,] would be the values from first dimension starting from index 0 to 6

# data[:,:] would be the values from first and second dimension from index 0 to the end
# data[:2,:7] in the first dimension the values would be from index 0 to 1 while in the second dimension the values would be from index 0 to 6

data_select = data[0,:]
data_select

[[ 51   1  24]
 [ 20 100  14]]


DeviceArray([51,  1, 24], dtype=int32)

**Exercise 18: Append the row `data_select` to `data`**

In [654]:
def my_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64)
    n, k = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2*q+1))
    B = D = jnp.diag(pre_D)[:, None]
    pre_A = jnp.diag(q)
    
    A_base = (jnp.sqrt(2*n+1) * jnp.sqrt(2*k+1))
    
    # A = jnp.where(n > k, A.at[n, k].set(A_base), A.at[n, k].set(0)) # if n > k, then A_base is used, otherwise 0
    # A = jnp.where(n == k, A.at[n, k].set(case_2 * A_base), A) # if n == k, then A_base is used, otherwise A
    
    A = jnp.where(n > k, A_base, 0.0) # if n > k, then A_base is used, otherwise 0
    A = jnp.where(n == k, n+1, A) # if n == k, then A_base is used, otherwise A
    
    return A, B

In [655]:
#Scaled Legendre (LegS), non-vectorized
def build_LegS_NV(N):
    q = jnp.arange(N, dtype=jnp.float64) # q represents the values 1, 2, ..., N each column has
    n, k = jnp.meshgrid(q, q)
    M = -(jnp.where(n >= k, 2*q+1, 0) - jnp.diag(q)) # represents the state matrix M 
    D = jnp.sqrt(jnp.diag(2*q+1)) # represents the diagonal matrix D $D := \text{diag}[(2n+1)^{\frac{1}{2}}]^{N-1}_{n=0}$
    A = D @ M @ jnp.linalg.inv(D)
    B = jnp.diag(D)[:, None]
    
    return A, B

In [656]:
#Scaled Legendre (LegS)
def build_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64)
    n, k = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2*q+1))
    B = D = jnp.diag(pre_D)[:, None]
    pre_A = jnp.diag(q)
    
    A_base = -(jnp.sqrt(2*n+1) * jnp.sqrt(2*k+1))
    case_2 = (n+1)/(2*n+1) 
    
    # A = jnp.where(n > k, A.at[n, k].set(A_base), A.at[n, k].set(0)) # if n > k, then A_base is used, otherwise 0
    # A = jnp.where(n == k, A.at[n, k].set(case_2 * A_base), A) # if n == k, then A_base is used, otherwise A
    
    A = jnp.where(n > k, A_base, 0.0) # if n > k, then A_base is used, otherwise 0
    A = jnp.where(n == k, (case_2 * A_base), A) # if n == k, then A_base is used, otherwise A
    
    return A, B

In [657]:
#print(build_LegS(10))
#print(build_LegS_NV(10))
nv_A, nv_B = build_LegS_NV(5)
A, B = build_LegS(5)
other_A, other_B = other_LegS(5)
print(f"nv:\n ", nv_A)
print(f"v:\n ", A)
print(f"other:\n ", other_A)
print(f"A Comparison:\n ", jnp.allclose(nv_A, A))
print(f"Other A Comparison:\n ", jnp.allclose(nv_A, other_A))
print(f"B Comparison:\n ", jnp.allclose(nv_B, B))

nv:
  [[-1.        -1.7314453 -2.2363281 -2.6455078 -2.9992676]
 [ 0.        -1.9997292 -3.8751373 -4.5823975 -5.1966476]
 [ 0.         0.        -3.0015717 -5.9169617 -6.7066956]
 [ 0.         0.         0.        -4.00074   -7.935562 ]
 [ 0.         0.         0.         0.        -4.9987793]]
v:
  [[-1.        -1.7320508 -2.2360678 -2.6457512 -3.       ]
 [ 0.        -2.        -3.872983  -4.5825753 -5.196152 ]
 [ 0.         0.        -2.9999995 -5.916079  -6.7082033]
 [ 0.         0.         0.        -4.        -7.937254 ]
 [ 0.         0.         0.         0.        -5.       ]]
other:
  [[1.        1.7320508 2.2360678 2.6457512 3.       ]
 [0.        2.        3.872983  4.5825753 5.196152 ]
 [0.        0.        3.        5.916079  6.7082033]
 [0.        0.        0.        4.        7.937254 ]
 [0.        0.        0.        0.        5.       ]]
A Comparison:
  False
Other A Comparison:
  False
B Comparison:
  True


In [658]:
# truncated Fourier (FouT)
def build_FouT(N):
    freqs = jnp.arange(N//2)
    d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
    A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
    
    B = jnp.zeros(N)
    B = B.at[0::2].set(2**.5)
    B = B.at[0].set(1)
    B = B[:, None]
    
    return A, B

In [659]:
# truncated Fourier (FouT)
def build_FouT_RC(N):
    freqs = jnp.arange(N//2)
    d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
    A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
    
    B = jnp.zeros(N)
    B = B.at[0::2].set(2**.5)
    B = B.at[0].set(1)
    
    A = A - B[:, None] * B[None, :]
    B = B[:, None]
    
    return A, B

In [660]:
A, B = build_FouT(5)
rc_A, rc_B = build_FouT_RC(N)

print(A)
print(rc_A)

[[ 0.         0.         0.         0.       ]
 [ 0.         0.         0.         0.       ]
 [ 0.         0.         0.        -3.1415927]
 [ 0.         0.         3.1415927  0.       ]]
[[ -1.          0.         -1.4142135   0.         -1.4142135   0.
   -1.4142135   0.         -1.4142135   0.       ]
 [  0.          0.          0.          0.          0.          0.
    0.          0.          0.          0.       ]
 [ -1.4142135   0.         -1.9999999  -3.1415927  -1.9999999   0.
   -1.9999999   0.         -1.9999999   0.       ]
 [  0.          0.          3.1415927   0.          0.          0.
    0.          0.          0.          0.       ]
 [ -1.4142135   0.         -1.9999999   0.         -1.9999999  -6.2831855
   -1.9999999   0.         -1.9999999   0.       ]
 [  0.          0.          0.          0.          6.2831855   0.
    0.          0.          0.          0.       ]
 [ -1.4142135   0.         -1.9999999   0.         -1.9999999   0.
   -1.9999999  -9.424778   -1

In [661]:
data = jnp.vstack((data, data_select))
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14],
             [ 51,   1,  24]], dtype=int32)

**Exercise 19: Multiply the matrices `data` and `dataT` and assign it to `data_prod`**

In [662]:
data_prod = jnp.dot(data, dataT)
other_data_prod = data @ dataT

if jnp.array_equal(data_prod, other_data_prod):
    print("these are equal")
    
data_prod

these are equal


DeviceArray([[1087, 1371],
             [ 636, 2096],
             [1087, 1371]], dtype=int32)

**Exercise 20: Convert the dtype of `data_prod` to `float32`**

In [663]:
data_prod = data_prod.astype('float32')
data_prod

DeviceArray([[1087., 1371.],
             [ 636., 2096.],
             [1087., 1371.]], dtype=float32)

<center>
    This completes Set 2: Data Operations (Exercises 11-20) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>