In [None]:
import os
import numpy as np
import trimesh
from skimage.measure import block_reduce
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

desired_shape = (16, 16, 16)
# desired_shape = (32, 32, 32)
desired_voxel_size = 1.00

base_dir = 'Big-Data-CNN-3D-object-detection/data'

categories = os.listdir(base_dir)

label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(categories)

# Collect labels during the generation process
labels = []

def generator():
    for label, category in zip(labels_encoded, categories):
        category_dir = os.path.join(base_dir, category)
        labels.append(label)
        for file in os.listdir(category_dir):
            if file.endswith('.obj'):
                mesh = trimesh.load_mesh(os.path.join(category_dir, file))
                if isinstance(mesh, trimesh.Scene):
                    mesh = trimesh.util.concatenate(mesh.dump())

                voxel_grid = mesh.voxelized(pitch=desired_voxel_size)
                voxels = voxel_grid.matrix
                if voxels.shape != desired_shape:
                    factor = tuple([int(np.ceil(n_i/n_o)) for n_i, n_o in zip(voxels.shape, desired_shape)])
                    voxels = block_reduce(voxels, block_size=factor, func=np.max)

                # Add padding if necessary
                pad_x = max(0, desired_shape[0] - voxels.shape[0])
                pad_y = max(0, desired_shape[1] - voxels.shape[1])
                pad_z = max(0, desired_shape[2] - voxels.shape[2])

                voxels = np.pad(voxels, ((0, pad_x), (0, pad_y), (0, pad_z)), 'constant')

                # Truncate if necessary
                voxels = voxels[:desired_shape[0], :desired_shape[1], :desired_shape[2]]

                voxels = np.expand_dims(voxels, axis=-1)
                tensor = tf.convert_to_tensor(voxels)
                yield tensor, label
    print("Finished processing all .obj files.")

# Define the output signatures for the generator function
output_signature = (
    tf.TensorSpec(shape=(desired_shape[0], desired_shape[1], desired_shape[2], 1), dtype=tf.float32),
    tf.TensorSpec(shape=(), dtype=tf.int32)
)

# Create a tf.data.Dataset from the generator
dataset = tf.data.Dataset.from_generator(
     generator,
     output_signature=output_signature
)

# Shuffle and batch the dataset
# dataset = dataset.shuffle(1000).batch(32)
dataset = dataset.shuffle(1000).batch(16)

# Split the dataset into training and validation sets
train_size = int(0.8 * dataset.cardinality().numpy())
val_size = int(0.2 * dataset.cardinality().numpy())

train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size).take(val_size)

# Print out the shapes to check everything is as expected
print('Train Dataset element spec:', train_dataset.element_spec)
print('Validation Dataset element spec:', val_dataset.element_spec)