# CCNN

General imports

In [None]:
import torch
import torchaudio
import torchmetrics
import math
import os
import sklearn
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.profilers import PyTorchProfiler
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from pytorch_lightning.loggers import TensorBoardLogger
from torch import optim
from typing import Optional
from omegaconf import OmegaConf
from tqdm import tqdm

from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.datasets import STL10
from torchaudio.datasets import SPEECHCOMMANDS
from torch.utils.data import random_split, DataLoader, TensorDataset

# Datasets

Utils

In [None]:
def split_data(tensor, stratify):
    # 0.7/0.15/0.15 train/val/test split
    (
        train_tensor,
        testval_tensor,
        train_stratify,
        testval_stratify,
    ) = sklearn.model_selection.train_test_split(
        tensor,
        stratify,
        train_size=0.7,
        random_state=0,
        shuffle=True,
        stratify=stratify,
    )

    val_tensor, test_tensor = sklearn.model_selection.train_test_split(
        testval_tensor,
        train_size=0.5,
        random_state=1,
        shuffle=True,
        stratify=testval_stratify,
    )
    return train_tensor, val_tensor, test_tensor

def save_data(dir, **tensors):
    if not os.path.exists(dir):
        os.makedirs(dir)
    for tensor_name, tensor_value in tensors.items():
        torch.save(tensor_value, str(dir + "/" + tensor_name) + ".pt")


def load_data(dir):
    tensors = {}
    for filename in os.listdir(dir):
        if filename.endswith(".pt"):
            tensor_name = filename.split(".")[0]
            tensor_value = torch.load(str(dir + "/" + filename))
            tensors[tensor_name] = tensor_value
    return tensors


def load_data_from_partition(data_loc, partition):
    assert partition in ["train", "val", "test"]
    # load tensors
    tensors = load_data(data_loc)
    # select partition
    name_x, name_y = f"{partition}_x", f"{partition}_y"
    x, y = tensors[name_x], tensors[name_y]
    return x, y


def normalise_data(X, y):
    train_X, _, _ = split_data(X, y)
    out = []
    for Xi, train_Xi in zip(X.unbind(dim=-1), train_X.unbind(dim=-1)):
        train_Xi_nonan = train_Xi.masked_select(~torch.isnan(train_Xi))
        mean = train_Xi_nonan.mean()  # compute statistics using only training data.
        std = train_Xi_nonan.std()
        out.append((Xi - mean) / (std + 1e-5))
    out = torch.stack(out, dim=-1)
    return out

MNIST

In [None]:
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, cfg, data_dir : str = "../datasets"):
        super().__init__()
        self.data_dir = data_dir
        self.type = cfg.data.dataset
        self.cfg = cfg
        self.num_workers = 7


    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
        
    
    def _set_transform(self):
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.view(1, -1)) # flatten the image to 784 pixels
        ])

        if self.type == "pmnist":
            self.transform.transforms.append(transforms.Lambda(lambda x: x[torch.randperm(self.size)]))  # permutation of the 784 pixels

    
    def _yaml_parameters(self):
        hidden_channels = self.cfg.net.hidden_channels

        OmegaConf.update(self.cfg, "train.batch_size", 100)
        OmegaConf.update(self.cfg, "train.epochs", 210)
        OmegaConf.update(self.cfg, "net.in_channels", 1)
        OmegaConf.update(self.cfg, "net.out_channels", 10)
        OmegaConf.update(self.cfg, "net.data_dim", 1)

        if hidden_channels == 140:
            if self.type == "smnist":
                OmegaConf.update(self.cfg, "train.learning_rate", 0.01)
                OmegaConf.update(self.cfg, "train.dropout_rate", 0.1)
                OmegaConf.update(self.cfg, "train.weight_decay", 1e-6)
                OmegaConf.update(self.cfg, "kernel.omega_0", 2976.49)
            elif self.type == "pmnist":
                OmegaConf.update(self.cfg, "train.learning_rate", 0.02)
                OmegaConf.update(self.cfg, "train.dropout_rate", 0.2)
                OmegaConf.update(self.cfg, "train.weight_decay", 0)
                OmegaConf.update(self.cfg, "kernel.omega_0", 2985.63)
        elif hidden_channels == 380:
            OmegaConf.update(self.cfg, "train.weight_decay", 0)

            if self.type == "smnist":
                OmegaConf.update(self.cfg, "train.learning_rate", 0.01)
                OmegaConf.update(self.cfg, "train.dropout_rate", 0.1)
                OmegaConf.update(self.cfg, "kernel.omega_0", 2976.49)
            elif self.type == "pmnist":
                OmegaConf.update(self.cfg, "train.learning_rate", 0.02)
                OmegaConf.update(self.cfg, "train.dropout_rate", 0.2)
                OmegaConf.update(self.cfg, "kernel.omega_0", 2985.63)


    def setup(self, stage: str):
        self._set_transform()
        self._yaml_parameters()

        self.batch_size = self.cfg.train.batch_size

        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                self.mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )
            print(f'Training set size: {len(self.mnist_train)}')
            print(f'Validation set size: {len(self.mnist_val)}')

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
            print(f'Test set size: {len(self.mnist_test)}')

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
            print(f'Prediction set size: {len(self.mnist_predict)}')


    def train_dataloader(self):
        return DataLoader(self.mnist_train,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

    def val_dataloader(self):
        return DataLoader(self.mnist_val,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)
    
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...

SpeechCommands

In [None]:
class SpeechCommandsModule(pl.LightningDataModule):
    def __init__(self, cfg, data_dir : str = "../datasets"):
        super().__init__()
        self.data_dir = data_dir
        self.data_processed_location = self.data_dir + "/SpeechCommands/processed_data"
        self.download_location = self.data_dir + "/SpeechCommands/speech_commands_v0.02"
        self.type = cfg.data.dataset
        self.cfg = cfg
        self.num_workers = 7


    def process_data(self):
        x = torch.empty(34975, 16000, 1)
        y = torch.empty(34975, dtype=torch.long)

        batch_index = 0
        y_index = 0
        for foldername in (
            "yes",
            "no",
            "up",
            "down",
            "left",
            "right",
            "on",
            "off",
            "stop",
            "go",
        ):
            loc = self.download_location + "/" + foldername
            for filename in tqdm(os.listdir(loc)):
                audio, _ = torchaudio.load(
                    loc + "/" + filename,
                    channels_first=False,
                )

                # A few samples are shorter than the full length; for simplicity we discard them.
                if len(audio) != 16000:
                    continue

                x[batch_index] = audio
                y[batch_index] = y_index
                batch_index += 1
            y_index += 1


        # If MFCC, then we compute these coefficients.
        if self.type == "sc_mfcc":
            x = torchaudio.transforms.MFCC(
                log_mels=True, n_mfcc=20, melkwargs=dict(n_fft=200, n_mels=64)
            )(x.squeeze(-1)).detach()
            # X is of shape (batch=34975, channels=20, length=161)
        else:
            x = x.unsqueeze(1).squeeze(-1)
            # X is of shape (batch=34975, channels=1, length=16000)

        # Normalize data
        if self.type == "sc_mfcc":
            x = normalise_data(x.transpose(1, 2), y).transpose(1, 2)
        else:
            x = normalise_data(x, y)

        train_x, val_x, test_x = split_data(x, y)
        train_y, val_y, test_y = split_data(y, y)

        
        return (
            train_x,
            val_x,
            test_x,
            train_y,
            val_y,
            test_y,
        )


    def prepare_data(self):
        # download
        SPEECHCOMMANDS(self.data_dir, download=True)
        if not os.path.exists(self.data_processed_location + "/train_x.pt"):
            train_x, val_x, test_x, train_y, val_y, test_y = self.process_data()

            save_data(
                self.data_processed_location,
                train_x=train_x,
                val_x=val_x,
                test_x=test_x,
                train_y=train_y,
                val_y=val_y,
                test_y=test_y,
            )

    
    def _yaml_parameters(self):
        OmegaConf.update(self.cfg, "net.out_channels", 10)
        OmegaConf.update(self.cfg, "net.data_dim", 1)
        OmegaConf.update(self.cfg, "train.dropout_rate", 0.2)
        OmegaConf.update(self.cfg, "train.learning_rate", 0.02)
        OmegaConf.update(self.cfg, "train.weight_decay", 1e-6)

        # 140 and 380 hidden_channels have same parameters
        if self.type == "sc_raw":
            OmegaConf.update(self.cfg, "net.in_channels", 1)
            OmegaConf.update(self.cfg, "train.batch_size", 20)
            OmegaConf.update(self.cfg, "train.epochs", 160)
            OmegaConf.update(self.cfg, "kernel.omega_0", 1295.61)
        elif self.type == "sc_mfcc":
            OmegaConf.update(self.cfg, "net.in_channels", )
            OmegaConf.update(self.cfg, "train.batch_size", 100)
            OmegaConf.update(self.cfg, "train.epochs", 110)
            OmegaConf.update(self.cfg, "kernel.omega_0", 750.18)


    def setup(self, stage: str):
        self._yaml_parameters()

        self.batch_size = self.cfg.train.batch_size

        if stage == "fit":
            # train
            x_train, y_train = load_data_from_partition(
                self.data_processed_location, partition="train"
            )
            self.train_dataset = TensorDataset(x_train, y_train)
            # validation
            x_val, y_val = load_data_from_partition(
                self.data_processed_location, partition="val"
            )
            self.val_dataset = TensorDataset(x_val, y_val)
        if stage == "test":
            # test
            x_test, y_test = load_data_from_partition(
                self.data_processed_location, partition="test"
            )
            self.test_dataset = TensorDataset(x_test, y_test)
        if stage == "predict":
            # predict
            x_test, y_test = load_data_from_partition(
                self.data_processed_location, partition="test"
            )
            self.test_dataset = TensorDataset(x_test, y_test)


    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)
    
    def predict_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...

get_data_module function

In [None]:
def get_data_module(cfg : OmegaConf):
    
    assert cfg.data.dataset in ["smnist","pmnist","cifar10","scifar10","cifar100","stl10","sc_mfcc","sc_raw","pathfinder","path_x","image"], "Dataset not supported"

    # can be either sequential or permuted mnist
    if "mnist" in cfg.data.dataset: 
        return MnistDataModule(cfg)
    # if cfg.data.dataset == "cifar10" or cfg.data.dataset == "scifar10":
    #     return Cifar10DataModule(cfg)
    # if cfg.data.dataset == "cifar100":
    #     return Cifar100DataModule(cfg)
    # if cfg.data.dataset == "stl10":
    #     return STL10DataModule(cfg)
    if cfg.data.dataset == "sc_mfcc" or cfg.data.dataset == "sc_raw":
        return SpeechCommandsModule(cfg)
    
    # TODO other dataset

# CCNN implementation

GetBatchNormalization function

In [None]:
def GetBatchNormalization(data_dim, num_features):
    if data_dim == 1:
        return nn.BatchNorm1d(num_features)
    elif data_dim == 2:
        return nn.BatchNorm2d(num_features)
    elif data_dim == 3:
        return nn.BatchNorm3d(num_features)

GetDropout function

In [None]:
def GetDropout(data_dim):
    if data_dim == 1:
        return nn.Dropout1d()
    elif data_dim == 2:
        return nn.Dropout2d()
    elif data_dim == 3:
        return nn.Dropout3d()

GetAdaptiveAvgPool function

In [None]:
def GetAdaptiveAvgPool(data_dim, output_size):
    if data_dim == 1:
        return nn.AdaptiveAvgPool1d(output_size)
    elif data_dim == 2:
        return nn.AdaptiveAvgPool2d(output_size)
    elif data_dim == 3:
        return nn.AdaptiveAvgPool3d(output_size)

create_coordinates function

In [None]:
def create_coordinates(kernel_size, data_dim):
    
    values = torch.linspace(-1, 1, steps=kernel_size)   # i.e tensor([-1, 1])  
    positions = [values for _ in range(data_dim)]      # i.e [tensor([-1, 1]), tensor([-1, 1])]

    grids = []
    for i, t in enumerate(positions):
        shape = [1] * data_dim      # i.e [1, 1] for data_dim = 2
        shape[i] = -1               # shape = [-1, 1] i = 0
                                    # shape = [1, -1] i = 1
        
        t_reshaped = t.view(*shape) # t_reshaped_0 [3, 1], t_reshaped_1 [1, 3]
        
        t_broadcasted = t_reshaped.expand(* [kernel_size] * data_dim) # expand dimension to match [3,3]
        grids.append(t_broadcasted)

    grids = torch.stack(grids, dim=0).unsqueeze(0)  # stack along a new dimension [2,3,3] and another dimension [1,2,3,3]
    
    return grids

Linear Layer

In [None]:
class LinearLayer(nn.Module):
    def __init__(self, dim: int, in_channels: int, out_channels: int, bias: bool = True):
        super().__init__()
        if dim == 1:
            self.layer = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, bias=bias)
        elif dim == 2:
            self.layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=bias)
        elif dim == 3:
            self.layer = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=bias)
        else:
            raise ValueError(f"Invalid dimension {dim}. Supported dimensions are 1, 2, and 3.")

    def forward(self, x):
        return self.layer(x)

Anisotropic Gabor Layer

In [None]:
class AnisotropicGaborLayer(nn.Module):
    def __init__(
        self,
        data_dim: int,
        hidden_channels: int,
        current_layer: int,
        omega_0: float = 2976.49,
        alpha: float = 6.0,
        beta: float = 1.0,
    ):
        super().__init__()

        self.data_dim = data_dim

        # linear layer
        self.linear = LinearLayer(
            dim=data_dim,
            in_channels=data_dim,
            out_channels=hidden_channels,
            bias=True,
        )

        gamma_dist = torch.distributions.gamma.Gamma(alpha / (current_layer + 1), beta)

        # generate as gamma_dist as data_dim (gamma_x, gamma_y, ...)
        self.gamma = nn.ParameterList(
            [
                nn.Parameter(gamma_dist.sample((hidden_channels, 1)))
                for _ in range(data_dim)
            ]
        )

        normal_dist = torch.distributions.normal.Normal(0, 1)

        # generate as many mi as data_dim (mi_x, mi_y, ...)
        self.mi = nn.ParameterList(
            [
                nn.Parameter(normal_dist.sample((hidden_channels, 1)))
                for _ in range(data_dim)
            ]
        )
        
        scaling_factor = 25.6

        self.linear.weight = nn.Parameter(torch.randn(hidden_channels,data_dim,*((1,) * data_dim)))
        self.linear.weight.data *= omega_0 * scaling_factor * self.gamma[0].view(
            *self.gamma[0].shape, *((1,) * data_dim)
        )

        self.linear.bias = nn.Parameter(torch.randn(hidden_channels))
        self.linear.bias.data.uniform_(-np.pi, np.pi)

    def forward(self, x):

        # coordinates (x,y,...)
        coord = [x[0][i] for i in range(self.data_dim)]

        # reshaping the parameters to [1, 1, 1, W, H] if data_dim = 2
        reshaped_coord = [c.view(1, 1, 1, *c.shape) for c in coord]

        reshaped_gamma = [
            g.view(1, *g.shape, *((1,) * (self.data_dim))) for g in self.gamma
        ]

        reshaped_mi = [m.view(1, *m.shape, *((1,) * (self.data_dim))) for m in self.mi]
        # -> [1, hidden_channels, 1, 1, 1] if data_dim = 2

        g_envelopes = []
        for i in range(self.data_dim):
            g_envelope = torch.exp(
                -0.5 * (reshaped_gamma[i] * (reshaped_coord[i] - reshaped_mi[i])) ** 2
            )  # Shape: [1, hidden_channels, 20, 20]
            g_envelopes.append(g_envelope)

        # Multiply all the envelopes together
        g_envelope = g_envelopes[0]
        for i in range(1, self.data_dim):
            g_envelope *= g_envelopes[i]

        # Squeeze the third dimension
        g_envelope = g_envelope.squeeze(2)

        # computing the sinusoidal
        sinusoidal = torch.sin(self.linear(x))

        return g_envelope * sinusoidal

MFN

In [None]:
class MFN(nn.Module):
    def __init__(
        self, data_dim: int, hidden_channels: int, out_channels: int, no_layers: int
    ):
        """
        Initializes an instance of the MFN class.
        Args:
            data_dim (int): The dimension of the input data.
            hidden_channels (int): The number of hidden channels in the linear layers.
            out_channels (int): The number of output channels in the final linear layer.
            no_layers (int): The number of hidden layers in the network.
        Returns:
            None
        """
        super(MFN, self).__init__()

        # hidden layers
        self.linearLayer = nn.ModuleList(
            [
                LinearLayer(
                    dim=data_dim,
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    bias=True,
                )
                for _ in range(no_layers - 1)
            ]
        )

        # output layer
        self.linearLayer.append(
            LinearLayer(
                dim=data_dim,
                in_channels=hidden_channels,
                out_channels=out_channels,
                bias=True,
            )
        )

        self.reweighted_output_layer = False

    def re_weight_output_layer(self, kernel_positions: torch.Tensor, in_channels: int, data_dim: int):
        """
        Re-weights the last layer of the kernel net by factor = gain / sqrt(in_channels * kernel_size).
        Args:
            gain (float): The gain to re-weight the last layer by.
        Returns:
            None
        """

        if not self.reweighted_output_layer:

            # Re weight the last layer of the kernel net
            
            kernel_size = torch.Tensor([*kernel_positions.shape[data_dim:]]).prod().item() # just a way to get the kernel size 
            # [1, 2, 33, 33] -> [33,33] for data_dim=2
            # prod multiplies all elements in the tensor i.e. 33*33 = 1089
            # item converts the tensor to a python number
            
            # define gain / sqrt(in_channels * kernel_size) by Chang et al. (2020)
            factor = 1.0 / math.sqrt(in_channels * kernel_size)

            # get the last layer and re-weight it
            self.linearLayer[-1].layer.weight.data *= factor 

            # set the flag to True so that the output layer is only re-weighted the first time                                          
            self.reweighted_output_layer = True


    def forward(self, x):
        
        h = self.gabor_filters[0](x)
        for l in range(1, len(self.gabor_filters)):
            h = self.gabor_filters[l](x) * self.linearLayer[l - 1](h)

        last = self.linearLayer[-1](h)

        return last

MAGNet

In [None]:
class MAGNet(MFN):
    def __init__(
        self, data_dim: int, hidden_channels: int, out_channels: int, no_layers: int, omega_0: float
    ):
        """
        TODO
        """
        super().__init__(data_dim, hidden_channels, out_channels, no_layers)
        self.gabor_filters = nn.ModuleList(
            [
                AnisotropicGaborLayer(
                    data_dim=data_dim,
                    hidden_channels=hidden_channels,
                    current_layer=l,
                    omega_0=omega_0,
                )
                for l in range(no_layers)
            ]
        )

Conv functions

In [None]:
def conv1d(
    x: torch.Tensor,
    kernel: torch.Tensor,
    bias: Optional[torch.Tensor],
    padding: int,
    groups: int,
):
    return torch.nn.functional.conv1d(x, kernel, bias=bias, padding=padding, stride=1, groups=groups)

def conv2d(
    x: torch.Tensor,
    kernel: torch.Tensor,
    bias: Optional[torch.Tensor],
    padding: int,
    groups: int,
):
    return torch.nn.functional.conv2d(x, kernel, bias=bias, padding=padding, stride=1, groups=groups)

def conv3d(
    x: torch.Tensor,
    kernel: torch.Tensor,
    bias: Optional[torch.Tensor],
    padding: int,
    groups: int,
):
    return torch.nn.functional.conv3d(x, kernel, bias=bias, padding=padding, stride=1, groups=groups)


def get_conv_function(
    x: torch.Tensor,
    kernel: torch.Tensor,
    bias: Optional[torch.Tensor],
    padding: int,
    groups: int,
    dim: int,
):
    """
    Returns the Convolutional Layer.
    """
    if dim == 1:
        # x = F.pad(x, [padding[0], padding[0]], value=0.0)
        return conv1d(x, kernel, bias, padding, groups)
    elif dim == 2:
        return conv2d(x, kernel, bias, padding, groups)
    elif dim == 3:
        return conv3d(x, kernel, bias, padding, groups)
    else:
        raise ValueError(f"Invalid dimension {dim}")
    

def conv(
    x: torch.Tensor,
    kernel: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
):

    data_dim = len(x.shape) - 2
    # -> [batch_size, channels, x_dimension, y_dimension, ...] -> len[x.shape] = 2 + data_dim

    kernel_size = torch.tensor(kernel.shape[-data_dim:])
    assert torch.all(
        kernel_size % 2 != 0
    ), f"Convolutional kernels must have odd dimensionality. Received {kernel.shape}"
    # pad by kernel_size // 2 so that the output has the same size as the input
    padding = (kernel_size // 2).tolist() 

    groups = kernel.shape[1]
    # invert first two dimensions of kernel because there should be one kernel per input channel
    kernel = kernel.view(kernel.shape[1], 1, *kernel.shape[2:])

    return get_conv_function(x, kernel, bias, padding=padding, groups=groups, dim=data_dim)


def fftconv(
    x: torch.Tensor,
    kernel: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

    data_dim = len(x.shape) - 2
    # -> [batch_size, channels, x_dimension, y_dimension, ...] -> len[x.shape] = 2 + data_dim

    assert data_dim == 1

    kernel_size = torch.tensor(kernel.shape[-data_dim:])
    assert torch.all(
        kernel_size % 2 != 0
    ), f"Convolutional kernels must have odd dimensionality. Received {kernel.shape}"

    # padding input
    padding_x = kernel.shape[-1] // 2
    padding_x = (2 * data_dim) * [padding_x]

    x_padded = F.pad(x, padding_x)

    if x_padded.shape[-1] % 2 != 0:
        x_padded = F.pad(x_padded, [0, 1])

    # padding kernel
    padding_kernel = [
        pad
        for i in reversed(range(2, x_padded.ndim))
        for pad in [0, x_padded.shape[i] - kernel.shape[i]]
    ]

    kernel_padded = F.pad(kernel, padding_kernel, mode="constant", value=0)

    # Fourier Transform
    x_fr = torch.fft.rfftn(x_padded, dim=tuple(range(2, x_padded.ndim)))
    kernel_fr = torch.fft.rfftn(kernel_padded, dim=tuple(range(2, kernel.ndim)))

    # (Input * Conj(Kernel)) = Correlation(Input, Kernel)
    # b->batch i->input channel o->output channel , ...> any additional dimensions
    # The einsum notation specifies that for each position in the output tensor,
    # you perform element-wise multiplication of x_fr and kernel_fr and sum over the i dimension (input channels)
    # and any remaining spatial dimensions
    # Essentially, it calculates a form of convolution (correlation) in the Fourier domain,
    # where the i dimension of x_fr is multiplied with the i dimension of kernel_fr
    # and summed to produce the output tensor's o dimension
    kernel_fr = torch.conj(kernel_fr)
    # Assuming x_fr has shape [batch_size, num_channels, x_dim1, x_dim2, ...]
    # and kernel_fr has shape [num_channels, k_dim1, k_dim2, ...]
    print(
        f"expected_shape : [batch_size, num_channels, x_dim1, x_dim2, ...] , x_fr shape: {x_fr.shape}"
    )
    print(
        f"expected_shape : [num_channels, k_dim1, k_dim2, ...] ,kernel_fr shape: {kernel_fr.shape}"
    )
    # Element-wise Multiplication in Fourier domain
    output_fr = x_fr * kernel_fr

    # Inverse FFT to transform the result back to the spatial domain
    out = torch.fft.irfftn(output_fr, dim=tuple(range(2, x_padded.ndim))).float()

    # This part of the code ensures that the output tensor out has the same spatial dimensions as the original input tensor x (before padding)

    # Select all elements in the batch_size and channels dimensions (first two dimensions of out)
    slices = [slice(None), slice(None)]
    # Extension of the slices list to include slices for each spatial dimension (for all dimensions [data_dim])
    # Let's assume x_padded has a shape of [batch_size, channels, height, width]. After this step, slices might look like:
    # - slices = [slice(None), slice(None), slice(None, height), slice(None, width)]
    slices.extend(slice(None, x.shape[-i]) for i in range(1, data_dim + 1))
    # This operation effectively crops the out tensor to remove any padding that was added during the earlier steps
    out = out[tuple(slices)]

    # Add bias if provided
    if bias is not None:
        out = out + bias.view(1, -1, *([1] * data_dim))

    return out

SeparableFlexConv

In [None]:
class SepFlexConv(nn.Module):
    """
    SeparableFlexConv (SepFlexConv) is a depthwise separable version of FlexConv (Romero et al., 2022a)
    
    ConstructMaskedKernel is a continuous version of the kernel whichs multiplied by a Gaussian mask.
    
    The gaussian mask has learnable parameters and by learning it the model can learn the size of the convolutional kernel.

    The flow is the following:

        input
            |
            |
            |
            -------------- | 
            |              |input.length
            |              |
            |    ConstructMaskedKernel   
            |              |
            |              |
            SpatialConvolution
                    |
                    |
            DepthwiseConvolution

    """

    def __init__(
        self,
        data_dim: int,
        in_channels: int,
        net_cfg: OmegaConf,
        kernel_cfg: OmegaConf, 
    ):
        """
        Initializes the CKConv module.
        Args:
            data_dim (int): The dimensionality of the input data.
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            hidden_channels (int): The number of hidden channels.
            kernel_no_layers (int): The number of layers in the kernel network.
            kernel_hidden_channels (int): The number of hidden channels in the kernel network.
            kernel_size (int, optional): The size of the kernel. Defaults to 33.
            conv_type (str, optional): The type of convolution. Defaults to "conv".
            fft_thresold (int, optional): The threshold for using FFT. Defaults to 50.
            bias (bool, optional): Whether to include bias in the pointwise convolution layer. Defaults to False.
        """
        super().__init__()

        # sep flex conv parameters
        self.data_dim = data_dim
        self.in_channels = in_channels
        hidden_channels = net_cfg.hidden_channels
    
        
        # kernel parameters
        kernel_no_layers = kernel_cfg.kernel_no_layers
        kernel_hidden_channels = kernel_cfg.kernel_hidden_channels
        self.kernel_size = kernel_cfg.kernel_size
        self.conv_type = kernel_cfg.conv_type
        self.fft_threshold = kernel_cfg.fft_threshold
        

        # init relative positions of the kernel
        self.kernel_positions = torch.zeros(1)

        if net_cfg.bias:
            # init random bias with in_channels dimensions
            self.bias = torch.randn(in_channels)
            self.bias.data.fill_(0.0)
        else :
            self.bias = None


        # init gaussian mask parameter
        self.mask_mean = torch.nn.Parameter(torch.zeros(data_dim)) # mi = 0
        self.mask_sigma = torch.nn.Parameter(torch.ones(data_dim)) # sigma = 1

    
        # Define the kernel net, in our case always a MAGNet
        self.KernelNet = MAGNet(
            data_dim=data_dim,
            hidden_channels=kernel_hidden_channels,
            out_channels=in_channels, # always in channel because separable
            no_layers=kernel_no_layers,
            omega_0=kernel_cfg.omega_0,
        )
        
        # Define the pointwise convolution layer (page 4 original paper)
        self.pointwise_conv = LinearLayer(
            dim=data_dim,
            in_channels=in_channels,
            out_channels=hidden_channels,
            bias=net_cfg.bias,
        )
        

    def construct_masked_kernel(self, x):
        """
        Construct the masked kernel by multiplying the result of the kernel net with a
        gaussian mask.

        input.length
        | 
        GetRelPositions             
        RelPositions             
        |     
        KernelNet          
        ConvKernel               
        |             
        GaussMask                  
        |
        MaskedKernel
        """

        # 1. Get the relative positions
        kernel_positions = self.get_rel_positions(x)

        # 2 Re-weight the output layer of the kernel net
        self.KernelNet.re_weight_output_layer(kernel_positions, self.in_channels, self.data_dim)

        # 3. Get the kernel
        conv_kernel = self.KernelNet(kernel_positions) 

        # 4. Get the mask gaussian mask
        mask = self.gaussian_mask(
            kernel_pos=kernel_positions,
            mask_mean=self.mask_mean,  
            mask_sigma=self.mask_sigma,
        )

        return conv_kernel * mask

    def get_rel_positions(self, x):
        """
        Handles the vector or relative positions which is given to KernelNet.
        """
        if (
            self.kernel_positions.shape[-1] == 1  # Only for the first time
        ):  # The conv. receives input signals of length > 1

            # Creates the vector of relative positions
            
            kernel_positions = create_coordinates(
                kernel_size=self.kernel_size,
                data_dim=self.data_dim,
            )
            # -> Grid sized: [kernel_size] * data_dim
            # -> kernel_positions : [1, dim, kernel_size, kernel_size]

            self.kernel_positions = kernel_positions.type_as(self.kernel_positions)
            # -> With form: [batch_size=1, dim, x_dimension, y_dimension, ...]

            # Save the step size for the calculation of dynamic cropping
            # The step is max - min / (no_steps - 1)
            # TODO : Check cropping
            # self.linspace_stepsize = (
            #     (1.0 - (-1.0)) / (self.train_length[0] - 1)
            # ).type_as(self.linspace_stepsize)
        return self.kernel_positions

    def gaussian_mask(
            self,
            kernel_pos: torch.Tensor,
            mask_mean: torch.Tensor,
            mask_sigma: torch.Tensor
        ) -> torch.Tensor:
        
        """
        Generates a Gaussian mask based on the given parameters.
        Args:
            kernel_pos (torch.Tensor): The position of the kernel.
            mask_mean (torch.Tensor): The mean value of the mask.
            mask_sigma (torch.Tensor): The standard deviation of the mask.
            Returns:
                torch.Tensor: The generated Gaussian mask of [1, 1, Y, X] in 2D or [1, 1, X] in 1D
        
        Example 2D:
            if kernel_pos.shape = [1, 2, 33, 33] and mask_mean.shape = [1, 2] 
            in order to sum them you need to reshape mask_mean to [1, 2, 1, 1]
            Then you sum over the first dimension and the output will be [1, 1, 33, 33]
        """
        
        # reshape the mask_mean and mask_sigma so that they can be broadcasted
        mask_mean = mask_mean.view(1, self.data_dim, *(1,) * self.data_dim)
        mask_sigma = mask_sigma.view(1, self.data_dim, *(1,) * self.data_dim)
        
        return torch.exp(
            -0.5
            * (
                1.0 / (mask_sigma**2) * (kernel_pos - mask_mean) ** 2
            ).sum(1, keepdim=True)
        )

    def forward(self, x):
        """
        Forward pass of the SepFlexConv model.
        Args:
            x (torch.Tensor): The input tensor.
        Example 2D:
            1. x.shape = [64, 140, 32, 32]
            2. masked_kernel.shape = [1, 140, 33, 33]
            3. spatial convolution between x and masked kernel -> [64, 140, 32, 32]
            4. Pointwise convolution -> [64, 140, 32, 32]
        """
        
        masked_kernel = self.construct_masked_kernel(x)

        size = torch.tensor(masked_kernel.shape[2:]) # -> [33,33] for data_dim=2
        # fftconv is used when the size of the kernel is large enough
        if self.conv_type == "fftconv" and torch.all(size > self.fft_thresold):
            out = fftconv(x=x, kernel=masked_kernel, bias=self.bias)
        else:
            out = conv(x=x, kernel=masked_kernel, bias=self.bias)

        # pointwise convolution where out is the spatial convolution
        out = self.pointwise_conv(out)
    
        return out

S4Block

In [None]:
class S4Block(nn.Module):
    """
    Create a S4 block (Gu et al., 2022) as defined in the Continuous CNN architecture.

          input
            |
    | -------------|
    |            BarchNorm             
    |            SepFlecConv             
    |            GELU     
    |            DropOut          
    |            PointwiseLinear                
    |            GELU             
    |              |
    |---->(+)<-----|
           |
        output

    """

    def __init__(
            self,
            in_channels,
            out_channels,
            data_dim,
            net_cfg: OmegaConf,
            kernel_cfg: OmegaConf, 
        ):
        """
        Method to init the S4 block
        """
        super().__init__()

        self.batch_norm_layer = GetBatchNormalization(data_dim=data_dim, num_features=in_channels)

        # separable flexible convolutional layer
        self.sep_flex_conv_layer = SepFlexConv(
            data_dim=data_dim,
            in_channels=in_channels,
            net_cfg=net_cfg,
            kernel_cfg=kernel_cfg
        )
        
        self.gelu_layer = [nn.GELU(), nn.GELU()]

        self.dropout_layer = GetDropout(data_dim=data_dim)

        # pointwise linear convolutional layer
        self.pointwise_linear_layer = LinearLayer(data_dim, in_channels, out_channels)

        self.seq_modules = nn.Sequential(
            self.batch_norm_layer,
            self.sep_flex_conv_layer,
            self.gelu_layer[0],
            self.dropout_layer,
            self.pointwise_linear_layer,
            self.gelu_layer[1]
        )

        # Used in residual networks (ResNets) to add a direct path from the input to the output, 
        # which helps in training deeper networks by mitigating the vanishing gradient problem.
        shortcut = []
        if in_channels != out_channels:
            shortcut.append(LinearLayer(data_dim, in_channels, out_channels))
            nn.init.kaiming_normal_(shortcut[0].weight)
            if shortcut[0].bias is not None:
                shortcut[0].bias.data.fill_(value=0.0)
        # If no layer is added (because in_channels and out_channels were the same), 
        # the shortcut will be empty and effectively be an identity mapping.
        self.shortcut = nn.Sequential(*shortcut)

    def forward(self, x):
        """
        Standard method of nn.modules we embed also the residual connection
        """
        shortcut = self.shortcut(x)
        out = self.seq_modules(x)
        return out + shortcut

CCNN

In [None]:
class CCNN(pl.LightningModule):
    """
    CCNN architecture (Romero et al., 2022) as defined in the original paper.

    input --> SepFlexConv --> BatchNorm --> GELU --> L x S4Block --> BatchNorm --> GlobalAvgPool -->PointwiseLinear --> output
    """
    def __init__(
        self,  
        in_channels: int,
        out_channels: int, 
        data_dim: int, 
        cfg: OmegaConf
    ):
        super(CCNN, self).__init__()

        self.no_blocks = cfg.net.no_blocks
        hidden_channels = cfg.net.hidden_channels

        self.learning_rate = cfg.train.learning_rate
        self.warmup_epochs = cfg.train.warmup_epochs
        self.epochs = cfg.train.epochs
        self.start_factor = cfg.train.start_factor
        self.end_factor = cfg.train.end_factor

        # separable flexible convolutional layer
        self.sep_flex_conv_layer = SepFlexConv(
            data_dim=data_dim,
            in_channels=in_channels, 
            net_cfg=cfg.net,
            kernel_cfg=cfg.kernel
        )
        # batch normalization layer
        self.batch_norm_layer = [
            GetBatchNormalization(data_dim=data_dim, num_features=hidden_channels),
            GetBatchNormalization(data_dim=data_dim, num_features=hidden_channels)
        ]
        # gelu layer
        self.gelu_layer = nn.GELU()
        # s4blocks
        self.blocks = []
        for _ in range(self.no_blocks):
            s4 = S4Block(in_channels=hidden_channels, out_channels=hidden_channels, data_dim=data_dim, net_cfg=cfg.net, kernel_cfg=cfg.kernel)
            self.blocks.append(s4)
        
        
        # global average pooling layer (the information of each channel is compressed into a single value)
        self.global_avg_pool_layer = GetAdaptiveAvgPool(data_dim=data_dim, output_size=(1,) * data_dim)
        # pointwise linear convolutional layer
        self.pointwise_linear_layer = LinearLayer(data_dim, hidden_channels, out_channels)

        # define sequencial modules
        self.seq_modules = nn.Sequential(
            self.sep_flex_conv_layer,
            self.batch_norm_layer[0],
            self.gelu_layer,
            *self.blocks,
            self.batch_norm_layer[1],
            self.global_avg_pool_layer,
            self.pointwise_linear_layer
        )

        # define metrics
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=out_channels)
        self.f1_score = torchmetrics.F1Score(task="multiclass", num_classes=out_channels)

    def forward(self, x):
        
        out = self.seq_modules(x)

        return out.squeeze()

    # Probably works only for sMNIST
    def training_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        accuracy = self.accuracy(scores, y)
        f1_score = self.f1_score(scores, y)
        metrics_dict = {
            'train_loss': loss,
            'train_accuracy': accuracy,
            'train_f1_score': f1_score
        }
        self.log_dict(dictionary=metrics_dict, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('val_loss', loss)
        self.log('accuracy', self.accuracy(scores, y))
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        
        metrics_dict = {
            'loss': loss,
            'accuracy': self.accuracy(scores, y),
            
        }
        self.log_dict(metrics_dict)
        #self.log('accuracy', self.accuracy(scores, y))
        return loss

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        scores = self.seq_modules(x).squeeze()
        preds = torch.argmax(scores, dim=1)
        return preds

    def _common_step(self, batch, batch_idx):
        x, y = batch
        scores = self.forward(x)
        loss = F.cross_entropy(scores, y)
        return loss, scores, y
    
    def configure_optimizers(self):
        # Define the optimizer (AdamW)
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)

        # Define the linear learning rate warm-up for 10 epochs
        linear_warmup = optim.lr_scheduler.LinearLR(optimizer=optimizer, start_factor=self.start_factor, end_factor=self.end_factor, total_iters=self.warmup_epochs)

        # Define the cosine annealing scheduler
        cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(self.epochs - self.warmup_epochs))

        # Combine the warm-up and cosine annealing using SequentialLR
        scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[linear_warmup, cosine_scheduler], milestones=[self.warmup_epochs])

        return {"optimizer": optimizer, "lr_scheduler": scheduler}

# Execution

In [None]:
def create_model(cfg: OmegaConf) -> CCNN:
    return CCNN(
        in_channels=cfg.net.in_channels,
        out_channels=cfg.net.out_channels,
        data_dim=cfg.net.data_dim,
        cfg=cfg
    )

def setup_trainer_components(cfg: OmegaConf):
    # Setup logger
    logger = None
    if cfg.train.logger:
        logger = TensorBoardLogger("tb_logs", name=f"{cfg.data.dataset}_{cfg.net.no_blocks}_{cfg.net.hidden_channels}")

    # Setup callbacks
    callbacks = []
    if cfg.train.callbacks:
        checkpoint_callback = ModelCheckpoint(
            monitor="val_acc",
            dirpath="checkpoints",
            save_top_k=1,
            mode="max"
        )
        early_stop_callback = EarlyStopping(monitor="val_loss")
        model_summary_callback = ModelSummary(max_depth=-1)
        callbacks.extend([model_summary_callback, checkpoint_callback, early_stop_callback])

    # Setup profiler
    profiler = None
    if cfg.train.profiler:
        profiler = PyTorchProfiler(
            output_filename="profiler_output",
            group_by_input_shapes=True,
        )

    return logger, callbacks, profiler

def create_trainer(cfg: OmegaConf, logger: TensorBoardLogger, callbacks: list, profiler: PyTorchProfiler) -> pl.Trainer:
    return pl.Trainer(
        logger=logger,
        accelerator=cfg.train.accelerator,
        devices=cfg.train.devices,
        max_epochs=cfg.train.epochs,
        callbacks=callbacks,
        profiler=profiler
    )

def train_and_evaluate(trainer: pl.Trainer, model: CCNN, datamodule, callbacks: list) -> None:
    trainer.fit(model, datamodule)
    trainer.validate(model, datamodule)
    trainer.test(model, datamodule)
    checkpoint_callback = next(cb for cb in callbacks if isinstance(cb, ModelCheckpoint))
    print("Finished training, best model path: ", checkpoint_callback.best_model_path)

def load_and_predict(trainer: pl.Trainer, model: CCNN, datamodule, path: str) -> None:
    # TODO check how you take the best model path
    model = model.load_from_checkpoint(path)
    trainer.predict(model, datamodule)

cfg = OmegaConf.load("../config/config.yaml")

# 1. Create the dataset
datamodule = get_data_module(cfg)
# 2. Create the model
model = create_model(cfg)
# 3. Create the logger, callbacks, profiler and trainer
logger, callbacks, profiler = setup_trainer_components(cfg)
trainer = create_trainer(cfg, logger, callbacks, profiler)

# 4. Train the model or use a pretrained one
if not cfg.pre_trained:
    train_and_evaluate(trainer, model, datamodule, callbacks)
else:
    load_and_predict(trainer, model, datamodule, callbacks)