In [1]:
import numpy as np
import tensorflow as tf

from collections import namedtuple

from lead_binary import DealMeta, suit_index_lookup
from binary_declarer import binary_hand, get_card_index

In [2]:
sess = tf.InteractiveSession()

In [3]:
model_path = './declarer_model2/declarer-940000'

In [4]:
saver = tf.train.import_meta_graph(model_path + '.meta')
saver.restore(sess, model_path)

INFO:tensorflow:Restoring parameters from ./declarer_model2/declarer-940000


In [5]:
graph = tf.get_default_graph()

In [6]:
seq_in = graph.get_tensor_by_name('seq_in:0')
seq_out = graph.get_tensor_by_name('seq_out:0')
keep_prob = graph.get_tensor_by_name('keep_prob:0')

out_card_logit = graph.get_tensor_by_name('out_card_logit:0')
out_card_target = graph.get_tensor_by_name('out_card_target:0')

In [7]:
state_c_0 = graph.get_tensor_by_name('state_c_0:0')
state_h_0 = graph.get_tensor_by_name('state_h_0:0')

state_c_1 = graph.get_tensor_by_name('state_c_1:0')
state_h_1 = graph.get_tensor_by_name('state_h_1:0')

state_c_2 = graph.get_tensor_by_name('state_c_2:0')
state_h_2 = graph.get_tensor_by_name('state_h_2:0')

state_c_3 = graph.get_tensor_by_name('state_c_3:0')
state_h_3 = graph.get_tensor_by_name('state_h_3:0')

state_c_4 = graph.get_tensor_by_name('state_c_4:0')
state_h_4 = graph.get_tensor_by_name('state_h_4:0')

next_c_0 = graph.get_tensor_by_name('next_c_0:0')
next_h_0 = graph.get_tensor_by_name('next_h_0:0')

next_c_1 = graph.get_tensor_by_name('next_c_1:0')
next_h_1 = graph.get_tensor_by_name('next_h_1:0')

next_c_2 = graph.get_tensor_by_name('next_c_2:0')
next_h_2 = graph.get_tensor_by_name('next_h_2:0')

next_c_3 = graph.get_tensor_by_name('next_c_3:0')
next_h_3 = graph.get_tensor_by_name('next_h_3:0')

next_c_4 = graph.get_tensor_by_name('next_c_4:0')
next_h_4 = graph.get_tensor_by_name('next_h_4:0')

x_in = graph.get_tensor_by_name('x_in:0')
out_card = graph.get_tensor_by_name('out_card:0')

In [8]:
x_in.shape, out_card.shape

(TensorShape([Dimension(1), Dimension(147)]),
 TensorShape([Dimension(1), Dimension(32)]))

In [9]:
State = namedtuple('State', ['c', 'h'])

lstm_size = 128

zero_state = (
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
)

In [10]:
def model(sess, p_keep=1.0):
    def pred_fun(x, state_in):
        feed_dict = {
            keep_prob: p_keep,
            x_in: x,
            state_c_0: state_in[0].c,
            state_h_0: state_in[0].h,
            state_c_1: state_in[1].c,
            state_h_1: state_in[1].h,
            state_c_2: state_in[2].c,
            state_h_2: state_in[2].h,
            state_c_3: state_in[2].c,
            state_h_3: state_in[2].h,
            state_c_4: state_in[2].c,
            state_h_4: state_in[2].h,
        }
        cards = sess.run(out_card, feed_dict=feed_dict)
        next_state = (
            State(c=sess.run(next_c_0, feed_dict=feed_dict), h=sess.run(next_h_0, feed_dict=feed_dict)),
            State(c=sess.run(next_c_1, feed_dict=feed_dict), h=sess.run(next_h_1, feed_dict=feed_dict)),
            State(c=sess.run(next_c_2, feed_dict=feed_dict), h=sess.run(next_h_2, feed_dict=feed_dict)),
            State(c=sess.run(next_c_3, feed_dict=feed_dict), h=sess.run(next_h_3, feed_dict=feed_dict)),
            State(c=sess.run(next_c_4, feed_dict=feed_dict), h=sess.run(next_h_4, feed_dict=feed_dict)),
        )
        return cards, next_state
    return pred_fun

In [11]:
decl = model(sess)

In [12]:
# deal_str = 'W:T9.T86.AK3.AT873 AKQ87.AKQ5.Q862. J65.7432.J94.KJ6 432.J9.T75.Q9542'

# outcome = 'S - 4S.=.N'

deal_str = 'W:AJT9.J84.JT9872. 6.Q9.AKQ3.KJ9743 Q87542.A73.6.AQ8 K3.KT652.54.T652'

outcome = 'W NS 4S.+2.E'

In [13]:
hands = list(map(lambda hand_str: list(map(list, hand_str.split('.'))), deal_str[2:].split()))

In [14]:
hands

[[['A', 'J', 'T', '9'], ['J', '8', '4'], ['J', 'T', '9', '8', '7', '2'], []],
 [['6'], ['Q', '9'], ['A', 'K', 'Q', '3'], ['K', 'J', '9', '7', '4', '3']],
 [['Q', '8', '7', '5', '4', '2'], ['A', '7', '3'], ['6'], ['A', 'Q', '8']],
 [['K', '3'], ['K', 'T', '6', '5', '2'], ['5', '4'], ['T', '6', '5', '2']]]

In [15]:
declarer_bin = binary_hand(hands[2])
dummy_bin = binary_hand(hands[0])

In [16]:
dummy_bin.reshape((4, 8))

array([[ 1.,  0.,  0.,  1.,  1.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.,  1.,  1.,  2.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float16)

In [17]:
deal_meta = DealMeta.from_str(outcome)

In [18]:
def get_x_in(declarer_bin, dummy_bin, card0, card1, d_meta, on_lead, on_play):
    x = np.zeros((1, 147), np.float16)
    
    x[0, 137] = d_meta.level
    if d_meta.strain == 'N':
        x[0, 138] = 1
    else:
        x[0, 139 + suit_index_lookup[d_meta.strain]] = 1
    x[0, 144] = int(d_meta.doubled)
    x[0, 145] = int(d_meta.declarer_vuln())
    x[0, 146] = int(d_meta.leader_vuln())
    x[0, 128 + on_lead] = 1
    x[0, 132 + on_play] = 1
    x[0, 0:32] = dummy_bin[:,:]
    x[0, 32:64] = declarer_bin[:,:]
    if card0 != '>>':
        x[0, 64 + get_card_index(card0)] = 1
    x[0, 96 + get_card_index(card1)] = 1
    
    return x

In [19]:
next_state = zero_state

In [44]:
dummy_bin[0, get_card_index('D5')] -= 1

In [40]:
declarer_bin[0, get_card_index('DQ')] -= 1

In [20]:
x = get_x_in(declarer_bin, dummy_bin, '>>', 'C2', deal_meta, 0, 1)

In [21]:
card, next_state = decl(x, next_state)

In [22]:
card.reshape((4, 8))

array([[  3.79360776e-04,   5.06378361e-04,   2.64651934e-03,
          1.04095484e-03,   2.08720285e-03,   3.46488327e-01,
          3.37117926e-05,   3.67203538e-05],
       [  3.78908262e-05,   4.04609782e-05,   6.53797542e-05,
          1.26800779e-03,   1.88201142e-04,   8.72436154e-04,
          8.94523342e-04,   3.09821695e-01],
       [  4.71577841e-05,   7.67370148e-05,   1.51492204e-04,
          7.73113628e-04,   4.55712696e-04,   1.73731986e-03,
          5.62496157e-03,   3.24690968e-01],
       [  3.67561256e-07,   5.27790417e-06,   1.48076344e-06,
          9.35530352e-06,   1.40150214e-05,   3.54903773e-06,
          4.59555423e-07,   1.81333078e-07]], dtype=float32)

In [57]:
dummy_bin.reshape(4, 8)

array([[ 1.,  0.,  0.,  1.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float16)

In [50]:
declarer_bin.reshape(4, 8)

array([[ 1.,  1.,  1.,  0.,  0.,  0.,  1.,  0.],
       [ 1.,  1.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  1.,  2.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float16)

In [37]:
state_h_2

<tf.Tensor 'state_h_2:0' shape=(1, 128) dtype=float32>

In [None]:
sess.run(out_card, {
    keep_prob: 1.0,
    x_in: x
})