In [73]:
# https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [74]:
import math
import copy
import random
from dataclasses import dataclass

In [75]:
from algo.dynamicProgramming import dynamicPlayer
from algo.iplayer import RandomPlayer
from algo.board import Board, GameState

In [76]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
device

device(type='cuda')

In [77]:
class DQN(nn.Module):
	"""
	Using structure similar to NNUE:
	https://www.chessprogramming.org/File:StockfishNNUELayers.png
	
	Observation space: hot-encoded board:
	for each of 18 cells we can be -2, -1, 0, 1, 2 (5 possibilities).
	In total it gives 18 * 5 = 90 possible inputs, out of which at most 12 are on.
	
	# Action space: 4 possible actions.
	Value function: 1 output. # https://www.reddit.com/r/reinforcementlearning/comments/1b1te73/help_me_understand_why_use_a_policy_net_instead/
	"""

	def __init__(self):
		super(DQN, self).__init__()

		layer_sizes = [
			90,
			60,
			20,
			20,
			1
		]

		layers = []
		prev_size = layer_sizes[0]
		for cur_size in layer_sizes[1:]:
			layers.append(nn.Linear(prev_size, cur_size))
			prev_size = cur_size

		self.layers = nn.ModuleList(layers)

	def forward(self, board: Board) -> torch.Tensor:
		state = board.to_tensor(device)
		for layer in self.layers[:-1]:
			state = F.relu(layer(state))
		return self.layers[-1](state)

In [78]:
enemy = RandomPlayer()

def make_environment_step(state: Board, action: tuple[tuple[int, int], tuple[int, int]]) -> tuple[Board, torch.Tensor]:
	"""
	Returns new state and reward for the given action.
	"""
	cur_sign = state.turn_sign
	we_captured = state.make_move(*action) * cur_sign
	enemy_captured = 0
	while state.game_state == GameState.NOT_OVER and state.turn_sign != cur_sign:
		enemy_captured += state.make_move(*enemy.decide_move(state)) * cur_sign * (-1)
		
	reward = we_captured - enemy_captured
	if state.game_state != GameState.NOT_OVER:
		if state.game_state == GameState.DRAW:
			reward -= 40
		elif state.game_state == GameState(cur_sign):
			reward += 40
		elif state.game_state == GameState(-cur_sign):
			reward -= 40
		else:
			raise ValueError("Unexpected game state")
	
	return state, torch.Tensor([reward])

In [79]:
GAMMA = 0.99 # discount rate

@dataclass
class Transition:
	new_state: Board
	action: tuple[tuple[int, int], tuple[int, int]]
	immediate_reward: torch.Tensor
	value: torch.Tensor

def q_s(dqn: DQN, current_state: Board) -> list[Transition]:
	"""
	Return: list[(new_state, action, immediate_reward, value)]
	"""
	ret: list[Transition] = []
	for s in current_state.get_possible_pos():
		for e in current_state.get_correct_moves(s):
			next_state = copy.deepcopy(current_state)
			next_state, reward = make_environment_step(next_state, (s, e))
			value = dqn(next_state) * GAMMA + reward.to(device)
			ret.append(Transition(next_state, (s, e), reward, value))
	return ret

In [80]:
BATCH_SIZE = 128 # number of transitions sampled from the replay buffer

EPS_START = 0.9 # exploration rate
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005 # update rate
LR = 1e-4 # AdamW learning rate

policy_net = DQN().to(device) # to be updated often
target_net = DQN().to(device) # to be updated with TAU
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
steps_done = 0


def select_action(board: Board) -> Transition:
	global steps_done
	sample = random.random()
	eps_threshold = EPS_END + (EPS_START - EPS_END) * \
		math.exp(-1. * steps_done / EPS_DECAY)
	steps_done += 1
	if sample > eps_threshold:
		with torch.no_grad():
			state_values = q_s(policy_net, board)
			return max(state_values, key=lambda x: x.value.item())
					
	else:
		possible_moves = []
		for s in board.get_possible_pos():
			for e in board.get_correct_moves(s):
				possible_moves.append((s, e))
		
		action = random.choice(possible_moves)
		next_state = copy.deepcopy(board)
		next_state, reward = make_environment_step(next_state, action)
		return Transition(next_state, action, reward, torch.Tensor([0]))

In [81]:
@dataclass
class TransitionRecord:
	current_state: Board
	next_state: Board
	immediate_reward: torch.Tensor

def optimize_model(memory: list[TransitionRecord]):
	if len(memory) < BATCH_SIZE:
		return
	

	for r in random.sample(memory, BATCH_SIZE):
		# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
		# columns of actions taken. These are the actions which would've been taken
		# for each batch state according to policy_net
		state_action_values = max(q_s(policy_net, r.current_state), key=lambda x: x.value.item()).value

		# Compute V(s_{t+1}) for all next states.
		# Expected values of actions for non_final_next_states are computed based
		# on the "older" target_net; selecting their best reward with max(1).values
		# This is merged based on the mask, such that we'll have either the expected
		# state value or 0 in case the state was final.
		next_state_values = 0
		if r.next_state.game_state == GameState.NOT_OVER:
			with torch.no_grad():
				next_state_values = max(q_s(target_net, r.next_state), key=lambda x: x.value.item()).value
		# Compute the expected Q values
		expected_state_action_value = (next_state_values * GAMMA) + r.immediate_reward

		# Compute Huber loss
		criterion = nn.SmoothL1Loss()
		loss = criterion(state_action_values.to(device), expected_state_action_value.to(device))

		# Optimize the model
		optimizer.zero_grad()
		loss.backward()
		# In-place gradient clipping
		torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
		optimizer.step()

In [71]:
if torch.cuda.is_available() or torch.backends.mps.is_available():
	num_episodes = 600
else:
	num_episodes = 50
	
our_sign = -1
memory: list[TransitionRecord] = []

win_rate = []

for i_episode in range(num_episodes):
	# Initialize the environment and get its state
	cur_state = Board()
	cur_state.make_move(*enemy.decide_move(cur_state))
	while True:
		t = select_action(cur_state)
		memory.append(TransitionRecord(cur_state, t.new_state, torch.tensor([t.immediate_reward], device=device)))
		cur_state = t.new_state

		# Soft update of the target network's weights
		# θ′ ← τ θ + (1 −τ )θ′
		target_net_state_dict = target_net.state_dict()
		policy_net_state_dict = policy_net.state_dict()
		for key in policy_net_state_dict:
			target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
		target_net.load_state_dict(target_net_state_dict)

		if cur_state.game_state != GameState.NOT_OVER:
			win_rate.append(cur_state.game_state == GameState(our_sign))
			break

	# Perform one step of the optimization (on the policy network)
	optimize_model(memory)

print('Complete')

KeyboardInterrupt: 

In [72]:
win_rate

[True, True, True, False, False, True, True, True, True, True]