In [1]:
import jax.numpy as jnp
from jax import random, jit

In [8]:
def dropout(x: jnp.ndarray, key, p=0.5, train=True):
    if not train:
        return x, key
    p_keep = 1 - p
    mask = random.bernoulli(key, p_keep, x.shape)
    new_key, _ = random.split(key)  # update key

    return jnp.where(mask, x/p_keep, 0), new_key  # scale here to make E(X) the same while evaluating.

# dropout = jit(dropout, static_argnames=('train'))  # not use jit here

In [11]:
key = random.PRNGKey(41)
x = jnp.ones((5))  # 输入数据

train = True
p = 0.5

import time
s = time.time()
for i in range(3):  # 模拟3个batch
    x_dropout, key = dropout(x, key, p, train)
    print(f"Batch {i+1} 结果:\n", x_dropout)

print(f'time: {time.time() - s}')

Batch 1 结果:
 [0. 2. 2. 2. 0.]
Batch 2 结果:
 [2. 2. 2. 2. 2.]
Batch 3 结果:
 [2. 0. 2. 0. 2.]
time: 0.0011780261993408203
