In [1]:
import json
import os
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader

from fastai.data.all import *
from fastai.learner import *
from fastai.optimizer import *
from fastai.metrics import *
from fastai.interpret import *

from game_engine import *

In [2]:
#!conda install --yes --prefix {sys.prefix} -c fastchan fastai

In [3]:
def read_records(path):
    with open(path) as f:
        return json.load(f)

def load_data(limit=None):
    DATA_DIR = "bidder-data"
    records = []
    for filename in os.listdir(DATA_DIR):
        if filename.endswith(".json"):
            path = os.path.join(DATA_DIR, filename)
            records.extend(read_records(path))
        if limit is not None and len(records) > limit:
            break
    return records

In [4]:
ex = load_data(limit = 100000)
len(ex)

100500

In [5]:
class FirstCallModel(nn.Module):
    def __init__(self):
        super(FirstCallModel, self).__init__()
        embed_dims = 32
        hidden_dims = 64
        self.network = nn.Sequential(
            nn.Linear(len(CARDS), embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, len(CALLS)),
        )

    def forward(self, x):
        return self.network(x)

In [6]:
model = FirstCallModel()
print(model)

FirstCallModel(
  (network): Sequential(
    (0): Linear(in_features=52, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=38, bias=True)
  )
)


In [7]:
def get_x(record):
    board = get_board_from_identifier(record["board"])
    dealer = board["dealer"]
    return board["hands"][dealer]

def get_y(record):
    return record["calls"][0]

In [8]:
calls = DataBlock(
    (MultiCategoryBlock(vocab=CARDS), CategoryBlock(vocab=CALLS)),
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_x=get_x,
    get_y=get_y)

In [9]:
dls = calls.dataloaders(ex)

dls.valid.show_batch(max_n=4, nrows=1)

3C;6C;TC;JC;QC;QD;KD;3H;6H;8H;JH;KH;KS
3C;4C;2D;3D;5D;9D;QD;KD;6H;9H;AH;7S;TS
4C;5C;7C;JC;QC;2D;6D;3H;9H;5S;7S;9S;TS
2C;TC;KC;6D;9D;JD;KD;2H;3H;7H;TH;QH;AH
1H
2D
P
1H


In [10]:
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)

In [11]:
learn.fit(10)

[0, 0.3436969220638275, 0.3299921154975891, 0.8819900751113892, '01:04']
[1, 0.2396089881658554, 0.2473532259464264, 0.9050248861312866, '01:04']
[2, 0.2073841392993927, 0.19623896479606628, 0.9237313270568848, '01:04']
[3, 0.1649974137544632, 0.15869030356407166, 0.9428855776786804, '01:03']
[4, 0.14295418560504913, 0.13530460000038147, 0.9532338380813599, '01:04']
[5, 0.11900314688682556, 0.11900375783443451, 0.9576616883277893, '01:03']
[6, 0.09826595336198807, 0.1006813794374466, 0.9640796184539795, '01:03']
[7, 0.08119724690914154, 0.08666863292455673, 0.9723383188247681, '01:04']
[8, 0.07418646663427353, 0.07853342592716217, 0.9720398187637329, '01:03']
[9, 0.06653723865747452, 0.06898316740989685, 0.9783582091331482, '01:04']


In [12]:
#get_board_from_identifier("6-6523f1878e3deb685418c8cbb5")

In [13]:
guess = learn.predict({"board": "3-f1a4bd93c331ae44b672e5d209"})
guess

('1N',
 TensorMultiCategory(3),
 TensorMultiCategory([2.7321e-02, 4.9154e-12, 2.2769e-09, 9.7265e-01, 2.8375e-05, 2.0464e-23,
         1.8532e-33, 2.3831e-35, 7.2237e-12, 1.1544e-25, 5.7445e-26, 2.0987e-37,
         9.4070e-38, 1.0991e-15, 4.5519e-35, 1.7990e-29, 0.0000e+00, 0.0000e+00,
         1.1612e-15, 1.4531e-40, 6.5857e-16, 7.8899e-16, 5.4065e-16, 6.3280e-16,
         1.8089e-15, 3.4075e-16, 1.8014e-15, 6.3265e-16, 8.6318e-16, 1.2617e-15,
         1.9135e-15, 3.6701e-16, 8.4928e-16, 3.6255e-15, 1.2613e-15, 2.1113e-15,
         2.0689e-15, 1.1817e-15]))

In [19]:
#learn.export("100k-model.pkl")
#learn_rst = load_learner("100k-model.pkl")
#learn_rst.predict({"board": "6-6523f1878e3deb685418c8cbb5"})

In [15]:
interp = ClassificationInterpretation.from_learner(learn)

In [21]:
bad = interp.top_losses(3, items=True)

In [31]:
bad[2][2]

{'rules': ['StrongTwoClubs',
  'DefaultPass',
  'NotrumpResponseToStrongTwoClubs',
  'DefaultPass',
  'None',
  'DefaultPass'],
 'board': '8-302a67ca590925fc81f523a5ff',
 'calls': ['2C', 'P', '2N', 'P', 'P', 'P']}

In [32]:
get_board_from_identifier("8-302a67ca590925fc81f523a5ff")

{'dealer': 'W',
 'vulnerability': 'None',
 'hands': {'N': ['2C',
   '4C',
   '5C',
   '6C',
   '2D',
   '9D',
   'TD',
   'KD',
   '7H',
   '9H',
   'TH',
   '3S',
   '5S'],
  'E': ['TC',
   'QC',
   '5D',
   '6D',
   '8D',
   'QD',
   '2H',
   '3H',
   'JH',
   'AH',
   '2S',
   '9S',
   'TS'],
  'S': ['7C',
   '8C',
   '9C',
   'JC',
   '3D',
   '4D',
   '7D',
   'JD',
   'AD',
   '8H',
   '4S',
   '7S',
   '8S'],
  'W': ['3C',
   'KC',
   'AC',
   '4H',
   '5H',
   '6H',
   'QH',
   'KH',
   '6S',
   'JS',
   'QS',
   'KS',
   'AS']}}

In [33]:
learn.predict({"board":"8-302a67ca590925fc81f523a5ff"})

('1S',
 TensorMultiCategory(4),
 TensorMultiCategory([2.9297e-03, 1.8323e-17, 1.5481e-02, 4.1887e-15, 9.7765e-01, 1.0642e-03,
         0.0000e+00, 0.0000e+00, 2.8772e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 3.2340e-24, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         1.1406e-24, 0.0000e+00, 4.8773e-25, 4.0415e-24, 5.9424e-25, 5.9681e-24,
         1.1316e-24, 3.0965e-24, 3.2288e-25, 3.7973e-26, 1.9639e-25, 5.3251e-25,
         1.3841e-24, 4.3719e-25, 7.9587e-24, 4.2739e-24, 4.2142e-25, 1.0466e-36,
         1.2731e-24, 6.7557e-25]))

In [17]:
interp.most_confused()

[('2S', 'P', 39),
 ('2D', 'P', 36),
 ('2H', 'P', 31),
 ('2C', '2N', 29),
 ('3C', 'P', 28),
 ('1D', '1N', 27),
 ('P', '1S', 16),
 ('P', '1C', 15),
 ('P', '3C', 15),
 ('3D', '2D', 14),
 ('1D', '2N', 13),
 ('2S', '1S', 11),
 ('1C', '2N', 10),
 ('1C', '1N', 9),
 ('4H', '3H', 9),
 ('P', '3S', 9),
 ('3H', '2H', 8),
 ('3H', 'P', 8),
 ('P', '1D', 8),
 ('P', '1H', 8),
 ('4C', '3C', 7),
 ('P', '2H', 7),
 ('4S', '3S', 6),
 ('1D', '3D', 5),
 ('1N', '1C', 5),
 ('1N', '1S', 5),
 ('1H', '3H', 4),
 ('4D', '3D', 4),
 ('P', '3H', 4),
 ('1H', '1S', 3),
 ('2C', '1H', 3),
 ('2C', '1S', 3),
 ('2N', '1S', 3),
 ('3C', '1C', 3),
 ('P', '2D', 3),
 ('P', '2S', 3),
 ('1C', '3C', 2),
 ('1C', 'P', 2),
 ('1S', '2N', 2),
 ('2N', '1C', 2),
 ('3D', '1D', 2),
 ('3D', 'P', 2),
 ('3H', '1H', 2),
 ('1D', '2D', 1),
 ('1H', '2N', 1),
 ('1H', 'P', 1),
 ('1N', '1H', 1),
 ('1S', '3S', 1),
 ('2N', '1H', 1),
 ('3S', '1S', 1),
 ('3S', 'P', 1),
 ('4C', 'P', 1),
 ('P', '3D', 1)]