In [1]:
import os
import numpy as np
import torch
import torch.nn as nn

from model.unet import Unet3D
from model.Diffusion import Diffusion_Control
from model.EMA import ExponentialMovingAverage

from dataset.SimpleShapeDataset.VoxelData import VoxelDataset, ConditionDataset
from torch.utils.data import Dataset, DataLoader

from utils.visualization import visualize_voxel_map
from diffusers.optimization import get_cosine_schedule_with_warmup

In [2]:
voxel_dataset_dir = "dataset/SimpleShapeDataset/voxel_datasets"

voxel_dataset = VoxelDataset(voxel_dataset_dir)

# Create the DataLoader
batch_size = 2
voxel_dataloader = DataLoader(voxel_dataset, batch_size=batch_size, shuffle=False)


condition_dataset_dir = "dataset/SimpleShapeDataset/condition_datasets"

condition_dataset = ConditionDataset(voxel_dataset_dir)

# Create the DataLoader
batch_size = 2
condition_dataloader = DataLoader(condition_dataset, batch_size=batch_size, shuffle=False)

In [3]:
class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
        self.mse = nn.MSELoss(reduction='mean')

    def forward(self, output, target):
        return torch.sqrt(self.mse(output, target))

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

batch_size = 2
model_ema_steps = 10
num_epochs = 70
model_ema_decay = 0.995

# Define the model
model = Diffusion_Control(
    timesteps=1000,
    image_size=64,
    in_channels=1,
    base_dim=32,
    dim_mults=[1, 2, 4, 8]
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=(len(voxel_dataloader) * num_epochs),
)
adjust = 1 * batch_size * model_ema_steps / num_epochs
alpha = 1.0 - model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = ExponentialMovingAverage(model, device=device, decay=1.0 - alpha)

loss_fn = nn.MSELoss(reduction='mean')
min_loss = np.inf
global_steps = 0

os.makedirs("results", exist_ok=True)
checkpoint_path = "condition_results/steps_00140000.pt"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model_ema.load_state_dict(checkpoint["model_ema"])
# Load checkpoint if exists
# if os.path.exists(checkpoint_path):
#     print("Loading best checkpoint...")
#     checkpoint = torch.load(checkpoint_path)
#     unet_state_dict = {k.replace("model.", " "): v for k, v in checkpoint['model'].items() if k.startswith("model.")}
#     unet_state_dict = {k.strip(): v for k, v in unet_state_dict.items() if k.startswith(" ")}

#     model.unet.load_state_dict(unet_state_dict)  # Load UNet parameters only
#     # model_ema.load_state_dict(ema_state_dict)  
#     print("Checkpoint loaded successfully!")

# Training loop
for epoch in range(num_epochs):
    model.train()
    for (voxel_batch, condition_batch) in zip(voxel_dataloader, condition_dataloader):
        # Prepare inputs
        noise = torch.randn_like(voxel_batch).to(device)
        voxel_batch = voxel_batch.to(device)
        condition_batch = condition_batch.to(device)

        # Model forward pass

        pred = model(voxel_batch,condition_batch, noise)  # Ensure model supports conditioning
        noise = noise.unsqueeze(1)
        loss = loss_fn(pred, noise)

        # Backpropagation
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        # Update EMA
        if global_steps % model_ema_steps == 0:
            model_ema.update_parameters(model)

        # Logging
        if global_steps % 100 == 0:
            print(f"Epoch[{epoch + 1}/{num_epochs}], Step[{global_steps}], Loss: {loss.item():.4f}, lr: {lr_scheduler.get_last_lr()[0]:.6f}")

        # Save best model
        if loss.item() < min_loss and epoch > 1:
            min_loss = loss.item()
            torch.save(
                {"model": model.state_dict(), "model_ema": model_ema.state_dict()},
                f"results/best.pt"
            )
        global_steps += 1

    # Save checkpoint
    torch.save(
        {"model": model.state_dict(), "model_ema": model_ema.state_dict()},
        f"condition_results/steps_{global_steps:08d}.pt"
    )

print("Training complete!")

Device: cuda
Loading best checkpoint...
Checkpoint loaded successfully!
Epoch[1/70], Step[0], Loss: 0.0722, lr: 0.000000
Epoch[1/70], Step[100], Loss: 0.0067, lr: 0.000010
Epoch[1/70], Step[200], Loss: 0.0114, lr: 0.000010
Epoch[1/70], Step[300], Loss: 0.0054, lr: 0.000010
Epoch[1/70], Step[400], Loss: 0.0219, lr: 0.000010
Epoch[1/70], Step[500], Loss: 0.0033, lr: 0.000010
Epoch[1/70], Step[600], Loss: 0.0018, lr: 0.000010
Epoch[1/70], Step[700], Loss: 0.0827, lr: 0.000010
Epoch[1/70], Step[800], Loss: 0.0825, lr: 0.000010
Epoch[1/70], Step[900], Loss: 0.0106, lr: 0.000010
Epoch[1/70], Step[1000], Loss: 0.0145, lr: 0.000010
Epoch[1/70], Step[1100], Loss: 0.0243, lr: 0.000010
Epoch[1/70], Step[1200], Loss: 0.0119, lr: 0.000010
Epoch[1/70], Step[1300], Loss: 0.0045, lr: 0.000010
Epoch[1/70], Step[1400], Loss: 0.0064, lr: 0.000010
Epoch[1/70], Step[1500], Loss: 0.0029, lr: 0.000010
Epoch[1/70], Step[1600], Loss: 0.0153, lr: 0.000010
Epoch[1/70], Step[1700], Loss: 0.0762, lr: 0.000010
Epoc

In [6]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Paths and directories
checkpoint_path = "condition_results/steps_00140000.pt"
save_dir = "generated_samples"
os.makedirs(save_dir, exist_ok=True)

# Load the trained model and EMA
model = Diffusion_Control(
    timesteps=1000,
    image_size=64,
    in_channels=1,
    base_dim=32,
    dim_mults=[1, 2, 4, 8]
).to(device)

model_ema = ExponentialMovingAverage(model, device=device, decay=1.0 - 0.995)

if os.path.exists(checkpoint_path):
    print("Loading best checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model"])
    model_ema.load_state_dict(checkpoint["model_ema"])
    print("Checkpoint loaded successfully!")

model.eval()  # Switch to evaluation mode
model_ema.eval()

# Data loader for control dataset
# control_dataset_loader = ...  # Define your control dataset loader here
condition_dataloader = DataLoader(condition_dataset, batch_size=1, shuffle=True)
samples = []
with torch.no_grad():
    for i, condition_batch in enumerate(condition_dataloader):
        condition_batch = condition_batch.to(device)

        # Generate noise for sampling
        batch_size = condition_batch.size(0)
        noise = torch.randn(1, 1, 64, 64,64).to(device)

        # Perform sampling with the model
        generated_samples = model_ema.module.sampling(batch_size, condition_batch, noise)
        
        # Process samples
        for j in range(batch_size):
            voxel_1d_array = generated_samples[j].cpu().numpy()

            # Convert to binary voxel data
            binary_data = (voxel_1d_array > 0.5).astype(int)

            # Parameters for voxel map
            voxel_size = 0.25
            grid_size = 64

            # Reshape binary data into 8x8x8
            reshaped_data = binary_data.reshape(grid_size, 1, grid_size, 1, grid_size, 1).mean(axis=(1, 3, 5))
            voxel_data = (reshaped_data > 0.5).astype(int)

            # Prepare the 3D plot
            fig = plt.figure(figsize=(10, 10))
            ax = fig.add_subplot(111, projection='3d')

            # Create a 3D grid for the voxel map
            x, y, z = np.indices((grid_size + 1, grid_size + 1, grid_size + 1)) * voxel_size

            # Display voxels
            filled_voxels = (voxel_data == 1)
            ax.voxels(x, y, z, filled_voxels, facecolors="blue", edgecolors="black", alpha=0.7)

            # Set labels and aspect ratio
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
            ax.set_aspect('auto')
            plt.title("8x8x8 Voxel Map")

            # Save the plot
            save_path = os.path.join(save_dir, f"sample_{i * batch_size + j + 1}.png")
            plt.savefig(save_path)
            plt.close(fig)

        # Break loop after generating desired number of samples
        if len(samples) >= 2:
            break

print("Generation complete!")


Device: cuda
Loading best checkpoint...
Checkpoint loaded successfully!


Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.68it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.61it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.68it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.60it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.61it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.57it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.47it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.50it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.36it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.46it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.42it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.44it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.37it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.62it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.59it/s]
Sampling: 100%|██████████| 1000/1000 [00:32<00:00, 30.65it/s]
Sampling