In [1]:
import torch
import torch.nn.functional as F
from torch import nn


class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        return X - X.mean()

In [2]:
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

tensor([-2., -1.,  0.,  1.,  2.])

In [3]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [4]:
Y = net(torch.rand(4, 8))
Y.mean()

tensor(1.8626e-09, grad_fn=<MeanBackward0>)

In [5]:
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

In [6]:
linear = MyLinear(5, 3)
linear.weight

Parameter containing:
tensor([[ 0.0668,  0.5748, -1.2197],
        [-0.3411,  0.8489, -0.1694],
        [-0.7381, -0.1566,  0.2411],
        [ 1.7724, -0.7835,  0.4507],
        [-1.6869, -0.0959, -0.6579]], requires_grad=True)

In [7]:
linear(torch.rand(2, 5))

tensor([[0.0000, 1.9488, 0.0000],
        [1.6845, 1.1124, 0.0000]])

In [8]:
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

tensor([[3.7027],
        [1.9504]])

In [15]:
# class DRLayer(nn.Module):
#     def __init__(self, i, j, k):
#         super().__init__()
#         self.k = k
#         self.weight = nn.Parameter(torch.randn((k,i,j)))

#     def forward(self,X):
#         z = torch.ones((1,self.k))
#         for i in range(X.shape[0]):
#             y = torch.matmul(X[i,:].reshape(-1,1),X[i,:].reshape(1,-1))
#             tmp_z = torch.matmul(self.weight,y).sum(axis=[1,2]).reshape(1,-1)
#             z = torch.cat([z,tmp_z],0)
#         return z[1:]


# net = DRLayer(5,2)
# print(net(torch.rand(4,5)).shape)

TypeError: __init__() missing 1 required positional argument: 'k'

In [16]:
class HalfFFT(nn.Module):
    def __init__(self):
        super(HalfFFT, self).__init__()

    def forward(self, X):
        """
        Compute FFT and return half of it
        :param X: size = B*L
        :return: size = B*round(L/2)
        """
        half_len = round(X.shape[1]/2)
        X_f = torch.fft.fft(X)
        return X_f[:, :half_len]


myNet2 = HalfFFT()
print(myNet2(torch.rand(2, 3)))

tensor([[ 1.3686+0.0000j, -0.5299+0.0503j],
        [ 1.4565+0.0000j, -0.4599+0.4170j]])
