# Dropout Mechanism

In [1]:
import jax.numpy as jnp
from jax import random, jit

## # Big No No !

To make sure the selected dropout neuros are different from batch to batch. You should split the key you use. Or it can drop the same neuros for each batch just like: 

In [2]:
def dropout(x: jnp.ndarray, key, p=0.5, train=True):
    if not train:
        return x, key
    p_keep = 1 - p
    mask = random.bernoulli(key, p_keep, x.shape)

    return jnp.where(mask, x/p_keep, 0)  # scale here to make E(X) the same while evaluating.

In [3]:
key = random.PRNGKey(41)
x = jnp.ones((100))  # Input data

p = 0.5

import time
s = time.time()
for i in range(1000):  # batches
    x_dropout = dropout(x, key, p, True)
    if i % 200 == 0:
        print(f"Batch {i + 1} :\n", x_dropout[::5])

print(f'time: {time.time() - s}')

Batch 1 :
 [0. 2. 0. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2.]
Batch 201 :
 [0. 2. 0. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2.]
Batch 401 :
 [0. 2. 0. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2.]
Batch 601 :
 [0. 2. 0. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2.]
Batch 801 :
 [0. 2. 0. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2. 2. 2. 0. 2. 2. 0. 2.]
time: 0.213700532913208


each batch are the same. To avoid this, we can split the key and return it explicitly & change past-in key value.

In [4]:
def dropout(x: jnp.ndarray, key, p=0.5, train=True):
    if not train:
        return x, key
    p_keep = 1 - p
    new_key, use_key = random.split(key)  # update key
    mask = random.bernoulli(use_key, p_keep, x.shape)

    return jnp.where(mask, x/p_keep, 0), new_key  # scale here to make E(X) the same while evaluating.

In [5]:
key = random.PRNGKey(41)
x = jnp.ones((100))  # Input data

p = 0.5

import time
s = time.time()
for i in range(1000):  # batches
    x_dropout, key = dropout(x, key, p, True)
    if i % 200 == 0:
        print(f"Batch {i + 1} :\n", x_dropout[::5])

print(f'time: {time.time() - s}')

Batch 1 :
 [0. 0. 0. 0. 2. 0. 2. 2. 0. 0. 0. 2. 2. 0. 0. 0. 0. 0. 0. 0.]
Batch 201 :
 [0. 0. 2. 0. 2. 0. 2. 0. 2. 0. 0. 0. 2. 0. 0. 2. 0. 0. 2. 2.]
Batch 401 :
 [2. 0. 2. 0. 2. 0. 2. 2. 2. 0. 2. 0. 2. 0. 0. 0. 0. 2. 0. 2.]
Batch 601 :
 [0. 0. 2. 2. 0. 2. 2. 0. 2. 2. 2. 0. 2. 0. 2. 2. 0. 2. 0. 0.]
Batch 801 :
 [2. 2. 2. 2. 2. 0. 2. 0. 2. 2. 0. 2. 2. 0. 0. 2. 2. 0. 0. 2.]
time: 0.15825676918029785


and if you want to use jit, you should make the 'train' parameter a static parameter, or the output will not be fixed due to this code: 

```python
if not train:
    return x, key
```

In [6]:
dropout = jit(dropout, static_argnames=('train'))  # use jit here

In [14]:
key = random.PRNGKey(41)
x = jnp.ones((100))  # Input data

p = 0.5

import time
s = time.time()
for i in range(1000):  # batches
    x_dropout, key = dropout(x, key, p, True)
    if i % 200 == 0:
        print(f"Batch {i + 1} :\n", x_dropout[::5])

print(f'time with JIT: {time.time() - s}')

Batch 1 :
 [0. 0. 0. 0. 2. 0. 2. 2. 0. 0. 0. 2. 2. 0. 0. 0. 0. 0. 0. 0.]
Batch 201 :
 [0. 0. 2. 0. 2. 0. 2. 0. 2. 0. 0. 0. 2. 0. 0. 2. 0. 0. 2. 2.]
Batch 401 :
 [2. 0. 2. 0. 2. 0. 2. 2. 2. 0. 2. 0. 2. 0. 0. 0. 0. 2. 0. 2.]
Batch 601 :
 [0. 0. 2. 2. 0. 2. 2. 0. 2. 2. 2. 0. 2. 0. 2. 2. 0. 2. 0. 0.]
Batch 801 :
 [2. 2. 2. 2. 2. 0. 2. 0. 2. 2. 0. 2. 2. 0. 0. 2. 2. 0. 0. 2.]
time with JIT: 0.014821052551269531
