In [1]:
import os
import pandas as pd
import trimesh
from torch.utils.data import Dataset
import numpy as np
import torch


In [2]:
def resize_array_if_large(arr):
    shape = arr.shape
    
    if any(dim > 256 for dim in shape):
        resized_arr = arr[:256, :256, :256]
    else:
        resized_arr = arr

    return resized_arr

In [3]:
def get_voxel_matrix(mesh):
    voxel = mesh.voxelized(pitch=1.0/64)
    voxel_matrix = voxel.matrix.astype(np.float32) 
    
    target_shape = [256, 256, 256]

    voxel_matrix = resize_array_if_large(voxel_matrix)
    
    padded_matrix = np.zeros((256,256,256))
    offset_x = 0
    offset_y = 0
    offset_z = 0
    
    # Place the original matrix in the padded matrix
    padded_matrix[0 + offset_x : offset_x + voxel_matrix.shape[0] ,
                0  + offset_y : offset_y + voxel_matrix.shape[1],
                0 + offset_z : offset_z + voxel_matrix.shape[2]] = voxel_matrix
    
    padded_voxelized = trimesh.voxel.VoxelGrid(padded_matrix)
    voxel_matrix = padded_voxelized.matrix.astype(np.float32) 
    return voxel_matrix

def create_view(voxel_matrix):
    axial_grid = [voxel_matrix[:, :, i] for i in range(voxel_matrix.shape[2])]
    coronal_grid = [voxel_matrix[:, i, :] for i in range(voxel_matrix.shape[1])]
    sagittal_grid = [voxel_matrix[i, :, :] for i in range(voxel_matrix.shape[0])]
    axial_grid = np.array(axial_grid)
    coronal_grid = np.array(coronal_grid)
    sagittal_grid = np.array(sagittal_grid)
    
    first_view = np.zeros((256,256,256,3))
    first_view[..., 0] = axial_grid
    first_view[..., 1] = sagittal_grid
    first_view[..., 2] = coronal_grid
    return first_view

def transform_mesh(mesh):
    voxel_matrix = get_voxel_matrix(mesh)
    view = create_view(voxel_matrix)
    return view


In [4]:
def get_slices(obj):
    axial_slice = obj[:, :, 0]
    sagittal_slice = obj[:, :, 1]
    coronal_slice = obj[:, :, 2]

    combined_slice = torch.stack([torch.tensor(axial_slice, dtype=torch.float32),
                                    torch.tensor(sagittal_slice, dtype=torch.float32),
                                    torch.tensor(coronal_slice, dtype=torch.float32)], dim=0)
    return combined_slice

In [5]:
class SliceVoxelDataset(Dataset):
    def __init__(self, train_path, label_path):
        self.label_dir = label_path
        self.train_dir = train_path

        self.label_list = os.listdir(self.label_dir)
        self.obj_list = os.listdir(self.label_dir)

    def __len__(self):
        return len(self.obj_list)

    def __getitem__(self, idx):

        train_obj = trimesh.load_mesh(self.train_dir + '/' + self.obj_list[idx])
        label_obj = trimesh.load_mesh(self.label_dir + '/' + self.label_list[idx])

        train_obj = transform_mesh(train_obj)
        label_obj = transform_mesh(label_obj)

        train_obj = get_slices(train_obj)
        label_obj = get_slices(label_obj)
        
        print("Train shape:", train_obj.shape)
        print("Label shape:", label_obj.shape)

        return train_obj, label_obj

In [6]:
voxel_dataset = SliceVoxelDataset("dataset_3d/train", "dataset_3d/ground_truth")

In [7]:
from torch.utils.data import random_split, DataLoader

# Example: 80% training, 20% validation
train_size = int(0.8 * len(voxel_dataset))
val_size = len(voxel_dataset) - train_size

# Split dataset into training and validation sets
train_dataset, val_dataset = random_split(voxel_dataset, [train_size, val_size])

# Create DataLoaders for both sets
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

In [8]:
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # [3, 256, 256] -> [32, 128, 128]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # [32, 128, 128] -> [64, 64, 64]
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(64 * 64 * 64, 128)
        self.fc_logvar = nn.Linear(64 * 64 * 64, 128)
        self.fc_decode = nn.Linear(128, 64 * 64 * 64)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # [64, 64, 64] -> [32, 128, 128]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # [32, 128, 128] -> [3, 256, 256]
            nn.Sigmoid(),  # To get output in the range [0, 1]
        )

    def encode(self, x):
        h1 = self.encoder(x)
        h1 = h1.view(h1.size(0), -1)  # Flatten
        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = self.fc_decode(z)
        h3 = h3.view(-1, 64, 64, 64)
        return self.decoder(h3)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Initialize the model, optimizer, and loss function
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [9]:
for epoch in range(10):  # Example of 10 epochs
    model.train()
    optimizer.zero_grad()
    print(train_dataset[0][0].shape)
    reconstructed, mu, logvar = model(train_dataset[0][0])
    
    # Calculate the loss
    reconstruction_loss = nn.functional.binary_cross_entropy(train_dataset.view(-1, 3 * 256 * 256), train_dataset.view(-1, 3 * 256 * 256), reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss = reconstruction_loss + kl_divergence
    
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

Train shape: torch.Size([3, 256, 256, 3])
Label shape: torch.Size([3, 256, 256, 3])
torch.Size([3, 256, 256, 3])
Train shape: torch.Size([3, 256, 256, 3])
Label shape: torch.Size([3, 256, 256, 3])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x4096 and 262144x128)