In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchinfo import summary

class RealFFN(nn.Module):
    """ A real 2-layer FFN """
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 128)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

realnet = RealFFN()

class ComplexFFN(nn.Module):
    """ A complex complex fc linear layer """
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(128, 256, dtype=torch.cfloat)
        self.fc2 = nn.Linear(256, 128, dtype=torch.cfloat)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.complex(x, x)
        x = self.fc1(x)
        x = torch.view_as_real(x)
        x = F.relu(x)
        x = torch.view_as_complex(x)
        x = self.fc2(x)
        return x

complexnet = ComplexFFN()
input_size = (2, 101, 101, 128)

In [2]:
summary(realnet, input_size=input_size, verbose=0)

Layer (type:depth-idx)                   Output Shape              Param #
RealFFN                                  [2, 101, 101, 128]        --
├─Linear: 1-1                            [2, 101, 101, 256]        33,024
├─Linear: 1-2                            [2, 101, 101, 128]        32,896
Total params: 65,920
Trainable params: 65,920
Non-trainable params: 0
Total mult-adds (M): 0.13
Input size (MB): 10.45
Forward/backward pass size (MB): 62.67
Params size (MB): 0.26
Estimated Total Size (MB): 73.38

In [3]:
summary(complexnet, input_size=input_size, verbose=0)

Layer (type:depth-idx)                   Output Shape              Param #
ComplexFFN                               [2, 101, 101, 128]        --
├─Linear: 1-1                            [2, 101, 101, 256]        66,048
├─Linear: 1-2                            [2, 101, 101, 128]        65,792
Total params: 131,840
Trainable params: 131,840
Non-trainable params: 0
Total mult-adds (M): 0.26
Input size (MB): 10.45
Forward/backward pass size (MB): 125.35
Params size (MB): 1.05
Estimated Total Size (MB): 136.85