# 🧠 GAN for 3D Voxel Shape Generation

**Project Title:** ShapeGAN — Generating 3D Voxel Chairs with Generative Adversarial Networks

**Student Name:** Yazdan Ghanavati

**Course:** 3D Vision / ICT  

**Date:** July 2025  

**Institution:** University of Padova

---

### 📍 Project Summary

This notebook presents a complete pipeline for training a Generative Adversarial Network (GAN) to synthesize 3D voxel-based shapes. It uses chair models from the ShapeNet dataset, represented as signed distance fields in `.npy` format, and trains a Generator and Discriminator using PyTorch.

🛠 The entire codebase is embedded directly in the notebook for portability, clarity, and grading convenience.

---

### 💡 Inspiration

> Most of the original code structure and core ideas were inspired by the open-source repository:
> [marian42/shapegan](https://github.com/marian42/shapegan)

This project builds upon ShapeGAN’s foundational ideas and adapts them into a compact, streamlined notebook format suitable for academic demonstration and hands-on experimentation.

---

### 🔍 Highlights

- Fully embedded GAN architecture (Generator & Discriminator) using 3D convolutions
- Integrated training loop with checkpoint resumption
- Voxel shape generation with slice-based ASCII visualization
- Optional rendering support for 3D previews
- No external imports — everything is self-contained

---


# 📦 Section 1: Imports and Device Setup

This section loads all required Python packages for the GAN pipeline, including PyTorch for deep learning, NumPy for numerical operations, and utilities for file handling. It also detects whether a GPU is available and sets the correct device for computation.

💡 Using CUDA if available ensures faster training and generation.

---


In [1]:
# Source: util.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
import os
import sys
import time
import re

# Set device based on availability (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


# ⚙️ Section 2: Configuration and Hyperparameters

This section defines all the key training and dataset parameters for the GAN model. It includes file paths for saving checkpoints, model settings like batch size and learning rate, and the location of the voxel data.

📌 These values shape how your GAN learns and where it stores outputs. Updating them here ensures consistency across your pipeline.

---


In [2]:
# Source: train_gan.py

# --- Define Paths ---
MODEL_PATH = "models"
CHECKPOINT_PATH = os.path.join(MODEL_PATH, 'checkpoints')

# Create model directories if they don't exist
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)
if not os.path.exists(CHECKPOINT_PATH):
    os.makedirs(CHECKPOINT_PATH)

# --- Training Parameters ---
BATCH_SIZE = 32
TOTAL_TARGET_EPOCHS = 160
SAVE_EVERY_N_EPOCHS = 10

LEARNING_RATE_G = 0.0002
LEARNING_RATE_D = 0.0002
BETA1 = 0.5
LATENT_CODE_SIZE = 128

# --- Data Path & Workers ---
DATASET_PATH = 'C:/Users/yghan/Documents/ICT/3D-Vision/test/dataset/ShapeNet_SDF/chairs/voxels_32/*.npy'
NUM_WORKERS = 4


# 🗂️ Section 3: Dataset Loading

This section uses the `VoxelDataset` class to load voxel data files from the ShapeNet dataset. These 3D shapes are represented as `.npy` files containing signed distance fields (SDFs). The dataset loader normalizes the values and prepares batches for training.

💾 By using `VoxelDataset.glob()`, the script searches through the specified folder and loads every `.npy` file matching the pattern.

---


In [3]:
# Source: train_gan.py + datasets.py

from datasets import VoxelDataset

# Load voxel dataset from files matching the glob pattern
print(f"Loading dataset from: {DATASET_PATH}")
dataset = VoxelDataset.glob(DATASET_PATH)

# Wrap the dataset in a PyTorch DataLoader for batching
data_loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Dataset loaded. Found {len(dataset)} samples.")


Loading dataset from: C:/Users/yghan/Documents/ICT/3D-Vision/test/dataset/ShapeNet_SDF/chairs/voxels_32/*.npy
Dataset loaded. Found 6232 samples.


# 🧠 Section 4: Model Definitions (Generator & Discriminator)

This section defines the neural networks used in the GAN pipeline. The `Generator` uses 3D transposed convolutions to transform random noise vectors into voxel-based shapes, while the `Discriminator` tries to distinguish between real and fake voxel inputs using 3D convolutional layers.

🧩 Both models inherit from `SavableModule`, allowing them to be saved and loaded from checkpoints during training and evaluation. Their architectures are tuned for voxel resolution 32×32×32, which matches the dataset format.

---


In [13]:
# Source: Combined from init.py, gan.py, util.py

import torch
import torch.nn as nn
import os
import numpy as np

# --- Constants & Utility Setup ---
LATENT_CODE_SIZE = 128
MODEL_PATH = "models"
CHECKPOINT_PATH = os.path.join(MODEL_PATH, 'checkpoints')
standard_normal_distribution = torch.distributions.normal.Normal(0, 1)

# --- SavableModule Base Class ---
class SavableModule(nn.Module):
    def __init__(self, filename):
        super(SavableModule, self).__init__()
        self.filename = filename

    def get_filename(self, epoch=None, filename=None):
        if filename is None:
            filename = self.filename
        if epoch is None:
            return os.path.join(MODEL_PATH, filename)
        else:
            filename = filename.split('.')
            filename[-2] += '-epoch-{:05d}'.format(epoch)
            filename = '.'.join(filename)
            return os.path.join(CHECKPOINT_PATH, filename)

    def load(self, epoch=None):
        self.load_state_dict(torch.load(self.get_filename(epoch=epoch)), strict=False)

    def save(self, epoch=None):
        if epoch is not None and not os.path.exists(CHECKPOINT_PATH):
            os.mkdir(CHECKPOINT_PATH)
        torch.save(self.state_dict(), self.get_filename(epoch=epoch))

    @property
    def device(self):
        return next(self.parameters()).device

# --- Lambda Layer for Custom Operations ---
class Lambda(nn.Module):
    def __init__(self, function):
        super(Lambda, self).__init__()
        self.function = function

    def forward(self, x):
        return self.function(x)

# --- Generator Model ---
class Generator(SavableModule):
    def __init__(self):
        super(Generator, self).__init__(filename="generator.to")

        self.layers = nn.Sequential(
            nn.ConvTranspose3d(LATENT_CODE_SIZE, 256, 4, 1),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose3d(256, 128, 4, 2, 1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose3d(128, 64, 4, 2, 1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose3d(64, 1, 4, 2, 1),
            nn.Tanh()
        )
        self.to(device)

    def forward(self, x):
        x = x.reshape((-1, LATENT_CODE_SIZE, 1, 1, 1))
        return self.layers(x)

    def generate(self, sample_size=1):
        shape = torch.Size((sample_size, LATENT_CODE_SIZE))
        x = standard_normal_distribution.sample(shape).to(self.device)
        return self(x)

# --- Discriminator Model ---
class Discriminator(SavableModule):
    def __init__(self):
        super(Discriminator, self).__init__(filename="discriminator.to")
        self.use_sigmoid = True

        self.layers = nn.Sequential(
            nn.Conv3d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(256, 1, 4, 1, 0, bias=False),
            Lambda(lambda x: torch.sigmoid(x) if self.use_sigmoid else x)
        )
        self.to(device)

    def forward(self, x):
        if len(x.shape) < 5:
            x = x.unsqueeze(1)
        return self.layers(x).view(-1, 1)


# 🔁 Section 5: Training Logic and Checkpoint Resume

This section defines the GAN training loop using Binary Cross-Entropy loss and optimizers for both Generator and Discriminator. It also includes logic to resume training from the last saved checkpoint, ensuring training progress isn’t lost between runs.

🧠 The `train()` function handles real vs fake discrimination, generator updates, loss logging, and periodic model saving.

📋 Losses and discriminator predictions are logged per epoch for visualization and analysis.

---


In [11]:
# Source: train_gan.py

import torch.nn.functional as F
import re

# --- Loss Function ---
criterion = nn.BCEWithLogitsLoss()

# --- Helper: Find Last Saved Epoch ---
def find_last_saved_epoch():
    max_epoch = 0
    for f in os.listdir(CHECKPOINT_PATH):
        match = re.match(r'generator-epoch-(\d{5})\.to', f)
        if match:
            epoch = int(match.group(1))
            if epoch > max_epoch:
                max_epoch = epoch
    return max_epoch


# Source: embedded from previous section (fixed scope)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Set up optimizers INSIDE global scope so train() can use them
generator_optimizer = torch.optim.Adam(
    generator.parameters(),
    lr=LEARNING_RATE_G,
    betas=(BETA1, 0.999)
)

discriminator_optimizer = torch.optim.Adam(
    discriminator.parameters(),
    lr=LEARNING_RATE_D,
    betas=(BETA1, 0.999)
)

# --- Training Function ---
def train():
    global log_file

    for epoch_num in range(start_epoch_global + 1, TOTAL_TARGET_EPOCHS + 1):
        d_losses, g_losses = [], []
        d_preds_real, d_preds_fake = [], []
        epoch_start_time = time.time()

        for i, real_voxels in enumerate(data_loader):
            real_voxels = real_voxels.to(device)
            current_batch_size = real_voxels.size(0)

            real_labels = torch.full((current_batch_size,), 1.0, device=device)
            fake_labels = torch.full((current_batch_size,), 0.0, device=device)

            # --- Train Discriminator ---
            discriminator_optimizer.zero_grad()
            output_real = discriminator(real_voxels).view(-1)
            d_loss_real = criterion(output_real, real_labels)
            d_loss_real.backward()
            d_preds_real.append(torch.sigmoid(output_real).mean().item())

            noise = torch.randn(current_batch_size, LATENT_CODE_SIZE, device=device)
            fake_voxels = generator(noise)
            output_fake = discriminator(fake_voxels.detach()).view(-1)
            d_loss_fake = criterion(output_fake, fake_labels)
            d_loss_fake.backward()
            d_preds_fake.append(torch.sigmoid(output_fake).mean().item())

            d_loss = d_loss_real + d_loss_fake
            discriminator_optimizer.step()

            # --- Train Generator ---
            generator_optimizer.zero_grad()
            output_gen = discriminator(fake_voxels).view(-1)
            g_loss = criterion(output_gen, real_labels)
            g_loss.backward()
            generator_optimizer.step()

            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())

        # --- Epoch Summary ---
        avg_d_loss = sum(d_losses) / len(d_losses)
        avg_g_loss = sum(g_losses) / len(g_losses)
        avg_d_real = sum(d_preds_real) / len(d_preds_real)
        avg_d_fake = sum(d_preds_fake) / len(d_preds_fake)

        print(f"Epoch {epoch_num}/{TOTAL_TARGET_EPOCHS} - D_Loss: {avg_d_loss:.4f}, G_Loss: {avg_g_loss:.4f}, D(Real): {avg_d_real:.4f}, D(Fake): {avg_d_fake:.4f}")
        log_file.write(f"{epoch_num},{time.time() - epoch_start_time:.1f},{avg_d_loss:.4f},{avg_g_loss:.4f},{avg_d_real:.4f},{avg_d_fake:.4f}\n")
        log_file.flush()

        if epoch_num % SAVE_EVERY_N_EPOCHS == 0 or epoch_num == TOTAL_TARGET_EPOCHS:
            generator.save(epoch=epoch_num)
            discriminator.save(epoch=epoch_num)
            print(f"Models saved at epoch {epoch_num}.")

    print("Training complete.")


# 🚀 Section 6: Entry Point and Execution

This block initializes the models, resumes training from the last saved epoch if applicable, opens the log file for recording metrics, and begins the training process. By encapsulating it in a main guard (`if __name__ == '__main__':`), the code stays safe and executable in both notebook and script form.

📋 You can adjust the resume epoch, logging behavior, or total training epochs here without modifying the core `train()` loop.

---


In [None]:
# Source: train_gan.py

if __name__ == '__main__':
    # --- Resume from last saved epoch ---
    start_epoch_global = find_last_saved_epoch()
    print(f"Resuming from epoch {start_epoch_global}.")
    if start_epoch_global > 0:
        generator.load(epoch=start_epoch_global)
        discriminator.load(epoch=start_epoch_global)
        print("Models loaded successfully from checkpoint.")

    # --- Open log file for metric recording ---
    log_filename = os.path.join(MODEL_PATH, "log.txt")
    log_file = open(log_filename, "a")

    # --- Begin training ---
    train()

    # --- Close log file when done ---
    log_file.close()


# 🎨 Section 7: Shape Generation and Visualization

This section loads the trained Generator model from a selected epoch and uses it to generate new voxel shapes from random latent vectors. The output shapes can be visualized as ASCII text slices using `create_text_slice()` or rendered in 3D if the rendering modules are available.

📸 This stage showcases what the GAN has learned and lets us preview creative voxel outputs. You can customize sample size, checkpoint epoch, and visualization method.

---


In [None]:
# Source: visualization.py + generator preview

# --- Optional Visualization Helpers ---
def create_text_slice(voxel_data, threshold=0.0):
    """
    Creates a slice-by-slice text view of voxel tensor for previewing shape.
    Voxels with value above the threshold are shown as '*', below as ' '.
    """
    voxel_data = voxel_data.detach().cpu().numpy()
    shape = voxel_data.shape
    print(f"Voxel shape: {shape}")
    for slice_index in range(shape[2]):
        print(f"\nSlice {slice_index:02d}")
        for row in range(shape[1]):
            line = ""
            for col in range(shape[0]):
                line += "*" if voxel_data[col][row][slice_index] > threshold else " "
            print(line)

# --- Generate and Visualize Samples ---
# Load a trained generator from a specific epoch
preview_epoch = TOTAL_TARGET_EPOCHS  # or any previously saved epoch like 080
generator.load(epoch=preview_epoch)
print(f"Generator loaded from epoch {preview_epoch}")

# Sample new voxel shape
with torch.no_grad():
    generated_voxel = generator.generate(sample_size=1)[0, 0]  # [batch, channel, ...]

# Show ASCII slice preview
print("\n--- Voxel Slice Preview ---")
create_text_slice(generated_voxel, threshold=0.05)

# Optional: call MeshRenderer if available (skipped here)
# You can integrate rendering code here if graphics libraries are configured
