https://gist.github.com/programming-datascience/d8b96346e347b0b6942e16a33e64039c#file-actor-critic-cartpole-ipynb

In [1]:
# Importing libraries
# import gym
from pettingzoo.classic import connect_four_v3
import numpy as np
from itertools import count
from collections import namedtuple, deque
from time import sleep
# import supersuit as ss

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
# Importing PyTorch here

In [3]:
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

In [4]:
env = connect_four_v3.env()
# env = ss.resize_v0(env, x_size=84, y_size=84)
print(env.observation_spaces)
print(env.action_spaces)

{'player_0': Dict(action_mask:Box(0, 1, (7,), int8), observation:Box(0, 1, (6, 7, 2), int8)), 'player_1': Dict(action_mask:Box(0, 1, (7,), int8), observation:Box(0, 1, (6, 7, 2), int8))}
{'player_0': Discrete(7), 'player_1': Discrete(7)}


In [99]:
# You can move either left or right to balance the pole
# Lets implement the Actor critic network
class ActorCritic(nn.Module):
    def __init__(self, name, obs_shape, act_shape, buffer_size, lr=1e-2):
        super(ActorCritic, self).__init__()
        self.obs_shape = obs_shape
        self.name = name
        self.games_played = 0
        self.wins = 0
        self.history = deque(maxlen=1000000)
        self.conv1 = nn.Conv2d(2, 8, 3)
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8, 12, 3)
        self.bn2 = nn.BatchNorm2d(12)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(self.calc_input_size(), 64)
        self.actor = nn.Linear(64, act_shape) 
        self.critic = nn.Linear(64, 1) # Critic is always 1
        self.saved_actions = deque(maxlen=buffer_size)
        self.rewards = deque(maxlen=buffer_size)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        
    def calc_input_size(self):
        m = self.conv1(torch.zeros((1,)+self.obs_shape))
#         print(m.shape)
        m = self.bn1(m)
#         print(m.shape)
        m = self.conv2(m)
#         print(m.shape)
        m = self.maxpool1(m)
#         print(m.shape)
        return int(np.prod(m.size()))
        
    def forward(self, x):
        x = F.relu(self.conv1(x.view((1,)+self.obs_shape)))
        x = self.bn1(x)
        x = F.relu(self.conv2(x))
        x = self.bn2(x)
        x = self.maxpool1(x)
        x = F.relu(self.fc1(x.reshape(-1, 12)))
        action_prob = F.softmax(self.actor(x), dim=-1)
        state_values = self.critic(x)
        return action_prob, state_values
    
    def select_action(self, state, mask):
        state = torch.from_numpy(state).float()
        probs, state_value = self.forward(state)
        mask = torch.from_numpy(mask)
#         print(probs)
        m = Categorical(probs * mask)
        action = m.sample()
#         action = torch.argmax(probs * mask)
        self.saved_actions.append(SavedAction(m.log_prob(action), state_value))
        return action.item()
        # In this function, we decide whehter we want the block to move left or right,based on what the model decided
        
    def finish_episode(self):
        # We calculate the losses and perform backprop in this function
        R = 0
        saved_actions = [x for x in self.saved_actions]
    #     log_prob = torch.tensor([x.log_prob for x in model.saved_actions])
    #     value = 
        policy_losses = []
        value_losses =[]
        returns = []
        rewards = [x for x in self.rewards]

        for r in rewards[::-1]:
            R = r + 0.99 * R # 0.99 is our gamma number
            returns.insert(0, R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + eps)

        for (log_prob, value), R in zip(saved_actions, returns):
            advantage = R - value.item()

            policy_losses.append(-log_prob * advantage)
            value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))

        self.optimizer.zero_grad()
        loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
        loss.float().backward()
        self.optimizer.step()

    #     del model.rewards[:]
    #     del model.saved_actions[:]
        self.rewards.clear()
        self.saved_actions.clear()
    
    def save(self, suff=''):
        if len(suff) > 0:
            suff = '_'+suff
        torch.save(self.state_dict(), f"Connect4_models/Connect4_{self.name}{suff}.pt")
        
    def load(self, suff=''):
        if len(suff) > 0:
            suff = '_'+suff
        self.load_state_dict(torch.load(f"Connect4_models/Connect4_{self.name}{suff}.pt"))
        
    def game_done(self, reward):
        self.games_played += 1
        self.wins += 1 if reward > 0 else 0
        self.history.append(reward)

In [100]:
def train(episodes_max, t_max=1000):
#     print('target reward:', env.spec.reward_threshold)
    running_reward = 0
    for i_episode in range(episodes_max): # We need around this much episodes
        env.reset()
        ep_reward = 0
        reward = [0,0]
        done = [0,0]
        for t in range(t_max):
            state, reward[t%2], done[t%2], _ = env.last()
            reward[t%2] = float(reward[t%2])
            players[t%2].rewards.append(reward[t%2])
            if done[t%2]:
                if all(done):
                    players[t%2].game_done(reward[t%2])
                    players[(t+1)%2].game_done(reward[(t+1)%2])
                    break
                env.step(None)
                continue
            action = players[t%2].select_action(state['observation'], state['action_mask'])
#             ep_reward += reward
            env.step(action)
            
            
        running_reward = 0.05 * ep_reward + (1-0.05) * running_reward
        model_1.finish_episode()
        model_2.finish_episode()

        print("\rEpisode {}\tmodel_1 wins: {:.2f}\tmodel_2 wins: {:.2f}".format(
                i_episode, model_1.wins, model_2.wins
            ), end=' '*10)
        if i_episode % 100 == 0: # We will print some things out
            print("\rEpisode {}\tmodel_1 wins: {:.2f}\tmodel_2 wins: {:.2f}".format(
                i_episode, model_1.wins, model_2.wins
            ), end=' '*10)
            print()
            model_1.save('last')
            model_2.save('last')
                  

In [101]:
saved_model = False
buffer_size = 50000

obs = env.observation_spaces['player_0'].spaces['observation'].shape
mask = env.observation_spaces['player_0'].spaces['action_mask'].shape[0]
# obs = obs[0] * obs[1] * obs[2] + mask

model_1 = ActorCritic('model1', obs[::-1], env.action_spaces['player_0'].n, buffer_size, lr=1e-3)
model_2 = ActorCritic('model2', obs[::-1], env.action_spaces['player_0'].n, buffer_size, lr=5e-3)
players = [model_1, model_2]

if saved_model:
    model_1.load('last')
    model_2.load('last')

eps = np.finfo(np.float32).eps.item()

In [113]:
eps

1.1920928955078125e-07

In [102]:
t_max = 500
episodes_max = 10000
train(episodes_max, t_max)

  value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))


Episode 0	model_1 wins: 0.00	model_2 wins: 1.00          
Episode 100	model_1 wins: 51.00	model_2 wins: 50.00                                                                                                                     
Episode 200	model_1 wins: 111.00	model_2 wins: 90.00                                                                                                                                 
Episode 300	model_1 wins: 172.00	model_2 wins: 129.00                                                                                                                                           
Episode 400	model_1 wins: 215.00	model_2 wins: 186.00                                                                                                                                                                                    
Episode 500	model_1 wins: 248.00	model_2 wins: 253.00                                                                                                               

Episode 9800	model_1 wins: 4218.00	model_2 wins: 5575.00                                                                                                                        
Episode 9900	model_1 wins: 4265.00	model_2 wins: 5628.00                                                                                                                                  
Episode 9999	model_1 wins: 4299.00	model_2 wins: 5693.00                                                                                                                                            

In [135]:
np.unique(np.array(model_1.history), return_counts=True)

(array([-1.,  0.,  1.]), array([5731,   32, 4211], dtype=int64))

In [136]:
print(f'{model_1.wins = }; games_played = {model_1.games_played}')
print(f'{model_2.wins = }; games_played = {model_1.games_played}')

model_1.wins = 4211; games_played = 9974
model_2.wins = 5731; games_played = 9974


# There. we finished
### Lets see it in action

In [46]:
state['observation'].shape

(6, 7, 2)

In [47]:
state['action_mask'].shape

(7,)

In [49]:
state['observation'].flatten().shape

(84,)

In [112]:
pl1_wins = 0
pl2_wins = 0


with torch.no_grad():
    for _ in range(5):
        env.reset()
        reward = [0,0]
        done = [0,0]
        for t in range(100):
            state, reward[t%2], done[t%2], _ = env.last()
            env.render()
            if done[t%2]:
                if all(done):
                    locals()[f'pl{t%2+1}_wins'] += 1 if reward[t%2] > 0 else 0
                    locals()[f'pl{(t+1)%2+1}_wins'] += 1 if reward[(t+1)%2] > 0 else 0
                    break
                env.step(None)
                continue
            action = players[t%2].select_action(state['observation'], state['action_mask'])
            env.step(action)
            sleep(0.1)
    env.close()
print(f'{pl1_wins = }\t{pl2_wins = }')

pl1_wins = 1	pl2_wins = 4


In [10]:
players[0].select_action(state['observation'].flatten(), state['action_mask'])

4

In [2]:
!start ..

In [105]:
env.reset()

In [111]:
state, reward, done, _ = env.last()
print(state['observation'])
print(state['observation'].shape)

[[[0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]]

 [[0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]]

 [[0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]]

 [[0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]]

 [[0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]]

 [[1 0]
  [0 1]
  [0 0]
  [0 0]
  [0 0]
  [0 0]
  [0 0]]]
(6, 7, 2)


In [110]:
env.step(1)