In [None]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    voxel_size = 64  # the generated voxel resolution
    voxel_channels = 1  # binary voxels
    train_batch_size = 16
    eval_batch_size = 16  # how many samples to generate during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_sample_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-buildings-3d"  # the model name locally
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()

In [None]:
import numpy as np
import json
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader

class BuildingVoxelDataset(Dataset):
    def __init__(self, dataset_path, transform=None):
        self.dataset_path = Path(dataset_path)
        self.voxels = np.load(self.dataset_path / "voxels.npy")
        with open(self.dataset_path / "metadata.json", "r") as f:
            self.metadata = json.load(f)
        self.transform = transform

    def __len__(self):
        return len(self.voxels)

    def __getitem__(self, idx):
        voxel_data = self.voxels[idx]
        if self.transform:
            voxel_data = self.transform(voxel_data)
        return {"voxels": voxel_data}

# Load the dataset
dataset_path = "../../training_data/n250_training_data_20250206_092106"  # Update this path
dataset = BuildingVoxelDataset(dataset_path)

In [None]:
def voxel_transform(voxel_array):
    # Convert to tensor and add channel dimension
    tensor = torch.from_numpy(voxel_array).float()
    tensor = tensor.unsqueeze(0)  # Add channel dimension
    return tensor

dataset.transform = voxel_transform

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def visualize_voxels(voxel_data, threshold=0.5):
    # Convert tensor to numpy if needed
    if torch.is_tensor(voxel_data):
        voxel_data = voxel_data.detach().cpu().numpy()
    
    # Convert to binary voxels if not already
    binary_voxels = voxel_data > threshold
    
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # Create solid voxels with permuted axes - removed the third axis from transpose
    ax.voxels(binary_voxels,
             facecolors='royalblue',
             edgecolor=None,
             alpha=0.9)
    
    # Set axis labels
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    
    # Set the viewing angle
    ax.view_init(elev=20, azim=45)
    
    # Make the grid lines lighter
    ax.grid(True, alpha=0.3)
    
    # Ensure axis starts from 0
    ax.set_xlim(0, binary_voxels.shape[0])
    ax.set_ylim(0, binary_voxels.shape[1])
    ax.set_zlim(0, binary_voxels.shape[2])
    
    return fig


def visualize_voxel_grid(samples, rows=1, cols=4):
    fig = plt.figure(figsize=(5*cols, 5*rows))
    for i in range(min(len(samples), rows*cols)):
        ax = fig.add_subplot(rows, cols, i+1, projection='3d')
        
        voxel_data = samples[i].squeeze()
        if torch.is_tensor(voxel_data):
            voxel_data = voxel_data.detach().cpu().numpy()
            
        binary_voxels = voxel_data > 0.5
        voxels_permuted = np.transpose(binary_voxels, (1, 0, 2))
        
        # Create solid voxels without wireframe
        ax.voxels(voxels_permuted,
                 facecolors='royalblue',
                 edgecolor=None,
                 alpha=0.9)
        
        ax.set_xlabel('Y')
        ax.set_ylabel('X')
        ax.set_zlabel('Z')
        ax.view_init(elev=20, azim=90, roll=270)
        ax.grid(True, alpha=0.3)
        
        # Ensure axis starts from 0
        ax.set_xlim(0, binary_voxels.shape[1])
        ax.set_ylim(0, binary_voxels.shape[0])
        ax.set_zlim(0, binary_voxels.shape[2])
    
    plt.tight_layout()
    return fig

# Example usage:
# visualize_voxels(your_voxel_data)

# Visualize some samples
fig = visualize_voxel_grid([dataset[i]["voxels"] for i in range(34)])
plt.show()

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention3D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        size = x.shape[-3:]
        x = x.reshape(x.shape[0], self.channels, -1).transpose(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.transpose(1, 2).reshape(x.shape[0], self.channels, *size)

class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.GELU(),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.GELU()
        )

    def forward(self, x):
        return self.double_conv(x)

class Down3D(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2, stride=2, padding=1),  # Added padding
            DoubleConv3D(in_channels, out_channels)
        )
        self.use_attention = use_attention
        if use_attention:
            self.attention = SelfAttention3D(out_channels)

    def forward(self, x):
        x = self.maxpool_conv(x)
        if self.use_attention:
            x = self.attention(x)
        return x

class Up3D(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv3D(in_channels, out_channels)
        self.use_attention = use_attention
        if use_attention:
            self.attention = SelfAttention3D(out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # Dynamic padding for differently sized inputs
        diff_x = x2.size()[2] - x1.size()[2]
        diff_y = x2.size()[3] - x1.size()[3]
        diff_z = x2.size()[4] - x1.size()[4]
        
        x1 = F.pad(x1, [
            diff_z // 2, diff_z - diff_z // 2,
            diff_y // 2, diff_y - diff_y // 2,
            diff_x // 2, diff_x - diff_x // 2
        ])
        
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        if self.use_attention:
            x = self.attention(x)
        return x

class UNet3DModel(nn.Module):
    def __init__(self, 
                 in_channels=1,
                 out_channels=1,
                 block_out_channels=(64, 128, 256, 512)):  # Reduced number of blocks and channels
        super().__init__()
        self.inc = DoubleConv3D(in_channels, block_out_channels[0])
        
        # Down blocks
        self.down1 = Down3D(block_out_channels[0], block_out_channels[1])
        self.down2 = Down3D(block_out_channels[1], block_out_channels[2])
        self.down3 = Down3D(block_out_channels[2], block_out_channels[3], use_attention=True)
        
        # Up blocks
        self.up1 = Up3D(block_out_channels[3], block_out_channels[2], use_attention=True)
        self.up2 = Up3D(block_out_channels[2], block_out_channels[1])
        self.up3 = Up3D(block_out_channels[1], block_out_channels[0])
        
        self.outc = nn.Conv3d(block_out_channels[0], out_channels, kernel_size=1)
        
        # Time embedding
        time_emb_dim = block_out_channels[0] * 4
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.GELU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

    def forward(self, x, timesteps, return_dict=True):
        # Time embedding
        emb = self.time_mlp(timesteps.float().view(-1, 1))
        
        # Initial conv
        x1 = self.inc(x)
        
        # Downsample
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        
        # Upsample
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        
        # Output conv
        output = self.outc(x)
        
        if return_dict:
            return {"sample": output}
        return (output,)

# Create model instance
model = UNet3DModel(
    in_channels=1,
    out_channels=1,
    block_out_channels=(64, 128, 256, 512)  # Reduced architecture
)

In [None]:
# Test the model with a sample batch
sample_voxels = dataset[0]["voxels"].unsqueeze(0)
print("Input shape:", sample_voxels.shape)

# Create a proper timesteps tensor
timesteps = torch.LongTensor([0])
output = model(sample_voxels, timesteps=timesteps)
print("Output shape:", output["sample"].shape)

In [None]:
import torch
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_voxels.shape)
timesteps = torch.LongTensor([50])
noisy_voxels = noise_scheduler.add_noise(sample_voxels, noise, timesteps)

# Visualize the noisy voxels
visualize_voxels(noisy_voxels.squeeze())

In [None]:
import torch.nn.functional as F

noise_pred = model(noisy_voxels, timesteps).sample
loss = F.mse_loss(noise_pred, noise)

In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

In [None]:
from diffusers import DDPMPipeline
import os

def evaluate(config, epoch, pipeline):
    # Sample some voxels from random noise (this is the backward diffusion process)
    voxels = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.Generator(device='cpu').manual_seed(config.seed),
    ).voxels

    # Create a grid visualization of the voxels
    fig = plt.figure(figsize=(20, 5))
    for i in range(4):
        plt.subplot(1, 4, i+1, projection='3d')
        visualize_voxels(voxels[i].squeeze())
    plt.tight_layout()
    
    # Save the figure
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    plt.savefig(f"{test_dir}/{epoch:04d}.png")
    plt.close()

In [None]:
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare everything
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Training loop
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_voxels = batch["voxels"]
            # Sample noise to add to the voxels
            noise = torch.randn(clean_voxels.shape, device=clean_voxels.device)
            bs = clean_voxels.shape[0]

            # Sample a random timestep for each voxel
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,),
                device=clean_voxels.device, dtype=torch.int64
            )

            # Add noise to the clean voxels according to the noise magnitude at each timestep
            noisy_voxels = noise_scheduler.add_noise(clean_voxels, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_voxels, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images and save the model
        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config.save_sample_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                pipeline.save_pretrained(config.output_dir)

In [None]:
from accelerate import notebook_launcher

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

notebook_launcher(train_loop, args, num_processes=1)