In [85]:
import einops
import torch.nn as nn
import torch.nn.functional as F
import torch
from utils import state_mapping


class SelfAttention(nn.Module):
    def __init__(self,head_size,n_embd,block_size,dropout):
        super().__init__()
        self.head_size = head_size
        self.keys = nn.Linear(n_embd,head_size,bias=False)
        self.queries = nn.Linear(n_embd,head_size,bias=False)
        self.values = nn.Linear(n_embd,head_size,bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('tril',torch.tril(torch.ones((block_size,block_size))))


    def forward(self,x):
        # x.shape = B,T,C
        B,T,C = x.shape
        k = self.keys(x)
        q = self.queries(x)
        v = self.values(x)
        qk = (q @ einops.rearrange(k,'b t c -> b c t')) * self.head_size**0.5 # (b t t)
        qk = qk.masked_fill(self.tril[:T,:T] == 0,float('-inf'))
        qk = F.softmax(qk,dim=-1)
        qk = self.dropout(qk)
        out = qk @ v # (b t c)
        return out

class MultiHeadedAttention(nn.Module):
    def __init__(self,n_heads,head_size,n_embd,dropout,block_size):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(head_size,n_embd,block_size,dropout) for _ in range(n_heads)])
        self.projection = nn.Linear(n_embd,n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads],dim=-1)
        out = self.dropout(self.projection(out))
        return out

class FeedForward(nn.Module):
    def __init__(self,n_embd,dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd,4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd,n_embd),
            nn.Dropout(dropout)
        )

    def forward(self,x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self,n_embd,n_heads,dropout,block_size):
        super().__init__()
        self.head = MultiHeadedAttention(n_heads,n_embd//n_heads,n_embd,dropout,block_size)
        self.ff = FeedForward(n_embd,dropout)
        self.layernorm1 = nn.LayerNorm(n_embd)
        self.layernorm2 = nn.LayerNorm(n_embd)

    def forward(self,x):
        x = x + self.head(self.layernorm1(x))
        x = x + self.ff(self.layernorm2(x))
        return x

class Transformer(nn.Module):
    def __init__(self,n_embd,n_heads,dropout,block_size,action_size,n_layers):
        super().__init__()
        self.tblocks = nn.Sequential(
            *[TransformerBlock(n_embd,n_heads,dropout,block_size) for _ in range(n_layers)]
            )
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd,action_size)
        self.emb_position = nn.Embedding(7, 6, padding_idx=0)
        self.emb_action = nn.Embedding(12, 6, padding_idx=0)

    def forward(self, state):
        """ 
        state.shape = (B,T,49)
        T = 24 (max length of game)
        Game state is a vector of length 49
        We combine the T and state dimensions -> (B,T*49)
        This becomes our B by tokens input to the transformer.
        """
        B, M, C = state.shape
        stats = state.float()
        pot = state[:, :, state_mapping["pot"]] # (B, T, 1)
        amnt_to_call = state[:, :, state_mapping["amount_to_call"]] # (B, T, 1)
        previous_amount = state[:, :, state_mapping["previous_amount"]] # (B, T, 1)
        action = self.emb_action(state[:, :, state_mapping["previous_action"]].long()) # (B, T, 6)
        position = self.emb_position(
            state[:, :, state_mapping["previous_position"]].long()
        ) # (B, T, 6)
        x = torch.cat([pot.unsqueeze(-1), amnt_to_call.unsqueeze(-1), previous_amount.unsqueeze(-1), action, position], dim=-1).float() # (B, T, 3 + 6 + 6)
        # We flatten the sequence and convert the input to a (B, 15 * 24) tensor
        x = self.tblocks(x)
        x = self.ln_f(x)
        x = self.lm_head(x)
        return x

In [90]:
import numpy as np
# from models import Simple,Transformer
from torch.optim import AdamW
import torch.nn.functional as F
from config import Config
import torch


def train_network(training_params, game_states, target_actions, target_rewards,config):
    model = Transformer(config.n_embd,config.n_heads,config.dropout,config.block_size,config.action_size,config.n_layers)
    optimizer = AdamW(model.parameters(), lr=0.003)
    for e in range(training_params["epochs"]):
        out = model(game_states)
        print("out", out.shape)
        loss = F.cross_entropy(out[:,-1,:], target_actions)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())

    return model


In [99]:

config = Config()
config.n_embd = 15
config.n_heads = 3
training_params = {
    "epochs": 10,
}
game_states = torch.from_numpy(np.load("data/states.npy"))
target_actions = torch.from_numpy(np.load("data/actions.npy") - 1)
target_rewards = torch.from_numpy(np.load("data/rewards.npy"))

In [100]:
print(game_states[0].shape)
print(set(target_actions.tolist()))
print(state_mapping)

torch.Size([24, 49])
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
{'hand_range': [0, 8], 'board_range': [8, 18], 'street': 18, 'num_players': 19, 'hero_pos': 20, 'hero_active': 21, 'vil1_active': 22, 'vil2_active': 23, 'vil3_active': 24, 'vil4_active': 25, 'vil5_active': 26, 'next_player': 27, 'next_player2': 28, 'next_player3': 29, 'next_player4': 30, 'next_player5': 31, 'last_agro_amount': 32, 'last_agro_action': 33, 'last_agro_position': 34, 'last_agro_is_blind': 35, 'hero_stack': 36, 'vil1_stack': 37, 'vil2_stack': 38, 'vil3_stack': 39, 'vil4_stack': 40, 'vil5_stack': 41, 'pot': 42, 'amount_to_call': 43, 'pot_odds': 44, 'previous_amount': 45, 'previous_position': 46, 'previous_action': 47, 'previous_bet_is_blind': 48}


In [101]:

train_network(training_params, game_states, target_actions, target_rewards,config)

out torch.Size([10595, 24, 11])
2.520909547805786
out torch.Size([10595, 24, 11])
2.2540693283081055
out torch.Size([10595, 24, 11])
2.1148977279663086
out torch.Size([10595, 24, 11])
2.035886526107788
out torch.Size([10595, 24, 11])
1.975250482559204
out torch.Size([10595, 24, 11])
1.9303780794143677
out torch.Size([10595, 24, 11])
1.8890833854675293
out torch.Size([10595, 24, 11])
1.8523496389389038
out torch.Size([10595, 24, 11])
1.8216984272003174
out torch.Size([10595, 24, 11])
1.7929282188415527


# Check predictions and visualize each poker situation.