In [585]:
import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl

In [630]:
class GaussianRBF(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.output_dim = output_dim
        #Definition and initialization of centers and scales
        self.mus = torch.nn.Parameter(torch.ones(output_dim, input_dim))
        self.log_sigmas = torch.nn.Parameter(torch.ones(output_dim))
        nn.init.uniform_(self.mus, 0, 1)
        nn.init.constant_(self.log_sigmas, 0)
        
    def forward(self, x):
        d_scaled = (x[:,None,:] - self.mus[None,:,:]/torch.exp(self.log_sigmas[None,:,None]))
        return torch.exp(-(torch.linalg.vector_norm(d_scaled, axis=-1, ord=2))**2/2)

In [631]:
model = GaussianRBF(2,72)

In [632]:
output = model.forward(torch.rand(100,2))

In [621]:
output.shape

torch.Size([100, 72])

In [590]:
a = torch.rand(2)
a
torch.tile(a,(10,1)).shape

torch.Size([10, 2])

In [591]:
class NLBranchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.ConvTranspose2d(in_channels=1, out_channels=16, kernel_size=4, stride=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=16))
        self.layers.append(nn.ConvTranspose2d(in_channels=16, out_channels=32, kernel_size=4, stride=1))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.BatchNorm2d(num_features=32))
        self.layers.append(nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=2, stride=2))
        self.layers.append(nn.ReLU())

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [592]:
model = NLBranchNet().forward(torch.rand((100,1,12,12)))

In [593]:
output.shape

torch.Size([100, 72])

In [662]:
class LBranchNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, output_dim))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        return x

In [663]:
model = LBranchNet(144, 144)

In [664]:
model(torch.rand(144))

tensor([ 1.0826e-01,  3.2304e-01, -1.9980e-01,  8.9366e-02, -1.1215e-02,
        -4.2666e-01,  3.9343e-01,  1.7136e-01, -1.9564e-01,  3.2310e-01,
        -3.3134e-01, -4.5084e-01,  4.1863e-01,  4.4906e-01,  6.4499e-01,
         4.6688e-01, -2.2056e-02,  8.0327e-01, -3.6321e-02, -4.0772e-01,
         2.0600e-01,  2.1587e-01, -2.7418e-01,  2.6230e-01, -9.9448e-02,
        -8.0348e-03,  3.3014e-01, -5.7964e-01, -1.0197e+00, -1.7571e-01,
         2.1947e-01,  3.7182e-01, -8.1308e-02, -5.1688e-01, -1.3647e-01,
         2.0460e-04, -5.0699e-01, -3.7830e-01,  2.6403e-01,  6.2820e-01,
        -7.1961e-02, -5.0715e-01, -1.5096e-01, -1.8724e-01,  1.5253e-01,
         8.2365e-02,  2.5304e-01,  4.4619e-01, -2.4975e-01, -1.3530e-01,
        -4.5496e-01, -4.1148e-01, -1.3446e-01,  4.0538e-01, -2.2410e-01,
         5.9388e-01, -4.4017e-02, -5.7302e-01,  1.6264e-01,  1.7279e-01,
        -3.4509e-01,  4.9541e-01,  5.1375e-01,  8.1608e-02,  2.7153e-01,
        -5.1271e-01,  5.6308e-01, -4.2395e-01,  7.2

In [665]:
class VarMiON(pl.LightningModule):
    def __init__(self, params):
        super().__init__()
        self.hparams.update(params['hparams'])

    def forward(self, theta, F, N, x):
        NLBranch = NLBranchNet().forward(theta)
        NLBranch = NLBranch.reshape((NLBranch.shape[0], NLBranch.shape[2], NLBranch.shape[3]))
        LBranch = LBranchNet(144,72).forward(F) + LBranchNet(144,72).forward(N)
        Branch = torch.einsum('nij,nj->ni', NLBranch, LBranch)
        Trunk = GaussianRBF(2,72).forward(x)
        u_hat = torch.einsum('ni,ni->n', Branch, Trunk)
        return u_hat

In [666]:
params = {}
params['hparams'] = {}
model = VarMiON(params)

In [667]:
theta = torch.rand((100,1,12,12))
F = torch.rand(100,144)
N = torch.rand(100,144)
x = torch.rand(100,2)

In [668]:
model(theta, F, N, x)

tensor([ -8.9597,  -2.5934, -26.0627, -14.9835, -33.3732,  -9.6151,  -3.1857,
        -19.3091, -21.1174, -18.1950,  -7.5267, -16.4898,  -6.8867, -15.4041,
        -26.2049, -11.8599, -17.0219,  -6.8183, -18.8837,  -3.7176, -25.0287,
        -19.8458, -10.1516, -10.0587, -12.6031, -13.2446, -24.4722, -11.8340,
        -22.7031, -35.9719,  -9.2498,  -0.2187,  -5.0522, -12.5399, -15.7212,
        -16.7408, -13.9912, -10.0065, -13.5594, -14.3105, -24.2220,  -8.4578,
        -17.8024, -19.3186, -31.2625, -20.5583,  -4.7125, -19.8892,  -3.2145,
         -9.1456,  -3.5943, -14.1928, -12.6701, -19.2701, -20.2234,  -5.6060,
         -7.9946, -12.5741, -18.4971,  -4.2853, -13.8165, -11.8999, -15.9461,
         -2.1674, -13.1560,  -9.3719, -23.5992, -18.2455, -13.6656,   6.1686,
         -8.3715, -30.2949,  -0.5465,  -5.5226, -19.2469, -11.4430, -20.4002,
         -0.3315,  -9.3901, -15.3602, -14.8297,  -7.4635, -28.1031,  -7.2694,
        -31.7850,  -6.0084, -12.6843,  -1.7178, -11.7688, -13.76