In [1]:
import json
import os
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader

from fastai.data.all import *
from fastai.learner import *
from fastai.optimizer import *
from fastai.metrics import *
from fastai.interpret import *

from game_engine import *

In [2]:
#!conda install --yes --prefix {sys.prefix} -c fastchan fastai

In [3]:
def read_records(path):
    with open(path) as f:
        return json.load(f)

def load_data(limit=None):
    DATA_DIR = "bidder-data"
    records = []
    for filename in os.listdir(DATA_DIR):
        if filename.endswith(".json"):
            path = os.path.join(DATA_DIR, filename)
            records.extend(read_records(path))
        if limit is not None and len(records) > limit:
            break
    return records

In [4]:
ex = load_data(limit = 10000)
len(ex)

10500

In [12]:
class FirstCallModel(nn.Module):
    def __init__(self):
        super(FirstCallModel, self).__init__()
        embed_dims = 32
        hidden_dims = 64
        self.network = nn.Sequential(
            nn.Linear(len(CARDS), embed_dims),
            nn.ReLU(),
            nn.Linear(embed_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, len(CALLS)),
        )

    def forward(self, x):
        return self.network(x)

In [13]:
model = FirstCallModel()
print(model)

FirstCallModel(
  (network): Sequential(
    (0): Linear(in_features=52, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=38, bias=True)
  )
)


In [14]:
def get_x(record):
    board = get_board_from_identifier(record["board"])
    dealer = board["dealer"]
    return board["hands"][dealer]

def get_y(record):
    return record["calls"][0]

In [15]:
calls = DataBlock(
    (MultiCategoryBlock(vocab=CARDS), CategoryBlock(vocab=CALLS)),
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_x=get_x,
    get_y=get_y)

In [16]:
dls = calls.dataloaders(ex)

dls.valid.show_batch(max_n=4, nrows=1)

6C;TC;KC;3D;TD;QD;AD;4H;8H;QH;2S;4S;JS
5C;7C;TC;KC;5D;JD;QD;AD;2H;8H;QH;2S;AS
2C;6C;8C;JC;QC;9D;TD;QD;KD;AD;2H;KH;8S
2C;4C;6C;9C;TC;JC;7D;JD;2H;8H;3S;5S;JS
P
1N
1D
P


In [23]:
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)

In [24]:
learn.fit(10)

[0, 0.30785059928894043, 0.31195273995399475, 0.8871428370475769, '00:06']
[1, 0.3044830858707428, 0.3158399164676666, 0.8895238041877747, '00:07']
[2, 0.28620657324790955, 0.2956961989402771, 0.8876190185546875, '00:06']
[3, 0.27962198853492737, 0.28526052832603455, 0.8966666460037231, '00:06']
[4, 0.2730786204338074, 0.29114019870758057, 0.8899999856948853, '00:06']
[5, 0.28293636441230774, 0.28194937109947205, 0.8904761672019958, '00:06']
[6, 0.2651417553424835, 0.28510308265686035, 0.8857142925262451, '00:06']
[7, 0.2592635750770569, 0.2771367132663727, 0.8976190686225891, '00:06']
[8, 0.2550404369831085, 0.2666270136833191, 0.8971428275108337, '00:07']
[9, 0.24958106875419617, 0.27550047636032104, 0.8942857384681702, '00:06']


In [None]:
#get_board_from_identifier("6-6523f1878e3deb685418c8cbb5")

In [25]:
guess = learn.predict({"board": "6-6523f1878e3deb685418c8cbb5"})
guess

('1H',
 TensorMultiCategory(2),
 TensorMultiCategory([5.0569e-02, 1.1625e-03, 8.0847e-01, 1.3979e-01, 3.4253e-08, 1.4969e-07,
         3.8897e-12, 6.4868e-08, 8.0385e-06, 8.9395e-14, 7.6276e-11, 2.9659e-14,
         1.8703e-10, 2.0775e-10, 5.7927e-19, 1.8504e-09, 3.3493e-15, 2.0638e-09,
         4.1511e-10, 6.5685e-10, 1.7889e-09, 5.1221e-10, 5.3024e-10, 3.6880e-10,
         2.8066e-09, 8.1370e-10, 7.6159e-10, 7.7340e-10, 2.2054e-10, 6.6506e-10,
         2.4187e-10, 4.6417e-10, 6.4745e-10, 4.1299e-10, 3.7143e-10, 9.9636e-07,
         3.6346e-10, 2.5795e-10]))

In [None]:
#learn.export("100k-model.pkl")
#learn_rst = load_learner("100k-model.pkl")
#learn_rst.predict({"board": "6-6523f1878e3deb685418c8cbb5"})