In [None]:
# pip install -U trimesh[easy]

In [None]:
# pip install pyglet==1.5.11

In [1]:
import numpy as np
import pyglet
import trimesh
import os
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import gzip
import io
import matplotlib.pyplot as plt

  warn(f"Failed to load image Python extension: {e}")


In [2]:
class McGillDataset(Dataset):
    def __init__(self, root_dir: str, file_type: str = 'ply', transform=None):
        """
        Custom dataset for McGill dataset.

        Parameters:
        - root_dir (str): Root directory of the dataset.
        - file_type (str, optional): Type of file to load data from. Defaults to 'ply'.
        - transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.file_type = file_type
        self.transform = transform

        # List all the categories (subdirectories) in the root directory
        self.categories = os.listdir(root_dir)

        # Create a dictionary to map category names to their respective subdirectories
        self.category_paths = {category: os.path.join(root_dir, category) for category in self.categories}

        self.all_data_paths = self.get_all_data_paths()

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.all_data_paths)

    def __getitem__(self, idx):
        """
        Returns a sample from the dataset.

        Parameters:
        - idx (int): Index of the sample.

        Returns:
        - sample (dict): A dictionary containing 'data' (3D mesh) and 'label' (category name).
        """
        # Load the data (3D mesh) using the appropriate method
        data_path = self.all_data_paths[idx]
        category = self.get_category_from_path(data_path)

        if self.file_type == 'Im':
            with gzip.open(data_path, 'rb') as f:
                data = Image.open(io.BytesIO(f.read()))
        elif self.file_type == 'Ply':
            if data_path.endswith('.gz'):
                with gzip.open(data_path, 'rb') as f:
                    file_extension = os.path.splitext(data_path[:-3])[1]
                    data = trimesh.load(file_obj=io.BytesIO(f.read()), file_type=file_extension, process=False)
            else:
                file_extension = os.path.splitext(data_path)[1]
                data = trimesh.load(data_path, file_type=file_extension, process=False)

        # Apply the optional transform
        if self.transform:
            data = self.transform(data)

        # Prepare the sample dictionary
        sample = {'data': data, 'label': category}

        return sample

    def get_category_from_path(self, data_path):
        """
        Extract category name from the data path.

        Parameters:
        - data_path (str): Full path to the data file.

        Returns:
        - category (str): Category name.
        """
        # Extract category name from the path
        category = os.path.basename(os.path.dirname(os.path.dirname(data_path)))
        return category


    def get_all_data_paths(self):
        """
        Return a list of all data paths in the dataset.
        """
        all_paths = []
        for category in self.categories:
            category_path = self.category_paths[category]
            data_paths = os.listdir(os.path.join(category_path, f'{category}{self.file_type}'))
            full_paths = [os.path.join(category_path, f'{category}{self.file_type}', data_name) for data_name in data_paths]
            all_paths.extend(full_paths)
        return all_paths

    def get_mesh(self, idx):
        """
        Return the 3D mesh.

        Parameters:
        - idx (int): Index of the sample.
        """
        return self.__getitem__(idx)['data']
    

In [3]:
root_dir = './articulated/'
# Create an instance of the McGillDataset class for Ply files
mcgill_dataset_ply = McGillDataset(root_dir, file_type='Ply')

In [5]:
print(len(mcgill_dataset_ply))
mcgill_dataset_ply.all_data_paths

255


['./articulated/ants\\antsPly\\1.ply.gz',
 './articulated/ants\\antsPly\\11.ply.gz',
 './articulated/ants\\antsPly\\12.ply.gz',
 './articulated/ants\\antsPly\\13.ply.gz',
 './articulated/ants\\antsPly\\14.ply.gz',
 './articulated/ants\\antsPly\\16.ply.gz',
 './articulated/ants\\antsPly\\17.ply.gz',
 './articulated/ants\\antsPly\\18.ply.gz',
 './articulated/ants\\antsPly\\19.ply.gz',
 './articulated/ants\\antsPly\\2.ply.gz',
 './articulated/ants\\antsPly\\20.ply.gz',
 './articulated/ants\\antsPly\\21.ply.gz',
 './articulated/ants\\antsPly\\22.ply.gz',
 './articulated/ants\\antsPly\\23.ply.gz',
 './articulated/ants\\antsPly\\26.ply.gz',
 './articulated/ants\\antsPly\\27.ply.gz',
 './articulated/ants\\antsPly\\28.ply.gz',
 './articulated/ants\\antsPly\\29.ply.gz',
 './articulated/ants\\antsPly\\3.ply.gz',
 './articulated/ants\\antsPly\\30.ply.gz',
 './articulated/ants\\antsPly\\31.ply.gz',
 './articulated/ants\\antsPly\\32.ply.gz',
 './articulated/ants\\antsPly\\33.ply.gz',
 './articulate

In [None]:
def generate2dImage(mesh,input_file_name,item): 
    # Assuming meshes is an array of Trimesh objects
    scene = mesh.scene()
    label = item['label']
    # Define your ranges for theta and phi

    number_of_rotation = 3
    theta = 2*np.pi/number_of_rotation

    for i in range(number_of_rotation):
        for j in range(number_of_rotation):
            # Rotation matrix around the Y-axis (theta)
            rotate_y = trimesh.transformations.rotation_matrix(
                angle=theta, direction=[0, 1, 0], point=scene.centroid)

            # Rotation matrix around the X-axis (phi)
            rotate_x = trimesh.transformations.rotation_matrix(
                angle=theta, direction=[1, 0, 0], point=scene.centroid)

            # Combine the rotations
            rotate_combined = trimesh.transformations.concatenate_matrices(rotate_x, rotate_y)
            # Apply the combined transform to the camera view transform
            camera_old, _geometry = scene.graph[scene.camera.name]
            camera_new = np.dot(rotate_combined, camera_old)
            scene.graph[scene.camera.name] = camera_new

            if not os.path.exists(f'2dImages\{label}\{input_file_name}'):
                os.makedirs(f'2dImages\{label}\{input_file_name}')

            # Save the rendered image
            try:
                # Increment the file name
                file_name = os.path.join(f'2dImages\{label}\{input_file_name}',f"{input_file_name}_{i}{j}.png")
                # Save a render of the object as a png
                png = scene.save_image(resolution=[256, 256])
                with open(file_name, "wb") as f:
                    f.write(png)
                    f.close()
            except BaseException as E:
                trimesh.constants.log.debug("Unable to save image", str(E))

In [None]:
generate2dImage(mcgill_dataset_ply.get_mesh(0),os.path.basename(mcgill_dataset_ply.all_data_paths[0]),mcgill_dataset_ply[0])

In [None]:
for i in range(len(mcgill_dataset_ply) - 1, -1, -1):
    # print(f"{i/len(mcgill_dataset_ply)*100}%")
    generate2dImage(mcgill_dataset_ply.get_mesh(i),os.path.basename(mcgill_dataset_ply.all_data_paths[i]),mcgill_dataset_ply[i])