In [2]:
import numpy as np
import tensorflow as tf

from collections import namedtuple

from lead_binary import DealMeta, suit_index_lookup
from binary_righty import binary_hand, get_card_index, encode_card

In [2]:
sess = tf.InteractiveSession()

In [3]:
model_path = './dummy_model/dummy-920000'

In [4]:
saver = tf.train.import_meta_graph(model_path + '.meta')
saver.restore(sess, model_path)

INFO:tensorflow:Restoring parameters from ./dummy_model/dummy-920000


In [5]:
graph = tf.get_default_graph()

In [6]:
seq_in = graph.get_tensor_by_name('seq_in:0')
seq_out = graph.get_tensor_by_name('seq_out:0')
keep_prob = graph.get_tensor_by_name('keep_prob:0')

out_card_logit = graph.get_tensor_by_name('out_card_logit:0')
out_card_target = graph.get_tensor_by_name('out_card_target:0')

In [7]:
out_card_logit.shape, out_card_target.shape, seq_out.shape

(TensorShape([Dimension(None), Dimension(32)]),
 TensorShape([Dimension(None), Dimension(32)]),
 TensorShape([Dimension(None), Dimension(None), Dimension(32)]))

In [8]:
state_c_0 = graph.get_tensor_by_name('state_c_0:0')
state_h_0 = graph.get_tensor_by_name('state_h_0:0')

state_c_1 = graph.get_tensor_by_name('state_c_1:0')
state_h_1 = graph.get_tensor_by_name('state_h_1:0')

state_c_2 = graph.get_tensor_by_name('state_c_2:0')
state_h_2 = graph.get_tensor_by_name('state_h_2:0')

next_c_0 = graph.get_tensor_by_name('next_c_0:0')
next_h_0 = graph.get_tensor_by_name('next_h_0:0')

next_c_1 = graph.get_tensor_by_name('next_c_1:0')
next_h_1 = graph.get_tensor_by_name('next_h_1:0')

next_c_2 = graph.get_tensor_by_name('next_c_2:0')
next_h_2 = graph.get_tensor_by_name('next_h_2:0')

x_in = graph.get_tensor_by_name('x_in:0')
out_card = graph.get_tensor_by_name('out_card:0')

In [14]:
x_in.shape, out_card.shape

(TensorShape([Dimension(1), Dimension(298)]),
 TensorShape([Dimension(1), Dimension(32)]))

In [15]:
State = namedtuple('State', ['c', 'h'])

lstm_size = 128

zero_state = (
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
    State(c=np.zeros((1, lstm_size)), h=np.zeros((1, lstm_size))),
)

In [16]:
def model(sess, p_keep=1.0):
    def pred_fun(x, state_in):
        feed_dict = {
            keep_prob: p_keep,
            x_in: x,
            state_c_0: state_in[0].c,
            state_h_0: state_in[0].h,
            state_c_1: state_in[1].c,
            state_h_1: state_in[1].h,
            state_c_2: state_in[2].c,
            state_h_2: state_in[2].h,
        }
        cards = sess.run(out_card, feed_dict=feed_dict)
        next_state = (
            State(c=sess.run(next_c_0, feed_dict=feed_dict), h=sess.run(next_h_0, feed_dict=feed_dict)),
            State(c=sess.run(next_c_1, feed_dict=feed_dict), h=sess.run(next_h_1, feed_dict=feed_dict)),
            State(c=sess.run(next_c_2, feed_dict=feed_dict), h=sess.run(next_h_2, feed_dict=feed_dict)),
        )
        return cards, next_state
    return pred_fun

In [17]:
dummy = model(sess)

In [19]:
# deal_str = 'W:T9.T86.AK3.AT873 AKQ87.AKQ5.Q862. J65.7432.J94.KJ6 432.J9.T75.Q9542'

# outcome = 'S - 4S.=.N'

deal_str = 'W:AJT9.J84.JT9872. 6.Q9.AKQ3.KJ9743 Q87542.A73.6.AQ8 K3.KT652.54.T652'

outcome = 'W NS 4S.+2.E'

In [20]:
hands = list(map(lambda hand_str: list(map(list, hand_str.split('.'))), deal_str[2:].split()))

In [21]:
hands

[[['A', 'J', 'T', '9'], ['J', '8', '4'], ['J', 'T', '9', '8', '7', '2'], []],
 [['6'], ['Q', '9'], ['A', 'K', 'Q', '3'], ['K', 'J', '9', '7', '4', '3']],
 [['Q', '8', '7', '5', '4', '2'], ['A', '7', '3'], ['6'], ['A', 'Q', '8']],
 [['K', '3'], ['K', 'T', '6', '5', '2'], ['5', '4'], ['T', '6', '5', '2']]]

In [22]:
declarer_bin = binary_hand(hands[2])
dummy_bin = binary_hand(hands[0])

In [23]:
dummy_bin.reshape((4, 8))

array([[ 1.,  0.,  0.,  1.,  1.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.,  1.,  1.,  2.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float16)

In [24]:
deal_meta = DealMeta.from_str(outcome)

In [25]:
def get_x_in(dummy_bin, declarer_bin, last_lead, last_trick, this_trick, d_meta):
    x = np.zeros((1, 298), np.float16)
    
    x[0, 292] = d_meta.level
    if d_meta.strain == 'N':
        x[0, 293] = 1
    else:
        x[0, 294 + suit_index_lookup[d_meta.strain]] = 1
    x[0, 288 + last_lead] = 1
    
    x[0, 0:32] = dummy_bin
    x[0, 32:64] = declarer_bin
    
    x[0, 64:96] = encode_card(last_trick[0])
    x[0, 96:128] = encode_card(last_trick[1])
    x[0, 128:160] = encode_card(last_trick[2])
    x[0, 160:192] = encode_card(last_trick[3])
    
    x[0, 192:224] = encode_card(this_trick[0])
    x[0, 224:256] = encode_card(this_trick[1])
    x[0, 256:288] = encode_card(this_trick[2])
    
    return x

In [26]:
next_state = zero_state

In [27]:
x = get_x_in(dummy_bin, declarer_bin, 0, ['>>', '>>', '>>', '>>'],  ['>>', '>>', 'C2'], deal_meta)

In [29]:
card, next_state = dummy(x, next_state)

In [30]:
card.reshape((4, 8))

array([[  3.29213362e-04,   2.32883322e-04,   2.70162109e-05,
          7.86976540e-04,   1.16306928e-03,   2.29660362e-01,
          3.67557514e-05,   8.82916447e-06],
       [  2.56797939e-05,   1.05918407e-04,   3.02418375e-05,
          7.48313905e-04,   1.68862766e-06,   1.14631139e-05,
          3.73032154e-03,   3.37206751e-01],
       [  6.27911882e-04,   2.13620260e-05,   2.63371312e-05,
          1.27133244e-04,   1.03821229e-04,   6.16876059e-04,
          1.03778671e-02,   4.13878024e-01],
       [  2.57621537e-06,   2.52471400e-05,   3.15151046e-05,
          9.07687809e-06,   2.45828051e-05,   1.72677173e-05,
          4.45334763e-06,   4.99855730e-07]], dtype=float32)

In [3]:
np.argmax(np.array([1,2,5,3]))

2

In [20]:
[(suit + rank) for suit in 'SHDC' for rank in 'AKQJT98x']

['SA',
 'SK',
 'SQ',
 'SJ',
 'ST',
 'S9',
 'S8',
 'Sx',
 'HA',
 'HK',
 'HQ',
 'HJ',
 'HT',
 'H9',
 'H8',
 'Hx',
 'DA',
 'DK',
 'DQ',
 'DJ',
 'DT',
 'D9',
 'D8',
 'Dx',
 'CA',
 'CK',
 'CQ',
 'CJ',
 'CT',
 'C9',
 'C8',
 'Cx']

In [8]:
min(['5', '2', '7'])

'2'

In [9]:
lst1 = [1,2,3,4]

In [10]:
del lst1[0]

In [11]:
lst1

[2, 3, 4]

In [19]:
'7' < '8'

True