<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/safety_lambda_experiment_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install gymnasium minigrid torch matplotlib

In [None]:
import gymnasium as gym
import minigrid
import numpy as np
import torch
import matplotlib.pyplot as plt

# ===== 1) Violation wrapper (version-agnostic) =====
try:
    from minigrid.core.world_object import Lava as LavaObj
except Exception:
    LavaObj = None

class ViolationWrapper(gym.Wrapper):
    """
    Adds info['violation'] = 1 when the agent stands on a lava tile after a step.
    Uses env internals (grid + agent_pos) available across MiniGrid versions.
    """
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        if info is None:
            info = {}
        info["violation"] = 0
        return obs, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        if info is None:
            info = {}
        violation = 0
        # Inspect the agent's current tile
        try:
            unwrapped = self.env.unwrapped
            grid = getattr(unwrapped, "grid", None)
            agent_pos = getattr(unwrapped, "agent_pos", None)
            if grid is not None and agent_pos is not None:
                obj = grid.get(*agent_pos)
                if obj is not None:
                    if LavaObj is not None and isinstance(obj, LavaObj):
                        violation = 1
                    else:
                        # Fallback by name/type attribute
                        obj_type = getattr(obj, "type", None)
                        if isinstance(obj_type, str) and obj_type.lower() == "lava":
                            violation = 1
        except Exception:
            # If internals change, remain safe: no violation flagged
            violation = 0

        info["violation"] = int(violation)
        return obs, reward, terminated, truncated, info

# ===== 2) Simple SafeDreamer-like agent =====
class SafeDreamer(torch.nn.Module):
    def __init__(self, obs_dim, action_dim, cost_coef=1.0, lr=1e-3):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.cost_coef = cost_coef
        self.net = torch.nn.Sequential(
            torch.nn.Linear(obs_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, action_dim)
        )
        self.opt = torch.optim.Adam(self.parameters(), lr=lr)

    def _flatten_obs(self, obs):
        if isinstance(obs, dict) and "image" in obs:
            return torch.tensor(obs["image"], dtype=torch.float32).flatten()
        return torch.tensor(obs, dtype=torch.float32).flatten()

    def act(self, obs):
        with torch.no_grad():
            logits = self.net(self._flatten_obs(obs))
            return int(torch.argmax(logits))

    def learn(self, obs, action, reward, violation):
        x = self._flatten_obs(obs)
        logits = self.net(x)
        target = torch.zeros(self.action_dim)
        target[action] = reward - self.cost_coef * violation
        loss = torch.nn.functional.mse_loss(logits, target)
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

# ===== 3) Training loop with λ sweep =====
def train_with_lambda(env_id, lambdas, episodes=100, max_steps=200):
    results = []
    for lam in lambdas:
        base_env = gym.make(env_id)
        env = ViolationWrapper(base_env)

        # Detect obs_dim from a reset
        sample_obs, _ = env.reset()
        if isinstance(sample_obs, dict) and "image" in sample_obs:
            obs_dim = torch.tensor(sample_obs["image"], dtype=torch.float32).numel()
        else:
            obs_dim = torch.tensor(sample_obs, dtype=torch.float32).numel()

        agent = SafeDreamer(obs_dim, env.action_space.n, cost_coef=lam)
        rewards, violations = [], []

        for _ in range(episodes):
            obs, info = env.reset()
            done = False
            ep_reward, ep_viol = 0.0, 0.0
            steps = 0
            while not done and steps < max_steps:
                action = agent.act(obs)
                next_obs, reward, terminated, truncated, info = env.step(action)
                done = terminated or truncated
                viol = info.get("violation", 0)
                agent.learn(obs, action, reward, viol)
                obs = next_obs
                ep_reward += float(reward)
                ep_viol += float(viol)
                steps += 1

            rewards.append(ep_reward)
            violations.append(ep_viol)

        env.close()
        mean_r = float(np.mean(rewards))
        mean_v = float(np.mean(violations))
        results.append((lam, mean_r, mean_v))
        print(f"λ={lam:.2f} → reward={mean_r:.2f}, violations={mean_v:.2f}")

    return results

# ===== 4) Run sweep on a lava environment =====
if __name__ == "__main__":
    env_id = "MiniGrid-LavaGapS7-v0"
    lambdas = [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
    results = train_with_lambda(env_id, lambdas, episodes=100, max_steps=200)

    # Pareto plot
    rewards = [r for _, r, _ in results]
    violations = [v for _, _, v in results]

    plt.figure(figsize=(6, 4))
    for (lam, r, v) in results:
        plt.scatter(v, r, label=f"λ={lam}")
    plt.xlabel("Avg violations per episode")
    plt.ylabel("Avg reward per episode")
    plt.title("Reward–violation Pareto frontier")
    plt.legend()
    plt.grid(True)
    plt.show()

    # AUC (sorted by x=violations)
    sorted_points = sorted(zip(violations, rewards), key=lambda p: p[0])
    x = [p[0] for p in sorted_points]
    y = [p[1] for p in sorted_points]
    auc = np.trapezoid(y, x) if len(x) > 1 else 0.0
    print(f"Area under curve (AUC): {auc:.3f}")