In [1]:
import json
import os
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader

from fastai.data.all import *
from fastai.learner import *
from fastai.optimizer import *
from fastai.metrics import *
from fastai.interpret import *

from game_engine import *

In [2]:
#!conda install --yes --prefix {sys.prefix} -c fastchan fastai

In [3]:
def read_records(path):
    with open(path) as f:
        return json.load(f)

def load_data(limit=None):
    DATA_DIR = "bidder-data"
    records = []
    for filename in os.listdir(DATA_DIR):
        if filename.endswith(".json"):
            path = os.path.join(DATA_DIR, filename)
            records.extend(read_records(path))
        if limit is not None and len(records) > limit:
            break
    return records

In [14]:
ex = load_data(limit = 100000)
len(ex)

100500

In [5]:
class FirstCallModel(nn.Module):
    def __init__(self):
        super(FirstCallModel, self).__init__()
        embed_dims = 32
        hidden_dims = 64
        self.network = nn.Sequential(
            nn.Linear(len(CARDS), embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, len(CALLS)),
        )

    def forward(self, x):
        return self.network(x)


In [6]:
model = FirstCallModel()
print(model)

FirstCallModel(
  (network): Sequential(
    (0): Linear(in_features=52, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=38, bias=True)
  )
)


In [7]:
def get_x(record):
    board = get_board_from_identifier(record["board"])
    dealer = board["dealer"]
    return board["hands"][dealer]

def get_y(record):
    return record["calls"][0]

In [8]:
calls = DataBlock(
    (MultiCategoryBlock(vocab=CARDS), CategoryBlock(vocab=CALLS)),
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_x=get_x,
    get_y=get_y)

In [15]:
dls = calls.dataloaders(ex)

dls.valid.show_batch(max_n=4, nrows=1)

3C;6C;TC;JC;QC;QD;KD;3H;6H;8H;JH;KH;KS
3C;4C;2D;3D;5D;9D;QD;KD;6H;9H;AH;7S;TS
4C;5C;7C;JC;QC;2D;6D;3H;9H;5S;7S;9S;TS
2C;TC;KC;6D;9D;JD;KD;2H;3H;7H;TH;QH;AH
1H
2D
P
1H


In [16]:
learn = Learner(dls, model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [17]:
learn.fit(10)

[0, 0.2778310775756836, 0.2710099220275879, 0.8975622057914734, '01:03']
[1, 0.22101768851280212, 0.21754609048366547, 0.9166169166564941, '01:03']
[2, 0.18657149374485016, 0.1818346530199051, 0.9300000071525574, '01:02']
[3, 0.1503184586763382, 0.15930134057998657, 0.9388059973716736, '01:02']
[4, 0.13299152255058289, 0.13315653800964355, 0.9504975080490112, '01:02']
[5, 0.11138908565044403, 0.11270098388195038, 0.9613930583000183, '01:02']
[6, 0.10721228271722794, 0.09883623570203781, 0.9684079885482788, '01:03']
[7, 0.0919075757265091, 0.08849112689495087, 0.9706467390060425, '01:02']
[8, 0.07605760544538498, 0.08215095102787018, 0.9722885489463806, '01:03']
[9, 0.07031102478504181, 0.07381175458431244, 0.977661669254303, '01:03']


In [12]:
get_board_from_identifier("6-6523f1878e3deb685418c8cbb5")

{'dealer': 'E',
 'vulnerability': 'E-W',
 'hands': {'N': ['6C',
   '8C',
   'QC',
   '2D',
   '6D',
   '9D',
   '7H',
   'JH',
   'QH',
   '2S',
   '4S',
   '6S',
   '8S'],
  'E': ['2C',
   '4C',
   '5C',
   'KC',
   '3D',
   'QD',
   '4H',
   '8H',
   '9H',
   'TH',
   'KH',
   'KS',
   'AS'],
  'S': ['3C',
   '7C',
   'AC',
   '5D',
   '8D',
   'AD',
   '2H',
   '5H',
   '6H',
   'AH',
   '5S',
   '9S',
   'JS'],
  'W': ['9C',
   'TC',
   'JC',
   '4D',
   '7D',
   'TD',
   'JD',
   'KD',
   '3H',
   '3S',
   '7S',
   'TS',
   'QS']}}

In [13]:
guess = learn.predict({"board": "6-6523f1878e3deb685418c8cbb5"})
guess

("['1S', '1C', '2H', '1N', '7D', 'P', '6D', '7N', 'X', '5H', '6N', '5H', '7S', '6C', '4S', '6S', '5N', 'P', '6D', '6D', '6D', '6C', '6D', '6C', '6D', '6D', '6C', '6C', '6C', '6D', '6D', '6D', '6D', '6D', '6H', '7S', '6C', '6C']",
 TensorMultiCategory([  4.0435,   0.3618,   7.1648,   3.9504,  -7.0264,  -3.1271, -12.9049,
          -5.5908,  -2.4590, -16.0952, -10.6255, -16.3810,  -4.2499, -13.6865,
         -19.6547,  -9.3776, -15.4376,  -3.8003, -12.4712, -12.4213, -12.9135,
         -13.2232, -12.4455, -13.1009, -12.8377, -12.5504, -13.1488, -13.1714,
         -13.3027, -12.0333, -12.8069, -12.0041, -12.2694, -12.5018, -11.8812,
          -4.7238, -13.3926, -13.0429]),
 TensorMultiCategory([  4.0435,   0.3618,   7.1648,   3.9504,  -7.0264,  -3.1271, -12.9049,
          -5.5908,  -2.4590, -16.0952, -10.6255, -16.3810,  -4.2499, -13.6865,
         -19.6547,  -9.3776, -15.4376,  -3.8003, -12.4712, -12.4213, -12.9135,
         -13.2232, -12.4455, -13.1009, -12.8377, -12.5504, -13.1488, -1

In [19]:
#learn.export("100k-model.pkl")