In [21]:
import json
import random
import os
from tqdm import tqdm
from pathlib import Path
from typing import List, Tuple


import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [22]:
NUM_LAYOUTS = 10000       # total layouts to generate
MIN_BOXES = 5
MAX_BOXES = 20
NUM_CLASSES = 6           # number of product types
OUT_PATH = "/kaggle/working/layouts.json"

# Shelf structure parameters
MIN_ROWS = 2
MAX_ROWS = 5
ROW_GAP = 0.02            # vertical space between rows
X_MARGIN = 0.02
Y_MARGIN = 0.02

In [23]:
def generate_layout():
    layout = []
    n_rows = random.randint(MIN_ROWS, MAX_ROWS)
    total_height = 1.0 - 2 * Y_MARGIN - ROW_GAP * (n_rows - 1)
    row_height = total_height / n_rows

    y_start = Y_MARGIN
    for r in range(n_rows):
        n_boxes = random.randint(MIN_BOXES // n_rows, MAX_BOXES // n_rows)
        # Random horizontal segmentation
        x_positions = sorted([random.random() for _ in range(n_boxes - 1)])
        x_positions = [0.0] + x_positions + [1.0]

        for i in range(n_boxes):
            x0 = x_positions[i]
            x1 = x_positions[i + 1]
            w = max(0.05, (x1 - x0) * random.uniform(0.8, 1.0))
            cx = X_MARGIN + x0 + w / 2
            h = row_height * random.uniform(0.8, 1.0)
            cy = y_start + h / 2
            cls = random.randint(0, NUM_CLASSES - 1)
            layout.append([cx, cy, w, h, cls])

        y_start += row_height + ROW_GAP

    return layout

In [24]:
if __name__ == "__main__":
    layouts = []
    for _ in tqdm(range(NUM_LAYOUTS), desc="Generating layouts"):
        layouts.append(generate_layout())

    os.makedirs(os.path.dirname(OUT_PATH) or ".", exist_ok=True)
    with open(OUT_PATH, "w") as f:
        json.dump(layouts, f)

    print(f"\n✅ Generated {NUM_LAYOUTS} synthetic layouts and saved to '{OUT_PATH}'")

Generating layouts: 100%|██████████| 10000/10000 [00:00<00:00, 33952.63it/s]



✅ Generated 10000 synthetic layouts and saved to '/kaggle/working/layouts.json'


In [25]:
import torch
from torch.utils.data import Dataset
import numpy as np
import json

class PlanogramDataset(Dataset):
    def __init__(self, json_path, N_max=32, num_classes=6):
        self.N_max = N_max
        self.num_classes = num_classes
        with open(json_path, 'r') as f:
            self.layouts = json.load(f)

    def __len__(self):
        return len(self.layouts)

    def __getitem__(self, idx):
        layout = self.layouts[idx]
        k = len(layout)
        n = min(k, self.N_max)
        data = np.zeros((self.N_max, 4 + self.num_classes + 1), dtype=np.float32)
        mask = np.zeros((self.N_max,), dtype=np.float32)

        for i in range(n):
            cx, cy, w, h, cls = layout[i]
            onehot = np.zeros((self.num_classes,), dtype=np.float32)
            onehot[int(cls)] = 1.0
            data[i, :4] = [cx, cy, w, h]
            data[i, 4:4+self.num_classes] = onehot
            data[i, -1] = 1.0
            mask[i] = 1.0

        return torch.from_numpy(data), torch.from_numpy(mask)


In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Critic(nn.Module):
    def __init__(self, N_max=32, num_classes=6, hidden=512):
        super().__init__()
        in_dim = N_max * (4 + num_classes + 1) + N_max  # include mask
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.LeakyReLU(0.2, True),
            nn.Linear(hidden, hidden),
            nn.LeakyReLU(0.2, True),
            nn.Linear(hidden, 1)
        )

    def forward(self, x, mask):
        B = x.shape[0]
        flat = x.view(B, -1)
        flat_mask = mask.view(B, -1)
        inp = torch.cat([flat, flat_mask], dim=1)
        return self.net(inp).squeeze(1)


In [27]:
import torch
import torch.nn.functional as F

def gradient_penalty(critic, real, fake, mask, device, lambda_gp=10.0):
    B = real.shape[0]
    alpha = torch.rand(B, 1, 1, device=device).expand_as(real)
    interpolates = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    d_interpolates = critic(interpolates, mask)
    grads = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grads = grads.view(B, -1)
    gp = ((grads.norm(2, dim=1) - 1)**2).mean() * lambda_gp
    return gp

def overlap_penalty(boxes_tensor, mask_tensor, iou_threshold=0.05):
    B, N, _ = boxes_tensor.shape
    penalty = 0.0
    eps = 1e-6
    for b in range(B):
        boxes = boxes_tensor[b]
        mask = mask_tensor[b]
        valid_idx = (mask > 0.5).nonzero(as_tuple=False).squeeze(1)
        m = valid_idx.numel()
        if m <= 1: continue
        sel = boxes[valid_idx]
        cx, cy, w, h = sel[:,0], sel[:,1], sel[:,2], sel[:,3]
        x1 = cx - w/2; y1 = cy - h/2; x2 = cx + w/2; y2 = cy + h/2
        xx1 = torch.max(x1.unsqueeze(1), x1.unsqueeze(0))
        yy1 = torch.max(y1.unsqueeze(1), y1.unsqueeze(0))
        xx2 = torch.min(x2.unsqueeze(1), x2.unsqueeze(0))
        yy2 = torch.min(y2.unsqueeze(1), y2.unsqueeze(0))
        inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0)
        area = w * h
        union = area.unsqueeze(1) + area.unsqueeze(0) - inter + eps
        iou = inter / union
        iou = iou - torch.diag(torch.diag(iou))
        penalty += F.relu(iou - iou_threshold).sum()
    return penalty / max(B,1)


In [28]:
import matplotlib.pyplot as plt
import numpy as np

def plot_layout(boxes, mask, title=None, save_path=None):
    fig, ax = plt.subplots(1,1, figsize=(6,3))
    ax.set_xlim(0,1); ax.set_ylim(1,0); ax.set_xticks([]); ax.set_yticks([])
    n = boxes.shape[0]
    for i in range(n):
        if mask[i] < 0.5: continue
        cx, cy, w, h = boxes[i]
        rect = plt.Rectangle((cx-w/2, cy-h/2), w, h, fill=False, edgecolor='C0')
        ax.add_patch(rect)
    if title: ax.set_title(title)
    if save_path:
        plt.savefig(save_path, bbox_inches='tight'); plt.close(fig)
    else:
        plt.show()


In [29]:
import os
import torch
from torch.utils.data import DataLoader
from dataset import PlanogramDataset
from model_generator import Generator
from model_critic import Critic
from losses import gradient_penalty, overlap_penalty
from visualize import plot_layout
import json

def train():
    # -------------------------
    # Load synthetic layouts
    # -------------------------
    with open("layouts.json", "r") as f:
        layouts = json.load(f)

    ds = PlanogramDataset(layouts, N_max=32, num_classes=6)
    dl = DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, drop_last=True)

    # -------------------------
    # Initialize models and optimizers
    # -------------------------
    G = Generator(z_dim=128, N_max=32, num_classes=6).to("cuda")
    D = Critic(N_max=32, num_classes=6).to("cuda")

    optG = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optD = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.9))

    fixed_z = torch.randn(16, 128, device="cuda")  # fixed noise for visualization

    # -------------------------
    # Training loop
    # -------------------------
    for epoch in range(10):  # set number of epochs
        for real_batch, mask_batch in dl:
            real_batch = real_batch.to("cuda")
            mask_batch = mask_batch.to("cuda")
            B = real_batch.shape[0]

            # --------- Update critic ---------
            for _ in range(5):  # n_critic
                z = torch.randn(B, 128, device="cuda")
                fake = G(z).detach()
                D_real = D(real_batch, mask_batch)
                D_fake = D(fake, mask_batch)
                gp = gradient_penalty(D, real_batch, fake, mask_batch, device="cuda")
                loss_D = D_fake.mean() - D_real.mean() + gp

                optD.zero_grad()
                loss_D.backward()
                optD.step()

            # --------- Update generator ---------
            z = torch.randn(B, 128, device="cuda")
            fake = G(z)
            D_fake_forG = D(fake, mask_batch)
            loss_G = -D_fake_forG.mean()

            boxes_pred = fake[:, :, :4]
            occ = fake[:, :, -1]
            overlap = overlap_penalty(boxes_pred, (occ > 0.5).float())
            loss_G += 10.0 * overlap  # lambda_overlap = 10

            optG.zero_grad()
            loss_G.backward()
            optG.step()

        # --------- End epoch: print and save sample ---------
        print(f"Epoch {epoch+1}/10 | loss_D {loss_D.item():.4f} | "
              f"loss_G {loss_G.item():.4f} | overlap {overlap.item():.4f}")

        # visualize a few samples
        with torch.no_grad():
            samples = G(fixed_z).cpu()
        for i in range(min(4, samples.shape[0])):
            boxes = samples[i, :, :4].numpy()
            occ_mask = (samples[i, :, -1].numpy() > 0.5).astype(float)
            os.makedirs("samples", exist_ok=True)
            save_path = f"samples/epoch{epoch+1}_sample{i}.png"
            plot_layout(boxes, occ_mask, title=f"epoch {epoch+1}", save_path=save_path)

    print("Training finished")


if __name__ == "__main__":
    train()


ModuleNotFoundError: No module named 'dataset'