In [1]:
import numpy as np

def getSuit(card):
    return np.floor(card/9).astype('int')

def getRank(card):
    return np.floor(card%9).astype('int')
        
def getLastAction(player):
    return player.actions[-1].verb if len(player.actions) > 0 else None

def getNumCardsOnBoard(board):
    c = 0
    for pair in board:
        c += len(pair)
    return c

def getNumNotCovered(board):
    c = 0
    for a,b in board:
        if b == None:
            c += 1
    return c

def getNotCoveredBoardId(board, target):
    for i,(a,b) in enumerate(board):
        if a == target and b == None:
            return i
    return -1

def boardAttacksAllSameRank(board):
    # Empty board is all same rank for learning penalty purposes
    # Was assert
    if len(board) == 0:
        return True
    r = getRank(board[0][0])
    for a,b in board:
        if r != getRank(a):
            return False
    return True

def beats(over, under, trump):
    if getSuit(over) == getSuit(under):
        return getRank(over) > getRank(under)
    if getSuit(over) == getSuit(trump):
        return True
    return False

def allPass(players):
    for p in playrs:
        if getLastAction(p) != 'pass':
            return False
    return True

def rankOnBoard(board, card):
    r = getRank(card)
    for a,b in board:
        if getRank(a) == r and getRank(b) == r:
            return True
    return False

class Player:
    def __init__(self):
        self.hand = []
        self.actions = []
        
    def __str__(self):
        return f'{self.hand} {getRank(np.array(self.hand))}'
        
class Action:
    def __init__(self, player, defender, verb, card=None, target=None):
        self.player = player
        self.defender = defender
        self.verb = verb
        self.card = card
        self.target = target

class Game:
    def __init__(self, nPlayers):
        self.players = [Player() for i in range(nPlayers)]
        self.deck = np.arange(36)
        np.random.shuffle(self.deck)
        self.discard = []
        self.board = []
        self.trump = self.deck[0]
        self.attacker = self.players[0]
        self.defender = self.players[1]
        self.deal()
        self.attacker = self.getFirstAttacker()
        self.defender = self.players[(self.players.index(self.attacker)+1)%len(self.players)]
        self.turn = 0
        
    def __str__(self):
        return f'''{len(self.players)} players {len(self.deck)} cards in deck {self.trump} {getSuit(self.trump)} trump
{self.board}
{[str(p) for p in self.players]}
{self.players.index(self.attacker)} attacker {self.players.index(self.defender)} defender
'''
        
    def deal(self):
        d = self.getDirection()
        n = len(self.players)
        i = self.players.index(self.defender)+n
        dealOrder = np.arange(3*n)
        dealOrder = dealOrder[i:i+d*n:d]
        for i in dealOrder:
            h = self.players[i%n].hand
            m = 6-len(h)
            h[len(h):] = self.deck[len(self.deck)-1:len(self.deck)-m-1:-1]
            self.deck = self.deck[:len(self.deck)-m]
    
    def randomState(self):
        # Choose attacker and defender
        idx = np.random.randint(0,len(self.players))
        direc = np.random.randint(0,2)*2-1
        self.attacker = self.players[idx]
        self.defender = self.players[(idx+direc)%len(self.players)]
        # Random deck, trump, and random number of cards
        self.deck = np.arange(36)
        np.random.shuffle(self.deck)
        self.trump = self.deck[0]
        for i in range(len(self.players)):
            self.players[i].hand = []
            if np.random.rand() < 0.2:
                nc = np.random.randint(0,7)
            else:
                nc = np.random.randint(6,7)
            self.players[i].hand[:nc] = self.deck[len(self.deck)-1:len(self.deck)-1-nc:-1]
            self.deck = self.deck[:len(self.deck)-nc]
        # Random board (actual covers don't make sense)
        nc = np.random.randint(0,5)
        self.board = []
        for i in range(nc):
            if np.random.randint(0,2) == 0:
                self.board.append([self.deck[-1], None])
                self.deck = self.deck[:-1]
            else:
                self.board.append([self.deck[-1], self.deck[-2]])
                self.deck = self.deck[:-2]
        
    def getDirection(self):
        i = self.players.index(self.defender)
        j = self.players.index(self.attacker)
        d = i-j
        if abs(d) == 1:
            return d
        if i == 0 and j == len(self.players)-1:
            return 1
        if i == len(self.players)-1 and j == 0:
            return -1
        assert False
    
    def getFirstAttacker(self):
        s = getSuit(self.trump)
        m = np.zeros([len(self.players),2])
        for i,p in enumerate(self.players):
            h = np.array(p.hand)
            m[i,1] = np.amin(h%9)
            h[getSuit(h) != s] = 100
            m[i,0] = np.amin(h)
        if np.min(m[:,0]) < 100:
            return self.players[np.argmin(m[:,0])]
        else:
            return self.players[np.argmin(m[:,1])]
        
    def action(self, e):
        reward = torch.zeros(5).float().cuda()
        if e.player < len(self.players):
            reward[0] = 10
        else:
            return reward
        p = self.players[e.player]
        if e.defender == self.players.index(self.defender):
            reward[1] = 10
        if p == self.defender:
            if getLastAction(p) == 'pickup':
                return reward
            if e.verb in ['reverse', 'cover', 'pickup']:
                reward[2] = 10
            if e.verb == 'reverse':
                if p.hand.count(e.card) > 0:
                    reward[3] = 10
                if e.target is None:
                    reward[4] = 10
                if len(self.board) == 0:
                    reward[2] -= 2
                if getNumNotCovered(self.board) < len(self.board):
                    reward[2] -= 2
                if not boardAttacksAllSameRank(self.board):
                    reward[2] -= 2
                if len(self.board) > 0 and getRank(self.board[0][0]) != getRank(e.card):
                    reward[3] -= 2
                #board.append([e.card])
                #p.hand.remove(e.card)
                #self.attacker, self.defender = self.defender, self.attacker
                #p.actions.append(e)
            elif e.verb == 'cover':
                if p.hand.count(e.card) > 0:
                    reward[3] = 10
                if e.target is not None:
                    reward[4] = 10
                if e.card is not None and e.target is not None and not beats(e.card, e.target, self.trump):
                    reward[3] -= 2
                    reward[4] -= 2
                boardId = getNotCoveredBoardId(self.board, e.target)
                if boardId == -1:
                    reward[4] -= 2
                #self.board[boardId][1] = e.card
                #p.hand.remove(e.card)
                #p.actions.append(e)
            elif e.verb == 'pickup':
                if e.card is None:
                    reward[3] = 10
                if e.target is None:
                    reward[4] = 10
                #p.actions.append(e)
        else:
            if e.verb in ['play', 'pass']:
                reward[2] = 10
            if e.verb == 'play':
                if p.hand.count(e.card) > 0:
                    reward[3] = 10
                if e.target is None:
                    reward[4] = 10
                if getNumNotCovered(self.board) > len(self.defender.hand):
                    reward[2] -= 2
                if (len(self.board) == 0 and p == self.attacker) or rankOnBoard(self.board, e.card):
                    #p.hand.remove(e.card)
                    #p.actions.append(e)
                    pass
                else:
                    reward[3] -= 2
            elif e.verb == 'pass':
                if e.card is None:
                    reward[3] = 10
                if e.target is None:
                    reward[4] = 10
                #p.actions.append(e)
        return reward
    
g = Game(2)
print(g)

g.randomState()
print(g)

2 players 24 cards in deck 4 0 trump
[]
['[26, 15, 8, 5, 20, 10] [8 6 8 5 2 1]', '[2, 24, 16, 28, 13, 23] [2 6 7 1 4 5]']
1 attacker 0 defender

2 players 23 cards in deck 20 2 trump
[[2, None]]
['[14, 19, 7, 4, 27, 0] [5 1 7 4 0 0]', '[12, 13, 18, 3, 16, 17] [3 4 0 3 7 8]']
0 attacker 1 defender



In [4]:
g.action(Action(0,1,'play',32))

tensor([10., 10., 10., -2., 10.], device='cuda:0')

In [5]:
g.board

[[2, None]]

In [321]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn.dense import DenseGraphConv

def maskDiag(mat):
    return mat*(1-torch.eye(mat.shape[0]).float().cuda())

class DurakBoard(nn.Module):
    def __init__(self):
        super(DurakBoard, self).__init__()
        # p is short for pairs
        self.p0 = DenseGraphConv(37+37,200).float().cuda()
        
    def forward(self, board):
        p = torch.zeros(len(board)+1,37*2).float().cuda()
        
        for i,(a,b) in enumerate(board):
            p[i,a] = 1
            if b is not None:
                p[i,37+b] = 1
            else:
                p[i,37+36] = 1
        
        # Last position is readout
        p[len(board),:] = 1
        
        Ap = maskDiag(torch.ones(2*[p.shape[0]]).float().cuda())/p.shape[0]
        
        p = self.p0(p,Ap).squeeze(0)
        
        return p[-1]
        
class DurakQueryCardsBoard(nn.Module):
    def __init__(self):
        super(DurakQueryCardsBoard, self).__init__()
        self.fc0 = nn.Linear(200,36).float().cuda()
        
    def forward(self, board):
        x = self.fc0(board)
        return x

durakBoard = DurakBoard()
durakQueryCardsBoard = DurakQueryCardsBoard()

optim = torch.optim.Adam([{'params': mod.parameters()} for mod in [durakBoard, durakQueryCardsBoard]], 
                         lr=1e-4, weight_decay=0)

print('Complete')

Complete


In [330]:
torch.save(durakBoard, 'DurakBoard.pyt')
torch.save(durakQueryCardsBoard, 'DurakQueryCardsBoard.pyt')

print('Complete')

Complete


In [317]:
def recalcTgt():
    tgt = torch.zeros(1,36).float().cuda()
    card = torch.zeros(36,36).float().cuda()
    location = torch.zeros(36,1).float().cuda()

    idx = 0
    for c in range(36):
        for loc in range(1):
            card[idx,c] = 1
            location[idx,loc] = 1
            idx += 1
            for a,b in g.board:
                if a == c or b == c:
                    tgt[0,c] = 1
    
    return tgt, card, location

tgt, card, location = recalcTgt()

def multiCrossEntropyLoss(inp, tgt):
    A = torch.exp(inp)*tgt
    B = torch.exp(-inp)*(1-tgt)
    C = torch.exp(-inp)*tgt
    D = torch.exp(inp)*(1-tgt)
    return torch.log(torch.sum((C+D)/(A+B))+1)

print(torch.sum(tgt))
print(tgt)
print(tgt.softmax(dim=1))
print(multiCrossEntropyLoss(tgt, tgt))
print(torch.sum(card))
print(torch.sum(location))

tensor(4., device='cuda:0')
tensor([[0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0')
tensor([[0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0634, 0.0233, 0.0634, 0.0233,
         0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233,
         0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0634, 0.0634, 0.0233,
         0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233, 0.0233]],
       device='cuda:0')
tensor(3.5128, device='cuda:0')
tensor(36., device='cuda:0')
tensor(36., device='cuda:0')


In [318]:
tgta = tgt.detach().clone()
tgta[tgt == 0] = -5
tgta[tgt > 0] = 5
print(multiCrossEntropyLoss(tgta, tgt))
print(tgt)
print(tgta)

tensor(0.0016, device='cuda:0')
tensor([[0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0')
tensor([[-5., -5., -5., -5., -5.,  5., -5.,  5., -5., -5., -5., -5., -5., -5.,
         -5., -5., -5., -5., -5., -5., -5., -5., -5., -5.,  5.,  5., -5., -5.,
         -5., -5., -5., -5., -5., -5., -5., -5.]], device='cuda:0')


In [306]:
g.randomState()
tgt, card, location = recalcTgt()
print(g.board)
print(multiCrossEntropyLoss(tgt, tgt))

[[20, None], [35, None], [34, 14]]
tensor(0.8397, device='cuda:0')


In [307]:
print(tgt)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]],
       device='cuda:0')


In [323]:
nEpochs = 500_000
running = []
pPeriod = 1000
window = 100
nReset = 0

lossFn = nn.CrossEntropyLoss()

for epoch in range(nEpochs):
    optim.zero_grad()
    b = durakBoard(g.board)
    res = durakQueryCardsBoard(b).unsqueeze(0)
    loss = multiCrossEntropyLoss(res, tgt)
    loss.backward()
    optim.step()
#     running.append(loss)
#     if len(running) == window:
#         running.pop(0)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'epoch {epoch} loss {loss} nReset {nReset}')
        nReset = 0
    if loss < 1e-2:
        g.randomState()
        tgt, card, location = recalcTgt()
        nReset += 1

epoch 0 loss 0.0002574589161667973 nReset 0
epoch 1000 loss 3.2543604902457446e-05 nReset 973
epoch 2000 loss 2.5033637939486653e-05 nReset 988
epoch 3000 loss 2.658331868587993e-05 nReset 997
epoch 4000 loss 0.0005138983833603561 nReset 965
epoch 5000 loss 0.0002686616498976946 nReset 999
epoch 6000 loss 1.4781842764932662e-05 nReset 1000
epoch 7000 loss 0.002256944077089429 nReset 1000
epoch 8000 loss 0.001392586505971849 nReset 1000
epoch 9000 loss 0.001650877296924591 nReset 1000
epoch 10000 loss 0.00016544880054425448 nReset 1000


KeyboardInterrupt: 

In [328]:
g.randomState()
b = durakBoard(g.board).unsqueeze(0)
tgt, card, location = recalcTgt()
res = durakQueryCardsBoard(b)
print(torch.cat([res, tgt], dim=0))
print(multiCrossEntropyLoss(res, tgt))

tensor([[-5.7436, -6.0122, -6.1552, -5.3784, -6.1279, -5.5721, -5.6316, -6.6657,
         -4.7557, -6.0169,  7.8229, -5.4099, -5.3436, -5.6879, -6.1757,  6.5619,
         -6.2521, -5.8146, -6.5479,  7.2914, -5.8215, -6.0255, -6.3573, -5.8374,
          7.6083, -5.4972, -6.0702, -6.3186,  6.6246, -6.5314, -5.0410, -5.6181,
         -6.5129, -5.8469, -5.5663,  7.4019],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,
          0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          1.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  1.0000]], device='cuda:0',
       grad_fn=<CatBackward0>)
tensor(0.0004, device='cuda:0', grad_fn=<LogBackward0>)


In [None]:
from torch.distributions import Categorical, OneHotCategorical

verbs = ['reverse','cover','pickup','play','pass']

def makeAction(x):
    mod = [Categorical(logits=x) for x in [x[:6], x[6:12], x[12:17], x[17:17+37], x[17+37:]]]
    act = [m.sample() for m in mod] 
    me = int(act[0])
    defender = int(act[1])
    verb = verbs[act[2]]
    card = int(act[3])
    target = int(act[4])
    if card == 36:
        card = None
    if target == 36:
        target = None
    return Action(me, defender, verb, card, target), mod, act

def makeActionOneHot(act):
    x = torch.zeros(6+6+5+37+37).float().cuda()
    x[act.player] = 1
    x[6+act.defender] = 1
    x[6+6+verbs.index(act.verb)] = 1
    card = act.card if act.card is not None else 36
    target = act.target if act.target is not None else 36
    x[6+6+5+card] = 1
    x[6+6+5+37+target] = 1
    return x

class History:
    def __init__(self, discount, forget):
        self.hist = 5*[{}]
        self.discount = discount
        self.forget = forget
        
    def process(self, reward, act):
        disc = 0
        for a,d in zip([int(part) for part in act],self.hist):
            for k in d.keys():
                d[k] *= self.forget
            if a in d:
                d[a] += 1
            else:
                d[a] = 0
            disc += d[a]
        return reward - disc*self.discount

nEpochs = 10_000
running = []
pPeriod = 100
window = 50
hist = History(2,0.95)

for i in range(nEpochs):
    optim.zero_grad()
    xb, x = moves(g.board, g.attacker.hand, g.players.index(g.attacker), g.players.index(g.defender))
    x = gen(xb)
    action, mod, act = makeAction(x)
    actOneHot = makeActionOneHot(action)
    actionReward = g.action(action).detach()
    reward = torch.sum(actionReward)
    add = reward == 50
    reward = hist.process(reward, act)
    loss = sum([-(reward)*m.log_prob(a) for m,a in zip(mod,act)])
    loss.backward()
    if add and not gen.lookup(actOneHot):
        gen.add(actOneHot)
    optim.step()
    running.append(reward)
    if len(running) == window:
        running.pop(0)
    if i % pPeriod == 0:
        print(f'epoch {i} running {sum(running)/len(running)}')
        print(action.verb)

In [94]:
running[-10:]
# print(torch.stack(gen.notLegal[-10:]))
print(gen.num)
print(gen.legal[:10])

1
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

In [43]:
keys = sorted(hist.hist, key=hist.hist.get)
keys.reverse()
for k in keys[:20]:
    print(k,hist.hist[k])

013836 26.17921006037081
0132136 17.842742822579773
013336 15.198590524843553
0131836 14.510441507120012
0131236 13.964993451288487
0132336 1.8996215775555378
0133336 1.583071768619456
013236 1.2731325898092947
0132836 1.0348757202153298
0131136 1.0054513079341796
0132436 0.7994947495712481
0131230 0.6744559581129893
013815 0.6117180999749755
0132113 0.5739762378493748
5132136 0.5582661385478638
013828 0.5439750549573139
013813 0.31342453672908893
013830 0.23930717852278607
013310 0.2023022985967337
0131228 0.1697567740825631
