In [1]:
import os
import numpy as np
import trimesh
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

def get_voxel_matrix(mesh):
    voxel = mesh.voxelized(pitch=1.0 / 64)
    voxel_matrix = voxel.matrix.astype(np.float32)
    dimension = 512
    target_shape = (dimension, dimension, dimension)
    
    # Initialize a padded matrix with zeros
    padded_matrix = np.zeros(target_shape, dtype=np.float32)
    offset = 128

    # Ensure we don't go out of bounds
    if voxel_matrix.shape[0] + offset > dimension:
        offset = dimension - voxel_matrix.shape[0]
    if voxel_matrix.shape[1] + offset > dimension:
        offset = dimension - voxel_matrix.shape[1]
    if voxel_matrix.shape[2] + offset > dimension:
        offset = dimension - voxel_matrix.shape[2]

    # Place the original matrix in the padded matrix
    padded_matrix[offset:offset + voxel_matrix.shape[0],
                  offset:offset + voxel_matrix.shape[1],
                  offset:offset + voxel_matrix.shape[2]] = voxel_matrix
    
    padded_voxelized = trimesh.voxel.VoxelGrid(padded_matrix)
    return padded_voxelized.matrix.astype(np.float32)

def get_coronal_view(voxel_matrix):
    # Use numpy's transpose to avoid loop
    coronal_grid = np.transpose(voxel_matrix, (1, 0, 2))
    return coronal_grid

def get_item_data(mesh):
    voxel_matrix = get_voxel_matrix(mesh)
    return get_coronal_view(voxel_matrix)

def process_file(filename):
    mesh = trimesh.load_mesh('dataset_3d/train/train/' + filename)
    ground_truth = trimesh.load_mesh('dataset_3d/ground_truth/ground_truth/' + filename)

    coronal_grid = get_item_data(mesh)
    gt_coronal_grid = get_item_data(ground_truth)

    base_filename = filename.split('.')[0]
    np.save(f'coronal_dataset/train/{base_filename}.npy', coronal_grid)
    np.save(f'coronal_dataset/ground_truth/{base_filename}.npy', gt_coronal_grid)

os.makedirs('coronal_dataset', exist_ok=True)
os.makedirs('coronal_dataset/train', exist_ok=True)
os.makedirs('coronal_dataset/ground_truth', exist_ok=True)

# Create a thread pool to process files concurrently
with ThreadPoolExecutor() as executor:
    futures = []
    for filename in tqdm(os.listdir('dataset_3d/train/train')):
        futures.append(executor.submit(process_file, filename))

    # Optionally wait for all futures to complete and handle exceptions
    for future in as_completed(futures):
        try:
            future.result()  # This will raise any exception caught during processing
        except Exception as e:
            print(f'Error processing file: {e}')


100%|██████████| 8473/8473 [00:00<00:00, 111374.37it/s]
