In [0]:
# Install a Drive FUSE wrapper.
# https://github.com/astrada/google-drive-ocamlfuse
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse

In [0]:
# Generate auth tokens for Colab
from google.colab import auth
auth.authenticate_user()

In [0]:
# Generate creds for the Drive FUSE library.
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

In [0]:
# Create a directory and mount Google Drive using that directory.
!mkdir -p drive
!google-drive-ocamlfuse drive

print('Files in Drive:')
!ls drive/

# Create a file in Drive.
!echo "This newly created file will appear in your Drive file list." > drive/created.txt

In [0]:
import os
import sys
import tensorflow as tf
!pip install discord.py
import discord
import asyncio
from discord.ext.commands import Bot
from discord.ext import commands
!pip install dblpy
import dbl
import aiohttp
import asyncio
import logging

import resource

import nltk
nltk.download('punkt')

In [0]:
class DiscordBotsOrgAPI:
    """Handles interactions with the discordbots.org API"""

    def __init__(self, bot):
        self.bot = bot
        self.token = 'Token'  #  set this to your DBL token
        self.dblpy = dbl.Client(self.bot, self.token)
        self.bot.loop.create_task(self.update_stats())

    async def update_stats(self):
        """This function runs every 30 minutes to automatically update your server count"""

        while True:
            logger.info('attempting to post server count')
            try:
                await self.dblpy.post_server_count()
                logger.info('posted server count ({})'.format(len(self.bot.servers)))
            except Exception as e:
                logger.exception('Failed to post server count\n{}: {}'.format(type(e).__name__, e))
            await asyncio.sleep(1800)
            
def setup(bot):
    global logger
    logger = logging.getLogger('bot')
    bot.add_cog(DiscordBotsOrgAPI(bot))

In [0]:
from __future__ import print_function
import numpy as np


class Board(object):
    """board for the game"""

    def __init__(self, **kwargs):
        self.width = int(kwargs.get('width', 8))
        self.height = int(kwargs.get('height', 8))
        # board states stored as a dict,
        # key: move as location on the board,
        # value: player as pieces type
        self.states = {}
        # need how many pieces in a row to win
        self.n_in_row = int(kwargs.get('n_in_row', 5))
        self.players = [1, 2]  # player1 and player2

    def init_board(self, start_player=0):
        if self.width < self.n_in_row or self.height < self.n_in_row:
            raise Exception('board width and height can not be '
                            'less than {}'.format(self.n_in_row))
        self.current_player = self.players[start_player]  # start player
        # keep available moves in a list
        self.availables = list(range(self.width * self.height))
        self.states = {}
        self.last_move = -1

    def move_to_location(self, move):
        """
        3*3 board's moves like:
        6 7 8
        3 4 5
        0 1 2
        and move 5's location is (1,2)
        """
        h = move // self.width
        w = move % self.width
        return [h, w]

    def location_to_move(self, location):
        if len(location) != 2:
            return -1
        h = location[0]
        w = location[1]
        move = h * self.width + w
        if move not in range(self.width * self.height):
            return -1
        return move

    def current_state(self):
        """return the board state from the perspective of the current player.
        state shape: 4*width*height
        """

        square_state = np.zeros((4, self.width, self.height))
        if self.states:
            moves, players = np.array(list(zip(*self.states.items())))
            move_curr = moves[players == self.current_player]
            move_oppo = moves[players != self.current_player]
            square_state[0][move_curr // self.width,
                            move_curr % self.height] = 1.0
            square_state[1][move_oppo // self.width,
                            move_oppo % self.height] = 1.0
            # indicate the last move location
            square_state[2][self.last_move // self.width,
                            self.last_move % self.height] = 1.0
        if len(self.states) % 2 == 0:
            square_state[3][:, :] = 1.0  # indicate the colour to play
        return square_state[:, ::-1, :]

    def do_move(self, move):
        self.states[move] = self.current_player
        self.availables.remove(move)
        self.current_player = (
            self.players[0] if self.current_player == self.players[1]
            else self.players[1]
        )
        self.last_move = move

    def has_a_winner(self):
        width = self.width
        height = self.height
        states = self.states
        n = self.n_in_row

        moved = list(set(range(width * height)) - set(self.availables))
        if len(moved) < self.n_in_row + 2:
            return False, -1

        for m in moved:
            h = m // width
            w = m % width
            player = states[m]

            if (w in range(width - n + 1) and
                    len(set(states.get(i, -1) for i in range(m, m + n))) == 1):
                return True, player

            if (h in range(height - n + 1) and
                    len(set(states.get(i, -1) for i in range(m, m + n * width, width))) == 1):
                return True, player

            if (w in range(width - n + 1) and h in range(height - n + 1) and
                    len(set(states.get(i, -1) for i in range(m, m + n * (width + 1), width + 1))) == 1):
                return True, player

            if (w in range(n - 1, width) and h in range(height - n + 1) and
                    len(set(states.get(i, -1) for i in range(m, m + n * (width - 1), width - 1))) == 1):
                return True, player

        return False, -1

    def game_end(self):
        """Check whether the game is ended or not"""
        win, winner = self.has_a_winner()
        if win:
            return True, winner
        elif not len(self.availables):
            return True, -1
        return False, -1

    def get_current_player(self):
        return self.current_player


class Game(object):
    """game server"""

    def __init__(self, board, **kwargs):
        self.board = board

    async def graphic(self, board, player1, player2, Message_LOC2):
        """Draw the board and show game info"""
        width = board.width
        height = board.height

        #await client.send_message(Message_LOC2, "Player" + player1 + "with X")
        #await client.send_message(Message_LOC2, "Player" + player2 + "with O")
        print()
        for x in range(width):
            await client.send_message(Message_LOC2, "{0:8}".format(x))
        print('\r\n')
        for i in range(height - 1, -1, -1):
            await client.send_message(Message_LOC2, "{0:4d}".format(i))
            for j in range(width):
                loc = i * width + j
                p = board.states.get(loc, -1)
                if p == player1:
                    await client.send_message(Message_LOC2, 'X'.center(8))
                elif p == player2:
                    await client.send_message(Message_LOC2, 'O'.center(8))
                else:
                    await client.send_message(Message_LOC2, '_'.center(8))
            await client.send_message(Message_LOC2, '\r\n\r\n')

    async def start_play(self, player1, player2, Message_LOC1, start_player=0, is_shown=1):
        """start a game between two players"""
        if start_player not in (0, 1):
            raise Exception('start_player should be either 0 (player1 first) '
                            'or 1 (player2 first)')
        self.board.init_board(start_player)
        p1, p2 = self.board.players
        player1.set_player_ind(p1)
        player2.set_player_ind(p2)
        players = {p1: player1, p2: player2}
        if is_shown:
            await self.graphic(self.board, player1.player, player2.player, Message_LOC1)
        while True:
            current_player = self.board.get_current_player()
            player_in_turn = players[current_player]
            move = player_in_turn.get_action(self.board)
            self.board.do_move(move)
            if is_shown:
                await self.graphic(self.board, player1.player, player2.player, Message_LOC1)
            end, winner = self.board.game_end()
            if end:
                if is_shown:
                    if winner != -1:
                        await client.send_message(Message_LOC1, "Game end. Winner is", players[winner])
                    else:
                        await client.send_message(Message_LOC1, "Game end. Tie")
                return winner

    def start_self_play(self, player, is_shown=0, temp=1e-3):
        """ start a self-play game using a MCTS player, reuse the search tree,
        and store the self-play data: (state, mcts_probs, z) for training
        """
        self.board.init_board()
        p1, p2 = self.board.players
        states, mcts_probs, current_players = [], [], []
        while True:
            move, move_probs = player.get_action(self.board,
                                                 temp=temp,
                                                 return_prob=1)
            # store the data
            states.append(self.board.current_state())
            mcts_probs.append(move_probs)
            current_players.append(self.board.current_player)
            # perform a move
            self.board.do_move(move)
            if is_shown:
                self.graphic(self.board, p1, p2)
            end, winner = self.board.game_end()
            if end:
                # winner from the perspective of the current player of each state
                winners_z = np.zeros(len(current_players))
                if winner != -1:
                    winners_z[np.array(current_players) == winner] = 1.0
                    winners_z[np.array(current_players) != winner] = -1.0
                # reset MCTS root node
                player.reset_player()
                if is_shown:
                    if winner != -1:
                        print("Game end. Winner is player:", winner)
                    else:
                        print("Game end. Tie")
                return winner, zip(states, mcts_probs, winners_z)

In [0]:
import numpy as np
import copy


def softmax(x):
    probs = np.exp(x - np.max(x))
    probs /= np.sum(probs)
    return probs


class TreeNode(object):
    """A node in the MCTS tree.

    Each node keeps track of its own value Q, prior probability P, and
    its visit-count-adjusted prior score u.
    """

    def __init__(self, parent, prior_p):
        self._parent = parent
        self._children = {}  # a map from action to TreeNode
        self._n_visits = 0
        self._Q = 0
        self._u = 0
        self._P = prior_p

    def expand(self, action_priors):
        """Expand tree by creating new children.
        action_priors: a list of tuples of actions and their prior probability
            according to the policy function.
        """
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(self, prob)

    def select(self, c_puct):
        """Select action among children that gives maximum action value Q
        plus bonus u(P).
        Return: A tuple of (action, next_node)
        """
        return max(self._children.items(),
                   key=lambda act_node: act_node[1].get_value(c_puct))

    def update(self, leaf_value):
        """Update node values from leaf evaluation.
        leaf_value: the value of subtree evaluation from the current player's
            perspective.
        """
        # Count visit.
        self._n_visits += 1
        # Update Q, a running average of values for all visits.
        self._Q += 1.0*(leaf_value - self._Q) / self._n_visits

    def update_recursive(self, leaf_value):
        """Like a call to update(), but applied recursively for all ancestors.
        """
        # If it is not root, this node's parent should be updated first.
        if self._parent:
            self._parent.update_recursive(-leaf_value)
        self.update(leaf_value)

    def get_value(self, c_puct):
        """Calculate and return the value for this node.
        It is a combination of leaf evaluations Q, and this node's prior
        adjusted for its visit count, u.
        c_puct: a number in (0, inf) controlling the relative impact of
            value Q, and prior probability P, on this node's score.
        """
        self._u = (c_puct * self._P *
                   np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
        return self._Q + self._u

    def is_leaf(self):
        """Check if leaf node (i.e. no nodes below this have been expanded)."""
        return self._children == {}

    def is_root(self):
        return self._parent is None


class MCTS(object):
    """An implementation of Monte Carlo Tree Search."""

    def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
        """
        policy_value_fn: a function that takes in a board state and outputs
            a list of (action, probability) tuples and also a score in [-1, 1]
            (i.e. the expected value of the end game score from the current
            player's perspective) for the current player.
        c_puct: a number in (0, inf) that controls how quickly exploration
            converges to the maximum-value policy. A higher value means
            relying on the prior more.
        """
        self._root = TreeNode(None, 1.0)
        self._policy = policy_value_fn
        self._c_puct = c_puct
        self._n_playout = n_playout

    def _playout(self, state):
        """Run a single playout from the root to the leaf, getting a value at
        the leaf and propagating it back through its parents.
        State is modified in-place, so a copy must be provided.
        """
        node = self._root
        while(1):
            if node.is_leaf():
                break
            # Greedily select next move.
            action, node = node.select(self._c_puct)
            state.do_move(action)

        # Evaluate the leaf using a network which outputs a list of
        # (action, probability) tuples p and also a score v in [-1, 1]
        # for the current player.
        action_probs, leaf_value = self._policy(state)
        # Check for end of game.
        end, winner = state.game_end()
        if not end:
            node.expand(action_probs)
        else:
            # for end state，return the "true" leaf_value
            if winner == -1:  # tie
                leaf_value = 0.0
            else:
                leaf_value = (
                    1.0 if winner == state.get_current_player() else -1.0
                )

        # Update value and visit count of nodes in this traversal.
        node.update_recursive(-leaf_value)

    def get_move_probs(self, state, temp=1e-3):
        """Run all playouts sequentially and return the available actions and
        their corresponding probabilities.
        state: the current game state
        temp: temperature parameter in (0, 1] controls the level of exploration
        """
        for n in range(self._n_playout):
            state_copy = copy.deepcopy(state)
            self._playout(state_copy)

        # calc the move probabilities based on visit counts at the root node
        act_visits = [(act, node._n_visits)
                      for act, node in self._root._children.items()]
        acts, visits = zip(*act_visits)
        act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))

        return acts, act_probs

    def update_with_move(self, last_move):
        """Step forward in the tree, keeping everything we already know
        about the subtree.
        """
        if last_move in self._root._children:
            self._root = self._root._children[last_move]
            self._root._parent = None
        else:
            self._root = TreeNode(None, 1.0)

    def __str__(self):
        return "MCTS"


class MCTSPlayer(object):
    """AI player based on MCTS"""

    def __init__(self, policy_value_function,
                 c_puct=5, n_playout=2000, is_selfplay=0):
        self.mcts = MCTS(policy_value_function, c_puct, n_playout)
        self._is_selfplay = is_selfplay

    def set_player_ind(self, p):
        self.player = p

    def reset_player(self):
        self.mcts.update_with_move(-1)

    def get_action(self, board, temp=1e-3, return_prob=0):
        sensible_moves = board.availables
        # the pi vector returned by MCTS as in the alphaGo Zero paper
        move_probs = np.zeros(board.width*board.height)
        if len(sensible_moves) > 0:
            acts, probs = self.mcts.get_move_probs(board, temp)
            move_probs[list(acts)] = probs
            if self._is_selfplay:
                # add Dirichlet Noise for exploration (needed for
                # self-play training)
                move = np.random.choice(
                    acts,
                    p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))
                )
                # update the root node and reuse the search tree
                self.mcts.update_with_move(move)
            else:
                # with the default temp=1e-3, it is almost equivalent
                # to choosing the move with the highest prob
                move = np.random.choice(acts, p=probs)
                # reset the root node
                self.mcts.update_with_move(-1)
#                location = board.move_to_location(move)
#                print("AI move: %d,%d\n" % (location[0], location[1]))

            if return_prob:
                return move, move_probs
            else:
                return move
        else:
            print("WARNING: the board is full")

    def __str__(self):
        return "MCTS {}".format(self.player)

In [0]:
from __future__ import print_function
import numpy as np


# some utility functions
def softmax(x):
    probs = np.exp(x - np.max(x))
    probs /= np.sum(probs)
    return probs


def relu(X):
    out = np.maximum(X, 0)
    return out


def conv_forward(X, W, b, stride=1, padding=1):
    n_filters, d_filter, h_filter, w_filter = W.shape
    # theano conv2d flips the filters (rotate 180 degree) first
    # while doing the calculation
    W = W[:, :, ::-1, ::-1]
    n_x, d_x, h_x, w_x = X.shape
    h_out = (h_x - h_filter + 2 * padding) / stride + 1
    w_out = (w_x - w_filter + 2 * padding) / stride + 1
    h_out, w_out = int(h_out), int(w_out)
    X_col = im2col_indices(X, h_filter, w_filter,
                           padding=padding, stride=stride)
    W_col = W.reshape(n_filters, -1)
    out = (np.dot(W_col, X_col).T + b).T
    out = out.reshape(n_filters, h_out, w_out, n_x)
    out = out.transpose(3, 0, 1, 2)
    return out


def fc_forward(X, W, b):
    out = np.dot(X, W) + b
    return out


def get_im2col_indices(x_shape, field_height,
                       field_width, padding=1, stride=1):
    # First figure out what the size of the output should be
    N, C, H, W = x_shape
    assert (H + 2 * padding - field_height) % stride == 0
    assert (W + 2 * padding - field_height) % stride == 0
    out_height = int((H + 2 * padding - field_height) / stride + 1)
    out_width = int((W + 2 * padding - field_width) / stride + 1)

    i0 = np.repeat(np.arange(field_height), field_width)
    i0 = np.tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_height), out_width)
    j0 = np.tile(np.arange(field_width), field_height * C)
    j1 = stride * np.tile(np.arange(out_width), out_height)
    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
    j = j0.reshape(-1, 1) + j1.reshape(1, -1)

    k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)

    return (k.astype(int), i.astype(int), j.astype(int))


def im2col_indices(x, field_height, field_width, padding=1, stride=1):
    """ An implementation of im2col based on some fancy indexing """
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

    k, i, j = get_im2col_indices(x.shape, field_height,
                                 field_width, padding, stride)

    cols = x_padded[:, k, i, j]
    C = x.shape[1]
    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
    return cols


class PolicyValueNetNumpy():
    """policy-value network in numpy """
    def __init__(self, board_width, board_height, net_params):
        self.board_width = board_width
        self.board_height = board_height
        self.params = net_params

    def policy_value_fn(self, board):
        """
        input: board
        output: a list of (action, probability) tuples for each available
        action and the score of the board state
        """
        legal_positions = board.availables
        current_state = board.current_state()

        X = current_state.reshape(-1, 4, self.board_width, self.board_height)
        # first 3 conv layers with ReLu nonlinearity
        for i in [0, 2, 4]:
            X = relu(conv_forward(X, self.params[i], self.params[i+1]))
        # policy head
        X_p = relu(conv_forward(X, self.params[6], self.params[7], padding=0))
        X_p = fc_forward(X_p.flatten(), self.params[8], self.params[9])
        act_probs = softmax(X_p)
        # value head
        X_v = relu(conv_forward(X, self.params[10],
                                self.params[11], padding=0))
        X_v = relu(fc_forward(X_v.flatten(), self.params[12], self.params[13]))
        value = np.tanh(fc_forward(X_v, self.params[14], self.params[15]))[0]
        act_probs = zip(legal_positions, act_probs.flatten()[legal_positions])
        return act_probs, value

In [0]:
import numpy as np
import tensorflow as tf


class PolicyValueNet():
    def __init__(self, board_width, board_height, model_file=None):
        self.board_width = board_width
        self.board_height = board_height

        # Define the tensorflow neural network
        # 1. Input:
        self.input_states = tf.placeholder(
                tf.float32, shape=[None, 4, board_height, board_width])
        self.input_state = tf.transpose(self.input_states, [0, 2, 3, 1])
        # 2. Common Networks Layers
        self.conv1 = tf.layers.conv2d(inputs=self.input_state,
                                      filters=32, kernel_size=[3, 3],
                                      padding="same", data_format="channels_last",
                                      activation=tf.nn.relu)
        self.conv2 = tf.layers.conv2d(inputs=self.conv1, filters=64,
                                      kernel_size=[3, 3], padding="same",
                                      data_format="channels_last",
                                      activation=tf.nn.relu)
        self.conv3 = tf.layers.conv2d(inputs=self.conv2, filters=128,
                                      kernel_size=[3, 3], padding="same",
                                      data_format="channels_last",
                                      activation=tf.nn.relu)
        # 3-1 Action Networks
        self.action_conv = tf.layers.conv2d(inputs=self.conv3, filters=4,
                                            kernel_size=[1, 1], padding="same",
                                            data_format="channels_last",
                                            activation=tf.nn.relu)
        # Flatten the tensor
        self.action_conv_flat = tf.reshape(
                self.action_conv, [-1, 4 * board_height * board_width])
        # 3-2 Full connected layer, the output is the log probability of moves
        # on each slot on the board
        self.action_fc = tf.layers.dense(inputs=self.action_conv_flat,
                                         units=board_height * board_width,
                                         activation=tf.nn.log_softmax)
        # 4 Evaluation Networks
        self.evaluation_conv = tf.layers.conv2d(inputs=self.conv3, filters=2,
                                                kernel_size=[1, 1],
                                                padding="same",
                                                data_format="channels_last",
                                                activation=tf.nn.relu)
        self.evaluation_conv_flat = tf.reshape(
                self.evaluation_conv, [-1, 2 * board_height * board_width])
        self.evaluation_fc1 = tf.layers.dense(inputs=self.evaluation_conv_flat,
                                              units=64, activation=tf.nn.relu)
        # output the score of evaluation on current state
        self.evaluation_fc2 = tf.layers.dense(inputs=self.evaluation_fc1,
                                              units=1, activation=tf.nn.tanh)

        # Define the Loss function
        # 1. Label: the array containing if the game wins or not for each state
        self.labels = tf.placeholder(tf.float32, shape=[None, 1])
        # 2. Predictions: the array containing the evaluation score of each state
        # which is self.evaluation_fc2
        # 3-1. Value Loss function
        self.value_loss = tf.losses.mean_squared_error(self.labels,
                                                       self.evaluation_fc2)
        # 3-2. Policy Loss function
        self.mcts_probs = tf.placeholder(
                tf.float32, shape=[None, board_height * board_width])
        self.policy_loss = tf.negative(tf.reduce_mean(
                tf.reduce_sum(tf.multiply(self.mcts_probs, self.action_fc), 1)))
        # 3-3. L2 penalty (regularization)
        l2_penalty_beta = 1e-4
        vars = tf.trainable_variables()
        l2_penalty = l2_penalty_beta * tf.add_n(
            [tf.nn.l2_loss(v) for v in vars if 'bias' not in v.name.lower()])
        # 3-4 Add up to be the Loss function
        self.loss = self.value_loss + self.policy_loss + l2_penalty

        # Define the optimizer we use for training
        self.learning_rate = tf.placeholder(tf.float32)
        self.optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate).minimize(self.loss)

        # Make a session
        self.session = tf.Session()

        # calc policy entropy, for monitoring only
        self.entropy = tf.negative(tf.reduce_mean(
                tf.reduce_sum(tf.exp(self.action_fc) * self.action_fc, 1)))

        # Initialize variables
        init = tf.global_variables_initializer()
        self.session.run(init)

        # For saving and restoring
        self.saver = tf.train.Saver()
        if model_file is not None:
            self.restore_model(model_file)

    def policy_value(self, state_batch):
        """
        input: a batch of states
        output: a batch of action probabilities and state values
        """
        log_act_probs, value = self.session.run(
                [self.action_fc, self.evaluation_fc2],
                feed_dict={self.input_states: state_batch}
                )
        act_probs = np.exp(log_act_probs)
        return act_probs, value

    def policy_value_fn(self, board):
        """
        input: board
        output: a list of (action, probability) tuples for each available
        action and the score of the board state
        """
        legal_positions = board.availables
        current_state = np.ascontiguousarray(board.current_state().reshape(
                -1, 4, self.board_width, self.board_height))
        act_probs, value = self.policy_value(current_state)
        act_probs = zip(legal_positions, act_probs[0][legal_positions])
        return act_probs, value

    def train_step(self, state_batch, mcts_probs, winner_batch, lr):
        """perform a training step"""
        winner_batch = np.reshape(winner_batch, (-1, 1))
        loss, entropy, _ = self.session.run(
                [self.loss, self.entropy, self.optimizer],
                feed_dict={self.input_states: state_batch,
                           self.mcts_probs: mcts_probs,
                           self.labels: winner_batch,
                           self.learning_rate: lr})
        return loss, entropy

    def save_model(self, model_path):
        self.saver.save(self.session, model_path)

    def restore_model(self, model_path):
        self.saver.restore(self.session, model_path)

In [0]:
from __future__ import print_function
import pickle
#from mcts_pure import MCTSPlayer as MCTS_Pure
# from policy_value_net import PolicyValueNet  # Theano and Lasagne
# from policy_value_net_pytorch import PolicyValueNet  # Pytorch
# from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
# from policy_value_net_keras import PolicyValueNet  # Keras


class Human(object):
    """
    human player
    """

    def __init__(self):
        self.player = None

    def set_player_ind(self, p):
        self.player = p

    def get_action(self, board):
        try:
            location = input("Your move: ")
            if isinstance(location, str):  # for python3
                location = [int(n, 10) for n in location.split(",")]
            move = board.location_to_move(location)
        except Exception as e:
            move = -1
        if move == -1 or move not in board.availables:
            print("invalid move")
            move = self.get_action(board)
        return move

    def __str__(self):
        return "Human {}".format(self.player)


async def run(Message_LOC):
    n = 5
    width, height = 8, 8
    model_file = 'drive/AgateV2/best_policy_8_8_5.model'
    try:
        board = Board(width=width, height=height, n_in_row=n)
        game = Game(board)

        # ############### human VS AI ###################
        # load the trained policy_value_net in either Theano/Lasagne, PyTorch or TensorFlow

        # best_policy = PolicyValueNet(width, height, model_file = model_file)
        # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400)

        # load the provided model (trained in Theano/Lasagne) into a MCTS player written in pure numpy
        try:
            policy_param = pickle.load(open(model_file, 'rb'))
        except:
            policy_param = pickle.load(open(model_file, 'rb'),
                                       encoding='bytes')  # To support python3
        best_policy = PolicyValueNetNumpy(width, height, policy_param)
        mcts_player = MCTSPlayer(best_policy.policy_value_fn,
                                 c_puct=5,
                                 n_playout=400)  # set larger n_playout for better performance

        # uncomment the following line to play with pure MCTS (it's much weaker even with a larger n_playout)
        # mcts_player = MCTS_Pure(c_puct=5, n_playout=1000)

        # human player, input your move in the format: 2,3
        human = Human()

        # set start_player=0 for human first
        await game.start_play(human, mcts_player, Message_LOC, start_player=0, is_shown=1)
    except KeyboardInterrupt:
        print('\n\rquit')

In [0]:
import nltk
import os
import string
import tensorflow as tf

import codecs
import os
import tensorflow as tf

from collections import namedtuple
from tensorflow.python.ops import lookup_ops

import codecs
import json
import os
import tensorflow as tf


class HParams:
    def __init__(self, model_dir):
        """
        Args:
            model_dir: Name of the folder storing the hparams.json file.
        """
        self.hparams = self.load_hparams(model_dir)

    @staticmethod
    def load_hparams(model_dir):
        """Load hparams from an existing directory."""
        hparams_file = os.path.join(model_dir, "hparams.json")
        if tf.gfile.Exists(hparams_file):
            print("# Loading hparams from {} ...".format(hparams_file))
            with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f:
                try:
                    hparams_values = json.load(f)
                    hparams = tf.contrib.training.HParams(**hparams_values)
                except ValueError:
                    print("Error loading hparams file.")
                    return None
            return hparams
        else:
            return None

COMMENT_LINE_STT = "#=="
CONVERSATION_SEP = "==="

AUG0_FOLDER = "Augment0"
AUG1_FOLDER = "Augment1"
AUG2_FOLDER = "Augment2"

MAX_LEN = 1000  # Assume no line in the training data is having more than this number of characters
VOCAB_FILE = "vocab.txt"


class TokenizedData:
    def __init__(self, corpus_dir, hparams=None, training=True, buffer_size=8192):
        """
        Args:
            corpus_dir: Name of the folder storing corpus files for training.
            hparams: The object containing the loaded hyper parameters. If None, it will be 
                    initialized here.
            training: Whether to use this object for training.
            buffer_size: The buffer size used for mapping process during data processing.
        """
        if hparams is None:
            self.hparams = HParams(corpus_dir).hparams
        else:
            self.hparams = hparams

        self.src_max_len = self.hparams.src_max_len
        self.tgt_max_len = self.hparams.tgt_max_len

        self.training = training
        self.text_set = None
        self.id_set = None

        vocab_file = os.path.join(corpus_dir, VOCAB_FILE)
        self.vocab_size, _ = check_vocab(vocab_file)
        self.vocab_table = lookup_ops.index_table_from_file(vocab_file,
                                                            default_value=self.hparams.unk_id)
        # print("vocab_size = {}".format(self.vocab_size))

        if training:
            self.case_table = prepare_case_table()
            self.reverse_vocab_table = None
            self._load_corpus(corpus_dir)
            self._convert_to_tokens(buffer_size)
        else:
            self.case_table = None
            self.reverse_vocab_table = \
                lookup_ops.index_to_string_table_from_file(vocab_file,
                                                           default_value=self.hparams.unk_token)

    def get_training_batch(self, num_threads=4):
        assert self.training

        buffer_size = self.hparams.batch_size * 400

        # Comment this line for debugging.
        train_set = self.id_set.shuffle(buffer_size=buffer_size)

        # Create a target input prefixed with BOS and a target output suffixed with EOS.
        # After this mapping, each element in the train_set contains 3 columns/items.
        train_set = train_set.map(lambda src, tgt:
                                  (src, tf.concat(([self.hparams.bos_id], tgt), 0),
                                   tf.concat((tgt, [self.hparams.eos_id]), 0)),
                                  num_parallel_calls=num_threads).prefetch(buffer_size)

        # Add in sequence lengths.
        train_set = train_set.map(lambda src, tgt_in, tgt_out:
                                  (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
                                  num_parallel_calls=num_threads).prefetch(buffer_size)

        def batching_func(x):
            return x.padded_batch(
                self.hparams.batch_size,
                # The first three entries are the source and target line rows, these have unknown-length
                # vectors. The last two entries are the source and target row sizes, which are scalars.
                padded_shapes=(tf.TensorShape([None]),  # src
                               tf.TensorShape([None]),  # tgt_input
                               tf.TensorShape([None]),  # tgt_output
                               tf.TensorShape([]),      # src_len
                               tf.TensorShape([])),     # tgt_len
                # Pad the source and target sequences with eos tokens. Though we don't generally need to
                # do this since later on we will be masking out calculations past the true sequence.
                padding_values=(self.hparams.eos_id,  # src
                                self.hparams.eos_id,  # tgt_input
                                self.hparams.eos_id,  # tgt_output
                                0,       # src_len -- unused
                                0))      # tgt_len -- unused

        if self.hparams.num_buckets > 1:
            bucket_width = (self.src_max_len + self.hparams.num_buckets - 1) // self.hparams.num_buckets

            # Parameters match the columns in each element of the dataset.
            def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
                # Calculate bucket_width by maximum source sequence length. Pairs with length [0, bucket_width)
                # go to bucket 0, length [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with
                # length over ((num_bucket-1) * bucket_width) words all go into the last bucket.
                # Bucket sentence pairs by the length of their source sentence and target sentence.
                bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)
                return tf.to_int64(tf.minimum(self.hparams.num_buckets, bucket_id))

            # No key to filter the dataset. Therefore the key is unused.
            def reduce_func(unused_key, windowed_data):
                return batching_func(windowed_data)

            batched_dataset = train_set.apply(
                tf.contrib.data.group_by_window(key_func=key_func,
                                                reduce_func=reduce_func,
                                                window_size=self.hparams.batch_size))
        else:
            batched_dataset = batching_func(train_set)

        batched_iter = batched_dataset.make_initializable_iterator()
        (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len) = (batched_iter.get_next())

        return BatchedInput(initializer=batched_iter.initializer,
                            source=src_ids,
                            target_input=tgt_input_ids,
                            target_output=tgt_output_ids,
                            source_sequence_length=src_seq_len,
                            target_sequence_length=tgt_seq_len)

    def get_inference_batch(self, src_dataset):
        text_dataset = src_dataset.map(lambda src: tf.string_split([src]).values)

        if self.hparams.src_max_len_infer:
            text_dataset = text_dataset.map(lambda src: src[:self.hparams.src_max_len_infer])
        # Convert the word strings to ids
        id_dataset = text_dataset.map(lambda src: tf.cast(self.vocab_table.lookup(src),
                                                          tf.int32))
        if self.hparams.source_reverse:
            id_dataset = id_dataset.map(lambda src: tf.reverse(src, axis=[0]))
        # Add in the word counts.
        id_dataset = id_dataset.map(lambda src: (src, tf.size(src)))

        def batching_func(x):
            return x.padded_batch(
                self.hparams.batch_size_infer,
                # The entry is the source line rows; this has unknown-length vectors.
                # The last entry is the source row size; this is a scalar.
                padded_shapes=(tf.TensorShape([None]),  # src
                               tf.TensorShape([])),     # src_len
                # Pad the source sequences with eos tokens. Though notice we don't generally need to
                # do this since later on we will be masking out calculations past the true sequence.
                padding_values=(self.hparams.eos_id,  # src
                                0))                   # src_len -- unused

        id_dataset = batching_func(id_dataset)

        infer_iter = id_dataset.make_initializable_iterator()
        (src_ids, src_seq_len) = infer_iter.get_next()

        return BatchedInput(initializer=infer_iter.initializer,
                            source=src_ids,
                            target_input=None,
                            target_output=None,
                            source_sequence_length=src_seq_len,
                            target_sequence_length=None)

    def _load_corpus(self, corpus_dir):
        for fd in range(2, -1, -1):
            file_list = []
            if fd == 0:
                file_dir = os.path.join(corpus_dir, AUG0_FOLDER)
            elif fd == 1:
                file_dir = os.path.join(corpus_dir, AUG1_FOLDER)
            else:
                file_dir = os.path.join(corpus_dir, AUG2_FOLDER)

            for data_file in sorted(os.listdir(file_dir)):
                full_path_name = os.path.join(file_dir, data_file)
                if os.path.isfile(full_path_name) and data_file.lower().endswith('.txt'):
                    file_list.append(full_path_name)

            assert len(file_list) > 0
            dataset = tf.data.TextLineDataset(file_list)

            src_dataset = dataset.filter(lambda line:
                                         tf.logical_and(tf.size(line) > 0,
                                                        tf.equal(tf.substr(line, 0, 2), tf.constant('Q:'))))
            src_dataset = src_dataset.map(lambda line:
                                          tf.substr(line, 2, MAX_LEN)).prefetch(4096)
            tgt_dataset = dataset.filter(lambda line:
                                         tf.logical_and(tf.size(line) > 0,
                                                        tf.equal(tf.substr(line, 0, 2), tf.constant('A:'))))
            tgt_dataset = tgt_dataset.map(lambda line:
                                          tf.substr(line, 2, MAX_LEN)).prefetch(4096)

            src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
            if fd == 1:
                src_tgt_dataset = src_tgt_dataset.repeat(self.hparams.aug1_repeat_times)
            elif fd == 2:
                src_tgt_dataset = src_tgt_dataset.repeat(self.hparams.aug2_repeat_times)

            if self.text_set is None:
                self.text_set = src_tgt_dataset
            else:
                self.text_set = self.text_set.concatenate(src_tgt_dataset)

    def _convert_to_tokens(self, buffer_size):
        # The following 3 steps act as a python String lower() function
        # Split to characters
        self.text_set = self.text_set.map(lambda src, tgt:
                                          (tf.string_split([src], delimiter='').values,
                                           tf.string_split([tgt], delimiter='').values)
                                          ).prefetch(buffer_size)
        # Convert all upper case characters to lower case characters
        self.text_set = self.text_set.map(lambda src, tgt:
                                          (self.case_table.lookup(src), self.case_table.lookup(tgt))
                                          ).prefetch(buffer_size)
        # Join characters back to strings
        self.text_set = self.text_set.map(lambda src, tgt:
                                          (tf.reduce_join([src]), tf.reduce_join([tgt]))
                                          ).prefetch(buffer_size)

        # Split to word tokens
        self.text_set = self.text_set.map(lambda src, tgt:
                                          (tf.string_split([src]).values, tf.string_split([tgt]).values)
                                          ).prefetch(buffer_size)
        # Remove sentences longer than the model allows
        self.text_set = self.text_set.map(lambda src, tgt:
                                          (src[:self.src_max_len], tgt[:self.tgt_max_len])
                                          ).prefetch(buffer_size)

        # Reverse the source sentence if applicable
        if self.hparams.source_reverse:
            self.text_set = self.text_set.map(lambda src, tgt:
                                              (tf.reverse(src, axis=[0]), tgt)
                                              ).prefetch(buffer_size)

        # Convert the word strings to ids.  Word strings that are not in the vocab get
        # the lookup table's default_value integer.
        self.id_set = self.text_set.map(lambda src, tgt:
                                        (tf.cast(self.vocab_table.lookup(src), tf.int32),
                                         tf.cast(self.vocab_table.lookup(tgt), tf.int32))
                                        ).prefetch(buffer_size)


def check_vocab(vocab_file):
    """Check to make sure vocab_file exists"""
    if tf.gfile.Exists(vocab_file):
        vocab_list = []
        with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
            for word in f:
                vocab_list.append(word.strip())
    else:
        raise ValueError("The vocab_file does not exist. Please run the script to create it.")

    return len(vocab_list), vocab_list


def prepare_case_table():
    keys = tf.constant([chr(i) for i in range(32, 127)])

    l1 = [chr(i) for i in range(32, 65)]
    l2 = [chr(i) for i in range(97, 123)]
    l3 = [chr(i) for i in range(91, 127)]
    values = tf.constant(l1 + l2 + l3)

    return tf.contrib.lookup.HashTable(
        tf.contrib.lookup.KeyValueTensorInitializer(keys, values), ' ')


class BatchedInput(namedtuple("BatchedInput",
                              ["initializer",
                               "source",
                               "target_input",
                               "target_output",
                               "source_sequence_length",
                               "target_sequence_length"])):
    pass

import tensorflow as tf
import tensorflow as tf


def get_initializer(init_op, seed=None, init_weight=None):
    """Create an initializer. init_weight is only for uniform."""
    if init_op == "uniform":
        assert init_weight
        return tf.random_uniform_initializer(-init_weight, init_weight, seed=seed)
    elif init_op == "glorot_normal":
        return tf.contrib.keras.initializers.glorot_normal(seed=seed)
    elif init_op == "glorot_uniform":
        return tf.contrib.keras.initializers.glorot_uniform(seed=seed)
    else:
        raise ValueError("Unknown init_op %s" % init_op)


# def get_device_str(device_id, num_gpus):
#     """Return a device string for multi-GPU setup."""
#     if num_gpus == 0:
#         return "/cpu:0"
#     device_str_output = "/gpu:%d" % (device_id % num_gpus)
#     return device_str_output


def create_embbeding(vocab_size, embed_size, dtype=tf.float32, scope=None):
    """Create embedding matrix for both encoder and decoder."""
    with tf.variable_scope(scope or "embeddings", dtype=dtype):
        embedding = tf.get_variable("embedding", [vocab_size, embed_size], dtype)

    return embedding


def _single_cell(num_units, keep_prob, device_str=None):
    """Create an instance of a single RNN cell."""
    single_cell = tf.contrib.rnn.GRUCell(num_units)

    if keep_prob < 1.0:
        single_cell = tf.contrib.rnn.DropoutWrapper(cell=single_cell, input_keep_prob=keep_prob)

    # Device Wrapper
    if device_str:
        single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)

    return single_cell


def create_rnn_cell(num_units, num_layers, keep_prob):
    """Create multi-layer RNN cell."""
    cell_list = []
    for i in range(num_layers):
        single_cell = _single_cell(num_units=num_units, keep_prob=keep_prob)
        cell_list.append(single_cell)

    if len(cell_list) == 1:  # Single layer.
        return cell_list[0]
    else:  # Multi layers
        return tf.contrib.rnn.MultiRNNCell(cell_list)


def gradient_clip(gradients, max_gradient_norm):
    """Clipping gradients of a model."""
    clipped_gradients, gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm)
    gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)]
    gradient_norm_summary.append(
        tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients)))

    return clipped_gradients, gradient_norm_summary

from tensorflow.python.layers import core as layers_core


class ModelCreator(object):
    """Sequence-to-sequence model creator to create models for training or inference"""
    def __init__(self, training, tokenized_data, batch_input, scope=None):
        """
        Create the model.

        Args:
            training: A boolean value to indicate whether this model will be used for training.
            tokenized_data: The data object containing all information required for the model.
            scope: scope of the model.
        """
        self.training = training
        self.batch_input = batch_input
        self.vocab_table = tokenized_data.vocab_table
        self.vocab_size = tokenized_data.vocab_size
        self.reverse_vocab_table = tokenized_data.reverse_vocab_table

        hparams = tokenized_data.hparams
        self.hparams = hparams

        self.num_layers = hparams.num_layers
        self.time_major = hparams.time_major

        # Initializer
        initializer = get_initializer(
            hparams.init_op, hparams.random_seed, hparams.init_weight)
        tf.get_variable_scope().set_initializer(initializer)

        # Embeddings
        self.embedding = (create_embbeding(vocab_size=self.vocab_size,
                                                        embed_size=hparams.num_units,
                                                        scope=scope))
        # This batch_size might vary among each batch instance due to the bucketing and/or reach
        # the end of the training set. Treat it as size_of_the_batch.
        self.batch_size = tf.size(self.batch_input.source_sequence_length)

        # Projection
        with tf.variable_scope(scope or "build_network"):
            with tf.variable_scope("decoder/output_projection"):
                self.output_layer = layers_core.Dense(
                    self.vocab_size, use_bias=False, name="output_projection")

        # Training or inference graph
        print("# Building graph for the model ...")
        res = self.build_graph(hparams, scope=scope)

        if training:
            self.train_loss = res[1]
            self.word_count = tf.reduce_sum(self.batch_input.source_sequence_length) + \
                              tf.reduce_sum(self.batch_input.target_sequence_length)
            # Count the number of predicted words for compute perplexity.
            self.predict_count = tf.reduce_sum(self.batch_input.target_sequence_length)
        else:
            self.infer_logits, _, self.final_context_state, self.sample_id = res
            self.sample_words = self.reverse_vocab_table.lookup(tf.to_int64(self.sample_id))

        self.global_step = tf.Variable(0, trainable=False)

        params = tf.trainable_variables()

        # Gradients update operation for training the model.
        if training:
            self.learning_rate = tf.placeholder(tf.float32, shape=[], name='learning_rate')
            opt = tf.train.AdamOptimizer(self.learning_rate)

            gradients = tf.gradients(self.train_loss, params)

            clipped_gradients, gradient_norm_summary = gradient_clip(
                gradients, max_gradient_norm=hparams.max_gradient_norm)

            self.update = opt.apply_gradients(
                zip(clipped_gradients, params), global_step=self.global_step)

            # Summary
            self.train_summary = tf.summary.merge([
                tf.summary.scalar("learning_rate", self.learning_rate),
                tf.summary.scalar("train_loss", self.train_loss),
            ] + gradient_norm_summary)
        else:
            self.infer_summary = tf.no_op()

        # Saver
        self.saver = tf.train.Saver(tf.global_variables())

        # Print trainable variables
        if training:
            print("# Trainable variables:")
            for param in params:
                print("  {}, {}, {}".format(param.name, str(param.get_shape()), param.op.device))

    def train_step(self, sess, learning_rate):
        """Run one step of training."""
        assert self.training

        return sess.run([self.update,
                         self.train_loss,
                         self.predict_count,
                         self.train_summary,
                         self.global_step,
                         self.word_count,
                         self.batch_size],
                        feed_dict={self.learning_rate: learning_rate})

    def build_graph(self, hparams, scope=None):
        """Creates a sequence-to-sequence model with dynamic RNN decoder API."""
        dtype = tf.float32

        with tf.variable_scope(scope or "dynamic_seq2seq", dtype=dtype):
            # Encoder
            encoder_outputs, encoder_state = self._build_encoder(hparams)

            # Decoder
            logits, sample_id, final_context_state = self._build_decoder(
                encoder_outputs, encoder_state, hparams)

            # Loss
            if self.training:
                loss = self._compute_loss(logits)
            else:
                loss = None

            return logits, loss, final_context_state, sample_id

    def _build_encoder(self, hparams):
        """Build an encoder."""
        source = self.batch_input.source
        if self.time_major:
            source = tf.transpose(source)

        with tf.variable_scope("encoder") as scope:
            dtype = scope.dtype
            # Look up embedding, emp_inp: [max_time, batch_size, num_units]
            encoder_emb_inp = tf.nn.embedding_lookup(self.embedding, source)

            # Encoder_outpus: [max_time, batch_size, num_units]
            cell = self._build_encoder_cell(hparams)

            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                cell,
                encoder_emb_inp,
                dtype=dtype,
                sequence_length=self.batch_input.source_sequence_length,
                time_major=self.time_major)

        return encoder_outputs, encoder_state

    def _build_encoder_cell(self, hparams):
        """Build a multi-layer RNN cell that can be used by encoder."""
        return create_rnn_cell(
            num_units=hparams.num_units,
            num_layers=hparams.num_layers,
            keep_prob=hparams.keep_prob)

    def _build_decoder(self, encoder_outputs, encoder_state, hparams):
        """Build and run a RNN decoder with a final projection layer."""
        bos_id = tf.cast(self.vocab_table.lookup(tf.constant(hparams.bos_token)), tf.int32)
        eos_id = tf.cast(self.vocab_table.lookup(tf.constant(hparams.eos_token)), tf.int32)

        # maximum_iteration: The maximum decoding steps.
        if hparams.tgt_max_len_infer:
            maximum_iterations = hparams.tgt_max_len_infer
        else:
            decoding_length_factor = 2.0
            max_encoder_length = tf.reduce_max(self.batch_input.source_sequence_length)
            maximum_iterations = tf.to_int32(tf.round(
                tf.to_float(max_encoder_length) * decoding_length_factor))

        # Decoder.
        with tf.variable_scope("decoder") as decoder_scope:
            cell, decoder_initial_state = self._build_decoder_cell(
                hparams, encoder_outputs, encoder_state,
                self.batch_input.source_sequence_length)

            # Training
            if self.training:
                # decoder_emp_inp: [max_time, batch_size, num_units]
                target_input = self.batch_input.target_input
                if self.time_major:
                    target_input = tf.transpose(target_input)
                decoder_emb_inp = tf.nn.embedding_lookup(self.embedding, target_input)

                # Helper
                helper = tf.contrib.seq2seq.TrainingHelper(
                    decoder_emb_inp, self.batch_input.target_sequence_length,
                    time_major=self.time_major)

                # Decoder
                my_decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell,
                    helper,
                    decoder_initial_state,)

                # Dynamic decoding
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder,
                    output_time_major=self.time_major,
                    swap_memory=True,
                    scope=decoder_scope)

                sample_id = outputs.sample_id
                logits = self.output_layer(outputs.rnn_output)
            # Inference
            else:
                beam_width = hparams.beam_width
                length_penalty_weight = hparams.length_penalty_weight
                start_tokens = tf.fill([self.batch_size], bos_id)
                end_token = eos_id

                if beam_width > 0:
                    my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                        cell=cell,
                        embedding=self.embedding,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=beam_width,
                        output_layer=self.output_layer,
                        length_penalty_weight=length_penalty_weight)
                else:
                    # Helper
                    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                        self.embedding, start_tokens, end_token)

                    # Decoder
                    my_decoder = tf.contrib.seq2seq.BasicDecoder(
                        cell,
                        helper,
                        decoder_initial_state,
                        output_layer=self.output_layer  # applied per timestep
                    )

                # Dynamic decoding
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder,
                    maximum_iterations=maximum_iterations,
                    output_time_major=self.time_major,
                    swap_memory=True,
                    scope=decoder_scope)

                if beam_width > 0:
                    logits = tf.no_op()
                    sample_id = outputs.predicted_ids
                else:
                    logits = outputs.rnn_output
                    sample_id = outputs.sample_id

        return logits, sample_id, final_context_state

    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length):
        """Build a RNN cell with attention mechanism that can be used by decoder."""
        num_units = hparams.num_units
        num_layers = hparams.num_layers
        beam_width = hparams.beam_width

        dtype = tf.float32

        if self.time_major:
            memory = tf.transpose(encoder_outputs, [1, 0, 2])
        else:
            memory = encoder_outputs

        if not self.training and beam_width > 0:
            memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width)
            source_sequence_length = tf.contrib.seq2seq.tile_batch(source_sequence_length,
                                                                   multiplier=beam_width)
            encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state,
                                                          multiplier=beam_width)
            batch_size = self.batch_size * beam_width
        else:
            batch_size = self.batch_size

        attention_mechanism = tf.contrib.seq2seq.LuongAttention(
            num_units, memory, memory_sequence_length=source_sequence_length)

        cell = create_rnn_cell(
            num_units=num_units,
            num_layers=num_layers,
            keep_prob=hparams.keep_prob)

        # Only generate alignment in greedy INFER mode.
        alignment_history = (not self.training and beam_width == 0)
        cell = tf.contrib.seq2seq.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=num_units,
            alignment_history=alignment_history,
            name="attention")

        if hparams.pass_hidden_state:
            decoder_initial_state = cell.zero_state(batch_size, dtype).clone(cell_state=encoder_state)
        else:
            decoder_initial_state = cell.zero_state(batch_size, dtype)

        return cell, decoder_initial_state

    def _compute_loss(self, logits):
        """Compute optimization loss."""
        target_output = self.batch_input.target_output
        if self.time_major:
            target_output = tf.transpose(target_output)
        max_time = self.get_max_time(target_output)
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=target_output, logits=logits)
        target_weights = tf.sequence_mask(
            self.batch_input.target_sequence_length, max_time, dtype=logits.dtype)
        if self.time_major:
            target_weights = tf.transpose(target_weights)

        loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(self.batch_size)
        return loss

    def get_max_time(self, tensor):
        time_axis = 0 if self.time_major else 1
        return tensor.shape[time_axis].value or tf.shape(tensor)[time_axis]

    def infer(self, sess):
        assert not self.training
        _, infer_summary, _, sample_words = sess.run([
            self.infer_logits, self.infer_summary, self.sample_id, self.sample_words
        ])

        # make sure outputs is of shape [batch_size, time]
        if self.time_major:
            sample_words = sample_words.transpose()

        return sample_words, infer_summary
      
import os

UPPER_FILE = "upper_words.txt"
STORIES_FILE = "stories.txt"
JOKES_FILE = "jokes.txt"


class KnowledgeBase:
    def __init__(self):
        self.upper_words = {}
        self.stories = {}
        self.jokes = []

    def load_knbase(self, knbase_dir):
        """
        Args:
             knbase_dir: Name of the KnowledgeBase folder. The file names inside are fixed.
        """
        upper_file_name = os.path.join(knbase_dir, UPPER_FILE)
        stories_file_name = os.path.join(knbase_dir, STORIES_FILE)
        jokes_file_name = os.path.join(knbase_dir, JOKES_FILE)

        with open(upper_file_name, 'r') as upper_f:
            for line in upper_f:
                ln = line.strip()
                if not ln or ln.startswith('#'):
                    continue
                cap_words = ln.split(',')
                for cpw in cap_words:
                    tmp = cpw.strip()
                    self.upper_words[tmp.lower()] = tmp

        with open(stories_file_name, 'r') as stories_f:
            s_name, s_content = '', ''
            for line in stories_f:
                ln = line.strip()
                if not ln or ln.startswith('#'):
                    continue
                if ln.startswith('_NAME:'):
                    if s_name != '' and s_content != '':
                        self.stories[s_name] = s_content
                        s_name, s_content = '', ''
                    s_name = ln[6:].strip().lower()
                elif ln.startswith('_CONTENT:'):
                    s_content = ln[9:].strip()
                else:
                    s_content += ' ' + ln.strip()

            if s_name != '' and s_content != '':  # The last one
                self.stories[s_name] = s_content

        with open(jokes_file_name, 'r') as jokes_f:
            for line in jokes_f:
                ln = line.strip()
                if not ln or ln.startswith('#'):
                    continue
                self.jokes.append(ln)
                
class SessionData:
    def __init__(self):
        self.session_dict = {}

    def add_session(self):
        items = self.session_dict.items()
        if items:
            last_id = max(k for k, v in items)
        else:
            last_id = 0
        new_id = last_id + 1

        self.session_dict[new_id] = ChatSession(new_id)
        return new_id

    def get_session(self, session_id):
        return self.session_dict[session_id]


class ChatSession:
    def __init__(self, session_id):
        """
        Args:
            session_id: The integer ID of the chat session.
        """
        self.session_id = session_id

        self.howru_asked = False

        self.user_name = None
        self.call_me = None

        self.last_question = None
        self.last_answer = None
        self.update_pair = True

        self.last_topic = None
        self.keep_topic = False

        # Will be storing the information of the pending action:
        # The action function name, the parameter for answer yes, and the parameter for answer no.
        self.pending_action = {'func': None, 'Yes': None, 'No': None}

    def before_prediction(self):
        self.update_pair = True
        self.keep_topic = False

    def after_prediction(self, new_question, new_answer):
        self._update_last_pair(new_question, new_answer)
        self._clear_last_topic()

    def _update_last_pair(self, new_question, new_answer):
        """
        Last pair is updated after each prediction except in a few cases.
        """
        if self.update_pair:
            self.last_question = new_question
            self.last_answer = new_answer

    def _clear_last_topic(self):
        """
        Last topic is cleared after each prediction except in a few cases.
        """
        if not self.keep_topic:
            self.last_topic = None

    def update_pending_action(self, func_name, yes_para, no_para):
        self.pending_action['func'] = func_name
        self.pending_action['Yes'] = yes_para
        self.pending_action['No'] = no_para

    def clear_pending_action(self):
        """
        Pending action is, and only is, cleared at the end of function: execute_pending_action_and_reply.
        """
        self.pending_action['func'] = None
        self.pending_action['Yes'] = None
        self.pending_action['No'] = None
        
import re


def check_patterns_and_replace(question):
    pat_matched, new_sentence, para_list = _check_arithmetic_pattern_and_replace(question)

    if not pat_matched:
        pat_matched, new_sentence, para_list = _check_not_username_pattern_and_replace(new_sentence)

    if not pat_matched:
        pat_matched, new_sentence, para_list = _check_username_callme_pattern_and_replace(new_sentence)

    return pat_matched, new_sentence, para_list


def _check_arithmetic_pattern_and_replace(sentence):
    pat_matched, ind_list, num_list = _contains_arithmetic_pattern(sentence)
    if pat_matched:
        s1, e1 = ind_list[0]
        s2, e2 = ind_list[1]
        # Leave spaces around the special tokens so that NLTK knows they are separate tokens
        new_sentence = sentence[:s1] + ' _num1_ ' + sentence[e1:s2] + ' _num2_ ' + sentence[e2:]
        return True, new_sentence, num_list
    else:
        return False, sentence, num_list


def _contains_arithmetic_pattern(sentence):
    numbers = [
        "zero", "one", "two", "three", "four", "five", "six", "seven",
        "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen",
        "fifteen", "sixteen", "seventeen", "eighteen", "nineteen",
        "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety",
        "hundred", "thousand", "million", "billion", "trillion"]

    pat_op1 = re.compile(
        r'\s(plus|add|added|\+|minus|subtract|subtracted|-|times|multiply|multiplied|\*|divide|(divided\s+by)|/)\s',
        re.IGNORECASE)
    pat_op2 = re.compile(r'\s((sum\s+of)|(product\s+of))\s', re.IGNORECASE)
    pat_as = re.compile(r'((\bis\b)|=|(\bequals\b)|(\bget\b))', re.IGNORECASE)

    mat_op1 = re.search(pat_op1, sentence)
    mat_op2 = re.search(pat_op2, sentence)
    mat_as = re.search(pat_as, sentence)
    if (mat_op1 or mat_op2) and mat_as:  # contains an arithmetic operator and an assign operator
        # Replace all occurrences of word "and" with 3 whitespaces before feeding to
        # the pattern matcher.
        pat_and = re.compile(r'\band\b', re.IGNORECASE)
        if mat_op1:
            tmp_sentence = pat_and.sub('   ', sentence)
        else:  # Do not support word 'and' in the English numbers any more as that can be ambiguous.
            tmp_sentence = pat_and.sub('_T_', sentence)

        number_rx = r'(?:{})'.format('|'.join(numbers))
        pat_num = re.compile(r'\b{0}(?:(?:\s+(?:and\s+)?|-){0})*\b|\d+'.format(number_rx),
                             re.IGNORECASE)
        ind_list = [(m.start(0), m.end(0)) for m in re.finditer(pat_num, tmp_sentence)]
        num_list = []
        if len(ind_list) == 2:  # contains exactly two numbers
            for start, end in ind_list:
                text = sentence[start:end]
                text_int = _text2int(text)
                if text_int == -1:
                    return False, [], []
                num_list.append(text_int)

            return True, ind_list, num_list

    return False, [], []


def _text2int(text):
    if text.isdigit():
        return int(text)

    num_words = {}
    units = [
        "zero", "one", "two", "three", "four", "five", "six", "seven", "eight",
        "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
        "sixteen", "seventeen", "eighteen", "nineteen",
    ]
    tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
    scales = ["hundred", "thousand", "million", "billion", "trillion"]

    num_words["and"] = (1, 0)
    for idx, word in enumerate(units):
        num_words[word] = (1, idx)
    for idx, word in enumerate(tens):
        num_words[word] = (1, idx * 10)
    for idx, word in enumerate(scales):
        num_words[word] = (10 ** (idx * 3 or 2), 0)

    current = result = 0
    for word in text.replace("-", " ").lower().split():
        if word not in num_words:
            return -1

        scale, increment = num_words[word]
        current = current * scale + increment
        if scale > 100:
            result += current
            current = 0

    return result + current


def _check_not_username_pattern_and_replace(sentence):
    import nltk

    tokens = nltk.word_tokenize(sentence)
    tmp_sentence = ' '.join(tokens[:]).strip()

    pat_not_but = re.compile(r'(\s|^)my\s+name\s+is\s+(not|n\'t)\s+(.+?)(\s\.|\s,|\s!)\s*but\s+(.+?)(\s\.|\s,|\s!|$)',
                             re.IGNORECASE)
    mat_not_but = re.search(pat_not_but, tmp_sentence)

    pat_not = re.compile(r'(\s|^)my\s+name\s+is\s+(not|n\'t)\s+(.+?)(\s\.|\s,|\s!|$)', re.IGNORECASE)
    mat_not = re.search(pat_not, tmp_sentence)

    para_list = []
    found = 0
    if mat_not_but:
        wrong_name = mat_not_but.group(3).strip()
        correct_name = mat_not_but.group(5).strip()
        para_list.append(correct_name)
        new_sentence = sentence.replace(wrong_name, ' _ignored_ ', 1).replace(correct_name, ' _name_ ', 1)
        # print("User name is not: {}, but {}.".format(wrong_name, correct_name))
        found += 1
    elif mat_not:
        wrong_name = mat_not.group(3).strip()
        new_sentence = sentence.replace(wrong_name, ' _ignored_ ', 1)
        # print("User name is not: {}.".format(wrong_name))
        found += 1
    else:
        new_sentence = sentence
        # print("Wrong name not found.")

    if found >= 1:
        return True, new_sentence, para_list
    else:
        return False, sentence, para_list


def _check_username_callme_pattern_and_replace(sentence):
    import nltk

    tokens = nltk.word_tokenize(sentence)
    tmp_sentence = ' '.join(tokens[:]).strip()

    pat_name = re.compile(r'(\s|^)my\s+name\s+is\s+(.+?)(\s\.|\s,|\s!|$)', re.IGNORECASE)
    pat_call = re.compile(r'(\s|^)call\s+me\s+(.+?)(\s(please|pls))?(\s\.|\s,|\s!|$)', re.IGNORECASE)

    mat_name = re.search(pat_name, tmp_sentence)
    mat_call = re.search(pat_call, tmp_sentence)

    para_list = []
    found = 0
    if mat_name:
        user_name = mat_name.group(2).strip()
        para_list.append(user_name)
        new_sentence = sentence.replace(user_name, ' _name_ ', 1)
        # print("User name is: {}.".format(user_name))
        found += 1
    else:
        para_list.append('')  # reserve the slot
        new_sentence = sentence
        # print("User name not found.")

    if mat_call:
        call_me = mat_call.group(2).strip()
        para_list.append(call_me)
        new_sentence = new_sentence.replace(call_me, ' _callme_ ')
        # print("Call me {}.".format(call_me))
        found += 1
    else:
        para_list.append('')
        # print("call me not found.")

    if found >= 1:
        return True, new_sentence, para_list
    else:
        return False, sentence, para_list


if __name__ == "__main__":
    sentence = "My name is jack brown. Please call me Mr. Brown."
    print("# {}".format(sentence))
    _, ns, _ = _check_username_callme_pattern_and_replace(sentence)
    print(ns)

    sentence = "My name is Bo Shao."
    print("# {}".format(sentence))
    _, ns, _ = _check_username_callme_pattern_and_replace(sentence)
    print(ns)

    sentence = "You can call me Dr. Shao."
    print("# {}".format(sentence))
    _, ns, _ = _check_username_callme_pattern_and_replace(sentence)
    print(ns)

    sentence = "Call me Ms. Tailor please."
    print("# {}".format(sentence))
    _, ns, _ = _check_username_callme_pattern_and_replace(sentence)
    print(ns)

    sentence = "My name is Mark. Please call me Mark D."
    print("# {}".format(sentence))
    _, ns, _ = _check_username_callme_pattern_and_replace(sentence)
    print(ns)

    sentence = "My name is not just Shao, but Bo Shao."
    print("# {}".format(sentence))
    _, ns, _ = _check_not_username_pattern_and_replace(sentence)
    print(ns)

    sentence = "My name is not just Shao."
    print("# {}".format(sentence))
    _, ns, _ = _check_not_username_pattern_and_replace(sentence)
    print(ns)
    
import calendar as cal
import datetime as dt
import random
import re
import time


class FunctionData:
    easy_list = [
        "", "",
        "Here you are: ",
        "Here is the result: ",
        "That's easy: ",
        "That was an easy one: ",
        "It was a piece of cake: ",
        "That's simple, and I know how to solve it: ",
        "That wasn't hard. Here is the result: ",
        "Oh, I know how to deal with this: "
    ]
    hard_list = [
        "", "",
        "Here you are: ",
        "Here is the result: ",
        "That's a little hard: ",
        "That was an tough one, and I had to use a calculator: ",
        "That's a little difficult, but I know how to solve it: ",
        "It was hard and took me a little while to figure it out. Here is the result: ",
        "It took me a little while, and finally I got the result: ",
        "I had to use my cell phone for this calculation. Here is the outcome: "
    ]
    ask_howru_list = [
        "And you?",
        "How are you?",
        "How about yourself?"
    ]
    ask_name_list = [
        "May I also have your name, please?",
        "Would you also like to tell me your name, please?",
        "And, how should I call you, please?",
        "And, what do you want me to call you, dear sir or madam?"
    ]

    def __init__(self, knowledge_base, chat_session, html_format):
        """
        Args:
            knowledge_base: The knowledge base data needed for prediction.
            chat_session: The chat session object that can be read and written.
            html_format: Whether out_sentence is in HTML format.
        """
        self.knowledge_base = knowledge_base
        self.chat_session = chat_session
        self.html_format = html_format

    """
    # Rule 2: Date and Time
    """
    @staticmethod
    def get_date_time():
        return time.strftime("%Y-%m-%d %H:%M")

    @staticmethod
    def get_time():
        return time.strftime("%I:%M %p")

    @staticmethod
    def get_today():
        return "{:%B %d, %Y}".format(dt.date.today())

    @staticmethod
    def get_weekday(day_delta):
        now = dt.datetime.now()
        if day_delta == 'd_2':
            day_time = now - dt.timedelta(days=2)
        elif day_delta == 'd_1':
            day_time = now - dt.timedelta(days=1)
        elif day_delta == 'd1':
            day_time = now + dt.timedelta(days=1)
        elif day_delta == 'd2':
            day_time = now + dt.timedelta(days=2)
        else:
            day_time = now

        weekday = cal.day_name[day_time.weekday()]
        return "{}, {:%B %d, %Y}".format(weekday, day_time)

    """
    # Rule 3: Stories and Jokes, and last topic
    """
    def get_story_any(self):
        self.chat_session.last_topic = "STORY"
        self.chat_session.keep_topic = True

        stories = self.knowledge_base.stories
        _, content = random.choice(list(stories.items()))
        if not self.html_format:
            content = re.sub(r'_np_', '', content)
        return content

    def get_story_name(self, story_name):
        self.chat_session.last_topic = "STORY"
        self.chat_session.keep_topic = True

        stories = self.knowledge_base.stories
        content = stories[story_name]
        if not self.html_format:
            content = re.sub(r'_np_', '', content)
        return content

    def get_joke_any(self):
        self.chat_session.last_topic = "JOKE"
        self.chat_session.keep_topic = True

        jokes = self.knowledge_base.jokes
        content = random.choice(jokes)
        if not self.html_format:
            content = re.sub(r'_np_', '', content)
        return content

    def continue_last_topic(self):
        if self.chat_session.last_topic == "STORY":
            self.chat_session.keep_topic = True
            return self.get_story_any()
        elif self.chat_session.last_topic == "JOKE":
            self.chat_session.keep_topic = True
            return self.get_joke_any()
        else:
            return "Sorry, but what topic do you prefer?"

    """
    # Rule 4: Arithmetic ops
    """
    @staticmethod
    def get_number_plus(num1, num2):
        res = num1 + num2
        desc = random.choice(FunctionData.easy_list)
        return "{}{} + {} = {}".format(desc, num1, num2, res)

    @staticmethod
    def get_number_minus(num1, num2):
        res = num1 - num2
        desc = random.choice(FunctionData.easy_list)
        return "{}{} - {} = {}".format(desc, num1, num2, res)

    @staticmethod
    def get_number_multiply(num1, num2):
        res = num1 * num2
        if num1 > 100 and num2 > 100 and num1 % 2 == 1 and num2 % 2 == 1:
            desc = random.choice(FunctionData.hard_list)
        else:
            desc = random.choice(FunctionData.easy_list)
        return "{}{} * {} = {}".format(desc, num1, num2, res)

    @staticmethod
    def get_number_divide(num1, num2):
        if num2 == 0:
            return "Sorry, but that does not make sense as the divisor cannot be zero."
        else:
            res = num1 / num2
            if isinstance(res, int):
                if 50 < num1 != num2 > 50:
                    desc = random.choice(FunctionData.hard_list)
                else:
                    desc = random.choice(FunctionData.easy_list)
                return "{}{} / {} = {}".format(desc, num1, num2, res)
            else:
                if num1 > 20 and num2 > 20:
                    desc = random.choice(FunctionData.hard_list)
                else:
                    desc = random.choice(FunctionData.easy_list)
                return "{}{} / {} = {:.2f}".format(desc, num1, num2, res)

    """
    # Rule 5: User name, call me information, and last question and answer
    """
    def ask_howru_if_not_yet(self):
        howru_asked = self.chat_session.howru_asked
        if howru_asked:
            return ""
        else:
            self.chat_session.howru_asked = True
            return random.choice(FunctionData.ask_howru_list)

    def ask_name_if_not_yet(self):
        user_name = self.chat_session.user_name
        call_me = self.chat_session.call_me
        if user_name or call_me:
            return ""
        else:
            return random.choice(FunctionData.ask_name_list)

    def get_user_name_and_reply(self):
        user_name = self.chat_session.user_name
        if user_name and user_name.strip() != '':
            return user_name
        else:
            return "Did you tell me your name? Sorry, I missed that."

    def get_callme(self, punc_type):
        call_me = self.chat_session.call_me
        user_name = self.chat_session.user_name

        if call_me and call_me.strip() != '':
            if punc_type == 'comma0':
                return ", {}".format(call_me)
            else:
                return call_me
        elif user_name and user_name.strip() != '':
            if punc_type == 'comma0':
                return ", {}".format(user_name)
            else:
                return user_name
        else:
            return ""

    def get_last_question(self):
        # Do not record this pair as the last question and answer
        self.chat_session.update_pair = False

        last_question = self.chat_session.last_question
        if last_question is None or last_question.strip() == '':
            return "You did not say anything."
        else:
            return "You have just said: {}".format(last_question)

    def get_last_answer(self):
        # Do not record this pair as the last question and answer
        self.chat_session.update_pair = False

        last_answer = self.chat_session.last_answer
        if last_answer is None or last_answer.strip() == '':
            return "I did not say anything."
        else:
            return "I have just said: {}".format(last_answer)

    def update_user_name(self, new_name):
        return self.update_user_name_and_call_me(new_name=new_name)

    def update_call_me(self, new_call):
        return self.update_user_name_and_call_me(new_call=new_call)

    def update_user_name_and_call_me(self, new_name=None, new_call=None):
        user_name = self.chat_session.user_name
        call_me = self.chat_session.call_me
        # print("{}; {}; {}; {}".format(user_name, call_me, new_name, new_call))

        if user_name and new_name and new_name.strip() != '':
            if new_name.lower() != user_name.lower():
                self.chat_session.update_pending_action('update_user_name_confirmed', None, new_name)
                return "I am confused. I have your name as {}. Did I get it correctly?".format(user_name)
            else:
                return "You told me your name already. Thank you, {}, for assuring me.".format(user_name)

        if call_me and new_call and new_call.strip() != '':
            if new_call.lower() != call_me.lower():
                self.chat_session.update_pending_action('update_call_me_confirmed', new_call, None)
                return "You wanted me to call you {}. Would you like me to call you {} now?"\
                    .format(call_me, new_call)
            else:
                return "Thank you for letting me again, {}.".format(call_me)

        if new_call and new_call.strip() != '':
            if new_name and new_name.strip() != '':
                self.chat_session.user_name = new_name

            self.chat_session.call_me = new_call
            return "Thank you, {}.".format(new_call)
        elif new_name and new_name.strip() != '':
            self.chat_session.user_name = new_name
            return "Thank you, {}.".format(new_name)

        return "Sorry, I am confused. I could not figure out what you meant."

    def update_user_name_enforced(self, new_name):
        if new_name and new_name.strip() != '':
            self.chat_session.user_name = new_name
            return "OK, thank you, {}.".format(new_name)
        else:
            self.chat_session.user_name = None  # Clear the existing user_name, if any.
            return "Sorry, I am lost."

    def update_call_me_enforced(self, new_call):
        if new_call and new_call.strip() != '':
            self.chat_session.call_me = new_call
            return "OK, got it. Thank you, {}.".format(new_call)
        else:
            self.chat_session.call_me = None  # Clear the existing call_me, if any.
            return "Sorry, I am totally lost."

    def update_user_name_and_reply_papaya(self, new_name):
        user_name = self.chat_session.user_name

        if new_name and new_name.strip() != '':
            if user_name:
                if new_name.lower() != user_name.lower():
                    self.chat_session.update_pending_action('update_user_name_confirmed', None, new_name)
                    return "I am confused. I have your name as {}. Did I get it correctly?".format(user_name)
                else:
                    return "Thank you, {}, for assuring me your name. My name is Papaya.".format(user_name)
            else:
                self.chat_session.user_name = new_name
                return "Thank you, {}. BTW, my name is Papaya.".format(new_name)
        else:
            return "My name is Papaya. Thanks."

    def correct_user_name(self, new_name):
        if new_name and new_name.strip() != '':
            self.chat_session.user_name = new_name
            return "Thank you, {}.".format(new_name)
        else:
            # Clear the existing user_name and call_me information
            self.chat_session.user_name = None
            self.chat_session.call_me = None
            return "I am totally lost."

    def clear_user_name_and_call_me(self):
        self.chat_session.user_name = None
        self.chat_session.call_me = None

    def execute_pending_action_and_reply(self, answer):
        func = self.chat_session.pending_action['func']
        if func == 'update_user_name_confirmed':
            if answer.lower() == 'yes':
                reply = "Thank you, {}, for confirming this.".format(self.chat_session.user_name)
            else:
                new_name = self.chat_session.pending_action['No']
                self.chat_session.user_name = new_name
                reply = "Thank you, {}, for correcting me.".format(new_name)
        elif func == 'update_call_me_confirmed':
            if answer.lower() == 'yes':
                new_call = self.chat_session.pending_action['Yes']
                self.chat_session.call_me = new_call
                reply = "Thank you, {}, for correcting me.".format(new_call)
            else:
                reply = "Thank you. I will continue to call you {}.".format(self.chat_session.call_me)
        else:
            reply = "OK, thanks."  # Just presents a reply that is good for most situations

        # Clear the pending action anyway
        self.chat_session.clear_pending_action()
        return reply

    """
    # Other Rules: Client Code
    """
    def client_code_show_picture_randomly(self, picture_name):
        if not self.html_format:  # Ignored in the command line interface
            return ''
        else:
            return ' _cc_start_show_picture_randomly_para1_' + picture_name + '_cc_end_'


def call_function(func_info, knowledge_base=None, chat_session=None, para_list=None,
                  html_format=False):
    func_data = FunctionData(knowledge_base, chat_session, html_format=html_format)

    func_dict = {
        'get_date_time': FunctionData.get_date_time,
        'get_time': FunctionData.get_time,
        'get_today': FunctionData.get_today,
        'get_weekday': FunctionData.get_weekday,

        'get_story_any': func_data.get_story_any,
        'get_story_name': func_data.get_story_name,
        'get_joke_any': func_data.get_joke_any,
        'continue_last_topic': func_data.continue_last_topic,

        'get_number_plus': FunctionData.get_number_plus,
        'get_number_minus': FunctionData.get_number_minus,
        'get_number_multiply': FunctionData.get_number_multiply,
        'get_number_divide': FunctionData.get_number_divide,

        'ask_howru_if_not_yet': func_data.ask_howru_if_not_yet,
        'ask_name_if_not_yet': func_data.ask_name_if_not_yet,
        'get_user_name_and_reply': func_data.get_user_name_and_reply,
        'get_callme': func_data.get_callme,
        'get_last_question': func_data.get_last_question,
        'get_last_answer': func_data.get_last_answer,

        'update_user_name': func_data.update_user_name,
        'update_call_me': func_data.update_call_me,
        'update_user_name_and_call_me': func_data.update_user_name_and_call_me,
        'update_user_name_enforced': func_data.update_user_name_enforced,
        'update_call_me_enforced': func_data.update_call_me_enforced,
        'update_user_name_and_reply_papaya': func_data.update_user_name_and_reply_papaya,

        'correct_user_name': func_data.correct_user_name,
        'clear_user_name_and_call_me': func_data.clear_user_name_and_call_me,

        'execute_pending_action_and_reply': func_data.execute_pending_action_and_reply,

        'client_code_show_picture_randomly': func_data.client_code_show_picture_randomly
    }

    para1_index = func_info.find('_para1_')
    para2_index = func_info.find('_para2_')
    if para1_index == -1:  # No parameter at all
        func_name = func_info
        if func_name in func_dict:
            return func_dict[func_name]()
    else:
        func_name = func_info[:para1_index]
        if para2_index == -1:  # Only one parameter
            func_para = func_info[para1_index+7:]
            if func_para == '_name_' and para_list is not None and len(para_list) >= 1:
                return func_dict[func_name](para_list[0])
            elif func_para == '_callme_' and para_list is not None and len(para_list) >= 2:
                return func_dict[func_name](para_list[1])
            else:  # The parameter value was embedded in the text (part of the string) of the training example.
                return func_dict[func_name](func_para)
        else:
            func_para1 = func_info[para1_index+7:para2_index]
            func_para2 = func_info[para2_index+7:]
            if para_list is not None and len(para_list) >= 2:
                para1_val = para_list[0]
                para2_val = para_list[1]

                if func_para1 == '_num1_' and func_para2 == '_num2_':
                    return func_dict[func_name](para1_val, para2_val)
                elif func_para1 == '_num2_' and func_para2 == '_num1_':
                    return func_dict[func_name](para2_val, para1_val)
                elif func_para1 == '_name_' and func_para2 == '_callme_':
                    return func_dict[func_name](para1_val, para2_val)

    return "You beat me to it, and I cannot tell which is which for this question."

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


class BotPredictor(object):
    def __init__(self, session, corpus_dir, knbase_dir, result_dir, result_file):
        """
        Args:
            session: The TensorFlow session.
            corpus_dir: Name of the folder storing corpus files and vocab information.
            knbase_dir: Name of the folder storing data files for the knowledge base.
            result_dir: The folder containing the trained result files.
            result_file: The file name of the trained model.
        """
        self.session = session

        # Prepare data and hyper parameters
        print("# Prepare dataset placeholder and hyper parameters ...")
        tokenized_data = TokenizedData(corpus_dir=corpus_dir, training=False)

        self.knowledge_base = KnowledgeBase()
        self.knowledge_base.load_knbase(knbase_dir)

        self.session_data = SessionData()

        self.hparams = tokenized_data.hparams
        self.src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        src_dataset = tf.data.Dataset.from_tensor_slices(self.src_placeholder)
        self.infer_batch = tokenized_data.get_inference_batch(src_dataset)

        # Create model
        print("# Creating inference model ...")
        self.model = ModelCreator(training=False, tokenized_data=tokenized_data,
                                  batch_input=self.infer_batch)
        # Restore model weights
        print("# Restoring model weights ...")
        self.model.saver.restore(session, os.path.join(result_dir, result_file))

        self.session.run(tf.tables_initializer())

    def predict(self, session_id, question, html_format=False):
        chat_session = self.session_data.get_session(session_id)
        chat_session.before_prediction()  # Reset before each prediction

        if question.strip() == '':
            answer = "Don't you want to say something to me?"
            chat_session.after_prediction(question, answer)
            return answer

        pat_matched, new_sentence, para_list = check_patterns_and_replace(question)

        for pre_time in range(2):
            tokens = nltk.word_tokenize(new_sentence.lower())
            tmp_sentence = [' '.join(tokens[:]).strip()]

            self.session.run(self.infer_batch.initializer,
                             feed_dict={self.src_placeholder: tmp_sentence})

            outputs, _ = self.model.infer(self.session)

            if self.hparams.beam_width > 0:
                outputs = outputs[0]

            eos_token = self.hparams.eos_token.encode("utf-8")
            outputs = outputs.tolist()[0]

            if eos_token in outputs:
                outputs = outputs[:outputs.index(eos_token)]

            if pat_matched and pre_time == 0:
                out_sentence, if_func_val = self._get_final_output(outputs, chat_session,
                                                                   para_list=para_list,
                                                                   html_format=html_format)
                if if_func_val:
                    chat_session.after_prediction(question, out_sentence)
                    return out_sentence
                else:
                    new_sentence = question
            else:
                out_sentence, _ = self._get_final_output(outputs, chat_session,
                                                         html_format=html_format)
                chat_session.after_prediction(question, out_sentence)
                return out_sentence

    def _get_final_output(self, sentence, chat_session, para_list=None, html_format=False):
        sentence = b' '.join(sentence).decode('utf-8')
        if sentence == '':
            return "I don't know what to say.", False

        if_func_val = False
        last_word = None
        word_list = []
        for word in sentence.split(' '):
            word = word.strip()
            if not word:
                continue

            if word.startswith('_func_val_'):
                if_func_val = True
                word = call_function(word[10:], knowledge_base=self.knowledge_base,
                                     chat_session=chat_session, para_list=para_list,
                                     html_format=html_format)
                if word is None or word == '':
                    continue
            else:
                if word in self.knowledge_base.upper_words:
                    word = self.knowledge_base.upper_words[word]

                if (last_word is None or last_word in ['.', '!', '?']) and not word[0].isupper():
                    word = word.capitalize()

            if not word.startswith('\'') and word != 'n\'t' \
                and (word[0] not in string.punctuation or word in ['(', '[', '{', '``', '$']) \
                and last_word not in ['(', '[', '{', '``', '$']:
                word = ' ' + word

            word_list.append(word)
            last_word = word

        return ''.join(word_list).strip(), if_func_val

In [0]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
PROJECT_ROOT = "drive/AgateV2"

def initBot():
    try:
        answer = callBot("Hi")
    except:
        global predictor
        predictor = BotPredictor(tf.Session(), corpus_dir=os.path.join(PROJECT_ROOT, 'Data', 'Corpus'), knbase_dir=os.path.join(PROJECT_ROOT, 'Data', 'KnowledgeBase'), result_dir=os.path.join(PROJECT_ROOT, 'Data', 'Result'), result_file='basic')
        global session_id 
        session_id = predictor.session_data.add_session()

def callBot(sentence):
    return predictor.predict(session_id, sentence)

In [0]:
client = Bot(description="JADE AI", command_prefix="", pm_help = False)
setup(client)

@client.event
async def on_ready():
    initBot()
    print('Logged in as '+client.user.name+' (ID:'+client.user.id+') | Connected to '+str(len(client.servers))+' servers | Connected to '+ str(len(set(client.get_all_members()))) +' users')
    print('--------')
    print('You are running Agate AI v0.2') #Do not change this. This will really help us support you, if you need support.
    print('--------')
    resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
    return await client.change_presence(game=discord.Game(name="with my sister, Jade! ||| I've been invited to "+str(len(client.servers))+" homes, and "+ str(len(set(client.get_all_members())))+ " people are my friends! ||| AG is my prefix!")) #This is buggy, let us know if it doesn't work.

In [0]:
@client.event
async def on_message(message):
  if not message.server == None and not message.author.bot:
    if message.content.startswith('AG ') or message.content.startswith('ag '):
        initBot()
        ModMessage = message.content[3:]
        print('\n' + str(message.server) + '\n' + str(message.author) + ": " + ModMessage)
        ModMessage = ModMessage.replace("Agate", "Papaya")
        await client.send_typing(message.channel)
        if ModMessage == "Play GO":
          await run(message.channel)
        answer = callBot(ModMessage)
        answer = answer.replace("Papaya", "Agate")
        answer = answer.replace("father", "daughter")
        answer = answer.replace("male", "female")
        answer = answer.replace("boy", "girl")
        answer = answer.replace("Although being a robot, I look like a normal 9 year old boy.", "I look pretty good ;D")
        if ModMessage.lower == "what's your memory usage?":
          process = psutil.Process(os.getpid())
          answer = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        await client.send_message(message.channel, answer)
        print("Agate: " + answer)
        
  elif message.server == None and not message.author.bot:
        initBot()
        print('\n' + str(message.server) + '\n' + str(message.author) + ": " + message.content)
        ModMessage = message.content.replace("Agate", "Papaya")
        await client.send_typing(message.channel)
        if ModMessage == "Play GO":
          await run(message.channel)
        answer = callBot(ModMessage)
        answer = answer.replace("Papaya", "Agate")
        answer = answer.replace("father", "daughter")
        answer = answer.replace("male", "female")
        answer = answer.replace("boy", "girl")
        answer = answer.replace("Although being a robot, I look like a normal 9 year old boy.", "I look pretty good ;D")
        if ModMessage.lower == "what's your memory usage?":
          answer = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        await client.send_message(message.channel, answer)
        print("Agate: " + answer)
        
  await client.change_presence(game=discord.Game(name="with my sister, Jade! ||| I've been invited to "+str(len(client.servers))+" homes, and "+ str(len(set(client.get_all_members())))+ " people are my friends! ||| AG is my prefix!"))
        
client.run('Token')