<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 2: Data Operations (Exercises 11-20) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.

In [150]:
#!python3 -m pip install jax

In [151]:
## import packages
import jax
import jax.numpy as jnp
import os
import requests

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[CpuDevice(id=0)]

**Exercise 11: Create a matrix with values [[10, 1, 24], [20, 15, 14]] and assign it to `data`**

In [152]:
data = jnp.array([[10,1,24], [20,15,14]])
data

DeviceArray([[10,  1, 24],
             [20, 15, 14]], dtype=int32)

**Exercise 12: Assign the transpose of `data` to `dataT`**

In [153]:
dataT = data.T
dataT

DeviceArray([[10, 20],
             [ 1, 15],
             [24, 14]], dtype=int32)

**Exercise 13: Assign the element of `data` at index [0, 2] to `value`**

In [154]:
value = data[0][2]
value1 = data[0, 2]
if value1 == value:
    print('value1 and value are equal')
    print(value)
    print(value1)

value1 and value are equal
24
24


**Exercise 14: Update the value of `data` at index [1, 1] to `100`**

In [155]:
data = data.at[1,1].set(100)
data

DeviceArray([[ 10,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 15: Add `41` to the value of `data` at index [0, 0]**

In [156]:
data = data.at[0, 0].add(41)
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14]], dtype=int32)

**Exercise 16: Calculate the minimum values over axis=1 and assign it to `mins`**

In [157]:
mins = data.min(axis=1)
mins

DeviceArray([ 1, 14], dtype=int32)

**Exercise 17: Select the first row of values of `data` and assign it to `data_select`**

In [158]:
print(data)
# data[begin:end]
# data[:everything until this number] -> data[:7] would be values from index from 0 to 6
# data[start from this number:] -> data[7:] would be values from index 7 to the end

# data[first dimension, second dimension]
# data[,1] would be the values from second dimension but only from index 1
# data[,0] would be the values from second dimension but only from index 0
# data[,2] would be the values from second dimension but only from index 2

# data[,:] would be the values from second dimension but from index 0 to the end
# data[,7:] would be the values from second dimension starting from index 7
# data[,:7] would be the values from second dimension but only from index 0 to 6

# data[:,] would be the values from first dimension from index 0 to the end
# data[7:,] would be the values from first dimension starting from index 7
# data[:7,] would be the values from first dimension starting from index 0 to 6

# data[:,:] would be the values from first and second dimension from index 0 to the end
# data[:2,:7] in the first dimension the values would be from index 0 to 1 while in the second dimension the values would be from index 0 to 6

data_select = data[0,:]
data_select

[[ 51   1  24]
 [ 20 100  14]]


DeviceArray([51,  1, 24], dtype=int32)

**Exercise 18: Append the row `data_select` to `data`**

In [159]:
#Scaled Legendre (LegS), non-vectorized
def build_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64)  # q represents the values 1, 2, ..., N each column has
    k, n = jnp.meshgrid(q, q)
    r = 2 * q + 1
    M = -(jnp.where(n >= k, r, 0) - jnp.diag(q)) # represents the state matrix M 
    D = jnp.sqrt(jnp.diag(2 * q + 1)) # represents the diagonal matrix D $D := \text{diag}[(2n+1)^{\frac{1}{2}}]^{N-1}_{n=0}$
    A = D @ M @ jnp.linalg.inv(D)
    B = jnp.diag(D)[:, None]
    B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
    
    return A, B

In [160]:
#Scaled Legendre (LegS) vectorized
def build_LegS(N):
    q = jnp.arange(N, dtype=jnp.float64)
    k, n = jnp.meshgrid(q, q)
    pre_D = jnp.sqrt(jnp.diag(2*q+1))
    B = D = jnp.diag(pre_D)[:, None]
    
    A_base = (-jnp.sqrt(2*n+1)) * jnp.sqrt(2*k+1)
    case_2 = (n+1)/(2*n+1)
    
    A = jnp.where(n > k, A_base, 0.0) # if n > k, then A_base is used, otherwise 0
    A = jnp.where(n == k, (A_base * case_2), A) # if n == k, then A_base is used, otherwise A
    
    return A, B

In [161]:
nv_A, nv_B = build_LegS(5)
A, B = build_LegS_V(5)
print(f"nv:\n ", nv_A)
print(f"v:\n ", A)
print(f"A Comparison:\n ", jnp.allclose(nv_A, A))
print(f"B Comparison:\n ", jnp.allclose(nv_B, B))

nv:
  [[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -2.         0.         0.         0.       ]
 [-2.236068  -3.8729832 -3.         0.         0.       ]
 [-2.6457512 -4.5825753 -5.9160795 -4.         0.       ]
 [-3.        -5.196152  -6.7082043 -7.937254  -5.       ]]
v:
  [[-1.         0.         0.         0.         0.       ]
 [-1.7320508 -2.         0.         0.         0.       ]
 [-2.236068  -3.8729832 -3.         0.         0.       ]
 [-2.6457512 -4.5825753 -5.9160795 -4.         0.       ]
 [-3.        -5.196152  -6.7082043 -7.937254  -5.       ]]
A Comparison:
  True
B Comparison:
  True


  lax_internal._check_user_dtype_supported(dtype, "arange")


In [163]:
import numpy as np

In [165]:
# truncated Fourier (FouT)
def build_FouT(N):
    freqs = jnp.arange(N//2)
    d = jnp.stack([jnp.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
    A = jnp.pi*(-jnp.diag(d, 1) + jnp.diag(d, -1))
    
    B = jnp.zeros(A.shape[1])
    B = B.at[0::2].set(jnp.sqrt(2))
    B = B.at[0].set(1)
    
    A = A - B[:, None] * B[None, :]
    B = B[:, None]
    
    return A, B

In [166]:
# truncated Fourier (FouT)
def build_FouT_V(N):
    A = jnp.diag(jnp.stack([jnp.zeros(N//2), jnp.zeros(N//2)], axis=-1).reshape(-1))
    B = jnp.zeros(A.shape[1], dtype=jnp.float64)
    q = jnp.arange((N//2)*2, dtype=jnp.float64)
    n, k = jnp.meshgrid(q, q)
    n_odd = n % 2 == 0
    k_odd = k % 2 == 0
    
    case_1 = (n==k) & (n==0)
    case_2_3 = ((k==0) & (n_odd)) | ((n==0) & (k_odd))
    case_4 = (n_odd) & (k_odd)
    case_5 = (n-k==1) & (k_odd)
    case_6 = (k-n==1) & (n_odd)
    
    A = jnp.where(case_1, -1.0, 
                  jnp.where(case_2_3, -jnp.sqrt(2),
                            jnp.where(case_4, -2, 
                                      jnp.where(case_5, -jnp.pi * (n//2), 
                                                jnp.where(case_6, jnp.pi * (k//2), 0.0)))))
    
    B = B.at[::2].set(jnp.sqrt(2))
    B = B.at[0].set(1)
    #A = 2 * A
    #B = 2 * B
    
    B = B[:, None]
        
    return A, B

In [167]:
N=7
A, B = build_FouT(N)
A_b, B_b = build_FouT_V(N)
print(f"A Comparison:\n {jnp.allclose(A, A_b)}")
print(f"B Comparison:\n {jnp.allclose(B, B_b)}")

A Comparison:
 True
B Comparison:
 True


  lax_internal._check_user_dtype_supported(dtype, "zeros")


In [168]:
data = jnp.vstack((data, data_select))
data

DeviceArray([[ 51,   1,  24],
             [ 20, 100,  14],
             [ 51,   1,  24]], dtype=int32)

**Exercise 19: Multiply the matrices `data` and `dataT` and assign it to `data_prod`**

In [169]:
data_prod = jnp.dot(data, dataT)
other_data_prod = data @ dataT

if jnp.array_equal(data_prod, other_data_prod):
    print("these are equal")
    
data_prod

these are equal


DeviceArray([[1087, 1371],
             [ 636, 2096],
             [1087, 1371]], dtype=int32)

**Exercise 20: Convert the dtype of `data_prod` to `float32`**

In [170]:
data_prod = data_prod.astype('float32')
data_prod

DeviceArray([[1087., 1371.],
             [ 636., 2096.],
             [1087., 1371.]], dtype=float32)

<center>
    This completes Set 2: Data Operations (Exercises 11-20) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>