In [1]:
#import sys
#!conda install --yes --prefix {sys.prefix} -c fastchan fastai

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

In [2]:
import json
import os

In [3]:
def read_records(path):
    with open(path) as f:
        return json.load(f)

In [5]:
def load_data(limit=None):
    DATA_DIR = "bidder-data"
    records = []
    for filename in os.listdir(DATA_DIR):
        if filename.endswith(".json"):
            path = os.path.join(DATA_DIR, filename)
            records.extend(read_records(path))
        if limit is not None and len(records) > limit:
            break
    return records

In [6]:
ex = load_data(limit = 10000)

In [7]:
len(ex)

10500

In [8]:
from game_engine import *

In [9]:
from fastai.data.all import *
from fastai.learner import *
from fastai.optimizer import *
from fastai.metrics import *

In [10]:
class FirstCallModel(nn.Module):
    def __init__(self):
        super(FirstCallModel, self).__init__()
        embed_dims = 64
        hidden_dims = 64
        self.stack = nn.Sequential(
            nn.Linear(len(CARDS), embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, len(CALLS)),
        )

    def forward(self, x):
        logits = self.stack(x)
        return logits

In [11]:
model = FirstCallModel()

In [12]:
print(model)

FirstCallModel(
  (stack): Sequential(
    (0): Linear(in_features=52, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=38, bias=True)
  )
)


In [13]:
def get_x(record):
    board = get_board_from_identifier(record["board"])
    dealer = board["dealer"]
    return board["hands"][dealer]

In [14]:
def get_y(record):
    return record["calls"][0]

In [15]:
calls = DataBlock(
    (MultiCategoryBlock(vocab=CARDS), CategoryBlock(vocab=CALLS)),
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_x=get_x,
    get_y=get_y)

In [16]:
dls = calls.dataloaders(ex)

In [17]:
dls.valid.show_batch(max_n=4, nrows=1)

6C;TC;KC;3D;TD;QD;AD;4H;8H;QH;2S;4S;JS
5C;7C;TC;KC;5D;JD;QD;AD;2H;8H;QH;2S;AS
2C;6C;8C;JC;QC;9D;TD;QD;KD;AD;2H;KH;8S
2C;4C;6C;9C;TC;JC;7D;JD;2H;8H;3S;5S;JS
P
1N
1D
P


In [18]:
learn = Learner(dls, model, loss_func=nn.CrossEntropyLoss(), metrics=accuracy)

In [19]:
learn.fit(10)

[0, 1.5828391313552856, 1.3902479410171509, 0.5533333420753479, '00:08']
[1, 1.0506082773208618, 0.8968051075935364, 0.7433333396911621, '00:09']
[2, 0.7192568182945251, 0.6471738815307617, 0.8114285469055176, '00:08']
[3, 0.5653246641159058, 0.5014913082122803, 0.8628571629524231, '00:09']
[4, 0.4592345356941223, 0.43888789415359497, 0.8557142615318298, '00:08']
[5, 0.4077032506465912, 0.38646945357322693, 0.8761904835700989, '00:09']
[6, 0.3714393675327301, 0.36285993456840515, 0.8704761862754822, '00:09']
[7, 0.34942030906677246, 0.34484753012657166, 0.8723809719085693, '00:08']
[8, 0.328656941652298, 0.3403625190258026, 0.8752381205558777, '00:08']
[9, 0.3214888572692871, 0.317013680934906, 0.8899999856948853, '00:08']
