Imports

In [None]:
import time
time.sleep(10800)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import trimesh
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import trimesh.exchange.binvox as binvox

The Autoencoder with Conv3D layers.

The input is size is 64.

In [2]:
class CNN3DVAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(CNN3DVAE, self).__init__()

        reshape = 4

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1),  # [32, 32, 32, 32]
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1),  # [64, 16, 16, 16]
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1),  # [128, 8, 8, 8]
            nn.ReLU(),
            nn.Conv3d(128, 256, kernel_size=4, stride=2, padding=1),  # [256, 4, 4, 4]
            nn.ReLU(),
            nn.Conv3d(256, 512, kernel_size=4, stride=2, padding=1),  # [256, 4, 4, 4]
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(512 * reshape * reshape * reshape, latent_dim)
        self.fc_logvar = nn.Linear(512 * reshape * reshape * reshape, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 512 * reshape * reshape * reshape)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1),  # [128, 8, 8, 8]
            nn.ReLU(),
            nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),  # [128, 8, 8, 8]
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),  # [64, 16, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1),  # [32, 32, 32, 32]
            nn.ReLU(),
            nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, padding=1),  # [1, 64, 64, 64]
            nn.Sigmoid()  # Output between 0 and 1
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.fc_decode(z)
        x = x.view(x.size(0), 512, 4, 4, 4)  # Reshape to convolutional shape
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Loss function
def vae_loss(recon_x, target_x, mu, logvar):
    recon_loss = nn.functional.binary_cross_entropy(recon_x, target_x, reduction='sum')
    # KL Divergence Loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss


Instantiate the model, set loss function and optimizer.

Put the model on to the device

Create the VoxelDataset for dataset loading. It does loads in the object from obj, converts it to voxel_grid and pads the grid to shape (64, 64, 64)

In [3]:
class VoxelDataset(Dataset):
    def __init__(self, data_path, labels_path):
        self.data_path = data_path
        self.labels_path = labels_path

        self.data_files = os.listdir(data_path)
        self.label_files = os.listdir(labels_path)

    
    def __len__(self):
        return len(self.data_files)

    def load_pre_processed_voxel(self, file_path):

        with open(file_path, 'rb') as f:
            voxel_grid = binvox.load_binvox(f)
        
        return voxel_grid


    def __getitem__(self, idx):

        
        train_mesh_file = os.path.join(self.data_path, self.data_files[idx])
        train_voxel = self.load_pre_processed_voxel(train_mesh_file)
        train_voxel_matrix = train_voxel.matrix.astype(np.float32) 

        label_mesh_file = os.path.join(self.labels_path, self.label_files[idx])
        label_voxel = self.load_pre_processed_voxel(label_mesh_file)
        label_voxel_matrix = label_voxel.matrix.astype(np.float32)

        input_voxel = np.expand_dims(train_voxel_matrix, axis=0)
        target_voxel = np.expand_dims(label_voxel_matrix, axis=0)


        return torch.tensor(input_voxel), torch.tensor(target_voxel)       

Input the data, set voxel resolution and set dataloader. 

Batch size of 1 because only 1 data object.

In [4]:
# Instantiate the model and optimizer
latent_dim = 512
model = CNN3DVAE(latent_dim=latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [5]:
batch_size = 16

# Create dataset and dataloader for training
voxel_dataset = VoxelDataset("Data/train5/Pre-processed/128/train", "Data/train5/Pre-processed/128/labels")
#voxel_dataloader = DataLoader(voxel_dataset, batch_size=batch_size, shuffle=True)

In [6]:
from torch.utils.data import random_split, DataLoader

# Example: 80% training, 20% validation
train_size = int(0.8 * len(voxel_dataset))
val_size = len(voxel_dataset) - train_size

# Split dataset into training and validation sets
train_dataset, val_dataset = random_split(voxel_dataset, [train_size, val_size])

# Create DataLoaders for both sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
import os
import time

def train_vae_3d(model, train_loader, val_loader, optimizer, epochs=20, device='cuda', checkpoint_path='vae_checkpoint.pth'):
    # Check if checkpoint exists and load it
    start_epoch = 0
    best_val_loss = float('inf')
    model.to(device)  # Move the model to the correct device
    print(f"Training on {device}, model now on {device}")

    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)  # Ensure checkpoint is loaded on the correct device
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1  # Start training from the next epoch
        best_val_loss = checkpoint['best_loss']
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}, with best loss {best_val_loss:.4f}")

    model.train()
    print(f"Starting training from epoch {start_epoch}")

    for epoch in range(start_epoch, epochs):
        # Track time for each epoch
        start_time = time.time()
        train_loss = 0
        
        # Training loop
        model.train()  # Set model to training mode
        for input_batch, expected_batch in train_loader:

            input_batch = input_batch.to(device)
            expected_batch = expected_batch.to(device)

            optimizer.zero_grad()
            recon_batch, mu, logvar = model(input_batch)  # Forward pass


            loss = vae_loss(recon_batch, expected_batch, mu, logvar)  # Compute loss
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

        avg_train_loss = train_loss / len(train_loader.dataset)

        # Validation loop
        val_loss = 0
        model.eval()  # Set model to evaluation mode
        with torch.no_grad():  # No need to compute gradients for validation
            for input_batch, expected_batch in val_loader:
                input_batch = input_batch.to(device)
                expected_batch = expected_batch.to(device)

                recon_batch, mu, logvar = model(input_batch)  # Forward pass
                loss = vae_loss(recon_batch, expected_batch, mu, logvar)  # Compute loss
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader.dataset)

        # Track time at the end of the epoch
        end_time = time.time()
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {end_time - start_time:.2f} seconds")

        # Save model checkpoint if validation loss is the best
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_loss': best_val_loss
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch+1}, with validation loss {best_val_loss:.4f}")


Train the model

In [None]:
train_vae_3d(model, train_loader, val_loader, optimizer, epochs=2000, device='cuda')

Training on cuda, model now on cuda


  checkpoint = torch.load(checkpoint_path, map_location=device)  # Ensure checkpoint is loaded on the correct device


Loaded checkpoint from epoch 0, with best loss 38703.2702
Starting training from epoch 1


Loaded checkpoint from epoch 2, with best loss 81482.4339
Starting training from epoch 3
Epoch 4, Train Loss: 78760.3150, Val Loss: 78828.6544, Time: 110.61 seconds
Checkpoint saved at epoch 4, with validation loss 78828.6544
Epoch 5, Train Loss: 78374.0195, Val Loss: 78582.1556, Time: 109.58 seconds
Checkpoint saved at epoch 5, with validation loss 78582.1556
Epoch 6, Train Loss: 77333.7012, Val Loss: 77044.7628, Time: 107.79 seconds
Checkpoint saved at epoch 6, with validation loss 77044.7628
Epoch 7, Train Loss: 76398.3599, Val Loss: 77033.0401, Time: 107.53 seconds
Checkpoint saved at epoch 7, with validation loss 77033.0401
Epoch 8, Train Loss: 75108.3008, Val Loss: 74105.2728, Time: 107.97 seconds
Checkpoint saved at epoch 8, with validation loss 74105.2728
Epoch 9, Train Loss: 73875.8382, Val Loss: 74564.8534, Time: 107.26 seconds
Epoch 10, Train Loss: 73709.9448, Val Loss: 72756.2539, Time: 107.31 seconds
Checkpoint saved at epoch 10, with validation loss 72756.2539
Epoch 11, Train Loss: 72947.6701, Val Loss: 73899.4487, Time: 107.45 seconds
Epoch 12, Train Loss: 72084.1303, Val Loss: 72219.8246, Time: 107.40 seconds
Checkpoint saved at epoch 12, with validation loss 72219.8246
Epoch 13, Train Loss: 71603.3918, Val Loss: 71182.3434, Time: 111.30 seconds
Checkpoint saved at epoch 13, with validation loss 71182.3434
Epoch 14, Train Loss: 71557.7964, Val Loss: 71981.1636, Time: 113.69 seconds
Epoch 15, Train Loss: 71426.3750, Val Loss: 70903.6943, Time: 113.50 seconds
Checkpoint saved at epoch 15, with validation loss 70903.6943
Epoch 16, Train Loss: 70913.7461, Val Loss: 70715.3565, Time: 111.35 seconds
Checkpoint saved at epoch 16, with validation loss 70715.3565
Epoch 17, Train Loss: 70652.2332, Val Loss: 71125.3929, Time: 110.15 seconds
Epoch 18, Train Loss: 70184.2880, Val Loss: 69756.0963, Time: 110.55 seconds
Checkpoint saved at epoch 18, with validation loss 69756.0963
Epoch 19, Train Loss: 70230.5434, Val Loss: 69405.5184, Time: 112.12 seconds
Checkpoint saved at epoch 19, with validation loss 69405.5184
Epoch 20, Train Loss: 69758.8516, Val Loss: 72031.8733, Time: 112.99 seconds
Epoch 21, Train Loss: 69550.8717, Val Loss: 69890.1854, Time: 111.46 seconds
Epoch 22, Train Loss: 69617.5468, Val Loss: 71135.8658, Time: 111.18 seconds
Epoch 23, Train Loss: 69368.2436, Val Loss: 68601.0923, Time: 111.74 seconds
Checkpoint saved at epoch 23, with validation loss 68601.0923
Epoch 24, Train Loss: 68895.7006, Val Loss: 68567.4258, Time: 110.98 seconds
Checkpoint saved at epoch 24, with validation loss 68567.4258
Epoch 25, Train Loss: 68823.1834, Val Loss: 68181.7952, Time: 112.11 seconds
Checkpoint saved at epoch 25, with validation loss 68181.7952
Epoch 26, Train Loss: 69036.1165, Val Loss: 69682.0627, Time: 114.03 seconds
Epoch 27, Train Loss: 68734.5376, Val Loss: 68092.8073, Time: 110.79 seconds
Checkpoint saved at epoch 27, with validation loss 68092.8073
Epoch 28, Train Loss: 68322.9151, Val Loss: 68022.2327, Time: 110.74 seconds
Checkpoint saved at epoch 28, with validation loss 68022.2327
Epoch 29, Train Loss: 68399.0886, Val Loss: 68552.3808, Time: 107.57 seconds
Epoch 30, Train Loss: 68623.9734, Val Loss: 68372.5990, Time: 112.72 seconds
Epoch 31, Train Loss: 68390.2851, Val Loss: 70227.0090, Time: 109.74 seconds
Epoch 32, Train Loss: 68243.8787, Val Loss: 67727.8196, Time: 108.38 seconds
Checkpoint saved at epoch 32, with validation loss 67727.8196
Epoch 33, Train Loss: 67820.6121, Val Loss: 68789.5640, Time: 109.06 seconds
Epoch 34, Train Loss: 67856.4738, Val Loss: 67286.1321, Time: 108.65 seconds
Checkpoint saved at epoch 34, with validation loss 67286.1321
Epoch 35, Train Loss: 68019.6468, Val Loss: 68550.6108, Time: 109.48 seconds
Epoch 36, Train Loss: 67811.4905, Val Loss: 72664.1964, Time: 109.53 seconds
Epoch 37, Train Loss: 67780.5393, Val Loss: 67641.0478, Time: 109.29 seconds
Epoch 38, Train Loss: 67641.5571, Val Loss: 66998.9193, Time: 110.17 seconds
Checkpoint saved at epoch 38, with validation loss 66998.9193
Epoch 39, Train Loss: 67849.5221, Val Loss: 67176.3940, Time: 109.53 seconds
Epoch 40, Train Loss: 67861.7164, Val Loss: 67395.5032, Time: 108.60 seconds
Epoch 41, Train Loss: 67428.7246, Val Loss: 67391.4132, Time: 110.91 seconds
Epoch 42, Train Loss: 67550.3604, Val Loss: 67247.8164, Time: 113.49 seconds