# Bit flipping game with DQN solver

This is the implementation of the DQN solver for the bit flipping game in [**Hindsight Experience Replay**](https://arxiv.org/abs/1707.01495).

**Rerefence**:

1. Marcin Andrychowicz, Filip Wolski, Alex Ray, Jonas Schneider, Rachel Fong, Peter Welinder, Bob McGrew, Josh Tobin, Pieter Abbeel, Wojciech Zaremba, Hindsight Experience Replay


In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from bitflipping import bitflipping as bf
from DQN import DQN

plt.rcParams['figure.figsize'] = [15, 20]
%matplotlib inline

## Set up the bit flipping game environment

In [2]:
init_state = np.array([0,1])
goal = np.ones((2,))
n = 4
bf_env = bf(n)

## Build up the DQN neural network

In [3]:
tf.reset_default_graph()


x = tf.placeholder(tf.float32, shape=(None, 2*n))
y = tf.placeholder(tf.float32, shape=(None, 1))


hid = [256]
agent = DQN(x, hid, n, discount=0.98, eps=0.5, tau = 0.95, replay_buffer_size=1e3, batch_size=32)

In [None]:
losses, success_all = agent.train_Q(x, y, epoch=8, cycles=50, episode=16, iteration=50)

Epoch 0 Cycle 0: loss is 0.0843, success rate 0.125
Epoch 0 Cycle 1: loss is 0.101, success rate 0.1875
Epoch 0 Cycle 2: loss is 0.105, success rate 0.1875
Epoch 0 Cycle 3: loss is 0.107, success rate 0.25
Epoch 0 Cycle 4: loss is 0.109, success rate 0.3125
Epoch 0 Cycle 5: loss is 0.153, success rate 0.0625
Epoch 0 Cycle 6: loss is 0.085, success rate 0.1875
Epoch 0 Cycle 7: loss is 0.121, success rate 0.3125
Epoch 0 Cycle 8: loss is 0.128, success rate 0.1875
Epoch 0 Cycle 9: loss is 0.0995, success rate 0.25
Epoch 0 Cycle 10: loss is 0.106, success rate 0.0625
Epoch 0 Cycle 11: loss is 0.146, success rate 0.25
Epoch 0 Cycle 12: loss is 0.143, success rate   0
Epoch 0 Cycle 13: loss is 0.154, success rate 0.3125
Epoch 0 Cycle 14: loss is 0.171, success rate 0.3125
Epoch 0 Cycle 15: loss is 0.163, success rate 0.0625
Epoch 0 Cycle 16: loss is 0.205, success rate 0.3125
Epoch 0 Cycle 17: loss is 0.158, success rate 0.25
Epoch 0 Cycle 18: loss is 0.191, success rate 0.0625
Epoch 0 Cycle

Epoch 3 Cycle 8: loss is 0.894, success rate 0.6875
Epoch 3 Cycle 9: loss is 1.15, success rate 0.75
Epoch 3 Cycle 10: loss is 0.912, success rate 0.875
Epoch 3 Cycle 11: loss is 1.16, success rate 0.8125
Epoch 3 Cycle 12: loss is 1.04, success rate 0.9375
Epoch 3 Cycle 13: loss is 0.926, success rate 0.9375
Epoch 3 Cycle 14: loss is 0.792, success rate   1
Epoch 3 Cycle 15: loss is 1.32, success rate 0.8125
Epoch 3 Cycle 16: loss is 1.06, success rate 0.625
Epoch 3 Cycle 17: loss is 0.801, success rate 0.9375
Epoch 3 Cycle 18: loss is 0.71, success rate 0.875
Epoch 3 Cycle 19: loss is 0.797, success rate   1
Epoch 3 Cycle 20: loss is 0.952, success rate 0.9375
Epoch 3 Cycle 21: loss is 1.02, success rate 0.625
Epoch 3 Cycle 22: loss is 1.03, success rate 0.75
Epoch 3 Cycle 23: loss is 1.11, success rate 0.6875
Epoch 3 Cycle 24: loss is 0.981, success rate   1
Epoch 3 Cycle 25: loss is 1.05, success rate 0.75
Epoch 3 Cycle 26: loss is 1.1, success rate 0.8125
Epoch 3 Cycle 27: loss is 

In [None]:
plt.figure()
plt.plot(losses)
plt.show()

## Test DQN

In [None]:
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, '/tmp/model.ckpt')
    
    success = 0
    for i in range(100):
        
        bf_env.reset()

        for i in range(n):
            X = np.concatenate((bf_env.state.reshape((1,-1)),bf_env.goal.reshape((1,-1))), axis=1)
            Q = sess.run(agent.targetModel, feed_dict={x: X})
            action = np.argmax(Q)
            bf_env.update_state(action)
            if (bf_env.reward(bf_env.state)==0):
                print('Success! state:{0}\t Goal state:{1}'.format(bf_env.state, bf_env.goal))
                success += 1
                break
            elif (i==n-1):
                print('Fail! state:{0}\t Goal state:{1}'.format(bf_env.state, bf_env.goal))
                
    print('Success rate {}%'.format(success))

In [None]:
a=np.array([[1,2,3,2,1,3]])

In [None]:
a=np.array([3,1,2])

In [None]:
a[a<0] += 2

In [None]:
a

In [None]:
(a==None).all()

In [None]:
s=np.argmax(a)