In [1]:
import numpy as np

import torch
import torch.nn as nn

In [None]:
class ComplexGaborLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega0=10.0, sigma0=40.0,
                 trainable=False):
        super().__init__()
        self.omega_0 = omega0
        self.scale_0 = sigma0
        self.is_first = is_first
        self.in_features = in_features
        
        if self.is_first: dtype = torch.float()
        else: dtype = torch.cfloat()
        
        self.omega_0 = nn.Parameter(self.omega_0*torch.ones(1), trainable)
        self.scale_0 = nn.Parameter(self.scale_0*torch.ones(1), trainable)
        
        self.linear = nn.Linear(in_features,
                                out_features,
                                bias=bias,
                                dtype=dtype)
        
    def forward(self, input):
        lin = self.linear(input)
        omega = self.omega_0*lin
        scale = self.scale_0*lin
        
        return torch.exp(1j*omega - scale.abs().square())

In [None]:
class Wire(nn.Moduel):
    def __init__(self, in_features, hidden_features, hidden_layers,
                 out_features, outermost_linear=True,
                 first_omega_0=30, hidden_omega_0=30, scale=10.0,
                 pos_encode=False, side_length=512, fn_samples=None,
                 use_nyquist=True):
        super().__init__()
        
        self.nonline = ComplexGaborLayer
        
        hidden_features = int(hidden_features/np.sqrt(2))
        dtype = torch.cfloat
        self.complex = True
        self.wavelet = 'gabor'
        
        self.pos_encode = False
        
        self.net = []
        self.net.append(self.nonlin(in_features,
                                    hidden_features,
                                    omega0=first_omega_0,
                                    sigma=scale,
                                    is_first=True,
                                    trainabe=False))
        
        for i in range(hidden_layers):
            self.net.append(self.nonlin(hidden_features,
                                        hidden_features,
                                        omega0=hidden_omega_0,
                                        sigma0=scale))
            
        final_linear = nn.Linear(hidden_features,
                                 out_features,
                                 dtype=dtype)
        
        self.net.append(final_linear)
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        output = self.net(coords)
        
        if self.wavelet == 'gabor':
            return output.real
        
        return output