In [231]:
from magent2.environments import battle_v4
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import deque,defaultdict
import random
import numpy as np
from dqn_model import DQN


## Replay Buffer

In [232]:

class ReplayBuffer:
    def __init__(self, capacity,observation_shape,action_shape):
        self.capacity = capacity
        self.observation_shape = observation_shape
        self.action_shape = action_shape
        self.buffers = defaultdict(lambda: {
            'observation': deque(maxlen=capacity),
            'action': deque(maxlen=capacity),
            'reward': deque(maxlen=capacity),
            'next_observation': deque(maxlen=capacity),
            'done': deque(maxlen=capacity),
        })

    def update_reward(self,agent_id,new_reward):
        if agent_id not in self.buffers:
            return
        self.buffers[agent_id]['reward'][-1] = new_reward
    def add(self,agent_id, observation, action, reward, next_observation, done):
        self.buffers[agent_id]['observation'].append(observation)
        self.buffers[agent_id]['action'].append(action)
        self.buffers[agent_id]['reward'].append(reward)
        self.buffers[agent_id]['next_observation'].append(next_observation)
        self.buffers[agent_id]['done'].append(done)
    def sample(self,batch_size):
        if len(self.buffers.keys()) == 0 or sum(len(agent['observation']) for agent in self.buffers.values()) < batch_size:
            return None
        transitions = []
        for agent_id in self.buffers.keys():
            agent = self.buffers[agent_id]
            for i in range(len(agent['observation'])):
                transition = {
                    'observation': agent['observation'][i],
                    'action': agent['action'][i],
                    'reward': agent['reward'][i],
                    'next_observation': agent['next_observation'][i],
                    'done': agent['done'][i],
                }
                transitions.append(transition)
        samples_index = np.random.choice(len(transitions),batch_size,replace=False)
        return {
            'observation': np.array([transitions[i]['observation'] for i in samples_index]),
            'action' :np.array([transitions[i]['action'] for i in samples_index]),
            'reward' : np.array([transitions[i]['reward'] for i in samples_index]),
            'next_observation' : np.array([transitions[i]['next_observation'] for i in samples_index]),
            'done' : np.array([transitions[i]['done'] for i in samples_index]),
        }
    def clear(self, agent_id=None):
        if agent_id:
            self.buffers[agent_id]['observation'].clear()
            self.buffers[agent_id]['action'].clear()
            self.buffers[agent_id]['reward'].clear()
            self.buffers[agent_id]['next_observation'].clear()
            self.buffers[agent_id]['done'].clear()
        else:
            for agent_id in self.buffers:
                self.clear(agent_id)

## Khởi tạo môi trường

In [233]:
env = battle_v4.env(map_size=45, render_mode="rgb_array",max_cycles=300,attack_opponent_reward=3)
observation_shape = env.observation_space("blue_0").shape
action_shape = env.action_space("blue_0").n

BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize networks
policy_net = DQN(observation_shape, action_shape).to(device)
red_policy = DQN(observation_shape, action_shape).to(device)
target_net = DQN(observation_shape, action_shape).to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)

# try:
#     checkpoint = torch.load("models/blue.pt", map_location=device, weights_only=True)
#     policy_net.load_state_dict(checkpoint["policy_net_state_dict"])
#     target_net.load_state_dict(checkpoint["target_net_state_dict"])
#     red_policy.load_state_dict(checkpoint["policy_net_state_dict"])
#     optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
#     episode = checkpoint["episode"]
#     print(f"Start with episode: {episode}")
# except Exception as e:
#     print(f"No model found!")
episode = 0

# pretrained
from torch_model import QNetwork
pretrained = QNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
pretrained.load_state_dict(
    torch.load("red.pt", weights_only=True, map_location=device)
)

# trained
from final_torch_model import QNetwork as FinalQNetwork
trained = FinalQNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
trained.load_state_dict(
    torch.load("red_final.pt", weights_only=True, map_location=device)
)


buffer = ReplayBuffer(10000, observation_shape, action_shape)
steps_done = episode
episode_rewards = []
episode_losses = []
running_loss = 0.0
num_episodes = 60

def plot_durations(episode_rewards,episode_losses, show_result=False):
    plt.figure(1)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    rewards_t = torch.tensor(episode_rewards, dtype=torch.float)
    losses_t = torch.tensor(episode_losses, dtype=torch.float)

    plt.plot(rewards_t.numpy(), label='Reward')
    plt.plot(losses_t.numpy(), label='Loss')

    if len(rewards_t) >= 5:
        rewards_means = rewards_t.unfold(0, 5, 1).mean(1).view(-1)
        rewards_means = torch.cat((torch.zeros(4), rewards_means))
        plt.plot(rewards_means.numpy(), label='Reward (mean)')

    if len(losses_t) >= 5:
        losses_means = losses_t.unfold(0, 5, 1).mean(1).view(-1)
        losses_means = torch.cat((torch.zeros(4), losses_means))
        plt.plot(losses_means.numpy(), label='Loss (mean)')
    plt.legend()
    plt.pause(0.001)
    if not show_result:
        display.display(plt.gcf())
        display.clear_output(wait=True)
    else:
        display.display(plt.gcf())
def linear_epsilon(steps_done):
    return max(EPS_END, EPS_START - (EPS_START - EPS_END) * (steps_done / EPS_DECAY))

def policy(observation, q_network):
    global steps_done
    sample = random.random()
    if sample < linear_epsilon(steps_done):
        return env.action_space("blue_0").sample()
    else:
        observation = (
            torch.Tensor(observation).to(device)
        )
        with torch.no_grad():
            q_values = q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

In [234]:
def save_model(i_episode, policy_net, target_net, optimizer, episode_rewards, episode_losses, path):
    torch.save({
        'episode': i_episode,
        'policy_net_state_dict': policy_net.state_dict(),
        'target_net_state_dict': target_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'episode_rewards': episode_rewards,
        'episode_losses': episode_losses,
    }, path)

## Training Loop

In [235]:
def optimize_model():
    global running_loss

    batch = buffer.sample(BATCH_SIZE)

    if batch is None:
        return

    state_batch = torch.from_numpy(batch['observation']).float().to(device)
    action_batch = torch.from_numpy(batch['action']).long().to(device)
    reward_batch = torch.from_numpy(batch['reward']).float().to(device)
    next_state_batch = torch.from_numpy(batch['next_observation']).float().to(device)
    done_batch = torch.from_numpy(batch['done']).float().to(device)

    # Reshape action_batch to (BATCH_SIZE, 1) for gather()
    action_batch = action_batch.unsqueeze(1)
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    non_final_mask = (done_batch == 0).squeeze()  # Create a mask for non-terminal states

    # Only compute for non-terminal states
    if non_final_mask.any():
        next_state_values[non_final_mask] = target_net(next_state_batch[non_final_mask]).max(1).values.detach()

    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    running_loss += loss.item()

    return loss.item()

### Training with pretrained policy red

In [None]:
for i_episode in range(episode, num_episodes):
    
    env.reset()
    episode_reward = 0
    running_loss = 0.0
    steps_done += 1

    for agent in env.agent_iter():

        observation, reward, termination, truncation, info = env.last()
        done = termination or truncation
        episode_reward += reward

        if done:
            action = None  # Agent is dead
            env.step(action)
        else:
            agent_handle = agent.split("_")
            agent_id = agent_handle[1]
            agent_team = agent_handle[0]
            if agent_team == "blue":

                buffer.update_reward(agent_id, reward) # update reward of last agent's action (bad environment!)

                action = policy(observation, policy_net)
                env.step(action)

                try:
                    next_observation = env.observe(agent)
                    agent_done = False
                except:
                    next_observation = None
                    agent_done = True

                reward = 0 # Wait for next time to be selected to get reward

                # Store the transition in buffer
                buffer.add(agent_id, observation, action, reward, next_observation, agent_done)

                # Perform one step of the optimization (on the policy network)
                optimize_model()

                # Soft update of the target network's weights
                # θ′ ← τ θ + (1 −τ )θ′
                target_net_state_dict = target_net.state_dict()
                policy_net_state_dict = policy_net.state_dict()
                for key in policy_net_state_dict:
                    target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
                target_net.load_state_dict(target_net_state_dict)

            else:
                # red agent
                action = policy(observation, red_policy)
                env.step(action)

        # Periodically update the red agent's policy with the blue agent's learned policy
        if i_episode % 4 == 0 and i_episode < 24:
            # Copy all weights and biases from the blue agent's policy network to the red agent's
            red_policy.load_state_dict(policy_net.state_dict())
        elif i_episode == 24: # more complex (pretrained) opponent
            red_policy.load_state_dict(pretrained.state_dict())



    # Add these lines at the end of each episode
    episode_rewards.append(episode_reward)
    episode_losses.append(running_loss)

    print(f'Episode {i_episode + 1}/{num_episodes}')
    print(f'Total Reward of previous episode: {episode_reward:.2f}')
    print(f'Average Loss: {running_loss:.4f}')
    print(f'Epsilon: {linear_epsilon(steps_done)}')
    print('-' * 40)
    save_model(i_episode, policy_net, target_net, optimizer, episode_rewards, episode_losses, path=f"models1/blue_{i_episode}.pt")

plot_durations(episode_rewards, episode_losses, show_result=True)
plt.ioff()
plt.show()

Episode 1/60
Total Reward of previous episode: -72.05
Average Loss: 765.7060
Epsilon: 0.89915
----------------------------------------
Episode 2/60
Total Reward of previous episode: 112.45
Average Loss: 1862.0758
Epsilon: 0.8983
----------------------------------------


### Training with trained policy red

In [224]:
for i_episode in range(episode, num_episodes):
    
    env.reset()
    episode_reward = 0
    running_loss = 0.0
    steps_done += 1

    for agent in env.agent_iter():

        observation, reward, termination, truncation, info = env.last()
        done = termination or truncation
        episode_reward += reward

        if done:
            action = None  # Agent is dead
            env.step(action)
        else:
            agent_handle = agent.split("_")
            agent_id = agent_handle[1]
            agent_team = agent_handle[0]
            if agent_team == "blue":

                buffer.update_reward(agent_id, reward)

                action = policy(observation, policy_net)
                env.step(action)

                try:
                    next_observation = env.observe(agent)
                    agent_done = False
                except:
                    next_observation = None
                    agent_done = True

                reward = 0 

                # Store the transition in buffer
                buffer.add(agent_id, observation, action, reward, next_observation, agent_done)

                # Perform one step of the optimization (on the policy network)
                optimize_model()

                # Soft update of the target network's weights
                # θ′ ← τ θ + (1 −τ )θ′
                target_net_state_dict = target_net.state_dict()
                policy_net_state_dict = policy_net.state_dict()
                for key in policy_net_state_dict:
                    target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
                target_net.load_state_dict(target_net_state_dict)

            elif agent_team == "red":
                action = policy(observation, red_policy)
                env.step(action)

        if i_episode % 4 == 0 and i_episode < 24:
            red_policy.load_state_dict(policy_net.state_dict())
        elif i_episode == 24: 
            red_policy.load_state_dict(trained.state_dict())

    episode_rewards.append(episode_reward)
    episode_losses.append(running_loss)

    print(f'Episode {i_episode + 1}/{num_episodes}')
    print(f'Total Reward of previous episode: {episode_reward:.2f}')
    print(f'Average Loss: {running_loss:.4f}')
    print(f'Epsilon: {linear_epsilon(steps_done)}')
    print('-' * 40)
    save_model(i_episode, policy_net, target_net, optimizer, episode_rewards, episode_losses, path=f"models2/blue_{i_episode}.pt")

plot_durations(episode_rewards, episode_losses, show_result=True)
plt.ioff()
plt.show()

KeyboardInterrupt: 