# Load and test network on any input

In [43]:
import torch
import numpy as np

import poker_env.datatypes as pdt
from poker_env.env import Poker
from poker_env.config import Config
from models.networks import OmahaActor,OmahaQCritic,OmahaObsQCritic,CombinedNet
from models.model_layers import ProcessHandBoard
from models.model_utils import strip_padding,unspool,hardcode_handstrength

In [4]:
config = Config()
game_object = pdt.Globals.GameTypeDict[pdt.GameTypes.OMAHAHI]

env_params = {
    'game':pdt.GameTypes.OMAHAHI,
    'betsizes': game_object.rule_params['betsizes'],
    'bet_type': game_object.rule_params['bettype'],
    'n_players': 2,
    'pot':1,
    'stacksize': game_object.state_params['stacksize'],
    'cards_per_player': game_object.state_params['cards_per_player'],
    'starting_street': game_object.starting_street,
    'global_mapping':config.global_mapping,
    'state_mapping':config.state_mapping,
    'obs_mapping':config.obs_mapping,
    'shuffle':True
}

env = Poker(env_params)

nS = env.state_space
nA = env.action_space
nB = env.betsize_space
seed = 1235
device = torch.device(cuda_dict[args.gpu] if torch.cuda.is_available() else "cpu")
gpu1 = 'cuda:0'
gpu2 = 'cuda:1'
network_params                                = config.network_params
network_params['device']                      = device

# Instantiate net and copy weights

In [23]:
def copy_weights(net,path):
    if torch.cuda.is_available():
        layer_weights = torch.load(path)
    else:
        layer_weights = torch.load(path,map_location=torch.device('cpu'))
    for name, param in net.named_parameters():
        if name in layer_weights:
            print('update_weights',name)
            param.data.copy_(layer_weights[name].data)
            param.requires_grad = False

In [17]:
net = ProcessHandBoard(network_params,hand_length=4)

In [24]:
copy_weights(net,config.network_params['actor_hand_recognizer_path'])

update_weights suit_conv.0.weight
update_weights suit_conv.0.bias
update_weights suit_conv.1.weight
update_weights suit_conv.1.bias
update_weights rank_conv.0.weight
update_weights rank_conv.0.bias
update_weights rank_conv.1.weight
update_weights rank_conv.1.bias
update_weights hidden_layers.0.weight
update_weights hidden_layers.0.bias
update_weights hidden_layers.1.weight
update_weights hidden_layers.1.bias
update_weights categorical_output.weight
update_weights categorical_output.bias


In [25]:
for name, param in net.named_parameters():
    print(name,param.requires_grad)

suit_conv.0.weight False
suit_conv.0.bias False
suit_conv.1.weight False
suit_conv.1.bias False
rank_conv.0.weight False
rank_conv.0.bias False
rank_conv.1.weight False
rank_conv.1.bias False
hand_out.weight True
hand_out.bias True
hidden_layers.0.weight False
hidden_layers.0.bias False
hidden_layers.1.weight False
hidden_layers.1.bias False
categorical_output.weight False
categorical_output.bias False


In [35]:
net_input = torch.tensor([[[ 8.,  3., 11.,  4.,  8.,  4., 11.,  2., 11.,  1., 13.,  1.,  7.,  4.,
          11.,  3.,  3.,  2.],
         [ 8.,  3., 11.,  4.,  8.,  4., 11.,  2., 11.,  1., 13.,  1.,  7.,  4.,
          11.,  3.,  3.,  2.]]])

In [39]:
net_input.size()

torch.Size([1, 2, 18])

# compare vs baseline

In [41]:
out = net(net_input)

tensor([[[[ 8,  8,  3,  7, 11],
          [ 8,  8,  3,  7, 11],
          [ 8,  8,  3,  7, 13],
          [ 8,  8,  3, 11, 11],
          [ 8,  8,  3, 11, 13],
          [ 8,  8,  3, 11, 13],
          [ 8,  8,  7, 11, 11],
          [ 8,  8,  7, 11, 13],
          [ 8,  8,  7, 11, 13],
          [ 8,  8, 11, 11, 13],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 13],
          [ 8, 11,  3, 11, 11],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  7, 11, 11],
          [ 8, 11,  7, 11, 13],
          [ 8, 11,  7, 11, 13],
          [ 8, 11, 11, 11, 13],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 11],
          [ 8, 11,  3,  7, 13],
          [ 8, 11,  3, 11, 11],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  3, 11, 13],
          [ 8, 11,  7, 11, 11],
          [ 8, 11,  7, 11, 13],
          [ 8, 11,  7, 11, 13],
          [ 8, 11, 11, 11, 13],
          [ 8, 11,  3,  7, 11],
        

NameError: name 'asdf' is not defined

In [38]:
print(hardcode_handstrength(net_input))

tensor([[[48.],
         [48.]]])
