<a href="https://colab.research.google.com/github/Shona173/codes/blob/main/3D_OT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install POT
!pip install trimesh

Collecting POT
  Downloading POT-0.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (34 kB)
Downloading POT-0.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: POT
Successfully installed POT-0.9.5
Collecting trimesh
  Downloading trimesh-4.6.8-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.6.8-py3-none-any.whl (709 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m709.3/709.3 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.6.8


In [3]:
import numpy as np
import trimesh
import skimage.measure
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import ot
from google.colab import files

In [6]:
def clamp(value, min_val, max_val):
    return np.maximum(min_val, np.minimum(value, max_val))

In [5]:
def sdf_sphere(p, s=0.5):
    p = p.copy()
    return np.sqrt(np.sum(p**2, axis=1))-s

In [7]:
def sdf_octahedron(p, s=0.5):
    p = np.abs(p)
    m = np.sum(p, axis=1) - s
    d = np.zeros(p.shape[0])

    # Get index for each branch
    idx_x = 3.0 * p[:, 0] < m
    idx_y = (~idx_x) & (3.0 * p[:, 1] < m)
    idx_z = (~idx_x) & (~idx_y) & (3.0 * p[:, 2] < m)
    idx_else = ~(idx_x | idx_y | idx_z)

    # Build q for each condition
    q = np.zeros_like(p)
    q[idx_x] = p[idx_x]
    q[idx_y] = p[idx_y][:, [1, 2, 0]]
    q[idx_z] = p[idx_z][:, [2, 0, 1]]

    # Calculate distance for q
    k = clamp(0.5 * (q[:, 2] - q[:, 1] + s), 0.0, s)
    d_tmp = np.linalg.norm(np.stack([q[:, 0], q[:, 1] - s + k, q[:, 2] - k], axis=1), axis=1)

    # Assign to d according to condition
    d[idx_x | idx_y | idx_z] = d_tmp[idx_x | idx_y | idx_z]
    d[idx_else] = m[idx_else] * 0.57735027  # 1/sqrt(3)

    return d

In [10]:
def r_intersection(f1,f2):
  return f1+f2-np.sqrt(f1**2+f2**2)

In [8]:
def gen_grid(resolution, low=-1.0, high=1.0):
    idx = np.linspace(low, high, num=resolution)
    x, y, z = np.meshgrid(idx, idx,idx)
    V = np.concatenate((x.reshape((-1,1)), y.reshape((-1,1)),z.reshape((-1,1))), 1)

    return V

In [9]:
def sample_and_normalize(f, grid, grid_size):
    '''
    Sample f on the grid and normalize it.
    Assume f>0 outside and <0 inside.
    '''
    print(grid.shape)
    fv = f(grid)
    print(fv.shape)

    # f is the characteristic function for {f>0}
    fv[fv>=0.0] = 1.0
    fv[fv<0.0] = 0.0

    total_sum = np.sum(fv)
    if total_sum > 0:
        fv = fv / total_sum
    else:
        raise ValueError("The sum of the function values is zero; normalization is not possible.")

    # reshape to have the same shape as grid
    fv = fv.reshape(grid_size, grid_size,grid_size)

    return fv

In [11]:
def export_sdf_sphere_to_obj(grid_size=64, bounds=(-2, 2), radius=1.0, output_file="sdf_sphere.obj"):
    """
    Generate a mesh of a sphere from its SDF and export it as a .obj file.
    """
    x = np.linspace(bounds[0], bounds[1], grid_size)
    y = np.linspace(bounds[0], bounds[1], grid_size)
    z = np.linspace(bounds[0], bounds[1], grid_size)
    X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
    grid = np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=1)

    sdf_values = sdf_sphere(grid, radius).reshape(grid_size, grid_size, grid_size)

    verts, faces, normals, _ = skimage.measure.marching_cubes(sdf_values, level=0.0)

    scale = (bounds[1] - bounds[0]) / (grid_size - 1)
    verts = verts * scale + bounds[0]

    mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
    mesh.export(output_file)

export_sdf_sphere_to_obj(grid_size=64, bounds=(-2.0,2.0), radius=1.0, output_file="sdf_sphere.obj")

In [12]:
def export_sdf_octahedron_to_obj(grid_size=64, bounds=(-2, 2), s=1.0, output_file="sdf_octahedron.obj"):
  x=np.linspace(bounds[0], bounds[1], grid_size)
  y=np.linspace(bounds[0], bounds[1], grid_size)
  z=np.linspace(bounds[0], bounds[1], grid_size)
  X,Y,Z=np.meshgrid(x,y,z,indexing="ij")
  grid=np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=1)

  sdf_values=sdf_octahedron(grid,s).reshape(grid_size, grid_size, grid_size)

  verts, faces, normals, _ = skimage.measure.marching_cubes(sdf_values, level=0.0)

  scale = (bounds[1] - bounds[0]) / (grid_size - 1)
  verts = verts * scale + bounds[0]

  mesh=trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
  mesh.export(output_file)

export_sdf_octahedron_to_obj(grid_size=64, bounds=(-2.0,2.0),s=1.0,output_file="sdf_octahedron.obj")

In [13]:
def convolutional_barycenter3d(mu_list, alpha_list, Ht_func, a, max_iter=100, sharpen_entropy=None):
    k = len(mu_list)
    n = mu_list[0].size

    v = [np.ones(n) for _ in range(k)]
    w = [np.ones(n) for _ in range(k)]

    for _ in range(max_iter):
        d = []
        for i in range(k):
            Hv = Ht_func(a * v[i])
            Hv[Hv == 0] = 1e-8
            w[i] = mu_list[i] / Hv
            d_i = v[i] * Ht_func(a * w[i])
            d.append(d_i)
        mu = np.ones(n)
        for i in range(k):
            mu *= np.power(d[i], alpha_list[i])
        if sharpen_entropy is not None:
            entropy = -np.sum(a * mu * np.log(np.maximum(mu, 1e-8)))
            if entropy > sharpen_entropy:
                beta = sharpen_entropy / entropy
                mu = np.power(mu, beta)
                mu /= np.sum(mu * a)
        for i in range(k):
            v[i] = v[i] * mu / (d[i] + 1e-8)

    return mu.reshape(mu_list[0].shape)

In [None]:
def visualize_OT(f1,f2,t,grid_size=64,bounds=(-2, 2),output_file="blended_mesh.obj"):
    x=np.linspace(bounds[0], bounds[1], grid_size)
    y=np.linspace(bounds[0], bounds[1], grid_size)
    z=np.linspace(bounds[0], bounds[1], grid_size)
    X,Y,Z=np.meshgrid(x, y, z,indexing="ij")
    grid=np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=1)


In [None]:
grid_size=64
grid=gen_grid(gen_grid)

f1 = sample_and_normalize(sdf_sphere, grid, grid_size)
f2 = sample_and_normalize(sdf_octahedron, grid, grid_size)

