In [1]:
from collections import defaultdict
from itertools import count
import numpy as np
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions
from torch.autograd import Variable

In [2]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [2]:
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor

In [3]:
class Environment(object):
    """
    The Tic-Tac-Toe Environment
    """
    # possible ways to win
    win_set = frozenset([(0,1,2), (3,4,5), (6,7,8), # horizontal
                         (0,3,6), (1,4,7), (2,5,8), # vertical
                         (0,4,8), (2,4,6)])         # diagonal
    # statuses
    STATUS_VALID_MOVE = 'valid'
    STATUS_INVALID_MOVE = 'inv'
    STATUS_WIN = 'win'
    STATUS_TIE = 'tie'
    STATUS_LOSE = 'lose'
    STATUS_DONE = 'done'

    def __init__(self):
        self.reset()

    def reset(self):
        """Reset the game to an empty board."""
        self.grid = np.array([0] * 9) # grid
        self.turn = 1                 # whose turn it is
        self.done = False             # whether game is done
        return self.grid

    def render(self):
        """Print what is on the board."""
        map = {0:'.', 1:'x', 2:'o'} # grid label vs how to plot
        print(''.join(map[i] for i in self.grid[0:3]))
        print(''.join(map[i] for i in self.grid[3:6]))
        print(''.join(map[i] for i in self.grid[6:9]))
        print('====')

    def check_win(self):
        """Check if someone has won the game."""
        for pos in self.win_set:
            s = set([self.grid[p] for p in pos])
            if len(s) == 1 and (0 not in s):
                return True
        return False

    def step(self, action):
        """Mark a point on position action."""
        action = int(action)
        assert type(action) == int and action >= 0 and action < 9
        # done = already finished the game
        if self.done:
            return self.grid, self.STATUS_DONE, self.done
        # action already have something on it
        if self.grid[action] != 0:
            return self.grid, self.STATUS_INVALID_MOVE, self.done
        # play move
        self.grid[action] = self.turn
        if self.turn == 1:
            self.turn = 2
        else:
            self.turn = 1
        # check win
        if self.check_win():
            self.done = True
            return self.grid, self.STATUS_WIN, self.done
        # check tie
        if all([p != 0 for p in self.grid]):
            self.done = True
            return self.grid, self.STATUS_TIE, self.done
        return self.grid, self.STATUS_VALID_MOVE, self.done

    def random_step(self):
        """Choose a random, unoccupied move on the board to play."""
        pos = [i for i in range(9) if self.grid[i] == 0]
        move = random.choice(pos)
        return self.step(move)

    def play_against_random(self, action):
        """Play a move, and then have a random agent play the next move."""
        state, status, done = self.step(action)
        if not done and self.turn == 2:
            state, s2, done = self.random_step()
            if done:
                if s2 == self.STATUS_WIN:
                    status = self.STATUS_LOSE
                elif s2 == self.STATUS_TIE:
                    status = self.STATUS_TIE
                else:
                    raise ValueError("???")
        return state, status, done
     
    def play_qustom_player(self, action, player):
        """Play a move, and then have a random agent play the next move."""
        state, status, done = self.step(action)
        if not done and self.turn == 2:
            act, _ = select_action(player, state)
            state, s2, done = self.step(act)
            if done:
                if s2 == self.STATUS_WIN:
                    status = self.STATUS_LOSE
                elif s2 == self.STATUS_TIE:
                    status = self.STATUS_TIE
                else:
                    raise ValueError("???")
        return state, status, done

In [4]:
class Policy(nn.Module):
    """
    The Tic-Tac-Toe Policy
    """
    def __init__(self, input_size=27, hidden_size=64, output_size=9):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(input_size,hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        out = F.softmax(self.fc2(x))
        return out

In [5]:
def select_action(policy, state):
    """Samples an action from the policy at the state."""
    state = torch.from_numpy(state).long().unsqueeze(0)
    state = torch.zeros(3,9).scatter_(0,state,1).view(1,27)
    #pr = policy(state)
    pr = policy(Variable(state))
    m = torch.distributions.Categorical(pr) 
    action = m.sample()
    log_prob = torch.sum(m.log_prob(action))
    
    return action.data[0], log_prob 

In [6]:
def compute_returns(rewards, gamma=1.0):
    """
    Compute returns for each time step, given the rewards
      @param rewards: list of floats, where rewards[t] is the reward
                      obtained at time step t
      @param gamma: the discount factor
      @returns list of floats representing the episode's returns
          G_t = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + ... 

    >>> compute_returns([0,0,0,1], 1.0)
    [1.0, 1.0, 1.0, 1.0]
    >>> compute_returns([0,0,0,1], 0.9)
    [0.7290000000000001, 0.81, 0.9, 1.0]
    >>> compute_returns([0,-0.5,5,0.5,-10], 0.9)
    [-2.5965000000000003, -2.8850000000000002, -2.6500000000000004, -8.5, -10.0]
    """
    r = []
    x = 0
    for i in range(len(rewards)):
        x = rewards[i]
        for j in range(0,len(rewards)-i):
            if j!=0:
                x= x + rewards[i + j]*gamma**j
    
        r.append(x)
    return r    

In [7]:
def finish_episode(saved_rewards, saved_logprobs, gamma=1.0):
    """Samples an action from the policy at the state."""
    policy_loss = []
    returns = compute_returns(saved_rewards, gamma)
    returns = torch.Tensor(returns)
  
    # subtract mean and std for faster training
    returns = (returns - returns.mean()) / (returns.std() +
                                            np.finfo(np.float32).eps)
   
    for log_prob, reward in zip(saved_logprobs, returns):
        policy_loss.append(-log_prob * reward)
       
    policy_loss = torch.stack(policy_loss).sum()
   
    policy_loss.backward(retain_graph=True)
    


In [8]:
def get_reward(status):
    """Returns a numeric given an environment status."""
    return {
            Environment.STATUS_VALID_MOVE  : 0, 
            Environment.STATUS_INVALID_MOVE: -5,
            Environment.STATUS_WIN         : 1,
            Environment.STATUS_TIE         : 0.5,
            Environment.STATUS_LOSE        : -1
    }[status]


In [12]:
def train(policy, env, gamma=0.98, log_interval=10000):
    """Train policy gradient."""
    optimizer = optim.Adam(policy.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=10000, gamma=0.9)
    running_reward = 0

    for i_episode in count(1):
        saved_rewards = []
        saved_logprobs = []
        state = env.reset()
        done = False
        optimizer.zero_grad()
        while not done:
            action, logprob = select_action(policy, state)
            
            state, status, done = env.play_against_random(action)
       
            reward = get_reward(status)
            #if reward == -5:
             # done = True
           
            saved_logprobs.append(logprob)
            saved_rewards.append(reward)
           
 
        R = compute_returns(saved_rewards)[0]
        running_reward += R

        finish_episode(saved_rewards, saved_logprobs, gamma)
        
        
        if i_episode % log_interval == 0:
            print('Episode {}\tAverage return: {:.2f}'.format(
                i_episode,
                running_reward / log_interval))
            running_reward = 0

        if i_episode % (log_interval) == 0:
            torch.save(policy.state_dict(),
                       "./RL/policy-%d.pkl" % i_episode)

        if i_episode % 16 == 0: # batch_size
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

In [13]:
def first_move_distr(policy, env):
    """Display the distribution of first moves."""
    state = env.reset()
    state = torch.from_numpy(state).long().unsqueeze(0)
    state = torch.zeros(3,9).scatter_(0,state,1).view(1,27)
    pr = policy(Variable(state))
    return pr.data


In [14]:
def load_weights(policy, episode):
    """Load saved weights"""
    weights = torch.load("./RL/policy-%d.pkl" % episode)
    policy.load_state_dict(weights)


In [15]:
policy = Policy()
env = Environment()

In [17]:
train(policy, env)



Episode 10000	Average return: -23.85
Episode 20000	Average return: -15.83
Episode 30000	Average return: -5.84
Episode 40000	Average return: -1.65
Episode 50000	Average return: -0.59
Episode 60000	Average return: -0.39
Episode 70000	Average return: -0.18
Episode 80000	Average return: 0.22
Episode 90000	Average return: 0.32
Episode 100000	Average return: 0.41
Episode 110000	Average return: 0.66
Episode 120000	Average return: 0.51
Episode 130000	Average return: 0.64
Episode 140000	Average return: 0.61
Episode 150000	Average return: 0.75
Episode 160000	Average return: 0.58
Episode 170000	Average return: 0.63


KeyboardInterrupt: ignored

In [0]:
load_weights(policy, 800000)
print(first_move_distr(policy, env))

tensor([[3.4975e-09, 1.0639e-11, 9.9994e-01, 2.4727e-12, 2.9912e-11, 6.9818e-12,
         1.6910e-12, 7.2865e-12, 5.8808e-05]])


  


In [0]:
def hod_setki(policy, grid):
  action,_ = select_action(policy, grid)
  #state = torch.from_numpy(grid).long().unsqueeze(0)
  #state = torch.zeros(3,9).scatter_(0,state,1).view(1,27)
  #pr = policy(Variable(state))
  
  #pr2 = pr.data.max(1, keepdim=True)[1]
  return action

In [0]:
env.render()

...
...
...
====


In [0]:
env.step(hod_setki(policy, env.grid))

  


(array([0, 0, 1, 0, 0, 0, 0, 0, 0]), 'valid', False)

In [0]:
env.render()

..x
...
...
====


In [0]:
env.step(4)

(array([0, 0, 1, 0, 2, 0, 0, 0, 0]), 'valid', False)

In [0]:
env.render()

..x
.o.
...
====


In [0]:
env.step(hod_setki(policy, env.grid))

  


(array([0, 0, 1, 0, 2, 0, 0, 0, 1]), 'valid', False)

In [0]:
env.render()

..x
.o.
..x
====


In [0]:
env.step(5)

(array([0, 0, 1, 0, 2, 2, 0, 0, 1]), 'valid', False)

In [0]:
env.render()

..x
.oo
..x
====


In [0]:
env.step(hod_setki(policy, env.grid))

  


(array([1, 0, 1, 0, 2, 2, 0, 0, 1]), 'valid', False)

In [0]:
env.render()

x.x
.oo
..x
====


In [0]:
env.step(7)

(array([0, 0, 2, 0, 2, 1, 1, 2, 1]), 'valid', False)

In [0]:
env.render()

..o
.ox
xox
====


In [0]:
env.step(hod_setki(policy,env.grid))

  


(array([0, 1, 2, 0, 2, 1, 1, 2, 1]), 'valid', False)

In [0]:
env.render()

.xo
.ox
xox
====


In [0]:
env.reset()

array([0, 0, 0, 0, 0, 0, 0, 0, 0])

In [0]:
env.step(2)

(array([0, 0, 1, 0, 0, 0, 0, 0, 0]), 'valid', False)

In [0]:
env.step(hod_setki(policy, env.grid))

  


(array([0, 0, 1, 0, 0, 0, 0, 0, 2]), 'valid', False)

In [0]:
env.render()

..x
...
..o
====


In [0]:
env.step(4)

(array([0, 0, 1, 0, 1, 0, 0, 0, 2]), 'valid', False)

In [0]:
env.render()

..x
.x.
..o
====


In [0]:
env.step(hod_setki(policy, env.grid))

  


(array([0, 0, 1, 0, 1, 0, 2, 0, 2]), 'valid', False)

In [0]:
env.render()

..x
.x.
o.o
====


In [0]:
env.step(7)

(array([0, 0, 1, 0, 1, 0, 2, 1, 2]), 'valid', False)

In [0]:
env.render()

..x
.x.
oxo
====


In [0]:
env.step(hod_setki(policy, env.grid))

  


(array([0, 0, 1, 0, 1, 2, 2, 1, 2]), 'valid', False)

In [0]:
env.render()

..x
.xo
oxo
====
