In [3]:
import random

class Game:
    def __init__(self):
        self.board = []
        self.discard = []
        
    def randomizeBoard(self):
        self.board = []
        while random.random() > 0.2:
            a = random.randint(0,35)
            b = random.randint(0,35) if random.random() > 0.5 else None
            self.board.append([a,b])
            
    def randomizeDiscard(self):
        self.discard = []
        while random.random() > 0.2:
            self.discard.append(random.randint(0,35))
        
game = Game()
game.randomizeBoard()
game.randomizeDiscard()

print(game.board)
print(game.discard)

[[26, 0], [27, None], [31, 12]]
[26, 12, 26, 20]


In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical

def inspect(pile):
    return pile

def countBoard():
    return len([card for pair in game.board for card in pair])

def getUncoveredOnBoard():
    return [pair[0] for pair in game.board if pair[1] == None]

def canCoverCard(player, bottom, top):
    if player != game.defender:
        return False
    if player.hand.index(top) == -1:
        return False
    if getSuit(top) == getSuit(bottom):
        return getRank(top) > getRank(top)
    if getSuit(top) == getSuit(game.trump):
        return True
    return False

class Query:
    def __init__(self, verb, args=None):
        self.verb = verb
        self.args = args
        
    def oneHot(self):
        return F.one_hot(torch.Tensor([queries.index(self)]).long().cuda(),num_classes=10).float()
    
    def oneHotArg(self):
        return F.one_hot(torch.Tensor([queries.index(self)]).long().cuda(),num_classes=10).float()
            
    def __hash__(self):
        return hash(self.verb + str(self.args))
    
    def __eq__(self, other):
        try:
            return self.verb == other.verb and self.args == other.args
        except:
            return False
        
    def __str__(self):
        return f'{self.verb} {self.args}'

def getArgs(queries):
    args = []
    for q in queries:
        args += q.args if type(q.args) == list else [q.args]
    return list(set(args))
    
q0 = Query('inspect', 'board')
q1 = Query('getUncoveredOnBoard')
q2 = Query('countBoard')
q3 = Query('inspect', 'discard')
q4 = Query('')

queries = [q0, q1, q2, q3]
args = getArgs(queries)

class DurakSub(nn.Module):
    def __init__(self):
        super(DurakSub, self).__init__()
        self.subs = [inspect, getUncoveredOnBoard, countBoard]
        self.objs = [game, game.board, game.discard, None]
        
        # Parse query to determine plan
        self.parseq0 = nn.Linear(10,3).float().cuda()
        self.parsea0 = nn.Linear(10,4).float().cuda()
    
    def forward(self, query):
        x = query.oneHot()
        y = query.oneHotArg()
        a = self.parseq0(x)
        b = self.parsea0(y)
        
        subCat = Categorical(logits=a)
        objCat = Categorical(logits=b)
        
        subIdx = subCat.sample()
        objIdx = objCat.sample()
        
        def lossFn(reward):
            return -reward*(subCat.log_prob(subIdx)+objCat.log_prob(objIdx))
        
        try:
            if self.objs[objIdx] == None:
                return self.subs[subIdx](), lossFn
            else:
                return self.subs[subIdx](self.objs[objIdx]), lossFn
        except:
            return None, lossFn
    
durakSub = DurakSub()
optim = torch.optim.Adam(durakSub.parameters(), lr=1e-4, weight_decay=0)
        
print('Complete')

Complete


In [42]:
rewardsTab = {
    q0: lambda : inspect(game.board),
    q1: lambda : getUncoveredOnBoard(),
    q2: lambda : countBoard(),
    q3: lambda : inspect(game.discard)
}

query = q0

print('Complete')

Complete


In [45]:
nEpochs = 100_000
running = []
pPeriod = 1000
window = 100
nReset = 0

for epoch in range(nEpochs):
    optim.zero_grad()
    res, lossFn = durakSub(query)
    if res == rewardsTab[query]():
        reward = 1
    else:
        reward = -1
    loss = lossFn(reward)
    loss.backward()
    optim.step()
    running.append(reward)
    if len(running) == window:
        running.pop(0)
    if epoch % pPeriod == 0 or epoch == nEpochs-1:
        print(f'epoch {epoch} loss {sum(running)/len(running)}')
    if sum(running)/len(running) > 0.9:
        query = random.choice(queries)
        print(f'Switched to {query.verb} {query.args}')

epoch 0 loss -1.0
epoch 1000 loss 0.696969696969697
epoch 2000 loss 0.6161616161616161
epoch 3000 loss 0.7373737373737373
epoch 4000 loss 0.7575757575757576
Switched to getUncoveredOnBoard None
Switched to inspect board
Switched to inspect discard
Switched to countBoard None
Switched to inspect discard
Switched to inspect discard
Switched to inspect discard
Switched to getUncoveredOnBoard None
Switched to inspect board
epoch 5000 loss 0.41414141414141414
epoch 6000 loss 0.47474747474747475
epoch 7000 loss 0.47474747474747475
epoch 8000 loss 0.6161616161616161
epoch 9000 loss 0.6767676767676768
epoch 10000 loss 0.8383838383838383
Switched to inspect board
epoch 11000 loss 0.7777777777777778
Switched to getUncoveredOnBoard None
Switched to inspect board
Switched to countBoard None
Switched to countBoard None
Switched to countBoard None
Switched to inspect board
Switched to countBoard None
Switched to inspect board
Switched to countBoard None
Switched to countBoard None
Switched to inspec

Switched to inspect discard
Switched to getUncoveredOnBoard None
Switched to countBoard None
Switched to inspect board
Switched to getUncoveredOnBoard None
Switched to getUncoveredOnBoard None
Switched to inspect discard
Switched to getUncoveredOnBoard None
Switched to countBoard None
Switched to getUncoveredOnBoard None
Switched to inspect discard
Switched to countBoard None
Switched to getUncoveredOnBoard None
Switched to getUncoveredOnBoard None
Switched to countBoard None
Switched to getUncoveredOnBoard None
Switched to inspect board
Switched to countBoard None
Switched to inspect board
Switched to countBoard None
Switched to inspect board
Switched to countBoard None
Switched to inspect discard
Switched to getUncoveredOnBoard None
Switched to inspect discard
Switched to getUncoveredOnBoard None
Switched to countBoard None
Switched to inspect discard
Switched to getUncoveredOnBoard None
Switched to inspect discard
Switched to inspect discard
Switched to getUncoveredOnBoard None
Swit

KeyboardInterrupt: 

In [33]:
print(query.args)

None


In [63]:
x = query.oneHot()
y = durakSub.parseq0(x)
c = Categorical(logits=y[:,:2])
d = Categorical(logits=y[:,2:])
e = [(int(c.sample()), int(d.sample())) for i in range(1000)]
for a,b in e:
    if a == 0 and b == 0:
        print('ok') 

ok
ok
ok
ok
ok


In [80]:
a = durakSub.subs[0](durakSub.objs[0])
a == countCards(game.board)

False

In [None]:
from inspect import signature
from itertools import product

class DurakAI:
    def __init__(self):
        self.subs = [inspect, getUncoveredOnBoard, countBoard]
        self.seqs = []
        
    def perform(self, query, target=None):
        # Check if we already have a sequence to do what we need
        for seq in self.seqs:
            if seq.query == query:
                return seq.perform()
            if seq.query.verb == seq.query.verb:
                pass
        # Build up and test sequences
        methods = iter(self.subs + self.seqs)
        for i in range(1,5):
            tstSeq = Sequence(query)
            product(self.subs + self.seqs)
            for j in range(i):
                try:
                    sub = next(methods)
                if sub in self.seqs:
                    sub = self.seqs.perform
                    nParam = 0
                else:
                    sig = signature(sub)
                    nParam = len(sig.parameters)
                params = []
                for n in range(nParam):
                    for o in range(len(self.objs)):
                        params.append(self.objs[o])

                    
                        
    
class Sequence:
    def __init__(self, query):
        self.query = query
        self.seqs = []
        self.state = []
        self.bad = None
        self.perform = None
    
    def add(self, seq):
        self.seqs.append(seq)
        self.state.append(None)
    
    def getResult(self, seqIdx, *stateIdcs):
        try:
            self.state[seqIdx] = self.seqs[seqIdx](*[self.state[idx] for idx in stateIdcs])
            return self.state[seqIdx]
        except Exception as ex:
            self.bad = ex
            raise
            
    def finalize(self, method):
        self.perform = method
            