In [20]:
import numpy as np
import tensorflow as tf

from collections import namedtuple

from lead_binary import DealMeta, suit_index_lookup
from binary_righty import binary_hand, get_card_index, encode_card

In [2]:
sess = tf.InteractiveSession()

In [3]:
model_path = './righty_model/righty-350000'

In [4]:
saver = tf.train.import_meta_graph(model_path + '.meta')
saver.restore(sess, model_path)

INFO:tensorflow:Restoring parameters from ./righty_model/righty-350000


In [5]:
graph = tf.get_default_graph()

In [6]:
seq_in = graph.get_tensor_by_name('seq_in:0')
seq_out = graph.get_tensor_by_name('seq_out:0')
keep_prob = graph.get_tensor_by_name('keep_prob:0')

out_card_logit = graph.get_tensor_by_name('out_card_logit:0')
out_card_target = graph.get_tensor_by_name('out_card_target:0')

In [126]:
out_card_logit.shape, out_card_target.shape, seq_out.shape

(TensorShape([Dimension(None), Dimension(32)]),
 TensorShape([Dimension(None), Dimension(32)]),
 TensorShape([Dimension(None), Dimension(None), Dimension(32)]))

In [8]:
state_c_0 = graph.get_tensor_by_name('state_c_0:0')
state_h_0 = graph.get_tensor_by_name('state_h_0:0')

state_c_1 = graph.get_tensor_by_name('state_c_1:0')
state_h_1 = graph.get_tensor_by_name('state_h_1:0')

state_c_2 = graph.get_tensor_by_name('state_c_2:0')
state_h_2 = graph.get_tensor_by_name('state_h_2:0')

next_c_0 = graph.get_tensor_by_name('next_c_0:0')
next_h_0 = graph.get_tensor_by_name('next_h_0:0')

next_c_1 = graph.get_tensor_by_name('next_c_1:0')
next_h_1 = graph.get_tensor_by_name('next_h_1:0')

next_c_2 = graph.get_tensor_by_name('next_c_2:0')
next_h_2 = graph.get_tensor_by_name('next_h_2:0')

x_in = graph.get_tensor_by_name('x_in:0')
out_card = graph.get_tensor_by_name('out_card:0')

In [9]:
x_in.shape, out_card.shape

(TensorShape([Dimension(1), Dimension(298)]),
 TensorShape([Dimension(1), Dimension(32)]))

In [10]:
State = namedtuple('State', ['c', 'h'])

lstm_size = 128

zero_state = (
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
)

In [11]:
def model(sess, p_keep=1.0):
    def pred_fun(x, state_in):
        feed_dict = {
            keep_prob: p_keep,
            x_in: x,
            state_c_0: state_in[0].c,
            state_h_0: state_in[0].h,
            state_c_1: state_in[1].c,
            state_h_1: state_in[1].h,
            state_c_2: state_in[2].c,
            state_h_2: state_in[2].h,
        }
        cards = sess.run(out_card, feed_dict=feed_dict)
        next_state = (
            State(c=sess.run(next_c_0, feed_dict=feed_dict), h=sess.run(next_h_0, feed_dict=feed_dict)),
            State(c=sess.run(next_c_1, feed_dict=feed_dict), h=sess.run(next_h_1, feed_dict=feed_dict)),
            State(c=sess.run(next_c_2, feed_dict=feed_dict), h=sess.run(next_h_2, feed_dict=feed_dict)),
        )
        return cards, next_state
    return pred_fun

In [12]:
righty = model(sess)

In [77]:
deal_str = 'W:T9.T86.AK3.AT873 AKQ87.AKQ5.Q862. J65.7432.J94.KJ6 432.J9.T75.Q9542'

outcome = 'S - 4S.=.N'

# deal_str = 'W:AJT9.J84.JT9872. 6.Q9.AKQ3.KJ9743 Q87542.A73.6.AQ8 K3.KT652.54.T652'

# outcome = 'W NS 4S.+2.E'

In [78]:
hands = list(map(lambda hand_str: list(map(list, hand_str.split('.'))), deal_str[2:].split()))

In [79]:
hands

[[['T', '9'], ['T', '8', '6'], ['A', 'K', '3'], ['A', 'T', '8', '7', '3']],
 [['A', 'K', 'Q', '8', '7'], ['A', 'K', 'Q', '5'], ['Q', '8', '6', '2'], []],
 [['J', '6', '5'], ['7', '4', '3', '2'], ['J', '9', '4'], ['K', 'J', '6']],
 [['4', '3', '2'], ['J', '9'], ['T', '7', '5'], ['Q', '9', '5', '4', '2']]]

In [80]:
righty_bin = binary_hand(hands[0])
dummy_bin = binary_hand(hands[3])

In [81]:
righty_bin.reshape((4, 8))

array([[ 0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.],
       [ 1.,  1.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 1.,  0.,  0.,  0.,  1.,  0.,  1.,  2.]], dtype=float16)

In [82]:
deal_meta = DealMeta.from_str(outcome)

In [83]:
def get_x_in(righty_bin, dummy_bin, last_lead, last_trick, this_trick, d_meta):
    x = np.zeros((1, 298), np.float16)
    
    x[0, 292] = d_meta.level
    if d_meta.strain == 'N':
        x[0, 293] = 1
    else:
        x[0, 294 + suit_index_lookup[d_meta.strain]] = 1
    x[0, 288 + last_lead] = 1
    
    x[0, 0:32] = righty_bin
    x[0, 32:64] = dummy_bin
    
    x[0, 64:96] = encode_card(last_trick[0])
    x[0, 96:128] = encode_card(last_trick[1])
    x[0, 128:160] = encode_card(last_trick[2])
    x[0, 160:192] = encode_card(last_trick[3])
    
    x[0, 192:224] = encode_card(this_trick[0])
    x[0, 224:256] = encode_card(this_trick[1])
    x[0, 256:288] = encode_card(this_trick[2])
    
    return x

In [84]:
next_state = zero_state

In [117]:
dummy_bin[get_card_index('D5')] -= 1

In [121]:
righty_bin[get_card_index('HT')] -= 1

In [122]:
x = get_x_in(righty_bin, dummy_bin, 1, ['D4', 'D5', 'DK', 'D2'],  ['>>', '>>', '>>'], deal_meta)

In [123]:
card, next_state = righty(x, next_state)

In [124]:
card.reshape((4, 8))

array([[  5.96527805e-09,   7.81264862e-06,   5.28886403e-06,
          2.85826559e-07,   1.16330312e-07,   1.17727176e-07,
          7.52518492e-09,   9.90443194e-09],
       [  8.55782833e-10,   4.29609418e-06,   3.56907585e-05,
          1.67684368e-04,   3.86898249e-01,   4.02634432e-05,
          1.51020974e-01,   1.29349564e-05],
       [  2.41957545e-01,   2.07751500e-03,   1.48116055e-06,
          1.61340085e-05,   2.50886887e-06,   8.16236061e-05,
          2.91481638e-06,   2.18722317e-02],
       [  4.27268151e-06,   2.33556329e-06,   2.17595170e-05,
          1.45187096e-06,   4.74915579e-02,   8.48249874e-06,
          6.43345118e-02,   8.39298293e-02]], dtype=float32)

In [94]:
dummy_bin.reshape((4, 8))

array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  0.,  2.],
       [ 0.,  0.,  1.,  0.,  0.,  1.,  0.,  3.]], dtype=float16)

In [97]:
righty_bin.reshape((4, 8))

array([[ 0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.],
       [ 1.,  1.,  0.,  0.,  0.,  0.,  0.,  1.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  1.,  2.]], dtype=float16)