In [1]:
import mine_sweeper_env
import torch
import numpy as np


In [2]:
import random

class ReplayMemory:
    def __init__(self, max_size, batch_size) -> None:
        self.max_size = max_size
        self.batch_size = batch_size

        self.memory = list()
        self.pointer = 0
        
    def log_memory(self, memory):
        if len(self.memory < self.max_size):
            self.memory.append(memory)
        else:
            self.memory[self.pointer] = memory
            self.pointer += 1
            if self.pointer >= len(self.memory):
                self.pointer = 0
    
    def get_sample(self, size:int=0):
        if size == 0:
            size = self.batch_size
        return random.sample(self.memory, size)


In [4]:
import torch.nn as nn
import torch

class Actor(nn.Module):
    def __init__(self, board_width, board_height) -> None:
        super().__init__()
        input_size = board_width * board_height
        output_size = 2
        hidden_size = int((input_size + output_size)/2)
        
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input:torch.Tensor) -> torch.Tensor:
        output = self.relu(self.linear1(input))
        output = self.sigmoid(self.linear2(output))
        return output


class Critic(nn.Module):
    def __init__(self, board_width, board_height) -> None:
        super().__init__()
        input_size = board_width * board_height + 2
        output_size = 1
        hidden_size = int((input_size + output_size)/2)

        self.linear1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, states:torch.Tensor, actions:torch.Tensor) -> torch.Tensor:
        output = torch.concat([states, actions], dim=1)
        output = self.relu(self.linear1(output))
        output = self.linear2(output)
        return output
        

In [8]:
import mlflow
import mlflow.pytorch
import os

os.environ['GIT_PYTHON_REFRESH'] = 'quiet'

def train():
    exp = mlflow.set_experiment("Actor_critic")
    mlflow.start_run(experiment_id=exp.experiment_id)

    # Start of Hyperparameters
    BOARD_SIZE = 10
    NUMBER_OF_MINTES = 15

    LEARNING_RATE = 1E-3
    BETAS = (0.9, 0.99)
    MAX_ITERATIONS_FOR_CONVERGENCE = 1000

    MAX_REPLAY_MEMORY_SIZE = 10000
    BATCH_SIZE = 512
    MIN_REPLAY_MEMORY_SIZE_TO_START_TRAINING = 1000
    UPDATE_TARGET_NET_PER_STEPS = 100

    MAX_EPISODES = 10000

    TD_LAMBDA = 5
    DISCOUNT_FACTOR = 0.99
    # End of hyperparameters

    # Logging hyperparameters using mlflow
    mlflow.log_param('BOARD_SIZE', BOARD_SIZE)
    mlflow.log_param('NUMBER_OF_MINTES', NUMBER_OF_MINTES)
    mlflow.log_param('LEARNING_RATE', LEARNING_RATE)
    mlflow.log_param('BETAS', BETAS)
    mlflow.log_param('MAX_ITERATIONS_FOR_CONVERGENCE', MAX_ITERATIONS_FOR_CONVERGENCE)
    mlflow.log_param('MAX_REPLAY_MEMORY_SIZE', MAX_REPLAY_MEMORY_SIZE)
    mlflow.log_param('BATCH_SIZE', BATCH_SIZE)
    mlflow.log_param('MIN_REPLAY_MEMORY_SIZE_TO_START_TRAINING', MIN_REPLAY_MEMORY_SIZE_TO_START_TRAINING)
    mlflow.log_param('UPDATE_TARGET_NET_PER_STEPS', UPDATE_TARGET_NET_PER_STEPS)
    mlflow.log_param('MAX_EPISODES', MAX_EPISODES)
    mlflow.log_param('TD_LAMBDA', TD_LAMBDA)
    mlflow.log_param('DISCOUNT_FACTOR', DISCOUNT_FACTOR)
    # End of logging hyperparameters
    
    # Initializing networks
    actor_net = Actor(board_height=BOARD_SIZE, board_width=BOARD_SIZE)
    critic_net = Critic(board_height=BOARD_SIZE, board_width=BOARD_SIZE)
    target_actor_net = Actor(board_height=BOARD_SIZE, board_width=BOARD_SIZE)
    target_critic_net = Critic(board_height=BOARD_SIZE, board_width=BOARD_SIZE)
    target_actor_net.load_state_dict(actor_net.state_dict())
    target_critic_net.load_state_dict(critic_net.state_dict())
    target_actor_net.eval()
    target_critic_net.eval()

    # Loss function and optimizer
    loss_function = torch.nn.functional.mse_loss 
    optimizer_actor = torch.optim.Adam(params=actor_net.parameters(), lr=LEARNING_RATE, betas=BETAS)
    optimizer_critic = torch.optim.Adam(params=critic_net.parameters(), lr=LEARNING_RATE, betas=BETAS)

    # Preparing environment, and replay memory
    env = mine_sweeper_env.MinesweeperEnv(board_size=BOARD_SIZE, num_mines=NUMBER_OF_MINTES)
    replay_memory = ReplayMemory(max_size=MAX_REPLAY_MEMORY_SIZE, batch_size=BATCH_SIZE)

    total_step_count = 0
    for episode in range(MAX_EPISODES):
        is_episode_done = False
        total_reward = 0

        steps_in_this_epoch = 0
        while is_episode_done is False:
            # Go through the environment for labmda number of steps
            states_list = list()
            actions_list = list()
            rewards_list = list()

            td_lambda_count = 0
            while td_lambda_count < TD_LAMBDA and not(is_episode_done):
                # Get an observation
                observation = torch.tensor(env.my_board).flatten()
                states_list.append(observation)

                # Prediction of action using actor net
                actor_net.eval()
                action = actor_net(observation[None, :]).flatten()
                actions_list.append(action)
                
                # Add noise to action here for exploration purpose

                # Performing the action in the environment
                final_action = action * BOARD_SIZE
                new_state, reward, is_episode_done, _ = env.step(final_action.numpy())
                
                total_reward += reward
                rewards_list.append(reward)

                td_lambda_count += 1
            
            # Saving the observations obtained
            target_q_values = [0]*len(states_list)
            if is_episode_done:
                last_target = 0
            else:
                output_of_actor = target_actor_net(torch.tensor(new_state)[None, :]).flatten()
                last_target = target_critic_net(torch.tensor(new_state)[None, :], output_of_actor).item()
            target_q_values [-1] = rewards_list[-1] + DISCOUNT_FACTOR * last_target

            # Calculating all target q values
            for i in range(len(target_q_values)-2, -1, -1):
                target_q_values[i] = rewards_list[i] + DISCOUNT_FACTOR * target_q_values[i+1]

            for item in zip(states_list, actions_list, target_q_values):
                replay_memory.log_memory(item)

            # Training the nets
            if len(replay_memory.memory) > MIN_REPLAY_MEMORY_SIZE_TO_START_TRAINING:
                # Training critic net first
                critic_net.train()
                prev_loss = 0
                for i in range(MAX_ITERATIONS_FOR_CONVERGENCE):
                    optimizer_critic.zero_grad()
                    batch = replay_memory.get_sample()
                    
                    states_list = [item[0] for item in batch]
                    actions_list = [item[1] for item in batch]
                    target_q_values = [item[2] for item in batch]
                    
                    predicted_q_values = critic_net(states_list)
                    target_q_values = torch.tensor(target_q_values)

                    loss = loss_function(predicted_q_values.flatten(), target_q_values)
                    loss.backward()
                    optimizer_critic.step()

                    # Check for convergence for an early exit
                    change_in_loss = abs(prev_loss - loss.item()) / loss.item()
                    if change_in_loss < 0.001:
                        print('Exiting due to early convergence!!! at step', i, '-> prev loss:', prev_loss, 'current loss:', loss.item())
                        break
                    else:
                        prev_loss = loss.item()
                        print('Steps in training', i)
                
                critic_net.eval()

                # Now, training actor net to maximize Q value.
                prev_loss = 0
                actor_net.train()
                for i in range(MAX_ITERATIONS_FOR_CONVERGENCE):
                    optimizer_actor.zero_grad()
                    batch = replay_memory.get_sample()
                    
                    states_list = [item[0] for item in batch]
                    states_list = torch.stack(states_list)

                    output_of_actor = actor_net(states_list)
                    final_q_values = critic_net(states_list, output_of_actor)
                    loss = -1 * torch.sum(final_q_values)
                    loss.backward()
                    optimizer_actor.step()

                    # Check for convergence for an early exit
                    change_in_loss = abs(prev_loss - loss.item()) / loss.item()
                    if change_in_loss < 0.001:
                        print('Exiting due to early convergence!!! at step', i, '-> prev loss:', prev_loss, 'current loss:', loss.item())
                        break
                    else:
                        prev_loss = loss.item()
                        print('Steps in training', i)
                
                actor_net.eval()
            
            steps_in_this_epoch += 1
            total_step_count += 1
            if total_step_count % UPDATE_TARGET_NET_PER_STEPS == 0:
                target_actor_net.load_state_dict(actor_net.state_dict())
                target_critic_net.load_state_dict(critic_net.state_dict())
            
        # Saving metrics and models after every epoch
        mlflow.log_metric('total_reward', total_reward, step=episode)

        # evaluate() # Write an evaluation function then log metrics with step as episode number

        mlflow.pytorch.log_model(actor_net, artifact_path='Actor_net_after_episode_'+str(episode))
        mlflow.pytorch.log_model(critic_net, artifact_path='Critic_net_after_episode_'+str(episode))

        print(f"Episode: {episode:5d}" ,f"Total reward: {total_reward:.2f}")

        # Resetting the environment
        env.reset()
        is_episode_done = False

    mlflow.end_run()


In [9]:
train()

Exception: Run with UUID 3cd7c8667e424a3f9b20d3ffe90625f9 is already active. To start a new run, first end the current run with mlflow.end_run(). To start a nested run, call start_run with nested=True