In [1]:
import torch 
import numpy as np
import torch.nn as nn

In [77]:
b_gause  = torch.randn(size=(256,3))

In [28]:
data = torch.rand((1024,3))

In [83]:
x_proj = (2.*np.pi*data)@b_gause.T
torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1).shape

torch.Size([1024, 512])

In [30]:
import numpy as np
import torch.nn as nn
import torch
from torch.autograd import grad
from math import pi


def gradient(inputs, outputs):
    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
    points_grad = grad(
        outputs=outputs,
        inputs=inputs,
        grad_outputs=d_points,
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0][:, -3:]
    return points_grad


class FourierLayer(nn.Module):
    def __init__(self, in_features, out_features, k):
        super().__init__()
        B = torch.randn(in_features, out_features) * k
        print(B.shape)
        self.register_buffer("B", B)

    def forward(self, x):
        print(x.shape)
        x_proj = torch.matmul(2 * pi * x, self.B)
        print(x_proj.shape)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        print(out.shape)
        return out


class ImplicitNet(nn.Module):
    def __init__(
            self,
            FF=False,
            k=3,
            d_in=3,
            dims=[512,512,512,512,512,512,512,512],
            skip_in={4},
            geometric_init=True,
            radius_init=1,
            beta=100
            
    ):
        super().__init__()
        self.B_gause  = torch.randn(size=(256,3))*k

        dims = [512] + dims + [1]
    
        self.num_layers = len(dims)
        self.skip_in = skip_in

        for layer in range(0, self.num_layers - 1):

            if layer + 1 in skip_in:
                out_dim = dims[layer + 1] - d_in
            else:
                out_dim = dims[layer + 1]

            lin = nn.Linear(dims[layer], out_dim)

            # if true preform preform geometric initialization
            if geometric_init:

                if layer == self.num_layers - 2:

                    torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[layer]), std=0.00001)
                    torch.nn.init.constant_(lin.bias, -radius_init)
                else:
                    torch.nn.init.constant_(lin.bias, 0.0)

                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))

            setattr(self, "lin" + str(layer), lin)

        if beta > 0:
            self.activation = nn.Softplus(beta=beta)

        # vanilla relu
        else:
            self.activation = nn.ReLU()

    def forward(self, input):

        x = input

        x_proj = (2.*np.pi*data)@self.B_gause.T
        x = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1)
            

        for layer in range(0, self.num_layers - 1):

            lin = getattr(self, "lin" + str(layer))

            if layer in self.skip_in:
                x = torch.cat([x, input], -1) / np.sqrt(2)

            x = lin(x)

            if layer < self.num_layers - 2:
                x = self.activation(x)

        return x

In [31]:
model1 = ImplicitNet(FF=True)

In [32]:
data_pred = model1(data)

In [23]:
ImplicitNet(FF=False), ImplicitNet(FF=True)

(ImplicitNet(
   (lin0): Linear(in_features=512, out_features=512, bias=True)
   (lin1): Linear(in_features=512, out_features=512, bias=True)
   (lin2): Linear(in_features=512, out_features=512, bias=True)
   (lin3): Linear(in_features=512, out_features=509, bias=True)
   (lin4): Linear(in_features=512, out_features=512, bias=True)
   (lin5): Linear(in_features=512, out_features=512, bias=True)
   (lin6): Linear(in_features=512, out_features=512, bias=True)
   (lin7): Linear(in_features=512, out_features=512, bias=True)
   (lin8): Linear(in_features=512, out_features=1, bias=True)
   (activation): Softplus(beta=100, threshold=20)
 ),
 ImplicitNet(
   (lin0): Linear(in_features=512, out_features=512, bias=True)
   (lin1): Linear(in_features=512, out_features=512, bias=True)
   (lin2): Linear(in_features=512, out_features=512, bias=True)
   (lin3): Linear(in_features=512, out_features=509, bias=True)
   (lin4): Linear(in_features=512, out_features=512, bias=True)
   (lin5): Linear(in_fea

In [16]:
import numpy as np
import torch.nn as nn
import torch
from torch.autograd import grad
from math import pi


def gradient(inputs, outputs):
    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
    points_grad = grad(
        outputs=outputs,
        inputs=inputs,
        grad_outputs=d_points,
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0][:, -3:]
    return points_grad


def doubleWellPotential(s):
    """
    double well potential function with zeros at -1 and 1
    """
    return (s ** 2) - 2 * (s.abs()) + 1.


class FourierLayer(nn.Module):
    def __init__(self, in_features, out_features, k):
        super().__init__()
        B = torch.randn(in_features, out_features) * k
        self.register_buffer("B", B)

    def forward(self, x):
        x_proj = torch.matmul(2 * pi * x, self.B)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        return out


class ImplicitNet1(nn.Module):
    def __init__(
            self,
            
            k = 3,
            d_in = 3,
            dims = [512,512,512,512,512,512,512,512],
            skip_in={4},
            geometric_init=True,
            radius_init=1,
            beta=100
    ):
        super().__init__()

        self.FF = FF
        self.k = k

        if FF:
            self.ffLayer = FourierLayer(in_features=3, out_features=dims[0]//2, k=self.k)
            dims = [dims[0]] + dims + [1]
        else:
            dims = [d_in] + dims + [1]

        self.num_layers = len(dims)
        self.skip_in = skip_in

        for layer in range(0, self.num_layers - 1):

            if layer + 1 in skip_in:
                print(dims[layer + 1], d_in)
                out_dim = dims[layer + 1] - d_in
            else:
                out_dim = dims[layer + 1]

            lin = nn.Linear(dims[layer], out_dim)

            # if true preform preform geometric initialization
            if geometric_init:

                if layer == self.num_layers - 2:

                    torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[layer]), std=0.00001)
                    torch.nn.init.constant_(lin.bias, -radius_init)
                else:
                    torch.nn.init.constant_(lin.bias, 0.0)

                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))

            setattr(self, "lin" + str(layer), lin)

        if beta > 0:
            self.activation = nn.Softplus(beta=beta)

        # vanilla relu
        else:
            self.activation = nn.ReLU()

    def forward(self, input):

        x = input

        if self.FF:
            x = self.ffLayer(x)  # apply the fourier

        for layer in range(0, self.num_layers - 1):

            lin = getattr(self, "lin" + str(layer))

            if layer in self.skip_in:
                x = torch.cat([x, input], -1) / np.sqrt(2)

            x = lin(x)

            if layer < self.num_layers - 2:
                x = self.activation(x)

        return x

In [34]:
ImplicitNet1(FF=True)

torch.Size([3, 256])
512 3


ImplicitNet1(
  (ffLayer): FourierLayer()
  (lin0): Linear(in_features=512, out_features=512, bias=True)
  (lin1): Linear(in_features=512, out_features=512, bias=True)
  (lin2): Linear(in_features=512, out_features=512, bias=True)
  (lin3): Linear(in_features=512, out_features=509, bias=True)
  (lin4): Linear(in_features=512, out_features=512, bias=True)
  (lin5): Linear(in_features=512, out_features=512, bias=True)
  (lin6): Linear(in_features=512, out_features=512, bias=True)
  (lin7): Linear(in_features=512, out_features=512, bias=True)
  (lin8): Linear(in_features=512, out_features=1, bias=True)
  (activation): Softplus(beta=100, threshold=20)
)

tensor([[ 9.3540],
        [ 9.0854],
        [11.8543],
        ...,
        [12.0148],
        [10.1440],
        [ 8.8734]], grad_fn=<AddmmBackward>)