# Train Cogames

this notebook will help you to get started with cogames with the simplest training example.

Make sure you have selected T4 runtime!

## Install Dependencies

In [None]:
!git clone https://github.com/Metta-AI/metta.git
%cd metta

### Install Bazel

In [None]:
# 3. Update and install Bazel
!sudo apt update && sudo apt install bazel -y
!bazel --version

### Install Nim

In [None]:
!sudo apt-get update -y
!sudo apt-get install -y curl git build-essential

#Install Nim via choosenim (official installer)
!curl https://nim-lang.org/choosenim/init.sh -sSf | sh -s -- -y
# Add Nim to PATH for this Colab session
import os
os.environ["PATH"] += ":/root/.nimble/bin"

# Verify installation
!nim --version
!nimble --version


### Install all the required packages

In [None]:
!uv pip install .
!uv pip install git+https://github.com/PufferAI/PufferLib

In [None]:
# ignore errors and warning in this.
!pip install numpy --upgrade

## Training Code + PPO Loss

It includes a complete reinforcement learning training loop that implements a rollout buffer for generalized advantage estimation (GAE), a PPO update function with clipping, entropy regularization, and value loss computation, and supports both simple and LSTM-based policies. It interacts with the MettaGridEnv simulation, collecting observations, actions, and rewards over multiple epochs to optimize the policy network.

You can modify mission configurations, hyperparameters, and policy types to experiment with different training setups.

In [None]:
# imports
from mettagrid import MettaGridEnv
import torch
from torch.distributions import Categorical
import numpy as np

In [None]:
def get_cogames_mission(mission_name="training_facility.harvest", variants=None):
    from cogames.cli.mission import get_mission
    _, config, _ = get_mission(mission_name, variants_arg=variants)
    return config



In [None]:
# Rollout Buffer Class which stores all the Expereince

class RolloutBuffer:
    def __init__(self, device, gamma, gae_lambda):
        self.device = device
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clear()

    def add(self, obs, action, reward, done, log_prob, value):
        self.observations.append(torch.as_tensor(obs, device=self.device, dtype=torch.float32).clone())
        self.actions.append(torch.as_tensor(action, device=self.device, dtype=torch.long).clone())
        self.rewards.append(torch.as_tensor(reward, device=self.device, dtype=torch.float32).clone())
        self.dones.append(torch.as_tensor(done, device=self.device, dtype=torch.float32).clone())
        self.log_probs.append(torch.as_tensor(log_prob, device=self.device, dtype=torch.float32).clone())
        self.values.append(torch.as_tensor(value, device=self.device, dtype=torch.float32).clone())

    def build_training_batch(self, last_value, last_done):
        if not self.observations:
            raise ValueError("RolloutBuffer is empty")

        obs = torch.stack(self.observations)
        actions = torch.stack(self.actions)
        rewards = torch.stack(self.rewards)
        dones = torch.stack(self.dones)
        old_log_probs = torch.stack(self.log_probs)
        old_values = torch.stack(self.values)

        values = old_values.view(old_values.shape[0], -1)
        rewards = rewards.view(rewards.shape[0], -1)
        dones = dones.view(dones.shape[0], -1)

        last_value = torch.as_tensor(last_value, device=self.device, dtype=torch.float32).view(-1)
        last_done = torch.as_tensor(last_done, device=self.device, dtype=torch.float32).view(-1)

        advantages = torch.zeros_like(values)
        next_advantage = torch.zeros_like(last_value)
        next_value = last_value
        next_nonterminal = 1.0 - last_done

        for step in reversed(range(values.shape[0])):
            reward = rewards[step]
            value = values[step]
            done = dones[step]
            delta = reward + self.gamma * next_value * next_nonterminal - value
            next_advantage = delta + self.gamma * self.gae_lambda * next_nonterminal * next_advantage
            advantages[step] = next_advantage
            next_value = value
            next_nonterminal = 1.0 - done

        returns = advantages + values

        batch = {
            "obs": obs,
            "actions": actions,
            "old_log_probs": old_log_probs,
            "old_values": values,
            "advantages": advantages,
            "returns": returns,
        }

        return batch

    def clear(self):
        self.observations = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
        self.values = []

    def __len__(self):
        return len(self.observations)





In [None]:
# this function give the action

def get_action(policy, obs, device, lstm_state=None):
    obs_tensor = torch.from_numpy(obs).float().unsqueeze(0).to(device)

    state_argument = None
    if policy.is_recurrent():
        if lstm_state is None:
            state_argument = {"lstm_h": None, "lstm_c": None}
        else:
            h, c = lstm_state
            state_argument = {"lstm_h": h, "lstm_c": c}

    with torch.no_grad():
        policy.network().eval()
        if policy.is_recurrent():
            logits, values = policy.network().forward_eval(obs_tensor, state_argument)
            state_h, state_c = state_argument.get("lstm_h"), state_argument.get("lstm_c")
            if state_h is not None and state_c is not None:
                new_state = (state_h.detach(), state_c.detach())
            else:
                new_state = None
        else:
            logits, values = policy.network().forward_eval(obs_tensor)
            new_state = None

    dist = Categorical(logits=logits)
    actions = dist.sample()
    log_probs = dist.log_prob(actions)

    if actions.dim() == 0:
        actions = actions.unsqueeze(0)
        log_probs = log_probs.unsqueeze(0)

    return actions.cpu().numpy(), log_probs.detach(), values.detach(), new_state

In [None]:
## PPO Update function


def ppo_update(
    policy,
    optimizer,
    batch,
    device,
    clip_coef,
    vf_clip_coef,
    vf_coef,
    ent_coef,
    max_grad_norm,
    ppo_epochs,
    minibatch_size,
):
    obs = batch["obs"].to(device)
    actions = batch["actions"].to(device)
    old_log_probs = batch["old_log_probs"].to(device)
    old_values = batch["old_values"].to(device)
    advantages = batch["advantages"].to(device)
    returns = batch["returns"].to(device)

    if advantages.numel() == 0:
        raise ValueError("Advantages tensor is empty")

    policy.network().train()

    advantages = advantages - advantages.mean()
    advantages = advantages / (advantages.std(unbiased=False) + 1e-8)

    returns = advantages + old_values

    batch_size = obs.shape[0]

    policy_losses = []
    value_losses = []
    entropies = []
    clip_fractions = []
    approx_kls = []

    for _ in range(ppo_epochs):
        indices = torch.randperm(batch_size, device=device)
        for start in range(0, batch_size, minibatch_size):
            mb_idx = indices[start : start + minibatch_size]
            mb_obs = obs[mb_idx]
            mb_actions = actions[mb_idx]
            mb_old_log_probs = old_log_probs[mb_idx]
            mb_old_values = old_values[mb_idx]
            mb_advantages = advantages[mb_idx]
            mb_returns = returns[mb_idx]

            # Flatten agent dimension so PPO treats each agent-step as one sample
            mb_actions = mb_actions.reshape(mb_actions.shape[0], -1)
            mb_old_log_probs = mb_old_log_probs.reshape(mb_old_log_probs.shape[0], -1)
            mb_old_values = mb_old_values.reshape(mb_old_values.shape[0], -1)
            mb_advantages = mb_advantages.reshape(mb_advantages.shape[0], -1)
            mb_returns = mb_returns.reshape(mb_returns.shape[0], -1)

            if policy.is_recurrent():
                logits, values = policy.network().forward_eval(mb_obs, None)
            else:
                logits, values = policy.network().forward_eval(mb_obs)

            dist = Categorical(logits=logits)
            new_log_probs = dist.log_prob(mb_actions.squeeze(-1) if mb_actions.shape[-1] == 1 else mb_actions)
            entropy = dist.entropy()

            new_log_probs = new_log_probs.unsqueeze(-1) if new_log_probs.dim() == 1 else new_log_probs
            log_ratio = new_log_probs - mb_old_log_probs
            ratio = log_ratio.exp()
            surr1 = ratio * mb_advantages
            surr2 = torch.clamp(ratio, 1.0 - clip_coef, 1.0 + clip_coef) * mb_advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            value_pred = values.reshape(values.shape[0], -1)
            value_pred_clipped = mb_old_values + (value_pred - mb_old_values).clamp(-vf_clip_coef, vf_clip_coef)
            value_loss_unclipped = (value_pred - mb_returns) ** 2
            value_loss_clipped = (value_pred_clipped - mb_returns) ** 2
            value_loss = 0.5 * torch.max(value_loss_unclipped, value_loss_clipped).mean()

            kld = 0.5 * log_ratio.pow(2).mean()
            clip_fraction = (torch.abs(ratio - 1.0) > clip_coef).float().mean()

            loss = policy_loss + vf_coef * value_loss - ent_coef * entropy.mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.network().parameters(), max_grad_norm)
            optimizer.step()

            policy_losses.append(policy_loss.detach())
            value_losses.append(value_loss.detach())
            entropies.append(entropy.mean().detach())
            clip_fractions.append(clip_fraction.detach())
            approx_kls.append(kld.detach())

    metrics = {
        "policy_loss": torch.stack(policy_losses).mean().item(),
        "value_loss": torch.stack(value_losses).mean().item(),
        "entropy": torch.stack(entropies).mean().item(),
        "clip_fraction": torch.stack(clip_fractions).mean().item(),
        "approx_kl": torch.stack(approx_kls).mean().item(),
    }

    return metrics

In [None]:


# Training loop

def train_cogames(
    mission_name="training_facility.harvest",
    policy_type="simple",
    num_epochs=10,
    batch_size=256,
    minibatch_size=64,
    ppo_epochs=4,
    learning_rate=3e-4,
    # Metta-tuned hyperparameters from PPOConfig
    gamma=0.977,
    gae_lambda=0.891477,
    clip_coef=0.264407,
    vf_clip_coef=0.1,
    vf_coef=0.897619,
    ent_coef=0.01,
    max_grad_norm=0.5,
    save_path="cogames_policy.pt",
    device="cpu",
    variants=None,
    seed=42,
):
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")
    if minibatch_size <= 0:
        raise ValueError("minibatch_size must be positive")
    if minibatch_size > batch_size:
        raise ValueError("minibatch_size must be <= batch_size")

    device = torch.device(device)
    config = get_cogames_mission(mission_name, variants=variants)
    env = MettaGridEnv(env_cfg=config)
    obs, _ = env.reset(seed=seed)

    if policy_type == "simple":
        from cogames.policy.simple import SimplePolicy

        policy = SimplePolicy(env, device)
    elif policy_type == "lstm":
        from cogames.policy.lstm import LSTMPolicy

        policy = LSTMPolicy(env, device)
    else:
        raise ValueError(f"Unknown policy type: {policy_type}")

    optimizer = torch.optim.Adam(policy.network().parameters(), lr=learning_rate)
    best_reward = -float("inf")

    for epoch in range(num_epochs):
        buffer = RolloutBuffer(device=device, gamma=gamma, gae_lambda=gae_lambda)
        obs, _ = env.reset(seed=seed + epoch)
        epoch_reward = 0.0
        hearts_collected = 0.0
        lstm_state = None
        steps_collected = 0
        last_done_flags = np.zeros(getattr(env, "num_agents", 1), dtype=np.float32)

        while len(buffer) < batch_size:
            actions_np, log_probs, values, next_state = get_action(policy, obs, device, lstm_state)
            next_obs, rewards, terminals, truncations, _ = env.step(actions_np)

            dones = np.logical_or(terminals, truncations)
            buffer.add(obs, actions_np, rewards, dones, log_probs, values)

            step_reward = float(np.sum(rewards))
            epoch_reward += step_reward
            if step_reward > 0:
                hearts_collected += step_reward

            last_done_flags = np.asarray(dones, dtype=np.float32).copy()

            if np.any(dones):
                next_obs, _ = env.reset()
                lstm_state = None
            else:
                lstm_state = next_state

            obs = next_obs
            steps_collected += 1

        with torch.no_grad():
            obs_tensor = torch.from_numpy(obs).float().unsqueeze(0).to(device)
            if policy.is_recurrent():
                state_arg = None
                if lstm_state is not None:
                    h, c = lstm_state
                    state_arg = {"lstm_h": h, "lstm_c": c}
                _, last_value_tensor = policy.network().forward_eval(obs_tensor, state_arg)
            else:
                _, last_value_tensor = policy.network().forward_eval(obs_tensor)

        last_value = last_value_tensor.squeeze(0)
        if last_value.dim() > 1:
            last_value = last_value.squeeze(-1)
        last_value = last_value.detach().cpu()

        last_done_tensor = torch.as_tensor(last_done_flags, dtype=torch.float32)
        if last_done_tensor.dim() == 0:
            last_done_tensor = last_done_tensor.unsqueeze(0)

        last_value = last_value * (1.0 - last_done_tensor)

        batch = buffer.build_training_batch(last_value=last_value, last_done=last_done_tensor)

        metrics = ppo_update(
            policy=policy,
            optimizer=optimizer,
            batch=batch,
            device=device,
            clip_coef=clip_coef,
            vf_clip_coef=vf_clip_coef,
            vf_coef=vf_coef,
            ent_coef=ent_coef,
            max_grad_norm=max_grad_norm,
            ppo_epochs=ppo_epochs,
            minibatch_size=minibatch_size,
        )

        print(
            f"Epoch {epoch + 1}/{num_epochs} - Reward: {epoch_reward:.2f}, "
            f"Policy loss: {metrics['policy_loss']:.4f}, Value loss: {metrics['value_loss']:.4f}"
        )

        if epoch_reward > best_reward:
            best_reward = epoch_reward
            policy.save_policy_data(save_path)

    env.close()


## Train the Policy

In [None]:
train_cogames(
    mission_name="training_facility.harvest",
    policy_type="simple", # or use lstm
    num_epochs=50000,
    batch_size=2048,
    minibatch_size=512,
    ppo_epochs=4,
    learning_rate=3e-4,
    save_path="cogames_policy.pt",
    variants=["neutral_faced"], # Available variants: mined_out, dark_side, super_charged, rough_terrain, solar_flare, desert, forest, city, caves, store_base, extractor_base, both_base, lonely_heart, pack_rat, energized, neutral_faced
    device="cuda"
    )
