# Temporal Difference

In [8]:
import itertools
import random

class TicTacToeEnv:
    def __init__(self):
        self.states = list(itertools.product(["X", "O", " "], repeat=9))

    @staticmethod
    def actions(state):
        return [i for i, s in enumerate(state) if s == " "]

    @staticmethod
    def transition_model(state, action, player):
        state_list = list(state)
        state_list[action] = player
        return tuple(state_list)

    @staticmethod
    def reward(state, player):
        win_positions = [(0, 1, 2), (3, 4, 5), (6, 7, 8), (0, 3, 6), (1, 4, 7), (2, 5, 8), (0, 4, 8), (2, 4, 6)]
        if player == "X":
            opponent = "O"
        else:
            opponent = "X"

        for pos in win_positions:
            if state[pos[0]] == state[pos[1]] == state[pos[2]] == player:
                return 1  # Reward for winning

        for pos in win_positions:
            if state[pos[0]] == state[pos[1]] == state[pos[2]] == opponent:
                return -1  # Negative reward for losing

        if " " not in state:
            return 0  # Reward for a draw

        # Modify the reward for other scenarios
        if state[4] == player:
            return 0.5  # Encourage taking the center position

        corner_positions = [0, 2, 6, 8]
        if any(state[i] == player for i in corner_positions):
            return -0.1  # Discourage taking corners

        edge_positions = [1, 3, 5, 7]
        if any(state[i] == player for i in edge_positions):
            return -0.1  # Slightly discourage taking edges

        return 0  # Default reward for other situations

    @staticmethod
    def is_terminal(state):
        return TicTacToeEnv.reward(state, "X") == 1 or TicTacToeEnv.reward(state, "O") == 1 or " " not in state

    @staticmethod
    def get_available_actions(state):
        return [i for i, s in enumerate(state) if s == " "]

class TicTacToeTD:
    def __init__(self, game, alpha=0.01, gamma=0.99, epsilon=0.1):
        self.q_values = {}  # Dictionary to store state-action values
        self.game = game
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.replay_buffer = []  # List to store states where the AI lost

    @staticmethod
    def actions(state):
        return [i for i, s in enumerate(state) if s == " "]

    def choose_action(self, state):
        # Check if this state is in the replay buffer
        if state in self.replay_buffer:
            # Avoid any action that leads to the state
            valid_actions = [a for a in self.actions(state) if self.game.transition_model(state, a, "X") not in self.replay_buffer]
            if valid_actions:
                return random.choice(valid_actions)

        # First, check if a winning move is possible
        for a in self.actions(state):
            next_state = self.game.transition_model(state, a, "X")
            if self.game.reward(next_state, "X") == 1:
                return a

        # Then, check if a blocking move is necessary
        for a in self.actions(state):
            next_state = self.game.transition_model(state, a, "O")
            if self.game.reward(next_state, "O") == 1:
                return a

        # Custom logic to prioritize corners after the center
        if state[4] == "X":
            corner_positions = [0, 2, 6, 8]
            unoccupied_corners = [p for p in corner_positions if state[p] == " "]
            if unoccupied_corners:
                return random.choice(unoccupied_corners)

        # If neither winning nor blocking is required, return a default action (e.g., the center if available)
        valid_actions = self.actions(state)
        if valid_actions:
            return random.choice(valid_actions)

        # If there are no valid actions, return an arbitrary action (0 in this case)
        return 0

    def train(self, num_episodes):
        for _ in range(num_episodes):
            state = (" ",) * 9
            while not self.is_terminal(state):
                action = self.choose_action(state)
                next_state = self.game.transition_model(state, action, "X")
                reward = self.game.reward(next_state, "X")
                next_action = self.choose_action(next_state)

                # If next_state leads to a loss, store it in the replay buffer
                if reward == -1:
                    self.replay_buffer.append(next_state)

                # Update Q-value using TD(0) learning
                q_state_action = self.q_values.get((state, action), 0)
                q_next_state_next_action = self.q_values.get((next_state, next_action), 0)
                self.q_values[(state, action)] = q_state_action + self.alpha * (reward + self.gamma * q_next_state_next_action - q_state_action)

                state = next_state

    def play_game(self):
        state = (" ",) * 9  # Initial state
        current_player = "X"
        while not self.is_terminal(state):
            self.print_board(state)
            if current_player == "X":  # "X" is the AI player
                action = self.choose_action(state)
            else:
                action = self.human_move(state)

            # Update state
            state = self.game.transition_model(state, action, current_player)

            current_player = "X" if current_player == "O" else "O"

        # Final outcome
        self.print_board(state)
        if self.game.reward(state, "X") == 1:
            print("AI wins!")
        elif self.game.reward(state, "O") == 1:
            print("You win!")
        else:
            print("It's a draw.")

    @staticmethod
    def is_terminal(state):
        return TicTacToeEnv.reward(state, "X") == 1 or TicTacToeEnv.reward(state, "O") == 1 or " " not in state

    def print_board(self, state):
        print(state[0:3])
        print(state[3:6])
        print(state[6:9])
        print("\n")

    def human_move(self, state):
        while True:
            try:
                action = int(input("Enter your move (0-8): "))
                if action in self.actions(state):
                    return action
                else:
                    print("Invalid move. Try again.")
            except ValueError:
                print("Invalid input. Please enter a number (0-8).")

        return action  # If no winning move, play the selected move

if __name__ == "__main__":
    game_env = TicTacToeEnv()
    game = TicTacToeTD(game_env, alpha=0.1, gamma=0.9, epsilon=0.1)
    game.train(num_episodes=100000)
    game.play_game()


(' ', ' ', ' ')
(' ', ' ', ' ')
(' ', ' ', ' ')


(' ', ' ', ' ')
('X', ' ', ' ')
(' ', ' ', ' ')


(' ', ' ', ' ')
('X', 'O', ' ')
(' ', ' ', ' ')


(' ', ' ', ' ')
('X', 'O', ' ')
(' ', 'X', ' ')


('O', ' ', ' ')
('X', 'O', ' ')
(' ', 'X', ' ')


('O', ' ', ' ')
('X', 'O', ' ')
(' ', 'X', 'X')


('O', ' ', ' ')
('X', 'O', ' ')
('O', 'X', 'X')


('O', ' ', 'X')
('X', 'O', ' ')
('O', 'X', 'X')


('O', ' ', 'X')
('X', 'O', 'O')
('O', 'X', 'X')


('O', 'X', 'X')
('X', 'O', 'O')
('O', 'X', 'X')


It's a draw.
