# Model experiments for Animals classification

In [1]:
import os
from enum import Enum
from pathlib import Path
from typing import Any
from abc import abstractmethod
from dataclasses import dataclass
import PIL
import json

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.io import decode_image
from torchvision import transforms

## Dataset

In [2]:
class EnglishLabel(str, Enum):
    DOG = "dog"
    HORSE = "horse"
    ELEPHANT = "elephant"
    BUTTERFLY = "butterfly"
    CHICKEN = "chicken"
    CAT = "cat"
    COW = "cow"
    SHEEP = "sheep"
    SPIDER = "spider"
    SQUIRREL = "squirrel"


class ItalianLabel(str, Enum):
    CANE = "cane"
    CAVALLO = "cavallo"
    ELEFANTE = "elefante"
    FARFALLA = "farfalla"
    GALLINA = "gallina"
    GATTO = "gatto"
    MUCCA = "mucca"
    PECORA = "pecora"
    RAGNO = "ragno"
    SCOIATTOLO = "scoiattolo"


def translate_labels(labels: list[EnglishLabel]) -> list[ItalianLabel]:
    translate = {
        EnglishLabel.DOG: ItalianLabel.CANE,
        EnglishLabel.HORSE: ItalianLabel.CAVALLO,
        EnglishLabel.ELEPHANT: ItalianLabel.ELEFANTE,
        EnglishLabel.BUTTERFLY: ItalianLabel.FARFALLA,
        EnglishLabel.CHICKEN: ItalianLabel.GALLINA,
        EnglishLabel.CAT: ItalianLabel.GATTO,
        EnglishLabel.COW: ItalianLabel.MUCCA,
        EnglishLabel.SHEEP: ItalianLabel.PECORA,
        EnglishLabel.SPIDER: ItalianLabel.RAGNO,
        EnglishLabel.SQUIRREL: ItalianLabel.SCOIATTOLO,
    }
    return [translate[label] for label in labels]


def build_dataframes(
    data_dir: str | Path, english_labels: list[EnglishLabel], test_size: float = 0.2
) -> tuple[pd.DataFrame, pd.DataFrame, int]:
    data_dir = Path(data_dir)
    if not data_dir.exists() or not data_dir.is_dir():
        raise ValueError(f"Invalid data directory: {data_dir}")

    italian_labels = translate_labels(english_labels)
    num_labels = len(italian_labels)

    data = []
    for label_idx, italian_label in enumerate(italian_labels):
        label_dir = data_dir / italian_label.value
        for fpath in label_dir.glob("*"):
            if fpath.is_file():
                data.append(
                    {
                        "italian_label": italian_label,
                        "english_label": english_labels[label_idx],
                        "label_idx": label_idx,
                        "path": str(fpath),
                    }
                )
    data_df = pd.DataFrame(data)

    train_df, test_df = train_test_split(data_df, test_size=test_size)
    return train_df, test_df, num_labels


class AnimalsDataset(Dataset):
    def __init__(
        self, data_df: pd.DataFrame, num_labels: int, transform: Any | None = None
    ):
        self.data_df = data_df
        self.num_labels = num_labels
        self.transform = transform

    def __len__(self) -> int:
        return len(self.data_df)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        entry = self.data_df.iloc[idx]
        image = decode_image(entry["path"])
        if self.transform is not None:
            image = self.transform(image)

        labels = np.zeros(self.num_labels, dtype=int)
        labels[entry["label_idx"]] = 1
        return image, torch.tensor(labels)

In [3]:
train_df, test_df, num_labels = build_dataframes(
    "../data/raw-img", [EnglishLabel.DOG, EnglishLabel.HORSE], test_size=0.2
)

### SENN

#### Conceptizers

In [4]:
class Conceptizer(nn.Module):
    def __init__(self):
        """
        A general Conceptizer meta-class. Children of the Conceptizer class
        should implement encode() and decode() functions.
        """
        super(Conceptizer, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

    def forward(self, x):
        """
        Forward pass of the general conceptizer.

        Computes concepts present in the input.

        Parameters
        ----------
        x : torch.Tensor
            Input data tensor of shape (BATCH, *). Only restriction on the shape is that
            the first dimension should correspond to the batch size.

        Returns
        -------
        encoded : torch.Tensor
            Encoded concepts (batch_size, concept_number, concept_dimension)
        decoded : torch.Tensor
            Reconstructed input (batch_size, *)
        """
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return encoded, decoded.view_as(x)

    @abstractmethod
    def encode(self, x):
        """
        Abstract encode function to be overridden.
        Parameters
        ----------
        x : torch.Tensor
            Input data tensor of shape (BATCH, *). Only restriction on the shape is that
            the first dimension should correspond to the batch size.
        """
        pass

    @abstractmethod
    def decode(self, encoded):
        """
        Abstract decode function to be overridden.
        Parameters
        ----------
        encoded : torch.Tensor
            Latent representation of the data
        """
        pass


class ConvConceptizer(Conceptizer):
    def __init__(self, image_size, num_concepts, concept_dim, image_channels=1, encoder_channels=(10,),
                 decoder_channels=(16, 8), kernel_size_conv=5, kernel_size_upsample=(5, 5, 2),
                 stride_conv=1, stride_pool=2, stride_upsample=(2, 1, 2),
                 padding_conv=0, padding_upsample=(0, 0, 1), **kwargs):
        """
        CNN Autoencoder used to learn the concepts, present in an input image

        Parameters
        ----------
        image_size : int
            the width of the input image
        num_concepts : int
            the number of concepts
        concept_dim : int
            the dimension of each concept to be learned
        image_channels : int
            the number of channels of the input images
        encoder_channels : tuple[int]
            a list with the number of channels for the hidden convolutional layers
        decoder_channels : tuple[int]
            a list with the number of channels for the hidden upsampling layers
        kernel_size_conv : int, tuple[int]
            the size of the kernels to be used for convolution
        kernel_size_upsample : int, tuple[int]
            the size of the kernels to be used for upsampling
        stride_conv : int, tuple[int]
            the stride of the convolutional layers
        stride_pool : int, tuple[int]
            the stride of the pooling layers
        stride_upsample : int, tuple[int]
            the stride of the upsampling layers
        padding_conv : int, tuple[int]
            the padding to be used by the convolutional layers
        padding_upsample : int, tuple[int]
            the padding to be used by the upsampling layers
        """
        super(ConvConceptizer, self).__init__()
        self.num_concepts = num_concepts
        self.filter = filter
        self.dout = image_size

        # Encoder params
        encoder_channels = (image_channels,) + encoder_channels
        kernel_size_conv = handle_integer_input(kernel_size_conv, len(encoder_channels))
        stride_conv = handle_integer_input(stride_conv, len(encoder_channels))
        stride_pool = handle_integer_input(stride_pool, len(encoder_channels))
        padding_conv = handle_integer_input(padding_conv, len(encoder_channels))
        encoder_channels += (num_concepts,)

        # Decoder params
        decoder_channels = (num_concepts,) + decoder_channels
        kernel_size_upsample = handle_integer_input(kernel_size_upsample, len(decoder_channels))
        stride_upsample = handle_integer_input(stride_upsample, len(decoder_channels))
        padding_upsample = handle_integer_input(padding_upsample, len(decoder_channels))
        decoder_channels += (image_channels,)

        # Encoder implementation
        self.encoder = nn.ModuleList()
        for i in range(len(encoder_channels) - 1):
            self.encoder.append(self.conv_block(in_channels=encoder_channels[i],
                                                out_channels=encoder_channels[i + 1],
                                                kernel_size=kernel_size_conv[i],
                                                stride_conv=stride_conv[i],
                                                stride_pool=stride_pool[i],
                                                padding=padding_conv[i]))
            self.dout = (self.dout - kernel_size_conv[i] + 2 * padding_conv[i] + stride_conv[i] * stride_pool[i]) // (
                    stride_conv[i] * stride_pool[i])

        if self.filter and concept_dim == 1:
            self.encoder.append(ScalarMapping((self.num_concepts, self.dout, self.dout)))
        else:
            self.encoder.append(Flatten())
            self.encoder.append(nn.Linear(self.dout ** 2, concept_dim))

        # Decoder implementation
        self.unlinear = nn.Linear(concept_dim, self.dout ** 2)
        self.decoder = nn.ModuleList()
        decoder = []
        for i in range(len(decoder_channels) - 1):
            decoder.append(self.upsample_block(in_channels=decoder_channels[i],
                                               out_channels=decoder_channels[i + 1],
                                               kernel_size=kernel_size_upsample[i],
                                               stride_deconv=stride_upsample[i],
                                               padding=padding_upsample[i]))
            decoder.append(nn.ReLU(inplace=True))
        decoder.pop()
        decoder.append(nn.Tanh())
        self.decoder = nn.ModuleList(decoder)

    def encode(self, x):
        """
        The encoder part of the autoencoder which takes an Image as an input
        and learns its hidden representations (concepts)

        Parameters
        ----------
        x : Image (batch_size, channels, width, height)

        Returns
        -------
        encoded : torch.Tensor (batch_size, concept_number, concept_dimension)
            the concepts representing an image

        """
        encoded = x
        for module in self.encoder:
            encoded = module(encoded)
        return encoded

    def decode(self, z):
        """
        The decoder part of the autoencoder which takes a hidden representation as an input
        and tries to reconstruct the original image

        Parameters
        ----------
        z : torch.Tensor (batch_size, channels, width, height)
            the concepts in an image

        Returns
        -------
        reconst : torch.Tensor (batch_size, channels, width, height)
            the reconstructed image

        """
        reconst = self.unlinear(z)
        reconst = reconst.view(-1, self.num_concepts, self.dout, self.dout)
        for module in self.decoder:
            reconst = module(reconst)
        return reconst

    def conv_block(self, in_channels, out_channels, kernel_size, stride_conv, stride_pool, padding):
        """
        A helper function that constructs a convolution block with pooling and activation

        Parameters
        ----------
        in_channels : int
            the number of input channels
        out_channels : int
            the number of output channels
        kernel_size : int
            the size of the convolutional kernel
        stride_conv : int
            the stride of the deconvolution
        stride_pool : int
            the stride of the pooling layer
        padding : int
            the size of padding

        Returns
        -------
        sequence : nn.Sequence
            a sequence of convolutional, pooling and activation modules
        """
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=kernel_size,
                      stride=stride_conv,
                      padding=padding),
            # nn.BatchNorm2d(out_channels),
            nn.AvgPool2d(kernel_size=stride_pool,
                         padding=padding),
            nn.ReLU(inplace=True)
        )

    def upsample_block(self, in_channels, out_channels, kernel_size, stride_deconv, padding):
        """
        A helper function that constructs an upsampling block with activations

        Parameters
        ----------
        in_channels : int
            the number of input channels
        out_channels : int
            the number of output channels
        kernel_size : int
            the size of the convolutional kernel
        stride_deconv : int
            the stride of the deconvolution
        padding : int
            the size of padding

        Returns
        -------
        sequence : nn.Sequence
            a sequence of deconvolutional and activation modules
        """
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=stride_deconv,
                               padding=padding),
        )


class Flatten(nn.Module):
    def forward(self, x):
        """
        Flattens the inputs to only 3 dimensions, preserving the sizes of the 1st and 2nd.

        Parameters
        ----------
        x : torch.Tensor
            Input data tensor of shape (dim1, dim2, *).

        Returns
        -------
        flattened : torch.Tensor
            Flattened input (dim1, dim2, dim3)
        """
        return x.view(x.size(0), x.size(1), -1)


def handle_integer_input(input, desired_len):
    """
    Checks if the input is an integer or a list.
    If an integer, it is replicated the number of  desired times
    If a tuple, the tuple is returned as it is

    Parameters
    ----------
    input : int, tuple
        The input can be either a tuple of parameters or a single parameter to be replicated
    desired_len : int
        The length of the desired list

    Returns
    -------
    input : tuple[int]
        a tuple of parameters which has the proper length.
    """
    if type(input) is int:
        return (input,) * desired_len
    elif type(input) is tuple:
        if len(input) != desired_len:
            raise AssertionError("The sizes of the parameters for the CNN conceptizer do not match."
                                 f"Expected '{desired_len}', but got '{len(input)}'")
        else:
            return input
    else:
        raise TypeError(f"Wrong type of the parameters. Expected tuple or int but got '{type(input)}'")


class ScalarMapping(nn.Module):
    def __init__(self, conv_block_size):
        """
        Module that maps each filter of a convolutional block to a scalar value

        Parameters
        ----------
        conv_block_size : tuple (int iterable)
            Specifies the size of the input convolutional block: (NUM_CHANNELS, FILTER_HEIGHT, FILTER_WIDTH)
        """
        super().__init__()
        self.num_filters, self.filter_height, self.filter_width = conv_block_size

        self.layers = nn.ModuleList()
        for _ in range(self.num_filters):
            self.layers.append(nn.Linear(self.filter_height * self.filter_width, 1))

    def forward(self, x):
        """
        Reduces a 3D convolutional block to a 1D vector by mapping each 2D filter to a scalar value.

        Parameters
        ----------
        x : torch.Tensor
            Input data tensor of shape (BATCH, CHANNELS, HEIGHT, WIDTH).

        Returns
        -------
        mapped : torch.Tensor
            Reduced input (BATCH, CHANNELS, 1)
        """
        x = x.view(-1, self.num_filters, self.filter_height * self.filter_width)
        mappings = []
        for f, layer in enumerate(self.layers):
            mappings.append(layer(x[:, [f], :]))
        return torch.cat(mappings, dim=1)


#### Parameterizers

In [5]:
class LinearParameterizer(nn.Module):
    def __init__(self, num_concepts, num_classes, hidden_sizes=(10, 5, 5, 10), dropout=0.5, **kwargs):
        """Parameterizer for compas dataset.
        
        Solely consists of fully connected modules.

        Parameters
        ----------
        num_concepts : int
            Number of concepts that should be parameterized (for which the relevances should be determined).
        num_classes : int
            Number of classes that should be distinguished by the classifier.
        hidden_sizes : iterable of int
            Indicates the size of each layer in the network. The first element corresponds to
            the number of input features.
        dropout : float
            Indicates the dropout probability.
        """
        super().__init__()
        self.num_concepts = num_concepts
        self.num_classes = num_classes
        self.hidden_sizes = hidden_sizes
        self.dropout = dropout
        layers = []
        for h, h_next in zip(hidden_sizes, hidden_sizes[1:]):
            layers.append(nn.Linear(h, h_next))
            layers.append(nn.Dropout(self.dropout))
            layers.append(nn.ReLU())
        layers.pop()
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        """Forward pass of compas parameterizer.

        Computes relevance parameters theta.

        Parameters
        ----------
        x : torch.Tensor
            Input data tensor of shape (BATCH, *). Only restriction on the shape is that
            the first dimension should correspond to the batch size.

        Returns
        -------
        parameters : torch.Tensor
            Relevance scores associated with concepts. Of shape (BATCH, NUM_CONCEPTS, NUM_CLASSES)
        """
        return self.layers(x).view(x.size(0), self.num_concepts, self.num_classes)


class ConvParameterizer(nn.Module):
    def __init__(self, num_concepts, num_classes, cl_sizes=(1, 10, 20), kernel_size=5, hidden_sizes=(10, 5, 5, 10), dropout=0.5,
                 **kwargs):
        """Parameterizer for MNIST dataset.

        Consists of convolutional as well as fully connected modules.

        Parameters
        ----------
        num_concepts : int
            Number of concepts that should be parameterized (for which the relevances should be determined).
        num_classes : int
            Number of classes that should be distinguished by the classifier.
        cl_sizes : iterable of int
            Indicates the number of kernels of each convolutional layer in the network. The first element corresponds to
            the number of input channels.
        kernel_size : int
            Indicates the size of the kernel window for the convolutional layers.
        hidden_sizes : iterable of int
            Indicates the size of each fully connected layer in the network. The first element corresponds to
            the number of input features. The last element must be equal to the number of concepts multiplied with the
            number of output classes.
        dropout : float
            Indicates the dropout probability.
        """
        super().__init__()
        self.num_concepts = num_concepts
        self.num_classes = num_classes
        self.hidden_sizes = hidden_sizes
        self.cl_sizes = cl_sizes
        self.kernel_size = kernel_size
        self.dropout = dropout

        cl_layers = []
        for h, h_next in zip(cl_sizes, cl_sizes[1:]):
            cl_layers.append(nn.Conv2d(h, h_next, kernel_size=self.kernel_size))
            # TODO: maybe adaptable parameters for pool kernel size and stride
            cl_layers.append(nn.AvgPool2d(2, stride=2))
            cl_layers.append(nn.ReLU())
        # dropout before maxpool
        cl_layers.insert(-2, nn.Dropout2d(self.dropout))
        self.cl_layers = nn.Sequential(*cl_layers)

        fc_layers = []
        for h, h_next in zip(hidden_sizes, hidden_sizes[1:]):
            fc_layers.append(nn.Linear(h, h_next))
            fc_layers.append(nn.Dropout(self.dropout))
            fc_layers.append(nn.ReLU())
        fc_layers.pop()
        fc_layers.append(nn.Tanh())
        self.fc_layers = nn.Sequential(*fc_layers)

    def forward(self, x):
        """Forward pass of MNIST parameterizer.

        Computes relevance parameters theta.

        Parameters
        ----------
        x : torch.Tensor
            Input data tensor of shape (BATCH, *). Only restriction on the shape is that
            the first dimension should correspond to the batch size.

        Returns
        -------
        parameters : torch.Tensor
            Relevance scores associated with concepts. Of shape (BATCH, NUM_CONCEPTS, NUM_CLASSES)
        """
        cl_output = self.cl_layers(x)
        flattened = cl_output.view(x.size(0), -1)
        return self.fc_layers(flattened).view(-1, self.num_concepts, self.num_classes)


#### Aggregators

In [6]:
class SumAggregator(nn.Module):
    def __init__(self, num_classes, **kwargs):
        """Basic Sum Aggregator that joins the concepts and relevances by summing their products.
        """
        super().__init__()
        self.num_classes = num_classes

    def forward(self, concepts, relevances):
        """Forward pass of Sum Aggregator.

        Aggregates concepts and relevances and returns the predictions for each class.

        Parameters
        ----------
        concepts : torch.Tensor
            Contains the output of the conceptizer with shape (BATCH, NUM_CONCEPTS, DIM_CONCEPT=1).
        relevances : torch.Tensor
            Contains the output of the parameterizer with shape (BATCH, NUM_CONCEPTS, NUM_CLASSES).

        Returns
        -------
        class_predictions : torch.Tensor
            Predictions for each class. Shape - (BATCH, NUM_CLASSES)
            
        """
        aggregated = torch.bmm(relevances.permute(0, 2, 1), concepts).squeeze(-1)
        return F.log_softmax(aggregated, dim=1)


#### SENN

In [7]:
class SENN(nn.Module):
    def __init__(self, conceptizer, parameterizer, aggregator):
        """Represents a Self Explaining Neural Network (SENN).
        (https://papers.nips.cc/paper/8003-towards-robust-interpretability-with-self-explaining-neural-networks)

        A SENN model is a neural network made explainable by design. It is made out of several submodules:
            - conceptizer
                Model that encodes raw input into interpretable feature representations of
                that input. These feature representations are called concepts.
            - parameterizer
                Model that computes the parameters theta from given the input. Each concept
                has with it associated one theta, which acts as a ``relevance score'' for that concept.
            - aggregator
                Predictions are made with a function g(theta_1 * h_1, ..., theta_n * h_n), where
                h_i represents concept i. The aggregator defines the function g, i.e. how each
                concept with its relevance score is combined into a prediction.

        Parameters
        ----------
        conceptizer : Pytorch Module
            Model that encodes raw input into interpretable feature representations of
            that input. These feature representations are called concepts.

        parameterizer : Pytorch Module
            Model that computes the parameters theta from given the input. Each concept
            has with it associated one theta, which acts as a ``relevance score'' for that concept.

        aggregator : Pytorch Module
            Predictions are made with a function g(theta_1 * h_1, ..., theta_n * h_n), where
            h_i represents concept i. The aggregator defines the function g, i.e. how each
            concept with its relevance score is combined into a prediction.
        """
        super().__init__()
        self.conceptizer = conceptizer
        self.parameterizer = parameterizer
        self.aggregator = aggregator

    def forward(self, x):
        """Forward pass of SENN module.
        
        In the forward pass, concepts and their reconstructions are created from the input x.
        The relevance parameters theta are also computed.

        Parameters
        ----------
        x : torch.Tensor
            Input data tensor of shape (BATCH, *). Only restriction on the shape is that
            the first dimension should correspond to the batch size.

        Returns
        -------
        predictions : torch.Tensor
            Predictions generated by model. Of shape (BATCH, *).
            
        explanations : tuple
            Model explanations given by a tuple (concepts, relevances).

            concepts : torch.Tensor
                Interpretable feature representations of input. Of shape (NUM_CONCEPTS, *).

            parameters : torch.Tensor
                Relevance scores associated with concepts. Of shape (NUM_CONCEPTS, *)
        """
        concepts, recon_x = self.conceptizer(x)
        relevances = self.parameterizer(x)
        predictions = self.aggregator(concepts, relevances)
        explanations = (concepts, relevances)
        return predictions, explanations, recon_x

#### Losses

In [8]:
def mnist_robustness_loss(x, aggregates, concepts, relevances):
    """Computes Robustness Loss for MNIST data
    
    Formulated by Alvarez-Melis & Jaakkola (2018)
    [https://papers.nips.cc/paper/8003-towards-robust-interpretability-with-self-explaining-neural-networks.pdf]
    The loss formulation is specific to the data format
    The concept dimension is always 1 for this project by design

    Parameters
    ----------
    x            : torch.tensor
                 Input as (batch_size x num_features)
    aggregates   : torch.tensor
                 Aggregates from SENN as (batch_size x num_classes x concept_dim)
    concepts     : torch.tensor
                 Concepts from Conceptizer as (batch_size x num_concepts x concept_dim)
    relevances   : torch.tensor
                 Relevances from Parameterizer as (batch_size x num_concepts x num_classes)
   
    Returns
    -------
    robustness_loss  : torch.tensor
        Robustness loss as frobenius norm of (batch_size x num_classes x num_features)
    """
    # concept_dim is always 1
    concepts = concepts.squeeze(-1)
    aggregates = aggregates.squeeze(-1)

    batch_size = x.size(0)
    num_concepts = concepts.size(1)
    num_classes = aggregates.size(1)

    # Jacobian of aggregates wrt x
    jacobians = []
    for i in range(num_classes):
        grad_tensor = torch.zeros(batch_size, num_classes).to(x.device)
        grad_tensor[:, i] = 1.
        j_yx = torch.autograd.grad(outputs=aggregates, inputs=x, \
                                   grad_outputs=grad_tensor, create_graph=True, only_inputs=True)[0]
        # bs x 1 x 28 x 28 -> bs x 784 x 1
        jacobians.append(j_yx.view(batch_size, -1).unsqueeze(-1))
    # bs x num_features x num_classes (bs x 784 x 10)
    J_yx = torch.cat(jacobians, dim=2)

    # Jacobian of concepts wrt x
    jacobians = []
    for i in range(num_concepts):
        grad_tensor = torch.zeros(batch_size, num_concepts).to(x.device)
        grad_tensor[:, i] = 1.
        j_hx = torch.autograd.grad(outputs=concepts, inputs=x, \
                                   grad_outputs=grad_tensor, create_graph=True, only_inputs=True)[0]
        # bs x 1 x 28 x 28 -> bs x 784 x 1
        jacobians.append(j_hx.view(batch_size, -1).unsqueeze(-1))
    # bs x num_features x num_concepts
    J_hx = torch.cat(jacobians, dim=2)

    # bs x num_features x num_classes
    robustness_loss = J_yx - torch.bmm(J_hx, relevances)

    return robustness_loss.norm(p='fro')


def BVAE_loss(x, x_hat, z_mean, z_logvar):
    """ Calculate Beta-VAE loss as in [1]

    Parameters
    ----------
    x : torch.tensor
        input data to the Beta-VAE

    x_hat : torch.tensor
        input data reconstructed by the Beta-VAE

    z_mean : torch.tensor
        mean of the latent distribution of shape
        (batch_size, latent_dim)

    z_logvar : torch.tensor
        diagonal log variance of the latent distribution of shape
        (batch_size, latent_dim)

    Returns
    -------
    loss : torch.tensor
        loss as a rank-0 tensor calculated as:
        reconstruction_loss + beta * KL_divergence_loss

    References
    ----------
        [1] Higgins, Irina, et al. "beta-vae: Learning basic visual concepts with
        a constrained variational framework." (2016).
    """
    # recon_loss = F.binary_cross_entropy(x_hat, x.detach(), reduction="mean")
    recon_loss = F.mse_loss(x_hat, x.detach(), reduction="mean")
    kl_loss = kl_div(z_mean, z_logvar)
    return recon_loss, kl_loss

def mse_l1_sparsity(x, x_hat, concepts, sparsity_reg):
    """Sum of Mean Squared Error and L1 norm weighted by sparsity regularization parameter

    Parameters
    ----------
    x : torch.tensor
        Input data to the encoder.
    x_hat : torch.tensor
        Reconstructed input by the decoder.
    concepts : torch.Tensor
        Concept (latent code) activations.
    sparsity_reg : float
        Regularizer (xi) for the sparsity term.

    Returns
    -------
    loss : torch.tensor
        Concept loss
    """
    return F.mse_loss(x_hat, x.detach()) + sparsity_reg * torch.abs(concepts).sum()


def kl_div(mean, logvar):
    """Computes KL Divergence between a given normal distribution
    and a standard normal distribution

    Parameters
    ----------
    mean : torch.tensor
        mean of the normal distribution of shape (batch_size x latent_dim)

    logvar : torch.tensor
        diagonal log variance of the normal distribution of shape (batch_size x latent_dim)

    Returns
    -------
    loss : torch.tensor
        KL Divergence loss computed in a closed form solution
    """
    batch_loss = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).mean(dim=0)
    loss = batch_loss.sum()
    return loss


def zero_loss(*args, **kwargs):
    """Dummy loss that always returns zero.

        Parameters
        ----------
        args : list
            Can take any number of positional arguments (without using them).
        kwargs : dict
            Can take any number of keyword arguments (without using them).

        Returns
        -------
        loss : torch.tensor
            torch.tensor(0)
        """
    return torch.tensor(0)


#### Config

In [9]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


class Config:
    device = get_device()
    batch_size = 128
    epochs = 10
    lr = 1e-4
    image_size = 256
    num_classes = 2
    num_concepts = 5
    concept_dim = 1
    sparsity_reg = 1e-3
    robust_reg = 1.0
    concept_reg = 1.0
    save_path = './results'


config = Config()

In [10]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
])

In [11]:
train_dataset = AnimalsDataset(
    train_df, num_labels, transform
)
test_dataset = AnimalsDataset(
    test_df, num_labels, transform
)

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

In [12]:
conceptizer = ConvConceptizer(
    image_size=config.image_size,
    num_concepts=config.num_concepts,
    concept_dim=config.concept_dim,
    image_channels=3
)
parameterizer = ConvParameterizer(
    num_concepts=config.num_concepts,
    num_classes=config.num_classes,
    cl_sizes=(3, 10, 20),
    hidden_sizes=(74420, 512, config.num_concepts * config.num_classes)

)
aggregator = SumAggregator(num_classes=config.num_classes)

model = SENN(conceptizer, parameterizer, aggregator).to(config.device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [13]:
@dataclass
class TrainHistory:
    train_losses: list[float]

In [None]:
def train_model(model, config, train_loader):
    os.makedirs(config.save_path, exist_ok=True)
    model.train()
    for epoch in range(config.epochs):
        total_loss = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")
        for images, labels in loop:
            images = images.to(config.device)
            labels = labels.to(config.device)
            images.requires_grad_(True)

            outputs, (concepts, relevances), recons = model(images)

            labels = labels.argmax(dim=1)
            classification_loss = F.nll_loss(outputs, labels)
            concept_loss = mse_l1_sparsity(images, recons, concepts, config.sparsity_reg)
            robustness_loss = mnist_robustness_loss(images, outputs, concepts, relevances)
        

            loss = classification_loss + config.concept_reg * concept_loss + config.robust_reg * robustness_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch+1} completed. Avg Loss: {total_loss / len(train_loader):.4f}")
        torch.save(model.state_dict(), os.path.join(config.save_path, f"model_epoch_{epoch+1}.pth"))

    print("Training complete.")

In [15]:
train_model(model, config, train_dataloader)

Epoch 1/10:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch 1/10: 100%|██████████| 47/47 [06:29<00:00,  8.28s/it, loss=1]    


Epoch 1 completed. Avg Loss: 1.0853


Epoch 2/10: 100%|██████████| 47/47 [12:22<00:00, 15.79s/it, loss=0.912]


Epoch 2 completed. Avg Loss: 0.9460


Epoch 3/10: 100%|██████████| 47/47 [06:30<00:00,  8.32s/it, loss=0.815]


Epoch 3 completed. Avg Loss: 0.8564


Epoch 4/10:  21%|██▏       | 10/47 [01:33<05:44,  9.31s/it, loss=0.801]


KeyboardInterrupt: 