# 🚀 WGAN-GP 3D Voxel Shape Generation Notebook

**Project Title:** ShapeGAN Reboot — Implementing WGAN-GP for 3D Chair Synthesis  
**Student Name:** Yazdan Ghanavati
**Course:** 3D Vision / ICT  
**Date:** July 2025  
**Institution:** University of Padova

---

### 📘 Summary

This notebook demonstrates the full implementation of a WGAN-GP (Wasserstein GAN with Gradient Penalty) architecture designed to generate 3D voxel representations of chairs. It uses Signed Distance Fields (SDFs) from the ShapeNet dataset and integrates custom training loops, sampling routines, and visualization tools — all fully embedded for portability and transparency.

🧱 All code, data loading, optimization, and rendering are contained in the notebook. No external imports from local modules.

---

### 💡 Attribution

Most architectural ideas and training flow were inspired by the open-source project:  
🔗 [marian42/shapegan](https://github.com/marian42/shapegan)

This implementation adapts, restructures, and expands upon that foundation to support WGAN-GP with gradient penalties, checkpointing, and interactive shape previews.

---


# 🔍 Section 0: Library Imports

This cell imports all necessary libraries used throughout the notebook, including:
- PyTorch for deep learning
- NumPy for tensor and array manipulation
- `os`, `re`, and `sys` for file and system operations
- Visualization libraries such as `pygame` and `OpenGL` for 3D rendering
- Skimage and Trimesh for mesh creation and processing

📦 These imports cover model definitions, training loop, voxel handling, gradient penalty logic, and real-time previews.


In [1]:
# Source: train_wgan.py (Global Imports)

# --- Deep Learning & Data Handling ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.autograd as autograd  # For gradient penalty

# --- Math, System & File Tools ---
import numpy as np
import os
import sys
import re
import time

# --- Visualization: 3D Viewer & Mesh ---
import pygame  # type: ignore
from pygame.locals import *
import pygame.image

from OpenGL.GL import *
from OpenGL.GLU import *
from OpenGL.arrays import vbo

# --- Scientific & Rendering Tools ---
import cv2
import skimage.measure
import trimesh
from threading import Thread, Lock

# 🔧 Custom Utilities (embedded separately later)
# - create_text_slice()
# - crop_image()
# - get_camera_transform()
# - MeshRenderer
# - SavableModule


pygame 2.6.1 (SDL 2.28.4, Python 3.10.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


# 🧰 Section 1: Constants & Utility Definitions

This section defines the foundational tools and parameters used throughout the notebook:

- ⚙️ `device`: Automatically selects CUDA or CPU for computation
- 📊 `standard_normal_distribution`: Enables Gaussian noise sampling for latent space
- 📁 Directory setup ensures output folders (`models`, `plots`, `data`) are created before use
- 🎨 `create_text_slice()`: Converts a 2D slice of a voxel grid into an ASCII art preview for inline visualization
- 🧪 Additional helpers support image cropping, voxel coordinate generation, and 3D point sampling

These utilities power data loading, model generation, shape rendering, and the training pipeline.


In [3]:
# Source: util.py (Device setup, Distribution, Text Slice, Misc Tools)

import torch
import numpy as np
import os

# --- Hardware Selection ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# --- Standard Gaussian Noise Distribution for Latent Sampling ---
standard_normal_distribution = torch.distributions.normal.Normal(0, 1)

# --- Directory Setup ---
def ensure_directory(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

ensure_directory('plots')
ensure_directory('models')
ensure_directory('data')

# --- ASCII Slice Renderer ---
CHARACTERS = '      `.-:/+osyhdmm###############'

def create_text_slice(voxels):
    if isinstance(voxels, np.ndarray):
        voxels = torch.from_numpy(voxels).float().to(device)
    voxel_resolution = voxels.shape[-1]
    center_slice_index = voxels.shape[0] // 4
    data = voxels[center_slice_index, :, :]

    data = torch.clamp(data * -0.5 + 0.5, 0, 1) * (len(CHARACTERS) - 1)
    data = data.type(torch.int).cpu()
    lines = ['|' + ''.join([CHARACTERS[i] for i in line]) + '|' for line in data]

    frame = '+' + '—' * voxel_resolution + '+\n'
    return frame + '\n'.join(lines) + '\n' + frame

# --- Optional Utilities ---
def get_points_in_unit_sphere(n, device):
    x = torch.rand(int(n * 2.5), 3, device=device) * 2 - 1
    mask = (torch.norm(x, dim=1) < 1).nonzero().squeeze()
    mask = mask[:n]
    x = x[mask, :]
    if x.shape[0] < n:
        print("Warning: Did not find enough points.")
    return x

def crop_image(image, background=255):
    mask = image[:, :] != background
    coords = np.array(np.nonzero(mask))
    if coords.size != 0:
        top_left = np.min(coords, axis=1)
        bottom_right = np.max(coords, axis=1)
    else:
        top_left = np.array((0, 0))
        bottom_right = np.array(image.shape)
        print("Warning: Image contains only background pixels.")
    half_size = int(max(bottom_right[0] - top_left[0], bottom_right[1] - top_left[1]) / 2)
    center = ((top_left + bottom_right) / 2).astype(int)
    center = (min(max(half_size, center[0]), image.shape[0] - half_size),
              min(max(half_size, center[1]), image.shape[1] - half_size))
    if half_size > 100:
        image = image[center[0] - half_size:center[0] + half_size,
                      center[1] - half_size:center[1] + half_size]
    return image

def get_voxel_coordinates(resolution=32, size=1, center=0, return_torch_tensor=False):
    if type(center) == int:
        center = (center, center, center)
    points = np.meshgrid(
        np.linspace(center[0] - size, center[0] + size, resolution),
        np.linspace(center[1] - size, center[1] + size, resolution),
        np.linspace(center[2] - size, center[2] + size, resolution)
    )
    points = np.stack(points)
    points = np.swapaxes(points, 1, 2)
    points = points.reshape(3, -1).transpose()
    if return_torch_tensor:
        return torch.tensor(points, dtype=torch.float32, device=device)
    else:
        return points.astype(np.float32)


cpu


# 🧠 Section 2: Model Definitions (Generator, Critic, SavableModule)

This section defines the two core neural networks used in the WGAN-GP framework:

---

### 🔹 Generator

- Uses transposed 3D convolution layers to convert a latent noise vector into a 3D voxel grid
- Activations include BatchNorm and LeakyReLU for smooth gradient flow
- Final layer uses `Tanh` to output values in the range [-1, 1], suitable for Signed Distance Fields (SDF)
- Inherits from `SavableModule` to support saving/loading checkpoints

---

### 🔹 Discriminator (Critic)

- Uses 3D convolutions to assign scores to voxel shapes
- No sigmoid applied at the output — raw score is used to estimate Wasserstein distance
- Includes BatchNorm layers for training stability
- `use_sigmoid` flag determines whether to use traditional GAN output or WGAN-style critic scoring

---

### 🧱 SavableModule & Lambda

- `SavableModule` is a lightweight base class that adds checkpointing logic to any PyTorch module
- `Lambda` wraps arbitrary functions (like optional activations) into a layer-compatible object

These models form the heart of shape synthesis and quality evaluation in this notebook.


In [4]:
# Source: init.py and gan.py (Generator, Critic, Lambda, SavableModule)

import torch
import torch.nn as nn
import os

LATENT_CODE_SIZE = 128
MODEL_PATH = "models"
CHECKPOINT_PATH = os.path.join(MODEL_PATH, 'checkpoints')

# --- Lambda Layer Wrapper ---
class Lambda(nn.Module):
    def __init__(self, function):
        super(Lambda, self).__init__()
        self.function = function
    def forward(self, x):
        return self.function(x)

# --- Savable Base Module ---
class SavableModule(nn.Module):
    def __init__(self, filename):
        super(SavableModule, self).__init__()
        self.filename = filename

    def get_filename(self, epoch=None, filename=None):
        if filename is None:
            filename = self.filename
        if epoch is None:
            return os.path.join(MODEL_PATH, filename)
        else:
            filename = filename.split('.')
            filename[-2] += '-epoch-{:05d}'.format(epoch)
            filename = '.'.join(filename)
            return os.path.join(CHECKPOINT_PATH, filename)

    def load(self, epoch=None, filename=None):
        load_filename = filename if filename is not None else self.filename
        self.load_state_dict(torch.load(self.get_filename(epoch=epoch, filename=load_filename)), strict=False)

    def save(self, epoch=None):
        if epoch is not None and not os.path.exists(CHECKPOINT_PATH):
            os.makedirs(CHECKPOINT_PATH)
        torch.save(self.state_dict(), self.get_filename(epoch=epoch))

    @property
    def device(self):
        return next(self.parameters()).device

# --- Generator (WGAN-GP Style) ---
class Generator(SavableModule):
    def __init__(self):
        super(Generator, self).__init__(filename="generator_wgan.to")
        self.layers = nn.Sequential(
            nn.ConvTranspose3d(LATENT_CODE_SIZE, 256, 4, 1),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose3d(256, 128, 4, 2, 1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose3d(128, 64, 4, 2, 1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose3d(64, 1, 4, 2, 1),
            nn.Tanh()
        )

        self.to(device)

    def forward(self, x):
        x = x.reshape((-1, LATENT_CODE_SIZE, 1, 1, 1))
        return self.layers(x)

    def generate(self, sample_size=1):
        shape = torch.Size((sample_size, LATENT_CODE_SIZE))
        x = standard_normal_distribution.sample(shape).to(self.device)
        return self(x)

# --- Discriminator (Critic with Raw Output) ---
class Discriminator(SavableModule):
    def __init__(self):
        super(Discriminator, self).__init__(filename="discriminator_wgan.to")
        self.use_sigmoid = False  # For WGAN-GP: raw score, no sigmoid

        self.layers = nn.Sequential(
            nn.Conv3d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm3d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv3d(256, 1, 4, 1, 0, bias=False),
            Lambda(lambda x: torch.sigmoid(x) if self.use_sigmoid else x)
        )

        self.to(device)

    def forward(self, x):
        if x.ndim < 5:
            x = x.unsqueeze(1)
        return self.layers(x).view(-1, 1)

    def clip_weights(self, value):
        for parameter in self.parameters():
            parameter.data.clamp_(-value, value)


# 🌊 Section 3: WGAN-GP Gradient Penalty Function

One of the key innovations of WGAN-GP over vanilla GANs is its use of a gradient penalty to enforce the Lipschitz constraint on the Critic (Discriminator). This replaces traditional weight clipping, offering smoother optimization and better stability.

---

### 💡 Purpose of `calc_gradient_penalty()`

- Calculates the **gradient norm** of the Critic with respect to interpolated samples
- Penalizes deviations from a norm of 1, encouraging more well-behaved gradients
- Uses PyTorch’s `autograd.grad()` to compute gradients for inputs with `requires_grad=True`

---

### 📌 How It Works

1. **Interpolation**: Mixes real and fake samples
2. **Forward Pass**: Computes the Critic score for these interpolates
3. **Gradient Calculation**: Derives the Critic's gradient w.r.t inputs
4. **Penalty**: Applies penalty proportional to deviation from unit norm

This function is invoked during each Discriminator update step to stabilize training.


In [5]:
# Source: train_wgan.py (WGAN-GP Gradient Penalty Function)

def calc_gradient_penalty(discriminator, real_data, fake_data, lambda_gp=10.0):
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, 1, device=device)
    
    # Interpolate between real and fake samples
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    interpolates.requires_grad_(True)

    # Evaluate discriminator on interpolated data
    disc_interpolates = discriminator(interpolates)

    # Compute gradients of outputs w.r.t inputs
    gradients = autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Flatten and compute gradient norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)

    # Compute penalty: deviation from unit norm
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    return gradient_penalty


# 🔁 Section 4: WGAN-GP Training Loop

This section defines the complete training procedure for both the Generator and Critic using the WGAN-GP framework. The process follows a carefully balanced update strategy to ensure stable learning:

---

### 🧪 Critic Optimization
- The Critic is updated multiple times per Generator update (`critic_iterations`)
- Calculates real and fake scores, and applies the **gradient penalty** to enforce Lipschitz continuity

---

### 🎨 Generator Optimization
- The Generator is updated once per cycle to fool the Critic
- Uses the negative of the Critic’s score on generated samples as its loss

---

### 📋 Features in the Training Loop
- Models saved every `save_every` epochs
- Voxel outputs visualized in ASCII using `create_text_slice()`
- Metrics logged: Critic loss, Generator loss, Gradient penalty
- Samples generated at `preview_every` intervals for inline display

This loop is the heartbeat of the notebook — training the networks to generate realistic, high-quality 3D shapes.


In [6]:
# Source: train_wgan.py (Main Training Loop for WGAN-GP)

def train_wgan(generator, discriminator, dataloader, epochs=100, critic_iterations=5, save_every=10, preview_every=10):
    optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))

    for epoch in range(epochs):
        for i, real_data in enumerate(dataloader):
            real_data = real_data.to(device)

            # ---------------------
            #  Train Critic
            # ---------------------
            for _ in range(critic_iterations):
                discriminator.zero_grad()

                # Generate fake data
                fake_data = generator.generate(sample_size=real_data.size(0))

                # Critic loss components
                score_real = discriminator(real_data).mean()
                score_fake = discriminator(fake_data.detach()).mean()
                gradient_penalty = calc_gradient_penalty(discriminator, real_data, fake_data.detach())

                loss_D = score_fake - score_real + gradient_penalty
                loss_D.backward()
                optimizer_D.step()

            # ---------------------
            #  Train Generator
            # ---------------------
            generator.zero_grad()

            fake_data = generator.generate(sample_size=real_data.size(0))
            score = discriminator(fake_data)
            loss_G = -score.mean()
            loss_G.backward()
            optimizer_G.step()

        # --- Logging ---
        print(f"[Epoch {epoch+1}/{epochs}] Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}, GP: {gradient_penalty.item():.4f}")

        # --- Preview Output ---
        if (epoch + 1) % preview_every == 0:
            preview = generator.generate().detach().cpu().squeeze()
            print(create_text_slice(preview))

        # --- Save Models ---
        if (epoch + 1) % save_every == 0:
            generator.save(epoch + 1)
            discriminator.save(epoch + 1)


# 📦 Section 5: Data Loading with `VoxelDataset`

This section defines the `VoxelDataset` class, which is responsible for loading 3D voxel data stored in `.npy` files. These files typically represent Signed Distance Fields (SDFs) extracted from the ShapeNet dataset.

---

### 🧠 What `VoxelDataset` Does
- Loads voxel grids from a list of file paths
- Clamps and normalizes SDF values to a usable range
- Provides PyTorch compatibility for use in `DataLoader`
- Includes convenient static methods to load datasets via glob patterns or split lists

---

### 🔍 Features
- `clamp`: Controls truncation limits for extreme SDF values
- `rescale_sdf`: Normalizes the values to [-1, 1] after clamping
- `glob(pattern)`: Finds and sorts voxel files based on a glob path
- `from_split(pattern, split_file)`: Loads files using an external split list
- `show()`: Allows visual inspection of samples using real-time rendering

This class gives structure and control over dataset usage during training and previewing.


In [7]:
# Source: voxel_dataset.py (VoxelDataset Class for Data Loading)

import torch
from torch.utils.data import Dataset
import os
import numpy as np

class VoxelDataset(Dataset):
    def __init__(self, files, clamp=0.1, rescale_sdf=True):
        self.files = files
        self.clamp = clamp
        self.rescale_sdf = rescale_sdf

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        array = np.load(self.files[index])
        result = torch.from_numpy(array)
        if self.clamp is not None:
            result.clamp_(-self.clamp, self.clamp)
            if self.rescale_sdf:
                result /= self.clamp
        return result

    @staticmethod
    def glob(pattern):
        import glob
        files = glob.glob(pattern, recursive=True)
        if len(files) == 0:
            raise Exception(f'No files found for glob pattern {pattern}.')
        return VoxelDataset(sorted(files))

    @staticmethod
    def from_split(pattern, split_file_name):
        with open(split_file_name, 'r') as split_file:
            ids = split_file.readlines()
        files = [pattern.format(id.strip()) for id in ids]
        files = [file for file in files if os.path.exists(file)]
        return VoxelDataset(files)

    def show(self):
        from rendering import MeshRenderer
        import time
        from tqdm import tqdm

        viewer = MeshRenderer()
        for item in tqdm(self):
            viewer.set_voxels(item.numpy())
            time.sleep(0.5)


# 🎥 Section 6: Shape Visualization with `MeshRenderer`

This section enables real-time 3D rendering of voxel shapes using OpenGL and PyGame. It builds a flexible and efficient viewer that can display both generated and training set shapes with shading, lighting, floor reference, and rotation controls.

---

### 🔧 Features of `MeshRenderer`

- Uses **marching cubes** to convert binary voxel grids or SDFs into 3D meshes
- Renders shapes with custom shaders for light and shadow control
- Visualizes shape geometry from latent codes during training
- Includes real-time preview options with mouse-controlled rotation
- Floor plane support for grounding visuals and spatial reference

---

### 🪞 Preview Utilities
- `set_voxels(voxels)`: Converts a voxel grid into a mesh and updates OpenGL buffers
- `set_mesh(mesh)`: Renders a precomputed mesh (e.g., via `trimesh`)
- `get_image()`: Captures the rendering window into a NumPy image
- `save_screenshot()`: Stores the current frame as a PNG

This viewer is used during model training and debugging to inspect voxel outputs in a spatially meaningful way.


In [8]:
# Source: rendering.py (MeshRenderer Class for 3D Voxel Visualization)

import numpy as np
import pygame
import os
from OpenGL.GL import *
from OpenGL.arrays import vbo
import trimesh

SHADOW_TEXTURE_SIZE = 1024
DEFAULT_ROTATION = [25, 15]

class MeshRenderer:
    def __init__(self, size=512):
        self.size = size
        self.model_color = [1.0, 1.0, 1.0]
        self.background_color = [0.2, 0.2, 0.2, 1]
        self.model_size = 1.0
        self.rotation = list(DEFAULT_ROTATION)
        self.render_lock = Lock()
        self.mouse = None
        self.request_render = True
        self.running = True
        self.vertex_buffer = None
        self.normal_buffer = None

    def set_voxels(self, voxels):
        marching_cubes = skimage.measure.marching_cubes(
            voxels, level=0.0, spacing=(1.0 / voxels.shape[0],) * 3)
        mesh = trimesh.Trimesh(vertices=marching_cubes[0], faces=marching_cubes[1])
        self.set_mesh(mesh, smooth=False, center_and_scale=True)

    def set_mesh(self, mesh, smooth=False, center_and_scale=False):
        if mesh is None:
            return

        vertices = np.array(mesh.triangles, dtype=np.float32).reshape(-1, 3)
        if center_and_scale:
            vertices -= mesh.bounding_box.centroid[np.newaxis, :]
            vertices /= np.max(np.linalg.norm(vertices, axis=1))
        self.ground_level = np.min(vertices[:, 1]).item()
        vertices = vertices.reshape((-1))

        if smooth:
            normals = mesh.vertex_normals[mesh.faces.reshape(-1)].astype(np.float32) * -1
        else:
            normals = np.repeat(mesh.face_normals, 3, axis=0).astype(np.float32)

        self._update_buffers(vertices, normals)
        self.model_size = 1.08

    def _update_buffers(self, vertices, normals):
        self.vertex_buffer = vbo.VBO(vertices)
        self.normal_buffer = vbo.VBO(normals)
        self.vertex_buffer_size = len(vertices) // 3

    def _poll_mouse(self):
        left_mouse, _, right_mouse = pygame.mouse.get_pressed()
        pressed = left_mouse == 1 or right_mouse == 1
        current_mouse = pygame.mouse.get_pos()
        if self.mouse is not None and pressed:
            movement = (current_mouse[0] - self.mouse[0], current_mouse[1] - self.mouse[1])
            self.rotation = [self.rotation[0] + movement[0], max(-90, min(90, self.rotation[1] + movement[1]))]
        self.mouse = current_mouse
        return pressed

    def _draw_mesh(self, use_normals=True):
        if self.vertex_buffer is None or self.normal_buffer is None:
            return
        glEnableClientState(GL_VERTEX_ARRAY)
        self.vertex_buffer.bind()
        glVertexPointer(3, GL_FLOAT, 0, self.vertex_buffer)
        if use_normals:
            glEnableClientState(GL_NORMAL_ARRAY)
            self.normal_buffer.bind()
            glNormalPointer(GL_FLOAT, 0, self.normal_buffer)
        glDrawArrays(GL_TRIANGLES, 0, self.vertex_buffer_size)

    def prepare_floor(self):
        size = 6
        mesh = trimesh.Trimesh([
            [-size, 0, -size], [-size, 0, +size], [+size, 0, +size],
            [-size, 0, -size], [+size, 0, +size], [+size, 0, -size]
        ], faces=[[0, 1, 2], [3, 4, 5]])
        vertices = np.array(mesh.triangles, dtype=np.float32).reshape(-1, 3)
        normals = np.repeat(mesh.face_normals, 3, axis=0).astype(np.float32)
        self.floor_vertices = vbo.VBO(vertices.reshape(-1))
        self.floor_normals = vbo.VBO(normals)

    def _draw_floor(self):
        self.shader.set_y_offset(self.ground_level)
        glEnableClientState(GL_VERTEX_ARRAY)
        self.floor_vertices.bind()
        glVertexPointer(3, GL_FLOAT, 0, self.floor_vertices)
        glEnableClientState(GL_NORMAL_ARRAY)
        self.floor_normals.bind()
        glNormalPointer(GL_FLOAT, 0, self.floor_normals)
        glDrawArrays(GL_TRIANGLES, 0, 6)

    def _initialize_opengl(self):
        pygame.init()
        pygame.display.set_caption('Model Viewer')
        pygame.display.gl_set_attribute(pygame.GL_MULTISAMPLEBUFFERS, 1)
        pygame.display.gl_set_attribute(pygame.GL_MULTISAMPLESAMPLES, 4)
        self.window = pygame.display.set_mode((self.size, self.size), pygame.OPENGLBLIT)

        self.shader = Shader()
        script_dir = os.path.dirname(__file__)
        self.shader.initShader(
            open(os.path.join(script_dir, 'vertex.glsl')).read(),
            open(os.path.join(script_dir, 'fragment.glsl')).read()
        )

        self.shadow_framebuffer = glGenFramebuffers(1)
        self.shadow_texture = create_shadow_texture()

        self.depth_shader = Shader()
        self.depth_shader.initShader(
            open(os.path.join(script_dir, 'depth_vertex.glsl')).read(),
            open(os.path.join(script_dir, 'depth_fragment.glsl')).read()
        )

        self.prepare_floor()

    def _render_shadow_texture(self, light_vp_matrix):
        glBindFramebuffer(GL_FRAMEBUFFER, self.shadow_framebuffer)
        glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_TEXTURE_2D, self.shadow_texture, 0)
        glDrawBuffer(GL_NONE)
        glReadBuffer(GL_NONE)
        glClear(GL_DEPTH_BUFFER_BIT)
        glViewport(0, 0, SHADOW_TEXTURE_SIZE, SHADOW_TEXTURE_SIZE)
        glEnable(GL_DEPTH_TEST)
        glDepthMask(GL_TRUE)
        glDepthFunc(GL_LESS)
        glDepthRange(0.0, 1.0)
        glDisable(GL_CULL_FACE)
        glDisable(GL_BLEND)

        self.depth_shader.use()
        self.depth_shader.set_vp_matrix(light_vp_matrix)
        self._draw_mesh(use_normals=False)

        glBindFramebuffer(GL_FRAMEBUFFER, 0)

    def _render(self):
        self.request_render = False
        self.render_lock.acquire()

        light_vp_matrix = get_camera_transform(6, self.rotation[0], 50, project=True)
        self._render_shadow_texture(light_vp_matrix)

        self.shader.use()
        self.shader.set_floor(False)
        self.shader.set_color(self.model_color)
        self.shader.set_y_offset(0)
        camera_vp_matrix = get_camera_transform(self.model_size * 2, self.rotation[0], self.rotation[1], project=True)
        self.shader.set_vp_matrix(camera_vp_matrix)
        self.shader.set_light_vp_matrix(light_vp_matrix)

        glClearColor(*self.background_color)
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        glClearDepth(1.0)
        glDepthMask(GL_TRUE)
        glDepthFunc(GL_LESS)
        glDepthRange(0.0, 1.0)
        glEnable(GL_CULL_FACE)
        glEnable(GL_DEPTH_TEST)
        glViewport(0, 0, self.size, self.size)

        glActiveTexture(GL_TEXTURE1)
        glBindTexture(GL_TEXTURE_2D, self.shadow_texture)
        self.shader.set_shadow_texture(1)

        self._draw_mesh()
        self.shader.set_floor(True)
        self._draw_floor()
        self.render_lock.release()

    def _run(self):
        self._initialize_opengl()
        self._render()
        while self.running:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    return
                if event.type == pygame.KEYDOWN:
                    if pygame.key.get_pressed()[pygame.K_F12]:
                        self.save_screenshot()
                    if pygame.key.get_pressed()[pygame.K_r]:
                        self.rotation = list(DEFAULT_ROTATION)
                        self.request_render = True
            if self._poll_mouse() or self.request_render:
                self._render()
                pygame.display.flip()
            pygame.time.wait(10)
        self.delete_buffers()

    def get_image(self, crop=False, output_size=None, greyscale=False, flip_red_blue=False):
        if self.request_render:
            self._render()
        if output_size is None:
            output_size = self.size
        string_image = pygame.image.tostring(self.window, 'RGB')
        image = pygame.image.fromstring(string_image, (self.size, self.size), 'RGB')
        if greyscale:
            array = np.transpose(pygame.surfarray.array3d(image)[:, :, 0])
        else:
            array = np.transpose(pygame.surfarray.array3d(image)[:, :, (2, 1, 0) if flip_red_blue else slice(None)], (1, 0, 2))
        if crop:
            array = crop_image(array)
        if output_size != self.size:
            array = cv2.resize(array, dsize=(output_size, output_size), interpolation=cv2.INTER_CUBIC)
        return array

    def save_screenshot(self):
        ensure_directory('screenshots')
        FILENAME_FORMAT = "screenshots/{:04d}.png"
        index

# 🔍 Section 7: Inference & 3D Sample Visualization

Now that the Generator is trained and saved at specific epochs, we can inspect its output using this embedded sampling and viewing routine.

---

### 🧪 What It Does
- Loads a Generator checkpoint (e.g., `epoch_140`)
- Generates a few voxel samples using random latent vectors
- Displays their 2D slices in ASCII format using `create_text_slice()`
- Optionally launches a 3D mesh viewer (`MeshRenderer`) if OpenGL rendering is available

---

### ⚙️ Parameters
- `EPOCH_TO_LOAD`: Defines which checkpoint to inspect
- `NUM_SAMPLES`: Number of fake voxel grids to generate
- `ENABLE_3D_VIEWER`: Flag that enables OpenGL mesh visualization

This block is ideal for evaluating the quality and diversity of generated shapes after training — either inline in the notebook or with rich 3D previews.


In [None]:
    # 🔍 Inference & Viewer Logic from generate_and_view.py

import torch
import numpy as np
import sys
import os
import time

# --- Constants ---
MODEL_PATH = "models"
CHECKPOINT_PATH = os.path.join(MODEL_PATH, 'checkpoints')
EPOCH_TO_LOAD = 140
NUM_SAMPLES = 3

# --- Load Generator ---
generator = Generator()
try:
    generator.load(epoch=EPOCH_TO_LOAD, filename="generator_wgan.to")
    generator.eval()
    print(f"✅ Generator loaded successfully from epoch {EPOCH_TO_LOAD}")
except FileNotFoundError:
    base_filename = "generator_wgan.to"
    filename_parts = base_filename.split('.')
    filename_parts[-2] += f'-epoch-{EPOCH_TO_LOAD:05d}'
    expected_filename = '.'.join(filename_parts)
    expected_path = os.path.join(CHECKPOINT_PATH, expected_filename)
    print(f"❌ Generator checkpoint not found: {expected_path}")
    print("Please verify that training ran long enough and model exists.")
    sys.exit(1)
except Exception as e:
    print(f"❌ Unexpected error loading Generator: {e}")
    sys.exit(1)

# --- Viewer Setup ---
ENABLE_3D_VIEWER = True
viewer = None
try:
    viewer = MeshRenderer()
    print("🟢 3D viewer initialized")
except Exception as e:
    print(f"⚠️ MeshRenderer init failed: {e}")
    ENABLE_3D_VIEWER = False

# --- Sample Generation ---
print(f"\n🎨 Generating {NUM_SAMPLES} voxel samples...")
with torch.no_grad():
    for i in range(NUM_SAMPLES):
        voxel_tensor = generator.generate(sample_size=1)
        voxel_numpy = voxel_tensor.squeeze().cpu().numpy()

        print(f"\n--- Sample {i+1} (2D Slice Preview) ---")
        if voxel_numpy.shape == (32, 32, 32):
            print(create_text_slice(voxel_numpy))
        else:
            print(f"⚠️ Unexpected shape: {voxel_numpy.shape}")
            print(voxel_numpy[:5, :5, :5])

        if ENABLE_3D_VIEWER and viewer:
            try:
                viewer.set_voxels(voxel_numpy, level=0.0)
                print(f"🧊 Sample {i+1} sent to 3D viewer")
            except Exception as e:
                print(f"❌ Error visualizing Sample {i+1}: {e}")

print("\n✅ Inference complete. Close 3D window(s) manually if open.")


# 🚀 Section 8: Entry Point and Execution

This block initializes the Generator and Discriminator, optionally resumes training from the last saved checkpoint, opens a log file to track metrics, and starts the training loop.

✅ You can adjust `start_epoch`, `total_epochs`, or checkpoint paths here. By wrapping the logic in `if __name__ == '__main__':`, it stays safe in both notebook and script form.


In [None]:
# Source: train_wgan.py

if __name__ == '__main__':
    # --- Resume from last saved epoch ---
    def find_last_saved_epoch(filename="generator_wgan.to"):
        if not os.path.exists(CHECKPOINT_PATH):
            return 0
        files = sorted([
            f for f in os.listdir(CHECKPOINT_PATH)
            if f.startswith(filename.split('.')[0]) and f.endswith('.to')
        ])
        if not files:
            return 0
        last_file = files[-1]
        import re
        match = re.search(r'epoch-(\d+)', last_file)
        return int(match.group(1)) if match else 0

    start_epoch = find_last_saved_epoch()
    print(f"Resuming from epoch {start_epoch}...")
    generator = Generator()
    discriminator = Discriminator()

    if start_epoch > 0:
        generator.load(epoch=start_epoch)
        discriminator.load(epoch=start_epoch)
        print("✅ Models loaded from checkpoint.")

    # --- Logging Setup ---
    log_filename = os.path.join(MODEL_PATH, "log.txt")
    log_file = open(log_filename, "a")

    # --- Dataset ---
    dataset = VoxelDataset.glob("data/YOUR_DATA_PATH/**/*.npy")
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    # --- Begin Training ---
    train_wgan(generator, discriminator, dataloader, epochs=150)

    log_file.close()
