In [20]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.io as pio
from math import sqrt
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset

pio.renderers.default = "browser"

In [21]:
# The idea basically is to develop a 2D Toy problem to demonstrate GINN. It must cover the entire workflow, establish some evaluations, and produce concrete comparable metrics. The choice of models or "data" is arbitrary. However, it must be easily transferable to VecSet type workflows too. Ideally modular.
# Also lets use neural fields instead of clouds from step 1.
N = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# def plot_cloud(points, title="Point Cloud"):
#     x, y = zip(*points)
#     fig = go.Figure(data=go.Scatter(x=x, y=y, mode='markers'))
#     fig.update_layout(title=title,
#                       xaxis_title='X',
#                       yaxis_title='Y',
#                       width=600, height=600)
#     fig.show()
# 
# def plot_field(grid_y, grid_x, values, plotly=False):
#     if not plotly:
#         plt.contourf(grid_x.cpu().numpy(), grid_y.cpu().numpy(), values.cpu(), levels=50)
#         plt.colorbar()
#         plt.contour(grid_x.cpu(), grid_y.cpu(), values.reshape(grid_x.shape).cpu(), levels=[0.0], colors='black')  # plot zero level
#         plt.show()
#         return
#     
#     fig = go.Figure(
#         data=go.Contour(
#             x=grid_x[0, :].cpu().numpy(),   # X-axis from meshgrid
#             y=grid_y[:, 0].cpu().numpy(),   # Y-axis from meshgrid
#             z=values.cpu().numpy(),         # field values
#             colorscale="Viridis",           # color map
#             contours=dict(showlines=True, coloring="fill"),
#             ncontours=100,
#             showscale=True
#         )
#     )
# 
#     # Add zero-level contour in black (like your plt.contour)
#     fig.add_trace(
#         go.Contour(
#             x=grid_x[0, :].cpu().numpy(),
#             y=grid_y[:, 0].cpu().numpy(),
#             z=values.cpu().numpy(),
#             contours=dict(start=-0.1, end=0.1, size=1, coloring="none"),
#             line=dict(color="black", width=2),
#             showscale=False
#         )
#     )
#     
#     fig.update_layout(
#         title="Contour Plot",
#         xaxis=dict(
#             scaleanchor="y",  # Lock x and y axes together
#         ),
#         yaxis=dict(
#             scaleanchor="x",  # Lock y and x axes together (optional, can be removed if set on xaxis)
#         ),
#         autosize=True
#     )
# 
#     fig.show()
# 
# def get_point_rep(fun):
#     grid_x, grid_y = torch.meshgrid(
#             torch.linspace(-1, 1, N), torch.linspace(-1, 1, N), indexing="ij"
#         )
#     coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=-1).to(device)
#     
#     with torch.no_grad():
#         values = fun(coords).reshape(N, N)
#         
#     return grid_x.to(device), grid_y.to(device), values.to(device)
# 
# def circle_SDF(x0=[0,0], r=1):
#     return lambda points: torch.asarray([abs(sqrt((x[0]-x0[0])**2 + (x[1]-x0[1])**2)-r) for x in points])
#     

In [71]:
class SDF:
    def __init__(self, fun=None, model=None, grid_x=None, grid_y=None, values=None, xy_lims=None):
        if xy_lims is not None:
            self.grid_y, self.grid_x = torch.meshgrid(torch.linspace(xy_lims[0], xy_lims[1], N), torch.linspace(xy_lims[2], xy_lims[3], N), indexing="ij") 
        else:
            assert grid_x is not None and grid_y is not None
            self.grid_x, self.grid_y = grid_x, grid_y
        
        self.values, self.model, self.values = None, None, None
        self.update(fun, model, values, device_=device)
        
        self.device = device
        
        self.fig = None
                
    def plot_field(self, plotly=False, layers=100, preprocess=None, domain=None, newN=None):
        values = None
        if domain is None:
            grid_x, grid_y = self.grid_x, self.grid_y
            values = self.values if (preprocess is None) else preprocess(self.values)
        else:
            assert self.model is not None
            model = self.model.to(self.device)
            N_ = N if newN is None else newN
            grid_y, grid_x = torch.meshgrid(torch.linspace(domain[0], domain[1], N_), torch.linspace(domain[0], domain[1], N_))
            coords = torch.stack([self.grid_x.reshape(-1), self.grid_y.reshape(-1)], dim=-1).to(self.device)
            with torch.no_grad():
                values = self.model(coords).reshape(N_, N_)
        if not plotly:
            plt.contourf(grid_x.cpu().numpy(), grid_y.cpu().numpy(), values.cpu(), levels=layers)
            plt.colorbar()
            plt.contour(grid_x.cpu(), grid_y.cpu(), values.reshape(grid_x.shape).cpu(), levels=[0.0], colors='black')  # plot zero level
            plt.show()
            return
        
        self.fig = go.Figure(
            data=go.Contour(
                x=grid_x[0, :].cpu().numpy(),   # X-axis from meshgrid
                y=grid_y[:, 0].cpu().numpy(),   # Y-axis from meshgrid
                z=values.cpu().numpy(),         # field values
                colorscale="Viridis",           # color map
                contours=dict(showlines=False, coloring="fill"),
                ncontours=layers,
                showscale=True
            )
        )
    
        # Add zero-level contour in black (like your plt.contour)
        self.fig.add_trace(
            go.Contour(
                x=grid_x[0, :].cpu().numpy(),
                y=grid_y[:, 0].cpu().numpy(),
                z=values.cpu().numpy(),
                contours=dict(start=-0.1, end=0.1, size=1, coloring="none"),
                line=dict(color="black", width=2),
                showscale=False
            )
        )
        
        self.fig.update_layout(
            title="Contour Plot",
            xaxis=dict(
                scaleanchor="y",  # Lock x and y axes together
            ),
            yaxis=dict(
                scaleanchor="x",  # Lock y and x axes together (optional, can be removed if set on xaxis)
            ),
            autosize=True
        )
    
        self.fig.show()
        
    def update(self, fun=None, model=None, values=None, device_="cuda"):
        if fun is None and model is None:
            assert values is not None
            self.values = values
            self.model = None
        else:
            if model is None:
                assert fun is not None
                model = lambda points: torch.asarray([fun(point) for point in points])
            self.model = model
            coords = torch.stack(
                [self.grid_x.reshape(-1), self.grid_y.reshape(-1)], dim=-1
            ).to(device_)
            with torch.no_grad():
                self.values = model(coords).reshape(N, N)
                
class CircleSDF(SDF):
    def __init__(self, x0=(0,0), r=1, xy_lims=(-1, 1, -1, 1)):
        model = lambda points: torch.asarray([abs(sqrt((x[0]-x0[0])**2 + (x[1]-x0[1])**2)-r) for x in points])
        super().__init__(model=model, xy_lims=xy_lims)

In [45]:
circle_sdf = CircleSDF(x0=[0,0], r=0.2)
circle_sdf.plot_field(True)

In [24]:
random = SDF(model=lambda x: torch.asarray(np.random.random((x.shape[0], 1))), xy_lims=(-1, 1, -1, 1))

In [25]:
random.plot_field(True, 5)

In [72]:
class FieldMLP(nn.Module):
    def __init__(self, in_dim=2, hidden=128, depth=4, out_dim=1):
        super().__init__()
        layers = []
        dims = [in_dim] + [hidden]*(depth-1) + [out_dim]
        for i in range(len(dims)-1):
            layers += [nn.Linear(dims[i], dims[i+1])]
            if i < len(dims)-2:
                layers += [nn.SiLU()]     # smooth activation is helpful for gradients
        self.net = nn.Sequential(*layers)

    def forward(self, x):   # x: (B, 2) points in [-1,1]^2
        return self.net(x)  # (B, 1)
    
    
class DummyDataset(Dataset):
    def __len__(self): return 10
    def __getitem__(self, idx): return torch.tensor(0.0)
    
    
# Implement a Constraint class here that you can just add new constraints to.
class ConstraintLoss:
    def __init__(self, model, coords):
        field = model(coords)

        self.area = area = self.constraint_area_loss(field, target_area=1.0)
        self.tv = tv   = self.total_variation_loss(field, coords)
        self.cont = cont = self.containment_loss(coords, field, thresh=0.5)
        self.cent = cent = self.center_loss(field)
        # misc = torch.zeros((), device=self.device)

        self.loss = (4.0 * area + 0.1 * tv + 10.0 * cont + 0.2 * cent) # + misc
       
    @property
    def total_loss(self)->torch.Tensor:
        return self.loss
        
    def constraint_area_loss(self, field, target_area=0.2):
        p = torch.sigmoid(field)
        area_est = p.mean()
        return (area_est - target_area)**2
    
    def total_variation_loss(self, field, coords):
        # TV on the field via gradient magnitude ||∇f||; needs grads w.r.t. coords
        grad = torch.autograd.grad(
            field, coords,
            grad_outputs=torch.ones_like(field),
            create_graph=True, retain_graph=True, only_inputs=True
        )[0]  # shape (B,2)
        tv = (grad.pow(2).sum(dim=-1) + 1e-8).sqrt().mean()
        return tv
    
    def containment_loss(self, coords, field, thresh=0.5):
        # Example: penalize predictions outside a disk of radius 1 (just a demo)
        # Encourage "inside" probs > thresh where r<=1, and < thresh outside
        r = coords.norm(dim=-1, keepdim=True)
        target = (r <= 1.0).float()
        p = torch.sigmoid(field)
        return ((p - (thresh*0 + target)).abs()).mean()
    
    def center_loss(self, field):
        return field.mean()**2
    
    def symmetry_loss(self, field, coords):
        pass
    

class ConstraintTrainer(pl.LightningModule):
    def __init__(self, sdf, lr=1e-3, n_points=1024, domain=[-1,1], seed=None):
        super().__init__()
        self.sdf = sdf
        self.model = sdf.model
        self.lr = lr
        self.n_points = n_points
        self.register_buffer('domain_min', torch.tensor([domain[0], domain[0]]), persistent=False)
        self.register_buffer('domain_max', torch.tensor([domain[1], domain[1]]), persistent=False)
        if seed is not None:
            pl.seed_everything(seed)
            
        self.all_frames = []

    def sample_coords(self, n):
        # Uniform in the box
        u = torch.rand(n, 2, device=self.device)
        return self.domain_min + (self.domain_max - self.domain_min) * u

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def train_dataloader(self):
        return DataLoader(DummyDataset(), batch_size=1)
    
    # --------------------- HOOKS ---------------------------------
    def training_step(self, batch, batch_idx):
        # No dataloader needed; we draw fresh collocation points each step
        coords = self.sample_coords(self.n_points).requires_grad_()
        field = self.model(coords)

        # loss = abs(field.mean()) # Dummy loss
        cl = ConstraintLoss(self.model, coords)
        loss = cl.total_loss

        self.log_dict(
            {"loss": loss,  "area": cl.area, "tv": cl.tv, "cont": cl.cont, "center": cl.cent},
            prog_bar=True, on_step=True, on_epoch=True
        )
        return loss
    
    def on_train_start(self) -> None:
        self.sdf.update(model=self.model, device_=self.device)
        self.sdf.plot_field(True, 5)
        
    def on_train_epoch_end(self):
        self.sdf.update(model=self.model, device_=self.device)
        # self.sdf.plot_field(True)
        self.all_frames.append(self.sdf.values.cpu().numpy())
        
    def on_train_end(self):
        # Build plotly animation
        frames = [go.Frame(
            data=[go.Contour(z=z, x=self.sdf.grid_x[0,:].cpu().numpy(),
                             y=self.sdf.grid_y[:,0].cpu().numpy())],
            name=f"epoch{k}"
        ) for k, z in enumerate(self.all_frames)]
    
        fig = go.Figure(
            data=frames[0].data,
            frames=frames
        )
    
        fig.update_layout(
            updatemenus=[{
                "buttons": [
                    {"args": [None, {"frame": {"duration": 200, "redraw": True},
                                     "fromcurrent": True}], "label": "▶ Play", "method": "animate"},
                    {"args": [[None], {"frame": {"duration": 0, "redraw": False},
                                       "mode": "immediate"}], "label": "⏸ Pause", "method": "animate"}
                ]
            }]
        )
        fig.show()

In [73]:
model = FieldMLP(hidden=32, depth=2).to(device)
sdf = SDF(model=model, xy_lims=(-1, 1, -1, 1))
lit = ConstraintTrainer(sdf, lr=1e-3, n_points=128, domain=[-1,1], seed=42)
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="auto",
    precision="bf16-mixed" if torch.cuda.is_available() else "32-true",
    gradient_clip_val=1.0,
    log_every_n_steps=10
)
trainer.fit(lit)

Seed set to 42
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params | Mode 
-------------------------------------------
0 | model | FieldMLP | 129    | train
-------------------------------------------
129       Trainable params
0         Non-trainable params
129       Total params
0.001     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Epoch 4: 100%|██████████| 10/10 [00:00<00:00, 46.00it/s, v_num=33, loss_step=4.700, area_step=0.121, tv_step=0.153, cont_step=0.412, center_step=0.396, loss_epoch=4.820, area_epoch=0.128, tv_epoch=0.165, cont_epoch=0.422, center_epoch=0.343]  

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 10/10 [00:00<00:00, 42.98it/s, v_num=33, loss_step=4.700, area_step=0.121, tv_step=0.153, cont_step=0.412, center_step=0.396, loss_epoch=4.820, area_epoch=0.128, tv_epoch=0.165, cont_epoch=0.422, center_epoch=0.343]


In [78]:
metrics = trainer.logged_metrics
print(metrics)

{'loss_step': tensor(4.6950), 'area_step': tensor(0.1209), 'tv_step': tensor(0.1534), 'cont_step': tensor(0.4117), 'center_step': tensor(0.3955), 'loss_epoch': tensor(4.8173), 'area_epoch': tensor(0.1278), 'tv_epoch': tensor(0.1653), 'cont_epoch': tensor(0.4221), 'center_epoch': tensor(0.3431)}


In [77]:
sdf.plot_field(
    True, 50, 
)