### Imports & Settings

In [95]:
import torch

from torch import nn, optim

from torch.utils.data import DataLoader, TensorDataset

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.colors

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

import cv2

import copy
import time

In [96]:
torch.manual_seed(123)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.__version__)
print(device)

2.2.0+cu121
cuda


### Modules

In [97]:
class SineActivation(nn.Module):
    def __init__(self):
        super(SineActivation, self).__init__()
        
    def forward(self, x):
        return torch.sin(x)

In [98]:
class FullyConnected(nn.Module):
    '''
    Builds a fully connected neural network from a list of layer sizes and an optional activation function.
    
    Args:
        fc_sizes (list): Sizes of each layer in the network.
        activation (nn.Module): Activation function applied between layers, defaults to SineActivation.
    '''
    def __init__(self, fc_sizes: list, activation=SineActivation):
        super(FullyConnected, self).__init__()
        
        layers = []
        for size_in, size_out in zip(fc_sizes[:-1], fc_sizes[1:-1]):
            layers.append(nn.Linear(size_in, size_out))
            layers.append(activation())
        
        layers.append(nn.Linear(fc_sizes[-2], fc_sizes[-1]))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

### Utils

In [99]:
def timer(func):
    '''
    Decorator to measure and print the execution time of the decorated function.

    Args:
        func (callable): Function to measure.

    Returns:
        callable: Wrapped function with timing.
    '''
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f'Executed in {end_time - start_time} seconds.')
        return result
    return wrapper

In [100]:
def get_normal_field(f):
    def normal_field(x):
        '''
        Compute the normal to the implicit surface represented by a predetermined neural signed distance field.

        Args:
            x (torch.Tensor): A tensor representing a vector in the input space where the normal will be computed.

        Returns:
            torch.Tensor: The normal to the zero-level set of the neural field at point `x`, normalized to unit length.
        '''
        x = x.requires_grad_(True) 

        f_value = f(x)
        grad = torch.autograd.grad(
            outputs=f_value, 
            inputs=x, 
            grad_outputs=torch.ones_like(f_value),
            create_graph=True)[0]
        normal = grad / torch.sqrt(torch.sum(grad ** 2))

        return normal
    return normal_field


In [103]:
def get_sphere_sdf(r):
    def sphere_sdf(x):
        '''
        Computes the signed distance from the point `x` to a sphere centered at the origin with radius `r`.

        Args:
            x (torch.Tensor): A tensor representing a point in space.

        Returns:
            float: The signed distance from the point `x` to the surface of the sphere.
        '''
        return torch.sum(x ** 2, dim=-1) - r ** 2
    return sphere_sdf

In [104]:
def get_star_sdf(r, n, m):
    def star_sdf(x):      
        an = torch.tensor(torch.pi / n, device=x.device)
        en = torch.tensor(torch.pi / m, device=x.device)
        
        acs = torch.tensor([torch.cos(an), torch.sin(an)], dtype=torch.float32, device=x.device)
        ecs = torch.tensor([torch.cos(en), torch.sin(en)], dtype=torch.float32, device=x.device)
        
        bn = torch.atan2(x[..., 1], x[..., 0]) % (2.0 * an) - an
        x = torch.norm(x, dim=-1, keepdim=True) * torch.cat((torch.cos(bn).unsqueeze(-1), torch.abs(torch.sin(bn)).unsqueeze(-1)), dim=-1)
        
        x = x - r * acs
        dot_product = torch.sum(x * ecs, dim=-1, keepdim=True)
        x = x + ecs * torch.clamp(-dot_product, 0.0, r * acs[1] / ecs[1])
        
        return torch.norm(x, dim=-1) * torch.sign(x[..., 0])
    return star_sdf

In [105]:
def make_annular(sdf_func, r):
    def onion_sdf(pos):
        sdf_values = sdf_func(pos)
        return torch.abs(sdf_values) - r
    return onion_sdf

In [106]:
def create_grid(res=100):
    '''
    Returns a grid of 2D points and the corresponding x, y meshgrid arrays.

    Args:
        res (int): The resolution of the grid, default is 100.

    Returns:
        grid (torch.Tensor): Flattened grid coordinates.
        x (torch.Tensor): X coordinates of the meshgrid.
        y (torch.Tensor): Y coordinates of the meshgrid.
    '''
    p = torch.linspace(-1, 1, res)
    x, y = torch.meshgrid(p, p, indexing='xy')
    grid = torch.stack((x.flatten(), y.flatten()), dim=-1).to(device)
    return grid, x, y

In [107]:
def plot_sdf(net, res=100, contour_lines=15, colormap='RdBu'):
    '''
    Visualizes the SDF generated by a neural network over a fixed range from -1 to 1 for both x and y axes.

    Args:
        net (torch.nn.Module): Neural network model to generate SDF values.
        res (int): Number of points along each axis (res x res grid).
        contour_lines (int): Number of contour lines in the plot.
        colormap (str): Colormap used for visualization.
    '''
    grid, x, y = create_grid(res)

    with torch.no_grad():
        z = net(grid).view(res, res).cpu()

    fig, ax = plt.subplots()

    max_abs_value = max(abs(z.min()), abs(z.max()))
    contour_levels = torch.linspace(-max_abs_value, max_abs_value, contour_lines)
    
    c = ax.contourf(
        x.numpy(), y.numpy(), z.numpy(), 
        levels=contour_levels.numpy(), 
        cmap=colormap, 
        origin='lower')
    plt.colorbar(c, ax=ax)
    
    ax.contour(x.numpy(), y.numpy(), z.numpy(), levels=[0], colors='black', linewidths=2)
    ax.set_aspect('equal', 'box')
    
    plt.show()

In [108]:
def generate_sdf_video(net, total_time_steps=100, res=200, fps=10):
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    out = cv2.VideoWriter('sdf_evolution.mp4', fourcc, fps, (2*res, 2*res))

    grid, x, y = create_grid(res)

    dpi = 100  
    figsize = (2*res / dpi, 2*res / dpi)  

    for t in torch.linspace(0, 1, total_time_steps):
        with torch.no_grad():
            grid_t = torch.cat((grid.cpu(), t.repeat(grid.size(0), 1)), dim=1).to(device)
            z = net(grid_t).view(res, res).cpu()

        fig, ax = plt.subplots(figsize=figsize)
        max_abs_value = max(abs(z.min()), abs(z.max()))
        contour_levels = torch.linspace(-max_abs_value, max_abs_value, 15)

        ax.contourf(x.numpy(), y.numpy(), z.numpy(), levels=contour_levels.numpy(), cmap='RdBu')
        ax.contour(x.numpy(), y.numpy(), z.numpy(), levels=[0], colors='black', linewidths=2)
        ax.set_aspect('equal', 'box')
        ax.axis('off')  

        fig.canvas.draw()
        img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        plt.close(fig)
        out.write(img)

In [None]:
V_layers = [3, 64, 64, 64, 64, 64, 64, 64, 64, 1]
V = FullyConnected(V_layers)

In [None]:
pinn_layers = [3, 1024, 1024, 1024, 1024, 512, 512, 256, 128, 1]
pinn = FullyConnected(pinn_layers).to(device)