In [6]:
from magent2.environments import battle_v4
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
from collections import deque,defaultdict
import random
import numpy as np
from dqn_model import DQN


## Replay Buffer

In [7]:
class ReplayBuffer:
    def __init__(self, capacity,observation_shape,action_shape):
        self.capacity = capacity
        self.observation_shape = observation_shape
        self.action_shape = action_shape
        self.buffers = defaultdict(lambda: {
            'observation': deque(maxlen=capacity),
            'action': deque(maxlen=capacity),
            'reward': deque(maxlen=capacity),
            'next_observation': deque(maxlen=capacity),
            'done': deque(maxlen=capacity),
        })

    def update_reward(self,agent_id,new_reward):
        if agent_id not in self.buffers:
            return
        self.buffers[agent_id]['reward'][-1] = new_reward
        
    def add(self,agent_id, observation, action, reward, next_observation, done):
        self.buffers[agent_id]['observation'].append(observation)
        self.buffers[agent_id]['action'].append(action)
        self.buffers[agent_id]['reward'].append(reward)
        self.buffers[agent_id]['next_observation'].append(next_observation)
        self.buffers[agent_id]['done'].append(done)
        
    def sample(self,batch_size):
        if len(self.buffers.keys()) == 0 or sum(len(agent['observation']) for agent in self.buffers.values()) < batch_size:
            return None
        transitions = []
        for agent_id in self.buffers.keys():
            agent = self.buffers[agent_id]
            for i in range(len(agent['observation'])):
                transition = {
                    'observation': agent['observation'][i],
                    'action': agent['action'][i],
                    'reward': agent['reward'][i],
                    'next_observation': agent['next_observation'][i],
                    'done': agent['done'][i],
                }
                transitions.append(transition)
        samples_index = np.random.choice(len(transitions),batch_size,replace=False)
        return {
            'observation': np.array([transitions[i]['observation'] for i in samples_index]),
            'action' :np.array([transitions[i]['action'] for i in samples_index]),
            'reward' : np.array([transitions[i]['reward'] for i in samples_index]),
            'next_observation' : np.array([transitions[i]['next_observation'] for i in samples_index]),
            'done' : np.array([transitions[i]['done'] for i in samples_index]),
        }
    def clear(self, agent_id=None):
        if agent_id:
            self.buffers[agent_id]['observation'].clear()
            self.buffers[agent_id]['action'].clear()
            self.buffers[agent_id]['reward'].clear()
            self.buffers[agent_id]['next_observation'].clear()
            self.buffers[agent_id]['done'].clear()
        else:
            for agent_id in self.buffers:
                self.clear(agent_id)

## Khởi tạo môi trường

In [8]:
env = battle_v4.env(map_size=45, render_mode="rgb_array",max_cycles=300,attack_opponent_reward=1)

observation_shape = env.observation_space("blue_0").shape
action_shape = env.action_space("blue_0").n

policy_net = DQN(observation_shape, action_shape).to("cuda")

red_policy = DQN(observation_shape, action_shape).to("cuda")

target_net = DQN(observation_shape, action_shape).to("cuda")

target_net.load_state_dict(policy_net.state_dict())

BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4


optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)

episode = 0

# pretrained model
from torch_model import QNetwork
pretrained = QNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
pretrained.load_state_dict(
    torch.load("red.pt", weights_only=True, map_location="cuda")
)

# trained model
from final_torch_model import QNetwork as FinalQNetwork
trained = FinalQNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
trained.load_state_dict(
    torch.load("red_final.pt", weights_only=True, map_location="cuda")
)


buffer = ReplayBuffer(10000, observation_shape, action_shape)
steps_done = episode
episode_rewards = []
episode_losses = []
running_loss = 0.0
num_episodes = 60

def linear_epsilon(steps_done):
    return max(EPS_END, EPS_START - (EPS_START - EPS_END) * (steps_done / EPS_DECAY))

def policy(observation, q_network):
    global steps_done
    sample = random.random()
    if sample < linear_epsilon(steps_done):
        return env.action_space("blue_0").sample()
    else:
        observation = (
            torch.Tensor(observation).to("cuda")
        )
        with torch.no_grad():
            q_values = q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

In [9]:
def save_model(i_episode, policy_net, target_net, optimizer, episode_rewards, episode_losses, path):
    torch.save({
        'episode': i_episode,
        'policy_net_state_dict': policy_net.state_dict(),
        'target_net_state_dict': target_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'episode_rewards': episode_rewards,
        'episode_losses': episode_losses,
    }, path)

## Training Loop

In [10]:
def optimize_model():
    global running_loss

    batch = buffer.sample(BATCH_SIZE)
    if batch is None:
        return

    state_batch = torch.from_numpy(batch['observation']).float().to("cuda")
    action_batch = torch.from_numpy(batch['action']).long().to("cuda")
    reward_batch = torch.from_numpy(batch['reward']).float().to("cuda")
    next_state_batch = torch.from_numpy(batch['next_observation']).float().to("cuda")
    done_batch = torch.from_numpy(batch['done']).float().to("cuda")

    action_batch = action_batch.unsqueeze(1)
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(BATCH_SIZE, device="cuda")
    non_final_mask = (done_batch == 0).squeeze()
    
    if non_final_mask.any():
        next_state_values[non_final_mask] = target_net(next_state_batch[non_final_mask]).max(1).values.detach()

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    running_loss += loss.item()

    return loss.item()

In [11]:

for i_episode in range(episode, num_episodes):
    env.reset()
    episode_reward = 0
    running_loss = 0.0
    steps_done += 1

    for agent in env.agent_iter():

        observation, reward, termination, truncation, info = env.last()
        done = termination or truncation
        episode_reward += reward

        if done:
            action = None
            env.step(action)
        else:
            agent_handle = agent.split("_")
            agent_id = agent_handle[1]
            agent_team = agent_handle[0]
            if agent_team == "blue":

                buffer.update_reward(agent_id, reward)

                action = policy(observation, policy_net)
                env.step(action)

                try:
                    next_observation = env.observe(agent)
                    agent_done = False
                except:
                    next_observation = None
                    agent_done = True
                reward = 0 
                buffer.add(agent_id, observation, action, reward, next_observation, agent_done)

                optimize_model()

                # Soft update of the target network's weights
                # θ′ ← τ θ + (1 −τ )θ′
                target_net_state_dict = target_net.state_dict()
                policy_net_state_dict = policy_net.state_dict()
                for key in policy_net_state_dict:
                    target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
                target_net.load_state_dict(target_net_state_dict)

            elif agent_team == "red":
                action = env.action_space(agent).sample()
                env.step(action)

        if (i_episode + 1) % 3 == 0:
            red_policy.load_state_dict(policy_net.state_dict())

    episode_rewards.append(episode_reward)
    episode_losses.append(running_loss)

    print(f'Episode {i_episode + 1}/{num_episodes}')
    print(f'Total Reward of previous episode: {episode_reward:.2f}')
    print(f'Average Loss: {running_loss:.4f}')
    print(f'Epsilon: {linear_epsilon(steps_done)}')
    print('-' * 40)
    save_model(i_episode, policy_net, target_net, optimizer, episode_rewards, episode_losses, path=f"models/blue_{i_episode}.pt")

plt.ioff()
plt.show()

Episode 1/60
Total Reward of previous episode: -1314.07
Average Loss: 493.7680
Epsilon: 0.89915
----------------------------------------
Episode 2/60
Total Reward of previous episode: -1476.27
Average Loss: 1326.9272
Epsilon: 0.8983
----------------------------------------
Episode 3/60
Total Reward of previous episode: -1431.66
Average Loss: 965.0810
Epsilon: 0.89745
----------------------------------------
Episode 4/60
Total Reward of previous episode: -1397.04
Average Loss: 746.9168
Epsilon: 0.8966000000000001
----------------------------------------
Episode 5/60
Total Reward of previous episode: -1353.14
Average Loss: 841.8537
Epsilon: 0.89575
----------------------------------------


KeyboardInterrupt: 