In [None]:
import numpy as np
import soccer_twos
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.logger import CustomLogger
from src.utils import shape_rewards


class PolicyNetwork(nn.Module):
    def __init__(self, input_size=336, action_size=9):
        super(PolicyNetwork, self).__init__()

        self.dense_256 = nn.Linear(input_size, 256)
        self.mu_dense_128 = nn.Linear(256, 128)
        self.mu_dense_64 = nn.Linear(128, 64)
        self.mu_dense_32 = nn.Linear(64, 32)
        self.action_out = nn.Linear(32, action_size)

    def forward(self, state):
        x = F.relu(self.dense_256(state))
        x = F.relu(self.mu_dense_128(x))
        x = F.relu(self.mu_dense_64(x))
        x = F.relu(self.mu_dense_32(x))
        action_logits = self.action_out(x)
        return action_logits


class PGAgent:
    def __init__(self, state_size=336, action_size=9, learning_rate=0.0003):
        self.policy = PolicyNetwork(state_size, action_size)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=learning_rate)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy.to(self.device)
        self.action_size = action_size
        self.num_agents = 1

    def act(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action_logits = self.policy(state)
            action_probs = F.softmax(action_logits, dim=-1)
        action = torch.multinomial(action_probs, 1).item()
        return action

    def update(self, states, actions, rewards):
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)

        # Normalize rewards
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)

        # Calculate log probabilities
        action_logits = self.policy(states)
        log_probs = F.log_softmax(action_logits, dim=-1)
        selected_log_probs = log_probs[torch.arange(len(actions)), actions]

        # Calculate loss
        loss = -(selected_log_probs * rewards).mean()

        # Update policy
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save(self, filename):
        torch.save(self.policy.state_dict(), filename)

    def load(self, filename):
        self.policy.load_state_dict(torch.load(filename))


def action_to_env_format(action):
    """Convert single integer action to the format expected by the environment."""
    env_action = [0, 0, 0]
    env_action[action // 3] = (action % 3) - 1
    return env_action


def train_pg(n_games, n_agents):
    env = soccer_twos.make()
    pg_agent = PGAgent()
    logger = CustomLogger("pg_agent", "pg_simple")
    for i in range(n_games):
        obs = env.reset()
        done = False
        episode_states = []
        episode_actions = []
        episode_rewards = []
        scores = {}
        while not done:
            actions = {}
            for j in range(4):
                if j < n_agents:
                    action = pg_agent.act(obs[j])
                    actions[j] = action_to_env_format(action)
                else:
                    actions[j] = [0, 0, 0]  # No action for non-agent players

            next_obs, reward, done, info = env.step(actions)
            done = done["__all__"]

            for i in range(4):
                scores[i] = shape_rewards(info, i)
            # Only store data for the PG agent (agent 0)
            episode_states.append(obs[0])
            episode_actions.append(action)  # Store the action index
            episode_rewards.append(scores[0])

            obs = next_obs

        # Update the agent after each episode
        pg_agent.update(
            np.array(episode_states),
            np.array(episode_actions),
            np.array(episode_rewards),
        )
        logger.write_logs_and_tensorboard(
            i, scores, next_obs, reward, done, info, actions, pg_agent
        )
        # Here you might want to add logging or printing of episode results
        print(f"Episode {i+1} finished. Total reward: {sum(episode_rewards)}")

    env.close()


if __name__ == "__main__":
    train_pg(n_games=1000, n_agents=1)