## Load libraries

In [1]:
import numpy as np
from tqdm import tqdm
import time
from torch.utils.tensorboard import SummaryWriter
from multiprocessing.pool import ThreadPool
from datetime import datetime
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from torch.distributions.categorical import Categorical
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper

## Parameters

In [3]:
device = (torch.device("cuda"))
dim = 256
memory_length = 64
bidirect = False
lstm_layers = 2
critic_lr = 3e-4
actor_lr = critic_lr / 3.0
reparam_noise = 1e-6
gamma=0.99
tau=0.005
alpha_start = 1
max_size = 1000000
batch_size = 512
total_plays = 50000
num_epochs = 1
N = 1

## Environment setup

In [4]:
unity_env = UnityEnvironment("D:\Practice\SentisInfrence\Build\SentisInfrence.exe", no_graphics=True)
env = UnityToGymWrapper(unity_env)

#env = gym.make("CartPole-v1")

obs_dim = env.observation_space.shape
n_actions=env.action_space.n

print(obs_dim)
print(n_actions)

(12,)
4


In [8]:
env.close()

## Replay buffer

In [5]:
class ReplayBuffer():
    def __init__(self, max_size, input_shape, n_actions):
        self.mem_size = max_size
        self.mem_cntr = 0
        self.state_memory = np.zeros((self.mem_size, *input_shape))
        self.new_state_memory = np.zeros((self.mem_size, *input_shape))
        self.action_memory = np.zeros((self.mem_size, n_actions))
        self.reward_memory = np.zeros(self.mem_size)
        self.done_memory = np.zeros((self.mem_size), dtype=bool)
        self.hidden_cr1 = (torch.zeros((1, dim), dtype=torch.float).to(device), torch.zeros((1, dim), dtype=torch.float).to(device))
        self.hidden_cr2 = (torch.zeros((1, dim), dtype=torch.float).to(device), torch.zeros((1, dim), dtype=torch.float).to(device))
        self.hidden_actor = (torch.zeros((1, dim), dtype=torch.float).to(device), torch.zeros((1, dim), dtype=torch.float).to(device))

    def store_transition(self, state, action, reward, state_, dones, hidden_actor, hidden_critic1, hidden_critic2):
        index = self.mem_cntr % self.mem_size

        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.done_memory[index] = dones
        if hidden_actor is not None and hidden_critic1 is not None and hidden_critic2 is not None:
            self.hidden_cr1 = (hidden_critic1[0].detach(), hidden_critic1[0].detach())
            self.hidden_cr2 = (hidden_critic2[0].detach(), hidden_critic2[0].detach())
            self.hidden_actor = (hidden_actor[0].detach(), hidden_actor[0].detach())

        self.mem_cntr += 1
    
    def sample_history_sequence(self, history_length):
        max_mem = min(self.mem_cntr, self.mem_size)
        if max_mem <= history_length:
            hist_part2 = torch.zeros(1, obs_dim[0] * memory_length)
        else:
            hist_states = np.zeros([1, history_length, obs_dim[0]])

            id = max_mem - 1
            hist_start_id = id - history_length
            if hist_start_id < 0:
                hist_start_id = 0
            # If exist done before the last experience (not including the done in id), start from the index next to the done.
            if len(np.where(self.done_memory[hist_start_id:id] == 1)[0]) != 0:
                hist_start_id = hist_start_id + (np.where(self.done_memory[hist_start_id:id] == True)[0][-1]) + 1
            hist_seg_len = id - hist_start_id
            hist_states[0, :hist_seg_len, :] = self.state_memory[hist_start_id:id]

            hist_part2 = torch.tensor(hist_states, dtype=torch.float).reshape(1, -1)

        dictionary = dict(history_obs=hist_part2)
        return dictionary

    def sample_buffer_history(self, batch_size, history_length):
        max_mem = min(self.mem_cntr, self.mem_size)
        batch = np.random.randint(history_length, max_mem, batch_size)

        if history_length == 0:
                hist_states = np.zeros([batch_size, 1, obs_dim[0]])
                hist_states_len = np.zeros(batch_size)
        else:
            hist_states = np.zeros([batch_size, history_length, obs_dim[0]])
            hist_states_len = history_length * np.ones(batch_size)

            for i, id, in enumerate(batch):
                hist_start_id = id - history_length
                if hist_start_id < 0:
                    hist_start_id = 0
                # If exist done before the last experience (not including the done in id), start from the index next to the done.
                if len(np.where(self.done_memory[hist_start_id:id] == 1)[0]) != 0:
                    hist_start_id = hist_start_id + (np.where(self.done_memory[hist_start_id:id] == True)[0][-1]) + 1
                hist_seg_len = id - hist_start_id
                hist_states_len[i] = hist_seg_len
                hist_states[i, :hist_seg_len, :] = self.state_memory[hist_start_id:id]

        hist_part2 = torch.tensor(hist_states, dtype=torch.float).reshape(batch_size, -1)

        if batch_size <= max_mem:
            dictionary = dict(states=self.state_memory[batch],
                        states_=self.new_state_memory[batch],
                        actions=self.action_memory[batch],
                        rewards=self.reward_memory[batch],
                        dones=self.done_memory[batch],
                        history_obs=hist_part2,
                        hidden_critic1=self.hidden_cr1,
                        hidden_critic2=self.hidden_cr2,
                        hidden_actor=self.hidden_actor)
        else:
            dictionary = dict(history_obs=hist_part2,
                              hidden_critic1=self.hidden_cr1,
                              hidden_critic2=self.hidden_cr2,
                              hidden_actor=self.hidden_actor)
        
        return dictionary

## Networks

In [None]:
class HistoryNetwork(nn.Module):
    def __init__(self):
        super(HistoryNetwork, self).__init__()

        self.fc_layers = nn.Sequential(nn.Linear((obs_dim[-1]) * memory_length, dim),
                                       nn.ReLU())
        self.lstm_layers = nn.LSTM(dim, dim, num_layers=lstm_layers, batch_first=True, bidirectional=bidirect)
        
        self.to(device)

    def forward(self, history_obs, hidden: tuple[torch.Tensor, torch.Tensor]):
        x = self.fc_layers(history_obs)
        out, hidden_ = self.lstm_layers(x, hidden)

        return out, hidden_
     
class CriticNetwork(nn.Module):
    def __init__(self):
        super(CriticNetwork, self).__init__()

        self.critic_me = HistoryNetwork()

        self.critic_cf = nn.Sequential(
            nn.Linear(obs_dim[-1], dim),
            nn.ReLU())
        
        self.critic_pi = nn.Sequential(nn.Linear(dim * 3 if bidirect else dim * 2, dim),
                                       nn.ReLU(),
                                       nn.Linear(dim, n_actions))

        self.optimizer = optim.Adam(self.parameters(), lr=critic_lr)
        self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1, end_factor=0, total_iters=total_plays)

        self.to(device)

    def forward(self, state, history_obs, hidden: tuple[torch.Tensor, torch.Tensor]):
        me, hidden_ = self.critic_me(history_obs, hidden)
        cf = self.critic_cf(state)
        pi = self.critic_pi(torch.cat([me, cf], dim=1))

        return pi, hidden_

class ActorNetwork(nn.Module):
    def __init__(self):
        super(ActorNetwork, self).__init__()

        self.actor_me = HistoryNetwork()
        self.actor_cf = nn.Sequential(
            nn.Linear(obs_dim[-1], dim),
            nn.ReLU())
        self.actor_pi = nn.Sequential(nn.Linear(dim * 3 if bidirect else dim * 2, int(n_actions)),
                                      nn.Softmax(dim=-1))

        self.optimizer = optim.Adam(self.parameters(), lr=actor_lr)
        self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1, end_factor=0, total_iters=total_plays)

        self.to(device)

    def forward(self, state, history_obs, hidden: tuple[torch.Tensor, torch.Tensor]):
        me, hidden_ = self.actor_me(history_obs, hidden)
        cf = self.actor_cf(state)
        pi = self.actor_pi(torch.cat([me, cf], dim=-1))
       
        return pi, hidden_

    def sample_normal(self, state, history_obs, hidden=None):
        action_probs, hidden_ = self.forward(state, history_obs, hidden)
        
        action = torch.argmax(action_probs, dim=-1)

        z = action_probs == 0.0
        z = z.float() * 1e-8
        log_probs = torch.log(action_probs + z)

        return action, log_probs, action_probs, hidden_
    
class Agent():
    def __init__(self):
        self.memory = ReplayBuffer(max_size, obs_dim, n_actions)

        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

        self.actor = ActorNetwork()
        self.critic_1 = CriticNetwork()
        self.critic_2 = CriticNetwork()
        self.critic_1_target = CriticNetwork()
        self.critic_2_target = CriticNetwork()

        self.actor.apply(init_weights)
        self.critic_1.apply(init_weights)
        self.critic_2.apply(init_weights)

        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())

        self.alpha = alpha_start
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=critic_lr)
        self.alpha_scheduler = optim.lr_scheduler.LinearLR(self.alpha_optimizer, start_factor=1, end_factor=0, total_iters=total_plays)
        self.target_entropy = -n_actions

    def choose_action(self, observation, hidden):
        dic = self.memory.sample_history_sequence(memory_length)
        history_obs = dic["history_obs"].to(device)
        state = torch.tensor(np.array([observation]), dtype=torch.float).to(device)
        actions, _, _, hidden_ = self.actor.sample_normal(state, history_obs, hidden)

        return actions.cpu().detach().numpy()[0], hidden_

    def remember(self, state, action, reward, new_state, dones, hidden_critic1, hidden_critic2, hidden_actor):
        self.memory.store_transition(state, action, reward, new_state, dones, hidden_critic1, hidden_critic2, hidden_actor)

    def gradient_step(self, hidden_in_a, hidden_in_c1, hidden_in_c2):
        if self.memory.mem_cntr < batch_size:
            return hidden_in_a, hidden_in_c1, hidden_in_c2
        
        for _ in range(num_epochs):
            dct = self.memory.sample_buffer_history(batch_size, memory_length)

            history_obs = dct["history_obs"].detach().to(device)
            #hidden_critic1 = dct["hidden_critic1"]
            #hidden_critic2 = dct["hidden_critic2"]
            #hidden_actor = dct["hidden_actor"]
            reward = torch.tensor(dct["rewards"], dtype=torch.float).to(device)
            state_ = torch.tensor(dct["states_"], dtype=torch.float).to(device)
            state = torch.tensor(dct["states"], dtype=torch.float).to(device)
            actions = torch.tensor(dct["actions"], dtype=torch.float).to(device)
            dones = torch.tensor(dct["dones"], dtype=torch.bool).to(device)

            # Critics gradient step
            _, _, action_probs_, _ = self.actor.sample_normal(state_, history_obs, hidden_in_a)
            
            with torch.no_grad():
                q1_target_value, _ = self.critic_1_target.forward(state_, history_obs, hidden_in_c1)
                q2_target_value, _ = self.critic_2_target.forward(state_, history_obs, hidden_in_c2)
                q_target_value = torch.min(q2_target_value, q1_target_value) * action_probs_
                q_hat = reward.view(batch_size, -1) + gamma * ~(dones.view(batch_size, -1)) * q_target_value
            q1_value, hidden_critic1_ = self.critic_1.forward(state, history_obs, hidden_in_c1)
            q2_value, hidden_critic2_ = self.critic_2.forward(state, history_obs, hidden_in_c2)
            q1_loss = 0.5 * F.mse_loss(q1_value.gather(1, actions.long()), q_hat)
            q2_loss = 0.5 * F.mse_loss(q2_value.gather(1, actions.long()), q_hat)
            
            q_loss = q1_loss + q2_loss
            self.critic_1.zero_grad()
            self.critic_2.zero_grad()
            q_loss.backward()
            self.critic_1.optimizer.step()
            self.critic_2.optimizer.step()

            # Policy gradient step
            _, log_probs, action_probs, hidden_actor_ = self.actor.sample_normal(state, history_obs, hidden_in_a)
            q1_value, _ = self.critic_1.forward(state, history_obs, hidden_in_c1)
            q2_value, _ = self.critic_2.forward(state, history_obs, hidden_in_c2)
            q_value = torch.min(q1_value, q2_value)
            actor_loss = (action_probs * (self.alpha * log_probs - q_value)).mean()
            self.actor.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Alpha gradient step
            _, log_probs, action_probs, _ = self.actor.sample_normal(state, history_obs, hidden_in_a)
            alpha_loss = -(action_probs * self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.alpha = self.log_alpha.exp().item()

            # Target critic weights update
            for param, target_param in zip(self.critic_1.parameters(), self.critic_1_target.parameters()):
                target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

            for param, target_param in zip(self.critic_2.parameters(), self.critic_2_target.parameters()):
                target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

            return hidden_actor_ , hidden_critic1_, hidden_critic2_
    
    def save_model(self):
        model_scripted = torch.jit.script(self.actor)
        model_scripted.save("models/unity_test" + "_final.pth")

## Training

In [None]:
agent = Agent()

In [None]:
agent.save_model()

In [7]:
if "pbar" in globals():
    pbar.close()
pbar = tqdm(total=total_plays)
pbar.reset()
writer = SummaryWriter("logs/unity_test" + str(datetime.now().day) + str(datetime.now().hour) + str(datetime.now().minute))

writer.add_text(
          "Hyperparameters",
          "|param|value|\n|-|-|\n%s" % ("\n".join(
               [f"|Critic lr|{critic_lr}|",
                f"|Actor lr|{actor_lr}|",
                f"|Layer dim|{dim}|",
                f"|Batch size|{batch_size}|",
                f"|Gamma|{gamma}|",
                f"|Tau|{tau}|",
                ]
          )),
          int(str(datetime.now().day) + str(datetime.now().hour) + str(datetime.now().minute)))

agent = Agent()
best_score = -100000
score_history = []

global_step = 0
for i in range(total_plays):
    observation = env.reset()
    done = False
    score = 0
    iter_steps = 0
    hidden_critic1 = (torch.zeros((lstm_layers * (2 if bidirect else 1), dim), dtype=torch.float).to(device), torch.zeros((lstm_layers * (2 if bidirect else 1), dim), dtype=torch.float).to(device))
    hidden_critic2 = (torch.zeros((lstm_layers * (2 if bidirect else 1), dim), dtype=torch.float).to(device), torch.zeros((lstm_layers * (2 if bidirect else 1), dim), dtype=torch.float).to(device))
    hidden_actor = (torch.zeros((lstm_layers * (2 if bidirect else 1), dim), dtype=torch.float).to(device), torch.zeros((lstm_layers * (2 if bidirect else 1), dim), dtype=torch.float).to(device))
    while not done:
        hidden_in_c1 = (hidden_critic1[0].detach(), hidden_critic1[1].detach())
        hidden_in_c2 = (hidden_critic2[0].detach(), hidden_critic2[1].detach())
        hidden_in_a = (hidden_actor[0].detach(), hidden_actor[1].detach())
        action, hidden_actor = agent.choose_action(observation, hidden_in_a)
        observation_, reward, terminated, _ = env.step(action)
        done = terminated
        score += reward
        if score < -1000:
            done = True
        if iter_steps % N == 0:
            _, hidden_critic1, hidden_critic2 = agent.gradient_step(hidden_in_a, hidden_in_c1, hidden_in_c2)
        agent.remember(observation, action, reward, observation_, terminated, hidden_actor, hidden_critic1, hidden_critic2)
        iter_steps += 1
        global_step += 1
        observation = observation_
    score_history.append(score)
    avg_score = np.mean(score_history[-100:])

    if avg_score > best_score:
        best_score = avg_score
        agent.save_model()

    writer.add_scalar("charts/reward", avg_score, global_step=global_step)
    writer.add_scalar("charts/step_count", iter_steps, global_step=global_step)

    # if global_step > batch_size:
    #     agent.critic_1.scheduler.step()
    #     agent.critic_2.scheduler.step()
    #     agent.actor.scheduler.step()
    #     agent.alpha_scheduler.step()

    pbar.update()

pbar.close() 

  0%|          | 157/50000 [45:54<450:27:39, 32.54s/it]

KeyboardInterrupt: 

## Weights save

In [None]:
torch.save({
            'actor_state_dict': agent.actor.state_dict(),
            'critic1_state_dict': agent.critic_1.state_dict(),
            'critic2_state_dict': agent.critic_2.state_dict(),
            'target_critic1_state_dict': agent.critic_1_target.state_dict(),
            'target_critic2_state_dict': agent.critic_2_target.state_dict(),
            'actor_optimizer_state_dict': agent.actor.optimizer.state_dict(),
            'critic1_optimizer_state_dict': agent.critic_1.optimizer.state_dict(),
            'critic2_optimizer_state_dict': agent.critic_2.optimizer.state_dict(),
            }, "models/unity" + str(i) + "_steps_weights.pt")

## Load weights

In [None]:
ckpt = torch.load("models/unity49999_steps_weights.pt")
agent.actor.load_state_dict(ckpt['actor_state_dict'])
agent.actor.optimizer.load_state_dict(ckpt['actor_optimizer_state_dict'])
agent.critic_1.load_state_dict(ckpt['critic1_state_dict'])
agent.critic_2.load_state_dict(ckpt['critic2_state_dict'])
agent.critic_1.optimizer.load_state_dict(ckpt['critic1_optimizer_state_dict'])
agent.critic_2.optimizer.load_state_dict(ckpt['critic2_optimizer_state_dict'])
agent.critic_1_target.load_state_dict(ckpt['target_critic1_state_dict'])
agent.critic_2_target.load_state_dict(ckpt['target_critic2_state_dict'])
agent.actor.train()
agent.critic_1.train()
agent.critic_2.train()

## Export to ONNX (Unity format)

In [None]:
observation = env.reset() 
observation = torch.tensor(np.array([observation]), dtype=torch.float).to(device)

class WrapperNet(torch.nn.Module):
    def __init__(self, actor):
        """
        Wraps the VisualQNetwork adding extra constants and dummy mask inputs
        required by runtime inference with Sentis.

        For environment continuous actions outputs would need to add them
        similarly to how discrete action outputs work, both in the wrapper
        and in the ONNX output_names / dynamic_axes.
        """
        super(WrapperNet, self).__init__()
        self.qnet = actor

        # version_number
        #   MLAgents1_0 = 2   (not covered by this example)
        #   MLAgents2_0 = 3
        version_number = torch.Tensor([3])
        self.version_number = nn.Parameter(version_number, requires_grad=False)

        # memory_size
        # TODO: document case where memory is not zero.
        memory_size = torch.Tensor([0])
        self.memory_size = nn.Parameter(memory_size, requires_grad=False)

        # discrete_action_output_shape
        output_shape = torch.Tensor([n_actions])
        self.discrete_shape = nn.Parameter(torch.tensor(output_shape), requires_grad=False)


    # if you have discrete actions ML-agents expects corresponding a mask
    # tensor with the same shape to exist as input
    def forward(self, obs: torch.tensor):
        probs = self.qnet(obs)
        
        action = torch.argmax(probs, dim=-1, keepdim=True)
        
        return [action], self.discrete_shape, self.version_number, self.memory_size
    
model = WrapperNet(agent.actor)

torch.onnx.export(
    model,
    observation,
    'UnityTest.onnx',
    opset_version=11,
    # input_names must correspond to the WrapperNet forward parameters
    # obs will be obs_0, obs_1, etc.
    input_names=["obs_0"],
    # output_names must correspond to the return tuple of the WrapperNet
    # forward function.
    output_names=["discrete_actions", "discrete_action_output_shape",
                  "version_number", "memory_size"],
    # All inputs and outputs should have their 0th dimension be designated
    # as 'batch'
    dynamic_axes={'obs_0': {0: 'batch'},
                  'discrete_actions': {0: 'batch'},
                  'discrete_action_output_shape': {0: 'batch'}
                 }
    )