<a href="https://colab.research.google.com/github/yblee110/jax-flax-book/blob/main/ch02_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
np.random.seed(0)

In [2]:
def print_truncated_random_state():
    """출력에 문제생기지 않게 상태의 일부만 보여준다."""
    full_random_state = np.random.get_state()
    print(str(full_random_state)[:460], '...')


print_truncated_random_state()



('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...


In [3]:
np.random.seed(0)
print_truncated_random_state()
_ = np.random.uniform()
print_truncated_random_state()


('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...


In [4]:
np.random.seed(0)
print(np.random.uniform(size=3))

[0.5488135  0.71518937 0.60276338]


In [5]:
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))

individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]


In [6]:
import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()


def foo(): return bar() + 2 * baz()


print(foo())

1.9791922366721637


In [7]:
# 새로운 버전
from numpy.random import default_rng
rng = default_rng(seed=0)
vals = rng.standard_normal(10)
more_vals = rng.standard_normal(10)

# 이전버전
from numpy import random
vals = random.standard_normal(10)
more_vals = random.standard_normal(10)


In [8]:
from numpy.random import SeedSequence, default_rng
ss = SeedSequence(12345)
child_seeds = ss.spawn(10)
streams = [default_rng(s) for s in child_seeds]
print(streams)


[Generator(PCG64) at 0x7F951A307300, Generator(PCG64) at 0x7F951A307840, Generator(PCG64) at 0x7F951A3074C0, Generator(PCG64) at 0x7F951A3075A0, Generator(PCG64) at 0x7F951A307920, Generator(PCG64) at 0x7F951A307680, Generator(PCG64) at 0x7F951A307760, Generator(PCG64) at 0x7F951A307BC0, Generator(PCG64) at 0x7F951A2525E0, Generator(PCG64) at 0x7F951A252420]


In [9]:
import numpy as np
from numpy.random import SeedSequence, default_rng
ss = SeedSequence(12345)
seeds = ss.spawn(2)
stream = [default_rng(s) for s in seeds]

def bar(): return stream[0].uniform()
def baz(): return stream[1].uniform()

def foo(): return bar() + 2 * baz()

print(foo())


1.6241496684412051


In [10]:
from jax import random
key = random.PRNGKey(42)
print(key)

[ 0 42]


In [11]:
print(random.normal(key))
print(random.normal(key))

-0.18471177
-0.18471177


In [13]:
print("old key", key)
new_key, subkey = random.split(key)
del key  # 오래된 키는 지워버리며 나중에라도 사용ㅇ하지 않습니다..
normal_sample = random.normal(subkey)
print(r"    \---SPLIT --> new key   ", new_key)
print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
del subkey  # 서브키도 사용후에 제거해야 합니다.
key = new_key  # 만약에 한번 이 키를 다시 생성해야 한다면 new_key는 키로 사용됩니다.

old key [2465931498 3679230171]
    \---SPLIT --> new key    [3164236999 3984487275]
             \--> new subkey [3923418436 1366451097] --> normal -0.19947024


In [14]:
key, subkey = random.split(key)

In [15]:
key = random.PRNGKey(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.PRNGKey(42)
print("all at once: ", random.normal(key, shape=(3,)))


individually: [-0.04838832  0.10796154 -1.2226542 ]
all at once:  [ 0.18693547 -1.2806505  -1.5593132 ]
