Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import trimesh.exchange.binvox as binvox
import torch
import torch.nn as nn
import trimesh

The Autoencoder with Conv3D layers.

The input is size is 64.

In [2]:
class CNN2DVAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(CNN2DVAE, self).__init__()

        reshape = 8  # Adjusted based on downsampling

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # [32, 128, 128]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # [64, 64, 64]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # [128, 32, 32]
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # [256, 16, 16]
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # [512, 8, 8]
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(512 * reshape * reshape, latent_dim)
        self.fc_logvar = nn.Linear(512 * reshape * reshape, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 512 * reshape * reshape)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # [256, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # [128, 32, 32]
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # [64, 64, 64]
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # [32, 128, 128]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),  # [1, 256, 256]
            nn.Sigmoid()  # Output between 0 and 1 for binary image
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.reshape(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        x = self.fc_decode(z)
        x = x.reshape(x.size(0), 512, 8, 8)  # Reshape to convolutional shape
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Loss function remains the same
def vae_loss(recon_x, target_x, mu, logvar):
    recon_loss = nn.functional.binary_cross_entropy(recon_x, target_x, reduction='sum')
    # KL Divergence Loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

Instantiate the model, set loss function and optimizer.

Put the model on to the device

Create the VoxelDataset for dataset loading. It does loads in the object from obj, converts it to voxel_grid and pads the grid to shape (64, 64, 64)

In [3]:
def resize_array_if_large(arr):
    shape = arr.shape
    
    if any(dim > 256 for dim in shape):
        resized_arr = arr[:256, :256, :256]
    else:
        resized_arr = arr

    return resized_arr

def get_voxel_matrix(mesh):
    voxel = mesh.voxelized(pitch=1.0/64)
    voxel_matrix = voxel.matrix.astype(np.float32) 
    
    target_shape = [256, 256, 256]

    voxel_matrix = resize_array_if_large(voxel_matrix)
    
    padded_matrix = np.zeros((256,256,256))
    offset_x = 0
    offset_y = 0
    offset_z = 0
    
    # Place the original matrix in the padded matrix
    padded_matrix[0 + offset_x : offset_x + voxel_matrix.shape[0] ,
                0  + offset_y : offset_y + voxel_matrix.shape[1],
                0 + offset_z : offset_z + voxel_matrix.shape[2]] = voxel_matrix
    
    padded_voxelized = trimesh.voxel.VoxelGrid(padded_matrix)
    voxel_matrix = padded_voxelized.matrix.astype(np.float32) 
    return voxel_matrix

def create_view(voxel_matrix):
    axial_grid = [voxel_matrix[:, :, i] for i in range(voxel_matrix.shape[2])]
    coronal_grid = [voxel_matrix[:, i, :] for i in range(voxel_matrix.shape[1])]
    sagittal_grid = [voxel_matrix[i, :, :] for i in range(voxel_matrix.shape[0])]
    axial_grid = np.array(axial_grid)
    coronal_grid = np.array(coronal_grid)
    sagittal_grid = np.array(sagittal_grid)
    
    first_view = np.zeros((256,256,256,3))
    first_view[..., 0] = axial_grid
    first_view[..., 1] = sagittal_grid
    first_view[..., 2] = coronal_grid
    return first_view

def transform_mesh(mesh):
    voxel_matrix = get_voxel_matrix(mesh)
    view = create_view(voxel_matrix)
    return view

def get_slices(obj):
    axial_slice = obj[:, :, 0]       # 256 x 256
    sagittal_slice = obj[:, 0, :]    # 256 x 256
    coronal_slice = obj[0, :, :]     # 256 x 256

    combined_slice = torch.stack([
        torch.tensor(axial_slice, dtype=torch.float32),
        torch.tensor(sagittal_slice, dtype=torch.float32),
        torch.tensor(coronal_slice, dtype=torch.float32)
    ], dim=0)  # Stack along a new dimension (3, 256, 256)
    
    return combined_slice

class Dataset2D(Dataset):
    def __init__(self, train_path, label_path):
        self.label_dir = label_path
        self.train_dir = train_path

        self.label_list = os.listdir(self.label_dir)
        self.obj_list = os.listdir(self.label_dir)

    def __len__(self):
        return len(self.obj_list)

    def __getitem__(self, idx):

        train_obj = trimesh.load_mesh(self.train_dir + '/' + self.obj_list[idx])
        label_obj = trimesh.load_mesh(self.label_dir + '/' + self.label_list[idx])

        train_obj = transform_mesh(train_obj)
        label_obj = transform_mesh(label_obj)

        train_obj = get_slices(train_obj)
        label_obj = get_slices(label_obj)
        
        print(train_obj.shape)
        print(label_obj.shape)
        

        return train_obj, label_obj

In [4]:
"""class Dataset2D(Dataset):
    def __init__(self, data_folder, label_folder):
        # Load list of files for both data and labels
        self.data_files = sorted([os.path.join(data_folder, f) for f in os.listdir(data_folder)])
        self.label_files = sorted([os.path.join(label_folder, f) for f in os.listdir(label_folder)])


        assert len(self.data_files) == len(self.label_files), "Mismatch between number of data and label files."

        self.num_slices_per_file = 256
        self.total_slices = len(self.data_files) * self.num_slices_per_file

    def __len__(self):
        return self.total_slices

    def __getitem__(self, idx):
        file_idx = idx // self.num_slices_per_file
        slice_idx = idx % self.num_slices_per_file

        data_file = self.data_files[file_idx]
        label_file = self.label_files[file_idx]
        
        data_array = np.load(data_file)
        label_array = np.load(label_file)

        data_slice = data_array[:, slice_idx, :, :]  
        label_slice = label_array[:, slice_idx, :, :]  

        # Convert to PyTorch tensors
        data_tensor = torch.tensor(data_slice, dtype=torch.float32)
        label_tensor = torch.tensor(label_slice, dtype=torch.float32)

        return data_tensor, label_tensor

"""

'class Dataset2D(Dataset):\n    def __init__(self, data_folder, label_folder):\n        # Load list of files for both data and labels\n        self.data_files = sorted([os.path.join(data_folder, f) for f in os.listdir(data_folder)])\n        self.label_files = sorted([os.path.join(label_folder, f) for f in os.listdir(label_folder)])\n\n\n        assert len(self.data_files) == len(self.label_files), "Mismatch between number of data and label files."\n\n        self.num_slices_per_file = 256\n        self.total_slices = len(self.data_files) * self.num_slices_per_file\n\n    def __len__(self):\n        return self.total_slices\n\n    def __getitem__(self, idx):\n        file_idx = idx // self.num_slices_per_file\n        slice_idx = idx % self.num_slices_per_file\n\n        data_file = self.data_files[file_idx]\n        label_file = self.label_files[file_idx]\n        \n        data_array = np.load(data_file)\n        label_array = np.load(label_file)\n\n        data_slice = data_array[:,

Input the data, set voxel resolution and set dataloader. 

Batch size of 1 because only 1 data object.

In [5]:
# Instantiate the model and optimizer
latent_dim = 512
model = CNN2DVAE(latent_dim=latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
batch_size = 128

# Create dataset and dataloader for training
voxel_dataset = Dataset2D("dataset_3d/train/train", "dataset_3d/ground_truth/ground_truth")
#voxel_dataloader = DataLoader(voxel_dataset, batch_size=batch_size, shuffle=True)

In [7]:
from torch.utils.data import random_split, DataLoader

# Example: 80% training, 20% validation
train_size = int(0.8 * len(voxel_dataset))
val_size = len(voxel_dataset) - train_size

# Split dataset into training and validation sets
train_dataset, val_dataset = random_split(voxel_dataset, [train_size, val_size])
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

# Create DataLoaders for both sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Training set size: 6779
Validation set size: 1695


In [8]:
import os
import time

def train_vae_3d(model, train_loader, val_loader, optimizer, epochs=20, device='cuda', checkpoint_path='vae_checkpoint_2D.pth'):
    # Check if checkpoint exists and load it
    start_epoch = 0
    best_val_loss = float('inf')
    model.to(device)  # Move the model to the correct device
    print(f"Training on {device}, model now on {device}")

    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)  # Ensure checkpoint is loaded on the correct device
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1  # Start training from the next epoch
        best_val_loss = checkpoint['best_loss']
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}, with best loss {best_val_loss:.4f}")

    model.train()
    print(f"Starting training from epoch {start_epoch}")

    for epoch in range(start_epoch, epochs):
        
        # Track time for each epoch
        start_time = time.time()
        train_loss = 0
        
        # Training loop
        model.train()  # Set model to training mode
        for input_batch, expected_batch in train_loader:
            print(input_batch.shape)
            


            input_batch = input_batch.to(device)
            expected_batch = expected_batch.to(device)

            optimizer.zero_grad()
            input_batch = input_batch.permute(0, 3, 1, 2) 
            recon_batch, mu, logvar = model(input_batch) 
            expected_batch = expected_batch.permute(0, 3, 1, 2)

            loss = vae_loss(recon_batch, expected_batch, mu, logvar)  
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

        avg_train_loss = train_loss / len(train_loader.dataset)

        # Validation loop
        val_loss = 0
        model.eval()  # Set model to evaluation mode
        with torch.no_grad():  # No need to compute gradients for validation
            for input_batch, expected_batch in val_loader:
                input_batch = input_batch.to(device)
                expected_batch = expected_batch.to(device)

                recon_batch, mu, logvar = model(input_batch)  # Forward pass
                loss = vae_loss(recon_batch, expected_batch, mu, logvar)  # Compute loss
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader.dataset)

        # Track time at the end of the epoch
        end_time = time.time()
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {end_time - start_time:.2f} seconds")

        # Save model checkpoint if validation loss is the best
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_loss': best_val_loss
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch+1}, with validation loss {best_val_loss:.4f}")


Train the model

In [9]:
train_vae_3d(model, train_loader, val_loader, optimizer, epochs=2000, device='mps')

Training on mps, model now on mps
Starting training from epoch 0
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
torch.S

KeyboardInterrupt: 