<a href="https://colab.research.google.com/github/DrVenkataRajeshKumar/ForensicGAN/blob/main/Teeth_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
if torch.__version__=='1.6.0+cu101' and sys.platform.startswith('linux'):
    !pip install pytorch3d
else:
    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

Collecting git+https://github.com/facebookresearch/pytorch3d.git@stable
  Cloning https://github.com/facebookresearch/pytorch3d.git (to revision stable) to /tmp/pip-req-build-p8gjxx82
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/pytorch3d.git /tmp/pip-req-build-p8gjxx82
  Running command git checkout -q 297020a4b1d7492190cb4a909cafbd2c81a12cb5
  Resolved https://github.com/facebookresearch/pytorch3d.git to commit 297020a4b1d7492190cb4a909cafbd2c81a12cb5
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fvcore (from pytorch3d==0.7.4)
  Using cached fvcore-0.1.5.post20221221.tar.gz (50 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting iopath (from pytorch3d==0.7.4)
  Using cached iopath-0.1.10.tar.gz (42 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6 (from fvcore->pytorch3d==0.7.4)
  Using cached yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting portalocker (from iopath->pytorch3d=

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
import numpy as np
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.renderer import look_at_view_transform, OpenGLPerspectiveCameras, RasterizationSettings, MeshRenderer, MeshRasterizer, TexturesVertex
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# Set device (GPU if available, else CPU)
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [55]:
!pip install numpy-stl
!pip install trimesh

Collecting numpy-stl
  Downloading numpy_stl-3.0.1-py3-none-any.whl (19 kB)
Installing collected packages: numpy-stl
Successfully installed numpy-stl-3.0.1
Collecting trimesh
  Downloading trimesh-3.23.5-py3-none-any.whl (685 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m685.4/685.4 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: trimesh
Successfully installed trimesh-3.23.5


In [56]:
from stl import mesh

In [83]:
class CustomDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.input_dir = os.path.join(self.data_dir, 'input_smpl')
        self.output_dir = os.path.join(self.data_dir, 'target_smpl')
        self.file_list = os.listdir(self.input_dir)


    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        input_mesh_path = os.path.join(self.input_dir, self.file_list[idx])
        target_mesh_path = os.path.join(self.output_dir, self.file_list[idx])


        # Load STL files using PyTorch3D
        #input_mesh = load_obj(input_mesh_path)
        input_mesh = mesh.Mesh.from_file(input_mesh_path, dtype=torch.float32)
        target_mesh = mesh.Mesh.from_file(target_mesh_path, dtype=torch.float32)
        #target_mesh = load_obj(target_mesh_path)

        return input_mesh, target_mesh


In [84]:
import torch.nn as nn


class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder layers
        self.encoder = nn.Sequential(
                              nn.Conv3d(4, 32, kernel_size=3, stride=1, padding=1),
                              nn.ReLU(),
                              nn.MaxPool3d(kernel_size=2, stride=2),
                              nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
                              nn.ReLU(),
                              nn.MaxPool3d(kernel_size=2, stride=2),
                              nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
                              nn.ReLU(),
                              nn.MaxPool3d(kernel_size=2, stride=2)
                              )


        # Latent space layers
        self.fc_mu = nn.Linear(128 * 4 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4 * 4, latent_dim)
               # x = x.view(x.size(0), -1)

        # Decoder layers
        self.decoder = nn.Sequential(
            # Define your 3D object decoder architecture here
            nn.Linear(latent_dim, 128 * 4 * 4 * 4),
            #x.view(x.size(0), 128, 4, 4, 4),
            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 4, kernel_size=4, stride=2, padding=1),
            #torch.sigmoid()
        )

    def encode(self, x):
        # Encode input to mean and logvar of the latent space
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # Reparameterization trick for sampling from latent space
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        # Decode from latent space to 3D object
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed_x = self.decode(z)
        return reconstructed_x, mu, logvar


In [85]:
def vae_loss(reconstructed_x, x, mu, logvar):
    # Define your VAE loss function (e.g., reconstruction loss + KL divergence)
    # Chamfer distance is a common choice for 3D object reconstruction
    recon_loss = chamfer_distance(reconstructed_x, x)
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_divergence


In [86]:
# Hyperparameters
latent_dim = 128
batch_size = 16
learning_rate = 1e-3
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the VAE model
model = VAE(latent_dim).to(device)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Create data loaders
data_dir = '/content/drive/Shareddrives/3D/3Ddataset'
dataset = CustomDataset(data_dir)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [87]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for input_mesh, target_mesh in data_loader:
        input_mesh = input_mesh.to(device)
        target_mesh = target_mesh.to(device)

        # Forward pass
        reconstructed_mesh, mu, logvar = model(input_mesh)

        # Calculate loss
        loss = vae_loss(reconstructed_mesh, target_mesh, mu, logvar)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print epoch statistics
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {total_loss / len(data_loader)}")

# Save the trained model
torch.save(model.state_dict(), 'vae_model.pth')


TypeError: ignored