In [None]:
#!/usr/bin/env python3
"""
Feudal RL: 3D Sculpting with a Beautiful 3D Voxel Visualization
---------------------------------------------------------------
We have:
 - A 3D NxNxN environment with an initial stock (True => present) 
   and a spherical "shape" to protect (True => shape).
 - The manager picks subgoals in [0..N^3 -1] (the subgoal is a coordinate in [0..N-1]^3).
 - The worker picks discrete moves among 6 directions (±x, ±y, ±z).
 - We do difference-based manager rewards based on fraction of outside removed.
 - We produce a "pretty" 3D voxel visualization for each step, storing frames and 
   compiling them into a GIF. The shape is displayed in a distinctive color, 
   the stock in partial alpha, and the router is shown as a bright 3D marker (voxel).
Usage:
  python feudal_3d_pretty.py
"""

import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import math
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import io
import imageio

# -------------------------
# 1) ENVIRONMENT
# -------------------------
class Sculpt3DEnv:
    """
    3D voxel-based sculpting environment:
     - NxNxN stock
     - NxNxN shape (protected)
     - The manager sets subgoals in discrete coords
     - The worker carves 1-step moves among 6 directions
     - If shape is cut or OOB => done
     - Worker reward = small positive for removing outside stock
     - Manager reward = difference-based fraction_outside_removed
    """
    def __init__(self, grid_size=8, max_steps=200, manager_update_freq=10):
        self.grid_size = grid_size
        self.max_steps = max_steps
        self.manager_update_freq = manager_update_freq
        self.reset()

    def reset(self):
        # stock True => material present
        self.stock = np.ones((self.grid_size, self.grid_size, self.grid_size), dtype=bool)
        # shape True => protected region
        self.shape = np.zeros((self.grid_size, self.grid_size, self.grid_size), dtype=bool)
        cx,cy,cz = self.grid_size//2, self.grid_size//2, self.grid_size//2
        r = self.grid_size//3
        for x in range(self.grid_size):
            for y in range(self.grid_size):
                for z in range(self.grid_size):
                    dx,dy,dz = x-cx,y-cy,z-cz
                    dist = dx*dx+dy*dy+dz*dz
                    if dist<=r*r:
                        self.shape[x,y,z] = True

        self.router_pos = np.array([0,0,0], dtype=int)
        self.steps_taken = 0
        self.done = False
        self.subgoal = np.array([cx,cy,cz], dtype=int)
        return self._get_manager_obs(), self._get_worker_obs()

    def set_subgoal(self, coord):
        coord = np.clip(coord,0,self.grid_size-1)
        self.subgoal = coord

    def worker_step(self, action):
        """
        6 discrete moves:
         0=+x,1=-x,2=+y,3=-y,4=+z,5=-z
        """
        if self.done:
            return self._get_worker_obs(),0.0,True
        move = np.array([0,0,0], dtype=int)
        if action==0: move=np.array([1,0,0])
        elif action==1:move=np.array([-1,0,0])
        elif action==2:move=np.array([0,1,0])
        elif action==3:move=np.array([0,-1,0])
        elif action==4:move=np.array([0,0,1])
        elif action==5:move=np.array([0,0,-1])

        oldp = self.router_pos.copy()
        newp = oldp+move
        reward=0.0

        # check OOB
        if not np.all((newp>=0)&(newp<self.grid_size)):
            reward-=5.0
            self.done=True
        else:
            path = [oldp, newp]
            for (vx,vy,vz) in path:
                if self.shape[vx,vy,vz]:
                    reward-=5.0
                    self.done=True
                    break
                if self.stock[vx,vy,vz]:
                    self.stock[vx,vy,vz]=False
                    reward+=1.0
            if not self.done:
                self.router_pos=newp

        self.steps_taken+=1
        if self.steps_taken>=self.max_steps:
            self.done=True
        return self._get_worker_obs(),reward,self.done

    def manager_reward(self):
        outside_mask = (self.shape==False)
        outside_total = np.sum(outside_mask)
        outside_removed = outside_total - np.sum(self.stock[outside_mask])
        frac_removed = outside_removed/(outside_total+1e-8)
        return frac_removed

    def manager_done(self):
        return self.done

    def _get_manager_obs(self):
        outside_mask = (self.shape==False)
        outside_total = np.sum(outside_mask)
        outside_removed = outside_total - np.sum(self.stock[outside_mask])
        frac_removed = outside_removed/(outside_total+1e-8)
        rx,ry,rz = self.router_pos
        return np.array([frac_removed, rx, ry, rz], dtype=float)

    def _get_worker_obs(self):
        rx,ry,rz = self.router_pos
        sx,sy,sz = self.subgoal
        return np.array([rx,ry,rz,sx,sy,sz], dtype=float)

# -------------------------
# 2) NETWORKS
# -------------------------
class ManagerNet(nn.Module):
    def __init__(self, input_dim, grid_size=8, hidden_dim=128):
        super().__init__()
        self.grid_size = grid_size
        self.output_dim = grid_size**3
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.output_dim)
        )
    def forward(self, x):
        return self.net(x)

class WorkerNet(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=128, n_actions=6):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    def forward(self,x):
        return self.net(x)

# -------------------------
# 3) REPLAY BUFFERS
# -------------------------
class ManagerReplayBuffer:
    def __init__(self, cap=5000):
        self.buffer = deque(maxlen=cap)
    def push(self,s,a,r,s_next,d):
        self.buffer.append((s,a,r,s_next,d))
    def sample(self, batch_size):
        batch = random.sample(self.buffer,batch_size)
        s,a,r,ns,d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(ns), np.array(d)
    def __len__(self):
        return len(self.buffer)

class WorkerReplayBuffer:
    def __init__(self, cap=5000):
        self.buffer = deque(maxlen=cap)
    def push(self,s,a,r,s_next,d):
        self.buffer.append((s,a,r,s_next,d))
    def sample(self,batch_size):
        batch = random.sample(self.buffer,batch_size)
        s,a,r,ns,d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(ns), np.array(d)
    def __len__(self):
        return len(self.buffer)

# -------------------------
# 4) TRAINING
# -------------------------
def feudal_train(env,
                 manager_episodes=100,
                 batch_size=32,
                 manager_lr=1e-3,
                 worker_lr=1e-3,
                 gamma_manager=0.9,
                 gamma_worker=0.9,
                 manager_eps_start=1.0,
                 worker_eps_start=1.0,
                 eps_min=0.1,
                 eps_decay=0.995):
    device="cpu"
    manager_obs_dim = len(env._get_manager_obs())
    worker_obs_dim  = len(env._get_worker_obs())
    manager_net = ManagerNet(manager_obs_dim, grid_size=env.grid_size).to(device)
    worker_net  = WorkerNet(worker_obs_dim, 128, 6).to(device)

    manager_opt=optim.Adam(manager_net.parameters(),lr=manager_lr)
    worker_opt= optim.Adam(worker_net.parameters(), lr=worker_lr)

    manager_replay=ManagerReplayBuffer(5000)
    worker_replay= WorkerReplayBuffer(5000)

    manager_eps = manager_eps_start
    worker_eps  = worker_eps_start

    for ep in range(manager_episodes):
        m_obs, w_obs = env.reset()
        done=False
        manager_return=0.0
        while not done:
            # manager picks subgoal
            manager_s=m_obs
            n_subgoals=env.grid_size**3
            if random.random()<manager_eps:
                subgoal_id = random.randint(0,n_subgoals-1)
            else:
                with torch.no_grad():
                    inp = torch.FloatTensor(manager_s).unsqueeze(0)
                    logits = manager_net(inp)
                    subgoal_id=logits.argmax(dim=1).item()
            z = subgoal_id%env.grid_size
            y = (subgoal_id//env.grid_size)%env.grid_size
            x = (subgoal_id//(env.grid_size*env.grid_size))%env.grid_size
            env.set_subgoal(np.array([x,y,z]))

            old_manager_r=env.manager_reward()
            # worker does env.manager_update_freq steps
            for _ in range(env.manager_update_freq):
                if env.done:
                    break
                worker_s=w_obs
                if random.random()<worker_eps:
                    action_worker = random.randint(0,5)
                else:
                    with torch.no_grad():
                        inp = torch.FloatTensor(worker_s).unsqueeze(0)
                        qvals = worker_net(inp)
                        action_worker = qvals.argmax(dim=1).item()
                w_obs_next, w_r, w_done = env.worker_step(action_worker)
                worker_replay.push(worker_s,action_worker,w_r,w_obs_next,w_done)
                w_obs=w_obs_next
                done=w_done
                # update worker
                if len(worker_replay)>=batch_size:
                    s_arr,a_arr,r_arr,ns_arr,d_arr = worker_replay.sample(batch_size)
                    s_t = torch.FloatTensor(s_arr)
                    a_t = torch.LongTensor(a_arr)
                    r_t = torch.FloatTensor(r_arr)
                    ns_t= torch.FloatTensor(ns_arr)
                    d_t = torch.BoolTensor(d_arr)

                    q_vals = worker_net(s_t)
                    q_s_a = q_vals.gather(1, a_t.unsqueeze(1)).squeeze(1)
                    with torch.no_grad():
                        q_next = worker_net(ns_t)
                        max_q_next,_=torch.max(q_next,dim=1)
                        max_q_next[d_t]=0.0
                    target = r_t + gamma_worker*max_q_next
                    loss = nn.MSELoss()(q_s_a,target)
                    worker_opt.zero_grad()
                    loss.backward()
                    worker_opt.step()

                if done:
                    break

            m_obs_next = env._get_manager_obs()
            new_manager_r= env.manager_reward()
            delta_r = new_manager_r - old_manager_r
            manager_return+=delta_r
            manager_done=env.manager_done()
            manager_replay.push(manager_s, subgoal_id, delta_r, m_obs_next, manager_done)
            m_obs=m_obs_next
            done=manager_done
            # manager update
            if len(manager_replay)>=batch_size:
                s_arr,a_arr,r_arr,ns_arr,d_arr = manager_replay.sample(batch_size)
                s_t = torch.FloatTensor(s_arr)
                a_t = torch.LongTensor(a_arr)
                r_t = torch.FloatTensor(r_arr)
                ns_t= torch.FloatTensor(ns_arr)
                d_t = torch.BoolTensor(d_arr)

                logits = manager_net(s_t)
                q_s_a = logits.gather(1, a_t.unsqueeze(1)).squeeze(1)
                with torch.no_grad():
                    logits_next = manager_net(ns_t)
                    max_q_next,_=torch.max(logits_next,dim=1)
                    max_q_next[d_t]=0.0
                target = r_t + gamma_manager*max_q_next
                loss_m = nn.MSELoss()(q_s_a,target)
                manager_opt.zero_grad()
                loss_m.backward()
                manager_opt.step()

        manager_eps = max(eps_min, manager_eps*eps_decay)
        worker_eps  = max(eps_min, worker_eps*eps_decay)
        print(f"Episode {ep+1}/{manager_episodes}, manager_return={manager_return:.3f}, epsM={manager_eps:.2f}, epsW={worker_eps:.2f}")

    return manager_net, worker_net

# -------------------------
# 5) 3D Voxel Visualization
# -------------------------
def create_3d_demo_gif(manager_net, worker_net, env, output_gif="feudal_3d_sculpt_demo.gif"):
    """
    Runs a final episode with manager & worker in greedy mode, capturing a 
    PRETTY 3D voxel rendering for each step. We store frames and compile into a GIF.
    We'll use partial alpha for stock, bright color for shape, a different color for router.
    """
    frames = []
    manager_eps=0.0
    worker_eps=0.0
    m_obs, w_obs = env.reset()
    done=False

    while not done:
        # manager picks subgoal greedily
        manager_s = m_obs
        with torch.no_grad():
            inp = torch.FloatTensor(manager_s).unsqueeze(0)
            logits = manager_net(inp)
            subgoal_id = logits.argmax(dim=1).item()

        z = subgoal_id % env.grid_size
        y = (subgoal_id // env.grid_size) % env.grid_size
        x = (subgoal_id // (env.grid_size*env.grid_size)) % env.grid_size
        env.set_subgoal(np.array([x,y,z]))

        # up to env.manager_update_freq worker steps
        old_mgr_r = env.manager_reward()
        for _ in range(env.manager_update_freq):
            if env.done:
                break

            # worker picks move greedily
            worker_s = w_obs
            with torch.no_grad():
                inp = torch.FloatTensor(worker_s).unsqueeze(0)
                qvals = worker_net(inp)
                act = qvals.argmax(dim=1).item()

            w_obs_next, w_r, w_done = env.worker_step(act)
            w_obs = w_obs_next

            # gather a pretty 3D voxel snapshot
            frames.append(render_3d_voxel(env))

            if w_done:
                break

        m_obs_next = env._get_manager_obs()
        done=env.manager_done()
        m_obs=m_obs_next

    # compile frames into a GIF
    imageio.mimsave(output_gif, frames, fps=2)
    print(f"Demo run saved to {output_gif}")

def render_3d_voxel(env):
    """
    Return a PIL image with a 3D voxel rendering:
     - shape in bright color
     - stock in partial alpha
     - router as a single bright voxel
    We'll do an ax.voxels(...) approach.
    We'll color shape vs. outside stock differently, plus an overlay for router.
    """
    from matplotlib.backends.backend_agg import FigureCanvasAgg
    fig = plt.figure(figsize=(8,6))
    ax = fig.add_subplot(111, projection='3d')

    n = env.grid_size

    # We'll create an RGBA array for each voxel. shape => bright color, stock => partial color, removed => alpha=0
    stock_bool = env.stock  # NxNxN
    shape_bool = env.shape

    # create a boolean 3D array for "visible" => shape or stock
    # We'll differentiate them by color
    # For shape: if shape[x,y,z] is True and stock[x,y,z] is True => shape voxel
    # For outside: if shape[x,y,z]==False and stock[x,y,z]==True => outside voxel
    # Everything else => alpha=0
    # Then we'll do a separate single voxel for the router.

    data_color = np.zeros((n,n,n,4), dtype=float)
    # shape => bright red or so
    # outside => gray with partial alpha
    # shape removed => alpha=0
    # outside removed => alpha=0

    for x in range(n):
        for y in range(n):
            for z in range(n):
                if shape_bool[x,y,z] and stock_bool[x,y,z]:
                    # shape voxel => color = [1,0,0,1]
                    data_color[x,y,z] = [1.0, 0.2, 0.2, 1.0]
                elif (not shape_bool[x,y,z]) and stock_bool[x,y,z]:
                    # outside => partial alpha grey
                    data_color[x,y,z] = [0.5,0.5,0.5,0.5]
                else:
                    data_color[x,y,z] = [1.0,1.0,1.0,0.0]

    # router => single voxel with color = [0,1,0,1]
    rx,ry,rz = env.router_pos
    data_color[rx,ry,rz] = [0.0, 1.0, 0.0, 1.0]

    # ax.voxels wants a boolean 3D array for "filled", plus facecolors
    filled = (data_color[:,:,:,3]>0.0)  # anything with alpha>0
    ax.voxels(filled, facecolors=data_color, edgecolor=None)

    ax.set_xlim(0,n)
    ax.set_ylim(0,n)
    ax.set_zlim(0,n)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title("Feudal 3D Sculpt")

    # convert figure to PIL image
    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    width, height = fig.canvas.get_width_height()
    buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    buf = buf.reshape(height, width, 3)
    plt.close(fig)

    return buf

# -------------------------
# 6) MAIN
# -------------------------
def main():
    env = Sculpt3DEnv(grid_size=8, max_steps=60, manager_update_freq=5)
    manager_net, worker_net = feudal_train(env, manager_episodes=40)
    print("Creating final 3D sculpting demo with pretty voxel rendering ...")
    create_3d_demo_gif(manager_net, worker_net, env, "feudal_3d_sculpt_demo.gif")

if __name__=="__main__":
    main()

Episode 1/60, manager_return=0.052, epsM=0.99, epsW=0.99
Episode 2/60, manager_return=0.000, epsM=0.99, epsW=0.99
Episode 3/60, manager_return=0.006, epsM=0.99, epsW=0.99
Episode 4/60, manager_return=0.008, epsM=0.98, epsW=0.98
Episode 5/60, manager_return=0.015, epsM=0.98, epsW=0.98
Episode 6/60, manager_return=0.000, epsM=0.97, epsW=0.97
Episode 7/60, manager_return=0.000, epsM=0.97, epsW=0.97
Episode 8/60, manager_return=0.000, epsM=0.96, epsW=0.96
Episode 9/60, manager_return=0.000, epsM=0.96, epsW=0.96
Episode 10/60, manager_return=0.000, epsM=0.95, epsW=0.95
Episode 11/60, manager_return=0.004, epsM=0.95, epsW=0.95
Episode 12/60, manager_return=0.004, epsM=0.94, epsW=0.94
Episode 13/60, manager_return=0.010, epsM=0.94, epsW=0.94
Episode 14/60, manager_return=0.000, epsM=0.93, epsW=0.93
Episode 15/60, manager_return=0.004, epsM=0.93, epsW=0.93
Episode 16/60, manager_return=0.015, epsM=0.92, epsW=0.92
Episode 17/60, manager_return=0.004, epsM=0.92, epsW=0.92
Episode 18/60, manager_

In [2]:
!pip3 install imageio

Collecting imageio
  Downloading imageio-2.37.0-py3-none-any.whl.metadata (5.2 kB)
Downloading imageio-2.37.0-py3-none-any.whl (315 kB)
Installing collected packages: imageio
Successfully installed imageio-2.37.0
