In [10]:
import numpy as np
from stl import mesh
from mpl_toolkits import mplot3d
from matplotlib import pyplot as plt
import cv2
import os

def generate_dataset(mesh_file, output_dir, angles, image_size=(256, 256)):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    your_mesh = mesh.Mesh.from_file(mesh_file)

    for i, angle in enumerate(angles):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        ax.add_collection3d(mplot3d.art3d.Poly3DCollection(your_mesh.vectors))

        scale = your_mesh.points.flatten()
        ax.auto_scale_xyz(scale, scale, scale)

        ax.view_init(elev=angle[0], azim=angle[1])

        ax.axis('off')

        image_path = os.path.join(output_dir, f'image_{i}.png')
        plt.savefig(image_path, bbox_inches='tight', pad_inches=0)
        plt.close(fig)

        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        image_resized = cv2.resize(image, image_size)

        _, segmented = cv2.threshold(image_resized, 1, 255, cv2.THRESH_BINARY)


        mask_path = os.path.join(output_dir, f'mask_{i}.png')
        cv2.imwrite(mask_path, segmented)
        cv2.imwrite(image_path, image_resized)

angles = [(90, 90), (90, 0), (0, 90), (0, 0), (45, 45)]

mesh_file = '../data/moon.stl'
output_dir = '../output'

generate_dataset(mesh_file, output_dir, angles)
