In [3]:
#!/usr/bin/env python3
"""
A complete Feudal Reinforcement Learning example from scratch, using a
3D milling environment as the training testbed. This code demonstrates:

1. A 3D voxel-based milling simulation (stock + shape).
2. A hierarchical RL setup (Feudal RL) with:
   - Manager Network: Produces a subgoal (3D router target or region).
   - Worker Network: Executes fine-grained actions to achieve that subgoal.
3. Two replay buffers (one for Manager, one for Worker).
4. Training loop with separate updates for Manager and Worker.
5. TensorBoard logging for monitoring progress.

DISCLAIMER:
- This code is a didactic, end-to-end *example*. It may need tuning,
  optimizations, or expansions (e.g., advanced replay strategies, 
  target networks, deeper architectures, etc.) to perform well in 
  real applications. But it lays out the essential pieces.
- The environment is simplified for demonstration; real CNC milling 
  involves more complex geometry and physics.
- The feudal approach here is fairly minimal; advanced FeUdal networks 
  or Option-Critic expansions can refine the concept.

USAGE:
    python feudal_milling.py

Then run:
    tensorboard --logdir=runs
to monitor training curves.

Enjoy!
"""

import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import matplotlib.pyplot as plt
import imageio
import os
import datetime

from torch.utils.tensorboard import SummaryWriter

########################################################
# 1) ENVIRONMENT DEFINITION (3D VOXEL MILLING)
########################################################

class Milling3DEnvFeudal:
    """
    A simplified 3D milling environment with:
    - stock[x,y,z] = 1 means material present, 0 means removed
    - shape[x,y,z] = 1 means protected shape
    - router_pos = (rx, ry, rz) within [0, N-1]
    - 6 discrete worker actions: ±x, ±y, ±z (small 1-step moves)
    - The manager sets a subgoal coordinate in the grid,
      which the worker tries to reach.

    NOTE: The manager does NOT directly interact with the environment. 
    We store subgoals in the environment's state so the worker can 
    see them. For clarity, 'manager_step()' is done externally 
    in the training loop.
    """

    def __init__(
        self,
        grid_size=8,
        max_steps=200,
        penalty_cut_shape=10.0,
        reward_outside_cut=1.0,
        success_bonus=30.0,
        step_penalty=0.1,
        min_radius=2,
        max_radius=3
    ):
        """
        :param grid_size: NxNxN
        :param max_steps: episode terminates after these many worker steps
        :param penalty_cut_shape: penalty for removing shape voxel
        :param reward_outside_cut: reward for removing outside voxel
        :param success_bonus: reward if all outside is removed
        :param step_penalty: small penalty each step
        :param min_radius, max_radius: random sphere radius for shape
        """
        self.grid_size = grid_size
        self.max_steps = max_steps
        self.penalty_cut_shape = penalty_cut_shape
        self.reward_outside_cut = reward_outside_cut
        self.success_bonus = success_bonus
        self.step_penalty = step_penalty
        self.min_radius = min_radius
        self.max_radius = max_radius

        # Worker has 6 discrete actions
        self.worker_action_space = [0,1,2,3,4,5]  # +x,-x,+y,-y,+z,-z
        self.reset()

    def reset(self):
        self.stock = np.ones((self.grid_size, self.grid_size, self.grid_size), dtype=int)
        self.shape = np.zeros((self.grid_size, self.grid_size, self.grid_size), dtype=int)

        # random sphere for shape in center
        cx = self.grid_size // 2
        cy = self.grid_size // 2
        cz = self.grid_size // 2
        radius = np.random.randint(self.min_radius, self.max_radius+1)
        for x in range(self.grid_size):
            for y in range(self.grid_size):
                for z in range(self.grid_size):
                    dist_sq = (x-cx)**2 + (y-cy)**2 + (z-cz)**2
                    if dist_sq <= radius**2:
                        self.shape[x,y,z] = 1

        # router at (0,0,0)
        self.router_pos = np.array([0,0,0], dtype=int)
        self.steps_taken = 0

        # subgoal (set by manager externally), initialize to center
        self.subgoal = np.array([cx,cy,cz], dtype=int)

        # track done or not
        self.done = False

        # return manager observation, worker observation as needed
        return self._get_manager_obs(), self._get_worker_obs()

    def set_subgoal(self, subgoal):
        """
        The manager calls this to set subgoal, shape = (3,).
        We clamp within the environment's boundaries.
        """
        subgoal = np.clip(subgoal, 0, self.grid_size-1)
        self.subgoal = subgoal.astype(int)

    def worker_step(self, action):
        """
        The worker acts with a discrete action in [0..5].
        Returns (worker_obs, worker_reward, done).
        Manager is updated externally.
        """
        if self.done:
            # If environment is already done, no further changes
            return self._get_worker_obs(), 0.0, True

        # decode action
        if action == 0:  # +x
            self.router_pos[0] += 1
        elif action == 1:  # -x
            self.router_pos[0] -= 1
        elif action == 2:  # +y
            self.router_pos[1] += 1
        elif action == 3:  # -y
            self.router_pos[1] -= 1
        elif action == 4:  # +z
            self.router_pos[2] += 1
        elif action == 5:  # -z
            self.router_pos[2] -= 1

        # clamp
        self.router_pos = np.clip(self.router_pos, 0, self.grid_size-1)
        rx, ry, rz = self.router_pos

        self.steps_taken += 1

        # compute reward for the worker
        worker_reward = 0.0

        # if we cut a voxel
        if self.stock[rx, ry, rz] == 1:
            self.stock[rx, ry, rz] = 0
            if self.shape[rx, ry, rz] == 1:
                # shape => big penalty
                worker_reward -= self.penalty_cut_shape
                self.done = True
            else:
                # outside => small positive
                worker_reward += self.reward_outside_cut

        # step penalty
        worker_reward -= self.step_penalty

        # check if all outside stock is removed
        outside_mask = (self.shape == 0)
        if np.sum(self.stock[outside_mask]) == 0:
            worker_reward += self.success_bonus
            self.done = True

        # check step limit
        if self.steps_taken >= self.max_steps:
            self.done = True

        # worker also gets an intrinsic reward for moving closer to subgoal
        subgoal_dist_before = 0.0  # we need to define a "before" measure if we want difference-based
        # In a typical approach, you'd store the old position, but let's do a simpler direct approach:
        # negative distance to subgoal, i.e. -||router - subgoal||. We'll incorporate it directly.
        dist = np.linalg.norm(self.router_pos - self.subgoal)
        # We'll do: worker_reward += -0.1 * dist  (some weighting)
        # but to be consistent each step, let's do it:
        worker_reward -= 0.1 * dist

        # now we have a partial measure of "intrinsic" subgoal approach
        # in practice, you'd define a more advanced shaping.

        # build next worker obs
        w_obs = self._get_worker_obs()
        return w_obs, worker_reward, self.done

    def manager_reward(self):
        """
        Manager reward is based on global progress:
        - fraction of outside stock removed
        - penalty if shape is cut
        - success if everything is removed
        For simplicity, let's do:
          R = (outside_removed_frac) * 10
              - (shape_cut_count) * 5
        plus a big bonus if done successfully
        We'll do a simpler approach: measure at the end of manager's timescale.
        """
        # count how many shape voxels are cut:
        shape_cut = (self.shape == 1) & (self.stock == 0)
        shape_cut_count = np.sum(shape_cut)

        # fraction outside removed:
        outside_mask = (self.shape == 0)
        outside_total = np.sum(outside_mask)
        outside_removed = outside_total - np.sum(self.stock[outside_mask])
        frac_outside_removed = outside_removed / (outside_total + 1e-8)

        # define a simple manager reward function
        # e.g. manager gets +10 * frac_outside_removed - 5*(cut shape count)
        # This is only a demonstration. Tuning is needed in real usage.
        r = 10.0 * frac_outside_removed - 5.0 * shape_cut_count
        return r

    def manager_done(self):
        """ 
        The manager can be considered 'done' when the environment is done 
        or if we want it to set subgoals multiple times per episode. 
        Typically, we keep manager's episode = env episode for simplicity. 
        """
        return self.done

    def _get_manager_obs(self):
        """
        Manager can see a downsampled version of the stock + shape,
        or a summarized feature. 
        Here, let's do a naive approach: flatten everything 
        but then downsample drastically or just flatten the full grid 
        (not recommended if large).
        We'll do a naive flatten for demonstration. In real usage, 
        you'd do a more compact representation or a 3D CNN with pooling.
        """
        # flatten
        stock_f = self.stock.flatten().astype(np.float32)
        shape_f = self.shape.flatten().astype(np.float32)
        router_f = self.router_pos.astype(np.float32)
        # combine
        obs = np.concatenate([stock_f, shape_f, router_f], axis=0)
        return obs

    def _get_worker_obs(self):
        """
        Worker sees local environment + subgoal. 
        For demonstration, let's flatten as well, with subgoal appended.
        """
        stock_f = self.stock.flatten().astype(np.float32)
        shape_f = self.shape.flatten().astype(np.float32)
        router_f = self.router_pos.astype(np.float32)
        subgoal_f = self.subgoal.astype(np.float32)
        obs = np.concatenate([stock_f, shape_f, router_f, subgoal_f], axis=0)
        return obs

########################################################
# 2) FEUDAL NETWORKS: MANAGER & WORKER ARCHITECTURES
########################################################

class ManagerNet(nn.Module):
    """
    Manager policy (discrete or continuous subgoal).
    For simplicity, let's produce a discrete subgoal coordinate in [0, grid_size).
    Or we produce 3 integers (subgoal_x, subgoal_y, subgoal_z).
    
    We'll do a small MLP that outputs 3 integers by classification 
    or a continuous approach. We'll do discrete for demonstration.
    
    We'll define: 
      subgoal_x in [0..grid_size-1]
      subgoal_y in [0..grid_size-1]
      subgoal_z in [0..grid_size-1]
    => (grid_size^3) possible subgoals is huge, but we'll keep it small (like 8^3=512)
    
    We'll do a single "logits" vector of length = grid_size^3, then sample from that distribution.
    
    In a real system, you might do a param for subgoal or a 3D bounding region, etc.
    """
    def __init__(self, input_dim, grid_size=8, hidden_dim=256):
        super().__init__()
        self.grid_size = grid_size
        self.output_dim = grid_size**3  # subgoal as a single categorical
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.output_dim)
        )
    
    def forward(self, x):
        """
        x shape: [batch, input_dim]
        return: [batch, grid_size^3] (logits)
        """
        return self.net(x)

class WorkerNet(nn.Module):
    """
    Worker policy. The worker has 6 discrete actions: +x,-x,+y,-y,+z,-z.
    We'll do a small MLP that outputs 6 Q-values or policy logits. 
    We'll do a policy approach for simplicity (like a categorical distribution).
    """
    def __init__(self, input_dim, n_actions=6, hidden_dim=256):
        super().__init__()
        self.n_actions = n_actions
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    
    def forward(self, x):
        """
        x shape: [batch, input_dim]
        output: [batch, n_actions] (logits for discrete actions)
        """
        return self.net(x)

########################################################
# 3) REPLAY BUFFERS
########################################################

class ManagerReplayBuffer:
    """
    For manager: we store (s, subgoal, reward, s_next, done).
    But manager's action = subgoal in discrete space (0..grid_size^3 -1).
    """
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, s, a, r, s_next, done):
        self.buffer.append((s, a, r, s_next, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s_next, d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(s_next), np.array(d)
    
    def __len__(self):
        return len(self.buffer)


class WorkerReplayBuffer:
    """
    For worker: (s, action, reward, s_next, done).
    actions in [0..5].
    """
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, s, a, r, s_next, done):
        self.buffer.append((s, a, r, s_next, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s_next, d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(s_next), np.array(d)
    
    def __len__(self):
        return len(self.buffer)

########################################################
# 4) TRAINING FUNCTIONS
########################################################

def manager_select_subgoal(manager_net, manager_state, grid_size, epsilon=0.1):
    """
    Epsilon-greedy selection for subgoal.
    manager_net outputs logits for grid_size^3 subgoals.
    """
    if random.random() < epsilon:
        # random subgoal
        subgoal_id = random.randint(0, grid_size**3 - 1)
    else:
        with torch.no_grad():
            inp = torch.FloatTensor(manager_state).unsqueeze(0)  # [1, input_dim]
            logits = manager_net(inp)  # [1, grid_size^3]
            subgoal_id = logits.argmax(dim=1).item()
    # decode subgoal_id -> (x,y,z)
    z = subgoal_id % grid_size
    y = (subgoal_id // grid_size) % grid_size
    x = (subgoal_id // (grid_size*grid_size)) % grid_size
    return subgoal_id, (x,y,z)

def worker_select_action(worker_net, worker_state, n_actions=6, epsilon=0.1):
    """
    Epsilon-greedy for worker's discrete 6 actions.
    """
    if random.random() < epsilon:
        return random.randint(0, n_actions-1)
    else:
        with torch.no_grad():
            inp = torch.FloatTensor(worker_state).unsqueeze(0)
            logits = worker_net(inp)
            action_id = logits.argmax(dim=1).item()
        return action_id

def manager_update(
    manager_net,
    manager_optimizer,
    replay_buffer,
    batch_size,
    gamma=0.9,
    grid_size=8,
    device="cpu"
):
    if len(replay_buffer) < batch_size:
        return
    s, a, r, s_next, d = replay_buffer.sample(batch_size)

    # convert to tensors
    s_t = torch.FloatTensor(s).to(device)
    a_t = torch.LongTensor(a).to(device)
    r_t = torch.FloatTensor(r).to(device)
    s_next_t = torch.FloatTensor(s_next).to(device)
    d_t = torch.BoolTensor(d).to(device)

    # forward
    logits = manager_net(s_t)  # [batch, grid_size^3]
    # gather predicted Q ~ cross-entropy approach or we do a Q approach?
    # Typically for a policy approach, we do policy gradient, but let's do a DQN style for manager:
    # We'll treat output as Q(s, subgoal). Then a is index in [0..grid_size^3).
    q_s_a = logits.gather(1, a_t.unsqueeze(1)).squeeze(1)

    # next Q
    with torch.no_grad():
        logits_next = manager_net(s_next_t)
        max_q_next, _ = torch.max(logits_next, dim=1)
        max_q_next[d_t] = 0.0

    target = r_t + gamma * max_q_next

    # MSE loss
    loss = nn.MSELoss()(q_s_a, target)

    manager_optimizer.zero_grad()
    loss.backward()
    manager_optimizer.step()

def worker_update(
    worker_net,
    worker_optimizer,
    replay_buffer,
    batch_size,
    gamma=0.9,
    n_actions=6,
    device="cpu"
):
    if len(replay_buffer) < batch_size:
        return
    s, a, r, s_next, d = replay_buffer.sample(batch_size)

    s_t = torch.FloatTensor(s).to(device)
    a_t = torch.LongTensor(a).to(device)
    r_t = torch.FloatTensor(r).to(device)
    s_next_t = torch.FloatTensor(s_next).to(device)
    d_t = torch.BoolTensor(d).to(device)

    # Q-style:
    logits = worker_net(s_t)  # [batch, n_actions]
    q_s_a = logits.gather(1, a_t.unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        logits_next = worker_net(s_next_t)
        max_q_next, _ = torch.max(logits_next, dim=1)
        max_q_next[d_t] = 0.0

    target = r_t + gamma * max_q_next

    loss = nn.MSELoss()(q_s_a, target)

    worker_optimizer.zero_grad()
    loss.backward()
    worker_optimizer.step()

########################################################
# 5) MAIN TRAINING LOOP
########################################################

def train_feudal_milling(
    num_episodes=1000,
    grid_size=8,
    manager_update_freq=10,  # T steps
    manager_gamma=0.9,
    worker_gamma=0.9,
    lr_manager=1e-3,
    lr_worker=1e-3,
    batch_size=32,
    manager_epsilon_start=1.0,
    worker_epsilon_start=1.0,
    manager_epsilon_end=0.1,
    worker_epsilon_end=0.1,
    manager_epsilon_decay=0.995,
    worker_epsilon_decay=0.995,
    device="cpu"
):
    """
    A full training loop for Feudal RL on the 3D milling environment.
    The manager sets subgoals every manager_update_freq steps.
    The worker attempts to follow them.
    """
    env = Milling3DEnvFeudal(grid_size=grid_size)
    # manager / worker input dims
    manager_input_dim = env._get_manager_obs().shape[0]
    worker_input_dim = env._get_worker_obs().shape[0]
    manager_net = ManagerNet(manager_input_dim, grid_size=grid_size).to(device)
    worker_net = WorkerNet(worker_input_dim, n_actions=6).to(device)

    manager_optimizer = optim.Adam(manager_net.parameters(), lr=lr_manager)
    worker_optimizer = optim.Adam(worker_net.parameters(), lr=lr_worker)

    manager_replay = ManagerReplayBuffer(capacity=50000)
    worker_replay = WorkerReplayBuffer(capacity=50000)

    manager_epsilon = manager_epsilon_start
    worker_epsilon = worker_epsilon_start

    # TensorBoard
    log_dir = os.path.join("runs", f"feudal_milling_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
    writer = SummaryWriter(log_dir=log_dir)

    global_step = 0

    for ep in range(num_episodes):
        m_obs, w_obs = env.reset()

        done = False
        episode_reward_manager = 0.0
        episode_reward_worker = 0.0
        manager_steps = 0

        # choose initial subgoal
        # we can do 1 subgoal right away
        m_a_id, subgoal_xyz = manager_select_subgoal(manager_net, m_obs, grid_size, epsilon=manager_epsilon)
        env.set_subgoal(np.array(subgoal_xyz))

        while not done:
            # manager only updates subgoal every manager_update_freq steps
            # but we might store transitions each time. We'll do a 
            # "macro-step" approach: after T worker actions, manager sees new state & gets reward.

            manager_s = m_obs
            manager_done = env.manager_done()  # typically just env done
            # We'll store manager transitions after T steps or if done

            # record old manager reward for later
            old_manager_reward = env.manager_reward()

            # run T worker steps (or until done)
            for _ in range(manager_update_freq):
                if env.done:
                    break
                w_s = w_obs
                action_worker = worker_select_action(worker_net, w_s, n_actions=6, epsilon=worker_epsilon)
                w_s_next, worker_r, w_done = env.worker_step(action_worker)
                w_obs = w_s_next
                episode_reward_worker += worker_r

                # store in worker replay
                worker_replay.push(w_s, action_worker, worker_r, w_s_next, w_done)

                # update worker
                worker_update(
                    worker_net,
                    worker_optimizer,
                    worker_replay,
                    batch_size=batch_size,
                    gamma=worker_gamma,
                    n_actions=6,
                    device=device
                )

                if w_done:
                    break

            # after T steps or done, manager sees new state
            m_obs_next = env._get_manager_obs()
            manager_r = env.manager_reward()  # new manager reward
            manager_done = env.manager_done()

            # The manager's "step" reward is the difference in the global measure or some shaping:
            # We'll do a simple approach: manager_r - old_manager_reward
            delta_manager_r = manager_r - old_manager_reward
            episode_reward_manager += delta_manager_r

            # store manager transition
            manager_replay.push(manager_s, m_a_id, delta_manager_r, m_obs_next, manager_done)

            # manager update
            manager_update(
                manager_net,
                manager_optimizer,
                manager_replay,
                batch_size=batch_size,
                gamma=manager_gamma,
                grid_size=grid_size,
                device=device
            )

            # next manager state
            m_obs = m_obs_next
            manager_steps += 1

            done = manager_done
            if not done:
                # pick a new subgoal
                m_a_id, subgoal_xyz = manager_select_subgoal(manager_net, m_obs, grid_size, epsilon=manager_epsilon)
                env.set_subgoal(np.array(subgoal_xyz))

        # end of episode
        # Decay epsilon
        manager_epsilon = max(manager_epsilon_end, manager_epsilon * manager_epsilon_decay)
        worker_epsilon = max(worker_epsilon_end, worker_epsilon * worker_epsilon_decay)

        episode_return = episode_reward_worker + episode_reward_manager

        writer.add_scalar("Episode/ManagerReturn", episode_reward_manager, ep)
        writer.add_scalar("Episode/WorkerReturn", episode_reward_worker, ep)
        writer.add_scalar("Episode/TotalReturn", episode_return, ep)
        writer.add_scalar("Epsilon/Manager", manager_epsilon, ep)
        writer.add_scalar("Epsilon/Worker", worker_epsilon, ep)
        writer.add_scalar("Manager/EpisodeSteps", manager_steps, ep)

        print(f"Episode {ep+1}/{num_episodes}, ManagerR={episode_reward_manager:.2f}, "
              f"WorkerR={episode_reward_worker:.2f}, EpsM={manager_epsilon:.2f}, EpsW={worker_epsilon:.2f}")

    writer.close()
    print("Training complete.")
    return manager_net, worker_net

########################################################
# 6) DEMO / GIF GENERATION (OPTIONAL)
########################################################

def create_3d_demo_gif(manager_net, worker_net, grid_size=8, max_steps=300, output_gif="feudal_milling_demo.gif"):
    """
    Runs a single episode with the trained manager & worker, 
    capturing frames of environment states. We'll do a 
    simple scatter or skip advanced 3D rendering for brevity.
    """
    env = Milling3DEnvFeudal(grid_size=grid_size, max_steps=max_steps)
    m_obs, w_obs = env.reset()
    frames = []

    # we won't do sub-step rendering in 3D for every step here, 
    # but let's do a naive 2D slice rendering or just record some text. 
    # For demonstration, let's create a text-based "frame" – not too fancy. 
    # If you want real 3D visuals, you'd replicate the environment's 3D scatter method 
    # and store as images. We'll do a textual approach to keep it self-contained.

    manager_epsilon = 0.0
    worker_epsilon = 0.0

    def snapshot_text(env):
        # create a text image for demonstration
        # but let's do a single slice z=0 to see environment
        # (In real usage, you'd do proper 3D scatter and convert to RGBA.)
        slice_2d = env.stock[:,:,0]
        shape_2d = env.shape[:,:,0]
        router = (env.router_pos[0], env.router_pos[1])
        lines = []
        for y in range(env.grid_size):
            row = []
            for x in range(env.grid_size):
                if (x,y) == (router[0], router[1]) and 0==env.router_pos[2]:
                    row.append("R")
                elif slice_2d[x,y] == 1 and shape_2d[x,y] == 1:
                    row.append("S") # shape uncut
                elif slice_2d[x,y] == 1 and shape_2d[x,y] == 0:
                    row.append("O") # outside stock
                else:
                    row.append(".") # empty
            lines.append("".join(row))
        text_repr = "\n".join(lines)
        return text_repr

    done = False
    manager_update_freq = 10

    # initial subgoal
    m_a_id, subgoal_xyz = manager_select_subgoal(manager_net, m_obs, grid_size, epsilon=manager_epsilon)
    env.set_subgoal(np.array(subgoal_xyz))

    step_count = 0
    while not done and step_count < max_steps:
        manager_s = m_obs
        old_mgr_r = env.manager_reward()

        # run T worker steps
        for _ in range(manager_update_freq):
            if env.done:
                break
            # record frame as text
            frame_txt = snapshot_text(env)
            frames.append(frame_txt)

            w_s = w_obs
            action_worker = worker_select_action(worker_net, w_s, epsilon=worker_epsilon)
            w_s_next, r_worker, done_worker = env.worker_step(action_worker)
            w_obs = w_s_next
            step_count += 1

            if done_worker:
                break

        m_obs_next = env._get_manager_obs()
        new_mgr_r = env.manager_reward()
        manager_done = env.manager_done()
        done = manager_done

        # pick new subgoal
        if not done:
            m_a_id, subgoal_xyz = manager_select_subgoal(manager_net, m_obs_next, grid_size, epsilon=manager_epsilon)
            env.set_subgoal(np.array(subgoal_xyz))

        m_obs = m_obs_next

    # final frame
    frames.append(snapshot_text(env))

    # Convert text frames to images for a naive "GIF"
    # We'll just do a white background with black text
    rendered_frames = []
    import PIL
    from PIL import Image, ImageDraw, ImageFont

    font = ImageFont.load_default()
    def text_to_image(txt):
        lines = txt.split("\n")
        w = max(len(line) for line in lines)
        h = len(lines)
        # each char let's say 10x15
        char_w, char_h = 10, 15
        img_w = char_w * w
        img_h = char_h * h
        img = Image.new("RGB", (img_w, img_h), color="white")
        d = ImageDraw.Draw(img)
        for i, line in enumerate(lines):
            d.text((0, i*char_h), line, font=font, fill=(0,0,0))
        return img

    for f_txt in frames:
        img = text_to_image(f_txt)
        rendered_frames.append(img)

    rendered_frames[0].save(
        output_gif,
        save_all=True,
        append_images=rendered_frames[1:],
        duration=500,
        loop=0
    )
    print(f"Demo GIF saved: {output_gif}")


########################################################
# 7) RUN EXAMPLE
########################################################

if __name__ == "__main__":
    manager_net, worker_net = train_feudal_milling(
        num_episodes=1000,   # increase for real training
        grid_size=8
    )
    create_3d_demo_gif(manager_net, worker_net, grid_size=8, max_steps=300, output_gif="feudal_milling_demo.gif")
    print("Done.")

Episode 1/1000, ManagerR=-4.69, WorkerR=-8.68, EpsM=0.99, EpsW=0.99
Episode 2/1000, ManagerR=-4.54, WorkerR=-20.95, EpsM=0.99, EpsW=0.99
Episode 3/1000, ManagerR=-3.66, WorkerR=-28.11, EpsM=0.99, EpsW=0.99
Episode 4/1000, ManagerR=-3.89, WorkerR=-77.10, EpsM=0.98, EpsW=0.98
Episode 5/1000, ManagerR=-4.61, WorkerR=-19.36, EpsM=0.98, EpsW=0.98
Episode 6/1000, ManagerR=-3.97, WorkerR=-54.69, EpsM=0.97, EpsW=0.97
Episode 7/1000, ManagerR=-4.61, WorkerR=-13.31, EpsM=0.97, EpsW=0.97
Episode 8/1000, ManagerR=-2.93, WorkerR=-13.76, EpsM=0.96, EpsW=0.96
Episode 9/1000, ManagerR=-4.51, WorkerR=-20.97, EpsM=0.96, EpsW=0.96
Episode 10/1000, ManagerR=1.46, WorkerR=-54.31, EpsM=0.95, EpsW=0.95
Episode 11/1000, ManagerR=-4.21, WorkerR=-31.37, EpsM=0.95, EpsW=0.95
Episode 12/1000, ManagerR=-4.69, WorkerR=-7.39, EpsM=0.94, EpsW=0.94
Episode 13/1000, ManagerR=-3.56, WorkerR=-40.10, EpsM=0.94, EpsW=0.94
Episode 14/1000, ManagerR=-4.56, WorkerR=-12.18, EpsM=0.93, EpsW=0.93
Episode 15/1000, ManagerR=-3.35,

In [28]:
#!/usr/bin/env python3
"""
Feudal RL Example (Line-Based Milling):
  - Worker chooses a destination coordinate in [0..grid_size^3-1].
  - We remove stock along the line from the current router position to the target.
    If shape or out-of-bounds is encountered, episode ends immediately in failure.
  - We track:
    - Worker moves (worker_move_count)
    - Fraction of outside stock removed
    - Total outside removed
  - We store separate experiences for Manager and Worker in their own replay buffers.
  - The manager picks subgoals every `manager_update_freq` worker steps.

Usage:
  python feudal_milling.py
Then:
  tensorboard --logdir=runs
"""

import os
import random
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from torch.utils.tensorboard import SummaryWriter

########################################################################
# 1) 3D Milling Environment
########################################################################

class Milling3DEnvFeudal:
    """
    3D milling environment for a Feudal RL approach:
      - stock[x,y,z] = 1 => material present, 0 => removed
      - shape[x,y,z] = 1 => protected shape
      - router_pos => current router position in the grid
      - The worker picks an action in [0..grid_size^3 - 1], which we decode
        into (x,y,z). We remove stock along the path from current->target.
        If we encounter shape or out-of-bounds, fail immediately.
      - We track worker_move_count and initial_outside_count, so we can measure
        how many steps before failing/succeeding and how much outside stock was removed.
    """

    def __init__(
        self,
        grid_size=8,
        max_steps=200,
        penalty_cut_shape=10.0,
        reward_outside_cut=1.0,
        success_bonus=30.0,
        step_penalty=0.1,
        min_radius=2,
        max_radius=3
    ):
        self.grid_size = grid_size
        self.max_steps = max_steps
        self.penalty_cut_shape = penalty_cut_shape
        self.reward_outside_cut = reward_outside_cut
        self.success_bonus = success_bonus
        self.step_penalty = step_penalty
        self.min_radius = min_radius
        self.max_radius = max_radius

        self.worker_move_count = 0
        self.initial_outside_count = 0

        # We'll call reset() immediately
        self.reset()

    def reset(self):
        # Create full stock
        self.stock = np.ones((self.grid_size, self.grid_size, self.grid_size), dtype=int)
        # Create shape
        self.shape = np.zeros((self.grid_size, self.grid_size, self.grid_size), dtype=int)

        # Random sphere in the center
        cx = self.grid_size // 2
        cy = self.grid_size // 2
        cz = self.grid_size // 2
        r = np.random.randint(self.min_radius, self.max_radius+1)
        for x in range(self.grid_size):
            for y in range(self.grid_size):
                for z in range(self.grid_size):
                    dist_sq = (x - cx)**2 + (y - cy)**2 + (z - cz)**2
                    if dist_sq <= r**2:
                        self.shape[x,y,z] = 1

        self.router_pos = np.array([0,0,0], dtype=int)
        self.steps_taken = 0
        self.done = False
        self.worker_move_count = 0

        # manager subgoal for demonstration
        self.subgoal = np.array([cx, cy, cz], dtype=int)

        # track how many outside voxels we start with
        outside_mask = (self.shape == 0)
        self.initial_outside_count = np.sum(outside_mask)

        # Return manager obs, worker obs
        return self._get_manager_obs(), self._get_worker_obs()

    def set_subgoal(self, subgoal):
        """
        Called by the manager. We clamp it in [0..grid_size-1].
        """
        subgoal = np.clip(subgoal, 0, self.grid_size-1)
        self.subgoal = subgoal.astype(int)

    def worker_step(self, action_idx):
        """
        Worker picks an action in [0..grid_size^3 -1].
        We decode it => (x,y,z). Then we remove stock along the line from router_pos
        to (x,y,z). If shape or out-of-bounds is encountered, fail immediately.
        """
        if self.done:
            return self._get_worker_obs(), 0.0, True

        # Decode
        z = action_idx % self.grid_size
        y = (action_idx // self.grid_size) % self.grid_size
        x = (action_idx // (self.grid_size*self.grid_size)) % self.grid_size
        target_pos = np.array([x,y,z], dtype=int)

        path_voxels = line_3d_voxels(self.router_pos, target_pos)

        total_reward = 0.0
        failed = False

        for vx, vy, vz in path_voxels:
            # check out-of-bounds
            if not np.all((vx >= 0) & (vx < self.grid_size)):
                total_reward -= 999.0  # big penalty
                failed = True
                break
            # check shape
            if self.shape[vx,vy,vz] == 1:
                total_reward -= self.penalty_cut_shape
                failed = True
                break
            # remove stock if present
            if self.stock[vx,vy,vz] == 1:
                self.stock[vx,vy,vz] = 0
                total_reward += self.reward_outside_cut

            # small penalty for each voxel traversed
            total_reward -= self.step_penalty

        if not failed:
            # successfully moved to target
            self.router_pos = target_pos

            # check if all outside is removed
            outside_mask = (self.shape == 0)
            if np.sum(self.stock[outside_mask]) == 0:
                total_reward += self.success_bonus
                failed = True  # success end

        # negative distance to subgoal
        dist = np.linalg.norm(self.router_pos - self.subgoal)
        total_reward -= 0.1 * dist

        self.steps_taken += 1
        self.worker_move_count += 1

        if self.steps_taken >= self.max_steps:
            failed = True

        if failed:
            self.done = True

        return self._get_worker_obs(), total_reward, self.done

    def manager_reward(self):
        """
        Example manager reward: fraction outside removed vs. shape cut
        """
        shape_cut = (self.shape == 1) & (self.stock == 0)
        cut_count = np.sum(shape_cut)

        outside_mask = (self.shape == 0)
        outside_total = np.sum(outside_mask)
        outside_removed = outside_total - np.sum(self.stock[outside_mask])
        frac_outside_removed = outside_removed / (outside_total + 1e-8)

        r = 10.0 * frac_outside_removed - 5.0 * cut_count
        return r

    def manager_done(self):
        return self.done

    def fraction_outside_removed(self):
        """
        fraction of outside stock removed
        """
        outside_mask = (self.shape == 0)
        current_outside = np.sum(self.stock[outside_mask])
        removed = self.initial_outside_count - current_outside
        return removed / (self.initial_outside_count + 1e-8)

    def total_outside_removed(self):
        """
        returns how many outside voxels have been removed
        """
        outside_mask = (self.shape == 0)
        current_outside = np.sum(self.stock[outside_mask])
        removed = self.initial_outside_count - current_outside
        return removed

    def _get_manager_obs(self):
        """
        Manager sees the entire stock+shape plus router coords, flattened.
        """
        stock_f = self.stock.flatten().astype(np.float32)
        shape_f = self.shape.flatten().astype(np.float32)
        router_f = self.router_pos.astype(np.float32)
        return np.concatenate([stock_f, shape_f, router_f], axis=0)

    def _get_worker_obs(self):
        """
        Worker sees the entire stock+shape plus router coords & subgoal, flattened.
        """
        stock_f = self.stock.flatten().astype(np.float32)
        shape_f = self.shape.flatten().astype(np.float32)
        router_f = self.router_pos.astype(np.float32)
        subgoal_f = self.subgoal.astype(np.float32)
        return np.concatenate([stock_f, shape_f, router_f, subgoal_f], axis=0)


def line_3d_voxels(start, end):
    """
    Return a list of voxel coords in the line from start->end (3D).
    A naive float stepping approach to approximate a Bresenham path.
    """
    s = start.astype(float)
    e = end.astype(float)
    diff = e - s
    length = int(np.linalg.norm(diff))
    if length == 0:
        return [tuple(start)]
    steps = max(1, length*2)
    out = []
    for i in range(steps+1):
        t = i / steps
        p = s + diff * t
        coords = np.round(p).astype(int)
        out.append(tuple(coords))
    # deduplicate
    uniq = []
    seen = set()
    for c in out:
        if c not in seen:
            seen.add(c)
            uniq.append(c)
    return uniq

########################################################################
# 2) Manager & Worker Networks + Replay
########################################################################

class ManagerNet(nn.Module):
    def __init__(self, input_dim, grid_size=8, hidden_dim=256):
        super().__init__()
        self.grid_size = grid_size
        self.output_dim = grid_size**3
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.output_dim)
        )
    def forward(self, x):
        return self.net(x)

class WorkerNet(nn.Module):
    def __init__(self, input_dim, grid_size=8, hidden_dim=256):
        super().__init__()
        self.grid_size = grid_size
        self.output_dim = grid_size**3
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.output_dim)
        )
    def forward(self, x):
        return self.net(x)

class ManagerReplayBuffer:
    def __init__(self, capacity=50000):
        self.buffer = deque(maxlen=capacity)
    def push(self, s, a, r, s_next, done):
        self.buffer.append((s, a, r, s_next, done))
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s,a,r,s_next,d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(s_next), np.array(d)
    def __len__(self):
        return len(self.buffer)

class WorkerReplayBuffer:
    def __init__(self, capacity=50000):
        self.buffer = deque(maxlen=capacity)
    def push(self, s, a, r, s_next, done):
        self.buffer.append((s, a, r, s_next, done))
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s,a,r,s_next,d = zip(*batch)
        return np.array(s), np.array(a), np.array(r), np.array(s_next), np.array(d)
    def __len__(self):
        return len(self.buffer)

########################################################################
# 3) Action Selection and Update Functions
########################################################################

def manager_select_subgoal(manager_net, manager_state, grid_size, epsilon=0.1):
    """
    Epsilon-greedy subgoal selection in [0..grid_size^3 - 1].
    """
    if random.random() < epsilon:
        subgoal_id = random.randint(0, grid_size**3 - 1)
    else:
        with torch.no_grad():
            inp = torch.FloatTensor(manager_state).unsqueeze(0)
            logits = manager_net(inp)  # shape [1, grid_size^3]
            subgoal_id = logits.argmax(dim=1).item()
    # decode subgoal_id -> (x,y,z)
    z = subgoal_id % grid_size
    y = (subgoal_id // grid_size) % grid_size
    x = (subgoal_id // (grid_size*grid_size)) % grid_size
    return subgoal_id, (x,y,z)

def worker_select_action(worker_net, worker_state, n_actions, epsilon=0.1):
    """
    Epsilon-greedy for the worker's coordinate choice in [0..grid_size^3-1].
    """
    if random.random() < epsilon:
        return random.randint(0, n_actions-1)
    else:
        with torch.no_grad():
            inp = torch.FloatTensor(worker_state).unsqueeze(0)
            logits = worker_net(inp)  # shape [1, n_actions]
            return logits.argmax(dim=1).item()

def manager_update(net, optimizer, replay, batch_size, gamma=0.9):
    if len(replay) < batch_size:
        return
    s, a, r, s_next, d = replay.sample(batch_size)
    s_t = torch.FloatTensor(s)
    a_t = torch.LongTensor(a)
    r_t = torch.FloatTensor(r)
    s_next_t = torch.FloatTensor(s_next)
    d_t = torch.BoolTensor(d)

    logits = net(s_t)  # shape [batch, grid_size^3]
    q_s_a = logits.gather(1, a_t.unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        logits_next = net(s_next_t)
        max_q_next,_ = torch.max(logits_next, dim=1)
        max_q_next[d_t] = 0.0
    target = r_t + gamma * max_q_next

    loss = nn.MSELoss()(q_s_a, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

def worker_update(net, optimizer, replay, batch_size, gamma=0.9, n_actions=8**3):
    if len(replay) < batch_size:
        return
    s, a, r, s_next, d = replay.sample(batch_size)
    s_t = torch.FloatTensor(s)
    a_t = torch.LongTensor(a)
    r_t = torch.FloatTensor(r)
    s_next_t = torch.FloatTensor(s_next)
    d_t = torch.BoolTensor(d)

    logits = net(s_t)  # [batch, n_actions]
    q_s_a = logits.gather(1, a_t.unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        logits_next = net(s_next_t)
        max_q_next,_ = torch.max(logits_next, dim=1)
        max_q_next[d_t] = 0.0
    target = r_t + gamma * max_q_next

    loss = nn.MSELoss()(q_s_a, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

########################################################################
# 4) Main Training Loop
########################################################################

def train_feudal_milling(
    num_episodes=50,
    grid_size=8,
    manager_update_freq=10,
    lr_manager=1e-3,
    lr_worker=1e-3,
    batch_size=32,
    manager_gamma=0.9,
    worker_gamma=0.9,
    manager_eps_start=1.0,
    manager_eps_end=0.1,
    manager_eps_decay=0.995,
    worker_eps_start=1.0,
    worker_eps_end=0.1,
    worker_eps_decay=0.995
):
    log_dir = os.path.join("runs", f"feudal_milling_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
    writer = SummaryWriter(log_dir=log_dir)

    # Create environment and get observation dims
    env = Milling3DEnvFeudal(grid_size=grid_size)
    manager_in_dim = env._get_manager_obs().shape[0]
    worker_in_dim = env._get_worker_obs().shape[0]

    # Networks
    manager_net = ManagerNet(manager_in_dim, grid_size=grid_size)
    worker_net = WorkerNet(worker_in_dim, grid_size=grid_size)

    # Optimizers
    manager_opt = optim.Adam(manager_net.parameters(), lr=lr_manager)
    worker_opt = optim.Adam(worker_net.parameters(), lr=lr_worker)

    # Replay Buffers
    manager_replay = ManagerReplayBuffer()
    worker_replay = WorkerReplayBuffer()

    # Epsilons
    manager_eps = manager_eps_start
    worker_eps = worker_eps_start

    for ep in range(num_episodes):
        m_obs, w_obs = env.reset()
        done = False
        manager_return = 0.0
        worker_return = 0.0

        old_mgr_r = env.manager_reward()
        sg_id, sg_xyz = manager_select_subgoal(manager_net, m_obs, grid_size, epsilon=manager_eps)
        env.set_subgoal(np.array(sg_xyz))

        while not done:
            manager_s = m_obs

            # run T worker steps
            for _ in range(manager_update_freq):
                if env.done:
                    break
                w_s = w_obs
                action_worker = worker_select_action(worker_net, w_s, env.grid_size**3, worker_eps)
                w_s_next, w_r, w_done = env.worker_step(action_worker)
                w_obs = w_s_next
                worker_return += w_r
                done = w_done

                # store worker experience
                worker_replay.push(w_s, action_worker, w_r, w_s_next, w_done)
                worker_update(worker_net, worker_opt, worker_replay, batch_size, gamma=worker_gamma)

                if done:
                    break

            m_obs_next = env._get_manager_obs()
            new_mgr_r = env.manager_reward()
            delta_r = new_mgr_r - old_mgr_r
            old_mgr_r = new_mgr_r
            manager_return += delta_r

            m_done = env.manager_done()
            done = m_done

            manager_replay.push(manager_s, sg_id, delta_r, m_obs_next, m_done)
            manager_update(manager_net, manager_opt, manager_replay, batch_size, gamma=manager_gamma)

            m_obs = m_obs_next

            if not done:
                sg_id, sg_xyz = manager_select_subgoal(manager_net, m_obs, grid_size, epsilon=manager_eps)
                env.set_subgoal(np.array(sg_xyz))

        # Episode end: update eps
        manager_eps = max(manager_eps_end, manager_eps * manager_eps_decay)
        worker_eps = max(worker_eps_end, worker_eps * worker_eps_decay)

        total_return = manager_return + worker_return
        frac_removed = env.fraction_outside_removed()
        total_removed = env.total_outside_removed()
        moves_before_end = env.worker_move_count

        # Log to TensorBoard
        writer.add_scalar("Episode/ManagerReturn", manager_return, ep)
        writer.add_scalar("Episode/WorkerReturn", worker_return, ep)
        writer.add_scalar("Episode/TotalReturn", total_return, ep)
        writer.add_scalar("Episode/OutsideRemovedFraction", frac_removed, ep)
        writer.add_scalar("Episode/OutsideRemovedCount", total_removed, ep)
        writer.add_scalar("Episode/WorkerMovesBeforeEnd", moves_before_end, ep)
        writer.add_scalar("Epsilon/Manager", manager_eps, ep)
        writer.add_scalar("Epsilon/Worker", worker_eps, ep)

        print(f"[Ep {ep+1}/{num_episodes}] "
              f"ManagerR={manager_return:.2f}, WorkerR={worker_return:.2f}, "
              f"Total={total_return:.2f}, Moves={moves_before_end}, "
              f"FracRemoved={frac_removed:.3f}, RemovedCount={total_removed}, "
              f"EpsM={manager_eps:.2f}, EpsW={worker_eps:.2f}")

    writer.close()
    print("Training complete.")
    return manager_net, worker_net

def main():
    train_feudal_milling(num_episodes=100000, grid_size=8)
    print("Done.")

if __name__ == "__main__":
    main()

[Ep 1/100000] ManagerR=0.25, WorkerR=-0.56, Total=-0.31, Moves=2, FracRemoved=0.025, RemovedCount=12, EpsM=0.99, EpsW=0.99
[Ep 2/100000] ManagerR=0.92, WorkerR=19.51, Total=20.42, Moves=14, FracRemoved=0.092, RemovedCount=44, EpsM=0.99, EpsW=0.99
[Ep 3/100000] ManagerR=0.17, WorkerR=-3.62, Total=-3.45, Moves=2, FracRemoved=0.017, RemovedCount=8, EpsM=0.99, EpsW=0.99
[Ep 4/100000] ManagerR=0.31, WorkerR=1.87, Total=2.18, Moves=3, FracRemoved=0.031, RemovedCount=15, EpsM=0.98, EpsW=0.98
[Ep 5/100000] ManagerR=0.17, WorkerR=-3.53, Total=-3.37, Moves=1, FracRemoved=0.017, RemovedCount=8, EpsM=0.98, EpsW=0.98
[Ep 6/100000] ManagerR=0.13, WorkerR=-5.31, Total=-5.19, Moves=1, FracRemoved=0.013, RemovedCount=6, EpsM=0.97, EpsW=0.97
[Ep 7/100000] ManagerR=0.13, WorkerR=-6.33, Total=-6.20, Moves=1, FracRemoved=0.013, RemovedCount=5, EpsM=0.97, EpsW=0.97
[Ep 8/100000] ManagerR=0.63, WorkerR=14.23, Total=14.86, Moves=5, FracRemoved=0.063, RemovedCount=30, EpsM=0.96, EpsW=0.96
[Ep 9/100000] Manager

KeyboardInterrupt: 

In [29]:
!tensorboard --logdir runs

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C
