In [22]:
from dataclasses import dataclass
from dataclasses import replace
from functools import partial

import jax
import jax.numpy as jnp
from flax import struct
import optax as ox

# Packages that actually performs Sinkhorn algorithm
from ott.geometry.pointcloud import PointCloud
from ott.problems.linear.linear_problem import LinearProblem
from ott.solvers.linear.sinkhorn import Sinkhorn

from sklearn.metrics import mean_squared_error, mean_absolute_error, explained_variance_score
from sklearn.kernel_ridge import KernelRidge
from sklearn.model_selection import GridSearchCV, KFold, train_test_split
from sklearn.gaussian_process import kernels
## Preprocessing step
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import FunctionTransformer

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import seaborn as sns
import random # to select randomly the reference distribution

import h5py
#import plotly.graph_objects as go
#from mpl_toolkits.mplot3d import Axes3D

def read_cgns_coordinates(file_path):
    with h5py.File(file_path, 'r') as file:
        # We retrieve coordinate by coordinate.
        # ! Notice the space before the data. This is due to the naming in the files themselves.
        x = np.array(file['Base_2_3/Zone/GridCoordinates/CoordinateX'].get(' data'))
        y = np.array(file['Base_2_3/Zone/GridCoordinates/CoordinateY'].get(' data'))
        z = np.array(file['Base_2_3/Zone/GridCoordinates/CoordinateZ'].get(' data'))

    return x, y, z



@struct.dataclass
class WeightedPointCloud:
  """A weighted point cloud.
  
  Attributes:
    cloud: Array of shape (n, d) where n is the number of points and d the dimension.
    weights: Array of shape (n,) where n is the number of points.
  """
  cloud: jnp.array
  weights: jnp.array

  def __len__(self):
    return self.cloud.shape[0]


@struct.dataclass
class VectorizedWeightedPointCloud:
  """Vectorized version of WeightedPointCloud.

  Assume that b clouds are all of size n and dimension d.
  
  Attributes:
    _private_cloud: Array of shape (b, n, d) where n is the number of points and d the dimension.
    _private_weights: Array of shape (b, n) where n is the number of points.
  
  Methods:
    unpack: returns the cloud and weights.
  """
  _private_cloud: jnp.array
  _private_weights: jnp.array

  def __getitem__(self, idx):
    return WeightedPointCloud(self._private_cloud[idx], self._private_cloud[idx])
  
  def __len__(self):
    return self._private_cloud.shape[0]
  
  def __iter__(self):
    for i in range(len(self)):
      yield self[i]

  def unpack(self):
    return self._private_cloud, self._private_weights

def pad_point_cloud(point_cloud, max_cloud_size, fail_on_too_big=True):
  """Pad a single point cloud with zeros to have the same size.
  
  Args:
    point_cloud: a weighted point cloud.
    max_cloud_size: the size of the biggest point cloud.
    fail_on_too_big: if True, raise an error if the cloud is too big for padding.
  
  Returns:
    a WeightedPointCloud with padded cloud and weights.
  """
  cloud, weights = point_cloud.cloud, point_cloud.weights
  delta = max_cloud_size - cloud.shape[0]
  if delta <= 0:
    if fail_on_too_big:
      assert False, 'Cloud is too big for padding.'
    return point_cloud

  ratio = 1e-3  # less than 0.1% of the total mass.
  smallest_weight = jnp.min(weights) / delta * ratio
  small_weights = jnp.ones(delta) * smallest_weight

  weights = weights * (1 - ratio)  # keep 99.9% of the mass.
  weights = jnp.concatenate([weights, small_weights], axis=0)

  cloud = jnp.pad(cloud, pad_width=((0, delta), (0,0)), mode='mean')

  point_cloud = WeightedPointCloud(cloud, weights)

  return point_cloud

def pad_point_clouds(cloud_list):
  """Pad the point clouds with zeros to have the same size.

  Note: this function should be used outside of jax.jit because the computation graph
        is huge. O(len(cloud_list)) nodes are generated.

  Args:
    cloud_list: a list of WeightedPointCloud.
  
  Returns:
    a VectrorizedWeightedPointCloud with padded clouds and weights.
  """
  # sentinel for unified processing of all clouds, including biggest one.
  max_cloud_size = max([len(cloud) for cloud in cloud_list]) + 1
  sentinel_padder = partial(pad_point_cloud, max_cloud_size=max_cloud_size)

  cloud_list = list(map(sentinel_padder, cloud_list))
  coordinates = jnp.stack([cloud.cloud for cloud in cloud_list])
  weights = jnp.stack([cloud.weights for cloud in cloud_list])
  return VectorizedWeightedPointCloud(coordinates, weights)

def clouds_barycenter(points):
  """Compute the barycenter of a set of clouds.
  
  Args:
    points: a VectorizedWeightedPointCloud.
    
  Returns:
    a barycenter of the clouds of points, of shape (1, d) where d is the dimension.
  """
  clouds, weights = points.unpack()
  barycenter = jnp.sum(clouds * weights[:,:,jnp.newaxis], axis=1)
  barycenter = jnp.mean(barycenter, axis=0, keepdims=True)
  return barycenter


def to_simplex(mu):
  """Project weights to the simplex.
  
  Args: 
    mu: a WeightedPointCloud.
    
  Returns:
    a WeightedPointCloud with weights projected to the simplex."""
  if mu.weights is None:
    mu_weights = None
  else:
    mu_weights = jax.nn.softmax(mu.weights)
  return replace(mu, weights=mu_weights)


def reparametrize_mu(mu, cloud_barycenter, scale):
  """Re-parametrize mu to be invariant by translation and scaling.

  Args:
    mu: a WeightedPointCloud.
    cloud_barycenter: Array of shape (1, d) where d is the dimension.
    scale: float, scaling parameter for the re-parametrization of mu.
  
  Returns:
    a WeightedPointCloud with re-parametrized weights and cloud.
  """
  # invariance by translation : recenter mu around its mean
  mu_cloud = mu.cloud - jnp.mean(mu.cloud, axis=0, keepdims=True)  # center.
  mu_cloud = scale * jnp.tanh(mu_cloud)  # re-parametrization of the domain.
  mu_cloud = mu_cloud + cloud_barycenter  # re-center toward barycenter of all clouds.
  return replace(mu, cloud=mu_cloud)


def clouds_to_dual_sinkhorn(points, 
                            mu, 
                            init_dual=(None, None),
                            scale=1.,
                            has_aux=False,
                            sinkhorn_solver_kwargs=None, 
                            parallel: bool = True,
                            batch_size: int = -1):
  """Compute the embeddings of the clouds with regularized OT towards mu.
  
  Args:
    points: a VectorizedWeightedPointCloud.
    init_dual: tuple of two arrays of shape (b, n) and (b, m) where b is the number of clouds,
               n is the number of points in each cloud, and m the number of points in mu.
    scale: float, scaling parameter for the re-parametrization of mu.
    has_aux: bool, whether to return the full Sinkhorn output or only the dual variables.
    sinkhorn_solver_kwargs: dict, kwargs for the Sinkhorn solver.
      Must contain the key 'epsilon' for the regularization parameter.

  Returns:
    a tuple (dual, init_dual) with dual variables of shape (n, m) where n is the number of points
    and m the number of points in mu, and init_dual a tuple (init_dual_cloud, init_dual_mu) 
  """
  sinkhorn_epsilon = sinkhorn_solver_kwargs.pop('epsilon')
  
  # weight projection
  barycenter = clouds_barycenter(points)
  mu = to_simplex(mu)

  # cloud projection
  mu = reparametrize_mu(mu, barycenter, scale)

  def sinkhorn_single_cloud(cloud, weights, init_dual):
    geom = PointCloud(cloud, mu.cloud,
                      epsilon=sinkhorn_epsilon)
    ot_prob = LinearProblem(geom,
                            weights,
                            mu.weights)
    solver = Sinkhorn(**sinkhorn_solver_kwargs)
    ot = solver(ot_prob, init=init_dual)
    return ot
  
  if parallel:
    if batch_size == -1:
        parallel_sinkhorn = jax.vmap(sinkhorn_single_cloud,
                                    in_axes=(0, 0, (0, 0)),
                                    out_axes=0)
        outs = parallel_sinkhorn(*points.unpack(), init_dual)
        return outs.g
    else:
      raise ValueError("Not coded yet") 
  else:
    list_of_g_potentials = []
    clouds, weights = points.unpack()
    for i in range(len(clouds)):
      ot_problem = sinkhorn_single_cloud(clouds[i], weights[i], init_dual)
      list_of_g_potentials.append(ot_problem.g)
    g_potentials_array = jnp.stack(list_of_g_potentials)
    return g_potentials_array
  

# Set the seed for reproducibility
np.random.seed(42)

In [23]:
## Number of blade one want to consider. We consider the train_250 split.
blades_number_train = [3,6,7,16,20,21,22,23,29,33,34,39,46,56,57,71,76,77,78,81,83,95,99,101,102,105,115,117,124,130,143,145,152,154,157,159,160,167,173,174,180,182,187,190,196,198,201,203,204,210,212,217,220,223,224,229,233,246,247,250,251,252,264,268,270,278,288,289,300,312,314,316,317,319,320,324,334,335,337,339,348,356,357,359,367,369,370,371,375,376,377,379,383,389,395,396,398,400,404,405,408,413,414,415,416,420,426,428,431,435,436,441,443,444,449,452,463,468,469,471,472,479,483,490,501,512,513,516,518,519,523,524,525,526,527,528,530,532,553,556,557,558,561,567,568,570,572,573,575,589,593,595,597,601,606,612,616,621,622,624,628,629,631,638,641,643,647,648,657,662,663,673,677,681,692,699,703,704,705,711,713,715,721,728,731,732,741,742,747,754,757,760,763,766,769,772,779,781,782,783,784,798,800,806,812,813,816,823,826,832,833,834,836,842,843,846,852,854,857,864,866,871,872,876,877,884,892,896,901,909,922,927,931,936,937,939,946,956,959,965,975,978,982,985,987,993,994,995,996,999]
blades_number_test = [1000,1001,1002,1003,1004,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023,1024,1025,1026,1027,1028,1029,1030,1031,1032,1033,1034,1035,1036,1037,1038,1039,1040,1041,1042,1043,1044,1045,1046,1047,1048,1049,1050,1051,1052,1053,1054,1055,1056,1057,1058,1059,1060,1061,1062,1063,1064,1065,1066,1067,1068,1069,1070,1071,1072,1073,1074,1075,1076,1077,1078,1079,1080,1081,1082,1083,1084,1085,1086,1087,1088,1089,1090,1091,1092,1093,1094,1095,1096,1097,1098,1099,1100,1101,1102,1103,1104,1105,1106,1107,1108,1109,1110,1111,1112,1113,1114,1115,1116,1117,1118,1119,1120,1121,1122,1123,1124,1125,1126,1127,1128,1129,1130,1131,1132,1133,1134,1135,1136,1137,1138,1139,1140,1141,1142,1143,1144,1145,1146,1147,1148,1149,1150,1151,1152,1153,1154,1155,1156,1157,1158,1159,1160,1161,1162,1163,1164,1165,1166,1167,1168,1169,1170,1171,1172,1173,1174,1175,1176,1177,1178,1179,1180,1181,1182,1183,1184,1185,1186,1187,1188,1189,1190,1191,1192,1193,1194,1195,1196,1197,1198,1199]

## Creating the list of all file numbers.
padded_numbers_train = [str(i).zfill(9) for i in blades_number_train]
padded_numbers_test = [str(i).zfill(9) for i in blades_number_train]

## Lists that will holds the cloud points and the associated efficiency.
distributions_train = []
distributions_test = []


for number in padded_numbers_train:
    ## File paths Personal Computer
    cgns_file_path = f'Rotor37/dataset/samples/sample_{number}/meshes/mesh_000000000.cgns'
    ## Computing the coordinates
    x, y, z = read_cgns_coordinates(cgns_file_path)
    blade = np.column_stack((x, y, z))
    ## Adding to our data
    distributions_train.append(blade)

# Test
for number in padded_numbers_test:
    ## File paths Personal Computer
    cgns_file_path = f'Rotor37/dataset/samples/sample_{number}/meshes/mesh_000000000.cgns'
    ## Computing the coordinates
    x, y, z = read_cgns_coordinates(cgns_file_path)
    blade = np.column_stack((x, y, z))
    ## Adding to our data
    distributions_test.append(blade)

#### Here the reference measure is the same for train and test
mu = random.choice(distributions_train)
## Convert the list of points mu into a WeightedPointCloud object
mu_cloud = WeightedPointCloud(
    cloud=jnp.array(mu),
    weights=jnp.ones(len(mu))
)


## First we convert the list all the sampled distributions to WeightedPointCloud objects
list_of_weighted_point_clouds_train = []
for sample in distributions_train:
    distrib_cloud = WeightedPointCloud(
        cloud=jnp.array(sample),
        weights=jnp.ones(len(sample)))
    list_of_weighted_point_clouds_train.append(distrib_cloud)

# Test
list_of_weighted_point_clouds_test = []
for sample in distributions_test:
    distrib_cloud = WeightedPointCloud(
        cloud=jnp.array(sample),
        weights=jnp.ones(len(sample)))
    list_of_weighted_point_clouds_test.append(distrib_cloud)

## We need to convert the cloud list to a VectorizedWeightedPointCloud
x_cloud_train = pad_point_clouds(list_of_weighted_point_clouds_train)
x_cloud_test = pad_point_clouds(list_of_weighted_point_clouds_test)

## We choose our epsilon parameter and perform the sinkhirn algorithm
sinkhorn_solver_kwargs = {'epsilon': 0.01}
# Train
sinkhorn_potentials_train = clouds_to_dual_sinkhorn(points = x_cloud_train, mu = mu_cloud, 
                                                   sinkhorn_solver_kwargs = sinkhorn_solver_kwargs, 
                                                   parallel = False, # going into the for loop
                                                   batch_size = -1)
np.savetxt("sinkhorn_potentials_train_250_epsilon_001.csv", sinkhorn_potentials_train, delimiter=";")

# Test
sinkhorn_potentials_test = clouds_to_dual_sinkhorn(points = x_cloud_test, mu = mu_cloud, 
                                                   sinkhorn_solver_kwargs = sinkhorn_solver_kwargs, 
                                                   parallel = False, # going into the for loop
                                                   batch_size = -1)
np.savetxt("sinkhorn_potentials_test_epsilon_001.csv", sinkhorn_potentials_test, delimiter=";")
