<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/Navigation_in_a_Dynamic_GridWorld.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import heapq
import torch
import torch.nn as nn
import torch.nn.functional as F

def mlp(in_dim, hidden, out_dim, act=nn.ReLU):
    layers = []
    dims = [in_dim] + hidden + [out_dim]
    for i in range(len(dims)-2):
        layers += [nn.Linear(dims[i], dims[i+1]), act()]
    layers += [nn.Linear(dims[-2], dims[-1])]
    return nn.Sequential(*layers)

class MemoryModule(nn.Module):
    def __init__(self, d_model, d_mem, capacity=100):
        super().__init__()
        self.key_proj = nn.Linear(d_model, d_mem)
        self.val_proj = nn.Linear(d_model, d_mem)
        self.capacity = capacity
        self.register_buffer("keys", torch.zeros(capacity, d_mem))
        self.register_buffer("vals", torch.zeros(capacity, d_mem))
        self.ptr = 0
        self.full = False

    def write(self, feat):
        K = self.key_proj(feat)
        V = self.val_proj(feat)
        B = K.size(0)
        idx = (torch.arange(B) + self.ptr) % self.capacity
        self.keys[idx] = K
        self.vals[idx] = V
        self.ptr = (self.ptr + B) % self.capacity
        if self.ptr == 0: self.full = True

    def retrieve(self, feat):
        if self.ptr == 0 and not self.full:
            return torch.zeros(feat.size(0), self.vals.size(-1)), None
        sims = self.key_proj(feat) @ self.keys.T
        top_idx = sims.argmax(dim=-1)
        ctx = self.vals[top_idx]
        return ctx, None

class Reasoner(nn.Module):
    def __init__(self, d_model, d_reason):
        super().__init__()
        self.net = mlp(d_model, [d_model], d_reason)

    def forward(self, feat):
        return self.net(feat)

class MemoryConditioner(nn.Module):
    def __init__(self, d_model, d_mem, d_reason):
        super().__init__()
        self.proj = nn.Linear(d_model + d_mem + d_reason, d_model)

    def forward(self, feat, mem_ctx, reason_ctx):
        x = torch.cat([feat, mem_ctx, reason_ctx], dim=-1)
        return torch.tanh(self.proj(x))

class ActorCritic(nn.Module):
    def __init__(self, obs_dim, n_act, d_model=32, d_mem=16, d_reason=8):
        super().__init__()
        self.encoder = mlp(obs_dim, [64], d_model)
        self.memory = MemoryModule(d_model, d_mem)
        self.reasoner = Reasoner(d_model, d_reason)
        self.conditioner = MemoryConditioner(d_model, d_mem, d_reason)
        self.policy = mlp(d_model, [32], n_act)
        self.value = mlp(d_model, [32], 1)

    def forward(self, obs):
        feat = self.encoder(obs)
        mem_ctx, _ = self.memory.retrieve(feat)
        reason_ctx = self.reasoner(feat)
        cond = self.conditioner(feat, mem_ctx, reason_ctx)
        logits = self.policy(cond)
        value = self.value(cond)
        return logits, value

    def write_memory(self, feat):
        self.memory.write(feat)

class MemoryReasonerPlanner:
    def __init__(self, actor_critic, gamma=0.99):
        self.net = actor_critic
        self.gamma = gamma

    def plan(self, start_state, transition_fn, rollout_fn, eval_fn, max_expansions=100):
        # Priority queue: (neg_score, state, first_action, depth)
        frontier = []
        visited = set()

        # Encode start state for memory/reasoner
        obs = torch.tensor([float(start_state)], dtype=torch.float32)
        feat = self.net.encoder(obs.unsqueeze(0))
        mem_ctx, _ = self.net.memory.retrieve(feat)
        reason_ctx = self.net.reasoner(feat)

        # Initial expansion
        for ns in transition_fn(start_state):
            score = rollout_fn(ns)
            # Reasoner bias: dot product of reason_ctx with itself as a toy heuristic
            reason_bias = reason_ctx.squeeze(0).mean().item()
            fused_score = score + 0.1 * reason_bias
            heapq.heappush(frontier, (-fused_score, ns, ns, 1))

        while frontier and len(visited) < max_expansions:
            neg_score, state, first_action, depth = heapq.heappop(frontier)
            if state in visited:
                continue
            visited.add(state)

            if eval_fn(state):
                print(f"[Planner] Goal reached at state {state} via first action {first_action}")
                return first_action

            # Encode for memory/reasoner at this node
            obs = torch.tensor([float(state)], dtype=torch.float32)
            feat = self.net.encoder(obs.unsqueeze(0))
            mem_ctx, _ = self.net.memory.retrieve(feat)
            reason_ctx = self.net.reasoner(feat)

            for ns in transition_fn(state):
                if ns not in visited:
                    score = rollout_fn(ns)
                    reason_bias = reason_ctx.squeeze(0).mean().item()
                    fused_score = score + 0.1 * reason_bias
                    heapq.heappush(frontier, (-fused_score, ns, first_action, depth + 1))

        print("[Planner] No goal found within expansion limit")
        return None

# Dummy functions
def dummy_transition_fn(state):
    return [state + 1, state - 1]

def dummy_rollout_fn(state):
    return -abs(state - 5)

def dummy_eval_fn(state):
    return state == 5

# Instantiate with your memory-conditioned actor-critic
net = ActorCritic(obs_dim=1, n_act=2, d_model=32, d_mem=16, d_reason=8)
planner = MemoryReasonerPlanner(net)

start_state = 0
next_action = planner.plan(start_state, dummy_transition_fn, dummy_rollout_fn, dummy_eval_fn)
print(f"AGI Agent's planned next action: {next_action}")