### Random numbers

How important are random numbers when it comes to machine learning or deep learning? Well helluva important. How else are you going to initialize your parameters then? 

Actually a theoratical discussion on parameter initialization will hugely off track from the goal of these notebooks, so how about you read these wonderful pieces? (defintely later, because if you start reading them now and get lost you won't be able to get your interest back here).

1. [Initializing neural networks](https://www.deeplearning.ai/ai-notes/initialization/)
2. [Weight Initialization in Neural Networks: A Journey From the Basics to Kaiming](https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79)
3. [Weight Initialization for Deep Learning Neural Networks](https://machinelearningmastery.com/weight-initialization-for-deep-learning-neural-networks/)

### Typical Random Numbers in python

In [1]:
import random

random.seed(42)
random.random()

0.6394267984578837

In [2]:
random.randrange(1, 10)

1

In [3]:
random.randrange(1, 10, 2)

5

### Random in Jax

Jax uses a different Pseudo Random Number Generator or, PRNG compared to the default in python or numpy. You can read more [here](https://github.com/google/jax/blob/main/design_notes/prng.md). The primary goal behind Jax's implementaition was more to the state of the PRNG consistent among multiple devices (say GPUs). 

What happens in a multi device setting is that the state of a PRNG changes (even for the same seed), which can lead to inconsistent results between the devices. Ideally you would want all your devices to sync the same PRNG configuration and produce random numbers the same way (i.e. without duplicates).

Obviously there were other design considerations, which they have mentioned in the link above.  

So in Jax, you generate a key for a random number, which is similar to what you call a seed for a random generator. (If anybody knows why 42 is a commonly used seed, kindly let me know!)

In [4]:
import jax as J

key = J.random.key(42) # holy 42!
key

DeviceArray([ 0, 42], dtype=uint32)

In [5]:
J.random.normal(key)

DeviceArray(-0.18471177, dtype=float32)

In [6]:
J.random.normal(key)

DeviceArray(-0.18471177, dtype=float32)

#### Whoa there! why the same number again? 

Weird isn't it! The thing is, random generators in jax don't keep track whethere a key was previously used or not. So everytime you want to generate a new random number, you have to create a new key. This is kinda weird. 

But you can get around this "creating a new key every time". You just have to __split__ the key.


In [7]:
newkey, subkey = J.random.split(key)

In [8]:
J.random.normal(newkey)

DeviceArray(0.13790321, dtype=float32)

In [9]:
J.random.normal(newkey) # one more time .....

DeviceArray(0.13790321, dtype=float32)

### Multiple Splits

![what if](./images/what_if.jpg)


Off course! Legit question. You can mention how many keys you need with `split()`

In [10]:
newkey, *subkeys = J.random.split(key, 9)

In [11]:
for sk in subkeys:
    print(sk)

[2015457675 1759218286]
[ 372461012 2067635993]
[1901925373 4124313125]
[ 163989101 2463678267]
[2310818676 2815820454]
[3589763416 3424932669]
[4221803610 1067686388]
[2667717255 2997884537]


In [12]:
for sk in subkeys:
    print(J.random.normal(sk))

-0.43648595
-2.2741115
1.0372064
-0.41117954
-0.21331418
0.9359986
-0.3840317
-0.92283344


### Randomized Vectors

This was also, a reason for Jax's random generator design. Now why do you need vectors, you may ask? A bit of recalling, algorithms in machine learning require vectors and matrices ;)

So far we have been writing hard coded matrices and vectors in these notebooks. This ends now. We use RANDOM!!!!!!!

In [13]:
key = J.random.key(99) # why should 42 get all the attention?

In [14]:
a = J.random.normal(key, shape=(10, 1))
a

DeviceArray([[-0.9513214 ],
             [ 1.7788743 ],
             [-1.4580659 ],
             [-0.33244848],
             [ 0.24782881],
             [ 1.0122505 ],
             [ 0.7675285 ],
             [ 0.3141468 ],
             [-0.01992193],
             [ 0.42590362]], dtype=float32)

In [15]:
mat_a = J.random.normal(key, shape=(6, 4))
mat_a

DeviceArray([[ 0.48635894, -1.0345986 , -1.1348429 ,  1.4427332 ],
             [-0.34284604,  1.4199504 ,  1.0333503 , -0.92543155],
             [-1.3507137 , -0.43571466, -0.21265179, -0.7674866 ],
             [ 0.05622242,  0.16131395, -1.1077237 ,  0.71504277],
             [-0.26486132,  0.16493076, -0.4747249 , -0.07925941],
             [ 0.06515129, -0.42943704,  0.10599458, -1.6549199 ]],            dtype=float32)

### Arrays in Jax

If you're coming fresh out of a beginner python crash course, arrays are what you've known as lists, but a bit faster, better, more memory efficient etc. If you've experience working with numpy, Jax is just basic numpy with some magic sauce. Same applies for the arrays. Jax arrays are numpy arrays. 

In [16]:
# should we hard code again? nein!

import jax.numpy as jnp

b = jnp.arange(1, 10)
b

DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [17]:
c = b * 2
c

DeviceArray([ 2,  4,  6,  8, 10, 12, 14, 16, 18], dtype=int32)

In [18]:
b / 2

DeviceArray([0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5], dtype=float32)

In [19]:
b + 2

DeviceArray([ 3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

In [20]:
new_array = jnp.array(
    [b, c]
)
new_array

DeviceArray([[ 1,  2,  3,  4,  5,  6,  7,  8,  9],
             [ 2,  4,  6,  8, 10, 12, 14, 16, 18]], dtype=int32)

In [21]:
new_array.shape

(2, 9)

### Since they are like numpy arrays......

You should be able to change a value at a specific index, right? Erm............

In [22]:
b[2]

DeviceArray(3, dtype=int32)

In [23]:
b[2] = 99
b

TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

JAX arrays are immutable. Which means they can't be modified once created. But you may need to modify values at one point in your wonderful experiments with glorious functions called AI. Jax allows that but differently. You can modify the array but it'll create a new array and preserve the older one. 

In [24]:
b.at[2].set(99)

DeviceArray([ 1,  2, 99,  4,  5,  6,  7,  8,  9], dtype=int32)

In [25]:
b

DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

But if I do this:

In [26]:
new_b = b.at[2].set(99)
new_b

DeviceArray([ 1,  2, 99,  4,  5,  6,  7,  8,  9], dtype=int32)

In [27]:
b

DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

Now you have both the old and the new array. All this behaviour seems weird. How dare Jax won't let you do something you want? You've to wait a bit for that. Perhaps another notebook. 

### More weirdo

```
b.at[100].set(100)
```

Do you think this line will run? b only has a length of 9. Whatever you say I won't blame you. You're neither wrong or correct in the JAX world.

In [28]:
b.at[100].set(100)

DeviceArray([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

JAX completely ignores out of bound errors and will keep your original array intact. I don't know why it works this way when JAX imposes so much strict restrictions on other areas. So keep track of your indexes in your projects otherwise you'll be scratching your head why some value didn't update. 