In [1]:
# まるばつゲームを強化学習で実装する(Python)｜にしまつ@マーケティング・サイエンティスト
# https://note.com/daikinishimatsu/n/n749a6f743a9f

In [23]:
from enum import Enum
import copy
from collections import defaultdict, deque
import random
import time

import numpy as np

class DefaultDict:
    # collections.defaultdict causes TypeError: () missing 1 required positional argument: 'x'
    # so we made a custom default dictionary
    def __init__(self, default_val):
        self.default_val = default_val
        self.dict = dict()

    def __getitem__(self, arg):
        if arg not in self.dict:
            return self.default_val
        
        return self.dict[arg]
    
    def __setitem__(self, k, v):
        self.dict[k] = v

    def keys(self):
        return self.dict.keys()
    
    def values(self):
        return self.dict.values()
    
    def __contains__(self, item):
        return item in self.dict
    
    def __len__(self):
        return len(self.dict)

    

class IssEnv():
    def __init__(self):
        self.reset()
        self.penalty = -0.2

    def reset(self):
        self.n_players = 5 # number of players
        self.remaining_players = self.n_players # remaining players. players will get out as the game goes by.
        self.called_number = 0 # called number by a player
        self.player_in_turn = 0 # the id of the player who is in turn
        self.thumbs = [0] * self.n_players # number of thumbs up
        # number of thumbs each player has.
        # if it's 2, the player has 2 thumbs.
        # if it's 1, one thumb is out.
        # if it's 0, the player won and is out.
        self.max_thumbs = [2] * self.n_players 

    def step(self, thumbs, called_number):
        self.thumbs = [thumbs] + [random.randint(0, 2) for _ in range(self.n_players - 1)]
        if self.player_in_turn != 0:
            called_number = random.randint(0, sum(self.max_thumbs))
        reward, done = self.calc_reward(called_number)
        # next_state, reward, done = self.T(self.thumbs, called_number)
        # self.thumbs = next_state # is this necessary?
        return self.thumbs, reward, done

    # def T(self,state,called_number):
    #     # done check of the state
    #     reward, done = self.R(state, called_number)
    #     if done:
    #         return state, reward, done

    #     # if not done, get the reward of the next state
    #     return copy.copy(state), reward, done

    # def R(self, state, called_number):
    #     # rewardをもとに終了判定
    #     reward, done = self.calc_reward(state, called_number)

    #     # finish
    #     if reward >= 0:
    #         done = True
    #         return reward, done
    #     # continue
    #     else:
    #         done = False
    #         return reward, done
        
    def calc_reward(self, called_number):
        # if the called number is above the maximum number of thumbs, give penalty
        if (called_number > sum(self.max_thumbs)) or (called_number < 0):
            done = False
            return self.penalty, done

        # when one got the right number of thumbs
        if called_number == sum(self.thumbs):
            if self.max_thumbs[self.player_in_turn] == 2:
                done = False
                self.max_thumbs[self.player_in_turn] -= 1
                # when you finish game faster(= with more number of remaining players), you are scored higher
                return self.remaining_players, done
            elif self.max_thumbs[self.player_in_turn] == 1:
                self.max_thumbs[self.player_in_turn] -= 1
                # when an opponent is out, game goes on
                done = self.player_in_turn == 0
                self.remaining_players -= 1
                return self.remaining_players, done
            else:
                raise ValueError('self.max_thumbs[0] is not 1 or 2')

        # game continues
        done = False
        return 0, done

    def is_game_over(self):
        # only one person remaining
        return self.max_thumbs.count(0) >= self.n_players - 1
    
    def encode_state(self,state=None):
        if state == None:
            state = self.thumbs
        # each element of state takes 0, 1, or 2. So encoding to an int by bit shifting by 4. 
        # return sum(x * (1 << (2 * i)) for i, x in enumerate(state))
        
        # encode each number of the player with i thumbs (i=0,1,2)
        # the state is described with a tuple (n_players with 0 thumbs, n_players with 1 thumb)
        # we suppose n_players = 5
        # each state is encoded to 0 to 20 integers
        thumb0 = state.count(0)
        thumb1 = state.count(1)
        return 15 - (5 - thumb0) * (6 - thumb0) / 2 + thumb1

    def __len__(self):
        return len(self.thumbs)
    
class IssAgent():
    def __init__(self, env, epsilon, min_alpha):
        self.epsilon = epsilon
        self.min_alpha = min_alpha

        # somehow this needs x in lambda expression...
        self.N = DefaultDict(DefaultDict(0))
        self.Q = DefaultDict(DefaultDict(0))

        self.env = env

        self.round = 0
        self.prev_states = deque()
        self.game_round = 0

    def __policy_random(self):
        thumbs = 0
        max_thumbs = self.env.max_thumbs[self.env.player_in_turn]
        thumbs = random.randint(0, max_thumbs)
        called_number = random.randint(0, sum(self.env.max_thumbs))
        return thumbs, called_number

    def policy(self):
        # takes action based on the state
        if len(self.Q) ==0 or self.game_round < 1 or random.random() < self.epsilon or self.env.player_in_turn != 0:
            return self.__policy_random()
        else:
            enc = self.env.encode_state()
            if enc in self.Q:
                k = self.Q[enc].keys() # list of tuple (thumbs, called_number)
                v = self.Q[enc].values() # float
                index = np.argmax(v)
                return k[index] # tuple (thumbs, called_number)
            else:
                return self.__policy_random()
            
    def play(self):
        # start a new game
        self.env.reset()

        # the id of the first player. if this is 0, the player is you.
        self.env.player_in_turn = random.randint(0, self.env.n_players - 1)
        
        self.game_round = 0
        done = False
        experience = []
        
        # play one game until it's over
        t0 = time.time()
        timeout = 60
        while not done:
            # get the best action based on Q function
            thumbs, called_number = self.policy()

            next_state, reward, done = self.env.step(thumbs, called_number)
            
            if self.env.player_in_turn == 0:
                experience.append({
                    "state": self.env.encode_state(next_state), 
                    "action": (thumbs, called_number), 
                    "reward": reward,
                    })

            # increment by 1. skip dead players
            self.env.player_in_turn = (self.env.player_in_turn + 1) % self.env.n_players 
            while self.env.max_thumbs[self.env.player_in_turn] == 0:
                self.env.player_in_turn = (self.env.player_in_turn + 1) % self.env.n_players 
            
            self.game_round += 1

            # prevent infinite loop
            if time.time() - t0 > timeout:
                raise TimeoutError('while not done:')

        # calculate value function and update Q function
        for i, x in enumerate(experience):
            s, a = x["state"], x["action"]
            
            # calculate accumulated reward across future
            G = 0
            jrange = range(i, len(experience))
            jrange = list(jrange)[::-1]
            for j in jrange:
                G *= 0.9
                G += experience[j]["reward"]
            # above is the faster equivalent of below
            # for j in range(i, len(experience)):
            #     cnt = j - i
            #     g += (0.9 ** cnt) * experience[j]["reward"]

            # update parameters
            self.N[s][a] += 1
            alpha = 1 / self.N[s][a]
            alpha = max(alpha, self.min_alpha)
            self.Q[s][a] += alpha * (G - self.Q[s][a])

        return [e['reward'] for e in experience]

In [None]:
# todo
# the code often freezes. Solve it.


env = IssEnv()
agent = IssAgent(env, epsilon=0.1, min_alpha=0.01)

score = 0
game = 0

for i in range(10):
    rewards = agent.play()

    # score
    score += rewards[-1]
    game += 1

    print(env.encode_state())
    print(str(i) + "th play",rewards)
    print(f"score / game : {score} / {game}")