In [None]:
!pip install nvidia-modulus
!pip install pyvista
!pip install trimesh
!pip install torch-geometric
!pip install torch-scatter
!pip install torch-sparse
!pip install torch-cluster
!pip install torch-spline-conv

In [10]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Mohamed Elrefaie, mohamed.elrefaie@mit.edu mohamed.elrefaie@tum.de

This module is part of the research presented in the paper:
"DrivAerNet++: A Large-Scale Multimodal Car Dataset with Computational Fluid Dynamics Simulations and Deep Learning Benchmarks".

This module is used to define both point-cloud based and graph-based models, including RegDGCNN, PointNet, and several Graph Neural Network (GNN) models
for the task of surrogate modeling of the aerodynamic drag.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import copy
import math
import numpy as np
import trimesh
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool, JumpingKnowledge
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d, Dropout
from torch_geometric.nn import BatchNorm

def knn(x, k):
    """
    Computes the k-nearest neighbors for each point in x.

    Args:
        x (torch.Tensor): The input tensor of shape (batch_size, num_dims, num_points).
        k (int): The number of nearest neighbors to find.

    Returns:
        torch.Tensor: Indices of the k-nearest neighbors for each point, shape (batch_size, num_points, k).
    """
    # Calculate pairwise distance, shape (batch_size, num_points, num_points)
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    # Retrieve the indices of the k nearest neighbors
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx


def get_graph_feature(x, k=20, idx=None):
    """
    Constructs local graph features for each point by finding its k-nearest neighbors and
    concatenating the relative position vectors.

    Args:
        x (torch.Tensor): The input tensor of shape (batch_size, num_dims, num_points).
        k (int): The number of neighbors to consider for graph construction.
        idx (torch.Tensor, optional): Precomputed k-nearest neighbor indices.

    Returns:
        torch.Tensor: The constructed graph features of shape (batch_size, 2*num_dims, num_points, k).
    """
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)

    # Compute k-nearest neighbors if not provided
    if idx is None:
        idx = knn(x, k=k)

    # Prepare indices for gathering
    device = x.device
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)

    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()

    # Gather neighbors for each point to construct local regions
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)

    # Expand x to match the dimensions for broadcasting subtraction
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

    # Concatenate the original point features with the relative positions to form the graph features
    feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()

    return feature


class RegDGCNN(nn.Module):
    """
    Deep Graph Convolutional Neural Network for Regression Tasks (RegDGCNN) for processing 3D point cloud data.

    This network architecture extracts hierarchical features from point clouds using graph-based convolutions,
    enabling effective learning of spatial structures.
    """

    def __init__(self, args, output_channels=1):
        """
        Initializes the RegDGCNN model with specified configurations.

        Args:
            args (dict): Configuration parameters including 'k' for the number of neighbors, 'emb_dims' for embedding
            dimensions, and 'dropout' rate.
            output_channels (int): Number of output channels (e.g., for drag prediction, this is 1).
        """
        super(RegDGCNN, self).__init__()
        self.args = args
        self.k = args['k']  # Number of nearest neighbors

        # Batch normalization layers to stabilize and accelerate training
        self.bn1 = nn.BatchNorm2d(256)
        self.bn2 = nn.BatchNorm2d(512)
        self.bn3 = nn.BatchNorm2d(512)
        self.bn4 = nn.BatchNorm2d(1024)
        self.bn5 = nn.BatchNorm1d(args['emb_dims'])

        # EdgeConv layers: Convolutional layers leveraging local neighborhood information
        self.conv1 = nn.Sequential(nn.Conv2d(6, 256, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(256 * 2, 512, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(512 * 2, 512, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(512 * 2, 1024, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(2304, args['emb_dims'], kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))

        # Fully connected layers to interpret the extracted features and make predictions
        self.linear1 = nn.Linear(args['emb_dims']*2, 128, bias=False)
        self.bn6 = nn.BatchNorm1d(128)
        self.dp1 = nn.Dropout(p=args['dropout'])

        self.linear2 = nn.Linear(128, 64)
        self.bn7 = nn.BatchNorm1d(64)
        self.dp2 = nn.Dropout(p=args['dropout'])

        self.linear3 = nn.Linear(64, 32)
        self.bn8 = nn.BatchNorm1d(32)
        self.dp3 = nn.Dropout(p=args['dropout'])

        self.linear4 = nn.Linear(32, 16)
        self.bn9 = nn.BatchNorm1d(16)
        self.dp4 = nn.Dropout(p=args['dropout'])

        self.linear5 = nn.Linear(16, output_channels)  # The final output layer

    def forward(self, x):
        """
        Forward pass of the model to process input data and predict outputs.

        Args:
            x (torch.Tensor): Input tensor representing a batch of point clouds.

        Returns:
            torch.Tensor: Model predictions for the input batch.
        """
        batch_size = x.size(0)

        # Extract graph features and apply EdgeConv blocks
        x = get_graph_feature(x, k=self.k)  # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)  # (batch_size, 3*2, num_points, k) -> (batch_size, 256, num_points, k)

        # Global max pooling
        x1 = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        # Repeat the process for subsequent EdgeConv blocks
        x = get_graph_feature(x1, k=self.k)   # (batch_size, 256, num_points) -> (batch_size, 256*2, num_points, k)
        x = self.conv2(x)                     # (batch_size, 256*2, num_points, k) -> (batch_size, 512, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 512, num_points, k) -> (batch_size, 512, num_points)

        x = get_graph_feature(x2, k=self.k)   # (batch_size, 512, num_points) -> (batch_size, 512*2, num_points, k)
        x = self.conv3(x)                     # (batch_size, 512*2, num_points, k) -> (batch_size, 512, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 512, num_points, k) -> (batch_size, 512, num_points)

        x = get_graph_feature(x3, k=self.k)   # (batch_size, 512, num_points) -> (batch_size, 512*2, num_points, k)
        x = self.conv4(x)                     # (batch_size, 512*2, num_points, k) -> (batch_size, 1024, num_points, k)
        x4 = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 1024, num_points, k) -> (batch_size, 1024, num_points)

        # Concatenate features from all EdgeConv blocks
        x = torch.cat((x1, x2, x3, x4), dim=1)  # (batch_size, 256+512+512+1024, num_points)

        # Apply the final convolutional block
        x = self.conv5(x)  # (batch_size, 256+512+512+1024, num_points) -> (batch_size, emb_dims, num_points)
        # Combine global max and average pooling features
        # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        x = torch.cat((x1, x2), 1)   # (batch_size, emb_dims*2)

        # Process features through fully connected layers with dropout and batch normalization
        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)  # (batch_size, emb_dims*2) -> (batch_size, 128)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)  # (batch_size, 128) -> (batch_size, 64)
        x = self.dp2(x)
        x = F.leaky_relu(self.bn8(self.linear3(x)), negative_slope=0.2)  # (batch_size, 64) -> (batch_size, 32)
        x = self.dp3(x)
        x = F.leaky_relu(self.bn9(self.linear4(x)), negative_slope=0.2)  # (batch_size, 32) -> (batch_size, 16)
        x = self.dp4(x)

        # Final linear layer to produce the output
        x = self.linear5(x)                                              # (batch_size, 16) -> (batch_size, 1)

        return x


class RegPointNet(nn.Module):
    """
    PointNet-based regression model for 3D point cloud data.

    Args:
        args (dict): Configuration parameters including 'emb_dims' for embedding dimensions and 'dropout' rate.

    Methods:
        forward(x): Forward pass through the network.
    """
    def __init__(self, args):
        """
        Initialize the RegPointNet model for regression tasks with enhanced complexity,
        including additional layers and residual connections.

        Parameters:
            emb_dims (int): Dimensionality of the embedding space.
            dropout (float): Dropout probability.
        """
        super(RegPointNet, self).__init__()
        self.args = args

        # Convolutional layers
        self.conv1 = nn.Conv1d(3, 512, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(512, 1024, kernel_size=1, bias=False)
        self.conv3 = nn.Conv1d(1024, 1024, kernel_size=1, bias=False)
        self.conv4 = nn.Conv1d(1024, 1024, kernel_size=1, bias=False)
        self.conv5 = nn.Conv1d(1024, 1024, kernel_size=1, bias=False)
        self.conv6 = nn.Conv1d(1024, args['emb_dims'], kernel_size=1, bias=False)

        # Batch normalization layers
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(1024)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.bn6 = nn.BatchNorm1d(args['emb_dims'])

        # Dropout layers
        self.dropout_conv = nn.Dropout(p=args['dropout'])
        self.dropout_linear = nn.Dropout(p=args['dropout'])

        # Residual connection layer
        self.conv_shortcut = nn.Conv1d(3, args['emb_dims'], kernel_size=1, bias=False)
        self.bn_shortcut = nn.BatchNorm1d(args['emb_dims'])

        # Linear layers for regression output
        self.linear1 = nn.Linear(args['emb_dims'], 512, bias=False)
        self.bn7 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn8 = nn.BatchNorm1d(256)
        self.linear3 = nn.Linear(256, 128)  # Output one scalar value
        self.bn9 = nn.BatchNorm1d(128)
        self.linear4 = nn.Linear(128, 64)  # Output one scalar value
        self.bn10 = nn.BatchNorm1d(64)
        self.final_linear = nn.Linear(64, 1)

    def forward(self, x):
        """
        Forward pass of the network.

        Parameters:
            x (Tensor): Input tensor of shape (batch_size, 3, num_points).

        Returns:
            Tensor: Output tensor of the predicted scalar value.
        """
        shortcut = self.bn_shortcut(self.conv_shortcut(x))

        x = F.relu(self.bn1(self.conv1(x)))
        x = self.dropout_conv(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.dropout_conv(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.dropout_conv(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.dropout_conv(x)
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.dropout_conv(x)
        x = F.relu(self.bn6(self.conv6(x)))
        # Adding the residual connection
        x = x + shortcut

        x = F.adaptive_max_pool1d(x, 1).squeeze(-1)
        x = F.relu(self.bn7(self.linear1(x)))
        x = F.relu(self.bn8(self.linear2(x)))
        x = F.relu(self.bn9(self.linear3(x)))
        x = F.relu(self.bn10(self.linear4(x)))
        features = x
        x = self.final_linear(x)

        #return x, features
        return x

class DragGNN(torch.nn.Module):
    """
    Graph Neural Network for predicting drag coefficients using GCNConv layers.

    Args:
        None

    Methods:
        forward(data): Forward pass through the network.
    """
    def __init__(self):
        super(DragGNN, self).__init__()
        self.conv1 = GCNConv(3, 512)
        self.conv2 = GCNConv(512, 1024)
        self.conv3 = GCNConv(1024, 512)
        self.fc1 = torch.nn.Linear(512, 128)
        self.fc2 = torch.nn.Linear(128, 1)

    def forward(self, data: Data) -> torch.Tensor:
        """
        Forward pass through the network.

        Args:
            data (Data): Input graph data containing node features, edge indices, and batch indices.

        Returns:
            torch.Tensor: Output predictions for drag coefficients.
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = global_mean_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class DragGNN_XL(torch.nn.Module):
    """
    Extended Graph Neural Network for predicting drag coefficients using GCNConv layers and BatchNorm layers.

    Args:
        None

    Methods:
        forward(data): Forward pass through the network.
    """
    def __init__(self):
        super(DragGNN_XL, self).__init__()
        self.conv1 = GCNConv(3, 64)
        self.conv2 = GCNConv(64, 128)
        self.conv3 = GCNConv(128, 128)
        self.conv4 = GCNConv(128, 256)

        self.bn1 = BatchNorm(64)
        self.bn2 = BatchNorm(128)
        self.bn3 = BatchNorm(128)
        self.bn4 = BatchNorm(256)

        self.dropout = Dropout(0.4)

        self.fc = Sequential(
            Linear(256, 128),
            ReLU(),
            Dropout(0.4),
            Linear(128, 64),
            ReLU(),
            Linear(64, 1)
        )

    def forward(self, data: Data) -> torch.Tensor:
        """
        Forward pass through the network.

        Args:
            data (Data): Input graph data containing node features, edge indices, and batch indices.

        Returns:
            torch.Tensor: Output predictions for drag coefficients.
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        x = self.dropout(x)
        x = F.relu(self.bn3(self.conv3(x, edge_index)))
        x = self.dropout(x)
        x = F.relu(self.bn4(self.conv4(x, edge_index)))
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x


class EnhancedDragGNN(torch.nn.Module):
    """
    Enhanced Graph Neural Network for predicting drag coefficients using both GCNConv and GATConv layers,
    with Jumping Knowledge for combining features from different layers.

    Args:
        None

    Methods:
        forward(data): Forward pass through the network.
    """
    def __init__(self):
        super(EnhancedDragGNN, self).__init__()
        self.gcn1 = GCNConv(3, 64)
        self.gat1 = GATConv(64, 64, heads=4, concat=True)

        self.bn1 = BatchNorm1d(128)
        self.gcn2 = GCNConv(256, 128)
        self.gat2 = GATConv(128, 128, heads=2, concat=True)

        self.bn2 = BatchNorm1d(256)
        self.gcn3 = GCNConv(256, 256)

        self.jk = JumpingKnowledge(mode='cat')

        self.fc1 = Sequential(
            Linear(256 * 3, 128),
            ReLU(),
            BatchNorm1d(128)
        )
        self.fc2 = Linear(128, 1)

    def forward(self, data: Data) -> torch.Tensor:
        """
        Forward pass through the network.

        Args:
            data (Data): Input graph data containing node features, edge indices, and batch indices.

        Returns:
            torch.Tensor: Output predictions for drag coefficients.
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x1 = F.relu(self.gcn1(x, edge_index))
        x1 = F.dropout(x1, p=0.2, training=self.training)
        x1 = self.gat1(x1, edge_index)

        x2 = F.relu(self.bn1(self.gcn2(x1, edge_index)))
        x2 = F.dropout(x2, p=0.2, training=self.training)
        x2 = self.gat2(x2, edge_index)

        x3 = F.relu(self.bn2(self.gcn3(x2, edge_index)))

        x = self.jk([x1, x2, x3])

        x = global_mean_pool(x, batch)

        x = self.fc1(x)
        x = self.fc2(x)

        return x


In [13]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Mohamed Elrefaie, mohamed.elrefaie@mit.edu mohamed.elrefaie@tum.de

This module is part of the research presented in the paper:
"DrivAerNet++: A Large-Scale Multimodal Car Dataset with Computational Fluid Dynamics Simulations and Deep Learning Benchmarks".

The module defines two PyTorch Datasets for loading and transforming 3D car models from the DrivAerNet++ dataset:
1. DrivAerNetDataset: Handles point cloud data, allowing loading, transforming, and augmenting 3D car models from STL files or existing point clouds.
2. DrivAerNetGNNDataset: Processes the dataset into graph format suitable for Graph Neural Networks (GNNs).
"""
import os
import logging
import torch
import numpy as np
import pandas as pd
import trimesh
from torch.utils.data import Dataset, DataLoader, random_split
import pyvista as pv
import seaborn as sns
from typing import Callable, Optional, Tuple, List
from torch_geometric.data import Data

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class DataAugmentation:
    """
    Class encapsulating various data augmentation techniques for point clouds.
    """
    @staticmethod
    def translate_pointcloud(pointcloud: torch.Tensor, translation_range: Tuple[float, float] = (2./3., 3./2.)) -> torch.Tensor:
        """
        Translates the pointcloud by a random factor within a given range.

        Args:
            pointcloud: The input point cloud as a torch.Tensor.
            translation_range: A tuple specifying the range for translation factors.

        Returns:
            Translated point cloud as a torch.Tensor.
        """
        # Randomly choose translation factors and apply them to the pointcloud
        xyz1 = np.random.uniform(low=translation_range[0], high=translation_range[1], size=[3])
        xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
        translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
        return torch.tensor(translated_pointcloud, dtype=torch.float32)

    @staticmethod
    def jitter_pointcloud(pointcloud: torch.Tensor, sigma: float = 0.01, clip: float = 0.02) -> torch.Tensor:
        """
        Adds Gaussian noise to the pointcloud.

        Args:
            pointcloud: The input point cloud as a torch.Tensor.
            sigma: Standard deviation of the Gaussian noise.
            clip: Maximum absolute value for noise.

        Returns:
            Jittered point cloud as a torch.Tensor.
        """
        # Add Gaussian noise and clip to the specified range
        N, C = pointcloud.shape
        jittered_pointcloud = pointcloud + torch.clamp(sigma * torch.randn(N, C), -clip, clip)
        return jittered_pointcloud

    @staticmethod
    def drop_points(pointcloud: torch.Tensor, drop_rate: float = 0.1) -> torch.Tensor:
        """
        Randomly removes points from the point cloud based on the drop rate.

        Args:
            pointcloud: The input point cloud as a torch.Tensor.
            drop_rate: The percentage of points to be randomly dropped.

        Returns:
            The point cloud with points dropped as a torch.Tensor.
        """
        # Calculate the number of points to drop
        num_drop = int(drop_rate * pointcloud.size(0))
        # Generate random indices for points to drop
        drop_indices = np.random.choice(pointcloud.size(0), num_drop, replace=False)
        keep_indices = np.setdiff1d(np.arange(pointcloud.size(0)), drop_indices)
        dropped_pointcloud = pointcloud[keep_indices, :]
        return dropped_pointcloud

class DrivAerNetDataset(Dataset):
    """
    PyTorch Dataset class for the DrivAerNet dataset, handling loading, transforming, and augmenting 3D car models.
    """
    def __init__(self, root_dir: str, csv_file: str, num_points: int, transform: Optional[Callable] = None, pointcloud_exist: bool = False):
        """
        Initializes the DrivAerNetDataset instance.

        Args:
            root_dir: Directory containing the STL files for 3D car models.
            csv_file: Path to the CSV file with metadata for the models.
            num_points: Fixed number of points to sample from each 3D model.
            transform: Optional transform function to apply to each sample.
            pointcloud_exist (bool): Whether the point clouds already exist as .pt files.
        """
        self.root_dir = root_dir
        try:
            self.data_frame = pd.read_csv(csv_file)
        except Exception as e:
            logging.error(f"Failed to load CSV file: {csv_file}. Error: {e}")
            raise

        self.transform = transform
        self.num_points = num_points
        self.augmentation = DataAugmentation()
        self.pointcloud_exist = pointcloud_exist
        self.cache = {}

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.data_frame)

    def min_max_normalize(self, data: torch.Tensor) -> torch.Tensor:
        """
        Normalizes the data to the range [0, 1] based on min and max values.
        """
        min_vals, _ = data.min(dim=0, keepdim=True)
        max_vals, _ = data.max(dim=0, keepdim=True)
        normalized_data = (data - min_vals) / (max_vals - min_vals)
        return normalized_data

    def z_score_normalize(self, data: torch.Tensor) -> torch.Tensor:
        """
        Normalizes the data using z-score normalization (standard score).
        """
        mean_vals = data.mean(dim=0, keepdim=True)
        std_vals = data.std(dim=0, keepdim=True)
        normalized_data = (data - mean_vals) / std_vals
        return normalized_data

    def mean_normalize(self, data: torch.Tensor) -> torch.Tensor:
        """
        Normalizes the data to the range [-1, 1] based on mean and range.
        """
        mean_vals = data.mean(dim=0, keepdim=True)
        min_vals, _ = data.min(dim=0, keepdim=True)
        max_vals, _ = data.max(dim=0, keepdim=True)
        normalized_data = (data - mean_vals) / (max_vals - min_vals)
        return normalized_data
    def _sample_or_pad_vertices(self, vertices: torch.Tensor, num_points: int) -> torch.Tensor:
        """
        Subsamples or pads the vertices of the model to a fixed number of points.

        Args:
            vertices: The vertices of the 3D model as a torch.Tensor.
            num_points: The desired number of points for the model.

        Returns:
            The vertices standardized to the specified number of points.
        """
        num_vertices = vertices.size(0)
        if num_vertices > num_points:
            indices = np.random.choice(num_vertices, num_points, replace=False)
            vertices = vertices[indices]
        elif num_vertices < num_points:
            padding = torch.zeros((num_points - num_vertices, 3), dtype=torch.float32)
            vertices = torch.cat((vertices, padding), dim=0)
        return vertices

    def _load_point_cloud(self, design_id: str) -> Optional[torch.Tensor]:
        load_path = os.path.join(self.root_dir, f"{design_id}.pt")
        if os.path.exists(load_path) and os.path.getsize(load_path) > 0:
            try:
                return torch.load(load_path)
            except (EOFError, RuntimeError) as e:
                #logging.error(f"Failed to load point cloud file {load_path}: {e}")
                return None
        else:
            #logging.error(f"Point cloud file {load_path} does not exist or is empty.")
            return None

    def __getitem__(self, idx: int, apply_augmentations: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieves a sample and its corresponding label from the dataset, with an option to apply augmentations.

        Args:
            idx (int): Index of the sample to retrieve.
            apply_augmentations (bool, optional): Whether to apply data augmentations. Defaults to True.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The sample (point cloud) and its label (Cd value).
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if idx in self.cache:
            return self.cache[idx]
        while True:
            row = self.data_frame.iloc[idx]
            design_id = row['Design']
            cd_value = row['Average Cd']

            if self.pointcloud_exist:
                vertices = self._load_point_cloud(design_id)

                if vertices is None:
                    #logging.warning(f"Skipping design {design_id} because point cloud is not found or corrupted.")
                    idx = (idx + 1) % len(self.data_frame)
                    continue
            else:
                geometry_path = os.path.join(self.root_dir, f"{design_id}.stl")
                try:
                    mesh = trimesh.load(geometry_path, force='mesh')
                    vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
                    vertices = self._sample_or_pad_vertices(vertices, self.num_points)
                except Exception as e:
                    logging.error(f"Failed to load STL file: {geometry_path}. Error: {e}")
                    raise

            if apply_augmentations:
                vertices = self.augmentation.translate_pointcloud(vertices.numpy())
                vertices = self.augmentation.jitter_pointcloud(vertices)

            if self.transform:
                vertices = self.transform(vertices)

            point_cloud_normalized = self.min_max_normalize(vertices)
            cd_value = torch.tensor(float(cd_value), dtype=torch.float32).view(-1)

            self.cache[idx] = (point_cloud_normalized, cd_value)
            return point_cloud_normalized, cd_value

    def split_data(self, train_ratio: float = 0.7, val_ratio: float = 0.15, test_ratio: float = 0.15) -> Tuple[List[int], List[int], List[int]]:
        """
        Splits the dataset into training, validation, and test sets.

        Args:
            train_ratio: The proportion of the data to be used for training.
            val_ratio: The proportion of the data to be used for validation.
            test_ratio: The proportion of the data to be used for testing.

        Returns:
            Indices for the training, validation, and test sets.
        """
        assert train_ratio + val_ratio + test_ratio == 1, "Ratios must sum to 1"
        num_samples = len(self)
        indices = list(range(num_samples))
        train_size = int(train_ratio * num_samples)
        val_size = int(val_ratio * num_samples)
        test_size = num_samples - train_size - val_size
        train_indices, val_indices, test_indices = random_split(indices, [train_size, val_size, test_size])
        return train_indices, val_indices, test_indices

    def visualize_mesh(self, idx):
        """
        Visualize the STL mesh for a specific design from the dataset.

        Args:
            idx (int): Index of the design to visualize in the dataset.

        This function loads the mesh from the STL file corresponding to the design ID at the given index,
        wraps it using PyVista for visualization, and then sets up a PyVista plotter to display the mesh.
        """
        row = self.data_frame.iloc[idx]
        design_id = row['Design']
        geometry_path = os.path.join(self.root_dir, f"{design_id}.stl")

        try:
            mesh = trimesh.load(geometry_path, force='mesh')
        except Exception as e:
            logging.error(f"Failed to load STL file: {geometry_path}. Error: {e}")
            raise

        pv_mesh = pv.wrap(mesh)
        plotter = pv.Plotter()
        plotter.add_mesh(pv_mesh, color='lightgrey', show_edges=True)
        plotter.add_axes()

        camera_position = [(-11.073024242161921, -5.621499358347753, 5.862225824910342),
                           (1.458462064391673, 0.002314306982062475, 0.6792134746589196),
                           (0.34000174095454166, 0.10379556639001211, 0.9346792479485448)]
        plotter.camera_position = camera_position
        plotter.show()

    def visualize_mesh_with_node(self, idx):
        """
        Visualizes the mesh for a specific design from the dataset with nodes highlighted.

        Args:
            idx (int): Index of the design to visualize in the dataset.

        This function loads the mesh from the STL file and highlights the nodes (vertices) of the mesh using spheres.
        It uses seaborn to obtain visually distinct colors for the mesh and nodes.
        """
        row = self.data_frame.iloc[idx]
        design_id = row['Design']
        geometry_path = os.path.join(self.root_dir, f"{design_id}.stl")

        try:
            mesh = trimesh.load(geometry_path, force='mesh')
            pv_mesh = pv.wrap(mesh)
        except Exception as e:
            logging.error(f"Failed to load STL file: {geometry_path}. Error: {e}")
            raise

        plotter = pv.Plotter()
        sns_blue = sns.color_palette("colorblind")[0]

        plotter.add_mesh(pv_mesh, color='lightgrey', show_edges=True, edge_color='black')
        nodes = pv_mesh.points
        plotter.add_points(nodes, color=sns_blue, point_size=5, render_points_as_spheres=True)
        plotter.add_axes()
        plotter.show()

    def visualize_point_cloud(self, idx):
        """
        Visualizes the point cloud for a specific design from the dataset.

        Args:
            idx (int): Index of the design to visualize in the dataset.

        This function retrieves the vertices for the specified design, converts them into a point cloud,
        and uses the z-coordinate for color mapping. PyVista's Eye-Dome Lighting is enabled for improved depth perception.
        """
        # Retrieve vertices and corresponding CD value for the specified index
        vertices, _ = self.__getitem__(idx)
        vertices = vertices.numpy()

        # Convert vertices to a PyVista PolyData object for visualization
        point_cloud = pv.PolyData(vertices)
        colors = vertices[:, 2]  # Using the z-coordinate for color mapping
        point_cloud["colors"] = colors  # Add the colors to the point cloud

        # Set up the PyVista plotter
        plotter = pv.Plotter()

        # Add the point cloud to the plotter with color mapping based on the z-coordinate
        plotter.add_points(point_cloud, scalars="colors", cmap="Blues", point_size=3, render_points_as_spheres=True)

        # Enable Eye-Dome Lighting for better depth perception
        plotter.enable_eye_dome_lighting()

        # Add axes for orientation and display the plotter window
        plotter.add_axes()
        camera_position = [(-11.073024242161921, -5.621499358347753, 5.862225824910342),
                           (1.458462064391673, 0.002314306982062475, 0.6792134746589196),
                           (0.34000174095454166, 0.10379556639001211, 0.9346792479485448)]

        # Set the camera position
        plotter.camera_position = camera_position

        plotter.show()

    def visualize_augmentations(self, idx):
        """
        Visualizes various augmentations applied to the point cloud of a specific design in the dataset.

        Args:
            idx (int): Index of the sample in the dataset to be visualized.

        This function retrieves the original point cloud for the specified design and then applies a series of augmentations,
        including translation, jittering, and point dropping. Each version of the point cloud (original and augmented) is then
        visualized in a 2x2 grid using PyVista to illustrate the effects of these augmentations.
        """
        # Retrieve the original point cloud without applying any augmentations
        vertices, _ = self.__getitem__(idx, apply_augmentations=False)
        original_pc = pv.PolyData(vertices.numpy())

        # Apply translation augmentation to the original point cloud
        translated_pc = self.augmentation.translate_pointcloud(vertices.numpy())
        # Apply jitter augmentation to the translated point cloud
        jittered_pc = self.augmentation.jitter_pointcloud(translated_pc)
        # Apply point dropping augmentation to the jittered point cloud
        dropped_pc = self.augmentation.drop_points(jittered_pc)

        # Initialize a PyVista plotter with a 2x2 grid for displaying the point clouds
        plotter = pv.Plotter(shape=(2, 2))

        # Display the original point cloud in the top left corner of the grid
        plotter.subplot(0, 0)  # Select the first subplot
        plotter.add_text("Original Point Cloud", font_size=10)  # Add descriptive text
        plotter.add_mesh(original_pc, color='black', point_size=3)  # Add the original point cloud to the plot

        # Display the translated point cloud in the top right corner of the grid
        plotter.subplot(0, 1)  # Select the second subplot
        plotter.add_text("Translated Point Cloud", font_size=10)  # Add descriptive text
        plotter.add_mesh(pv.PolyData(translated_pc.numpy()), color='lightblue',
                         point_size=3)  # Add the translated point cloud to the plot

        # Display the jittered point cloud in the bottom left corner of the grid
        plotter.subplot(1, 0)  # Select the third subplot
        plotter.add_text("Jittered Point Cloud", font_size=10)  # Add descriptive text
        plotter.add_mesh(pv.PolyData(jittered_pc.numpy()), color='lightgreen',
                         point_size=3)  # Add the jittered point cloud to the plot

        # Display the dropped point cloud in the bottom right corner of the grid
        plotter.subplot(1, 1)  # Select the fourth subplot
        plotter.add_text("Dropped Point Cloud", font_size=10)  # Add descriptive text
        plotter.add_mesh(pv.PolyData(dropped_pc.numpy()), color='salmon',
                         point_size=3)  # Add the dropped point cloud to the plot

        # Display the plot with all point clouds
        plotter.show()

class DrivAerNetGNNDataset(Dataset):
    """
    PyTorch Dataset for loading and processing the DrivAerNet dataset into graph format suitable for GNNs.
    """

    def __init__(self, root_dir: str, csv_file: str, normalize: bool = True):
        """
        Initialize the dataset.

        Args:
            root_dir (str): Path to the directory containing the STL files.
            csv_file (str): Path to the CSV file containing metadata such as aerodynamic coefficients.
            normalize (bool): Whether to normalize the node features.
        """
        self.root_dir = root_dir
        self.data_frame = pd.read_csv(csv_file)
        self.normalize = normalize
        self.cache = {}

    def __len__(self) -> int:
        """
        Return the length of the dataset.

        Returns:
            int: Number of samples in the dataset.
        """
        return len(self.data_frame)

    def min_max_normalize(self, data: torch.Tensor) -> torch.Tensor:
        """
        Normalizes the data to the range [0, 1] based on min and max values.

        Args:
            data (torch.Tensor): The input data tensor to be normalized.

        Returns:
            torch.Tensor: The normalized data tensor.
        """
        min_vals, _ = data.min(dim=0, keepdim=True)
        max_vals, _ = data.max(dim=0, keepdim=True)
        normalized_data = (data - min_vals) / (max_vals - min_vals)
        return normalized_data

    def __getitem__(self, idx: int) -> Data:
        """
        Get a graph data item for GNN processing.

        Args:
            idx (int): Index of the item.

        Returns:
            Data: A PyTorch Geometric Data object containing edge_index, x (node features), and y (target variable).
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if idx in self.cache:
            return self.cache[idx]

        row = self.data_frame.iloc[idx]
        stl_path = os.path.join(self.root_dir, f"{row['Design']}.stl")
        cd_value = row['Average Cd']

        # Load the mesh from STL
        try:
            mesh = trimesh.load(stl_path, force='mesh')
        except Exception as e:
            logging.error(f"Failed to load STL file: {stl_path}. Error: {e}")
            raise

        # Convert mesh to graph
        edge_index = torch.tensor(np.array(mesh.edges).T, dtype=torch.long)
        x = torch.tensor(mesh.vertices, dtype=torch.float)  # Using vertex positions as features

        if self.normalize:
            x = self.min_max_normalize(x)

        y = torch.tensor([cd_value], dtype=torch.float)  # Target variable as tensor

        # Create a graph data object
        data = Data(x=x, edge_index=edge_index, y=y)

        self.cache[idx] = data
        return data

    def visualize_mesh_with_node(self, idx: int) -> None:
        """
        Visualizes the mesh of a given sample index with triangles in light grey and nodes highlighted as spheres.

        Args:
            idx (int): Index of the sample to visualize.
        """
        row = self.data_frame.iloc[idx]
        design_id = row['Design']
        geometry_path = os.path.join(self.root_dir, f"{design_id}.stl")

        try:
            mesh = trimesh.load(geometry_path, force='mesh')
            pv_mesh = pv.wrap(mesh)
        except Exception as e:
            logging.error(f"Failed to load STL file: {geometry_path}. Error: {e}")
            raise

        plotter = pv.Plotter()
        sns_blue = sns.color_palette("colorblind")[0]

        # Add the mesh to the plotter with light grey color
        plotter.add_mesh(pv_mesh, color='lightgrey', show_edges=True, edge_color='black')

        # Highlight nodes as spheres
        nodes = pv_mesh.points
        plotter.add_points(nodes, color=sns_blue, point_size=5, render_points_as_spheres=True)  # Increase point_size as needed

        plotter.add_axes()
        camera_position = [(-11.073024242161921, -5.621499358347753, 5.862225824910342),
                           (1.458462064391673, 0.002314306982062475, 0.6792134746589196),
                           (0.34000174095454166, 0.10379556639001211, 0.9346792479485448)]

        # Set the camera position
        plotter.camera_position = camera_position
        plotter.show()

    def visualize_graph(self, idx: int) -> None:
        """
        Visualizes the graph representation of the 3D mesh using PyVista.

        Args:
            idx (int): Index of the sample to visualize.
        """
        data = self[idx]  # Get the data object
        mesh = pv.PolyData(data.x.numpy())  # Create a PyVista mesh from node features

        # Create edges array suitable for PyVista
        edges = data.edge_index.t().numpy()
        lines = np.full((edges.shape[0], 3), 2, dtype=np.int_)
        lines[:, 1:] = edges

        mesh.lines = lines
        mesh['scalars'] = np.random.rand(mesh.n_points)  # Random colors for nodes

        plotter = pv.Plotter()
        plotter.add_mesh(mesh, show_edges=True, line_width=1, color='white', point_size=8, render_points_as_spheres=True)
        plotter.add_scalar_bar('Scalar Values', 'scalars')

        # Optional: highlight edges for clarity
        edge_points = mesh.points[edges.flatten()]
        lines = pv.lines_from_points(edge_points)
        plotter.add_mesh(lines, color='blue', line_width=2)

        plotter.show()


# # Example usage
# if __name__ == '__main__':
#     dataset = DrivAerNetDataset(root_dir='../DrivAerNetPlusPlus_combined_all',
#                                 csv_file='../Combined_AeroCoefficients_DrivAerNet.csv',
#                                 num_points=100000,
#                                 pointcloud_exist=False  # Set to False if point clouds do not exist as .pt files
#                                 )
#
#     dataset.visualize_mesh_with_node(10)  # Visualize the mesh with nodes of the 300th sample
#
#     dataset.visualize_point_cloud(10)  # Visualize the point cloud of the 300th sample
#
#     # Splitting data into train, validation, and test sets
#     #train_indices, val_indices, test_indices = dataset.split_data()
#     #logging.info(f"Train size: {len(train_indices)}, Validation size: {len(val_indices)}, Test size: {len(test_indices)}")


In [12]:
!pip install pyvista

Collecting pyvista
  Downloading pyvista-0.44.2-py3-none-any.whl.metadata (15 kB)
Collecting vtk<9.4.0 (from pyvista)
  Downloading vtk-9.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.2 kB)
Downloading pyvista-0.44.2-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading vtk-9.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (92.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.1/92.1 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: vtk, pyvista
Successfully installed pyvista-0.44.2 vtk-9.3.1


In [15]:
import os
import torch
import numpy as np
import time
import trimesh
import pyvista as pv
import torch.optim as optim
import torch.nn.functional as F

# Configuration
config = {
    'exp_name': 'DragPrediction_DrivAerNet_PointNet_Showcase',
    'cuda': True,
    'seed': 42,
    'num_points': 100000,
    'dropout': 0.0,
    'emb_dims': 1024,
    'k': 40,
    'channels': [6, 64, 128, 256, 512, 1024],
    'linear_sizes': [128, 64, 32, 16],
    'output_channels': 1,
    'dataset_path': '/content/sample_data/DrivAerNet_ParametricData.csv',
    'aero_coeff': '/Users/aswin/data/DrivAerNet/ParametricModels/DrivAerNet_ParametricData.csv',
}

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() and config['cuda'] else "cpu")

# Utility Functions
def setup_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

def initialize_model(config: dict) -> torch.nn.Module:
    model = RegPointNet(args=config).to(device)
    if config['cuda'] and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, device_ids=[0, 2, 3])
    return model

# Killer Feature 1: Real-Time Drag Prediction
def predict_drag(model, geometry_path: str, num_points: int, dataset) -> float:
    model.eval()
    mesh = trimesh.load(geometry_path, force='mesh')
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
    if vertices.size(0) != num_points:
        vertices = dataset._sample_or_pad_vertices(vertices, num_points)
    vertices = dataset.min_max_normalize(vertices).unsqueeze(0).permute(0, 2, 1).to(device)

    with torch.no_grad():
        start_time = time.time()
        output = model(vertices).item()
        inference_time = time.time() - start_time
        print(f"Predicted Cd: {output:.4f}, Inference Time: {inference_time:.4f}s")
    return output

# Killer Feature 2: Drag Sensitivity Visualization
def visualize_sensitivity(model, geometry_path: str, num_points: int, dataset):
    model.eval()
    mesh = trimesh.load(geometry_path, force='mesh')
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32, requires_grad=True)
    if vertices.size(0) != num_points:
        vertices = dataset._sample_or_pad_vertices(vertices, num_points)
    vertices_normalized = dataset.min_max_normalize(vertices).unsqueeze(0).permute(0, 2, 1).to(device)

    output = model(vertices_normalized)
    output.backward()
    sensitivity = vertices.grad.abs().sum(dim=1).cpu().numpy()

    pv_mesh = pv.wrap(mesh)
    pv_mesh.point_data['sensitivity'] = sensitivity
    plotter = pv.Plotter()
    plotter.add_mesh(pv_mesh, scalars='sensitivity', cmap='hot', show_edges=True)
    plotter.add_scalar_bar(title='Drag Sensitivity')
    plotter.show()

# Killer Feature 3: Automated Design Optimization
def optimize_geometry(model, geometry_path: str, num_points: int, dataset, steps: int = 100, lr: float = 0.01):
    model.eval()
    mesh = trimesh.load(geometry_path, force='mesh')
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32, requires_grad=True)
    if vertices.size(0) != num_points:
        vertices = dataset._sample_or_pad_vertices(vertices, num_points)
    vertices_normalized = dataset.min_max_normalize(vertices).clone().detach().requires_grad_(True).to(device)
    optimizer = optim.Adam([vertices_normalized], lr=lr)

    for step in range(steps):
        optimizer.zero_grad()
        input_data = vertices_normalized.unsqueeze(0).permute(0, 2, 1)
        cd_pred = model(input_data)
        cd_pred.backward()
        optimizer.step()
        if step % 10 == 0:
            print(f"Step {step}, Predicted Cd: {cd_pred.item():.4f}")

    optimized_vertices = vertices_normalized.detach().cpu().numpy()
    return optimized_vertices

# Main Execution
if __name__ == "__main__":
    # Set random seed for reproducibility
    setup_seed(config['seed'])

    # Initialize the model
    model = initialize_model(config)

    # Load the pre-trained model checkpoint
    best_model_path = os.path.join('models', f'{config["exp_name"]}_best_model.pth')
    if not os.path.exists(best_model_path):
        raise FileNotFoundError(f"Pre-trained model not found at {best_model_path}. Please provide a trained model checkpoint.")
    model.load_state_dict(torch.load(best_model_path, map_location=device))
    print(f"Loaded pre-trained model from {best_model_path}")

    # Initialize dataset for normalization utilities
    dataset = DrivAerNetDataset(
        root_dir=config['dataset_path'],
        csv_file=config['aero_coeff'],
        num_points=config['num_points'],
        pointcloud_exist=True
    )

    # Example geometry path (replace with an actual mesh file, e.g., .stl or .obj)
    sample_geometry = "/content/sample_data/DrivAerNet_ParametricData.csv"  # Replace with a valid mesh file path

    # Killer Feature 1: Real-Time Drag Prediction
    print("Killer Feature 1: Real-Time Drag Prediction")
    drag_coeff = predict_drag(model, sample_geometry, config['num_points'], dataset)

    # Killer Feature 2: Drag Sensitivity Visualization
    print("Killer Feature 2: Drag Sensitivity Visualization")
    visualize_sensitivity(model, sample_geometry, config['num_points'], dataset)

    # Killer Feature 3: Automated Design Optimization
    print("Killer Feature 3: Automated Design Optimization")
    optimized_vertices = optimize_geometry(model, sample_geometry, config['num_points'], dataset)
    print(f"Optimized geometry vertices shape: {optimized_vertices.shape}")