# CNN-Based Architecture

## Input

3 Channels --> stock, shape, router  
We treat each 10×10 grid (stock, shape, and router location) as a separate channel, so the CNN can learn spatial relationships among material, protected shape, and the agent’s position.  
Concatenate the channels along the image dimension (like RGB) so we have 1. stock channel (1 where uncut stock remains) 2. shape channel (1 for protected shape, 0 otherwise) and 3. router channel: 1 at the routers grid cell, 0 otherwise.

## Convolutional layer 1

32 filters, kernel size 3x3, stride = 1, ReLU activation  
Why 32 filters? Enough capacity to capture local patterns (edges, corners, small features) without being too large for a small 10×10 input.  
Why a 3×3 kernel? This is a classic, effective “local receptive field” size that balances fine detail with efficient training.  

## Convolutional Layer 2
64 filters, kernel size 3x3, stride = 1, ReLU activation
Why another layer? Stacking conv layers lets the network learn higher-order features (combinations of the lower-layer edge/texture patterns) – essential for distinguishing shape boundaries vs. outside stock.  
Why 64 filters? Doubling filters in the second layer is a common practice, giving more representational power for more complex patterns.  

## Pooling Layer
2x2 max pooling  
reduces spatial resolution by half (10x10 to 5x5), lowers compute cost and increasing receptive field (avoid overfitting on small details)  

## Flatten

Why? After extracting spatial features, flattening converts the feature map into a 1D vector for a fully connected layer. This merges all local feature activations into a representation for decision-making.  

## Fully connected layer
128 units, relu activation  
Why 128? It’s enough capacity to combine and interpret the learned spatial features without being excessively large for a 10×10 grid.  

## Output Layer
Config: size = number of actions  
Why? In a DQN, this final layer directly outputs the Q-values for each discrete action. We only need as many outputs as there are possible moves.  

By combining two convolutional layers (each capturing progressively higher-level spatial features) with a final dense layer (for integrating those features into action values), we get a compact but effective architecture. The small 10×10 input size means a deep or wider network might overfit or be computationally wasteful, so 2 conv layers + 1 dense layer is a balanced choice for this milling task.  

In [59]:
!pip3 install tensorboard

Defaulting to user installation because normal site-packages is not writeable
Collecting tensorboard
  Downloading tensorboard-2.19.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 9.7 MB/s eta 0:00:01
[?25hCollecting markdown>=2.6.8
  Downloading Markdown-3.7-py3-none-any.whl (106 kB)
[K     |████████████████████████████████| 106 kB 44.6 MB/s eta 0:00:01
[?25hCollecting tensorboard-data-server<0.8.0,>=0.7.0
  Downloading tensorboard_data_server-0.7.2-py3-none-any.whl (2.4 kB)
Collecting protobuf!=4.24.0,>=3.19.6
  Downloading protobuf-6.30.2-cp39-abi3-macosx_10_9_universal2.whl (417 kB)
[K     |████████████████████████████████| 417 kB 11.6 MB/s eta 0:00:01
Collecting werkzeug>=1.0.1
  Downloading werkzeug-3.1.3-py3-none-any.whl (224 kB)
[K     |████████████████████████████████| 224 kB 45.7 MB/s eta 0:00:01
Collecting grpcio>=1.48.2
  Downloading grpcio-1.71.0-cp39-cp39-macosx_10_14_universal2.whl (11.3 MB)
[K     |████████████████████████████████| 11

In [71]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
import imageio
from IPython.display import display, Image
from collections import deque
from torch.utils.tensorboard import SummaryWriter

# ----------------------------
# 1) 3D ENVIRONMENT DEFINITION
# ----------------------------

class Milling3DEnv:
    """
    A 3D milling environment on an N x N x N grid:
      - stock[x,y,z] = 1 means material present, 0 means removed
      - shape[x,y,z] = 1 means protected shape
      - router_pos = (rx, ry, rz) within [0, N-1]
      - 6 discrete actions: move ±1 in x, y, or z
      - always-on cutter
      - random sphere in the center for the shape
    """
    def __init__(
        self,
        grid_size=8,
        max_steps=200,
        reward_outside_cut=1.0,
        penalty_step=0.1,
        penalty_cut_shape=10.0,
        success_bonus=30.0,
        min_radius=2,
        max_radius=3
    ):
        """
        :param grid_size: size of the 3D grid
        :param max_steps: max steps before termination
        :param reward_outside_cut: reward for removing outside stock
        :param penalty_step: penalty each step
        :param penalty_cut_shape: penalty if we cut the shape
        :param success_bonus: bonus if all outside stock is removed
        :param min_radius, max_radius: random sphere radius range
        """
        self.grid_size = grid_size
        self.max_steps = max_steps
        
        # Reward parameters
        self.reward_outside_cut = reward_outside_cut
        self.penalty_step = penalty_step
        self.penalty_cut_shape = penalty_cut_shape
        self.success_bonus = success_bonus
        
        self.min_radius = min_radius
        self.max_radius = max_radius
        
        # 6 actions: 0=+x,1=-x,2=+y,3=-y,4=+z,5=-z
        self.action_space = [0,1,2,3,4,5]
        
        # Will be populated on reset
        self.stock = None
        self.shape = None
        self.router_pos = None
        self.steps_taken = 0
    
    def reset(self):
        """ Initialize the 3D grid and place a random sphere in the center. """
        self.stock = np.ones((self.grid_size, self.grid_size, self.grid_size), dtype=int)
        self.shape = np.zeros((self.grid_size, self.grid_size, self.grid_size), dtype=int)
        
        cx = self.grid_size // 2
        cy = self.grid_size // 2
        cz = self.grid_size // 2
        
        radius = np.random.randint(self.min_radius, self.max_radius+1)
        
        # Mark shape = 1 for sphere region
        for x in range(self.grid_size):
            for y in range(self.grid_size):
                for z in range(self.grid_size):
                    dist_sq = (x-cx)**2 + (y-cy)**2 + (z-cz)**2
                    if dist_sq <= radius**2:
                        self.shape[x,y,z] = 1
        
        # Place router at (0,0,0) for simplicity
        self.router_pos = np.array([0,0,0], dtype=int)
        self.steps_taken = 0
        
        return self._get_observation()
    
    def step(self, action):
        """
        action in {0,1,2,3,4,5} => ±x, ±y, ±z
        router moves 1 voxel in the chosen direction, always-on cutter
        """
        if action == 0:  # +x
            self.router_pos[0] += 1
        elif action == 1:  # -x
            self.router_pos[0] -= 1
        elif action == 2:  # +y
            self.router_pos[1] += 1
        elif action == 3:  # -y
            self.router_pos[1] -= 1
        elif action == 4:  # +z
            self.router_pos[2] += 1
        elif action == 5:  # -z
            self.router_pos[2] -= 1
        
        # Clamp within [0, grid_size-1]
        self.router_pos = np.clip(self.router_pos, 0, self.grid_size-1)
        
        self.steps_taken += 1
        reward = 0.0
        
        rx, ry, rz = self.router_pos
        # If stock present => cut it
        if self.stock[rx, ry, rz] == 1:
            self.stock[rx, ry, rz] = 0
            # Check if it's shape
            if self.shape[rx, ry, rz] == 1:
                # cut shape => fail
                reward -= self.penalty_cut_shape
                done = True
                return self._get_observation(), reward, done, {}
            else:
                # outside => positive reward
                reward += self.reward_outside_cut
        
        # step penalty
        reward -= self.penalty_step
        
        # check if outside stock is all removed
        outside_mask = (self.shape == 0)  # shape=0 => outside
        if np.sum(self.stock[outside_mask]) == 0:
            # success
            reward += self.success_bonus
            done = True
        else:
            done = (self.steps_taken >= self.max_steps)
        
        return self._get_observation(), reward, done, {}
    
    def _get_observation(self):
        """
        Return a dict with:
         - self.stock
         - self.shape
         - self.router_pos
        We'll form 3 channels for a 3D CNN:
          channel0 = stock
          channel1 = shape
          channel2 = router location
        """
        return {
            "stock": self.stock.copy(),
            "shape": self.shape.copy(),
            "router_pos": self.router_pos.copy()
        }
    
    def render(self, azim=45, elev=30):
        """
        Create a 3D scatter plot of all voxels:
        - Blue: shape=1, stock=1  (uncut shape)
        - Red:  shape=1, stock=0  (shape was cut)
        - Gray: shape=0, stock=1  (outside stock present)
        - White: shape=0, stock=0 (outside removed)
        - Black: router
        Nothing is fully opaque; alpha < 1.0 so we can see all points simultaneously.
        """
        from mpl_toolkits.mplot3d import Axes3D
        
        shape_and_stock = (self.shape == 1) & (self.stock == 1)
        shape_cut       = (self.shape == 1) & (self.stock == 0)
        outside_stock   = (self.shape == 0) & (self.stock == 1)
        outside_cut     = (self.shape == 0) & (self.stock == 0)
        
        rx, ry, rz = self.router_pos
        
        fig = plt.figure(figsize=(6,6))
        ax = fig.add_subplot(111, projection='3d')
        
        # 1) Blue: shape & stock
        xs, ys, zs = np.where(shape_and_stock)
        if len(xs) > 0:
            ax.scatter(xs, ys, zs, c='blue', alpha=0.1, s=20, marker='o', depthshade=False, label="Shape Uncut")
        
        # 2) Red: shape cut
        xs, ys, zs = np.where(shape_cut)
        if len(xs) > 0:
            ax.scatter(xs, ys, zs, c='red', alpha=0.1, s=20, marker='o', depthshade=False, label="Shape Cut")
        
        # 3) Gray: outside stock
        xs, ys, zs = np.where(outside_stock)
        if len(xs) > 0:
            ax.scatter(xs, ys, zs, c='gray', alpha=0.1, s=20, marker='o', depthshade=False, label="Outside Stock")
        
        # 4) White: outside removed
        xs, ys, zs = np.where(outside_cut)
        if len(xs) > 0:
            ax.scatter(xs, ys, zs, c='white', alpha=0.1, s=20, marker='o', depthshade=False, label="Outside Removed")
        
        # 5) Router in black
        ax.scatter(rx, ry, rz, c='black', alpha=1.0, s=60, marker='o', depthshade=False, label="Router")
        
        ax.set_xlim(0, self.grid_size)
        ax.set_ylim(0, self.grid_size)
        ax.set_zlim(0, self.grid_size)
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.view_init(elev=elev, azim=azim)
        ax.set_title(f"Steps={self.steps_taken}")
        
        # Optional legend (comment out if you don't need it)
        ax.legend(loc='upper right')
        
        fig.canvas.draw()
        # Convert to an RGBA NumPy array
        w, h = fig.canvas.get_width_height()
        rgb_buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape((h, w, 3))
        
        # Make an RGBA array (full 255 alpha channel for the final image)
        rgba_img = np.zeros((h, w, 4), dtype=np.uint8)
        rgba_img[..., :3] = rgb_buf
        rgba_img[..., 3] = 255
        
        plt.close(fig)
        return rgba_img

# ----------------------------
# 2) 3D CNN ARCHITECTURE
# ----------------------------

class Milling3DCNN(nn.Module):
    """
    3D CNN that processes a (3, D, D, D) input for stock, shape, router
    and outputs Q-values for 6 actions.
    """
    def __init__(self, grid_size=8, n_channels=3, n_actions=6):
        super(Milling3DCNN, self).__init__()
        
        # We'll do two 3D conv layers, then a pooling
        # For a small 8x8x8 grid, we can keep it modest
        self.conv_net = nn.Sequential(
            nn.Conv3d(n_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2)  # from 8->4 in each dimension
        )
        
        # Then flatten => fully connected
        # 32 filters * 4*4*4 = 32*64 = 2048
        self.fc_net = nn.Sequential(
            nn.Linear(32*(grid_size//2)*(grid_size//2)*(grid_size//2), 256),
            nn.ReLU(),
            nn.Linear(256, n_actions)
        )
    
    def forward(self, x):
        """
        x: shape [batch_size, 3, D, D, D]
        """
        feats = self.conv_net(x)  # => [batch_size, 32, 4, 4, 4] if D=8
        feats = feats.view(feats.size(0), -1)
        out = self.fc_net(feats)
        return out

# ----------------------------
# 3) REPLAY BUFFER & UTIL
# ----------------------------

class ReplayBuffer:
    def __init__(self, capacity=2000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, transition):
        self.buffer.append(transition)
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s_arr, a_arr, r_arr, s_next_arr, done_arr = zip(*batch)
        return s_arr, a_arr, r_arr, s_next_arr, done_arr
    
    def __len__(self):
        return len(self.buffer)

def obs_to_3dtensor(obs, device="cpu"):
    """
    Convert {stock, shape, router_pos} into (3, D, D, D).
    """
    stock = obs["stock"].astype(np.float32)
    shape = obs["shape"].astype(np.float32)
    router_map = np.zeros_like(stock, dtype=np.float32)
    rx, ry, rz = obs["router_pos"]
    router_map[rx, ry, rz] = 1.0
    
    # stack channels
    vol_3ch = np.stack([stock, shape, router_map], axis=0)  # (3,D,D,D)
    return torch.tensor(vol_3ch, dtype=torch.float32, device=device)

# ----------------------------
# 4) TRAINING LOOP (DQN)
# ----------------------------

def train_dqn_3d(
    num_episodes=100,
    grid_size=8,
    max_steps=200,
    gamma=0.9,
    lr=1e-3,
    batch_size=32,
    epsilon_start=1.0,
    epsilon_end=0.1,
    epsilon_decay=0.995,
    replay_capacity=5000,
    updates_per_episode=10
):
    env = Milling3DEnv(grid_size=grid_size, max_steps=max_steps)
    device = "cpu"
    
    policy_net = Milling3DCNN(grid_size=grid_size, n_channels=3, n_actions=6).to(device)
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    replay_buffer = ReplayBuffer(capacity=replay_capacity)
    
    epsilon = epsilon_start
    
    for ep in range(num_episodes):
        obs = env.reset()
        state_t = obs_to_3dtensor(obs, device=device)
        
        total_reward = 0.0
        
        for step_i in range(env.max_steps):
            # Epsilon-greedy
            if random.random() < epsilon:
                action = random.choice(env.action_space)
            else:
                with torch.no_grad():
                    q_vals = policy_net(state_t.unsqueeze(0))  # [1,6]
                action = q_vals.argmax(dim=1).item()
            
            obs_next, reward, done, _ = env.step(action)
            next_state_t = obs_to_3dtensor(obs_next, device=device)
            
            # Store CPU arrays in replay
            replay_buffer.push((
                state_t.cpu().numpy(),
                action,
                reward,
                next_state_t.cpu().numpy(),
                done
            ))
            
            state_t = next_state_t
            total_reward += reward
            if done:
                break
        
        # After episode, do multiple training updates
        if len(replay_buffer) >= batch_size:
            for _ in range(updates_per_episode):
                s_arr, a_arr, r_arr, s_next_arr, done_arr = replay_buffer.sample(batch_size)
                
                s_batch = torch.tensor(s_arr, dtype=torch.float32, device=device)
                a_batch = torch.tensor(a_arr, dtype=torch.long, device=device)
                r_batch = torch.tensor(r_arr, dtype=torch.float32, device=device)
                s_next_batch = torch.tensor(s_next_arr, dtype=torch.float32, device=device)
                done_batch = torch.tensor(done_arr, dtype=torch.bool, device=device)
                
                # shape of s_batch => [B, 3, D, D, D]
                # forward pass
                q_values = policy_net(s_batch)
                q_chosen = q_values.gather(1, a_batch.unsqueeze(1)).squeeze(1)
                
                with torch.no_grad():
                    q_next = policy_net(s_next_batch)
                    max_q_next, _ = torch.max(q_next, dim=1)
                    max_q_next[done_batch] = 0.0
                target = r_batch + gamma * max_q_next
                
                loss = nn.MSELoss()(q_chosen, target)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        # Decay epsilon
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
        
        print(f"Episode {ep+1}/{num_episodes}, total_reward={total_reward:.2f}, eps={epsilon:.2f}")
    
    print("3D DQN training complete.")
    return policy_net

# ----------------------------
# 5) DEMO & GIF CREATION
# ----------------------------

def create_3d_demo_gif(policy_net, grid_size=8, max_steps=300, output_gif="milling_3d_demo.gif"):
    """
    Use a well-trained policy to run a single episode in 3D,
    capturing frames with a fixed camera and a black path
    tracing the router's movement.
    """
    env = Milling3DEnv(grid_size=grid_size, max_steps=max_steps)
    obs = env.reset()
    
    frames = []
    device = "cpu"
    
    for step_i in range(env.max_steps):
        # Render the current frame (no camera rotation)
        rgba = env.render()  
        frames.append(rgba)
        
        # Choose action from the trained policy
        state_t = obs_to_3dtensor(obs, device=device)
        with torch.no_grad():
            q_vals = policy_net(state_t.unsqueeze(0))  # shape [1,6]
            action = q_vals.argmax(dim=1).item()
        
        obs_next, reward, done, _ = env.step(action)
        obs = obs_next
        
        if done:
            # Capture one last frame after termination
            frames.append(env.render())
            break
    
    imageio.mimsave(output_gif, frames, fps=2)
    print(f"Demo GIF saved to {output_gif}")



In [None]:
# ----------------------------
# 6) EXAMPLE USAGE
# ----------------------------
model = train_dqn_3d(
    num_episodes=1000,    # Increase for better policies
    grid_size=8,
    max_steps=200,
    gamma=0.9,
    lr=1e-3,
    batch_size=16,      # smaller batch if your CPU is slow
    epsilon_start=1.0,
    epsilon_end=0.1,
    epsilon_decay=0.98,
    replay_capacity=5000,
    updates_per_episode=10
)

Episode 1/10, total_reward=8.30, eps=0.98
Episode 2/10, total_reward=17.80, eps=0.96
Episode 3/10, total_reward=-1.20, eps=0.94
Episode 4/10, total_reward=7.90, eps=0.92
Episode 5/10, total_reward=-3.10, eps=0.90
Episode 6/10, total_reward=53.00, eps=0.89
Episode 7/10, total_reward=-2.90, eps=0.87
Episode 8/10, total_reward=16.10, eps=0.85
Episode 9/10, total_reward=33.40, eps=0.83
Episode 10/10, total_reward=25.00, eps=0.82
3D DQN training complete.


In [62]:
!tensorboard --logdir runs

TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [None]:
create_3d_demo_gif(model, grid_size=8, max_steps=10, output_gif="milling_3d_demo.gif")

  rgb_buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape((h, w, 3))


Demo GIF saved to milling_3d_demo.gif
