In [None]:
import json
from enum import Enum
import numpy as np
from typing import Any, Dict
from torch.utils.data import Dataset
from pathlib import Path

class BuildingVoxelDataset(Dataset):
    def __init__(self, dataset_path, transform=None):
        self.dataset_path = Path(dataset_path)
        self.voxels = np.load(self.dataset_path / "voxels.npy")
        with open(self.dataset_path / "metadata.json", "r") as f:
            self.metadata = json.load(f)
        self.transform = transform

    def __len__(self):
        return len(self.voxels)

    def __getitem__(self, idx):
        voxel_data = self.voxels[idx]
        if self.transform:
            voxel_data = self.transform(voxel_data)
        return {"voxels": voxel_data}

# Load the dataset
dataset_path = "../../training_data/n250_training_data_20250206_092106"  # Update this path
dataset = BuildingVoxelDataset(dataset_path)

In [None]:
# Figure out voxel_channels by max over dataset
voxel_channels = dataset.voxels.max()
print("Number of voxel channels:", voxel_channels)

used_voxel_channels = voxel_channels + 1  # Add 1 to include 0 as a class

# Get size by max over dimension of dataset
voxel_size = dataset.voxels.shape[1:]
voxel_size = [2 ** int(np.ceil(np.log2(size))) for size in voxel_size]
print("Voxel size:", voxel_size)

In [None]:
# Transform the voxels to a one-hot encoding: Each channel is a block type

def voxel_transform(voxel_array):
    """
    Transform voxel array to padded one-hot encoding.

    Args:
        voxel_array: Input voxel array of shape (H, W, D)

    Returns:
        Padded, one-hot encoded tensor of shape (C, H', W', D') where dimensions
        are padded to nearest power of 2
    """
    # Calculate target size (power of 2)

    # Calculate padding needed on each dimension
    pad_size = []
    for actual, target in zip(voxel_array.shape, voxel_size):
        diff = target - actual
        # Add padding evenly to both sides
        pad_size.extend([diff // 2, diff - diff // 2])

    # Pad the array
    padded_array = np.pad(
        voxel_array,
        pad_width=[(pad_size[i], pad_size[i+1]) for i in range(0, len(pad_size), 2)],
        mode='constant',
        constant_values=0
    )

    # Convert to tensor
    tensor = torch.from_numpy(padded_array).long()

    # Create one-hot encoding
    # Add 1 to include 0 as a class
    num_classes = voxel_channels + 1
    one_hot = torch.nn.functional.one_hot(tensor, num_classes=num_classes)

    # Move channels dimension to front (channels, height, width, depth)
    one_hot = one_hot.permute(3, 0, 1, 2).float()

    # Normalize to [-1, 1] range
    one_hot = 2.0 * one_hot - 1.0

    return one_hot

def onehot_to_voxel(onehot_tensor):
    """Convert one-hot encoded tensor back to voxel format."""
    # Ensure tensor is on CPU and convert to numpy if needed
    if torch.is_tensor(onehot_tensor):
        onehot_tensor = onehot_tensor.detach().cpu()

    # Denormalize from [-1, 1] range
    onehot = (onehot_tensor + 1.0) / 2.0

    # Get class indices (convert back from one-hot)
    if len(onehot.shape) == 5:  # If batch dimension present
        onehot = onehot.squeeze(0)  # Remove batch dimension

    # Move channels to last dimension and get argmax
    onehot = onehot.permute(1, 2, 3, 0)
    voxels = torch.argmax(onehot, dim=-1)

    return voxels

# Example usage:
dataset.transform = voxel_transform

In [None]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    voxel_size = voxel_size[0]
    voxel_channels = voxel_channels + 1  # Add 1 to include 0 as a class
    train_batch_size = 16
    eval_batch_size = 4  # how many samples to generate during evaluation
    num_epochs = 100
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_sample_epochs = 10
    save_model_epochs = 25
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-buildings-3d"  # the model name locally
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def visualize_voxels(voxel_data, ax=None, threshold=0.5, elev=20, azim=90, roll=270, transpose=True):
    """
    Visualize voxel data in a 3D plot.

    Args:
        voxel_data: Tensor or numpy array of voxel data
        ax: Matplotlib axis for plotting. If None, creates new axis
        threshold: Threshold for binary voxelization
        elev, azim, roll: View angles for 3D plot
        transpose: Whether to transpose the voxel data (for grid visualization)

    Returns:
        fig: Matplotlib figure object (if ax was None)
        ax: Matplotlib axis object
    """
    # Convert tensor to numpy if needed
    if torch.is_tensor(voxel_data):
        voxel_data = voxel_data.detach().cpu().numpy()

    voxel_data = voxel_data.squeeze()

    # Convert to binary voxels
    binary_voxels = voxel_data > threshold

    # Transpose if needed (for grid visualization)
    if transpose:
        binary_voxels = np.transpose(binary_voxels, (1, 0, 2))
        voxel_data = np.transpose(voxel_data, (1, 0, 2))

    # Create colormap for unique values
    unique_values = np.unique(voxel_data[binary_voxels])
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_values)))
    color_dict = dict(zip(unique_values, colors))

    # Create color array
    facecolors = np.zeros((*binary_voxels.shape, 4))
    for value in unique_values:
        mask = (voxel_data == value) & binary_voxels
        facecolors[mask] = color_dict[value]

    # Create or use provided axis
    if ax is None:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')
        return_fig = True
    else:
        return_fig = False

    # Plot voxels
    ax.voxels(binary_voxels,
              facecolors=facecolors,
              edgecolor=None,
              alpha=0.9)

    # Set labels and view
    ax.set_xlabel('Y' if transpose else 'X')
    ax.set_ylabel('X' if transpose else 'Y')
    ax.set_zlabel('Z')
    ax.view_init(elev=elev, azim=azim, roll=roll)
    ax.grid(True, alpha=0.3)

    # Set axis limits
    ax.set_xlim(0, binary_voxels.shape[1] if transpose else binary_voxels.shape[0])
    ax.set_ylim(0, binary_voxels.shape[0] if transpose else binary_voxels.shape[1])
    ax.set_zlim(0, binary_voxels.shape[2])

    return (fig, ax) if return_fig else ax

def visualize_voxel_grid(samples, metadata=None, threshold=0.5, rows=1, cols=4):
    """
    Create a grid of voxel visualizations with metadata prompts as subtitles.

    Args:
        samples: List of voxel data samples
        metadata: List of metadata dictionaries containing prompts
        threshold: Threshold for binary voxelization
        rows: Number of rows in the grid
        cols: Number of columns in the grid

    Returns:
        fig: Matplotlib figure object
    """
    fig = plt.figure(figsize=(5*cols, 5*rows))

    for i in range(min(len(samples), rows*cols)):
        ax = fig.add_subplot(rows, cols, i+1, projection='3d')
        visualize_voxels(samples[i], ax=ax, threshold=threshold)

        # Add prompt as subtitle if metadata is provided
        if metadata is not None:
            prompt = metadata[i].get('prompt', '')
            # Wrap text to multiple lines if too long
            prompt = '\n'.join([prompt[j:j+30] for j in range(0, len(prompt), 30)])
            ax.set_title(prompt, pad=10, wrap=True, fontsize=8)

    plt.tight_layout()
    return fig

# Example usage:
# visualize_voxels(your_voxel_data)

# Visualize some samples
fig = visualize_voxel_grid(
    samples=[dataset.voxels[i] for i in range(4, 8)],
    metadata=[dataset.metadata[i] for i in range(4, 8)]
)
plt.show()

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def get_timestep_embedding(timesteps, embedding_dim=256):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
    :param embedding_dim: the dimension of the output.
    :return: an [N x embedding_dim] Tensor of positional embeddings.
    """
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    return emb

class SelfAttention3D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        size = x.shape[-3:]
        x = x.reshape(x.shape[0], self.channels, -1).transpose(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.transpose(1, 2).reshape(x.shape[0], self.channels, *size)

class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.GELU(),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(8, out_channels),
            nn.GELU()
        )

    def forward(self, x):
        return self.double_conv(x)

class Down3D(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv3D(in_channels, out_channels)
        )
        self.use_attention = use_attention
        if use_attention:
            self.attention = SelfAttention3D(out_channels)

    def forward(self, x):
        x = self.maxpool_conv(x)
        if self.use_attention:
            x = self.attention(x)
        return x

class Up3D(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super().__init__()
        self.up = nn.ConvTranspose3d(
            in_channels, in_channels // 2,
            kernel_size=2, stride=2
        )
        self.conv = DoubleConv3D(in_channels, out_channels)
        self.use_attention = use_attention
        if use_attention:
            self.attention = SelfAttention3D(out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Handling cases where sizes don't match perfectly
        diff_x = x2.size()[2] - x1.size()[2]
        diff_y = x2.size()[3] - x1.size()[3]
        diff_z = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [
            diff_z // 2, diff_z - diff_z // 2,
            diff_y // 2, diff_y - diff_y // 2,
            diff_x // 2, diff_x - diff_x // 2
        ])
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        if self.use_attention:
            x = self.attention(x)
        return x

# Modified VAE with simpler decoder that doesn't use skip connections
class VoxelVAE(nn.Module):
    def __init__(self, in_channels=used_voxel_channels, latent_dim=4):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            DoubleConv3D(in_channels, 32),
            Down3D(32, 64),
            Down3D(64, 128)
        )
        # Latent space
        self.fc_mu = nn.Linear(128 * 8 * 8 * 8, latent_dim)
        self.fc_var = nn.Linear(128 * 8 * 8 * 8, latent_dim)

        # Modified decoder without skip connections
        self.decoder_input = nn.Linear(latent_dim, 128 * 4 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            DoubleConv3D(64, 64),
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            DoubleConv3D(32, 32),
            nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2),
            DoubleConv3D(32, 32),
            nn.Conv3d(32, in_channels, kernel_size=1)
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var

    def decode(self, z):
        x = self.decoder_input(z)
        x = x.view(x.size(0), 128, 4, 4, 4)
        x = self.decoder(x)
        return x

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

class UNet3DModel(nn.Module):
    def __init__(self,
                 in_channels=used_voxel_channels,
                 out_channels=used_voxel_channels,
                 block_out_channels=(64, 128, 256, 512),
                 embedding_dim=256):
        super().__init__()

        # Initial convolution
        self.inc = DoubleConv3D(in_channels, block_out_channels[0])

        # Down blocks
        self.down1 = Down3D(block_out_channels[0], block_out_channels[1])
        self.down2 = Down3D(block_out_channels[1], block_out_channels[2])
        self.down3 = Down3D(block_out_channels[2], block_out_channels[3], use_attention=True)

        # Up blocks
        self.up1 = Up3D(block_out_channels[3], block_out_channels[2], use_attention=True)
        self.up2 = Up3D(block_out_channels[2], block_out_channels[1])
        self.up3 = Up3D(block_out_channels[1], block_out_channels[0])

        # Output convolution
        self.outc = nn.Conv3d(block_out_channels[0], out_channels, kernel_size=1)

        # Time embedding
        time_emb_dim = block_out_channels[0] * 4
        self.time_mlp = nn.Sequential(
            nn.Linear(embedding_dim, time_emb_dim),
            nn.GELU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

    def forward(self, x, timesteps):
        # Time embedding
        emb = get_timestep_embedding(timesteps)
        emb = self.time_mlp(emb)

        # Initial conv
        x1 = self.inc(x)

        # Downsample
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        # Upsample with skip connections
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)

        # Output conv
        output = self.outc(x)

        return output

class VoxelDiffusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.vae = VoxelVAE(in_channels=used_voxel_channels, latent_dim=4)
        self.unet = UNet3DModel(in_channels=used_voxel_channels, out_channels=used_voxel_channels)

    def encode(self, x):
        return self.vae.encode(x)

    def decode(self, z):
        return self.vae.decode(z)

    def forward(self, x, timesteps):
        # Get latent representation
        latent, mu, log_var = self.vae(x)

        # Apply UNet in latent space
        noise_pred = self.unet(latent, timesteps)

        return noise_pred, mu, log_var

# Create model instance
model = VoxelDiffusion()

# Example usage
voxel_data = torch.randn(1, used_voxel_channels, 32, 32, 32)  # Your voxel data
timesteps = torch.tensor([500])  # Current timestep in the diffusion process
output, mu, log_var = model(voxel_data, timesteps)
print("Output shape:", output.shape)
print("Mu shape:", mu.shape)
print("Log variance shape:", log_var.shape)

In [None]:
# Test the model with a sample batch
sample_voxels = dataset[0]["voxels"].unsqueeze(0)
print("Input shape:", sample_voxels.shape)

revert_voxels = onehot_to_voxel(sample_voxels)
visualize_voxel_grid([revert_voxels])
plt.show()

# Create a proper timesteps tensor
timesteps = torch.LongTensor([0])
output = model(sample_voxels, timesteps=timesteps)
print("Output shape:", output[0].shape)

In [None]:
# Visualize each layer of the one hot encoding as a separate voxel grid
viz_voxels = dataset[0]["voxels"]
layer_voxels = [viz_voxels[i, :] for i in range(viz_voxels.shape[0])]

print("Layer shapes:", [voxel.shape for voxel in layer_voxels])

# Visualize each layer
layer_fig = visualize_voxel_grid(layer_voxels, rows=2, cols=4)
layer_fig.show()

In [None]:
import torch
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_voxels.shape)
timesteps = torch.LongTensor([150])
noisy_voxels = noise_scheduler.add_noise(sample_voxels, noise, timesteps)

# print shapes
print("Input shape:", sample_voxels.shape)
print("Noise shape:", noise.shape)

noisy_voxels = noisy_voxels.squeeze(0)
layer_voxels = [noisy_voxels[i, :] for i in range(noisy_voxels.shape[0])]

# Visualize each layer
fig = visualize_voxel_grid(layer_voxels, rows=2, cols=4, threshold=0.0)
fig.show()

In [None]:
import torch.nn.functional as F

test_vector = noisy_voxels.unsqueeze(0)
test_output, _, _ = model(test_vector, timesteps)
print("Noise prediction shape:", test_output.shape)
loss = F.mse_loss(test_output, noise)
print("Loss:", loss.item())

In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

In [None]:
import os
import torch
from safetensors.torch import save_file

class VoxelDiffusionPipeline:
    def __init__(self, unet, scheduler):
        self.unet = unet
        self.scheduler = scheduler
        self.device = next(unet.parameters()).device

    def save_pretrained(self, save_directory):
        """Save the pipeline's models and scheduler."""
        os.makedirs(save_directory, exist_ok=True)

        # Save the model state
        model_path = os.path.join(save_directory, "model.safetensors")
        model_state = self.unet.state_dict()
        save_file(model_state, model_path)

        # Save the scheduler config
        scheduler_config = self.scheduler.config
        scheduler_path = os.path.join(save_directory, "scheduler_config.json")
        with open(scheduler_path, "w") as f:
            json.dump(scheduler_config, f, indent=2)

    def __call__(
        self,
        batch_size=1,
        generator=None,
        return_dict=True,
    ):
        # Create generator on the correct device
        if generator is not None:
            if generator.device.type != self.device.type:
                generator = torch.Generator(device=self.device).manual_seed(generator.initial_seed())

        # Start from random noise
        shape = (batch_size, config.voxel_channels, config.voxel_size, config.voxel_size, config.voxel_size)
        noise = torch.randn(shape, device=self.device, generator=generator)

        # Set timesteps
        self.scheduler.set_timesteps(1000)

        # Denoising loop
        voxels = noise
        for t in self.scheduler.timesteps:
            # Expand the latents for classifier free guidance if we are doing classifier free guidance
            timestep = torch.tensor([t], device=self.device)
            timestep = timestep.expand(batch_size)

            # Get model prediction
            with torch.no_grad():
                noise_pred = self.unet(voxels, timestep)[0]

            # Scheduler step
            voxels = self.scheduler.step(noise_pred, t, voxels).prev_sample

        # Convert from [-1, 1] range back to [0, 1]
        voxels = (voxels + 1.0) / 2.0

        # Move results to CPU for visualization
        voxels = voxels.cpu()

        if return_dict:
            return {"voxels": voxels}
        return voxels

    @classmethod
    def load_pretrained(cls, save_directory, scheduler_class=None):
        """Load a pretrained pipeline from a directory."""
        # Load the model state
        model_path = os.path.join(save_directory, "model.safetensors")
        model_state = torch.load(model_path)

        # Create a new model instance and load state
        model = VoxelDiffusion()  # You might need to pass config here
        model.load_state_dict(model_state)

        # Load scheduler config
        scheduler_path = os.path.join(save_directory, "scheduler_config.json")
        with open(scheduler_path, "r") as f:
            scheduler_config = json.load(f)

        # Create scheduler
        if scheduler_class is None:
            from diffusers import DDPMScheduler
            scheduler_class = DDPMScheduler

        scheduler = scheduler_class.from_config(scheduler_config)

        return cls(model, scheduler)

def evaluate(config, epoch, pipeline):
    # Sample some voxels from random noise
    output = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.Generator().manual_seed(config.seed),
    )
    voxels = output["voxels"]

    # Convert back to regular voxel format for visualization
    voxels_for_viz = [onehot_to_voxel(voxels[i]) for i in range(config.eval_batch_size)]

    # Create visualization
    fig = visualize_voxel_grid(voxels_for_viz)

    # Save the figure
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    plt.savefig(f"{test_dir}/{epoch:04d}.png")
    plt.close()

In [None]:
from accelerate import Accelerator
from tqdm.auto import tqdm
from pathlib import Path
import os
import torch.nn.functional as F

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )

    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare everything
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Training loop
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_voxels = batch["voxels"]  # Already one-hot encoded from dataset transform

            # Sample noise to add to the voxels
            noise = torch.randn_like(clean_voxels)
            bs = clean_voxels.shape[0]

            # Sample a random timestep for each voxel
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,),
                device=clean_voxels.device, dtype=torch.int64
            )

            # Add noise to the clean voxels according to the noise magnitude at each timestep
            noisy_voxels = noise_scheduler.add_noise(clean_voxels, noise, timesteps)

            with accelerator.accumulate(model):
                # Forward pass
                model_output = model(noisy_voxels, timesteps)

                if isinstance(model_output, tuple):
                    # If model returns multiple outputs (e.g., for VAE)
                    noise_pred = model_output[0]
                    # You might want to add VAE loss terms here if using a VAE
                else:
                    noise_pred = model_output

                # Calculate loss
                loss = F.mse_loss(noise_pred, noise)

                # Backward pass
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Logging
            progress_bar.update(1)
            logs = {
                "loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "step": global_step
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch - sampling and saving
        if accelerator.is_main_process:
            pipeline = VoxelDiffusionPipeline(
                unet=accelerator.unwrap_model(model),
                scheduler=noise_scheduler
            )

            if (epoch + 1) % config.save_sample_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                pipeline.save_pretrained(config.output_dir)

    return accelerator.unwrap_model(model)

In [None]:
from accelerate import notebook_launcher

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

notebook_launcher(train_loop, args, num_processes=1)