In [369]:
#Import nescessary libraries
import trimesh
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader,TensorDataset
from torchvision import datasets,transforms
import numpy as np
import tqdm



In [None]:
#Configure the neural network properties
epochs = 2000
batch_size = 4
sample_size =100 #Number of random values to sample

#The generator and discriminator learning rate
g_lr = 1.5e-5
d_lr = 0.5e-5

#Size 
size = 11

In [371]:
folder = "models"
#Load the models
obj_files = [file for file in os.listdir(folder) if file.endswith(".obj")]
meshes = []
for fname in obj_files:
    mesh = trimesh.load(folder+"/"+fname)
    #Normalize the model
    mesh.apply_scale(1.0/(max(mesh.extents)))
    meshes.append(mesh)

In [372]:
#Visualize mesh
meshes[4].show()

In [373]:
#Padding a matrix
def resize_voxel(array,size):
    x,y,z = array.shape

    if max(x,y,z) > size:
        raise ValueError("Target size is smaller than array size")
    
    pad_x_before  = (size-x)//2
    pad_x_after = size -x-pad_x_before

    
    pad_y_before  = (size-y)//2
    pad_y_after = size -y-pad_y_before

    
    pad_z_before  = (size-z)//2
    pad_z_after = size -z-pad_z_before
    
    return np.pad(array,((pad_x_before,pad_x_after),(pad_y_before,pad_y_after),(pad_z_before,pad_z_after)),mode='constant',constant_values=0)

In [374]:
#Voxelize the mesh
voxels = []

for mesh in meshes:
    print()
    voxels.append(trimesh.voxel.creation.voxelize(mesh,0.1).fill())
    

print(voxels)
voxels[0].show()













[<trimesh.VoxelGrid(11, 11, 11)>, <trimesh.VoxelGrid(11, 11, 11)>, <trimesh.VoxelGrid(11, 11, 11)>, <trimesh.VoxelGrid(11, 11, 11)>, <trimesh.VoxelGrid(11, 11, 11)>, <trimesh.VoxelGrid(11, 11, 11)>, <trimesh.VoxelGrid(11, 3, 11)>, <trimesh.VoxelGrid(11, 3, 11)>, <trimesh.VoxelGrid(11, 3, 11)>, <trimesh.VoxelGrid(11, 3, 11)>, <trimesh.VoxelGrid(11, 3, 11)>, <trimesh.VoxelGrid(11, 3, 11)>]


In [375]:
#Convert to tensors
real_X = []
real_Y = []
for voxel in voxels:
    real_X.append(torch.from_numpy(resize_voxel(voxel.matrix.astype(float),11)))
    real_Y.append(torch.Tensor([1.0]))

real_X = torch.stack(real_X)
real_Y = torch.stack(real_Y)

dataset = TensorDataset(real_X,real_Y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
#Our generator network
class Generator(nn.Sequential):
    def __init__(self,sample_size:int):
        super().__init__(
            nn.Linear(sample_size, 256),
            nn.LeakyReLU(0.2),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1331),
       
            nn.Sigmoid())
        #Random value
        self.sample_size = sample_size
    
    #Forward propagation

    def forward(self, batch_size:int):
        z = torch.randn(batch_size,self.sample_size)
        #Get output
        output = super().forward(z)

        #Create  images
        generated_models = output.reshape(batch_size,11,11,11)
        return generated_models

In [None]:
#Discriminator model
#We define our discriminator network
class Discriminator(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Linear(1331,512),
            nn.LeakyReLU(0.1),
            nn.Linear(512,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,1)
        )

    def forward(self,models:torch.Tensor,targets:torch.Tensor):
        #Predict the value
        models = models.to(torch.float32)
        prediction = super().forward(models.reshape(-1,1331))
        loss = F.binary_cross_entropy_with_logits(prediction,targets)
        return loss
     

In [378]:
#Create real and fake labels
real_targets = torch.ones(batch_size,1)
fake_targets = torch.zeros(batch_size,1)

In [379]:
#Create the models
generator=  Generator(sample_size)
discriminator =Discriminator()

In [380]:
#Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(),lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(),g_lr)

In [381]:
#Training loop
for epoch in range(epochs):
    d_losses = []
    g_losses = []

    for images,labels in dataloader:
        #===============================
        # Discriminator Network Training
        #===============================
        
        # Loss with MNIST image inputs and real_targets as labels
        discriminator.train()
        d_loss = discriminator(images, labels.float())

        # Generate images in eval mode
        generator.eval()
        with torch.no_grad():
            generated_images = generator(batch_size)

        # Loss with generated image inputs and fake_targets as labels
        d_loss += discriminator(generated_images, fake_targets.float())

        # Optimizer updates the discriminator parameters
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #-+-+-+-+-+-+-+-+-+-+-+-+
        #Training Generator
        #-+-+-+-+-+-+-+-+-+-+-+-+

        generator.train()
        generated_images = generator(batch_size)

        #We dont want to train the discriminator, but we still want the gradients
        discriminator.eval()

        g_loss = discriminator(generated_images,real_targets)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        #Append the losses
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())


    #Print average loss
    print(epoch,np.mean(d_losses),np.mean(g_losses))

0 1.404555877049764 0.6454720894495646
1 1.4030295213063557 0.6469100117683411
2 1.4018524885177612 0.6473880807558695
3 1.4006664752960205 0.6485551993052164
4 1.3995620807011921 0.6491700808207194
5 1.3983822266260784 0.6503300269444784
6 1.397233208020528 0.650722881158193
7 1.396047314008077 0.6520423889160156
8 1.394943356513977 0.6529469887415568
9 1.3931035995483398 0.6535709301630656
10 1.3925067981084187 0.6547974348068237
11 1.3914105892181396 0.6560151974360148
12 1.389693061510722 0.6564305822054545
13 1.3890950679779053 0.6574374636014303
14 1.3882768551508586 0.6584858099619547
15 1.3864700396855671 0.6592233975728353
16 1.385070562362671 0.6598249077796936
17 1.384057879447937 0.6609099507331848
18 1.3830976088841755 0.6615760127703348
19 1.3816368182500203 0.6625077724456787
20 1.3803117275238037 0.6635643045107523
21 1.3792122999827068 0.664529045422872
22 1.3782739639282227 0.6652958790461222
23 1.3769925038019817 0.6665679017702738
24 1.3758204380671184 0.66732219854

In [403]:
#Show output
gen_model = generator.forward(1).detach().numpy().reshape(11,11,11)

gen_model = np.where(gen_model>=0.5,1,0)

voxel = trimesh.voxel.VoxelGrid(gen_model)
voxel.show()