<a href="https://colab.research.google.com/github/adamrawashdeh/3D-SDF-Shapes/blob/main/3D_deepSDF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DeepSDF in Pictures

*(this is based on a translation from Russian to English using Google Translate from the Notebook: https://github.com/Oktosha/DeepSDF-explained/blob/master/deepSDF-explained.ipynb)*



## Project for the course "Additional Chapters of Machine Learning", Kolodzei 599b

In this notebook, I will retell the main ideas from the article [DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation](https://arxiv.org/abs/1901.05103). Why retell when there is the article itself? Then, that there is a code and pictures generated by it!

The article deals with 3D shapes. In the retelling, for simplicity, there will be pictures about 2D shapes.

<img src="https://github.com/Oktosha/DeepSDF-explained/blob/master/3Dvs2D.png?raw=1">

In addition, my figures will be simple, and the neural networks for coding them will be completely shallow. The purpose of the retelling is not to reproduce the state-of-the-art result, but to visualize the ideas of the article so that you can start playing with them right in your notebook.

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim

from torch import linalg as LA
from torch.autograd import grad
from torch.autograd.functional import hessian as hess
from torch.autograd.functional import jacobian as jac

---
## Signed Distance Function (SDF)

So let's say we have some *shape* in 3D or 2D. How can it be presented in a computer-readable form? Mankind has come up with many ways, for example:

+ Point clouds (point cloud): store many points from the surface of the figure (such data is often obtained from sensors)
+ Grid (mesh): store the approximation of the surface of the figure by polygons (more often - triangles)
+ Voxels: for each space cube, store whether it is occupied by a figure or not

The article discusses another way: encoding using a *signed distance function* (signed distance function, SDF). For a point $x$ outside the figure, its sign distance function $sdf(x)$ is the distance $\rho(x, S)$ from this point to the surface $S$ of the figure. For a point inside, this is $-\rho(x, S)$, i.e., the distance to the surface, taken with a minus sign.

$$sdf(x) = \begin{cases}
  \rho(x, S),  & \mbox{если } x \mbox{ снаружи} \\
  -\rho(x, S), & \mbox{если } x \mbox{ внутри}
\end{cases}$$

### Geometry Factory Class

Let's code some example shapes: circles and polygons. (https://iquilezles.org/articles/distfunctions2d/)

In [None]:
from typing_extensions import Self
from torch.utils.data.dataset import Tensor
from torch import linalg as LA


## -----------------------------------------------------------------------------
class Geometry:
    EPS = 1e-12
    """
    class-namespace with geometric functions
    """
    def distance_from_segment_to_point(a, b, p):
        ans = min(torch.linalg.norm(a - p), torch.linalg.norm(b - p))
        if (torch.linalg.norm(a - b) > Geometry.EPS
            and torch.dot(p - a, b - a) > Geometry.EPS
            and torch.dot(p - b, a - b) > Geometry.EPS):
            ans = abs(torch.cross(p - a, b - a) / torch.linalg.norm(b - a))
        return ans

## -----------------------------------------------------------------------------
class Shape:
    def sdf(self, p):
        pass
    def sdf_grad(self, p):
        pass
    def sdf_torch(self, p):
        pass

## -----------------------------------------------------------------------------
class Sphere(Shape):
    def __init__(self, c, r):
        self.c = c
        self.r = r

    def sdf(self, p):
        d = np.linalg.norm(p - self.c)
        return d - self.r

    def sdf_torch(self, x):
        d = LA.vector_norm(x)
        return d - self.r;

    def sdf_grad(self, p):
        d = np.linalg.norm(p - self.c)
        dfdx = p/d
        return dfdx

## -----------------------------------------------------------------------------

#   vec2 q = vec2(length(p.xz)-t.x,p.y);
#   return length(q)-t.y;
class Torus(Shape):
  def __init__(self, p, t):
    self.p = p
    self.t = t

  # p is a 3d vector (\in R3)
  # p[0] is a single number (aka scalar), and this is in \in R1
  # p[0] and p[2] and put them into 2d vector (\in R2)

  def sdf_torch(self, p):

    tmp = torch.tensor([p[0], p[2]])   # tmp \in R2
    len = LA.vector_norm(tmp)          # len \in R1
    tmp = len - self.t[0]              # tmp \in R1
    q   = torch.tensor([tmp, p[1]])   #  q   \in R2

    # q = LA.vector_norm(p[0], p[2]) - self.t[0], p[1]
    return LA.vector_norm(q)-self.t[1]

## -----------------------------------------------------------------------------
class CappedTorus(Shape):
  def __init__(self, p, sc, ra, rb):
    self.p = p
    self.sc = sc
    self.ra = ra
    self.rb = rb

  def sdf_torch(self, p):
    p[0] = abs(p[0])
    tmp = torch.tensor([p[0], p[1]])
    len = LA.vector_norm(tmp)
    k = torch.dot(tmp,self.sc) if (self.sc[1]*p[0]>self.sc[0]*p[1]) else len
    return torch.sqrt(torch.dot(p,p) + self.ra*self.ra - 2.0*self.ra*k) - self.rb

## -----------------------------------------------------------------------------
class Link(Shape):
  def __init__(self, p, le, r1, r2):
    self.p = p
    self.le = le
    self.r1 = r1
    self.r2 = r2

  def sdf_torch(self, p):
    q = (p[0], max(torch.abs(p[1])-self.le,0.0), p[2])
    tmp = torch.tensor([q[0], q[1]])
    len = LA.vector_norm(tmp)
    a = len - self.r1
    tmp2 = torch.tensor([a, q[2]])
    return LA.vector_norm(tmp2) - self.r2

## -----------------------------------------------------------------------------
class BoxFrame(Shape):
  def __init__(self, p, e, b):
    self.p = p
    self.e = e
    self.b = b

  def sdf_torch(self, p):
      p = abs(p)-self.b
      q = abs(p+self.e)-self.e
      return min(min(LA.vector_norm(max(p[0],q[1],q[2]),0.0)+min(max(p[0],max(q[1],q[2])),0.0), LA.vector_norm(max(q[0],p[1],q[2]),0.0)+min(max(q[0],max(p[1],q[2])),0.0)), LA.vector_norm(max(q[0],q[1],p[2]),0.0)+min(max(q[0],max(q[1],p[2])),0.0))


## -----------------------------------------------------------------------------
class Plane(Shape):
  def __init__(self, p, h, n):
    self.p = p
    self.h = h
    self.n = n


  def sdf_torch(self, p):
    return torch.dot(p, self.n) + self.h

## -----------------------------------------------------------------------------
class CutHollowSphere(Shape):
  def __init__(self, p, r, h, t):
    self.p = p
    self.r = r
    self.h = h
    self.t = t

  def sdf_torch(self, p):
    w = torch.sqrt(  torch.tensor(self.r*self.r-self.h*self.h))
    pxz = torch.tensor((p[0],p[2]))
    q = torch.tensor((LA.vector_norm(pxz), p[1]))

    if (self.h*q[0]<w*q[1]):
      return LA.vector_norm(q - torch.tensor((w,self.h)))
    else:
      return torch.abs(LA.vector_norm(q) - self.r) - self.t

## -----------------------------------------------------------------------------
class Octahedron(Shape):
  def __init__(self, p, s):
    self.p = p
    self.s = s

  def sdf_torch(self, p):
    p = torch.abs(p)
    return (p[0]+p[1]+p[2]-self.s)*0.57735027

## -----------------------------------------------------------------------------
class TriangularPrism(Shape):
  def __init__(self, p, h):
    self.p = p
    self.h = h

  def sdf_torch(self, p):
    q = abs(p)
    return max(q[2]-self.h[1],max(q[0]*0.866025+p[1]*0.5,-p[1])-self.h[0]*0.5)

## -----------------------------------------------------------------------------
class InfiniteCylinder(Shape):
  def __init__(self, p, c):
    self.p = p
    self.c = c

  def sdf_torch(self, p):
    tmp1 = torch.tensor([p[0], p[2]])
    tmp2 = torch.tensor([self.c[0], self.c[1]])
    return LA.vector_norm(tmp1-tmp2)-self.c[2]

## -----------------------------------------------------------------------------
class Cone(Shape):
  def __init__(self, p, c, h):
    self.p = p
    self.c = c
    self.h = h

  def sdf_torch(self, p):
    tmp1 = torch.tensor([p[0], p[2]])
    q = LA.vector_norm(tmp1)
    tmp2 = torch.tensor([self.c[0], self.c[1]])
    tmp3 = torch.tensor([q, p[1]])
    a = -self.h-p[1]
    return max(torch.dot(tmp2, tmp3), a)

## -----------------------------------------------------------------------------
class Rectangle(Shape):
  def __init__(self, p, b):
    self.p = p
    self.b = b

  def sdf_torch(self, p):
    q = abs(p) - self.b
    return LA.vector_norm(q,0.0) + min(max(q[0],max(q[1],q[2])),0.0)

## -----------------------------------------------------------------------------
class VerticalCapsule(Shape):
  def __init__(self, p, h, r):
    self.p = p
    self.h = h
    self.r = r

  def sdf_torch(self, p):
    p[1] -= torch.clamp(p[1], 0.0, self.h)
    return LA.vector_norm(p) - self.r

## -----------------------------------------------------------------------------
class Ellipsoid(Shape):
  def __init__(self, p, r):
    self.p = p
    self.r = r

  def sdf_torch(self, p):
    k0 = LA.vector_norm(p/self.r)
    k1 = LA.vector_norm(p/(self.r*self.r))
    return k0*(k0-1.0)/k1

## -----------------------------------------------------------------------------
class CappedCylinder(Shape):
  def __init__(self, p, h, r):
    self.p = p
    self.h = h
    self.r = r

  def sdf_torch(self, p):
    tmp1 = torch.tensor([p[0], p[2]])
    len = LA.vector_norm(tmp1)
    a = torch.tensor([len, p[1]])
    b = torch.tensor([self.h, self.r])
    d = abs(a) - (b)
    return LA.vector_norm(d,0.0) + min(max(d[0], d[1]),0.0)

## -----------------------------------------------------------------------------
class SolidAngle(Shape):
  def __init__(self, p, c, ra):
    self.p = p
    self.c = c
    self.ra = ra


  def sdf_torch(self, p):
    tmp1 = torch.tensor([p[0], p[2]])
    len = LA.vector_norm(tmp1)
    q = torch.tensor([len, p[1]])
    l = LA.vector_norm(q) - self.ra
    m = LA.vector_norm(q - self.c*torch.clamp(torch.dot(q,self.c),0.0,self.ra))
    return max(l,m*torch.sign(self.c[1]*q[0]-self.c[0]*q[1]))

## -----------------------------------------------------------------------------
class Rhombus(Shape):
  def __init__(self, p, h, ra, rb):
    self.p = p
    self.h = h
    self.ra = ra
    self.rb = rb

  def sdf_torch(self, p):
    tmp1 = torch.tensor([p[0], p[2]])
    len = LA.vector_norm(tmp1)
    a = len-2.0*self.ra+self.rb
    b = abs(p[1]) - self.h
    d = torch.tensor([a,b])
    l = max(d)
    return min(d[0],d[1],0.0) + LA.vector_norm(l) - self.rb

## -----------------------------------------------------------------------------
class CutSphere(Shape):
  def __init__(self, p, h, r):
    self.p = p
    self.h = h
    self.r = r

  def sdf_torch(self, p):
    w = torch.sqrt(self.r*self.r-self.h*self.h)
    tmp1 = torch.tensor([p[0], p[2]])
    len = LA.vector_norm(tmp1)
    q = len, p[1]
    a1 = (self.h-self.r)*q[0]*q[0]+w*w*(self.h+self.r-2.0*q[1])
    a2 = self.h*q[0]-w*q[1]
    s = max(a1,a2)
    if (s<0.0):
      return LA.vector_norm(q)-self.r
    elif(q[0]<w):
      return self.h - q[1]
    else:
      return LA.vector_norm(q - (w,self.h))

## -----------------------------------------------------------------------------
class Capsule(Shape):
  def __init__(self, p, a, b, r):
    self.p = p
    self.a = a
    self.b = b
    self.r = r

  def sdf_torch(self, p):
    pa = p - self.a
    ba = self.b - self.a
    h = torch.clamp(torch.dot(pa,ba)/torch.dot(ba,ba), 0.0, 1.0 )
    return LA.vector_norm(pa - ba*h) - self.r

## -----------------------------------------------------------------------------
class RoundedCylinder(Shape):
  def __init__(self, p, ra, rb, h):
    self.p = p
    self.ra = ra
    self.rb = rb
    self.h = h

  def sdf_torch(self, p):
    tmp = LA.vector_norm(torch.tensor([p[0], p[2]]))
    a = tmp - 2.0 * self.ra + self.rb
    b = abs(p[1]) - self.h
    d = a, b
    return min(max(d[0], d[1]), 0.0) + LA.vector_norm(max(d)) - self.rb

## -----------------------------------------------------------------------------
class CutSphere(Shape):
  def __init__(self, p, h, r):
    self.p = p
    self.h = h
    self.r = r

  def sdf_torch(self, p):
    w = torch.sqrt(self.r*self.r-self.h*self.h)
    tmp1 = torch.tensor([p[0], p[2]])
    len = LA.vector_norm(tmp1)
    q = torch.tensor([len, p[1]])
    s = max((self.h-self.r) * q[0] * q[0] + w * w * (self.h + self.r - 2.0 * q[1]))
    if (s<0.0):
      return LA.vector_norm(q)-self.r
    elif(q[0]<w):
      return self.h - q[1]
    else:
      return LA.vector_norm(q - torch.tensor([(w,self.h)]))

## -----------------------------------------------------------------------------
class DeathStar(Shape):
  def __init__(self, p2, ra, rb, d):
    self.p2 = p2
    self.ra = ra
    self.rb = rb
    self.d = d


  def sdf_torch(self, p2):
    a = (self.ra * self.ra - self.rb * self.rb + self.d * self.d) / (2.0 * self.d)
    b = torch.sqrt(max(self.ra * self.ra - a * a, 0.0))
    tmp1 = torch.tensor([p2[1], p2[2]])
    p = torch.tensor([p2[0], LA.vector_norm(tmp1)])
    tmp2 = torch.tensor([self.d, 0.0])
    if (p[0] * b - p[1] * a > self.d * max(b - p[1],0.0)):
      return LA.vector_norm(p - torch.tensor([a,b]))
    else:
      return max((LA.vector_norm(p) - self.ra), -(LA.vector_norm(p - tmp2) -self.rb))

## -----------------------------------------------------------------------------
class HexagonalPrism(Shape):
  def __init__(self, p, h):
    self.p = p
    self.h = h


  def sdf_torch(self, p):
    k = -0.8660254, 0.5, 0.57735
    p = abs(p)
    tmp1 = torch.tensor([p[0], p[1]])
    tmp1-= 2.0*min(torch.dot(torch.tensor([k[0], k[1]]), tmp1), 0.0) * torch.tensor([k[0], k[1]])
    a = LA.vector_norm(tmp1 - torch.tensor([torch.clamp(p[0], -k[2]*self.h[0], k[2]*self.h[0]), self.h[0]]))*torch.sign(p[1]-self.h[0])
    b = p[2]-self.h[1]
    d = torch.tensor([a, b])
    return min(max(d[0],d[1]), 0.0) + LA.vector_norm(max(d))

## -----------------------------------------------------------------------------
class RoundCone(Shape):
  def __init__(self, p, r1, r2, h):
    self.p = p
    self.r1 = r1
    self.r2 = r2
    self.h = h


  def sdf_torch(self, p):
    b = (self.r1 - self.r2) / (self.h)
    a = torch.sqrt(1.0-b*b)
    tmp1 = torch.tensor([p[0], p[2]])
    q = torch.tensor([LA.vector_norm(tmp1), p[1]])
    k = torch.dot(q, torch.tensor([-b, a]))
    if (k <0.0):
      return LA.vector_norm(q) - self.r1
    elif (k > a * self.h):
      return LA.vector_norm(q - torch.tensor([0.0, self.h])) - self.r2
    else:
      return torch.dot(q, torch.tensor([a, b])) - self.r1

### SDF Plot Function
Let's write a function that represents the sign function of the distance.

In [None]:
import plotly.express as px
import plotly.graph_objects as go

def plot_sdf_3d(sdf_func, density=20, iso=0, volume=False):

    dx = torch.linspace(-3, 3, 20)
    z, y, x = torch.meshgrid(dx, dx, dx, indexing='ij')
    sdf = torch.tensor([[[sdf_func(torch.tensor([x_, y_, z_]))
                for z_ in dx]
                for y_ in dx]
                for x_ in dx])

    if not volume:
        fig = px.scatter_3d(
            x=x.flatten(),
            y=y.flatten(),
            z=z.flatten(),
            color=sdf.flatten(),
            size=torch.abs(1.0-sdf.flatten()),
            opacity=0.1,
            range_x = (-3,3),
            range_y = (-3,3),
            range_z = (-3,3),
            color_continuous_scale = 'RdBu',
            color_continuous_midpoint = 0.0
            )
    else:
        # fig = go.Figure(data=go.Volume(
        fig = go.Figure(data=go.Isosurface(
            x=x.flatten(),
            y=y.flatten(),
            z=z.flatten(),
            value=sdf.flatten(),
            isomin=iso,
            isomax=iso,
            opacity=0.85, # needs to be small to see through all surfaces
            surface_count=3, # needs to be a large number for good volume rendering
            colorscale = 'RdBu',
            ))

    # tight layout
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    return fig

## SDF Examples

In [None]:
# t = [1.5,1.5]
# shape = Torus(torch.tensor([0, 0]), t)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# h = 1.25
# ra = 0.5
# rb = 0.5
# shape = RoundedCylinder(torch.tensor([0, 0]), h, ra, rb)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# h = torch.tensor([1])
# r1 = 1.5
# r2 = 1.5
# shape = RoundCone(torch.tensor([0, 0]), h, r1, r2)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# h = torch.tensor([1])
# r = 2.5
# shape = CutSphere(torch.tensor([0, 0]), h, r)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# a = torch.tensor([1])
# b = torch.tensor([2, 0, 2], dtype=torch.float32)
# r = 1
# shape = Capsule(torch.tensor([0, 0]), a, b, r)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# h = 2.0
# ra = 0.0
# rb = 0.0
# shape = Rhombus(torch.tensor([0, 0]), h, ra, rb)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# c = torch.tensor([8,9], dtype=torch.float32)
# ra = 3.0
# shape = SolidAngle(torch.tensor([0, 0]), c, ra)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# p2 = torch.tensor([2])
# ra = torch.tensor([2])
# rb = 1
# d = 1
# shape = DeathStar(torch.tensor([0, 0]), ra, rb, d)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# h = 3.0
# r = 4.0
# shape = CappedCylinder(torch.tensor([0, 0]), h, r)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
#r = 1.5
#shape = Ellipsoid(torch.tensor([0, 0]), r)

#fig = plot_sdf_3d(shape.sdf_torch, volume = True)
#fig.show()

In [None]:
#h = 2.0
#r = 1.0
#shape = VerticalCapsule(torch.tensor([0, 0]), h, r)

#fig = plot_sdf_3d(shape.sdf_torch, volume = True)
#fig.show()

In [None]:
# b = 5.75
# shape = Rectangle(torch.tensor([0, 0]), b)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# c = torch.tensor([1, 1], dtype=torch.float32)
# h = 8.0
# shape = Cone(torch.tensor([0, 0]), c, h)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
#c = torch.tensor([0, 1, 1])
#shape = InfiniteCylinder(torch.tensor([0, 0]), c)

#fig = plot_sdf_3d(shape.sdf_torch, volume = True)
#fig.show()

In [None]:
# s = 2.5
# shape = Octahedron(torch.tensor([0, 0]), s)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
#r = 1.0
#h = 1.0
#t = 1.0
#shape = CutHollowSphere(torch.tensor([0, 0]), r, h, t)

#fig = plot_sdf_3d(shape.sdf_torch, volume = True)
#fig.show()

In [None]:
#h = 0.5
#n = torch.tensor([1, 2, 4], dtype=torch.float32)
#shape = Plane(torch.tensor([0, 0]), h, n)

#fig = plot_sdf_3d(shape.sdf_torch, volume = True)
#fig.show()

In [None]:
# le = 1
# r1 = 1
# r2 = 1
# shape = Link(torch.tensor([0, 0]), le, r1, r2)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# sc = torch.tensor([1,2], dtype=torch.float32)
# ra = 0.75
# rb = 0.75
# shape = CappedTorus(torch.tensor([0, 0]), sc, ra, rb)

# fig = plot_sdf_3d(shape.sdf_torch, volume = True)
# fig.show()

In [None]:
# e = 1.5
# b = 4.0
# shape = BoxFrame(torch.tensor([0, 0]), e, b)

# fig = plot_sdf_3d(shape.sdf_torch, volume= True)
# fig.show()

In [None]:
#h = [3,2]
#shape = TriangularPrism(torch.tensor([0, 0]), h)

#fig = plot_sdf_3d(shape.sdf_torch, volume = True)
#fig.show()

In [None]:
# 1. create a shape (sphere)
r = 2.0
shape = Sphere(torch.tensor([0, 0]), r)

# 2. 3d draw the shape by passing the sdf_torch function to plot_sdf_3d
fig = plot_sdf_3d(shape.sdf_torch, volume = True)
fig.show()
