# Imports

In [1]:
from Transformation import Transformation
import numpy as np
# from sklearn.neighbors import NearestNeighbors
import networkx as nx
import torch
import torch.nn as nn
from igraph import Graph as igraphGraph
import os
from torch.utils.data import Dataset
# from scipy.spatial.distance import cdist
# from shapely.geometry import LineString, MultiLineString, Polygon, MultiPolygon, LinearRing

# Input

In [2]:
# torch.manual_seed(0)  #  for repeatable results
# transformation = Transformation()

# user_number_triangles = 500   #à diminuer si le process est trop long
# number_neigh_tri = 20

# # Create objects
# stl_file_path = "3d_models/stl/Handle.stl"
# mesh_data = transformation.stl_to_mesh(stl_file_path)
# graph = transformation.mesh_to_graph(mesh_data)

# transformation.print_graph_properties(graph, display_graph=False, display_labels=False)

In [3]:
# if len(graph._node)<20:
#     raise Exception("Input mesh does not have enough vertices. (More than 20 is needed)")

# Point Sampler

### DevConv

In [4]:
# def relu(array):
#     return np.maximum(array, 0)

# def sigmoid(array):
#     return 1 / (1 + np.exp(-array))

In [5]:
# graph_nodes = torch.Tensor(np.array(graph))
# graph_adjacency_matrix = torch.Tensor(nx.adjacency_matrix(graph).toarray())

In [6]:
class DevConv(nn.Module):
    def __init__(self, nodes, adjacency_matrix, output_dimension):
        super().__init__()
        self.size = output_dimension
        self.nodes = nodes
        self.adjacency_matrix = adjacency_matrix
        self.W_phi = nn.Parameter(torch.Tensor(output_dimension))
        self.W_theta = nn.Parameter(torch.Tensor(size=(3,1)))

        nn.init.normal_(self.W_phi)
        nn.init.normal_(self.W_theta)

        # print("self.W_phi.shape : ", self.W_phi.shape)
        # print("self.W_theta.shape : ", self.W_theta.shape)
    
    def forward(self, previous_inclusion_score, return_flatten=True):
        list_inc_score = torch.zeros((self.nodes.shape[0], self.size))                                          #list of "output_dimension" for each "list_node" element
        for index_current_node, list_neighbors in enumerate(self.adjacency_matrix):                             # for each node and its adjacency nodes
            neighbors = self.nodes[list_neighbors.nonzero()]                                                    # get neighbors nodes
            diff = self.nodes[index_current_node] - neighbors                                                   # Compute the differences between current_node and all neighbor nodes   (x_i - x_j)
            to_norm = self.W_theta.T.unsqueeze(1).repeat(1, diff.shape[0], 1)[0] * diff.squeeze(1)              # Compute W_theta * (x_i - x_j)
            neigh_distances = torch.norm(to_norm, dim=1)                                                        # Compute the norm for each vector difference  ||W_theta * (x_i - x_j)||
            list_inc_score[index_current_node] = (self.W_phi * neigh_distances.max()).clone()                   # Add (W_phi * ||W_theta * (x_i - x_j)||) to the inclusion score list

        if len(previous_inclusion_score)==0:                            # return if no previous inclusion score
            if return_flatten:
                list_inc_score = list_inc_score.flatten()
            return list_inc_score
        
        if list_inc_score.shape[1]!=1:                                  # If inclusion score is not vector
            list_inc_score = torch.mean(list_inc_score, dim=1)            # Mean the matrix for each node

        # array of array to array
        if len(list_inc_score.shape)==2:                 
            if list_inc_score.shape[1]==1:
                list_inc_score = list_inc_score.flatten()

        result_np = torch.stack([previous_inclusion_score, torch.tensor(list_inc_score)])
        
        result_np = torch.mean(result_np, dim=0)
        
        return result_np


class GNN_Model(nn.Module):
    def __init__(self, nodes, adjacency_matrix):
        super(GNN_Model, self).__init__()
        self.devconv = DevConv(nodes, adjacency_matrix, 1)
        self.relu = nn.ReLU()
        self.devconv2 = DevConv(nodes, adjacency_matrix, 64)
        self.relu2 = nn.ReLU()
        self.devconv3 = DevConv(nodes, adjacency_matrix,1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x= self.devconv(x)
        x= self.relu(x)
        x= self.devconv2(x)
        x= self.relu2(x)
        x= self.devconv3(x)
        x= self.sigmoid(x)
        return x


# gnn = GNN_Model(graph_nodes, graph_adjacency_matrix)
# inclusion_score = gnn(torch.empty(0))
# print(inclusion_score.shape)

### Multinomial Sampling

In [7]:
class MultinomialLayer(nn.Module):
    def __init__(self, target_number_points,nodes):
        super().__init__()
        self.target_number_points = target_number_points
        self.nodes = nodes

    def forward(self, f):
        normalized_inclusion_score = f / torch.sum(f)                           # normalize for multinomial sampling

        mult_sampling = torch.distributions.multinomial.Multinomial(total_count=10*normalized_inclusion_score.shape[0], probs=normalized_inclusion_score).sample()      # small:more randomness    |   big:less randomness
        mult_indices = mult_sampling.topk(k=self.target_number_points).indices
        selected_nodes = self.nodes[mult_indices]

        return selected_nodes


# target_number_point = min(len(graph._node), user_number_triangles*3)   # number of points for the simplification
# layer = MultinomialLayer(target_number_point, graph_nodes)
# extended_graph_nodes = layer.forward(inclusion_score)
# print(extended_graph_nodes.shape)

# KNN extended graph

In [8]:
class KNNSimple(nn.Module):
    """
    Create a graph based on k-nearest neighbors using PyTorch.
    Parameters:
    - nodes: Tensor of shape (n, 3) representing 3D nodes.
    - k: Number of nearest neighbors.
    Returns:
    - adjacency_matrix: Binary adjacency matrix representing the graph.
    """
    def __init__(self, k):
        super().__init__()
        self.k = k

    def forward(self, nodes):
    
        expanded_x1 = nodes.unsqueeze(1)
        expanded_x2 = nodes.unsqueeze(0)
        distances = torch.norm(expanded_x1 - expanded_x2, dim=2)        # distance matrix

        _, indices = torch.topk(distances, self.k + 1, largest=False, sorted=True, dim=1)
        indices = indices[:, 1:]  # Exclude the node itself

        # Create adjacency matrix
        adjacency_matrix = torch.zeros(nodes.shape[0], nodes.shape[0], dtype=torch.float32)
        adjacency_matrix.scatter_(1, indices, 1)

        return adjacency_matrix

# extended_graph_adjacency_matrix = KNNSimple(k=15)(extended_graph_nodes)
# print(extended_graph_adjacency_matrix.shape)
# transformation.print_graph_properties(graph=nx.from_numpy_array(extended_graph_adjacency_matrix.numpy()), display_graph=False, display_labels=False)

# Edge Predictor

In [9]:
# devconv = DevConv(extended_graph_nodes,extended_graph_adjacency_matrix, 64)
# inclusion_score_edge = devconv(previous_inclusion_score=torch.empty((0)), return_flatten=False)
# inclusion_score_edge.shape

In [10]:
class SparseAttentionEdgePredictorLayer(nn.Module):
    def __init__(self, nodes, neighbors, size=64):
        super().__init__()
        self.size = size
        self.nodes = nodes
        self.neighbors = neighbors
        self.wq = nn.Parameter(torch.Tensor(size))
        self.wk = nn.Parameter(torch.Tensor(size))

        nn.init.normal_(self.wq)
        nn.init.normal_(self.wk)

    def forward(self, f):
        wq_f = self.wq.reshape(-1, 1) * f                   # Wq*f
        wk_f = self.wk.reshape(-1, 1) * f                   # Wq*f
        S = torch.exp(torch.matmul(wq_f.T, wk_f))           # e^((wq_f.T)*(wk_f))
        
        nonzero_neigh = self.neighbors.nonzero()                                                    # Find indexes of neighbors in graph
        unique_first_elements, counts = torch.unique(nonzero_neigh[:, 0], return_counts=True)       # Count number of neighbors per node
        split_tensors = list(torch.split(nonzero_neigh, tuple(counts)))                             # split indexes of neighbors into a list (1 element = 1 tensor of indexes)

        temp = [[S[n[i,0], n[i,1]] for i in range(len(n))] for n in split_tensors]                  # For each node, get the S value for the neighbors indexes
        summed = torch.Tensor([torch.sum(torch.Tensor(e)) for e in temp])                           # Sum these results for each nodes
        division = summed.unsqueeze(0).repeat(1, S.shape[1], 1)[0]                                  # Repeat the sum in S.shape[1] array => division per columns
        final_term  = S / division

        return final_term


# f = torch.mean(inclusion_score_edge, dim=1)                            # Flatten the matrix of inclusion score
# layer = SparseAttentionEdgePredictorLayer(extended_graph_nodes, extended_graph_adjacency_matrix)
# S = layer.forward(f)
# print(S.shape)

### Sparse Attention

In [11]:
# S = S*np.random.choice([0, 1], size=S.shape)      # Add a random mask to emulate the 'sparse'

# Face Candidates

#### Inputs

In [12]:
class FaceCandidatesLayer(nn.Module):
    def __init__(self, adjacency_matrix):
        super().__init__()
        self.adjacency_matrix = adjacency_matrix

    def forward(self, S):
        A_s = torch.matmul(torch.matmul(S, self.adjacency_matrix), S.T)     # A_s = S * A * S.T
        A_s = A_s/A_s.max()                                                 # Normalize
        return A_s


# layer = FaceCandidatesLayer(extended_graph_adjacency_matrix)
# A_s = layer(torch.Tensor(S))
# print(A_s.shape)

# Face Classifier

### TriConv

#### Inputs

In [13]:
class TriangleIndexes(nn.Module):
    def __init__(self, adjacency_matrix):
        super().__init__()
        self.adjacency_matrix = adjacency_matrix

    def forward(self):
        # tensor of indexes of each neighbors of each nodes 
        nonzero = self.adjacency_matrix.nonzero()
        neighbors_one_indexes = nonzero.reshape(self.adjacency_matrix.shape[0],15,2)[:,:,1].clone()
        neighbors_two_indexes = neighbors_one_indexes[neighbors_one_indexes]        # Tensor for each 2 neighbors for each nodes (neighbors of neighbors)
        neighbors_three_indexes = neighbors_one_indexes[neighbors_two_indexes]      # Tensor for each 3 neighbors for each nodes (neighbors of neighbors of neighbors)

        # Find the indices where the current index is present along the last dimension => where start node = final node (= cycle)
        values_index_reshape = torch.arange(neighbors_three_indexes.shape[0]).repeat((15,15,15,1)).T
        indices = (neighbors_three_indexes == values_index_reshape).nonzero()

        i, j, k, l = indices[:,0], indices[:,1], indices[:,2], indices[:,3]         # First node index, Second node index, third node index, Fourth node index
        temp_j = neighbors_one_indexes[i,j]                                         # number of the nodes firsts neighbors
        temp_k = neighbors_two_indexes[i,j,k]                                       # number of the nodes seconds neighbors
        temp_l = neighbors_three_indexes[i,j,k,l]                                   # number of the nodes thirds neighbors
        triangles_indexes_test = torch.stack((i, temp_j, temp_k, temp_l), dim=1)    # nodes for each path 
        triangles_indexes_test = triangles_indexes_test[:,:3]                       # remove virtual 4th point (same as the first one (cycle))


        # filter triangles indexes to clean the clones (=> divide the number of triangles by 6)
        sorted_tensor, _ = torch.sort(triangles_indexes_test, dim=-1)
        triangles_ids_igraph = torch.unique(sorted_tensor, dim=0)

        return triangles_ids_igraph

        


# layer_find_triangles_indexes = TriangleIndexes(extended_graph_adjacency_matrix)
# triangles_ids_igraph = layer_find_triangles_indexes()
# print(triangles_ids_igraph.shape)

In [14]:
class TriangleNodes(nn.Module):
    def __init__(self, nodes):
        super().__init__()
        self.nodes = nodes

    def forward(self, triangles_indexes):
        return self.nodes[triangles_indexes]

# layer_get_triangles = TriangleNodes(extended_graph_nodes)
# triangles = layer_get_triangles(triangles_ids_igraph)
# print(triangles.shape)

In [15]:
class FirstPInitLayer(nn.Module):
    def __init__(self, A_s, triangles):
        super().__init__()
        self.A_s = A_s
        self.triangles = triangles

    def forward(self, triangles_indexes):
        # Extract indices for each triangle
        i, j, k = triangles_indexes.T

        # Extract probabilities using advanced indexing
        A_s_ij = self.A_s[i, j]
        A_s_ik = self.A_s[i, k]
        A_s_jk = self.A_s[j, k]

        # Calculate the barycenter probabilities
        p_init = torch.zeros(self.triangles.shape[0])
        p_init = (A_s_ij + A_s_ik + A_s_jk) / 3
        return p_init

# p_init_layer = FirstPInitLayer()
# p_init = p_init_layer(triangles_ids_igraph)
# print(p_init.shape)

#### Calculate barycenter

In [16]:
class BarycentersLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, triangles):
        return triangles.mean(1)

# barycenters_layer = BarycentersLayer()
# barycenters = barycenters_layer(triangles)
# print(barycenters.shape)

#### KNN Tri

In [17]:
class KNN(nn.Module):
    def __init__(self, barycenters):
        super().__init__()
        self.barycenters = barycenters

    def forward(self, x, k=20, batch_size=100):
        indices_knn = torch.empty(size=(x.shape[0], k))
        
        modulo = x.shape[0]%batch_size
        nb_iter = int((x.shape[0] - modulo) / batch_size)

        for i in range(nb_iter):
            i_start, i_end = i*batch_size, (i+1)*batch_size
            distances = torch.norm(self.barycenters[i_start:i_end].unsqueeze(1) - self.barycenters.unsqueeze(0), dim=2)

            neighbors = distances.topk(k, dim=1, largest=False).indices.clone()  # Indices of the k-nearest neighbors
            indices_knn[i_start:i_end] = neighbors

        # last piece of computation
        distances = torch.norm(self.barycenters[-modulo:].unsqueeze(1) - self.barycenters.unsqueeze(0), dim=2)
        neighbors = distances.topk(k, dim=1, largest=False).indices.clone()  # Indices of the k-nearest neighbors
        indices_knn[-modulo:] = neighbors


        return indices_knn

# knn_layer = KNN()
# indices_neigh_tri = knn_layer(barycenters).int()  #change datatype
# print(indices_neigh_tri.shape)

In [18]:
class RMatrix(nn.Module):
    def __init__(self, triangles, barycenters, indices_neigh_tri, number_neigh_tri):
        super().__init__()
        self.triangles = triangles
        self.barycenters = barycenters
        self.indices_neigh_tri = indices_neigh_tri
        self.number_neigh_tri = number_neigh_tri

    def forward(self):
        # DIFF BARYCENTERS
        barycenters_diff = np.subtract(self.barycenters[self.indices_neigh_tri[:, 0]][:, np.newaxis], self.barycenters[self.indices_neigh_tri[:, 1:]])   #Inverser la différence des barycentres si nécéssaire


        # TRIANGLE EDGES NORM
        v0, v1, v2 = self.triangles[:, 0], self.triangles[:, 1], self.triangles[:, 2]

        # Calculate edge vectors
        e_ij = torch.norm(v0 - v1, dim=1)
        e_ik = torch.norm(v0 - v2, dim=1)
        e_jk = torch.norm(v1 - v2, dim=1)

        # Stack the edge vectors along the last dimension
        diff_vectors = torch.stack([e_ij, e_ik, e_jk], dim=1)


        # MAX/MIN DIFF VECTORS
        max_diff_vectors = diff_vectors.max(dim=1).values       # calculate t_n_max
        min_diff_vectors = diff_vectors.min(dim=1).values       # calculate t_n_min

        max_diff_vectors_diff = max_diff_vectors[self.indices_neigh_tri[:, 0]][:, None] - max_diff_vectors[self.indices_neigh_tri[:, 1:]]   #Inverser la différence des barycentres si nécéssaire   # calculate t_n_max - t_m_max
        min_diff_vectors_diff = min_diff_vectors[self.indices_neigh_tri[:, 0]][:, None] - min_diff_vectors[self.indices_neigh_tri[:, 1:]]   #Inverser la différence des barycentres si nécéssaire   # calculate t_n_min - t_m_min


        # R MATRIX COMPUTATION
        r_matrix = torch.zeros((self.triangles.shape[0], self.number_neigh_tri-1, 5))

        r_matrix[:, :, 0]   = min_diff_vectors_diff
        r_matrix[:, :, 1]   = max_diff_vectors_diff
        r_matrix[:, :, 2:5] = barycenters_diff
        
        return r_matrix


# r_matrix_layer = RMatrix(triangles, barycenters, indices_neigh_tri)
# r_matrix = r_matrix_layer()
# r_matrix.shape

#### Calculate f

In [19]:
class MLP(nn.Module):
  def __init__(self, r_matrix, indices_neigh_tri, hidden_size):
    super().__init__()
    self.r_matrix = r_matrix
    self.indices_neigh_tri = indices_neigh_tri
    self.hidden_size = hidden_size

  def forward(self, p_init):
    neigh_all = self.indices_neigh_tri[:,1:]

    # Triconv 1
    f = p_init
    diff_p_all = (f.repeat((neigh_all.shape[1],1)).T - f[neigh_all])
    r_diff = torch.cat((self.r_matrix, diff_p_all.unsqueeze(-1)), dim=2)
    
    x = nn.Flatten()(r_diff)
    x = nn.Linear(r_diff.shape[1]*r_diff.shape[2], self.hidden_size)(x)
    x = nn.ReLU()(x)
    f = nn.Linear(self.hidden_size, 1)(x).squeeze()

    # Triconv 2
    diff_p_all = (f.repeat((neigh_all.shape[1],1)).T - f[neigh_all])
    r_diff = torch.cat((self.r_matrix, diff_p_all.unsqueeze(-1)), dim=2)

    x = nn.Flatten()(r_diff)
    x = nn.Linear(r_diff.shape[1]*r_diff.shape[2], self.hidden_size)(x)
    x = nn.ReLU()(x)
    f = nn.Linear(self.hidden_size, 1)(x).squeeze()

    # Triconv 3
    diff_p_all = (f.repeat((neigh_all.shape[1],1)).T - f[neigh_all])
    r_diff = torch.cat((self.r_matrix, diff_p_all.unsqueeze(-1)), dim=2)

    x = nn.Flatten()(r_diff)
    x = nn.Linear(r_diff.shape[1]*r_diff.shape[2], self.hidden_size)(x)
    x = nn.ReLU()(x)
    f = nn.Linear(self.hidden_size, 1)(x).squeeze()

    f_softmax = nn.Softmax()(f)
    
    return f_softmax

# mlp = MLP(r_matrix, indices_neigh_tri, 128)
# final_scores = mlp(p_init)
# final_scores.shape

# Simplified Mesh

In [20]:
# selected_triangles_indexes = torch.topk(final_scores, k=user_number_triangles).indices
# selected_triangles = triangles[selected_triangles_indexes]
# selected_triangles.shape

In [21]:
# selected_triangles_np = selected_triangles.numpy()

# simplified_final_graph = nx.Graph()
# for index_poly, poly in enumerate(selected_triangles_np):
#     for index_current_node in range(len(poly)):
#         current_node = tuple(poly[index_current_node])
#         for index_other_node in range(index_current_node+1, len(poly)):
#             edge = current_node, tuple(poly[index_other_node])
#             simplified_final_graph.add_edge(*edge)
#             # if attribute do not exists
#             if len(simplified_final_graph.nodes[current_node])==0:
#                 simplified_final_graph.nodes[current_node]['index_triangle'] = set()
#             simplified_final_graph.nodes[current_node]['index_triangle'].add(index_poly)
#             if len(simplified_final_graph.nodes[tuple(poly[index_other_node])])==0:
#                 simplified_final_graph.nodes[tuple(poly[index_other_node])]['index_triangle'] = set()
#             simplified_final_graph.nodes[tuple(poly[index_other_node])]['index_triangle'].add(index_poly)
            
# transformation.print_graph_properties(graph=simplified_final_graph, display_graph=False, display_labels=False)

In [22]:
# simplified_final_mesh = transformation.graph_to_mesh(simplified_final_graph)

# #Affichage
# transformation.mesh_to_display_vtk(mesh_data)
# transformation.mesh_to_display_vtk(simplified_final_mesh)

# Fonctions Loss

## Probabilistic Chamfer distance

In [23]:
def torch_d_P_Ps(p_y, x, y):
    """All Tensors in input"""
    # print(p_y.shape, x.shape, y.shape)

    expanded_x1 = x.unsqueeze(1)
    expanded_x2 = y.unsqueeze(0)
    distances = torch.norm(expanded_x1 - expanded_x2, dim=2)        # distance matrix

    min_x = distances.min(dim=1).values
    min_y = distances.min(dim=0)

    first_term = torch.sum(torch.index_select(p_y, 0, min_y.indices) * min_y.values)
    second_term = torch.sum(min_x * p_y)

    return first_term + second_term


# d_P_Ps = torch_d_P_Ps(inclusion_score, graph_nodes, extended_graph_nodes)
# d_P_Ps

## Probabilistic Surfaces Distance

### d_f_S_Ss

In [24]:
def torch_d_f_S_Ss(p_b_hat, b_hat, b):
    # print(p_b_hat.shape, b_hat.shape, b.shape)

    expanded_x1 = b_hat.unsqueeze(1)
    expanded_x2 = b.unsqueeze(0)
    distances = torch.norm(expanded_x1 - expanded_x2, dim=2)

    min_b = distances.min(dim=1).values

    final_term = torch.sum(p_b_hat * min_b)

    return final_term



# igraph_g_original = igraphGraph(directed=False).from_networkx(graph)
# triangles_ids_igraph_original = np.array(igraph_g_original.cliques(min=3, max=3))
# triangles_original = np.array(igraph_g_original.vs['_nx_name'])[triangles_ids_igraph_original]
# b = torch.Tensor(np.mean(triangles_original, axis=1))

# b_hat = selected_triangles.mean(dim=1)

# p_b_hat = final_scores[selected_triangles_indexes]

# d_f_S_Ss = torch_d_f_S_Ss(p_b_hat, b_hat, b)
# d_f_S_Ss

### d_r_S_Ss

In [25]:
def torch_d_r_S_Ss(p_x, p_y, x ,y):
    expanded_x = x.unsqueeze(1)
    expanded_y = y.unsqueeze(0)
    distances = torch.norm(expanded_x - expanded_y, dim=2)
    min_d = distances.min(dim=0).values
    first_term = p_y * min_d


    indices_knn = distances.topk(k=50, dim=0, largest=False).indices.T  # Indices of the k-nearest neighbors
    knn_labels = x[indices_knn]
    xtk = torch.reshape(knn_labels, shape=((y.shape[0])*50, 3))

    expanded_xtk = xtk.unsqueeze(1)
    distances_knn = torch.norm(expanded_xtk - expanded_y, dim=2)
    min_knn = distances_knn.min(dim=1).values
    min_knn_reshaped = min_knn.reshape(((y.shape[0]), 50))

    ptk_time_norm = p_x[indices_knn] * min_knn_reshaped
    factor = (1-p_y) * (1/50)
    second_term = factor * torch.sum(ptk_time_norm, dim=1)

    final_term = torch.sum(first_term + second_term)

    return final_term

# d_f_S_Ss = torch_d_r_S_Ss(final_scores, p_b_hat, b, b_hat)
# d_f_S_Ss

# LOSS

In [None]:
def total_loss(inclusion_score, graph_nodes, extended_graph_nodes, final_scores, selected_triangles, selected_triangles_indexes, graph):
    d_P_Ps = torch_d_P_Ps(inclusion_score, graph_nodes, extended_graph_nodes)

    igraph_g_original = igraphGraph(directed=False).from_networkx(graph)
    triangles_ids_igraph_original = np.array(igraph_g_original.cliques(min=3, max=3))
    triangles_original = np.array(igraph_g_original.vs['_nx_name'])[triangles_ids_igraph_original]
    b = torch.Tensor(np.mean(triangles_original, axis=1))

    b_hat = selected_triangles.mean(dim=1)

    p_b_hat = final_scores[selected_triangles_indexes]

    d_f_S_Ss = torch_d_f_S_Ss(p_b_hat, b_hat, b)

    d_f_S_Ss = torch_d_r_S_Ss(final_scores, p_b_hat, b, b_hat)

    loss = d_P_Ps + d_f_S_Ss + d_f_S_Ss

    return loss
    

# TRAINING MODEL

In [26]:
class GNNSimplificationMesh(nn.Module):
    def __init__(self, graph_nodes, graph_adjacency_matrix, number_neigh_tri):
        super().__init__()
        self.graph_nodes = graph_nodes
        self.graph_adjacency_matrix = graph_adjacency_matrix
        self.number_neigh_tri = number_neigh_tri

    def forward(self, user_number_triangles):
        # POINT SAMPLER
        gnn = GNN_Model(self.graph_nodes, self.graph_adjacency_matrix)
        inclusion_score = gnn(torch.empty(0))

        target_number_point = min(self.graph_nodes.shape[0], user_number_triangles*3)   # number of points for the simplification
        layer = MultinomialLayer(target_number_point, self.graph_nodes)
        extended_graph_nodes = layer.forward(inclusion_score)

        extended_graph_adjacency_matrix = KNNSimple(k=15)(extended_graph_nodes)

        # EDGE PREDICTOR
        devconv = DevConv(extended_graph_nodes,extended_graph_adjacency_matrix, 64)
        inclusion_score_edge = devconv(previous_inclusion_score=torch.empty((0)), return_flatten=False)

        f = torch.mean(inclusion_score_edge, dim=1)                            # Flatten the matrix of inclusion score
        layer = SparseAttentionEdgePredictorLayer(extended_graph_nodes, extended_graph_adjacency_matrix)
        S = layer.forward(f)

        # FACE CANDIDATES
        layer = FaceCandidatesLayer(extended_graph_adjacency_matrix)
        A_s = layer(torch.Tensor(S))

        # FACE CLASSIFIER
        layer_find_triangles_indexes = TriangleIndexes(extended_graph_adjacency_matrix)
        triangles_ids_igraph = layer_find_triangles_indexes()

        layer_get_triangles = TriangleNodes(extended_graph_nodes)
        triangles = layer_get_triangles(triangles_ids_igraph)

        p_init_layer = FirstPInitLayer(A_s, triangles)
        p_init = p_init_layer(triangles_ids_igraph)

        barycenters_layer = BarycentersLayer()
        barycenters = barycenters_layer(triangles)

        knn_layer = KNN(barycenters)
        indices_neigh_tri = knn_layer(barycenters).int()  #change datatype

        r_matrix_layer = RMatrix(triangles, barycenters, indices_neigh_tri, self.number_neigh_tri)
        r_matrix = r_matrix_layer()

        mlp = MLP(r_matrix, indices_neigh_tri, 128)
        final_scores = mlp(p_init)

        selected_triangles_indexes = torch.topk(final_scores, k=user_number_triangles).indices
        selected_triangles = triangles[selected_triangles_indexes]

        return selected_triangles

In [27]:
transformation = Transformation()

number_neigh_tri = 20
stl_file_path = "3d_models/stl/Handle.stl"
mesh_data = transformation.stl_to_mesh(stl_file_path)
graph = transformation.mesh_to_graph(mesh_data)


if len(graph._node)<20:
    raise Exception("Input mesh does not have enough vertices. (More than 20 is needed)")

graph_nodes = torch.Tensor(np.array(graph))
graph_adjacency_matrix = torch.Tensor(nx.adjacency_matrix(graph).toarray())

In [94]:
class MeshDataset(Dataset):
    def __init__(self, mesh_dir):
        self.mesh_dir = mesh_dir
        self.len = 0
        self.filepaths = list()
        for filename in os.listdir(mesh_dir):
            f = os.path.join(mesh_dir, filename)
            if os.path.isfile(f):
                self.len += 1
                self.filepaths.append(f)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        mesh_path = self.filepaths[idx]
        mesh_data = transformation.stl_to_mesh(mesh_path)
        graph = transformation.mesh_to_graph(mesh_data)
        return mesh_data, graph

torch_dataset = MeshDataset("3d_models/stl/")

(<Mesh: b'STLB ATF 12.9.0.99 COLOR=\xb3\xb3\xb3\xff' 3532 vertices>,
 <networkx.classes.graph.Graph at 0x19209cad330>)

In [28]:
# gnn_model = GNNSimplificationMesh(graph_nodes, graph_adjacency_matrix, number_neigh_tri)
# selected_triangles = gnn_model(500)
# selected_triangles.shape

  result_np = torch.stack([previous_inclusion_score, torch.tensor(list_inc_score)])
  values_index_reshape = torch.arange(neighbors_three_indexes.shape[0]).repeat((15,15,15,1)).T
  return self._call_impl(*args, **kwargs)


torch.Size([500, 3, 3])

In [33]:

gnn_model = GNNSimplificationMesh(graph_nodes, graph_adjacency_matrix, number_neigh_tri)
optimizer = torch.optim.Adam(gnn_model.parameters(), lr=1e-5, weight_decay=0.99)

# Run the training loop
for epoch in range(0, 5): 
    print(f'Starting epoch {epoch+1}')
    
    current_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        
        inputs, targets = data
        
        optimizer.zero_grad()
        
        outputs = gnn_model(inputs)
        
        loss = total_loss(outputs, targets)
        
        loss.backward()
        
        optimizer.step()
        
        current_loss += loss.item()
        if i % 500 == 499:
            print('Loss after mini-batch %5d: %.3f' %
                (i + 1, current_loss / 500))
            current_loss = 0.0
# Process is complete.
print('Training process has finished.')

TypeError: Module.parameters() missing 1 required positional argument: 'self'

In [29]:
selected_triangles_np = selected_triangles.numpy()

simplified_final_graph = nx.Graph()
for index_poly, poly in enumerate(selected_triangles_np):
    for index_current_node in range(len(poly)):
        current_node = tuple(poly[index_current_node])
        for index_other_node in range(index_current_node+1, len(poly)):
            edge = current_node, tuple(poly[index_other_node])
            simplified_final_graph.add_edge(*edge)
            # if attribute do not exists
            if len(simplified_final_graph.nodes[current_node])==0:
                simplified_final_graph.nodes[current_node]['index_triangle'] = set()
            simplified_final_graph.nodes[current_node]['index_triangle'].add(index_poly)
            if len(simplified_final_graph.nodes[tuple(poly[index_other_node])])==0:
                simplified_final_graph.nodes[tuple(poly[index_other_node])]['index_triangle'] = set()
            simplified_final_graph.nodes[tuple(poly[index_other_node])]['index_triangle'].add(index_poly)
            
transformation.print_graph_properties(graph=simplified_final_graph, display_graph=False, display_labels=False)

#Affichage
simplified_final_mesh = transformation.graph_to_mesh(simplified_final_graph)
transformation.mesh_to_display_vtk(mesh_data)
transformation.mesh_to_display_vtk(simplified_final_mesh)

Number of nodes: 129
Number of edges: 621


# END - END - END - END - END - END - END - END - END - END - END - END - END - END - END - END

## Triangle Collision Loss

In [26]:
def compute_lc_le_lo(p_t, m_c_e_o, Fs):
    """
    Compute the collision loss term L_c.

    Parameters:
    - p_t: 1D numpy array containing the probabilities of each triangle (indices)
    - m_c_t: 2D numpy array containing the number of faces penetrated by each triangle
    - Fs: 3D numpy array representing the vertices of triangles

    Returns:
    - L_c: Collision loss term
    """
    assert len(p_t) == len(m_c_e_o), "Input arrays must have the same length"

    penalty_per_triangle = p_t * m_c_e_o

    # Sum the penalties for all selected triangles
    total_penalty = np.sum(penalty_per_triangle)

    # Compute the collision loss term L_c
    L_c_e_o = (1 / len(Fs)) * total_penalty

    return L_c_e_o

# Example usage:
# Replace the arrays below with your actual data
# p_t = selected_triangles_indexes
# m_c_t = numpy array containing the number of faces penetrated by each triangle
# Fs = 3D numpy array representing the vertices of triangles
p_t = selected_triangles_indexes
Fs = triangles  # Given data

In [27]:
number_neigh_barycenters = min(50, len(b_hat))
_, indexes_neigh_selected_barycenters = NearestNeighbors(n_neighbors=number_neigh_barycenters).fit(b_hat).kneighbors(b_hat)

NameError: name 'NearestNeighbors' is not defined

### Lc

In [None]:
mc = np.zeros((500))

for index_neigh_barycenters in indexes_neigh_selected_barycenters:
    current_triangle = selected_triangles[index_neigh_barycenters[0]]
    others_triangles = selected_triangles[index_neigh_barycenters[1:]]

    lines_current_triangle = LinearRing(current_triangle)
    polygons_others_tri = MultiPolygon([Polygon(others_triangle) for others_triangle in others_triangles]).buffer(0)    # buffer 0 to correct invalid polygons => take the exterior of the shape


    intersection = lines_current_triangle.intersection(polygons_others_tri)
    if intersection.is_empty:
        continue
    if intersection.geom_type == 'MultiLineString':
        mc[index_neigh_barycenters[0]] = len(intersection.geoms)
    else:
        mc[index_neigh_barycenters[0]] += 1
L_c = compute_lc_le_lo(p_t, mc, Fs)
print("L_c : ", L_c)

L_c :  4225.18776054694


### Le

In [None]:
me = np.zeros((500))

for index_neigh_barycenters in indexes_neigh_selected_barycenters:
    current_triangle = selected_triangles[index_neigh_barycenters[0]]
    others_triangles = selected_triangles[index_neigh_barycenters[1:]]

    lines_current_triangle = LinearRing(current_triangle)
    lines_others_tri = MultiLineString([LineString([other_tri[0], other_tri[1], other_tri[2], other_tri[0]]) for other_tri in others_triangles])

    intersection = lines_current_triangle.intersection(lines_others_tri)
    if intersection.is_empty:
        continue
    if intersection.geom_type == 'MultiLineString':
        me[index_neigh_barycenters[0]] = len(intersection.geoms)
    else:
        me[index_neigh_barycenters[0]] += 1
L_e = compute_lc_le_lo(p_t, me, Fs)
print("L_e : ", L_e)

L_e :  24105.658829414708


### Lo

In [None]:
mo = np.zeros((500))

Lo - échantillonnage de 100 points à partir de chaque triangle

In [None]:

def sample_points_from_triangle(t, num_points=100):
    v1, v2, v3 = t
    bary_coords = np.random.rand(num_points, 2)
    sqrt_bary_coords = np.sqrt(bary_coords[:, 0])

    u = sqrt_bary_coords
    v = bary_coords[:, 1]

    """
    La formule spécifique est dérivée de l'expression générale d'interpolation barycentrique 
    sommets A, B et C
    coord barycentriques: u et v
    coord cartésiennes: x,y,z 
    """
    x_coords = (1 - u - v) * v1[0] + u * v2[0] + v * v3[0]
    y_coords = (1 - u - v) * v1[1] + u * v2[1] + v * v3[1]
    z_coords = (1 - u - v) * v1[2] + u * v2[2] + v * v3[2]

    sampled_points = np.column_stack((x_coords, y_coords, z_coords))
    return sampled_points

points100 = sample_points_from_triangle(triangle, num_points=100)
points100
points100.shape

(100, 3)

Les Aires

In [None]:
number_neigh_selected_barycenters = min(50, len(b_hat))
def knnbar(nn):
  _, indexes_neigh_selected_barycenters = NearestNeighbors(n_neighbors=nn).fit(b_hat).kneighbors(b_hat)
  return _, indexes_neigh_selected_barycenters

In [None]:
d, indices = knnbar(number_neigh_selected_barycenters)
indices

array([[  0, 292,  20, ..., 181, 180, 141],
       [  1,   6, 344, ..., 397, 151, 473],
       [  2, 186,  10, ..., 185, 184, 189],
       ...,
       [497, 421, 418, ..., 379, 238, 457],
       [498, 261, 303, ..., 385,  40, 327],
       [499, 264, 126, ..., 404, 338, 320]], dtype=int64)

In [None]:
from scipy.spatial import distance as scipy_distance

def calculate_triangle_area(triangle):
    # Fonction pour calculer l'aire d'un triangle en utilisant la formule de Héron
    side_lengths = [scipy_distance.euclidean(triangle[i], triangle[(i + 1) % 3]) for i in range(3)]
    s = sum(side_lengths) / 2
    return np.sqrt(s * np.prod([s - length for length in side_lengths]))

def penalize_overlapping_triangles(points_list, triangles, k=50):
    # Fonction pour pénaliser les triangles qui se chevauchent
    assignment_results = []
    overlapping_triangles = np.zeros(len(triangles), dtype=int)  # Déclaration en dehors de la boucle

    # ajustement de la valeur maximale de k en fonction du nombre de triangles
    k = min(k, len(triangles))

    # NearestNeighbors pour trouver les k triangles les plus proches pour chaque point
    knn_model = NearestNeighbors(n_neighbors=k).fit(np.vstack(triangles))
    d, indices = knn_model.kneighbors(np.array(points_list)[:, 0, :])

    for i, point_list in enumerate(points_list):
        point = point_list[0]

        # Verif que les indices sont valides
        closest_triangle_indices = indices[i, :k]
        closest_triangle_indices = closest_triangle_indices[closest_triangle_indices < len(triangles)]

        # Accéder aux triangles en utilisant les indices valides
        valid_triangle_indices = []
        for idx in closest_triangle_indices:
            modified_triangle = np.copy(triangles[idx])
            modified_triangle[0] = point

            area_original = calculate_triangle_area(triangles[idx])
            area_modified = calculate_triangle_area(modified_triangle)

            # Vérifier si la somme des aires est proche de l'aire du triangle
            if not np.isclose(area_original, area_modified, rtol=1e-5):
                assignment_results.append((point, triangles[idx]))
                valid_triangle_indices.append(idx)

        # MAJ du overlapping_triangles avec les indices valides
        overlapping_triangles[valid_triangle_indices] += 1

    penalties = overlapping_triangles
    total_penalty = np.sum(penalties)
    Lo = (1 / len(triangles)) * total_penalty

    return Lo


L_o = penalize_overlapping_triangles([points100], triangles, k=50)
print("Pénalité pour les triangles qui se chevauchent :", L_o)


Pénalité pour les triangles qui se chevauchent : 0.0008337502084375521


## Final Loss

lambda_c, lambda_e, lambda_o = 0.01, 0.01, 0.01


Loss_L = prob_chamfer_dist + d_f_S_Ss + d_r_S_Ss + (lambda_c * L_c) + (lambda_e * L_e) + (lambda_o * L_o)
print(Loss_L)