In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import sys
sys.path.append("..")

from torch.utils.data import DataLoader

from bridge.rnn.datasets import BidDataset
from bridge.rnn.models import HandsClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = BidDataset("../data/test.txt")
dataset

<bridge.rnn.datasets.BidDataset at 0x7f55103f1c70>

In [5]:
loader = DataLoader(dataset, batch_size=16)
loader

<torch.utils.data.dataloader.DataLoader at 0x7f537eb5e1c0>

In [6]:
model = HandsClassifier.load_from_checkpoint("../ckpt/hands/20220602_0740/hands-epoch=40-valid_card_acc=0.46-valid_loss=0.28.ckpt").cuda()
model

HandsClassifier(
  (gru): GRU(36, 36, batch_first=True, bidirectional=True)
  (fc): Sequential(
    (0): Dropout(p=0.0, inplace=False)
    (1): Linear(in_features=280, out_features=208, bias=True)
    (2): Sigmoid()
  )
  (loss): BCELoss()
)

In [106]:
import torch

with torch.no_grad():
    for masked_hand, bidding, length, target in loader:
        masked_hand = masked_hand.cuda()
        bidding = bidding.cuda()
        output = model(masked_hand, bidding, length)
        hinted_result = model.greedy_generate(output.cpu(), hints=masked_hand.cpu())
        null_result = model.greedy_generate(torch.zeros(output.shape), hints=masked_hand.cpu())
        random_result = model.greedy_generate(torch.rand(output.shape), hints=masked_hand.cpu())
        break

In [109]:
output = output.cpu()
hints = masked_hand.cpu()

In [143]:
n = 10

if hints is None:
    result = torch.zeros((output.size(0), n, output.size(1)))
else:
    result = hints.clone()
    result[result < 0] = 0
    hand_cnts = torch.stack(result.split(52, dim=1)).sum(dim=2)
    total_cnts = result.sum(dim=1)
    cards_selected = torch.stack(result.split(52, dim=1)).sum(dim=0) > 0
    # Reduce the probability of selected cards and players with full hands
    output = output - cards_selected.float().repeat(1, 4) - (hand_cnts.T == 13).float().repeat_interleave(52, dim=1)
    result = result.unsqueeze(1).repeat(1, n, 1)

In [161]:
import numpy as np
from torch.nn.functional import softmax

t = 1.0

for i, o in enumerate(output):
    for j in range(n):
        hand_cnt = hand_cnts[:, i].clone()
        total_cnt = total_cnts[i].clone()
        card_selected = cards_selected[i, :].clone()
        current_o = o.clone()
        while total_cnt < 52:
            probs = softmax(current_o / t, dim=0).numpy()
            probs /= probs.sum()
            indices = torch.tensor(np.random.choice(len(probs), size=52, replace=False, p=probs))
            side_indices = torch.div(indices, 52, rounding_mode='floor')
            card_indices = indices % 52
            for idx, side, card in zip(indices, side_indices, card_indices):
                if (hand_cnt[side] == 13) or card_selected[card]:
                    continue
                result[i, j, idx] = 1
                hand_cnt[side] += 1
                total_cnt += 1
                card_selected[card] = True
                # Reduce the probability of selected cards
                current_o[torch.arange(4) * 52 + card] -= 1
                if hand_cnt[side] == 13:
                    # Reduce the probability of players with full hands
                    current_o[side*52:(side+1)*52] -= 1

In [149]:
(torch.stack(result.split(52, dim=2), dim=2).sum(dim=2) == 1).all()

tensor(True)

In [157]:
torch.arange(4) * 52 + card

tensor([ 31,  83, 135, 187])

In [159]:
idx

tensor(83)