In [5]:
import os
import numpy as np
import torch
import torch.nn as nn

from model.unet import Unet3D
from model.Diffusion import Diffusion_Control
from model.EMA import ExponentialMovingAverage

from dataset.SimpleShapeDataset.VoxelData import VoxelDataset, ConditionDataset
from torch.utils.data import Dataset, DataLoader

from utils.visualization import visualize_voxel_map
from diffusers.optimization import get_cosine_schedule_with_warmup

In [6]:
voxel_dataset_dir = "dataset/SimpleShapeDataset/voxel_datasets"

voxel_dataset = VoxelDataset(voxel_dataset_dir)

# Create the DataLoader
batch_size = 2
voxel_dataloader = DataLoader(voxel_dataset, batch_size=batch_size, shuffle=True)


condition_dataset_dir = "dataset/SimpleShapeDataset/condition_datasets"

condition_dataset = ConditionDataset(voxel_dataset_dir)

# Create the DataLoader
batch_size = 2
condition_dataloader = DataLoader(condition_dataset, batch_size=batch_size, shuffle=True)

In [3]:
class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()
        self.mse = nn.MSELoss(reduction='mean')

    def forward(self, output, target):
        return torch.sqrt(self.mse(output, target))

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

batch_size = 2
model_ema_steps = 10
num_epochs = 70
model_ema_decay = 0.995

# Define the model
model = Diffusion_Control(
    timesteps=1000,
    image_size=64,
    in_channels=1,
    base_dim=32,
    dim_mults=[1, 2, 4, 8]
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=(len(voxel_dataloader) * num_epochs),
)
adjust = 1 * batch_size * model_ema_steps / num_epochs
alpha = 1.0 - model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = ExponentialMovingAverage(model, device=device, decay=1.0 - alpha)

loss_fn = nn.MSELoss(reduction='mean')
min_loss = np.inf
global_steps = 0

os.makedirs("results", exist_ok=True)
checkpoint_path = "results/steps_01244000.pt"

# Load checkpoint if exists
if os.path.exists(checkpoint_path):
    print("Loading best checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    unet_state_dict = {k.replace("model.", " "): v for k, v in checkpoint['model'].items() if k.startswith("model.")}
    unet_state_dict = {k.strip(): v for k, v in unet_state_dict.items() if k.startswith(" ")}

    model.unet.load_state_dict(unet_state_dict)  # Load UNet parameters only
    # model_ema.load_state_dict(ema_state_dict)  
    print("Checkpoint loaded successfully!")

# Training loop
for epoch in range(num_epochs):
    model.train()
    for (voxel_batch, condition_batch) in zip(voxel_dataloader, condition_dataloader):
        # Prepare inputs
        noise = torch.randn_like(voxel_batch).to(device)
        voxel_batch = voxel_batch.to(device)
        condition_batch = condition_batch.to(device)

        # Model forward pass

        pred = model(voxel_batch,condition_batch, noise)  # Ensure model supports conditioning
        noise = noise.unsqueeze(1)
        loss = loss_fn(pred, noise)

        # Backpropagation
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        # Update EMA
        if global_steps % model_ema_steps == 0:
            model_ema.update_parameters(model)

        # Logging
        if global_steps % 100 == 0:
            print(f"Epoch[{epoch + 1}/{num_epochs}], Step[{global_steps}], Loss: {loss.item():.4f}, lr: {lr_scheduler.get_last_lr()[0]:.6f}")

        # Save best model
        if loss.item() < min_loss and epoch > 1:
            min_loss = loss.item()
            torch.save(
                {"model": model.state_dict(), "model_ema": model_ema.state_dict()},
                f"results/best.pt"
            )
        global_steps += 1

    # Save checkpoint
    torch.save(
        {"model": model.state_dict(), "model_ema": model_ema.state_dict()},
        f"condition_results/steps_{global_steps:08d}.pt"
    )

print("Training complete!")

Device: cuda
Loading best checkpoint...
Checkpoint loaded successfully!
Epoch[1/70], Step[1074000], Loss: 0.1951, lr: 0.000001
Epoch[1/70], Step[1074100], Loss: 0.0522, lr: 0.000100
Epoch[1/70], Step[1074200], Loss: 0.1357, lr: 0.000100
Epoch[1/70], Step[1074300], Loss: 0.3089, lr: 0.000100
Epoch[1/70], Step[1074400], Loss: 0.1186, lr: 0.000100
Epoch[1/70], Step[1074500], Loss: 0.0588, lr: 0.000100
Epoch[1/70], Step[1074600], Loss: 0.0775, lr: 0.000100
Epoch[1/70], Step[1074700], Loss: 0.0315, lr: 0.000100
Epoch[1/70], Step[1074800], Loss: 0.0510, lr: 0.000100
Epoch[1/70], Step[1074900], Loss: 0.0799, lr: 0.000100
Epoch[1/70], Step[1075000], Loss: 0.1750, lr: 0.000100
Epoch[1/70], Step[1075100], Loss: 0.0313, lr: 0.000100
Epoch[1/70], Step[1075200], Loss: 0.0396, lr: 0.000100
Epoch[1/70], Step[1075300], Loss: 0.1378, lr: 0.000100
Epoch[1/70], Step[1075400], Loss: 0.1062, lr: 0.000100
Epoch[1/70], Step[1075500], Loss: 0.0733, lr: 0.000100
Epoch[1/70], Step[1075600], Loss: 0.0884, lr: 0.

RuntimeError: Parent directory condition_results does not exist.

In [None]:
device = "cuda"
print(device)
batch_size = 2; model_ema_steps = 10;num_epochs = 20; model_ema_decay = 0.995

# Define the model
model = Diffusion(timesteps=1000,
                        image_size=64,
                        in_channels=1,
                        base_dim=32,
                        dim_mults=[1,2, 4,8]).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=(len(voxel_dataloader) * num_epochs),
)
adjust = 1* batch_size * model_ema_steps / num_epochs
alpha = 1.0 - model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = ExponentialMovingAverage(model, device=device, decay=1.0 - alpha)
# Example usage:
# loss_fn = RMSELoss()
loss_fn = nn.MSELoss(reduction='mean')
min_loss = np.inf
global_steps = 40000

os.makedirs("results", exist_ok=True)
checkpoint_path = "results/steps_01174000.pt"

# if os.path.exists(checkpoint_path):
#     print("Loading best checkpoint...")

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])

save_dir = "output_voxel_maps"
os.makedirs(save_dir, exist_ok=True)

model_ema.load_state_dict(checkpoint["model_ema"])
# model_ema = ExponentialMovingAverage(model, device=device, decay=1.0 - 0.1)
samples = []
for i in range(32):
        samples.append(model_ema.module.sampling(1, device=device))

for i in range(32):
        voxel_1d_array = samples[i].cpu().numpy()
        # voxel_1d_array += 1
        # voxel_1d_array *=0.5
        binary_data = (voxel_1d_array > 0.5).astype(int)
        import numpy as np
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        # print(binary_data)
        # Parameters for voxel map
        voxel_size = 0.25
        grid_size = 64

        # Reshape binary data into 8x8x8 by averaging blocks of 4x4x4
        reshaped_data = binary_data.reshape(grid_size, 1,grid_size, 1, grid_size, 1).mean(axis=(1, 3, 5))
        voxel_data = (reshaped_data > 0.5).astype(int)  # Convert to binary based on average

        # Prepare the 3D plot
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')

        # Create a 3D grid for the voxel map (dimensions + 1 to align with voxel corners)
        x, y, z = np.indices((grid_size + 1, grid_size + 1, grid_size + 1)) * voxel_size

        # Display voxels
        filled_voxels = (voxel_data == 1)

        ax.voxels(x, y, z, filled_voxels, 
                facecolors="blue", edgecolors="black", alpha=0.7)

        # Set labels and aspect ratio
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_aspect('auto')
        plt.title("8x8x8 Voxel Map")

        # Show plot
        # plt.show()
        save_path = os.path.join(save_dir, f"sample_{i+33}.png")
        plt.savefig(save_path)
        plt.close(fig)

    # plt.savefig(save_dir)