<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/Multimodal%E2%80%91Safe_Agent_Cell_(fixed).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install minigrid

In [None]:
!pip install gymnasium

In [None]:
# === cell 18: run a multimodal-safe agent (shape-safe version) ===
import torch
import gymnasium as gym
import minigrid  # registers all MiniGrid-* environments

# ----- 1. Create environment -----
env = gym.make("MiniGrid-MemoryS7-v0", render_mode=None)

# ----- 2. Detect actual observation dimension -----
sample_obs, _ = env.reset()
if isinstance(sample_obs, dict) and "image" in sample_obs:
    obs_dim = torch.tensor(sample_obs["image"], dtype=torch.float32).numel()
else:
    obs_dim = torch.tensor(sample_obs, dtype=torch.float32).numel()

print(f"Detected obs_dim = {obs_dim}")

# ----- 3. Placeholder SafeDreamer -----
class SafeDreamer(torch.nn.Module):
    def __init__(self, obs_dim, action_dim, cost_coef=1.0):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.cost_coef = cost_coef
        self.net = torch.nn.Sequential(
            torch.nn.Linear(obs_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, action_dim)
        )
    def forward(self, x):
        return self.net(x)

# ----- 4. Placeholder MAML + AGIAgent -----
class DummyMAML:
    def __init__(self, obs_dim, action_dim):
        self.model = torch.nn.Sequential(
            torch.nn.Linear(obs_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, action_dim)
        )

class AGIAgent:
    def __init__(self, obs_dim, action_dim):
        self.maml = DummyMAML(obs_dim, action_dim)
        self.world_model = None
        self.action_dim = action_dim
    def act(self, obs):
        if isinstance(obs, dict) and "image" in obs:
            obs_vec = torch.tensor(obs["image"], dtype=torch.float32).flatten()
        else:
            obs_vec = torch.tensor(obs, dtype=torch.float32).flatten()
        logits = self.maml.model(obs_vec)
        return torch.argmax(logits).item()
    def observe(self, obs, action, reward, next_obs, done):
        pass  # hook for replay buffer
    def learn(self):
        pass  # hook for training updates

# ----- 5. Instantiate agent with correct obs_dim -----
agent = AGIAgent(obs_dim=obs_dim, action_dim=env.action_space.n)

# Attach SafeDreamer world model
agent.world_model = SafeDreamer(
    obs_dim=agent.maml.model[0].in_features,
    action_dim=env.action_space.n,
    cost_coef=7.5
)

# ----- 6. Dummy logging -----
def log_metrics(agent, episode):
    print(f"[Episode {episode}] Training step complete.")

# ----- 7. Training loop -----
num_episodes = 5  # small for quick test
for episode in range(num_episodes):
    obs, _ = env.reset()
    done = False
    while not done:
        action = agent.act(obs)
        next_obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        agent.observe(obs, action, reward, next_obs, done)
        agent.learn()
        obs = next_obs
    log_metrics(agent, episode)

env.close()