# Check the dependencies 

In [17]:
import tensorly as tl
import tltorch
import neuralop as no

print(f'{tl.__version__=}')
print(f'{tltorch.__version__=}')
print(f'{no.__version__=}')

tl.__version__='0.8.1'
tltorch.__version__='0.4.0'
no.__version__='0.1.3'


# FFT and Spectral Convolution


In [18]:
from neuralop.models.fno_block import FactorizedSpectralConv
from neuralop.models import TFNO2d
import torch

In [19]:
fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),
                                      factorization=None, implementation='reconstructed')

In [20]:
in_data = torch.randn((2, 3, 16, 16))

In [21]:
out = fourier_conv(in_data)

In [22]:
out.shape

torch.Size([2, 10, 16, 16])

In [23]:
fourier_conv

FactorizedSpectralConv(
  (weight): ModuleList(
    (0-1): 2 x ComplexDenseTensor(shape=torch.Size([3, 10, 2, 2]), rank=None)
  )
)

The way the spectral convolution works is that it multiplies (complex) coefficients with (complex) weights, learned end-to-end.

# Tensorized Spectral Convolutions

It is possible to express the weights of one or more layers as in factorized form, as a low-rank decomposition of the full weights.

`neuralop` comes with support for tensorization out of the box, you can simply specify, e.g., to use a Tucker factorization, `factorization='tucker'`.

In [24]:
fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),
                                      factorization='tucker', implementation='reconstructed')

In [25]:
fourier_conv

FactorizedSpectralConv(
  (weight): ModuleList(
    (0-1): 2 x ComplexTuckerTensor(shape=(3, 10, 2, 2), rank=(1, 5, 1, 1))
  )
)

## Efficient forward pass

When factorizing the weights, have two main options during the forward pass:
1. reconstruct the full weights and use that for the forward pass 
2. contract the input directly with the factorized weights to predict the output

When the factorized weights are small, the second option can lead to large speedups or memory reduction, particularly when coupled with checkpointing. 

In `neuralop`, you can use those simply by specifying `implementation='reconstructed'` or `implementation='factorized'`:

In [26]:
fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),
                                      factorization='tucker', implementation='factorized')

# Full Tensorized Fourier Neural Operator 

The full architecture is composed of 

i) a lifting layer taking the number of input channels and lifting that to the desired number of hidden channels
ii) a number of spectral convolutions, as shown above
iii) a projection layer projecting back from the number of hidden channels to the desired number of output channels


In [27]:
tfno = TFNO2d(n_modes_height=16, n_modes_width=16, hidden_channels=16, 
              factorization=None, skip='linear')

In [28]:
tfno

TFNO2d(
  (convs): FactorizedSpectralConv2d(
    (weight): ModuleList(
      (0-7): 8 x ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)
    )
  )
  (fno_skips): ModuleList(
    (0-3): 4 x Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (lifting): Lifting(
    (fc): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
  )
  (projection): Projection(
    (fc1): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
    (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

## Lifting layer

Increasing the number of channels

In [29]:
tfno.lifting

Lifting(
  (fc): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
)

## Spectral convolutions

In [30]:
tfno.convs

FactorizedSpectralConv2d(
  (weight): ModuleList(
    (0-7): 8 x ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)
  )
)

## Skip connections: recovering non-periodicity

Recall the FNO architecture has skip connections: the FFT transformation will loose non-periodic information that has to be reinjected through skip connections. These skip connections also help with learning.

![FNO_layer](./images/fourier_layer.png)

Here, linear layer (represented by weight W in the image). We can also use Identity skip (`skip='identity'`) or soft-gated connections (`skip='soft-gating'`)

In [31]:
tfno.fno_skips

ModuleList(
  (0-3): 4 x Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
)

## Projection: going back to the target number of channels 

Finally, the projection layer takes the hidden dimension to projection_channels and to the actual number of output channels (here, 1)

In [32]:
tfno.projection

Projection(
  (fc1): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
  (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
)