
# PPO + Transformer for 3D Box Packing (CPU‑only, Colab‑ready)

This notebook implements **Transformer‑based policy** trained with **PPO** for a simplified 3D bin‑packing environment.  
It is designed to run **on CPU (no CUDA required)** and is modularized into clear functions.

**Key design choices (PPO‑focused):**
- Policy predicts **scores over candidate placements** (Categorical over N candidates).  
- Proper **PPO returns/advantages (GAE)** with bootstrap value.  
- **Reward scaling & clipping** for stable learning.  
- **Pre‑LN Transformer**, gradient clipping, and entropy regularization.  
- CPU‑only — safe for Colab without GPU setup.


In [1]:

# !pip install numpy pyyaml matplotlib imageio pandas tensorboard

import math
import random
import numpy as np
import yaml
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import imageio
import pandas as pd
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

device = torch.device("cpu")
print("Using device:", device)


Using device: cpu


In [2]:

CONFIG = {
    "seed": 1234,
    "log": {"plot_interval": 20, "save_dir": "/mnt/data/ppo_tf_logs"},
    "env": {
        "bin_size": [25, 32, 50],
        "num_boxes": 60,
        "min_size": 4,
        "max_size": 12,
        "thin_threshold": 6,
        "max_weight": 500000,
        "zone_limits": [50000, 60000, 100000, 50000]
    },
    "candidates": {"max_N": 32, "grid_step": 3, "floor_first": True},
    "reward": {
        "w_bbox": 0.8, "w_contact": 1.0, "w_volume": 3.0,
        "w_wall": 0.3, "w_ems": 0.3, "w_height": 0.1,
        "w_layer": 0.5, "w_small": 0.3, "w_ycenter": 0.2,
        "w_unload": 0.2, "penalty": 0.5
    },
    "reward_norm": {"scale": 5.0, "clip": 5.0},
    "ppo": {
        "gamma": 0.995, "lam": 0.95, "clip": 0.15,
        "value_coeff": 0.5, "entropy_coeff": 0.01,
        "lr": 3e-4, "max_grad_norm": 1.0
    },
    "train": {"episodes": 200, "save_gif_every": 50}
}
print("CONFIG loaded.")


CONFIG loaded.


In [4]:
import os

def set_seed(seed:int=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def ensure_dir(path:str):
    os.makedirs(path, exist_ok=True)
    return path

set_seed(CONFIG["seed"])
_ = ensure_dir(CONFIG["log"]["save_dir"])


## Environment (3D bin packing, simplified)


In [5]:

from dataclasses import dataclass
from typing import List, Tuple

@dataclass
class EnvConfig:
    bin_size: Tuple[int,int,int]
    num_boxes: int
    min_size: int
    max_size: int
    thin_threshold: int
    max_weight: float
    zone_limits: List[float]

@dataclass
class CandConfig:
    max_N: int
    grid_step: int
    floor_first: bool

def sort_boxes_by_priority(boxes):
    def priority(b):
        areas = sorted([b[0]*b[1], b[1]*b[2], b[0]*b[2]], reverse=True)
        return (min(b) < CONFIG["env"]["min_size"], -areas[0], -areas[1], -areas[2])
    return sorted(boxes, key=priority)

def generate_boxes(n, min_s, max_s):
    raw = [tuple(np.random.randint(min_s, max_s+1, size=3)) for _ in range(n)]
    return sort_boxes_by_priority(raw)

def get_orientations(box):
    l, w, h = box
    return [(l,w,h),(w,l,h),(h,l,w),(l,h,w),(w,h,l),(h,w,l)]

def is_collision(pos, size, placed, bin_size):
    x,y,z = pos; dx,dy,dz = size
    if x+dx > bin_size[0] or y+dy > bin_size[1] or z+dz > bin_size[2]:
        return True
    for (px,py,pz),(pd,pw,ph) in placed:
        if not (x+dx<=px or x>=px+pd or y+dy<=py or y>=py+pw or z+dz<=pz or z>=pz+ph):
            return True
    return False

def is_stable(pos, size, placed):
    x,y,z = pos; dx,dy,dz = size
    if z == 0: return True
    support=0
    for (px,py,pz),(pd,pw,ph) in placed:
        if pz+ph == z:
            ox = max(0, min(x+dx, px+pd)-max(x, px))
            oy = max(0, min(y+dy, py+pw)-max(y, py))
            support += ox*oy
    return support >= 0.5*dx*dy

def get_ems(placed, bin_size):
    if not placed: return [(0,0,0)]
    ems = set()
    for (x,y,z),(dx,dy,dz) in placed:
        for off in [(dx,0,0),(0,dy,0),(0,0,dz)]:
            pt=(x+off[0], y+off[1], z+off[2])
            if 0<=pt[0]<bin_size[0] and 0<=pt[1]<bin_size[1] and 0<=pt[2]<bin_size[2]:
                ems.add(pt)
    return list(ems)

def overlap_with_ems(pos, size, ems):
    x,y,z = pos; dx,dy,dz = size
    vol=0
    for ex,ey,ez in ems:
        ox = max(0, min(x+dx, ex)-max(x, ex))
        oy = max(0, min(y+dy, ey)-max(y, ey))
        oz = max(0, min(z+dz, ez)-max(z, ez))
        vol += ox*oy*oz
    return vol

def generate_candidates(bin_size, placed, box, cand_cfg: CandConfig):
    Lx,Ly,Lz = bin_size
    step = max(1, cand_cfg.grid_step)
    anchors=set()
    for x in range(0, Lx, step):
        for y in range(0, Ly, step):
            anchors.add((x,y,0))
    for z in range(0, Lz, max(1, Lz//max(1, (Lz//step)))):
        for x in range(0, Lx, step):
            anchors.add((x,0,z))
        for y in range(0, Ly, step):
            anchors.add((0,y,z))
    ems = get_ems(placed, bin_size)
    anchors |= set(ems)

    valid=[]
    for a in anchors:
        for rot in get_orientations(box):
            if is_collision(a, rot, placed, bin_size) or not is_stable(a, rot, placed):
                continue
            wasted = np.prod(rot) - overlap_with_ems(a, rot, ems)
            valid.append((a, rot, wasted))
    if not valid:
        for _ in range(10):
            rx = np.random.randint(0, max(1, Lx-box[0]))
            ry = np.random.randint(0, max(1, Ly-box[1]))
            pos = (rx, ry, 0)
            for rot in get_orientations(box):
                if not is_collision(pos, rot, placed, bin_size) and is_stable(pos, rot, placed):
                    valid.append((pos, rot, np.prod(rot)))
                    break
        if not valid:
            valid=[((0,0,0), get_orientations(box)[0], float('inf'))]
    valid.sort(key=lambda t: t[2])
    valid = valid[:cand_cfg.max_N]
    feats = np.array([[*p, *r] for p,r,_ in valid], dtype=np.float32)
    return valid, feats

def compute_reward_components(pos, size, placed, bin_size):
    if not placed:
        bbox_score = 1.0
    else:
        xs,ys,zs=[],[],[]
        for (x,y,z),(dx,dy,dz) in placed:
            xs += [x, x+dx]; ys += [y, y+dy]; zs += [z, z+dz]
        vol_box = (max(xs)-min(xs))*(max(ys)-min(ys))*(max(zs)-min(zs))
        used    = sum(dx*dy*dz for _,(dx,dy,dz) in placed)
        bbox_score = 1.0 - (vol_box - used)/float(np.prod(bin_size))

    x0,y0,z0 = pos; dx,dy,dz = size
    contact=0.0
    if z0==0: contact=1.0
    else:
        for (px,py,pz),(pd,pw,ph) in placed:
            if pz+ph == z0:
                ox = max(0, min(x0+dx, px+pd)-max(x0, px))
                oy = max(0, min(y0+dy, py+pw)-max(y0, py))
                contact += (ox*oy)/(dx*dy+1e-6)
        contact = min(1.0, contact)

    wall_area=0.0
    if x0==0: wall_area += dy*dz
    if y0==0: wall_area += dx*dz
    if z0==0: wall_area += dx*dy
    sa = dx*dy + dy*dz + dx*dz
    wall_score = wall_area/(sa+1e-6)

    if not placed: height_score=1.0
    else:
        avg_top = np.mean([z+sz for (x,y,z),(sx,sy,sz) in placed])
        height_score = 1.0 - avg_top/bin_size[2]

    ems_overlap = overlap_with_ems(pos, size, get_ems(placed, bin_size))/float(np.prod(size)+1e-6)

    y_center = y0 + dy/2
    y_score  = 1.0 - (y_center/bin_size[1])

    return {"bbox": bbox_score, "contact": contact, "wall": wall_score,
            "height": height_score, "ems": ems_overlap, "ycenter": y_score}

class PackingEnv:
    def __init__(self, cfg:EnvConfig, cand_cfg:CandConfig):
        self.cfg = cfg
        self.cand_cfg = cand_cfg
        self.bin_size = tuple(cfg.bin_size)
        self.reset()

    def reset(self):
        self.placed: List[Tuple[Tuple[int,int,int], Tuple[int,int,int]]] = []
        self.index=0
        self.used_volume=0.0
        boxes = generate_boxes(self.cfg.num_boxes, self.cfg.min_size, self.cfg.max_size)
        self.boxes = sorted(boxes, key=lambda b: b[0]*b[1]*b[2], reverse=True)
        return self.boxes[0]

    def step(self, action):
        if self.index >= len(self.boxes): return None, 0.0, True
        pos, rot = action
        if is_collision(pos, rot, self.placed, self.bin_size) or not is_stable(pos, rot, self.placed):
            r = -CONFIG["reward"]["penalty"]
            self.index+=1
            if self.index>=len(self.boxes): return None, r, True
            return self.boxes[self.index], r, False

        self.placed.append((pos, rot))
        self.index += 1
        self.used_volume += float(np.prod(rot))

        comps = compute_reward_components(pos, rot, self.placed[:-1], self.bin_size)
        vol_ratio = self.used_volume/float(np.prod(self.bin_size))
        w = CONFIG["reward"]
        raw = (w["w_bbox"]*comps["bbox"] +
               w["w_contact"]*comps["contact"] +
               w["w_wall"]*comps["wall"] +
               w["w_height"]*comps["height"] +
               w["w_ems"]*comps["ems"] +
               w["w_ycenter"]*comps["ycenter"] +
               w["w_volume"]*vol_ratio)
        scale = CONFIG["reward_norm"]["scale"]
        clipv = CONFIG["reward_norm"]["clip"]
        r = max(-clipv, min(clipv, raw*scale))

        if self.index>=len(self.boxes): return None, r, True
        return self.boxes[self.index], r, False

    def valid_candidates(self, box):
        return generate_candidates(self.bin_size, self.placed, box,
                                   CandConfig(**CONFIG["candidates"]))


In [6]:

def draw_bin_with_boxes(bin_size, placed, gif_path):
    frames=[]
    placed_sorted = sorted(placed, key=lambda b: (b[0][2], b[0][1], b[0][0]))
    for i in range(1, len(placed_sorted)+1):
        fig = plt.figure(figsize=(5,5))
        ax = fig.add_subplot(111, projection='3d')
        ax.set(xlim=(0,bin_size[0]), ylim=(0,bin_size[1]), zlim=(0,bin_size[2]))
        for (x,y,z),(dx,dy,dz) in placed_sorted[:i]:
            v = [[x,y,z],[x+dx,y,z],[x+dx,y+dy,z],[x,y+dy,z],
                 [x,y,z+dz],[x+dx,y,z+dz],[x+dx,y+dy,z+dz],[x,y+dy,z+dz]]
            faces = [[v[ii] for ii in face] for face in
                     [[0,1,2,3],[4,5,6,7],[0,1,5,4],[2,3,7,6],[1,2,6,5],[0,3,7,4]]]
            ax.add_collection3d(Poly3DCollection(faces, facecolors='skyblue', edgecolors='gray', alpha=0.7))
        fig.canvas.draw()
        w,h = fig.canvas.get_width_height()
        buf = np.frombuffer(fig.canvas.tostring_argb(), dtype='uint8').reshape(h, w, 4)[:, :,1:]
        frames.append(buf)
        plt.close(fig)
    imageio.mimsave(gif_path, frames, duration=0.3)
    return gif_path



## Transformer Policy (scores over candidates)


In [7]:

class CandidatePolicy(nn.Module):
    def __init__(self, d_model=256, nhead=4, nlayers_ctx=2, nlayers_cand=2, pre_ln=True):
        super().__init__()
        def enc(nl):
            layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                               batch_first=True, norm_first=pre_ln)
            return nn.TransformerEncoder(layer, num_layers=nl)

        self.env_proj  = nn.Linear(6, d_model)
        self.box_proj  = nn.Linear(3, d_model)
        self.ctx_enc   = enc(nlayers_ctx)

        self.cand_proj = nn.Linear(6, d_model)
        self.cand_enc  = enc(nlayers_cand)

        self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.score_head  = nn.Sequential(nn.Linear(3*d_model, d_model),
                                         nn.ReLU(),
                                         nn.Linear(d_model, 1))
        self.value_head  = nn.Sequential(nn.Linear(d_model, d_model),
                                         nn.ReLU(),
                                         nn.Linear(d_model, 1))

    def forward(self, env_feat, box_feat, cand_feats):
        B,N,_ = cand_feats.shape
        ctx_tokens = torch.stack([self.env_proj(env_feat), self.box_proj(box_feat)], dim=1)  # [B,2,d]
        H_ctx = self.ctx_enc(ctx_tokens)
        h_ctx = H_ctx[:,0]

        H_cand = self.cand_enc(self.cand_proj(cand_feats))

        q = h_ctx.unsqueeze(1)
        h_attn, _ = self.cross_attn(q, H_cand, H_cand)
        h_attn = h_attn.squeeze(1)

        h_ctx_rep  = h_ctx.unsqueeze(1).expand(B,N,-1)
        h_attn_rep = h_attn.unsqueeze(1).expand(B,N,-1)
        feats = torch.cat([h_attn_rep*H_cand, h_ctx_rep, H_cand], dim=-1)
        scores = self.score_head(feats).squeeze(-1)
        value  = self.value_head(h_ctx).squeeze(-1)
        return scores, value



## PPO Agent (returns, GAE, clipped objective)


In [9]:

@dataclass
class PPOCfg:
    gamma: float
    lam: float
    clip: float
    value_coeff: float
    entropy_coeff: float
    lr: float
    max_grad_norm: float

class PPOAgent:
    def __init__(self, policy:CandidatePolicy, cfg:PPOCfg):
        self.policy = policy.to(device)
        self.cfg = cfg
        self.opt = torch.optim.Adam(self.policy.parameters(), lr=cfg.lr)

    @staticmethod
    def compute_returns_advantages(rewards, values, last_value, gamma, lam):
        vals = values + [last_value]
        T = len(rewards)
        adv, advs = 0.0, [0.0]*T
        for t in reversed(range(T)):
            delta = rewards[t] + gamma*vals[t+1] - vals[t]
            adv = delta + gamma*lam*adv
            advs[t] = adv
        returns = [advs[t] + vals[t] for t in range(T)]
        return torch.tensor(returns, dtype=torch.float32), torch.tensor(advs, dtype=torch.float32)

    def select_action(self, env_feat, box_feat, cand_feats):
        scores, value = self.policy(env_feat, box_feat, cand_feats)
        scores = torch.nan_to_num(scores, nan=0.0, posinf=0.0, neginf=0.0)
        dist = Categorical(logits=scores)
        a = dist.sample()
        logp = dist.log_prob(a)
        ent = dist.entropy().mean()
        return a.item(), logp.squeeze(), ent, value.squeeze()

    def update(self, logps, values, advantages, entropies, returns):
        advantages = (advantages - advantages.mean())/(advantages.std()+1e-8)
        policy_loss = -(torch.stack(logps) * advantages).mean()
        value_loss  = F.mse_loss(torch.stack(values).squeeze(), returns)
        entropy_loss= torch.stack(entropies).mean()
        loss = policy_loss + self.cfg.value_coeff*value_loss - self.cfg.entropy_coeff*entropy_loss
        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.cfg.max_grad_norm)
        self.opt.step()
        return {"loss": float(loss.item()), "policy_loss": float(policy_loss.item()),
                "value_loss": float(value_loss.item()), "entropy": float(entropy_loss.item())}



## Training Loop


In [11]:
def env_summary(env:PackingEnv):
    vol_util = env.used_volume/float(np.prod(env.bin_size))
    placed = len(env.placed)/max(1, env.cfg.num_boxes)
    avg_top = 0.0 if not env.placed else np.mean([z+sz for (x,y,z),(sx,sy,sz) in env.placed])/env.bin_size[2]
    y_center = 0.0 if not env.placed else np.mean([p[1]+s[1]/2 for p,s in env.placed])/env.bin_size[1]
    spare_x = 1.0 - (max([p[0]+s[0] for p,s in env.placed]+[0])/env.bin_size[0])
    spare_z = 1.0 - (max([p[2]+s[2] for p,s in env.placed]+[0])/env.bin_size[2])
    return np.array([vol_util, placed, avg_top, y_center, spare_x, spare_z], dtype=np.float32)

def train():
    env_cfg  = EnvConfig(**CONFIG["env"])
    cand_cfg = CandConfig(**CONFIG["candidates"])
    env = PackingEnv(env_cfg, cand_cfg)

    policy = CandidatePolicy(d_model=256, nhead=4, nlayers_ctx=2, nlayers_cand=2, pre_ln=True)
    agent  = PPOAgent(policy, PPOCfg(**CONFIG["ppo"]))

    reward_log=[]; entropy_log=[]; best = {"score": -1e9, "placed": []}
    episodes = CONFIG["train"]["episodes"]
    plot_interval = CONFIG["log"]["plot_interval"]
    save_dir = CONFIG["log"]["save_dir"]

    last_png_path = None
    last_gif_path = None

    for ep in range(1, episodes+1):
        box = env.reset()
        logps=[]; vals=[]; ents=[]; rews=[]
        while env.index < env.cfg.num_boxes:
            cands, feats = env.valid_candidates(box)
            if not cands:
                env.index += 1
                if env.index >= env.cfg.num_boxes: break
                box = env.boxes[env.index]
                continue

            env_feat = torch.tensor(env_summary(env), dtype=torch.float32, device=device).unsqueeze(0)
            box_feat = torch.tensor([box], dtype=torch.float32, device=device)[:, :3]
            cand_feats = torch.tensor(feats, dtype=torch.float32, device=device).unsqueeze(0)

            idx, logp, ent, val = agent.select_action(env_feat, box_feat, cand_feats)

            idx = int(np.clip(idx, 0, len(cands)-1))
            pos, rot, _ = cands[idx] # Unpack (pos, rot, wasted), ignoring wasted
            next_box, r, done = env.step((pos, rot)) # Pass only (pos, rot) to env.step

            logps.append(logp); vals.append(val); ents.append(ent); rews.append(float(r))
            if done: break
            box = next_box

        if env.index>=env.cfg.num_boxes or box is None:
            last_val = 0.0
        else:
            env_feat = torch.tensor(env_summary(env), dtype=torch.float32, device=device).unsqueeze(0)
            box_feat = torch.tensor([box], dtype=torch.float32, device=device)[:, :3]
            dummy = torch.zeros((1,1,6), dtype=torch.float32, device=device)
            _, last_val = policy(env_feat, box_feat, dummy)

        returns, advs = agent.compute_returns_advantages(rews, [v.item() for v in vals], float(last_val),
                                                         CONFIG["ppo"]["gamma"], CONFIG["ppo"]["lam"])
        stats = agent.update(logps, vals, advs, ents, returns)

        total_r = sum(rews)
        reward_log.append(total_r); entropy_log.append(stats["entropy"])

        vol_util = env.used_volume/float(np.prod(env.bin_size))
        score = total_r + 100.0*vol_util
        if score > best["score"]:
            best = {"score": score, "placed": list(env.placed)}

        if ep % plot_interval == 0 or ep == episodes:
            print(f"[Ep {ep:04d}] R={total_r:.2f} | Ent={stats['entropy']:.3f} | VolUtil={vol_util:.3f} | score={score:.2f}")
            plt.figure(figsize=(10,4))
            plt.subplot(1,2,1); plt.plot(reward_log); plt.title("Reward"); plt.grid(True)
            plt.subplot(1,2,2); plt.plot(entropy_log); plt.title("Entropy"); plt.grid(True)
            plt.tight_layout();
            png_path = f"{save_dir}/report_ep{ep}.png"
            plt.savefig(png_path); plt.close()
            last_png_path = png_path

        if ep % CONFIG["train"]["save_gif_every"] == 0 or ep == episodes:
            gif_path = f"{save_dir}/best_ep{ep}.gif"
            draw_bin_with_boxes(env.bin_size, best["placed"], gif_path)
            last_gif_path = gif_path

    return reward_log, entropy_log, best, last_png_path, last_gif_path

reward_log, entropy_log, best, last_png_path, last_gif_path = train()
print("Training finished.")



[Ep 0020] R=223.00 | Ent=0.649 | VolUtil=0.680 | score=290.99
[Ep 0040] R=201.00 | Ent=0.567 | VolUtil=0.734 | score=274.44
[Ep 0060] R=228.50 | Ent=0.499 | VolUtil=0.711 | score=299.55
[Ep 0080] R=272.50 | Ent=0.556 | VolUtil=0.689 | score=341.38
[Ep 0100] R=239.50 | Ent=0.490 | VolUtil=0.667 | score=306.18
[Ep 0120] R=289.00 | Ent=0.325 | VolUtil=0.707 | score=359.65
[Ep 0140] R=234.00 | Ent=0.257 | VolUtil=0.626 | score=296.57
[Ep 0160] R=300.00 | Ent=0.176 | VolUtil=0.651 | score=365.12
[Ep 0180] R=261.50 | Ent=0.108 | VolUtil=0.746 | score=336.10
[Ep 0200] R=234.00 | Ent=0.087 | VolUtil=0.601 | score=294.08
Training finished.


In [14]:

print("Saved under:", CONFIG["log"]["save_dir"])
from glob import glob
saved_files = sorted(glob(CONFIG["log"]["save_dir"]+"/*"))
print("Saved files:")
for f in saved_files:
    print(f)


Saved under: /mnt/data/ppo_tf_logs
Saved files:
/mnt/data/ppo_tf_logs/best_ep100.gif
/mnt/data/ppo_tf_logs/best_ep150.gif
/mnt/data/ppo_tf_logs/best_ep200.gif
/mnt/data/ppo_tf_logs/best_ep50.gif
/mnt/data/ppo_tf_logs/report_ep100.png
/mnt/data/ppo_tf_logs/report_ep120.png
/mnt/data/ppo_tf_logs/report_ep140.png
/mnt/data/ppo_tf_logs/report_ep160.png
/mnt/data/ppo_tf_logs/report_ep180.png
/mnt/data/ppo_tf_logs/report_ep20.png
/mnt/data/ppo_tf_logs/report_ep200.png
/mnt/data/ppo_tf_logs/report_ep40.png
/mnt/data/ppo_tf_logs/report_ep60.png
/mnt/data/ppo_tf_logs/report_ep80.png



### PPO + Transformer Stabilization recap
- Candidate categorical action space
- Proper GAE with bootstrap
- Reward scaling & clipping
- Pre‑LN Transformer + grad‑clip + entropy reg
