In [1]:
import torch
import numpy as np
import os

from networks import HandRankClassificationFive
from evaluate_card_models import load_weights
from network_config import NetworkConfig
import datatypes as dt

In [2]:
def load_weights(net):
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(examine_params['load_path']))
    else: 
        net.load_state_dict(torch.load(examine_params['load_path'],map_location=torch.device('cpu')))

In [3]:
datatype = 'handranksfive'
learning_category = dt.Globals.DatasetCategories[datatype]

In [4]:
network = NetworkConfig.DataModels[datatype]
network_name = NetworkConfig.DataModels[datatype].__name__
network_path = os.path.join('checkpoints',learning_category,network_name)

network_params = {
    'seed':346,
    'state_space':(13,2),
    'nA':dt.Globals.ACTION_SPACES[datatype],
    'channels':13,
    'kernel':2,
    'batchnorm':True,
    'conv_layers':1,
    'gpu1': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    'gpu2': torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
}
examine_params = {
        'network':network,
        'load_path':network_path
    }

In [5]:
net = network(network_params)

In [6]:
load_weights(net)

# Inputs

In [7]:
ranks = torch.tensor([[[[ 8,  8,  3,  7, 11],
          [ 8,  8,  3,  7, 11],
          [ 8,  8,  3,  7, 13],
          [ 8,  8,  3, 11, 11],
          [ 8,  8,  3, 11, 13],
          [ 8,  8,  3, 11, 13],
          [ 8,  8,  7, 11, 11],
          [ 8,  8,  7, 11, 13],
          [ 8,  8,  7, 11, 13],
          [ 8,  8, 11, 11, 13],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 13],
          [ 8, 11,  3, 11, 11],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  7, 11, 11],
          [ 8, 11,  7, 11, 13],
          [ 8, 11,  7, 11, 13],
          [ 8, 11, 11, 11, 13],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 13],
          [ 8, 11,  3, 11, 11],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  7, 11, 11],
          [ 8, 11,  7, 11, 13],
          [ 8, 11,  7, 11, 13],
          [ 8, 11, 11, 11, 13],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 13],
          [ 8, 11,  3, 11, 11],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  7, 11, 11],
          [ 8, 11,  7, 11, 13],
          [ 8, 11,  7, 11, 13],
          [ 8, 11, 11, 11, 13],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 13],
          [ 8, 11,  3, 11, 11],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  7, 11, 11],
          [ 8, 11,  7, 11, 13],
          [ 8, 11,  7, 11, 13],
          [ 8, 11, 11, 11, 13],
          [11, 11,  3,  7, 11],
          [11, 11,  3,  7, 11],
          [11, 11,  3,  7, 13],
          [11, 11,  3, 11, 11],
          [11, 11,  3, 11, 13],
          [11, 11,  3, 11, 13],
          [11, 11,  7, 11, 11],
          [11, 11,  7, 11, 13],
          [11, 11,  7, 11, 13],
          [11, 11, 11, 11, 13]]]])


In [8]:
suits = torch.tensor([[[[3, 4, 2, 4, 1],
          [3, 4, 2, 4, 3],
          [3, 4, 2, 4, 1],
          [3, 4, 2, 1, 3],
          [3, 4, 2, 1, 1],
          [3, 4, 2, 3, 1],
          [3, 4, 4, 1, 3],
          [3, 4, 4, 1, 1],
          [3, 4, 4, 3, 1],
          [3, 4, 1, 3, 1],
          [3, 2, 2, 4, 1],
          [3, 2, 2, 4, 3],
          [3, 2, 2, 4, 1],
          [3, 2, 2, 1, 3],
          [3, 2, 2, 1, 1],
          [3, 2, 2, 3, 1],
          [3, 2, 4, 1, 3],
          [3, 2, 4, 1, 1],
          [3, 2, 4, 3, 1],
          [3, 2, 1, 3, 1],
          [3, 4, 2, 4, 1],
          [3, 4, 2, 4, 3],
          [3, 4, 2, 4, 1],
          [3, 4, 2, 1, 3],
          [3, 4, 2, 1, 1],
          [3, 4, 2, 3, 1],
          [3, 4, 4, 1, 3],
          [3, 4, 4, 1, 1],
          [3, 4, 4, 3, 1],
          [3, 4, 1, 3, 1],
          [4, 2, 2, 4, 1],
          [4, 2, 2, 4, 3],
          [4, 2, 2, 4, 1],
          [4, 2, 2, 1, 3],
          [4, 2, 2, 1, 1],
          [4, 2, 2, 3, 1],
          [4, 2, 4, 1, 3],
          [4, 2, 4, 1, 1],
          [4, 2, 4, 3, 1],
          [4, 2, 1, 3, 1],
          [4, 4, 2, 4, 1],
          [4, 4, 2, 4, 3],
          [4, 4, 2, 4, 1],
          [4, 4, 2, 1, 3],
          [4, 4, 2, 1, 1],
          [4, 4, 2, 3, 1],
          [4, 4, 4, 1, 3],
          [4, 4, 4, 1, 1],
          [4, 4, 4, 3, 1],
          [4, 4, 1, 3, 1],
          [2, 4, 2, 4, 1],
          [2, 4, 2, 4, 3],
          [2, 4, 2, 4, 1],
          [2, 4, 2, 1, 3],
          [2, 4, 2, 1, 1],
          [2, 4, 2, 3, 1],
          [2, 4, 4, 1, 3],
          [2, 4, 4, 1, 1],
          [2, 4, 4, 3, 1],
          [2, 4, 1, 3, 1]]]])

In [9]:
net_inputs = torch.stack((ranks,suits))

In [10]:
net_inputs = net_inputs.squeeze(1).squeeze(1)

In [11]:
net_inputs = net_inputs.permute(1,2,0)

In [12]:
net_inputs.shape

torch.Size([60, 5, 2])

In [13]:
outputs = net(net_inputs)

torch.Size([60, 5, 15]) torch.Size([60, 5, 5])


In [14]:
torch.argmax(torch.softmax(outputs,dim=-1),dim=-1)

tensor([4798, 4798, 4734, 2862, 4716, 4716, 2858, 4712, 4712, 2854, 4174, 4174,
        6850, 1857, 4069, 4069, 4129, 4065, 4065, 1822, 4174, 4174, 6850, 1857,
        4069, 4069, 4129, 4065, 4065, 1822, 4174, 4174, 6850, 1857, 4069, 4069,
        4129, 4065, 4065, 4065, 4174, 4174, 6850, 1857, 4069, 4069, 1853, 4065,
        4065, 1822, 1862, 1862, 4074,   57, 1827, 1827,   53, 1823, 4065,   48])