In [6]:
import json
import os
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader

from fastai.data.all import *
from fastai.learner import *
from fastai.optimizer import *
from fastai.metrics import *
from fastai.interpret import *

from game_engine import *

In [2]:
#!conda install --yes --prefix {sys.prefix} -c fastchan fastai

In [4]:
def read_records(path):
    with open(path) as f:
        return json.load(f)

def load_data(limit=None):
    DATA_DIR = "bidder-data"
    records = []
    for filename in os.listdir(DATA_DIR):
        if filename.endswith(".json"):
            path = os.path.join(DATA_DIR, filename)
            records.extend(read_records(path))
        if limit is not None and len(records) > limit:
            break
    return records

In [5]:
ex = load_data(limit = 10000)
len(ex)

10500

In [35]:
class FirstCallModel(nn.Module):
    def __init__(self):
        super(FirstCallModel, self).__init__()
        embed_dims = 64
        hidden_dims = 64
        # self.embed = nn.EmbeddingBag(len(CARDS), embed_dims)
        self.stack = nn.Sequential(
            nn.Linear(len(CARDS), hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, len(CALLS)),
        )

    def forward(self, x):
        # features = self.embed(x)
        return self.stack(x)


In [36]:
model = FirstCallModel()
print(model)

FirstCallModel(
  (stack): Sequential(
    (0): Linear(in_features=52, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=38, bias=True)
  )
)


In [37]:
def get_x(record):
    board = get_board_from_identifier(record["board"])
    dealer = board["dealer"]
    return board["hands"][dealer]

def get_y(record):
    return record["calls"][0]

In [38]:
calls = DataBlock(
    (MultiCategoryBlock(vocab=CARDS), CategoryBlock(vocab=CALLS)),
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_x=get_x,
    get_y=get_y)

In [39]:
dls = calls.dataloaders(ex)

dls.valid.show_batch(max_n=4, nrows=1)

6C;TC;KC;3D;TD;QD;AD;4H;8H;QH;2S;4S;JS
5C;7C;TC;KC;5D;JD;QD;AD;2H;8H;QH;2S;AS
2C;6C;8C;JC;QC;9D;TD;QD;KD;AD;2H;KH;8S
2C;4C;6C;9C;TC;JC;7D;JD;2H;8H;3S;5S;JS
P
1N
1D
P


In [40]:
learn = Learner(dls, model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [41]:
learn.fit(10)

[0, 1.6102399826049805, 1.3720250129699707, 0.5533333420753479, '00:07']
[1, 1.1104165315628052, 0.9801210165023804, 0.6819047331809998, '00:07']
[2, 0.8262050151824951, 0.751828670501709, 0.7900000214576721, '00:07']
[3, 0.6717875003814697, 0.6250484585762024, 0.8447619080543518, '00:08']
[4, 0.5737388730049133, 0.5501239895820618, 0.8409523963928223, '00:07']
[5, 0.511540412902832, 0.49650174379348755, 0.8547618985176086, '00:08']
[6, 0.471888929605484, 0.4637906849384308, 0.8566666841506958, '00:07']
[7, 0.44418567419052124, 0.434711754322052, 0.8666666746139526, '00:07']
[8, 0.4098462760448456, 0.4091794788837433, 0.8709523677825928, '00:07']
[9, 0.39280152320861816, 0.3977205157279968, 0.8633333444595337, '00:07']


In [29]:
get_board_from_identifier("6-6523f1878e3deb685418c8cbb5")

{'dealer': 'E',
 'vulnerability': 'E-W',
 'hands': {'N': ['6C',
   '8C',
   'QC',
   '2D',
   '6D',
   '9D',
   '7H',
   'JH',
   'QH',
   '2S',
   '4S',
   '6S',
   '8S'],
  'E': ['2C',
   '4C',
   '5C',
   'KC',
   '3D',
   'QD',
   '4H',
   '8H',
   '9H',
   'TH',
   'KH',
   'KS',
   'AS'],
  'S': ['3C',
   '7C',
   'AC',
   '5D',
   '8D',
   'AD',
   '2H',
   '5H',
   '6H',
   'AH',
   '5S',
   '9S',
   'JS'],
  'W': ['9C',
   'TC',
   'JC',
   '4D',
   '7D',
   'TD',
   'JD',
   'KD',
   '3H',
   '3S',
   '7S',
   'TS',
   'QS']}}

In [42]:
guess = learn.predict({"board": "6-6523f1878e3deb685418c8cbb5"})

In [43]:
type(guess[0])

fastai.data.transforms.Category

In [44]:
learn.dls.vocab

(#2) [['2C', '3C', '4C', '5C', '6C', '7C', '8C', '9C', 'TC', 'JC', 'QC', 'KC', 'AC', '2D', '3D', '4D', '5D', '6D', '7D', '8D', '9D', 'TD', 'JD', 'QD', 'KD', 'AD', '2H', '3H', '4H', '5H', '6H', '7H', '8H', '9H', 'TH', 'JH', 'QH', 'KH', 'AH', '2S', '3S', '4S', '5S', '6S', '7S', '8S', '9S', 'TS', 'JS', 'QS', 'KS', 'AS'],['1C', '1D', '1H', '1N', '1S', '2C', '2D', '2H', '2N', '2S', '3C', '3D', '3H', '3N', '3S', '4C', '4D', '4H', '4N', '4S', '5C', '5D', '5H', '5N', '5S', '6C', '6D', '6H', '6N', '6S', '7C', '7D', '7H', '7N', '7S', 'P', 'X', 'XX']]

In [45]:
guess

("['1H', '1C', '1S', '1H', '7N', 'P', '7D', 'X', 'X', '6S', '7N', '6S', 'X', '6N', '6N', '7D', '7C', '7N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6N', '6S', '6N', 'X', '6N', '6N']",
 TensorMultiCategory([  2.9963,  -0.4025,   4.5483,   2.3362,  -5.2607,  -3.0720,  -7.0893,
          -2.3715,  -2.3027,  -9.9483,  -5.9619,  -9.0173,  -2.3082, -10.4847,
         -10.7180,  -7.1170,  -8.8305,  -5.0130, -10.3853, -10.1604, -10.6705,
         -10.4270, -10.0458, -10.1629, -10.0986, -10.0894, -10.7580, -10.5821,
         -10.2830, -10.0075, -10.4226, -10.0135, -10.3547,  -9.8596, -10.3355,
          -2.0653, -10.3591, -10.3004]),
 TensorMultiCategory([  2.9963,  -0.4025,   4.5483,   2.3362,  -5.2607,  -3.0720,  -7.0893,
          -2.3715,  -2.3027,  -9.9483,  -5.9619,  -9.0173,  -2.3082, -10.4847,
         -10.7180,  -7.1170,  -8.8305,  -5.0130, -10.3853, -10.1604, -10.6705,
         -10.4270, -10.0458, -10.1629, -10.0986, -10.0894, -10.7580, -10.