In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.io as pio
from math import sqrt

from typing_extensions import overload

pio.renderers.default = "browser"

In [2]:
# The idea basically is to develop a 2D Toy problem to demonstrate GINN. It must cover the entire workflow, establish some evaluations, and produce concrete comparable metrics. The choice of models or "data" is arbitrary. However, it must be easily transferable to VecSet type workflows too. Ideally modular.
# Also lets use neural fields instead of clouds from step 1.
N = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# def plot_cloud(points, title="Point Cloud"):
#     x, y = zip(*points)
#     fig = go.Figure(data=go.Scatter(x=x, y=y, mode='markers'))
#     fig.update_layout(title=title,
#                       xaxis_title='X',
#                       yaxis_title='Y',
#                       width=600, height=600)
#     fig.show()
# 
# def plot_field(grid_y, grid_x, values, plotly=False):
#     if not plotly:
#         plt.contourf(grid_x.cpu().numpy(), grid_y.cpu().numpy(), values.cpu(), levels=50)
#         plt.colorbar()
#         plt.contour(grid_x.cpu(), grid_y.cpu(), values.reshape(grid_x.shape).cpu(), levels=[0.0], colors='black')  # plot zero level
#         plt.show()
#         return
#     
#     fig = go.Figure(
#         data=go.Contour(
#             x=grid_x[0, :].cpu().numpy(),   # X-axis from meshgrid
#             y=grid_y[:, 0].cpu().numpy(),   # Y-axis from meshgrid
#             z=values.cpu().numpy(),         # field values
#             colorscale="Viridis",           # color map
#             contours=dict(showlines=True, coloring="fill"),
#             ncontours=100,
#             showscale=True
#         )
#     )
# 
#     # Add zero-level contour in black (like your plt.contour)
#     fig.add_trace(
#         go.Contour(
#             x=grid_x[0, :].cpu().numpy(),
#             y=grid_y[:, 0].cpu().numpy(),
#             z=values.cpu().numpy(),
#             contours=dict(start=-0.1, end=0.1, size=1, coloring="none"),
#             line=dict(color="black", width=2),
#             showscale=False
#         )
#     )
#     
#     fig.update_layout(
#         title="Contour Plot",
#         xaxis=dict(
#             scaleanchor="y",  # Lock x and y axes together
#         ),
#         yaxis=dict(
#             scaleanchor="x",  # Lock y and x axes together (optional, can be removed if set on xaxis)
#         ),
#         autosize=True
#     )
# 
#     fig.show()
# 
# def get_point_rep(fun):
#     grid_x, grid_y = torch.meshgrid(
#             torch.linspace(-1, 1, N), torch.linspace(-1, 1, N), indexing="ij"
#         )
#     coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=-1).to(device)
#     
#     with torch.no_grad():
#         values = fun(coords).reshape(N, N)
#         
#     return grid_x.to(device), grid_y.to(device), values.to(device)
# 
# def circle_SDF(x0=[0,0], r=1):
#     return lambda points: torch.asarray([abs(sqrt((x[0]-x0[0])**2 + (x[1]-x0[1])**2)-r) for x in points])
#     

In [26]:
class SDF:
    def __init__(self, fun=None, model=None, grid_x=None, grid_y=None, values=None, xy_lims=None):
        if xy_lims is not None:
            self.grid_y, self.grid_x = torch.meshgrid(torch.linspace(xy_lims[0], xy_lims[1], N), torch.linspace(xy_lims[2], xy_lims[3], N), indexing="ij") 
        else:
            assert grid_x is not None and grid_y is not None
            self.grid_x, self.grid_y = grid_x, grid_y
        
        self.values, self.model, self.values = None, None, None
        self.update(fun, model, values)
        
        self.fig = None
                
    def plot_field(self, plotly=False, layers=100):
        grid_x, grid_y = self.grid_x, self.grid_y
        values = self.values
        if not plotly:
            plt.contourf(grid_x.cpu().numpy(), grid_y.cpu().numpy(), values.cpu(), levels=layers)
            plt.colorbar()
            plt.contour(grid_x.cpu(), grid_y.cpu(), values.reshape(grid_x.shape).cpu(), levels=[0.0], colors='black')  # plot zero level
            plt.show()
            return
        
        self.fig = go.Figure(
            data=go.Contour(
                x=grid_x[0, :].cpu().numpy(),   # X-axis from meshgrid
                y=grid_y[:, 0].cpu().numpy(),   # Y-axis from meshgrid
                z=values.cpu().numpy(),         # field values
                colorscale="Viridis",           # color map
                contours=dict(showlines=False, coloring="fill"),
                ncontours=layers,
                showscale=True
            )
        )
    
        # Add zero-level contour in black (like your plt.contour)
        self.fig.add_trace(
            go.Contour(
                x=grid_x[0, :].cpu().numpy(),
                y=grid_y[:, 0].cpu().numpy(),
                z=values.cpu().numpy(),
                contours=dict(start=-0.1, end=0.1, size=1, coloring="none"),
                line=dict(color="black", width=2),
                showscale=False
            )
        )
        
        self.fig.update_layout(
            title="Contour Plot",
            xaxis=dict(
                scaleanchor="y",  # Lock x and y axes together
            ),
            yaxis=dict(
                scaleanchor="x",  # Lock y and x axes together (optional, can be removed if set on xaxis)
            ),
            autosize=True
        )
    
        self.fig.show()
        
    def update(self, fun=None, model=None, values=None):
        if fun is None and model is None:
            assert values is not None
            self.values = values
            self.model = None
        else:
            if model is None:
                assert fun is not None
                model = lambda points: torch.asarray([fun(point) for point in points])
            self.model = model
            coords = torch.stack([self.grid_x.reshape(-1), self.grid_y.reshape(-1)], dim=-1).to(device)
            with torch.no_grad():
                self.values = model(coords).reshape(N, N)
                
class CircleSDF(SDF):
    def __init__(self, x0=(0,0), r=1, xy_lims=(-1, 1, -1, 1)):
        model = lambda points: torch.asarray([abs(sqrt((x[0]-x0[0])**2 + (x[1]-x0[1])**2)-r) for x in points])
        super().__init__(model=model, xy_lims=xy_lims)

In [23]:
circle_sdf = CircleSDF(x0=[0,0], r=0.2)
circle_sdf.plot_field(True)

In [24]:
random = SDF(model=lambda x: torch.asarray(np.random.random((x.shape[0], 1))), xy_lims=(-1, 1, -1, 1))

In [25]:
random.plot_field(True, 5)