In [None]:
#!/usr/bin/env python3
"""
Feudal RL: 3D Sculpting + Demo using Trained Models
---------------------------------------------------
We train a hierarchical RL agent in a 3D voxel environment. The manager picks
subgoals (x,y,z) in [0..grid_size-1], and the worker carves small steps among
±x, ±y, ±z. The environment terminates if the agent cuts into the shape or
goes out-of-bounds.

At the end, we store the trained models in Python variables
(manager_model, worker_model) so we can reuse them, and we run a final "demo"
episode using those models to create a 3D voxel GIF.

Usage:
  python feudal_3d_sculpt.py
"""

import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import imageio

# ------------------------------------------------
# 1) ENVIRONMENT
# ------------------------------------------------
class Sculpt3DEnv:
    def __init__(self, grid_size=8, max_steps=200, manager_update_freq=10):
        """
        3D sculpt environment:
         - NxNxN 'stock' array (True => material present)
         - NxNxN 'shape' array (True => protected shape)
         - A router that starts at (0,0,0).
         - Manager picks subgoals in [0..N^3 -1].
         - Worker picks 6 discrete steps: ±x, ±y, ±z.
        :param grid_size: NxNxN dimension
        :param max_steps: max worker steps per episode
        :param manager_update_freq: number of worker steps per manager subgoal
        """
        self.grid_size = grid_size
        self.max_steps = max_steps
        self.manager_update_freq = manager_update_freq
        self.reset()

    def reset(self):
        # Initialize stock to True (material present)
        self.stock = np.ones((self.grid_size, self.grid_size, self.grid_size), dtype=bool)
        # Create shape (True => shape)
        self.shape = np.zeros((self.grid_size, self.grid_size, self.grid_size), dtype=bool)
        cx, cy, cz = self.grid_size//2, self.grid_size//2, self.grid_size//2
        r = self.grid_size // 3
        for x in range(self.grid_size):
            for y in range(self.grid_size):
                for z in range(self.grid_size):
                    dx, dy, dz = x-cx, y-cy, z-cz
                    dist = dx*dx + dy*dy + dz*dz
                    if dist <= r*r:
                        self.shape[x,y,z] = True

        self.router_pos = np.array([0,0,0], dtype=int)
        self.steps_taken = 0
        self.done = False
        # Manager subgoal
        self.subgoal = np.array([cx,cy,cz], dtype=int)

        return self._get_manager_obs(), self._get_worker_obs()

    def set_subgoal(self, coord):
        coord = np.clip(coord, 0, self.grid_size-1)
        self.subgoal = coord

    def worker_step(self, action):
        """
        action in [0..5]: 0=+x,1=-x,2=+y,3=-y,4=+z,5=-z
        If shape is cut or out-of-bounds => done with penalty.
        If outside stock is removed => small reward.
        """
        if self.done:
            return self._get_worker_obs(), 0.0, True
        move = np.array([0,0,0], dtype=int)
        if action==0: move = np.array([1,0,0])
        elif action==1: move = np.array([-1,0,0])
        elif action==2: move = np.array([0,1,0])
        elif action==3: move = np.array([0,-1,0])
        elif action==4: move = np.array([0,0,1])
        elif action==5: move = np.array([0,0,-1])

        oldp = self.router_pos.copy()
        newp = oldp+move
        reward = 0.0
        # check OOB
        if not np.all((newp>=0)&(newp<self.grid_size)):
            reward -= 5.0
            self.done=True
        else:
            path = [oldp, newp]
            for (vx,vy,vz) in path:
                if self.shape[vx,vy,vz]:
                    reward -= 5.0
                    self.done=True
                    break
                if self.stock[vx,vy,vz]:
                    self.stock[vx,vy,vz] = False
                    reward += 1.0
            if not self.done:
                self.router_pos = newp

        self.steps_taken+=1
        if self.steps_taken>=self.max_steps:
            self.done=True

        return self._get_worker_obs(), reward, self.done

    def manager_reward(self):
        """
        fraction_outside_removed
        difference-based approach
        """
        outside_mask = (self.shape==False)
        outside_total = np.sum(outside_mask)
        outside_removed = outside_total - np.sum(self.stock[outside_mask])
        frac_removed = outside_removed/(outside_total+1e-8)
        return frac_removed

    def manager_done(self):
        return self.done

    def _get_manager_obs(self):
        outside_mask = (self.shape==False)
        outside_total = np.sum(outside_mask)
        outside_removed = outside_total - np.sum(self.stock[outside_mask])
        frac_removed = outside_removed/(outside_total+1e-8)
        rx,ry,rz = self.router_pos
        return np.array([frac_removed, rx, ry, rz], dtype=float)

    def _get_worker_obs(self):
        rx,ry,rz = self.router_pos
        sx,sy,sz = self.subgoal
        return np.array([rx,ry,rz, sx,sy,sz], dtype=float)

# ------------------------------------------------
# 2) NETWORKS
# ------------------------------------------------
class ManagerNet(nn.Module):
    def __init__(self, input_dim, grid_size=8, hidden_dim=128):
        super().__init__()
        self.grid_size = grid_size
        self.output_dim = grid_size**3
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.output_dim)
        )
    def forward(self, x):
        return self.net(x)

class WorkerNet(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=128, n_actions=6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    def forward(self, x):
        return self.net(x)

# ------------------------------------------------
# 3) REPLAY BUFFERS
# ------------------------------------------------
class ManagerReplayBuffer:
    def __init__(self, capacity=5000):
        self.buffer = deque(maxlen=capacity)
    def push(self, s,a,r,s_next,d):
        self.buffer.append((s,a,r,s_next,d))
    def sample(self,batch_size):
        batch = random.sample(self.buffer,batch_size)
        s,a,r,ns,d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(ns), np.array(d)
    def __len__(self):
        return len(self.buffer)

class WorkerReplayBuffer:
    def __init__(self, capacity=5000):
        self.buffer = deque(maxlen=capacity)
    def push(self,s,a,r,s_next,d):
        self.buffer.append((s,a,r,s_next,d))
    def sample(self,batch_size):
        batch = random.sample(self.buffer,batch_size)
        s,a,r,ns,d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(ns), np.array(d)
    def __len__(self):
        return len(self.buffer)

# ------------------------------------------------
# 4) TRAINING
# ------------------------------------------------
def feudal_train(env,
                 manager_episodes=50,
                 batch_size=32,
                 manager_lr=1e-3,
                 worker_lr=1e-3,
                 gamma_manager=0.9,
                 gamma_worker=0.9,
                 manager_eps_start=1.0,
                 worker_eps_start=1.0,
                 eps_min=0.1,
                 eps_decay=0.995):
    device = "cpu"  # or "cuda" if GPU available
    manager_obs_dim = len(env._get_manager_obs())
    worker_obs_dim  = len(env._get_worker_obs())

    manager_model = ManagerNet(manager_obs_dim, grid_size=env.grid_size, hidden_dim=128).to(device)
    worker_model  = WorkerNet(worker_obs_dim, 128, 6).to(device)

    manager_opt = optim.Adam(manager_model.parameters(), lr=manager_lr)
    worker_opt  = optim.Adam(worker_model.parameters(),  lr=worker_lr)

    manager_replay = ManagerReplayBuffer(5000)
    worker_replay  = WorkerReplayBuffer(5000)

    manager_eps = manager_eps_start
    worker_eps  = worker_eps_start

    for ep in range(manager_episodes):
        m_obs, w_obs = env.reset()
        done=False
        manager_return = 0.0

        while not done:
            # manager picks subgoal
            manager_s = m_obs
            n_subgoals = env.grid_size**3
            if random.random()<manager_eps:
                subgoal_id = random.randint(0,n_subgoals-1)
            else:
                with torch.no_grad():
                    inp = torch.FloatTensor(manager_s).unsqueeze(0)
                    q_vals = manager_model(inp)
                    subgoal_id = q_vals.argmax(dim=1).item()

            # decode subgoal
            z = subgoal_id%env.grid_size
            y = (subgoal_id//env.grid_size)%env.grid_size
            x = (subgoal_id//(env.grid_size*env.grid_size))%env.grid_size
            env.set_subgoal(np.array([x,y,z]))
            old_manager_r = env.manager_reward()

            # worker steps
            for _ in range(env.manager_update_freq):
                if env.done:
                    break
                worker_s = w_obs
                if random.random()<worker_eps:
                    action_worker = random.randint(0,5)
                else:
                    with torch.no_grad():
                        inp = torch.FloatTensor(worker_s).unsqueeze(0)
                        q_vals = worker_model(inp)
                        action_worker = q_vals.argmax(dim=1).item()

                w_obs_next, w_r, w_done = env.worker_step(action_worker)
                worker_replay.push(worker_s, action_worker, w_r, w_obs_next, w_done)
                w_obs = w_obs_next
                done = w_done

                if len(worker_replay)>=batch_size:
                    s_arr,a_arr,r_arr,ns_arr,d_arr = worker_replay.sample(batch_size)
                    s_t = torch.FloatTensor(s_arr)
                    a_t = torch.LongTensor(a_arr)
                    r_t = torch.FloatTensor(r_arr)
                    ns_t= torch.FloatTensor(ns_arr)
                    d_t = torch.BoolTensor(d_arr)

                    q_vals = worker_model(s_t)
                    q_s_a = q_vals.gather(1,a_t.unsqueeze(1)).squeeze(1)
                    with torch.no_grad():
                        q_next = worker_model(ns_t)
                        max_q_next,_=torch.max(q_next, dim=1)
                        max_q_next[d_t]=0.0
                    target = r_t + gamma_worker*max_q_next
                    lossW = nn.MSELoss()(q_s_a, target)
                    worker_opt.zero_grad()
                    lossW.backward()
                    worker_opt.step()

                if done:
                    break

            m_obs_next = env._get_manager_obs()
            new_manager_r= env.manager_reward()
            delta_r = new_manager_r - old_manager_r
            manager_return += delta_r
            manager_done = env.manager_done()
            manager_replay.push(manager_s, subgoal_id, delta_r, m_obs_next, manager_done)
            m_obs = m_obs_next
            done = manager_done

            if len(manager_replay)>=batch_size:
                s_arr,a_arr,r_arr,ns_arr,d_arr = manager_replay.sample(batch_size)
                s_t = torch.FloatTensor(s_arr)
                a_t = torch.LongTensor(a_arr)
                r_t = torch.FloatTensor(r_arr)
                ns_t= torch.FloatTensor(ns_arr)
                d_t = torch.BoolTensor(d_arr)

                q_manager = manager_model(s_t)
                q_s_a = q_manager.gather(1,a_t.unsqueeze(1)).squeeze(1)
                with torch.no_grad():
                    q_next = manager_model(ns_t)
                    max_q_next, _ = torch.max(q_next, dim=1)
                    max_q_next[d_t]=0.0
                target = r_t + gamma_manager*max_q_next
                lossM = nn.MSELoss()(q_s_a,target)
                manager_opt.zero_grad()
                lossM.backward()
                manager_opt.step()

        manager_eps = max(eps_min, manager_eps*eps_decay)
        worker_eps  = max(eps_min, worker_eps*eps_decay)
        print(f"Ep {ep+1}/{manager_episodes}, manager_return={manager_return:.2f}, epsM={manager_eps:.2f}, epsW={worker_eps:.2f}")

    print("Feudal training complete.")
    return manager_model, worker_model

# ------------------------------------------------
# 5) DEMO WITH 3D VOXEL RENDERING
# ------------------------------------------------
def create_3d_demo_gif(manager_model, worker_model, env, output_gif="feudal_3d_sculpt_demo.gif"):
    """
    Use the trained manager_model and worker_model in a single episode (no exploration).
    Capture frames with a 3D voxel rendering each step. Then compile into a GIF.
    """
    frames = []
    manager_eps=0.0
    worker_eps=0.0
    m_obs, w_obs = env.reset()
    done=False

    while not done:
        # Manager picks subgoal greedily
        with torch.no_grad():
            inp = torch.FloatTensor(m_obs).unsqueeze(0)
            q_vals = manager_model(inp)
            subgoal_id = q_vals.argmax(dim=1).item()

        # decode subgoal
        z = subgoal_id%env.grid_size
        y = (subgoal_id//env.grid_size)%env.grid_size
        x = (subgoal_id//(env.grid_size*env.grid_size))%env.grid_size
        env.set_subgoal(np.array([x,y,z]))

        # up to env.manager_update_freq steps
        for _ in range(env.manager_update_freq):
            if env.done:
                break
            # Worker picks move greedily
            with torch.no_grad():
                inp = torch.FloatTensor(w_obs).unsqueeze(0)
                q_worker = worker_model(inp)
                act = q_worker.argmax(dim=1).item()

            w_obs_next, w_r, w_done = env.worker_step(act)
            w_obs = w_obs_next
            frames.append(render_3d_voxel(env))  # capture a 3D snapshot
            if w_done:
                break

        m_obs_next = env._get_manager_obs()
        done = env.manager_done()
        m_obs = m_obs_next

    # compile frames into a gif
    imageio.mimsave(output_gif, frames, fps=2)
    print(f"Demo run saved to {output_gif}")

def render_3d_voxel(env):
    """
    Return a single RGB frame as a numpy array from a 3D voxel plot:
     - shape in bright color
     - stock in partial alpha
     - removed => alpha=0
     - router => single bright voxel
    """
    import io
    from matplotlib.backends.backend_agg import FigureCanvasAgg

    n = env.grid_size
    shape = env.shape
    stock= env.stock

    # RGBA array
    data_color = np.zeros((n,n,n,4), dtype=float)
    for x in range(n):
        for y in range(n):
            for z in range(n):
                if shape[x,y,z] and stock[x,y,z]:
                    # shape voxel => bright red
                    data_color[x,y,z] = [1.0, 0.2,0.2,1.0]
                elif not shape[x,y,z] and stock[x,y,z]:
                    # outside => partial alpha grey
                    data_color[x,y,z] = [0.6,0.6,0.6,0.4]
                else:
                    # removed => alpha=0
                    data_color[x,y,z] = [1,1,1,0]

    # router => bright green
    rx,ry,rz = env.router_pos
    data_color[rx,ry,rz] = [0.0,1.0,0.0,1.0]

    filled = (data_color[:,:,:,3]>0.0)

    fig=plt.figure(figsize=(6,5))
    ax=fig.add_subplot(111, projection='3d')
    ax.voxels(filled, facecolors=data_color, edgecolor=None)

    ax.set_xlim(0,n)
    ax.set_ylim(0,n)
    ax.set_zlim(0,n)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title("3D Sculpt")

    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    width,height = fig.canvas.get_width_height()
    buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    buf = buf.reshape(height,width,3)
    plt.close(fig)
    return buf

# ------------------------------------------------
# 6) MAIN
# ------------------------------------------------
def main():
    # 1) Create environment
    env = Sculpt3DEnv(grid_size=8, max_steps=10000, manager_update_freq=5)

    # 2) Train feudal RL (manager, worker)
    manager_model, worker_model = feudal_train(env, manager_episodes=1000)

    # manager_model, worker_model are now stored in variables we can re-use.

    # 3) Use the final model to run a "demo" episode, capturing frames 
    #    with the 3D voxel rendering for each step, then compile into a GIF.
    print("Running final demo with the trained model, saving to feudal_3d_sculpt_demo.gif ...")
    create_3d_demo_gif(manager_model, worker_model, env, "feudal_3d_sculpt_demo.gif")

    print("All done. The final manager_model and worker_model are stored as Python variables.")

if __name__=="__main__":
    main()

Ep 1/1000, manager_return=0.00, epsM=0.99, epsW=0.99
Ep 2/1000, manager_return=0.00, epsM=0.99, epsW=0.99
Ep 3/1000, manager_return=0.00, epsM=0.99, epsW=0.99
Ep 4/1000, manager_return=0.01, epsM=0.98, epsW=0.98
Ep 5/1000, manager_return=0.01, epsM=0.98, epsW=0.98
Ep 6/1000, manager_return=0.01, epsM=0.97, epsW=0.97
Ep 7/1000, manager_return=0.01, epsM=0.97, epsW=0.97
Ep 8/1000, manager_return=0.02, epsM=0.96, epsW=0.96
Ep 9/1000, manager_return=0.00, epsM=0.96, epsW=0.96
Ep 10/1000, manager_return=0.00, epsM=0.95, epsW=0.95
Ep 11/1000, manager_return=0.02, epsM=0.95, epsW=0.95
Ep 12/1000, manager_return=0.00, epsM=0.94, epsW=0.94
Ep 13/1000, manager_return=0.00, epsM=0.94, epsW=0.94
Ep 14/1000, manager_return=0.02, epsM=0.93, epsW=0.93
Ep 15/1000, manager_return=0.00, epsM=0.93, epsW=0.93
Ep 16/1000, manager_return=0.01, epsM=0.92, epsW=0.92
Ep 17/1000, manager_return=0.00, epsM=0.92, epsW=0.92
Ep 18/1000, manager_return=0.00, epsM=0.91, epsW=0.91
Ep 19/1000, manager_return=0.01, epsM

  buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
