<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 3: Pseudorandom Numbers (Exercises 21-30) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.
* A sample array `sample_data` will be used for the exercises.

In [1]:
!python3 -m pip install jax



In [2]:
## import packages
import jax
import jax.numpy as jnp
import os
import requests

## setup JAX to use TPUs if available
try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

2022-05-09 16:59:07.129704: W external/org_tensorflow/tensorflow/compiler/xla/service/platform_util.cc:200] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 10501554176


[CpuDevice(id=0)]

In [3]:
## sample data
sample_data = jnp.array([10, 1, 24, 20, 15, 14])
sample_data

DeviceArray([10,  1, 24, 20, 15, 14], dtype=int32)

**Exercise 21: Create a pseudorandom number generator key with seed=100 and assign it to `key`**

In [4]:
key = jax.random.PRNGKey(100)
key

DeviceArray([  0, 100], dtype=uint32)

**Exercise 22: Create a subkey from `key` and assign it to `subkey`**

In [22]:
new_key, subkey = jax.random.split(key)
subkey

DeviceArray([3011861781, 1867493174], dtype=uint32)

**Exercise 23: Split `key` into seven subkeys `key_1`, `key_2`, `key_3`, `key_4`, `key_5`, `key_6` and `key_7`**

In [6]:
new_key1, sub_key1, sub_key2, sub_key3, sub_key4, sub_key5, sub_key6 = jax.random.split(new_key, num=7)
print(new_key1) 
print(sub_key1)
print(sub_key2)
print(sub_key3)
print(sub_key4)
print(sub_key5)
print(sub_key6)

[ 402730500 1595431526]
[1424548836 4263965854]
[4262858109 2830712664]
[695247316 923860704]
[2593091143 4270345139]
[3741056579  424226150]
[ 913445040 1145123381]


**Exercise 24: Create a random permutation of `sample_data` using `key_1` and assign it to `data_permutation`**

In [7]:
data_permutation = jax.random.permutation(new_key1, sample_data)
data_permutation

DeviceArray([20, 14,  1, 10, 24, 15], dtype=int32)

**Exercise 25: Choose a random element from `sample_data` using `key_2` and assign it to `random_selection`**

In [8]:
random_selection = jax.random.choice(sub_key1, sample_data)
random_selection

DeviceArray(1, dtype=int32)

**Exercise 26: Sample an integer between 10 and 24 using `key_3` and assign it to `sample_int`**

In [16]:
sample_int = jax.random.randint(sub_key2, shape=(1,), minval=10, maxval= 24)
sample_int

DeviceArray([14], dtype=int32)

**Exercise 27: Sample two values from uniform distribution between 1 and 2 using `key_4` and assign it to `sample_uniform`**

In [17]:
sample_uniform = jax.random.uniform(sub_key3, shape=(2,), minval=1, maxval=2)
sample_uniform

DeviceArray([1.6274643, 1.1133162], dtype=float32)

**Exercise 28: Sample three values from bernoulli distribution using `key_5` and assign it to `sample_bernoulli`**

In [18]:
sample_bernoulli = jax.random.bernoulli(sub_key4, shape=(3,))
sample_bernoulli

DeviceArray([False,  True,  True], dtype=bool)

**Exercise 29: Sample a 2x3 matrix from poisson distribution with λ=100 using `key_6` and assign it to `sample_poisson`**

In [19]:
sample_poisson = jax.random.poisson(sub_key5, shape=(2,3), lam=100)
sample_poisson

DeviceArray([[ 88,  82, 110],
             [ 89,  85,  98]], dtype=int32)

**Exercise 30: Sample a 2x3x4 array from normal distribution using `key_7` and assign it to `sample_normal`**

In [21]:
sample_normal = jax.random.normal(sub_key6, shape=(2,3,4))
sample_normal

DeviceArray([[[ 0.25418106,  1.1962762 ,  1.3234271 ,  0.7971064 ],
              [-1.8524723 , -0.28634217,  0.22515148, -0.6195277 ],
              [ 2.4013762 ,  0.07618266,  1.2277744 , -0.75625014]],

             [[-0.45340505,  1.1029627 , -0.3986071 , -1.1235172 ],
              [-1.5689536 ,  0.46173117, -0.5607152 , -1.7508616 ],
              [ 0.5020061 , -1.4972548 , -1.6995512 ,  0.5555248 ]]],            dtype=float32)

<center>
    This completes Set 3: Pseudorandom Numbers (Exercises 21-30) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>