In [1]:
import torch
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import time
import random
from typing import Any
import copy
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from collections import deque 
from torch.nn.utils.clip_grad import clip_grad_norm_,clip_grad_value_

In [2]:
torch.manual_seed(random.randint(0, 1000))
mean = 0.0
stddev = 0.05
def random_init_fn():
    return nn.init.normal_(torch.empty(1), mean=0.0, std=0.05).item()
random_init = torch.nn.init.normal_(torch.empty(1), mean=mean, std=stddev).item()
class TicTacToe(gym.Env):
    def __init__(self, grid_size=(3, 3)):
        super(TicTacToe, self).__init__()
        self.grid_size = grid_size
        self.action_space = spaces.Discrete(9)
        self.observation_space = spaces.Box(low=0, high=255, shape=(2, grid_size[0], grid_size[1]), dtype=np.uint8)
        self.steps=0
        self.agent_sym=self.get_agent()
        self.opponent_sym="X" if self.agent_sym=="O" else "O"
        self.board=self.build_board()
        self.placed=[]
#         self.root = tk.Tk()
#         self.root.title("TicTacToe")

    def reset(self):
        self.steps=0
        self.agent_sym=self.get_agent()
        self.opponent_sym="X" if self.agent_sym=="O" else "O"
        self.board=self.build_board()
        self.placed=[]
        return self._get_observation()

    def get_agent(self):
        return random.choice(['X','O'])

    def build_board(self):
        return np.zeros([3,3],dtype='str')

    def encode_board(self):
        # enboard=np.zeros([3,3])
        # for i in range(3):
        #     for j in range(3):
        #         if(self.board[i,j]=="X"):
        #             enboard[i,j]=1
        #         elif(self.board[i,j]=="O"):
        #             enboard[i,j]=-1
        #         else:
        #             enboard[i,j]=0
        # if self.turn() != "X":
        #     enboard = enboard * -1

        current_player = 0 if self.turn() == "X" else 1
        flag_board = np.zeros((2,3,3))
        for i in range(3):
            for j in range(3):
                if self.board[i,j] == self.agent_sym:
                    flag_board[0,i,j] = 1
                elif self.board[i,j] == self.opponent_sym:
                    flag_board[1,i,j] = 1

        if current_player != 0:
            flag_board_first = flag_board[0,:,:].copy()
            flag_board[0,:,:] = flag_board[1,:,:]
            flag_board[1,:,:] = flag_board_first
        debuged = flag_board.swapaxes(0,-1)
        return flag_board

    def GetPosMoves(self,board):
        pos_moves=[]
        for i in range(3):
            for j in range(3):
              if(board[i,j]=="" or board[i,j]=="0"):
                pos_moves.append(i*3 + j)
        return pos_moves

    def is_game_over(self,moves):
        winner=self.get_winner()
        # print("Winner:",winner,"Agent:",self.agent_sym,"Opponent:",self.opponent_sym)
        if(winner==self.agent_sym):
            return 1,1
        elif(winner==self.opponent_sym):
            return 2,-1
        else:
            if(self.steps>8 and len(moves)==0):
                # print(self.steps)
                return -1,0.5
            else:
                return 0,0

#     def get_reward(self,move):
#         i,j=move
#         reward=0
#         temp1=copy.deepcopy(self.board)
#         turn1 = "X" if self.turn()=="X" else "O"
#         turn2 = "O" if turn1=="X" else "X"
#         temp1[i,j]=turn1
#         winner=self.get_winner(temp1)
#         if(winner==turn1):
#             reward-=1
#             # print("Turn1")
#         temp1[i,j]=turn2
#         if(self.get_winner(temp1)==turn2):
#             reward+=0.2
#         return reward
        

    def get_winner(self,board):
        count=0
        for t in range(3):
            if(board[t,0]==board[t,1] and board[t,1]==board[t,2] and board[t,0]!=""):
                return board[t,0]
            if(board[0,t]==board[1,t] and board[1,t]==board[2,t] and board[0,t]!=""):

                return board[0,t]

        if(board[1,1]==board[0,0] and board[1,1]==board[2,2] and board[1,1]!=""):
            return board[0,0]

        elif(board[0,2]==board[1,1] and board[1,1]==board[2,0] and board[1,1]!=""):
            return board[1,1]
        return None

    def turn(self):
        if(self.steps%2==0):
            return "X"
        else:
            return "O"

    def step(self, action):
        if self.is_valid(action) == False:
            self.steps+=1
            return self.encode_board() , 1.1 , True
        action = self.decode_move(action)
        # reward = -self.get_reward(action)
        i, j = action
        self.board[i,j]= self.turn()
        self.steps+=1
        reward=0
        winner=self.get_winner(self.board)
        done = True if winner is not None or self.steps>8 else False
        if done:
            reward = 0.25 if winner is None else -1
        return self.encode_board() , reward , done

    def render(self):
        board=copy.deepcopy(self.board)
        for i in range(3):
            for j in range(3):
                if(board[i,j]==""):
                    board[i,j]="0"
        for row in board:
            formatted_row = [val for val in row]

            print(" | ".join(formatted_row))
            print("-" * 9)

    def _get_observation(self):
        return self.encode_board()

#     def custom_board(self,str1):
#         board=np.zeros([6,7],dtpye="str")
#         for i in range(7):
#             for j in range(int(str[i])-1,-1,-1):
#                 board[i]
#         board=board.reshape(3,3)
#         return board

    def encode_move(self,move):
        i,j=move
        return i*3 + j

    def decode_move(self,move):
        return (move//3,move%3)

    # def render1(self):                                   #used to render the board into tkinter gui
    #     for i in range(3):
    #         for j in range(3):
    #             piece = self.board[i][j]
    #             color = "#E3C16F" if (i + j) % 2 == 0 else "#B88B4A"
    #             if(piece==""):
    #                 piece="    "
    #             self.label = tk.Label(self.root, text="    ", font=("Helvetica", 21),bg=color)
    #             self.label.grid(row=i, column=j)
    #             self.label = tk.Label(self.root, text=piece, font=("Helvetica", 21),bg=color)
    #             self.label.grid(row=i, column=j)
    #     self.root.update()

    def is_valid(self,move):
        move=self.decode_move(move)
        i,j=move
        if(self.board[i,j]==""):
            return True
        return False

    def close(self):
        self.root.destroy()

In [9]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, learning_rate,horizon,actor):
        super(Actor, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.model = self.build_model()
        self.horizon = horizon
        self.model.apply(self.init_weights)  # Apply random weight initialization
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        # self.model.compile(optimizer=self.optimizer, loss='categorical_crossentropy')

    def forward(self, x):
        return self.model(x)
    
    def build_model(self):
        model=nn.Sequential(
            nn.Conv2d(in_channels=state_dim[0], out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * state_dim[1] * state_dim[2], 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax()
        )
        return model
    
    def init_weights(self, layer):
        if isinstance(layer, (nn.Conv2d, nn.Linear)):
            init.normal_(layer.weight, mean=0.0, std=0.05)
            init.constant_(layer.bias, 0.0)

    
    def get_actions_prob(self, states):
        states_tensor = torch.tensor(np.array(states), dtype=torch.float32)
        with torch.no_grad():
            action_probs_tensor = self.model(states_tensor)
        action_probs = action_probs_tensor.numpy()
        return action_probs

    def train(self, states, actions, advantages,masks):
        masks_tensor = torch.tensor(masks,dtype=torch.float32)
        advantages_tensor:torch.Tensor = torch.tensor(advantages,dtype=torch.float32)
        advantages_tensor = torch.squeeze(advantages_tensor)
        actions_tensor = torch.tensor(actions,dtype=torch.float32)
        states=np.array(states)
        states_tensor = torch.tensor(states,dtype=torch.float32)
        with torch.no_grad():
            probs = self.model(states_tensor)
            dist = torch.distributions.Categorical(probs)
        # entropy_loss = dist.entropy().mean()
            o_log_probs = dist.log_prob(actions_tensor)
        old_log_probs = torch.tensor(o_log_probs.numpy(),dtype=torch.float32) 
        n_batches = 2
        horizon = int(self.horizon)
        batch_size = horizon//n_batches
        for epoch in range(2):
            batch_starts = np.arange(0,horizon,batch_size)
            indices = np.arange(horizon,dtype = np.int32)
            np.random.shuffle(indices)
            batches = [indices[i:i+batch_size] for i in batch_starts]
            for batch in batches:
                states_batch = states_tensor[batch]
                old_log_probs_batch = old_log_probs[batch]
                actions_batch = actions_tensor[batch]
                masks_batch = masks_tensor[batch]
                advantages_batch = advantages_tensor[batch]
                probs = self.model(states_batch)
                dist = torch.distributions.Categorical(probs=probs)
                entropy = torch.mean(dist.entropy())
                new_log_probs = dist.log_prob(actions_batch)
                prob_ratio = torch.exp(new_log_probs-old_log_probs_batch)
                weighted_ratio = prob_ratio * advantages_batch
                clipped_prob_ratio = torch.clamp(prob_ratio,1-0.2,1+0.2)
                weighted_clipped_ratio = advantages_batch * clipped_prob_ratio
                loss = -torch.min(weighted_ratio,weighted_clipped_ratio).mean()
                total_actor_loss = loss -0.4*entropy
                self.optimizer.zero_grad()
                total_actor_loss.backward()
                clip_grad_value_(self.model.parameters(),10000)
                clip_grad_norm_(self.model.parameters(),max_norm=0.5)
                self.optimizer.step()
        # actor_loss = -log_probs * advantages_tensor
        # loss = actor_loss.mean() - 0.50 * entropy_loss
        # self.optimizer.zero_grad()
        # loss.backward()
        # self.optimizer.step()


class Critic(nn.Module):
    def __init__(self, state_dim, learning_rate,horizon,critic):
        super(Critic, self).__init__()
        self.state_dim = state_dim
        self.model = self.build_model()
        self.horizon = horizon
        self.model.apply(self.init_weights)  # Apply random weight initialization
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        # self.model.compile(optimizer=self.optimizer, loss='categorical_crossentropy')

    def build_model(self):
        model = nn.Sequential(
            nn.Conv2d(in_channels=state_dim[0], out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * state_dim[1] * state_dim[2], 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        return model
    
    
    def forward(self, x):
        return self.model(x)

    def init_weights(self, layer):
        if isinstance(layer, (nn.Conv2d, nn.Linear)):
            init.normal_(layer.weight, mean=0.0, std=0.05)
            init.constant_(layer.bias, 0.0)

    def get_values(self, states):
        states_tensor = torch.tensor(np.array(states),dtype=torch.float32)
        with torch.no_grad():
            values_tensor = self.model(states_tensor)
        values_tensor = torch.squeeze(values_tensor)
        values = values_tensor.numpy()
        return values

    def train(self, states, discounted_rewards):
        states_tensor = torch.tensor(np.array(states), dtype=torch.float32)
        discounted_rewards_tensor = torch.tensor(discounted_rewards, dtype=torch.float32)
        with torch.no_grad():
            old_values = self.model(states_tensor)
        n_batches = 2
        horizon = self.horizon
        batch_size = horizon//n_batches
        for epoch in range(2):
            batch_starts = np.arange(0,horizon//n_batches)
            indices = np.arange(horizon,dtype=np.int32)
            np.random.shuffle(indices)
            batches = [indices[i:i+batch_size] for i in batch_starts]
            for batch in batches:
                states_batch = states_tensor[batch]
                discounted_rewards_batch = discounted_rewards_tensor[batch]
                old_values_batch = old_values[batch]
                new_values = self.model(states_batch)
                clipped_values = old_values_batch + torch.clamp(new_values - old_values_batch,-0.2,+0.2)
                loss1 = (discounted_rewards_batch - new_values)**2
                loss2 = (discounted_rewards_batch - clipped_values)
                loss = 0.5*torch.max(loss1,loss2).mean()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
        
        values = self.model(states_tensor)
        loss = torch.mean(0.5*(discounted_rewards_tensor - values)**2)
        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_value_(self.model.parameters(),10000)
        clip_grad_norm_(self.model.parameters(),max_norm=0.5)
        self.optimizer.step()
        del states_tensor, discounted_rewards_tensor, values, loss

# A2C algorithm
class A2C:
    def __init__(self, state_dim, action_dim, learning_rate_actor, learning_rate_critic,horizon,actor,critic):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.horizon = horizon
        self.actor = Actor(state_dim, action_dim, learning_rate_actor,horizon,actor)
        self.critic = Critic(state_dim, learning_rate_critic,horizon,critic)
        

    def select_action(self, state,legal_actions,temp=0):
        assert len(legal_actions) > 0
        action_probs = self.actor.get_actions_prob(state)
        debugged_state= np.swapaxes(state,-3,-1)
        masks= np.zeros((self.action_dim,),dtype=np.float32)
        masks[legal_actions] = 1
        probs = action_probs[0]
        if(temp==1):
            print(probs)
            probs = probs * masks
            probs = probs / np.sum(probs,axis=-1)
            action=np.argmax(probs)
        else:
            # print(probs)
            action = np.random.choice(self.action_dim,p=probs)
        return action

    def train(self,states_batch,action_batch,returns_batch,masks):
        # states_batch=[]
        # action_batch=[]
        # returns_batch=[]
        # masks=[]
        # # print(memory_batch)
        # for i in range(len(memory_batch)):
        #     states_batch.append(memory_batch[i][0])
        #     action_batch.append(memory_batch[i][1])
        #     returns_batch.append(memory_batch[i][2])
        #     masks.append(memory_batch[i][3])
        # states_batch=np.array(states_batch)
        # action_batch=np.array(action_batch)
        # returns_batch=np.array(returns_batch)
        # masks=np.array(masks)
        values = self.critic.get_values(states_batch)
        advantages = returns_batch - values
        self.actor.train(states_batch, action_batch, advantages,masks)
        self.critic.train(states_batch, returns_batch)
        return advantages
        del values,next_values,td_targets

In [10]:
def test_vs_random_player(a2c:A2C,n_games:int):
    env  = TicTacToe()
    endingscore = 0
    for i in range(n_games):
        agent_player = random.randint(0,1) %2
        random_player = 1- agent_player
        state = env.reset()
        done = False
        current_player = 0
        while not done:
            # print(env.board)
            moves = env.GetPosMoves(env.board)
            if current_player == agent_player:
                action = a2c.select_action([state],moves)
            else:
                # action = int(input()) # for playing against a human
                action = np.random.choice(moves)

            next_state, reward , done = env.step(action)
            current_player = 1-current_player
            state= next_state

        if current_player != agent_player:
            if(reward<=-0.25):
                endingscore -= -1

    agentwins = endingscore
    return agentwins / n_games * 100


def calculate_returns(rewards,dones,last_value):
    rewards = np.array(rewards)
    returns = np.zeros_like(rewards)

    gamma = 1.0
    T = len(rewards)
    flag = 0
    next_value = last_value
    for t in reversed(range(T)):
        opponent_reward = rewards[t]
        current_reward = -opponent_reward
        if dones[t] and rewards[t]==1:
            next_value = 0
            flag = 1
            temp=int(t)
        elif dones[t] and rewards[t]!=-1:
            next_value = 0
            flag = 0
        current_value = current_reward - gamma * next_value
        if(flag and temp==t):
            returns[t] = current_value
            next_value = current_value
            temp=-2
        elif(flag and temp!=t):
            returns[t] = 0.005
            next_value = current_value
        else:
            returns[t] = current_value
            next_value = current_value
    return returns

In [None]:
if __name__ == "__main__":
    env_fn = lambda : TicTacToe()
    env = env_fn()
    # moded
    state_dim= env.observation_space.shape
    action_dim=9
    learning_rate_actor = 0.0001
    learning_rate_critic = 0.0001
    horizon = 64
    a2c = A2C(state_dim, action_dim, learning_rate_actor, learning_rate_critic,horizon,actor=None,critic=None)
    count1=0
    max_steps = 20000
    total_steps=0
    state=env.reset()
    states = []
    actions = []
    rewards = []
    next_states = []
    dones = []
    all_masks = []
    memory=[]
    while total_steps < max_steps:
        for step in range(horizon):
            moves = env.GetPosMoves(env.board)
            masks= np.zeros((action_dim,),dtype=np.float32)
            masks[moves] = 1
            # print(env.steps,moves)
            temp1= a2c.select_action([state],moves)
            is_legal_action=env.is_valid(temp1)
            if(is_legal_action==False):
                count1+=1
            temp=int(temp1)
            action = temp
            next_state , reward , done = env.step(action)
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            next_states.append(next_state)
            dones.append(done)
            all_masks.append(masks)
            if done:
                next_state = env.reset()
            state = next_state
        total_steps+=1
        # print(total_steps)
        last_state=state
        # print(state)
        last_value = a2c.critic.get_values([state])
        # print(last_value)
        returns = calculate_returns(rewards,dones,last_value)
        # for i in range(len(states)):
        #     memory.append([states[i],actions[i],returns[i],all_masks[i]])
        # if(len(states)<batch_size):
        #     pass
        # else:
        #     memory=random.sample(memory,batch_size)
        
        ad=a2c.train(
                np.array(states),
                np.array(actions),
                np.array(returns),
                np.array(all_masks)
            )
            
        if(total_steps % 100==99):
            print("count:",count1)
            print("steps:",total_steps+1,end=" ")
            win_ratio = test_vs_random_player(a2c , 50)
            print(f"Win ratio against random player {win_ratio:0.2f}")
            count1=0

count: 1586
steps: 100 Win ratio against random player 8.00
count: 1793
steps: 200 Win ratio against random player 4.00


In [14]:
win_ratio = test_vs_random_player(a2c , 2)

[['' '' '']
 ['' '' '']
 ['' '' '']]
[2.6911090e-08 6.6771291e-09 7.8482175e-05 9.9991953e-01 1.9276690e-06
 7.4280870e-10 3.2360006e-09 9.3368534e-17 3.1466850e-17]
[['' '' '']
 ['X' '' '']
 ['' '' '']]


 4


[['' '' '']
 ['X' 'O' '']
 ['' '' '']]
[4.4219507e-12 1.0566520e-10 9.9999571e-01 5.9772226e-22 1.4183807e-33
 3.9399789e-11 4.3136674e-06 4.3865581e-24 5.9690835e-16]
[['' '' 'X']
 ['X' 'O' '']
 ['' '' '']]


 8


[['' '' 'X']
 ['X' 'O' '']
 ['' '' 'O']]
[5.5009048e-02 9.4499081e-01 1.3876283e-19 1.9920178e-10 3.6901888e-25
 7.6213652e-08 2.5072731e-13 5.0766733e-15 1.7916055e-34]
[['' 'X' 'X']
 ['X' 'O' '']
 ['' '' 'O']]


 1


[['' '' '']
 ['' '' '']
 ['' '' '']]
[2.6911090e-08 6.6771291e-09 7.8482175e-05 9.9991953e-01 1.9276690e-06
 7.4280870e-10 3.2360006e-09 9.3368534e-17 3.1466850e-17]
[['' '' '']
 ['X' '' '']
 ['' '' '']]


 0


[['O' '' '']
 ['X' '' '']
 ['' '' '']]
[3.5119617e-21 2.1493612e-08 4.8478257e-09 2.6436589e-19 6.5619803e-05
 7.4102671e-04 9.9919337e-01 1.3085020e-18 4.8949346e-16]
[['O' '' '']
 ['X' '' '']
 ['X' '' '']]


 8


[['O' '' '']
 ['X' '' '']
 ['X' '' 'O']]
[1.3153297e-12 2.4482242e-07 6.1439894e-13 3.5134391e-15 5.6451043e-08
 9.9999976e-01 2.9993248e-25 8.0584688e-14 1.5398738e-32]
[['O' '' '']
 ['X' '' 'X']
 ['X' '' 'O']]


 4
