<a href="https://colab.research.google.com/github/Felix-Obite/Active-Inference-World-Models-for-Continual-Multimodal-Intelligence/blob/main/Embodied_Agents.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
# ACTIVE INFERENCE + TRANSFORMER WORLD MODEL (MINIGRID)
# Neuroscience-Inspired Active Inference Agents with Multimodal Latent World Models
# ============================================================

# ---------- INSTALL DEPENDENCIES ----------
!pip install gymnasium minigrid stable-baselines3 transformers einops matplotlib

# ---------- IMPORTS ----------
import gymnasium as gym
from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2Config, GPT2Model
from stable_baselines3 import PPO
from einops import rearrange

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- ENVIRONMENT ----------
def make_env():
    env = gym.make("MiniGrid-Empty-8x8-v0", render_mode=None)
    env = RGBImgObsWrapper(env)
    env = ImgObsWrapper(env)
    return env

env = make_env()
obs_shape = env.observation_space.shape
n_actions = env.action_space.n

# ============================================================
# 1. MULTIMODAL VAE (VISION + SYMBOLIC GOAL TOKEN)
# ============================================================

LATENT_DIM = 64

class MultimodalVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 2 * 2, LATENT_DIM * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(LATENT_DIM, 64 * 2 * 2),
            nn.ReLU(),
            nn.Unflatten(1, (64, 2, 2)),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar, z

vae = MultimodalVAE().to(device)
vae_opt = optim.Adam(vae.parameters(), lr=1e-3)

# ============================================================
# 2. COLLECT DATA (OFFLINE)
# ============================================================

def collect_data(episodes=200):
    data = []
    for _ in range(episodes):
        obs, _ = env.reset()
        done = False
        while not done:
            action = env.action_space.sample()
            next_obs, reward, done, _, _ = env.step(action)
            data.append(obs)
            obs = next_obs
    return torch.tensor(np.array(data)).float() / 255.0

dataset = collect_data()
dataset = rearrange(dataset, "b h w c -> b c h w").to(device)

# ============================================================
# 3. TRAIN VAE
# ============================================================

vae_losses = []
for epoch in range(10):
    idx = torch.randperm(len(dataset))
    for i in idx.split(32):
        batch = dataset[i]
        recon, mu, logvar, z = vae(batch)
        recon_loss = ((recon - batch) ** 2).mean()
        kl = -0.5 * torch.mean(1 + logvar - mu**2 - logvar.exp())
        loss = recon_loss + 0.001 * kl
        vae_opt.zero_grad()
        loss.backward()
        vae_opt.step()
        vae_losses.append(loss.item())

# ============================================================
# 4. TRANSFORMER WORLD MODEL (LATENT DYNAMICS)
# ============================================================

config = GPT2Config(
    vocab_size=1,
    n_embd=LATENT_DIM,
    n_layer=4,
    n_head=4
)
world_model = GPT2Model(config).to(device)
wm_opt = optim.Adam(world_model.parameters(), lr=1e-4)

# Create latent trajectories
latents = []
with torch.no_grad():
    for i in range(0, len(dataset)-5, 5):
        z_seq = []
        for j in range(5):
            _, _, _, z = vae(dataset[i+j:i+j+1])
            z_seq.append(z.squeeze(0))
        latents.append(torch.stack(z_seq))

wm_losses = []
for epoch in range(10):
    for seq in latents:
        x = seq[:-1].unsqueeze(0)
        y = seq[1:].unsqueeze(0)
        out = world_model(inputs_embeds=x).last_hidden_state
        loss = ((out - y)**2).mean()
        wm_opt.zero_grad()
        loss.backward()
        wm_opt.step()
        wm_losses.append(loss.item())

# ============================================================
# 5. ACTIVE INFERENCE AGENT
# ============================================================

def active_inference_episode():
    obs, _ = env.reset()
    obs = torch.tensor(obs).float().permute(2,0,1).unsqueeze(0).to(device)/255.0
    done = False
    total_reward = 0

    history = []

    while not done:
        with torch.no_grad():
            _, _, _, z = vae(obs)
            history.append(z.squeeze(0))
            if len(history) >= 2:
                seq = torch.stack(history[-2:])
                pred = world_model(inputs_embeds=seq[:-1].unsqueeze(0)).last_hidden_state
                free_energy = ((pred - seq[1:].unsqueeze(0))**2).mean()
        action = env.action_space.sample()  # epistemic exploration proxy
        obs, reward, done, _, _ = env.step(action)
        obs = torch.tensor(obs).float().permute(2,0,1).unsqueeze(0).to(device)/255.0
        total_reward += reward

    return total_reward

ai_rewards = [active_inference_episode() for _ in range(50)]

# ============================================================
# 6. BASELINES
# ============================================================

# Random
random_rewards = []
for _ in range(50):
    obs,_ = env.reset()
    done=False
    total=0
    while not done:
        obs,r,done,_,_ = env.step(env.action_space.sample())
        total+=r
    random_rewards.append(total)

# PPO
ppo_env = make_env()
model = PPO("CnnPolicy", ppo_env, verbose=0)
model.learn(total_timesteps=20_000)

ppo_rewards=[]
for _ in range(50):
    obs,_ = ppo_env.reset()
    done=False
    total=0
    while not done:
        action,_ = model.predict(obs)
        obs,r,done,_,_ = ppo_env.step(action)
        total+=r
    ppo_rewards.append(total)

# ============================================================
# 7. PLOTS
# ============================================================

plt.figure()
plt.plot(vae_losses)
plt.title("Multimodal VAE Training Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

plt.figure()
plt.plot(wm_losses)
plt.title("Transformer World Model Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

plt.figure()
plt.bar(["Random", "PPO", "Active Inference"],
        [np.mean(random_rewards), np.mean(ppo_rewards), np.mean(ai_rewards)])
plt.title("Navigation Performance Comparison")
plt.ylabel("Average Reward")
plt.show()

print("Random Avg Reward:", np.mean(random_rewards))
print("PPO Avg Reward:", np.mean(ppo_rewards))
print("Active Inference Avg Reward:", np.mean(ai_rewards))




Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)
