I trained on Google Colab to avoid overheating.

In [None]:
!sudo apt-get update -y
!sudo apt-get install python3.10 python3.10-distutils -y

!curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
!python3.10 get-pip.py

!python3.10 --version
!python3.10 -m pip --version

!python3.10 -m pip install torch --index-url https://download.pytorch.org/whl/cu118
!python3.10 -m pip install mlagents-envs matplotlib

In [None]:
!unzip a.zip

In [None]:
%%writefile train1.py
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.base_env import ActionTuple
import random
import numpy as np
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
os.environ.pop("MPLBACKEND", None)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

class ReplayMemory:
    def __init__(self, capacity=100000, batch_size=32, w=0.5):
        self.memory = deque(maxlen=capacity)
        self.priorities = deque(maxlen=capacity)
        self.batch_size = batch_size
        self.w = w

    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        max_p = max(self.priorities, default=1.0)
        self.priorities.append(max_p)

    def sample(self):
        weights = np.array(self.priorities)
        probs = weights / weights.sum()
        self.indices = np.random.choice(len(self.priorities), size=self.batch_size, replace=False, p=probs)
        batch = [self.memory[i] for i in self.indices]
        return batch

    def update_priorities(self, delta_for_priorities):
        for i, delta in zip(self.indices, delta_for_priorities):
            self.priorities[i] = (abs(delta) + 1e-6) ** self.w

    def __len__(self):
        return len(self.memory)


class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=1e-3, epsilon=1.0, min_epsilon=0.1, epsilon_decay_steps=200000.0, gamma=0.99):
        self.q_net = QNetwork(state_dim, action_dim)
        self.target_net = QNetwork(state_dim, action_dim)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)
        self.epsilon = epsilon
        self.epsilon_start = epsilon
        self.min_epsilon = min_epsilon
        self.epsilon_decay_steps = epsilon_decay_steps
        self.gamma = gamma
        self.epsilon_decay = (self.epsilon_start - self.min_epsilon) / self.epsilon_decay_steps

    def select_action(self, state, action_space, train=True):
        if train and random.random() < self.epsilon:
            action = random.choice(action_space)
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.q_net(state_tensor)
            action = int(torch.argmax(q_values, dim=-1).item())
        return action

    def decay_epsilon(self):
        if self.epsilon > self.min_epsilon:
            self.epsilon -= self.epsilon_decay
            self.epsilon = max(self.epsilon, self.min_epsilon)

    def train_step(self, batch, n_step):
        states, actions, rewards, next_states, dones = zip(*batch)

        states = np.array(states, dtype=np.float32)
        next_states = np.array(next_states, dtype=np.float32)
        actions = np.array(actions, dtype=np.int64)
        rewards = np.array(rewards, dtype=np.float32)
        dones = np.array(dones, dtype=np.float32)

        states = torch.from_numpy(states)
        next_states = torch.from_numpy(next_states)
        actions = torch.from_numpy(actions).unsqueeze(1)
        rewards = torch.from_numpy(rewards).unsqueeze(1)
        dones = torch.from_numpy(dones).unsqueeze(1)

        q_values = self.q_net(states).gather(1, actions)
        with torch.no_grad():
            max_next_q_values = self.target_net(next_states).max(1, keepdim=True)[0]
            target_q_values = rewards + (self.gamma ** n_step) * max_next_q_values * (1 - dones)

        delta = target_q_values - q_values
        delta_for_priorities = delta.detach().cpu().numpy().flatten()
        delta = delta + (torch.clamp(delta, -1, 1) - delta).detach()
        loss = (delta ** 2).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return delta_for_priorities

    def update_target(self):
        self.target_net.load_state_dict(self.q_net.state_dict())


def collect_transitions(memory, env, behavior_name, action_space, replay_start_size=2000):
    while len(memory) < replay_start_size:
        env.reset()
        decision_steps, _ = env.get_steps(behavior_name)
        done = False
        while not done:
            state = decision_steps.obs[0][0]
            action = random.choice(action_space)
            action_tuple = ActionTuple(discrete=np.array([[action]], dtype=np.int32))
            env.set_actions(behavior_name, action_tuple)
            env.step()
            decision_steps, terminal_steps = env.get_steps(behavior_name)
            if len(terminal_steps) > 0:
                next_state = terminal_steps.obs[0][0]
                reward = terminal_steps.reward[0]
                done = True
            else:
                next_state = decision_steps.obs[0][0]
                reward = decision_steps.reward[0]
                done = False
            memory.push(state, action, reward, next_state, done)
            state = next_state
            if len(memory) >= replay_start_size:
                break
    print("-"*100)
    print(f"[Info] {len(memory)} transitions collected")
    print("-"*100 + "\n")


def collect_eval_states(env, behavior_name, action_space, num_eval_states=500):
    eval_states = []
    while len(eval_states) < num_eval_states:
        env.reset()
        decision_steps, _ = env.get_steps(behavior_name)
        done = False
        while not done:
            state = decision_steps.obs[0][0]
            eval_states.append(state)
            action = random.choice(action_space)
            action_tuple = ActionTuple(discrete=np.array([[action]], dtype=np.int32))
            env.set_actions(behavior_name, action_tuple)
            env.step()
            decision_steps, terminal_steps = env.get_steps(behavior_name)
            if len(terminal_steps) > 0:
                next_state = terminal_steps.obs[0][0]
                done = True
            else:
                next_state = decision_steps.obs[0][0]
                done = False
            state = next_state
            if len(eval_states) >= num_eval_states:
                break
    print("-"*100)
    print(f"[Info] {len(eval_states)} evaluation states collected")
    print("-"*100 + "\n")
    return eval_states


def compute_avg_max_q(agent, eval_states):
    total_max_q = 0
    with torch.no_grad():
        for state in eval_states:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            q_values = agent.q_net(state_tensor)
            max_q = torch.max(q_values).item()
            total_max_q += max_q
    avg_max_q = total_max_q / len(eval_states)
    return avg_max_q


def plot_avg_max_q(episode_plot, avg_reward_plot, avg_max_q_plot, episode_count):
    fig, axes = plt.subplots(1, 2, figsize=(15,5))
    axes[0].plot(episode_plot, avg_reward_plot)
    axes[0].set_xlabel('Episodes')
    axes[0].set_ylabel('Average Reward')
    axes[0].set_title('Average Reward per 100 Episode')
    axes[0].grid(True)
    axes[1].plot(episode_plot, avg_max_q_plot)
    axes[1].set_xlabel('Episodes')
    axes[1].set_ylabel('Average Max Q')
    axes[1].set_title('Average Max Q per 100 Episode')
    axes[1].grid(True)
    plt.tight_layout()
    plt.savefig(f"plot.png")
    plt.close(fig)


if __name__ == "__main__":
    channel = EngineConfigurationChannel()
    channel.set_configuration_parameters(time_scale=20.0)
    env = UnityEnvironment(file_name="a/PongBuild.x86_64", side_channels=[channel], no_graphics=True)
    env.reset()
    available_behaviors = list(env.behavior_specs.keys())
    behavior_name = available_behaviors[0]
    spec = env.behavior_specs[behavior_name]

    state_dim = spec.observation_specs[0].shape[0]
    action_dim = spec.action_spec.discrete_branches[0]
    agent = DQNAgent(state_dim, action_dim)
    action_space = list(range(action_dim))
    memory = ReplayMemory()
    collect_transitions(memory, env, behavior_name, action_space)
    eval_states = collect_eval_states(env, behavior_name, action_space)

    episode_count = 0
    total_steps = 0
    best_test_reward = -float('inf')

    state_skip = 4
    target_update_step, next_target_update_step = 5000, 5000
    test_reward_limit, test_reward_limit_increment = 200, 200
    episode_count_interval = 100
    recent_rewards = deque(maxlen=episode_count_interval)

    episode_plot = []
    avg_reward_plot = []
    avg_max_q_plot = []

    n_step = 3
    n_step_buffer = deque(maxlen=n_step)

    while True:
        env.reset()
        decision_steps, _ = env.get_steps(behavior_name)
        done = False
        episode_reward = 0
        episode_count += 1

        while not done:
            state = decision_steps.obs[0][0]
            action = agent.select_action(state, action_space)

            skip_reward = 0
            for _ in range(state_skip):
                action_tuple = ActionTuple(discrete=np.array([[action]], dtype=np.int32))
                env.set_actions(behavior_name, action_tuple)
                env.step()
                decision_steps, terminal_steps = env.get_steps(behavior_name)
                if len(terminal_steps) > 0:
                    next_state = terminal_steps.obs[0][0]
                    reward = terminal_steps.reward[0]
                    done = True
                else:
                    next_state = decision_steps.obs[0][0]
                    reward = decision_steps.reward[0]
                    done = False
                skip_reward += reward
                total_steps += 1
                agent.decay_epsilon()
                if done:
                    break

            n_step_buffer.append((state, action, skip_reward, next_state, done))
            if len(n_step_buffer) == n_step:
                R = 0
                for idx, (_, _, r, _, _) in enumerate(n_step_buffer):
                    R += (agent.gamma ** idx) * r
                s, a, _, _, _ = n_step_buffer[0]
                _, _, _, last_next_state, last_done = n_step_buffer[-1]
                memory.push(s, a, R, last_next_state, last_done)
            if done:
                while len(n_step_buffer) > 0:
                    R = 0
                    for idx, (_, _, r, _, _) in enumerate(n_step_buffer):
                        R += (agent.gamma ** idx) * r
                    s, a, _, n, d = n_step_buffer.popleft()
                    if len(n_step_buffer) > 0:
                        _, _, _, last_next_state, last_done = n_step_buffer[-1]
                    else:
                        last_next_state, last_done = n, d
                    memory.push(s, a, R, last_next_state, last_done)

            episode_reward += skip_reward
            state = next_state

            if len(memory) >= memory.batch_size:
                batch = memory.sample()
                delta_for_priorities = agent.train_step(batch, n_step)
                memory.update_priorities(delta_for_priorities)

            if total_steps >= next_target_update_step:
                agent.update_target()
                next_target_update_step += target_update_step

                test_reward = 0
                env.reset()
                decision_steps, _ = env.get_steps(behavior_name)
                test_done = False
                while not test_done:
                    state = decision_steps.obs[0][0]
                    action = agent.select_action(state, action_space, train=False)
                    action_tuple = ActionTuple(discrete=np.array([[action]], dtype=np.int32))
                    env.set_actions(behavior_name, action_tuple)
                    env.step()
                    decision_steps, terminal_steps = env.get_steps(behavior_name)
                    if len(terminal_steps) > 0:
                        next_state = terminal_steps.obs[0][0]
                        t_reward = terminal_steps.reward[0]
                        test_done = True
                    else:
                        next_state = decision_steps.obs[0][0]
                        t_reward = decision_steps.reward[0]
                        test_done = False
                    test_reward += t_reward
                    state = next_state

                    if test_reward >= test_reward_limit:
                        test_done = True
                        test_reward_limit += test_reward_limit_increment

                print("="*100)
                print(f"[Test] Episode {episode_count}: Test reward = {test_reward}")
                if test_reward > best_test_reward:
                    best_test_reward = test_reward
                    print(f"[Test] Episode {episode_count}: New best test reward = {best_test_reward}")
                    torch.save(agent.q_net.state_dict(), "best_model.pth")
                    print(f"[Test] Episode {episode_count}: Model saved as 'best_model.pth'")
                print("="*100 + "\n")

        if episode_count % episode_count_interval == 0:
            print("-"*100)
            print(f"[Info] Episode {episode_count} in progress, Step {total_steps}, Epsilon {agent.epsilon}")
            print("-"*100 + "\n")
            episode_plot.append(episode_count)
            recent_rewards.append(episode_reward)
            avg_reward = sum(recent_rewards) / len(recent_rewards)
            avg_reward_plot.append(avg_reward)
            avg_max_q = compute_avg_max_q(agent, eval_states)
            avg_max_q_plot.append(avg_max_q)
            plot_avg_max_q(episode_plot, avg_reward_plot, avg_max_q_plot, episode_count)

In [None]:
!python3.10 train1.py

Disconnect the runtime and reconnect.

In [None]:
!pip install --upgrade onnx

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

state_dim = 5
action_dim = 3

model = QNetwork(state_dim, action_dim)
state_dict = torch.load("best_model.pth", weights_only=True)
model.load_state_dict(state_dict)
model.eval()

dummy_input = torch.randn(1, state_dim)

torch.onnx.export(
    model,
    dummy_input,
    "best_model.onnx",
    export_params=True,
    opset_version=15,
    do_constant_folding=True,
    input_names=['X'],
    output_names=['Y'],
    dynamo=False
)