In [1]:
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
from typing import Dict
from argparse import Namespace
import jax
from lux.utils import direction_to
from luxai_s3.params import EnvParams
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import namedtuple, defaultdict
import torch.nn.functional as F
from gen_data import gen_data
from agent import Agent

# from lux.config import EnvConfig
from lux.kit import from_json
env = LuxAIS3GymEnv()



In [2]:
# Replay Buffer
Transition = namedtuple("Transition", ("state", "actions", "rewards", "next_state", "dones"))

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, *args):
        if len(self.memory) >= self.capacity:
            self.memory.pop(0)
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        indices = np.random.choice(len(self.memory), batch_size, replace=False)
        batch = [self.memory[idx] for idx in indices]
        return Transition(*zip(*batch))

    def __len__(self):
        return len(self.memory)

In [3]:
# 에이전트 별 q-network
class AgentQNetwork(nn.Module):
    def __init__(self, obs_shape=(5, 24, 24), action_space=(6, 24, 24)):
        super(AgentQNetwork, self).__init__()
        self.action_space = action_space
        self.conv1 = nn.Conv2d(obs_shape[0], 16, kernel_size=3, stride=1, padding=1) 
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.LazyLinear(128)
        self.fc2 = nn.LazyLinear(action_space[0] * action_space[1] * action_space[2])

    def forward(self, obs): # Input dimension -> (bs, 5, 24, 24)
        if len(obs.shape) < 4:
            obs = obs.unsqueeze(0) # 1개 짜리 input을 받은 경우 (1,5,24,24)로 변환
        x = F.relu(self.conv1(obs))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = nn.Flatten()(x)  # Flatten for the fully connected layers
        x = F.relu(self.fc1(x))
        q_values = self.fc2(x).view((obs.shape)[0], *self.action_space) 
        return q_values # Out dimension -> (bs, 6, 24, 24).

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MixerNetwork(nn.Module):
    def __init__(self, n_agents=16, hidden_dim=64):
        super(MixerNetwork, self).__init__()
        self.n_agents = n_agents

        # CNN으로 상태 정보 처리
        self.state_encoder = nn.Sequential(
            nn.Conv2d(5, 32, kernel_size=3, stride=1, padding=1),  # (5, 24, 24) -> (32, 24, 24)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # (32, 24, 24) -> (64, 12, 12)
            nn.ReLU(),
            nn.Flatten(),  # (64, 12, 12) -> (64 * 12 * 12)
        )
        # Flatten된 state 차원 계산
        self.flattened_state_dim = 64 * 12 * 12

        # Hypernetwork for producing weights and biases
        self.hyper_w_1 = nn.Linear(self.flattened_state_dim, n_agents * hidden_dim)
        self.hyper_b_1 = nn.Linear(self.flattened_state_dim, hidden_dim)

        self.hyper_w_2 = nn.Linear(self.flattened_state_dim, hidden_dim)
        self.hyper_b_2 = nn.Sequential(
            nn.Linear(self.flattened_state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Non-linear transformation
        self.elu = nn.ELU()

    def forward(self, agent_qs, states):
        """
        Args:
            agent_qs (torch.Tensor): 각 에이전트의 Q-values. (batch_size, n_agents)
            states (torch.Tensor): 전역 상태 정보. (batch_size, 5, 24, 24)
        Returns:
            torch.Tensor: 팀의 전체 Q-value. (batch_size, 1)
        """
        batch_size = agent_qs.size(0)

        # State encoding
        encoded_state = self.state_encoder(states)  # (batch_size, flattened_state_dim)

        # Hypernetwork 1: Generate weights and biases for the first layer
        w1 = self.hyper_w_1(encoded_state)  # (batch_size, n_agents * hidden_dim)
        w1 = w1.view(batch_size, self.n_agents, -1)  # (batch_size, n_agents, hidden_dim)
        b1 = self.hyper_b_1(encoded_state).unsqueeze(1)  # (batch_size, 1, hidden_dim)

        # First layer: Multiply agent Q-values with generated weights
        hidden = torch.bmm(agent_qs.unsqueeze(1), w1).squeeze(1) + b1  # (batch_size, hidden_dim)
        hidden = self.elu(hidden)

        # Hypernetwork 2: Generate weights and biases for the second layer
        w2 = self.hyper_w_2(encoded_state).unsqueeze(-1)  # (batch_size, hidden_dim, 1)
        b2 = self.hyper_b_2(encoded_state)  # (batch_size, 1)

        # Second layer: Combine hidden layer outputs
        team_q = torch.bmm(hidden, w2).squeeze(1) + b2  # (batch_size, 1)

        return team_q

In [5]:
def optimal_action_from_qval(q_value): # sinlge agent의 batched q value(bs, 6,24,24)를 list로 받음
    batch_size = len(q_value)
    max_q, max_id = torch.max(q_value.view(batch_size, -1), dim =1)
    unraveled_idx = torch.tensor([torch.unravel_index(idx, q_value.shape[1:]) for idx in max_id])
    q_values_batched = [max_q,unraveled_idx]
    return q_values_batched # output -> list, q_values_batched = [(bs,1), (bs,3)], 첫 째는 optimal q value, 둘 째는 해당하는 action 자체

In [6]:
def decentralized_action(num_agent, obs, player_id, model, device, eps = 0.05): # 에이전트 별 모델을 list로 받아냄
    num_agent = 16 
    action = []
    imaged_obs = torch.tensor(gen_data(obs)[player_id], dtype=torch.float, device= device)
    with torch.no_grad():
        for agent_id ,q_network in zip(range(num_agent), model):
            q_network.eval()
            q_obs = q_network(imaged_obs)
            optimal_q = optimal_action_from_qval(q_obs)

            if np.random.rand() < eps:
                rand_act = torch.tensor([np.random.randint(6),np.random.randint(24),np.random.randint(24)])
                rand_act = rand_act.repeat(q_obs.size(0),1)
                action.append(rand_act)
            else:
                action.append(optimal_q[1])
    return action # action을 담은 list

In [7]:
# Envobs observation-> dict 으로 변환하는 함수
from dataclasses import asdict
def to_dict(obs):
    obs_0= asdict(obs['player_0'])
    obs_1= asdict(obs['player_1'])
    return {'player_0': obs_0, 'player_1': obs_1}

In [8]:
# Training Loop
# 초기값 및 모델 하이퍼파라미터 세팅
num_agents = 16
np.random.seed(16)
env = LuxAIS3GymEnv()
env = RecordEpisode(env, save_dir="episodes")
env_params = EnvParams(map_type=1, max_steps_in_match=100)
N = env_params.max_steps_in_match * env_params.match_count_per_episode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

q_networks = [AgentQNetwork().to(device) for _ in range(num_agents)]
optimizers = [optim.Adam(q.parameters(), lr=0.0005) for q in q_networks]
mixing_network = MixerNetwork(hidden_dim=64).to(device)
mixing_optimizer = optim.Adam(mixing_network.parameters(), lr=0.001)

buffer = ReplayBuffer(capacity=5000)
batch_size = 128
epsilon = 0.05
num_episodes = 125
GAMMA = 0.90

In [9]:
# training process
# 해당 셀을 실행하면 학습이 진행됩니다.
import time
self_play = False

for episode in range(num_episodes):
    states, info = env.reset(seed=np.random.seed(16), options=dict(params=env_params))
    done = False
    episode_transitions = []
    i = 0
    step = 0
    prev_reward = 0
    sparse_prev_reward = 0
    team_pt = []
    opp_agent = Agent('player_1', info['params'])

    while not done and i < N:

        i += 1
        actions = np.array(decentralized_action(16, states, 'player_0',  q_networks, device, epsilon)).reshape(16,3)

        if self_play:
        # self-play 기반 opponent action 생성
            opp_actions = np.array(decentralized_action(16, states, 'player_1',  q_networks, device, epsilon)).reshape(16,3)
        else:
            opp_actions = opp_agent.act(step =step , obs= to_dict(states))
                
        #opp_actions = env.action_space.sample()['player_1'] # random action opponents로 학습도 가능
        act = {'player_0': actions, 'player_1': opp_actions}

        next_state, rewards, terminated, truncated, info = env.step(act)
        dones = terminated['player_0'] + truncated['player_0']

        # Rewards engineering
        dense_rewards = next_state['player_0'].team_points[0] - prev_reward 
        sparse_rewards = (rewards['player_0'] - sparse_prev_reward) * 100

        buffer.push(gen_data(states)['player_0'], actions, dense_rewards, gen_data(next_state)['player_0'], dones) # 버퍼에는 내 플레이어의 정보들만 저장

        states = next_state
        done = dones
        prev_reward = dense_rewards
        sparse_prev_reward = sparse_rewards

        step +=1

        # reward reset per match
        if done or i % 100 == 0:
            step =0
            prev_reward = 0
            sparse_prev_reward = 0
            team_pt.append(states['player_0'].team_points[0].item())

    # optimizing 하는 루프
    if len(buffer) >= batch_size:
        batch = buffer.sample(batch_size)
        state_batch = torch.tensor(np.array([s for s in batch.state]), dtype=torch.float32, device=device)
        action_batch = torch.tensor(np.array([a for a in batch.actions]), dtype=torch.int64, device=device)
        reward_batch = torch.tensor(np.array([r for r in batch.rewards]), dtype=torch.float32, device=device)
        next_state_batch = torch.tensor(np.array([ns for ns in batch.next_state]), dtype=torch.float32, device=device)
        done_batch = torch.tensor(np.array([d for d in batch.dones]), dtype=torch.float32, device=device)

    # Calculate individual Q-values
        agent_qs = []; agent_qs_next = []
        for i, q_network in enumerate(q_networks):
            q_value_single = q_network(state_batch)
            action_single = action_batch[:,i,:]
            opt_q_value = torch.tensor([q_value_single[tuple(torch.hstack([torch.tensor(i),idx]).tolist())] for i, idx in zip(range(batch_size),action_single)])
            agent_qs.append(opt_q_value)

            next_q_value_single = q_network(next_state_batch)
            opt_next_q_value = optimal_action_from_qval(next_q_value_single)[0]
            agent_qs_next.append(opt_next_q_value)

        agent_qs = torch.stack(agent_qs, dim=1)
        agent_qs_next = torch.stack(agent_qs_next, dim=1)

        # Calculate total Q-value using mixing network
        state_inputs = state_batch.view(batch_size, 5,24,24)
        q_total = mixing_network(agent_qs, state_inputs)

        # Calculate total next Q-value using mixing network
        next_state_inputs = next_state_batch.view(batch_size, 5,24,24)
        with torch.no_grad():
            q_total_next = mixing_network(agent_qs_next, next_state_inputs)

        # Compute loss and update networks
        loss = torch.mean((reward_batch + GAMMA*q_total_next - q_total) ** 2)

        for optimizer in optimizers:
            optimizer.zero_grad()
        mixing_optimizer.zero_grad()
        loss.backward()

        for optimizer in optimizers:
            optimizer.step()
        mixing_optimizer.step()

    if episode % 1 == 0:
        print(f"Episode {episode+1}, Loss: {loss.item():.4f}, my_team_points: {team_pt},  my_team_wins: {rewards['player_0']}")

    if episode % 15 == 0:
        q_net = nn.ModuleList(q_networks)
        torch.save(q_net, "q_net.pt")



Episode 1, Loss: 57.0249, my_team_points: [0, 0, 13, 0, 0],  my_team_wins: 0
Episode 2, Loss: 3818.2107, my_team_points: [35, 60, 38, 45, 20],  my_team_wins: 0
Episode 3, Loss: 8082.2231, my_team_points: [0, 3, 0, 0, 8],  my_team_wins: 0
Episode 4, Loss: 6447.7573, my_team_points: [58, 21, 55, 44, 26],  my_team_wins: 0
Episode 5, Loss: 1432.2297, my_team_points: [8, 0, 0, 0, 2],  my_team_wins: 0
Episode 6, Loss: 903.3110, my_team_points: [7, 0, 0, 2, 0],  my_team_wins: 0
Episode 7, Loss: 2583.2397, my_team_points: [32, 62, 35, 9, 27],  my_team_wins: 0
Episode 8, Loss: 498.6404, my_team_points: [0, 4, 1, 0, 0],  my_team_wins: 0
Episode 9, Loss: 670.8605, my_team_points: [43, 58, 45, 48, 31],  my_team_wins: 0
Episode 10, Loss: 480.0352, my_team_points: [0, 0, 0, 0, 4],  my_team_wins: 0
Episode 11, Loss: 373.6354, my_team_points: [55, 42, 69, 31, 17],  my_team_wins: 0
Episode 12, Loss: 465.4753, my_team_points: [0, 0, 39, 21, 5],  my_team_wins: 0
Episode 13, Loss: 447.3589, my_team_points

KeyboardInterrupt: 

In [10]:
# 수동 저장
# 체크포인트 생성 및 모델의 자동 업데이트 기능은 아직 없습니다
q_net = nn.ModuleList(q_networks)
torch.save(q_net, "q_net.pt")