In [12]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import sys
sys.path.append("..")

import torch
from torch.utils.data import DataLoader

from bridge.rnn.datasets import BidDataset
from bridge.rnn.models import HandsClassifier

In [13]:
dataset = BidDataset("../data/test.txt")
dataset

<bridge.rnn.datasets.BidDataset at 0x7f78613fef10>

In [14]:
loader = DataLoader(dataset, batch_size=256)
loader

<torch.utils.data.dataloader.DataLoader at 0x7f78600c3340>

In [15]:
checkpoint = torch.load("../ckpt/hands/20220605_2200/hands-epoch=12-valid_card_acc=0.47-valid_loss=0.26.ckpt")
checkpoint.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])

In [16]:
{
    **checkpoint["hyper_parameters"],
    "gru_input_size": 38,
    "gru_hidden_size": 38,
}

{'hand_hidden_size': 208,
 'gru_hidden_size': 38,
 'num_layers': 1,
 'dropout': 0.0,
 'bidirectional': True,
 'lr': 0.002,
 'weight_decay': 0,
 'gru_input_size': 38}

In [17]:
model = HandsClassifier(**{
    **checkpoint["hyper_parameters"],
    "gru_input_size": 38,
    "gru_hidden_size": 38,
})
model

HandsClassifier(
  (gru): GRU(38, 38, batch_first=True, bidirectional=True)
  (fc): Sequential(
    (0): Dropout(p=0.0, inplace=False)
    (1): Linear(in_features=284, out_features=208, bias=True)
  )
  (sigmoid): Sigmoid()
  (loss): BCELoss()
)

In [18]:
pretrained_state_dict = checkpoint["state_dict"]
model_state_dict = model.state_dict()

In [19]:
{k: v.shape for k, v in pretrained_state_dict.items()}

{'gru.weight_ih_l0': torch.Size([108, 36]),
 'gru.weight_hh_l0': torch.Size([108, 36]),
 'gru.bias_ih_l0': torch.Size([108]),
 'gru.bias_hh_l0': torch.Size([108]),
 'gru.weight_ih_l0_reverse': torch.Size([108, 36]),
 'gru.weight_hh_l0_reverse': torch.Size([108, 36]),
 'gru.bias_ih_l0_reverse': torch.Size([108]),
 'gru.bias_hh_l0_reverse': torch.Size([108]),
 'fc.1.weight': torch.Size([208, 280]),
 'fc.1.bias': torch.Size([208])}

In [20]:
{k: v.shape for k, v in model_state_dict.items()}

{'gru.weight_ih_l0': torch.Size([114, 38]),
 'gru.weight_hh_l0': torch.Size([114, 38]),
 'gru.bias_ih_l0': torch.Size([114]),
 'gru.bias_hh_l0': torch.Size([114]),
 'gru.weight_ih_l0_reverse': torch.Size([114, 38]),
 'gru.weight_hh_l0_reverse': torch.Size([114, 38]),
 'gru.bias_ih_l0_reverse': torch.Size([114]),
 'gru.bias_hh_l0_reverse': torch.Size([114]),
 'fc.1.weight': torch.Size([208, 284]),
 'fc.1.bias': torch.Size([208])}

In [36]:
for k, v in model_state_dict.items():
    p = pretrained_state_dict[k]
    v = torch.zeros(v.shape)
    if k.startswith("fc"):
        if v.dim() == 2:
            v[:p.size(0), :p.size(1)] = p
        else: # == 1
            v[:p.size(0)] = p
    else: # GRU
        if v.dim() == 2:
            offset = 0
            for pp in p.split(36, dim=0):
                v[offset:offset+pp.size(0), :pp.size(1)] = pp
                offset += 38
        else: # == 1
            offset = 0
            for pp in p.split(36, dim=0):
                v[offset:offset+pp.size(0)] = pp
                offset += 38
    model_state_dict[k] = v

In [37]:
model.load_state_dict(model_state_dict)
model = model.cuda()

In [38]:
from tqdm import tqdm
import torch

hints = []
# results = []
hinted_results = []
# null_results = []
random_results = []
targets = []
biddings = []
with torch.no_grad():
    for masked_hand, bidding, length, target in tqdm(loader):
        masked_hand = masked_hand.cuda()
        bidding = torch.concat([bidding, torch.zeros((bidding.size(0), bidding.size(1), 2))], dim=2)
        biddings.append(bidding)
        bidding = bidding.cuda()
        output = model(masked_hand, bidding, length)
        hints.append(masked_hand.cpu())
        targets.append(target)
        # results.append(model.greedy_generate(output.cpu()))
        hinted_results.append(model.greedy_generate(output.cpu(), hints=masked_hand.cpu()))
        # null_results.append(model.greedy_generate(torch.zeros(output.shape), hints=masked_hand.cpu()))
        random_results.append(model.greedy_generate(torch.rand(output.shape), hints=masked_hand.cpu()))
hints = torch.concat(hints)
targets = torch.concat(targets)
# results = torch.concat(results)
hinted_results = torch.concat(hinted_results)
# null_results = torch.concat(null_results)
random_results = torch.concat(random_results)
biddings = torch.concat(biddings)

100%|██████████| 40/40 [00:38<00:00,  1.03it/s]


In [39]:
# print(f"Result: {model.get_accuracy(results, targets)}")
print(f"Hinted Result: {model.get_accuracy(hinted_results, targets)}")
print(f"Hinted Result (out of hints): {model.get_accuracy(hinted_results, targets, hints=hints)}")
# print(f"Null Result: {model.get_accuracy(null_results, targets)}")
# print(f"Null Result (out of hints): {model.get_accuracy(null_results, targets, hints=hints)}")
print(f"Random Result: {model.get_accuracy(random_results, targets)}")
print(f"Random Result (out of hints): {model.get_accuracy(random_results, targets, hints=hints)}")

Hinted Result: (0.6265807747840881, 0.37439998984336853, 0.0)
Hinted Result (out of hints): (0.40310224890708923, 0.0, 0.0)
Random Result: (0.6251288056373596, 0.37439998984336853, 0.0)
Random Result (out of hints): (0.40078142285346985, 0.0, 0.0)
