In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np

torch.fft.fftn(input, s=None, dim=None, norm=None, *, out=None) → Tensor

In [3]:
class ResidualFourierTransformLayer(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return torch.real(torch.fft.fft2(x)) + x

In [4]:
model = ResidualFourierTransformLayer()
print(model)

ResidualFourierTransformLayer()


100 trames par seconde (25 ms par trame et 10 ms d'overlap)

In [5]:
input_dim = 39
batch_size = 16
seq_length = 1000
x = torch.rand(batch_size, seq_length, input_dim)

In [6]:
z = model(x)

In [7]:
print(z.shape)
print(x.shape)

torch.Size([16, 1000, 39])
torch.Size([16, 1000, 39])


In [8]:
class ResidualFeedForwardLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout_rate, activation):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            activation, #nn.GELU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Dropout(dropout_rate)
        )
        
    def forward(self, x):
        return self.layers(x) + x

In [9]:
hidden_dim = 3072
dropout_rate = 0.1
ff_model = ResidualFeedForwardLayer(input_dim, hidden_dim, dropout_rate)

In [10]:
print(input_dim)

39


In [11]:
out = ff_model(z)

In [12]:
out.shape

torch.Size([16, 1000, 39])

In [23]:
class FNetEncoderBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout_rate, activation):
        super().__init__()
        self.layers = nn.Sequential(
            ResidualFourierTransformLayer(),
            nn.LayerNorm(input_dim), #eps=1e-12 initially
            ResidualFeedForwardLayer(input_dim, hidden_dim, dropout_rate, activation),
            nn.LayerNorm(input_dim) #eps=1e-12
        )
        
    def forward(self, x):
        return self.layers(x)

In [None]:
activation = nn.GELU()

In [24]:
block = FNetEncoderBlock(input_dim, hidden_dim, dropout_rate, activation)

In [26]:
out_block = block(x)
print(out_block.shape)

torch.Size([16, 1000, 39])


In [36]:
class FNetEncoder(nn.Module):
    def __init__(self, n_encoders, input_dim, hidden_dim, dropout_rate, output_dim, activation=nn.GELU):
        super().__init__()
        self.layers = nn.ModuleList([FNetEncoderBlock(input_dim, hidden_dim, dropout_rate, activation()) for i in range(n_encoders)])
        self.layers.append(nn.Linear(input_dim, output_dim)) #init ?
        # self.layers.append(nn.Tanh()) as added in the paper
        
        
    def forward(self, x):
        for layer in (self.layers):
            x = layer(x)
        return x

In [2]:
f = lambda : nn.ELU(alpha=0.5)

In [3]:
f()

ELU(alpha=0.5)

In [44]:
N = 6
activation = f
output_dim = 48
fnet_encoder = FNetEncoder(N, input_dim, hidden_dim, dropout_rate, output_dim, activation)
out_encoder = fnet_encoder(x)
print(out_encoder.shape)

torch.Size([16, 1000, 48])
