In [None]:
import json
import random
import os
from tqdm import tqdm
from pathlib import Path
from typing import List, Tuple


import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
NUM_LAYOUTS = 10000       # total layouts to generate
MIN_BOXES = 5
MAX_BOXES = 20
NUM_CLASSES = 6           # number of product types
OUT_PATH = "layouts.json"

# Shelf structure parameters
MIN_ROWS = 2
MAX_ROWS = 5
ROW_GAP = 0.02            # vertical space between rows
X_MARGIN = 0.02
Y_MARGIN = 0.02

In [None]:
def generate_layout():
    layout = []
    n_rows = random.randint(MIN_ROWS, MAX_ROWS)
    total_height = 1.0 - 2 * Y_MARGIN - ROW_GAP * (n_rows - 1)
    row_height = total_height / n_rows

    y_start = Y_MARGIN
    for r in range(n_rows):
        n_boxes = random.randint(MIN_BOXES // n_rows, MAX_BOXES // n_rows)
        # Random horizontal segmentation
        x_positions = sorted([random.random() for _ in range(n_boxes - 1)])
        x_positions = [0.0] + x_positions + [1.0]

        for i in range(n_boxes):
            x0 = x_positions[i]
            x1 = x_positions[i + 1]
            w = max(0.05, (x1 - x0) * random.uniform(0.8, 1.0))
            cx = X_MARGIN + x0 + w / 2
            h = row_height * random.uniform(0.8, 1.0)
            cy = y_start + h / 2
            cls = random.randint(0, NUM_CLASSES - 1)
            layout.append([cx, cy, w, h, cls])

        y_start += row_height + ROW_GAP

    return layout

In [None]:
class Config:
coco_json = "data/annotations.json" # replace path
images_dir = "data/images"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 42


# dataset
N_max = 32 # pad/truncate boxes per layout
num_classes = 50 # set your number of classes


# training
z_dim = 128
batch_size = 64
lr = 1e-4
betas = (0.5, 0.9)
epochs = 200
n_critic = 5
lambda_gp = 10.0
lambda_overlap = 1.0
save_every = 10
out_dir = "runs/baseline"


# reproducibility
torch.manual_seed(Config.seed)
np.random.seed(Config.seed)
random.seed(Config.seed)


os.makedirs(Config.out_dir, exist_ok=True)

In [None]:
def parse_coco_to_layouts(coco_json_path: str, images_dir: str, N_max: int, num_classes: int) -> List[dict]:
"""
Parse a COCO-style JSON and return list of layouts.
Each layout dict: {
'image_id': ...,
'boxes': numpy array shape (k,4) in [x,y,w,h] absolute px,
'labels': list of ints
'width': image width, 'height': image height
}
We'll convert to normalized [cx, cy, w, h] with values in [0,1]
"""
with open(coco_json_path, 'r') as f:
coco = json.load(f)


imgs = {img['id']: img for img in coco['images']}
anns_by_img = {}
for ann in coco['annotations']:
iid = ann['image_id']
anns_by_img.setdefault(iid, []).append(ann)


layouts = []
for iid, img in imgs.items():
anns = anns_by_img.get(iid, [])
if len(anns) == 0:
continue
H, W = img['height'], img['width']
boxes = []
labels = []
for a in anns:
x, y, w, h = a['bbox']
cx = (x + w/2.0) / W
cy = (y + h/2.0) / H
nw = w / W
nh = h / H
# clip to [0,1]
cx = min(max(cx, 0.0), 1.0)
cy = min(max(cy, 0.0), 1.0)
nw = min(max(nw, 1e-4), 1.0)
nh = min(max(nh, 1e-4), 1.0)
boxes.append([cx, cy, nw, nh])
labels.append(min(a.get('category_id', 1)-1, num_classes-1))


layouts.append({'image_id': iid, 'boxes': np.array(boxes, dtype=np.float32),
'labels': np.array(labels, dtype=np.int64), 'width': W, 'height': H})
return layouts

In [None]:
class PlanogramDataset(Dataset):
def __init__(self, layouts: List[dict], N_max: int, num_classes: int, shuffle=True):
self.layouts = layouts
self.N_max = N_max
self.num_classes = num_classes
if shuffle:
random.shuffle(self.layouts)


def __len__(self):
return len(self.layouts)


def __getitem__(self, idx):
item = self.layouts[idx]
boxes = item['boxes'] # (k,4)
labels = item['labels'] # (k,)
k = boxes.shape[0]
n = min(k, self.N_max)


# prepare tensors
data = np.zeros((self.N_max, 4 + self.num_classes + 1), dtype=np.float32)
mask = np.zeros((self.N_max,), dtype=np.float32)


for i in range(n):
cx, cy, w, h = boxes[i]
cls = labels[i]
onehot = np.zeros((self.num_classes,), dtype=np.float32)
onehot[cls] = 1.0
data[i, :4] = [cx, cy, w, h]
data[i, 4:4 + self.num_classes] = onehot
data[i, -1] = 1.0 # occupancy score
mask[i] = 1.0


return torch.from_numpy(data), torch.from_numpy(mask)

In [None]:
class Generator(nn.Module):
def __init__(self, z_dim, N_max, num_classes, hidden=512):
super().__init__()
self.z_dim = z_dim
self.N_max = N_max
self.num_classes = num_classes
out_dim = N_max * (4 + num_classes + 1) # [cx,cy,w,h] + onehot classes + occupancy


self.net = nn.Sequential(
nn.Linear(z_dim, hidden),
nn.LayerNorm(hidden),
nn.ReLU(True),
nn.Linear(hidden, hidden),
nn.LayerNorm(hidden),
nn.ReLU(True),
nn.Linear(hidden, out_dim),
)


def forward(self, z):
B = z.shape[0]
out = self.net(z)
out = out.view(B, self.N_max, 4 + self.num_classes + 1)
# boxes coords: apply sigmoid to ensure [0,1]
out[:, :, :4] = torch.sigmoid(out[:, :, :4])
# class logits -> softmax across classes
logits = out[:, :, 4:4 + self.num_classes]
probs = F.softmax(logits, dim=-1)
out[:, :, 4:4 + self.num_classes] = probs
# occupancy (last dim) -> sigmoid
out[:, :, -1] = torch.sigmoid(out[:, :, -1])
return out


class Critic(nn.Module):
def __init__(self, N_max, num_classes, hidden=512):
super().__init__()
in_dim = N_max * (4 + num_classes + 1) + N_max # include mask as extra channel
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.LeakyReLU(0.2, True),
nn.Linear(hidden, hidden),
nn.LeakyReLU(0.2, True),
nn.Linear(hidden, 1)
)


def forward(self, x, mask):
# x: (B, N_max, D); mask: (B, N_max)
B = x.shape[0]
flat = x.view(B, -1)
flat_mask = mask.view(B, -1)
inp = torch.cat([flat, flat_mask], dim=1)
return self.net(inp).squeeze(1)

In [None]:
def gradient_penalty(critic, real, fake, mask, device, lambda_gp=10.0):
B = real.shape[0]
alpha = torch.rand(B, 1, 1, device=device)
alpha = alpha.expand_as(real)
interpolates = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
interpolates_mask = mask # use real mask (approximation)


d_interpolates = critic(interpolates, interpolates_mask)
grads = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
grad_outputs=torch.ones_like(d_interpolates, device=device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
grads = grads.view(B, -1)
gp = ((grads.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
return gp




def overlap_penalty(boxes_tensor, mask_tensor, iou_threshold=0.05):
"""
Soft penalty: compute pairwise IoU between predicted boxes in a layout and penalize overlap > threshold.
boxes_tensor: (B, N, 4) with [cx,cy,w,h]
mask_tensor: (B, N) with occupancy in [0,1]
returns scalar tensor
"""
B, N, _ = boxes_tensor.shape
penalty = 0.0
eps = 1e-6
for b in range(B):
boxes = boxes_tensor[b]
mask = mask_tensor[b]
valid_idx = (mask > 0.5).nonzero(as_tuple=False).squeeze(1)
m = valid_idx.numel()
if m <= 1:
continue
sel = boxes[valid_idx] # (m,4)
# convert to xyxy
cx, cy, w, h = sel[:,0], sel[:,1], sel[:,2], sel[:,3]
x1 = cx - w/2; y1 = cy - h/2
x2 = cx + w/2; y2 = cy + h/2
# pairwise IoU
xx1 = torch.max(x1.unsqueeze(1), x1.unsqueeze(0))
yy1 = torch.max(y1.unsqueeze(1), y1.unsqueeze(0))
xx2 = torch.min(x2.unsqueeze(1), x2.unsqueeze(0))
yy2 = torch.min(y2.unsqueeze(1), y2.unsqueeze(0))
inter_w = (xx2 - xx1).clamp(min=0)
inter_h = (yy2 - yy1).clamp(min=0)
inter = inter_w * inter_h
area = w * h
union = area.unsqueeze(1) + area.unsqueeze(0) - inter + eps
iou = inter / union
# zero out diagonal
iou = iou - torch.diag(torch.diag(iou))
# penalize iou > threshold
over = F.relu(iou - iou_threshold)
penalty = penalty + over.sum()


return penalty / (B if B>0 else 1)

In [None]:
def plot_layout(boxes, mask, title=None, save_path=None):
# boxes: (N,4) normalized, mask: (N,)
fig, ax = plt.subplots(1,1, figsize=(6,3))
ax.set_xlim(0,1); ax.set_ylim(1,0)
ax.set_xticks([]); ax.set_yticks([])
n = boxes.shape[0]
for i in range(n):
if mask[i] < 0.5:
continue
cx, cy, w, h = boxes[i]
x = cx - w/2; y = cy - h/2
rect = plt.Rectangle((x,y), w, h, fill=False, edgecolor='C0')
ax.add_patch(rect)
if title:
ax.set_title(title)
if save_path:
plt.savefig(save_path, bbox_inches='tight')
plt.close(fig)
else:
plt.show()

In [None]:
def train():
real_batch = real_batch.to(Config.device) # (B, N, D)
mask_batch = mask_batch.to(Config.device)
B = real_batch.shape[0]


# -----------------
# Update critic n_critic times
# -----------------
for _ in range(Config.n_critic):
z = torch.randn(B, Config.z_dim, device=Config.device)
fake = G(z).detach()


D_real = D(real_batch, mask_batch)
D_fake = D(fake, mask_batch) # use same mask shape as approximation


gp = gradient_penalty(D, real_batch, fake, mask_batch, Config.device, Config.lambda_gp)
loss_D = D_fake.mean() - D_real.mean() + gp


optD.zero_grad()
loss_D.backward()
optD.step()


# -----------------
# Update generator
# -----------------
z = torch.randn(B, Config.z_dim, device=Config.device)
fake = G(z)
D_fake_forG = D(fake, mask_batch)
loss_G = - D_fake_forG.mean()
# soft overlap penalty
boxes_pred = fake[:, :, :4]
occ = fake[:, :, -1]
overlap = overlap_penalty(boxes_pred, (occ>0.5).float(), iou_threshold=0.05)
loss_G = loss_G + Config.lambda_overlap * overlap


optG.zero_grad()
loss_G.backward()
optG.step()


iters += 1


# end epoch
print(f"Epoch {epoch+1}/{Config.epochs} | loss_D {loss_D.item():.4f} | loss_G {loss_G.item():.4f} | overlap {overlap.item():.4f}")


# sample and save
if (epoch+1) % Config.save_every == 0 or epoch == 0:
with torch.no_grad():
samples = G(fixed_z).cpu()
# visualize first sample in grid
for i in range(min(4, samples.shape[0])):
boxes = samples[i,:, :4].numpy()
occ = (samples[i,:, -1].numpy() > 0.5).astype(np.float32)
save_p = os.path.join(Config.out_dir, f"epoch{epoch+1}_sample{i}.png")
plot_layout(boxes, occ, title=f"epoch{epoch+1}", save_path=save_p)


# checkpoint
torch.save({'G':G.state_dict(), 'D':D.state_dict(), 'optG':optG.state_dict(), 'optD':optD.state_dict()},
os.path.join(Config.out_dir, f"ckpt_epoch{epoch+1}.pth"))


print("Training finished")


if __name__ == '__main__':
train()