In [1]:
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
from typing import Dict
from argparse import Namespace
import jax
from lux.utils import direction_to
from luxai_s3.params import EnvParams
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import namedtuple, defaultdict

# from lux.config import EnvConfig
from lux.kit import from_json
env = LuxAIS3GymEnv()




In [5]:
## QMIX Neural Network
## 매우 제한적인 모델, 수정 필요
## 인풋으로 자기 자신의 observation만 받고, history를 받지도 않는다

class QMixNet(nn.Module):
    def __init__(self, num_agents, state_dim, action_dim=6, hidden_dim=32):
        super(QMixNet, self).__init__()
        self.num_agents = num_agents
        self.state_dim = state_dim
        self.action_dim = action_dim

        # State-dependent weights for mixing
        self.hyper_w1 = nn.LazyLinear(action_dim * hidden_dim)
        self.hyper_w2 = nn.LazyLinear(hidden_dim)
        self.hyper_w3 = nn.LazyLinear(1)
        self.hyper_b1 = nn.LazyLinear(num_agents * hidden_dim)
        self.hyper_b2 = nn.LazyLinear(num_agents)

        self.elu = nn.ELU()

    def forward(self, agent_qs, state_inputs):
        batch_size = agent_qs.size(0)
        
        # Flatten state inputs for the hypernetworks
        state_inputs = state_inputs.view(batch_size, -1)

        # Compute weights and biases
        w1 = self.hyper_w1(state_inputs).view(batch_size, self.action_dim, -1)
        w2 = self.hyper_w2(state_inputs).view(batch_size, -1, 1)
        b1 = self.hyper_b1(state_inputs).view(batch_size, self.num_agents, -1)
        b2 = self.hyper_b2(state_inputs).view(batch_size, self.num_agents, 1)

        # Mixing process
        hidden = self.elu(torch.bmm(agent_qs, w1) + b1)
        
        q_total = torch.bmm(hidden, w2) + b2
        q_total = self.hyper_w3(q_total.squeeze())

        return q_total

# Replay Buffer
Transition = namedtuple("Transition", ("state", "actions", "rewards", "next_state", "dones"))

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, *args):
        if len(self.memory) >= self.capacity:
            self.memory.pop(0)
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        indices = np.random.choice(len(self.memory), batch_size, replace=False)
        batch = [self.memory[idx] for idx in indices]
        return Transition(*zip(*batch))

    def __len__(self):
        return len(self.memory)

In [29]:
# Training Loop
# 초기값 및 모델 하이퍼파라미터 세팅

num_agents = 16
np.random.seed(16)
env = LuxAIS3GymEnv()
env = RecordEpisode(env, save_dir="episodes")
env_params = EnvParams(map_type=1, max_steps_in_match=100)
state_dim = 2
action_dim = 6
mixing_dim = 32
N = env_params.max_steps_in_match * env_params.match_count_per_episode

q_networks = [nn.Sequential(nn.Linear(state_dim, 64), nn.ReLU(), nn.Linear(64, action_dim)) for _ in range(num_agents)]
optimizers = [optim.Adam(q.parameters(), lr=0.001) for q in q_networks]
mixing_network = QMixNet(num_agents, state_dim, action_dim, mixing_dim)
mixing_optimizer = optim.Adam(mixing_network.parameters(), lr=0.001)

buffer = ReplayBuffer(capacity=3000)
batch_size = 32
epsilon = 0.08
num_episodes = 100

def select_action(q_values, epsilon):
    if np.random.rand() < epsilon:
        return np.random.randint(q_values.size(-1))
    return [torch.argmax(q_values).item(),0,0]

In [30]:
# training process
# 해당 셀을 실행하면 학습이 진행됩니다.

for episode in range(num_episodes):
    states, info = env.reset(seed=1, options=dict(params=env_params))
    done = False
    episode_transitions = []
    i = 0

    while not done or i < N:
        # 각 플레이어의 행동, q value, observation을 담는 object의 지정
        actions = np.zeros((num_agents, 3), dtype=int)
        opp_actions = np.zeros((num_agents, 3), dtype=int)
        q_values = {}
        opp_q_values = {}
        state = states['player_0'].units.position[0]
        opp_state = states['player_1'].units.position[0]
        
        # 플레이어 0의 에이전트 별 q value 계산 및 decentralized action 생성
        for agent_id, q_network in zip(range(num_agents), q_networks):
            obs = torch.tensor(np.array(state[agent_id]), dtype=torch.float32).unsqueeze(0)
            q_values[agent_id] = q_network(obs)
            actions[agent_id] = select_action(q_values[agent_id], epsilon)

        # self-play 기반 opponent action 생성
        # 위 for loop와 동일
        with torch.no_grad():
                for agent_id, q_network in zip(range(num_agents), q_networks):
                    obs_opp = torch.tensor(np.array(opp_state[agent_id]), dtype=torch.float32).unsqueeze(0)
                    opp_q_values[agent_id] = q_network(obs_opp)
                    opp_actions[agent_id] = select_action(opp_q_values[agent_id], epsilon)

        #opp_actions = env.action_space.sample()['player_1'] # random action opponents로 학습도 가능

        act = {'player_0': actions, 'player_1': opp_actions}

        # 버퍼에는 내 플레이어의 정보들만 저장
        next_state, rewards, terminated, truncated, info = env.step(act)
        dones = terminated['player_0'] + truncated['player_0']
        buffer.push(state, actions, rewards['player_0'], next_state['player_0'].units.position[0], dones)
        state = next_state
        done = dones
        i += 1

    # optimizing 하는 루프
    if len(buffer) >= batch_size:
        batch = buffer.sample(batch_size)
        state_batch = torch.tensor(np.array([s for s in batch.state]), dtype=torch.float32)
        action_batch = torch.tensor(np.array([a for a in batch.actions]), dtype=torch.int64)
        reward_batch = torch.tensor(np.array([r for r in batch.rewards]), dtype=torch.float32)
        next_state_batch = torch.tensor(np.array([ns for ns in batch.next_state]), dtype=torch.float32)
        done_batch = torch.tensor(np.array([d for d in batch.dones]), dtype=torch.float32)

        # Calculate individual Q-values
        agent_qs = []
        for i, q_network in enumerate(q_networks):
            agent_qs.append(q_network(state_batch[:, i, :]))
        agent_qs = torch.stack(agent_qs, dim=1)

        # Calculate total Q-value using mixing network
        state_inputs = state_batch.view(batch_size, -1)
        q_total = mixing_network(agent_qs, state_inputs)

        # Calculate individual next Q-values
        agent_qs_next = []
        for i, q_network in enumerate(q_networks):
            agent_qs_next.append(q_network(next_state_batch[:, i, :]))
        agent_qs_next = torch.stack(agent_qs_next, dim=1)

        # Calculate total next Q-value using mixing network
        next_state_inputs = next_state_batch.view(batch_size, -1)
        with torch.no_grad():
            q_total_next = mixing_network(agent_qs_next, next_state_inputs)

        # Compute loss and update networks
        loss = torch.mean((reward_batch + q_total_next - q_total) ** 2)

        for optimizer in optimizers:
            optimizer.zero_grad()
        mixing_optimizer.zero_grad()
        loss.backward()

        for optimizer in optimizers:
            optimizer.step()
        mixing_optimizer.step()

    if episode % 1 == 0:
        print(f"Episode {episode}, Loss: {loss.item():.4f}. my_team_wins: {rewards['player_0']}")


Episode 0, Loss: 725.9942. my_team_wins: 4
Episode 1, Loss: 666.0987. my_team_wins: 3
Episode 2, Loss: 1397.9777. my_team_wins: 3
Episode 3, Loss: 554.3622. my_team_wins: 1
Episode 4, Loss: 333.6446. my_team_wins: 1
Episode 5, Loss: 474.9640. my_team_wins: 3
Episode 6, Loss: 925.1939. my_team_wins: 4
Episode 7, Loss: 943.4852. my_team_wins: 2
Episode 8, Loss: 801.5966. my_team_wins: 3
Episode 9, Loss: 138.7711. my_team_wins: 1
Episode 10, Loss: 304.3309. my_team_wins: 3
Episode 11, Loss: 494.7083. my_team_wins: 2
Episode 12, Loss: 221.8595. my_team_wins: 0
Episode 13, Loss: 168.3093. my_team_wins: 4
Episode 14, Loss: 197.1183. my_team_wins: 4
Episode 15, Loss: 175.3127. my_team_wins: 4
Episode 16, Loss: 365.1474. my_team_wins: 3
Episode 17, Loss: 334.0553. my_team_wins: 3
Episode 18, Loss: 241.1416. my_team_wins: 0
Episode 19, Loss: 562.9294. my_team_wins: 3
Episode 20, Loss: 193.5354. my_team_wins: 3
Episode 21, Loss: 504.6407. my_team_wins: 2
Episode 22, Loss: 432.3690. my_team_wins:

In [31]:
# 수동 저장
# 체크포인트 생성 및 모델의 자동 업데이트 기능은 아직 없습니다
q_net = nn.ModuleList(q_networks)
torch.save(q_net, "q_net.pt")