-
Notifications
You must be signed in to change notification settings - Fork 1
/
bit_flip_env.py
32 lines (25 loc) · 892 Bytes
/
bit_flip_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
class Env(object):
def __init__(self, num_bits):
self.num_bits = num_bits
self.done = None
self.num_steps = None
self.state = None
self.target = None
def reset(self):
self.done = False
self.num_steps = 0
self.state = np.random.randint(2, size=self.num_bits)
self.target = np.random.randint(2, size=self.num_bits)
return self.state, self.target
def step(self, action):
assert not self.done
self.state[action] = 1 - self.state[action]
if self.num_steps > self.num_bits + 1:
self.done = True
self.num_steps += 1
if np.sum(self.state == self.target) == self.num_bits:
self.done = True
return np.copy(self.state), 0, self.done, {}
else:
return np.copy(self.state), -1, self.done, {}