# B-cos Networks: Alignment is All We Need for Interpretability

---
In this notebook, we present a new direction for increasing the interpretability of Deep Neural Networks (DNNs) by proposing to replace the linear transforms in DNNs by the **B-cos transform**.

The B-cos transform is designed to be compatible with existing architectures and
we show that it can easily be integrated into common models such as *VGGs*, *ResNets*, *InceptionNets*, and *DenseNets*, whilst maintaining similar performance.

The resulting explanations are of high visual quality and perform well under quantitative metrics for interpretability.

---

## Setup

In [1]:
import argparse
import numpy as np
import os
import pandas as pd
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.checkpoint as cp
import torchvision
import warnings

from collections import namedtuple, OrderedDict
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score
from typing import Any, Callable, Dict, List, Optional, Type, Tuple, Union, cast
from torch import Tensor
from torchvision.transforms import transforms
from torchvision import utils
from torch.hub import load_state_dict_from_url
from torch.utils.data import DataLoader
from tqdm import tqdm

## Data Preparation

---
**Datasets**

We evaluate the accuracy of several B-cos networks on the *CIFAR-10* dataset.

---

In [2]:
class loadData:
    def __init__(self, args):

        self.batch_size = args.batch_size

        self.dataPath = args.cifar10Path
        self.create_paths(self.dataPath)

    @staticmethod
    def create_paths(path):
        if not os.path.exists(path):
            os.makedirs(path)

    def getDataLoader(self):

        cifar10_transforms = transforms.Compose([
                                                  transforms.RandomCrop(32, padding=4),
                                                  transforms.RandomHorizontalFlip(),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                                 ])

        train_loader = DataLoader(torchvision.datasets.CIFAR10(self.dataPath, train=True, download=True, transform=cifar10_transforms), batch_size=self.batch_size, shuffle=True)

        val_loader = DataLoader(torchvision.datasets.CIFAR10(self.dataPath, train=False, download=True, transform=cifar10_transforms), batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader

## Model

---
**The B-cos transform**

Typically, the individual neurons in a DNN compute the dot product between their weights **w** and an input **x**:

     f(x; w) = wᵀ x = ||w|| ||x|| c(x, w) with c(x, w) = cos(∠(x, w)).

Here, `∠(x, w)` returns the angle between the vectors **x** and **w**.

In this work, we seek to improve the interpretability of DNNs by promoting weight-input alignment during optimisation.

To achieve this, we propose the ***B-cos transform***:

     B-cos(x; w) = ||ŵ|| ||x|| |c(x, ŵ)|ᴮ × sgn (c(x, ŵ)).`

Here, *B* is a hyperparameter, the hat-operator scales **ŵ** to unit norm, and `sgn` denotes the *sign* function.

Note that this only introduces minor changes with respect to the first equation; e.g., for *B* = 1, the B-cos transform is equivalent to a linear transform with **ŵ**.

These changes maintain an important property of the linear transform: similar to sequences of linear transforms, sequences of B-cos transforms can still be faithfully summarised by a single linear transform.

---

In [3]:
class NormedConv2d(nn.Conv2d):

    def forward(self, x):
        weight_shape = self.weight.shape

        w_hat = self.weight.view(weight_shape[0], -1)
        w_hat = w_hat/(w_hat.norm(p=2, dim=1, keepdim=True))
        w_hat = w_hat.view(weight_shape)

        return F.conv2d(x, w_hat, self.bias, self.stride, self.padding, self.dilation, self.groups)


class BcosConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, max_out=2, b=2, scale=None, scale_fact=100):
        super().__init__()

        self.NormedConv2d = NormedConv2d(in_channels, out_channels * max_out, kernel_size, stride, padding, dilation=1, groups=1, bias=False)

        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.max_out = max_out
        self.b = b

        if scale is None:
            kernel_size_scale = kernel_size if not isinstance(kernel_size, tuple) else np.sqrt(np.prod(kernel_size))
            self.scale = (kernel_size_scale * np.sqrt(in_channels)) / scale_fact
        else:
            self.scale = scale

        self.detach = False

        self.kernel_size_power = kernel_size**2 if not isinstance(kernel_size, tuple) else np.prod(kernel_size)

    def explanation_mode(self, detach=True):
        self.detach = detach

    def forward(self, x):

        out_normed_conv2d = self.NormedConv2d(x)
        batch_size, _, h, w = out_normed_conv2d.shape

        # MaxOut computation.
        if self.max_out > 1:
            out_normed_conv2d = out_normed_conv2d.view(batch_size, -1, self.max_out, h, w)
            out_normed_conv2d = out_normed_conv2d.max(dim=2, keepdim=False)[0]

        # If B=1, no further calculation necessary.
        if self.b == 1:
            return out_normed_conv2d / self.scale

        # Calculating the norm of input patches.
        norm = (F.avg_pool2d((x**2).sum(1, keepdim=True), self.kernel_size, padding=self.padding, stride=self.stride) * self.kernel_size_power + 1e-6).sqrt_()

        # Get absolute value of cos.
        abs_cos = (out_normed_conv2d/norm).abs() + 1e-6

        # In order to compute the explanations.
        if self.detach:
            abs_cos = abs_cos.detach()

        # Additional factor of cos^(b-1).
        out_normed_conv2d = out_normed_conv2d * abs_cos.pow(self.b-1)

        return out_normed_conv2d / self.scale

---
**B-cos networks**

The B-cos transform is designed as a *drop-in* replacement of the linear transform, i.e., it can be used in exactly the same way.

For example, a conventional fully connected multi-layer neural network f(**x**; θ) of L layers, is represented by:

      `f(x; θ) = lL ◦ lL−1 ◦ ... ◦ l2 ◦ l1(x),`

with lⱼ denoting layer j with parameters **w**ᵏⱼ for neuron k in layer j, and θ the collection of all model parameters.

In such a model, each layer lⱼ typically computes:

      `lⱼ(aⱼ; Wⱼ) = φ(Wⱼ aⱼ),`

with aⱼ the input to layer j, φ a non-linear activation function (e.g., ReLU), and the row k of Wⱼ given by the weight vector **w**ᵏⱼ of the k-th neuron in that layer.

A corresponding **B-cos network** f with layers lⱼ can be formulated in exactly the same way, with the only difference being that every dot product (here between rows of Wⱼ and the input aⱼ) is replaced by the B-cos transform.

In matrix form, this equates to:

       lⱼ(aⱼ; Wⱼ) = |c(aⱼ; Ŵⱼ)|^(B-1) × (Ŵⱼ aⱼ),`

Here, the power, absolute value, and `×` operators are applied element-wise, `c(aⱼ; Ŵⱼ)` computes the cosine similarity between input aⱼ and the rows of Ŵⱼ, and the hat operator scales the rows of Ŵⱼ to unit norm.

Finally, note that for *B* > 1 the layer transform lⱼ is non-linear.
As a result, a non-linearity φ is not required for a B-cos network to model non-linear relationships.

---
---

**MaxOut to increase modelling capacity**

As discussed, a deep B-cos network with *B* > 1 does not require a non-linearity between subsequent layers to model non-linear relationships.
This, of course, does not mean that it could not benefit from it.

In this work, we specifically explore the option of combining the B-cos
transform with the MaxOut operation.
In particular, we model every neuron in a B-cos network by 2 B-cos transforms of which the maximal activation is forwarded:

       MaxOut(x) = max{B-cos(x; wᵢ)} with i∈{1,2},

We noticed that networks with the MaxOut operation were much easier to optimise with respect to the ReLU operation.

---

---
**Advanced B-cos networks**

To test the generality of this approach, we evaluate how integrating the B-cos transform into commonly used DNN architectures affects their classification performance and interpretability.

In order to "convert" such models to B-cos networks we proceed as follows:

*   First, every convolutional kernel / fully connected layer is replaced by the corresponding B-cos version with two MaxOut units.

*   Secondly, any other non-linearities (e.g., ReLU, MaxPool, etc.), as well as any batch norm layers are removed to maintain the alignment pressure and to ensure that the model can be summarised via a single linear transform.

---
---

**Models**

For the experiments, we rely on the publicly available implementations of the VGG-11, ResNet-34, InceptionNet (v3), and DenseNet-121 model architectures. We adapt those architectures to B-cos networks as described before.

---

### ResNet-34

In [12]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> BcosConv2d:
    """3x3 convolution with padding"""
    return BcosConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> BcosConv2d:
    """1x1 convolution without padding"""
    return BcosConv2d(in_planes, out_planes, kernel_size=1, stride=stride)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = conv3x3(in_planes, planes, stride)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(conv1x1(in_planes, self.expansion*planes, stride))

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out += self.shortcut(x)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()

        self.conv1 = conv1x1(in_planes, planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.conv3 = conv1x1(planes, planes * self.expansion)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(conv1x1(in_planes, self.expansion*planes, stride))

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out += self.shortcut(x)
        return out


class ResNet(nn.Module):

    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()

        self.in_planes = 64

        self.conv1 = BcosConv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3)
        self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.linear = BcosConv2d(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):

        out = self.conv1(x)
        out = self.avgpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.linear(out)
        out = out.view(out.size(0), -1)

        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

### VGG-11

In [5]:
import math

import torch.nn as nn
import torch.nn.init as init

__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19',
]


class VGG(nn.Module):

    def __init__(self, features, num_classes=10):
        super(VGG, self).__init__()

        self.features = features

        self.classifier = nn.Sequential(
                                        BcosConv2d(512, 4096, kernel_size=7, padding=3, scale_fact=1000),
                                        BcosConv2d(4096, 4096, scale_fact=1000),
                                        BcosConv2d(4096, num_classes, scale_fact=1000),
                                        )

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        x = x.view(x.size(0), -1)
        return x


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3

    for v in cfg:
        if v != 'M':
            conv2d = BcosConv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v

    return nn.Sequential(*layers)


cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M'],
}


def vgg11():
    """VGG 11-layer model (configuration "A")"""
    return VGG(make_layers(cfg['A']))


def vgg11_bn(channels=3):
    """VGG 11-layer model (configuration "A") with batch normalization"""
    return VGG(make_layers(cfg['A'], batch_norm=True))


def vgg13():
    """VGG 13-layer model (configuration "B")"""
    return VGG(make_layers(cfg['B']))


def vgg13_bn():
    """VGG 13-layer model (configuration "B") with batch normalization"""
    return VGG(make_layers(cfg['B'], batch_norm=True))


def vgg16():
    """VGG 16-layer model (configuration "D")"""
    return VGG(make_layers(cfg['D']))


def vgg16_bn(channels=3):
    """VGG 16-layer model (configuration "D") with batch normalization"""
    return VGG(make_layers(cfg['D'], batch_norm=True))


def vgg19():
    """VGG 19-layer model (configuration "E")"""
    return VGG(make_layers(cfg['E']))


def vgg19_bn():
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    return VGG(make_layers(cfg['E'], batch_norm=True))

### DenseNet-121

In [6]:
__all__ = ['DenseNet', 'densenet121']

model_url = {'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth'}


class _DenseLayer(nn.Module):
    def __init__(
        self,
        num_input_features: int,
        growth_rate: int,
        bn_size: int,
        drop_rate: float,
        memory_efficient: bool = False,
        b_exp=2,
        max_out=2
    ) -> None:
        super(_DenseLayer, self).__init__()

        self.conv1 = BcosConv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, max_out=max_out, padding=0, b=b_exp)

        self.conv2 = BcosConv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, max_out=max_out, padding=1, b=b_exp)

        self.drop_rate = float(drop_rate)
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(concated_features)
        return bottleneck_output

    def any_requires_grad(self, input: List[Tensor]) -> bool:
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:

        def closure(*inputs):
            return self.bn_function(inputs)

        return cp.checkpoint(closure, *input)

    @torch.jit._overload_method
    def forward(self, input: List[Tensor]) -> Tensor:
        pass

    @torch.jit._overload_method
    def forward(self, input: Tensor) -> Tensor:
        pass

    # torchscript does not yet support *args, so we overload method
    # allowing it to take either a List[Tensor] or single Tensor
    def forward(self, input: Tensor) -> Tensor:

        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")
            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(bottleneck_output)

        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)

        return new_features


class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
        memory_efficient: bool = False,
        b_exp = 2,
        max_out = 2
    ) -> None:
        super(_DenseBlock, self).__init__()

        for i in range(num_layers):
            layer = _DenseLayer(
                                num_input_features + i * growth_rate,
                                growth_rate=growth_rate,
                                bn_size=bn_size,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient,
                                b_exp=b_exp,
                                max_out=max_out
                                )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features: int, num_output_features: int, b_exp=2, max_out=2) -> None:
        super(_Transition, self).__init__()
        self.conv = BcosConv2d(num_input_features, num_output_features, kernel_size=1, stride=1, padding=0, b=b_exp, max_out=max_out)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
    """

    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 10,
        memory_efficient: bool = False,
        b_exp = 2,
        max_out = 2
    ) -> None:

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
                                                  ('conv0', BcosConv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, b=b_exp, max_out=max_out)),
                                                  ('pool0', nn.AvgPool2d(kernel_size=3, stride=2, padding=1)),
                                                  ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):

            block = _DenseBlock(
                                  num_layers=num_layers,
                                  num_input_features=num_features,
                                  bn_size=bn_size,
                                  growth_rate=growth_rate,
                                  drop_rate=drop_rate,
                                  memory_efficient=memory_efficient,
                                  b_exp=b_exp,
                                  max_out=max_out
                                )

            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate

            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, b_exp=b_exp, max_out=max_out)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        self.classifier = BcosConv2d(num_features, num_classes, kernel_size=1, stride=1, padding=0, b=b_exp, max_out=max_out)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def get_features(self, x):
        return self.features(x)

    def get_sequential_model(self):
        model = nn.Sequential(*[m for m in self.features], self.classifier)
        return model

    def get_layer_idx(self, idx):
        return idx

    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        out = self.classifier(features)
        out = out.view(out.shape[0], -1)
        return out


def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.

    pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)


def _densenet(
    arch: str,
    growth_rate: int,
    block_config: Tuple[int, int, int, int],
    num_init_features: int,
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> DenseNet:

    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)

    if pretrained:
        _load_state_dict(model, model_url[arch], progress)

    return model


def densenet121(pretrained: bool = False, progress: bool = True, num_init_features=64, growth_rate=32, **kwargs: Any) -> DenseNet:

    r"""Densenet-121 model from "Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet.
        progress (bool): If True, displays a progress bar of the download to stderr.
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
    """

    return _densenet('densenet121', growth_rate, (6, 12, 24, 16), num_init_features, pretrained, progress, **kwargs)

### InceptionNet (v3)

In [7]:
__all__ = ['Inception3','InceptionOutputs', '_InceptionOutputs', 'inception_v3']

model_url = {'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'}

InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs


class Inception3(nn.Module):

    def __init__(
        self,
        num_classes: int = 10,
        aux_logits: bool = False,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
        init_weights: Optional[bool] = None
    ) -> None:

        super(Inception3, self).__init__()

        if inception_blocks is None:
            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]

        if init_weights is None:
            warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
            init_weights = True

        assert len(inception_blocks) == 7
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

        self.aux_logits = aux_logits
        self.transform_input = transform_input

        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
        self.avgpool1 = nn.AvgPool2d(kernel_size=3, stride=2)

        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
        self.avgpool2 = nn.AvgPool2d(kernel_size=3, stride=2)

        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)

        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)

        self.AuxLogits: Optional[nn.Module] = None
        if aux_logits:
            self.AuxLogits = inception_aux(768, num_classes)

        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)

        self.fc = BcosConv2d(2048, num_classes, kernel_size=1, stride=1, padding=0, scale_fact=200)

        self.debug = False

        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                    import scipy.stats as stats
                    stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                    X = stats.truncnorm(-2, 2, scale=stddev)
                    values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                    values = values.view(m.weight.size())
                    with torch.no_grad():
                        m.weight.copy_(values)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def get_features(self, x):
        return self.get_sequential_model()[:-1](x)

    def _transform_input(self, x: Tensor) -> Tensor:
        return x

    def get_sequential_model(self):
        """ For evaluation purposes only, to extract layers at roughly the same relative network depth between different models. """
        model = nn.Sequential(
                                self.Conv2d_1a_3x3,
                                self.Conv2d_2a_3x3,
                                self.Conv2d_2b_3x3,
                                self.avgpool1,
                                self.Conv2d_3b_1x1,
                                self.Conv2d_4a_3x3,
                                self.avgpool2,
                                self.Mixed_5b,
                                self.Mixed_5c,
                                self.Mixed_5d,
                                self.Mixed_6a,
                                self.Mixed_6b,
                                self.Mixed_6c,
                                self.Mixed_6d,
                                self.Mixed_6e,
                                self.Mixed_7a,
                                self.Mixed_7b,
                                self.Mixed_7c,
                                self.fc
                              )
        return model

    def get_layer_idx(self, idx):
        """ For evaluation purposes only, to extract layers at roughly the same relative network depth between different models. """
        return int(np.ceil(len(self.get_sequential_model())*idx/10))

    def print(self, layer_name, x):
        if self.debug:
            print(layer_name, x.shape)

    def _forward(self, x: Tensor):
        # N x 3 x 299 x 299
        x = self.Conv2d_1a_3x3(x)
        # N x 32 x 149 x 149
        x = self.Conv2d_2a_3x3(x)
        # N x 32 x 147 x 147
        x = self.Conv2d_2b_3x3(x)
        # N x 64 x 147 x 147
        x = self.avgpool1(x)
        # N x 64 x 73 x 73
        x = self.Conv2d_3b_1x1(x)
        # N x 80 x 73 x 73
        x = self.Conv2d_4a_3x3(x)
        # N x 192 x 71 x 71
        x = self.avgpool2(x)
        # N x 192 x 35 x 35
        x = self.Mixed_5b(x)
        # N x 256 x 35 x 35
        x = self.Mixed_5c(x)
        # N x 288 x 35 x 35
        x = self.Mixed_5d(x)
        # N x 288 x 35 x 35
        x = self.Mixed_6a(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6b(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6c(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6d(x)
        # N x 768 x 17 x 17
        x = self.Mixed_6e(x)
        # N x 768 x 17 x 17
        aux: Optional[Tensor] = None
        if self.AuxLogits is not None:
            if self.training:
                self.aux_out = self.AuxLogits(x)
        # N x 768 x 17 x 17
        x = self.Mixed_7a(x)
        # N x 1280 x 8 x 8
        x = self.Mixed_7b(x)
        # N x 2048 x 8 x 8
        x = self.Mixed_7c(x)
        # N x 2048 x 1 x 1
        x = self.fc(x)
        # N x 1000 (num_classes)
        x = x.view(x.shape[0], -1)
        return x

    @torch.jit.unused
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]):
        return x

    def forward(self, x: Tensor) -> InceptionOutputs:
        x = self._transform_input(x)
        x = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, None)
        else:
            return self.eager_outputs(x, None)


class InceptionA(nn.Module):

    def __init__(
        self,
        in_channels: int,
        pool_features: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:

        super(InceptionA, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d

        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)

        self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)

        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = self.pool(x)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):

    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:

        super(InceptionB, self).__init__()

        if conv_block is None:
            conv_block = BasicConv2d

        self.branch3x3 = conv_block(in_channels, 384, kernel_size=1, stride=2)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=1, stride=2)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = self.pool(x)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):

    def __init__(
        self,
        in_channels: int,
        channels_7x7: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:

        super(InceptionC, self).__init__()

        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)

        c7 = channels_7x7
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))

        self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))

        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = self.pool(x)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):

    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:

        super(InceptionD, self).__init__()

        if conv_block is None:
            conv_block = BasicConv2d

        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=1, stride=2)

        self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=1, stride=2)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = self.pool(x)

        outputs = [branch3x3, branch7x7x3, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):

    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:

        super(InceptionE, self).__init__()

        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)

        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
                      self.branch3x3_2a(branch3x3),
                      self.branch3x3_2b(branch3x3),
                    ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
                          self.branch3x3dbl_3a(branch3x3dbl),
                          self.branch3x3dbl_3b(branch3x3dbl),
                       ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = self.pool(x)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:

        super(InceptionAux, self).__init__()

        if conv_block is None:
            conv_block = BasicConv2d

        self.conv0 = conv_block(in_channels, 128, kernel_size=1)

        self.pool = nn.AvgPool2d(kernel_size=5, stride=3)

        self.conv1 = conv_block(128, 768, kernel_size=5)
        self.conv1.stddev = 0.01

        self.fc = BcosConv2d(768, num_classes, kernel_size=1, stride=1, padding=0, scale_fact=200)
        self.fc.stddev = 0.001

    def forward(self, x: Tensor) -> Tensor:
        # N x 768 x 17 x 17
        x = self.pool(x)
        # N x 768 x 5 x 5
        x = self.conv0(x)
        # N x 128 x 5 x 5
        x = self.conv1(x)
        x = self.fc(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))[..., 0, 0]
        return x


class BasicConv2d(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        **kwargs: Any
    ) -> None:

        super(BasicConv2d, self).__init__()

        self.conv = BcosConv2d(in_channels, out_channels, scale_fact=200, **kwargs)

    def forward(self, x: Tensor) -> Tensor:
        return self.conv(x)


def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":

    r"""Inception v3 model architecture from "Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>_.

    .. note:: **Important**: In contrast to the other models the inception_v3 expects tensors with a size of N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet.
        progress (bool): If True, displays a progress bar of the download to stderr.
        aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True*.
        transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: *False*.
    """

    if pretrained:

        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True

        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True

        kwargs['init_weights'] = False  # We are loading weights from a pretrained model.

        model = Inception3(**kwargs)

        state_dict = load_state_dict_from_url(model_url['inception_v3_google'], progress=progress)
        model.load_state_dict(state_dict)

        if not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None
        return model

    return Inception3(**kwargs)

## Training

---
**Training**

We trained our models for 5 epochs with Adam, an initial learning rate of 0.001 and a batch size of 64.

---

In [8]:
def getArgs():
    parser = argparse.ArgumentParser(description="BCOS_TRAINING")
    parser.add_argument('-f')
    args = parser.parse_args()

    args.batch_size = 64
    args.dataset = 'CIFAR10'
    args.epochs = 15
    args.learning_rate = 0.001
    args.model_name = 'ResNet34'
    args.model = ResNet34()

    args.cifar10Path = '/var/datasets/CIFAR10'

    args.save_losses = '/content/'
    args.save_ckpt = '/content/ckpt/'

    args.load_ckpt = f'/content/ckpt/Epoch_{args.epochs}.pt'
    args.save_results_path = '/content/'

    return args

---

**Optimising B-cos networks for classification**

*   First, note that the output of each neuron is bounded.
Since the output of a B-cos network is computed as a sequence of such bounded transforms, the output of the network as a whole is also bounded.

*   Secondly, note that a B-cos network as a whole can only achieve its upper bound for a given input if the units in each layer achieve their upper bound.
The individual units, in turn, can only achieve their maxima by aligning with their inputs.
Hence, optimising a B-cos network to maximise its output over a set of inputs will optimise the model weights to align with those inputs.

In order to take advantage of this when optimising for classification, we train the B-cos networks with the **Binary Cross Entropy** (**BCE**) loss:

      L(xᵢ, yᵢ) = BCE(σ(f(xᵢ; θ) + b), yᵢ)`,
for input xᵢ and its corresponding one-hot encoded class label yᵢ.
Here, σ denotes the sigmoid function, b a bias, and θ the model parameters.

We choose the BCE loss because it directly entails output maximisation. Specifically, in order to reduce the BCE loss, the network is optimised to maximise the (negative) class logit for the correct (incorrect) classes.

Finally, note that increasing B allows to specifically reduce the output of badly aligned weights in each layer.
This will decrease the layer’s output strength and thus the output of the network as a whole for badly aligned weights, which increases the alignment pressure during optimisation (thus, **higher B** → **higher alignment**).

---

In [11]:
class Trainer:

    def __init__(self, args):

        self.loader = loadData(args)
        self.model = args.model
        self.create_paths(args.save_ckpt, args.save_losses)

    @staticmethod
    def create_paths(ckpt_path, losses_path):

        if not os.path.exists(ckpt_path):
            os.makedirs(ckpt_path)

        if not os.path.exists(losses_path):
            os.makedirs(losses_path)

    def training(self, args):

        # LOADING DATA #
        train_dataloader, val_dataloader = self.loader.getDataLoader()

        # TRAINING PARAMETERS #
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=args.learning_rate)

        all_train_loss = []
        all_val_loss = []

        # print(self.model)

        #############################
        #       TRAINING LOOP       #
        #############################

        print(f'\nSTARTING TRAINING WITH {args.model_name} FOR {args.dataset}')

        for epoch in range(args.epochs):
            epoch_train_loss = []
            epoch_val_loss = []

            self.model.train()
            for imgs, labels in tqdm(train_dataloader):
                labels = F.one_hot(labels, num_classes=10)

                logits = self.model(imgs)

                optimizer.zero_grad()

                loss = criterion(logits, labels.float())
                loss.backward()

                optimizer.step()

                epoch_train_loss.append(loss.detach().numpy())

            self.model.eval()
            for imgs, labels in tqdm(val_dataloader):
                labels = F.one_hot(labels, num_classes=10)

                logits = self.model(imgs)

                loss = criterion(logits, labels.float())

                epoch_val_loss.append(loss.detach().numpy())

            train_loss, val_loss = np.mean(epoch_train_loss), np.mean(epoch_val_loss)

            all_train_loss.append(train_loss)
            all_val_loss.append(val_loss)

            if epoch % 10 == 0:
                np.save(os.path.join(args.save_losses, 'train_loss.npy'), all_train_loss)
                np.save(os.path.join(args.save_losses, 'val_loss.npy'), all_val_loss)
                torch.save(self.model.state_dict(), os.path.join(args.save_ckpt, f'Epoch_{args.epochs}.pt'))

        print(f'\n--- FINISHED {args.model_name} TRAINING ---')


args = getArgs()

# TRAINING
Trainer(args).training(args)

Files already downloaded and verified
Files already downloaded and verified

STARTING TRAINING WITH ResNet34 FOR CIFAR10


 10%|█         | 79/782 [00:54<08:07,  1.44it/s]


KeyboardInterrupt: ignored

## Inference

---

**Evaluating explanations**

We evaluate the B-cos networks across all models to investigate which one provides the best explanation.

This makes it possible to compare explanations between different models and to evaluate the explainability gain achieved by converting conventional models to B-cos networks.

---

In [None]:
def getResults(y_true, y_pred):

    conf_matrix = confusion_matrix(y_true, y_pred)

    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)

    dict_metrics = {'accuracy' : [accuracy],
                    'f1-score' : [f1],
                    'precision' : [precision],
                    'recall' : [recall]}

    df = pd.DataFrame.from_dict(dict_metrics)

    print('\n\n-- CONFUSION MATRIX --\n')
    print(conf_matrix)

    print('\n-- CLASSIFICATION METRICS --\n')
    print(df.to_string(index=False))

    return dict_metrics


# From the original implementation
def getExplanationImage(img, grads, smooth=15, alpha_percent=99.5):

    contribution = (img*grads).sum(0, keepdim=True)
    # print('contribution', contribution.shape)

    rgb_grad = (grads / (grads.abs().max(0, keepdim=True)[0] + 1e-12))
    rgb_grad = rgb_grad.clamp(0).numpy()
    # print('rgb_grad', rgb_grad.shape)

    alpha = (grads.norm(p=2, dim=0, keepdim=True))
    # print('alpha', alpha.shape)

    # Only show positive contributions
    alpha = torch.where(contribution < 0, torch.zeros_like(alpha) + 1e-12, alpha)

    if smooth:
        alpha = F.avg_pool2d(alpha, smooth, stride=1, padding=(smooth-1)//2)

    alpha = alpha.numpy()
    alpha = (alpha / np.percentile(alpha, alpha_percent)).clip(0, 1)
    # print('alpha', alpha.shape)

    rgb_grad = np.concatenate([rgb_grad, alpha], axis=0)[1:]
    # print('rgb_grad', rgb_grad.shape)

    # Reshaping to [H, W, C]
    grad_image = rgb_grad.transpose((1, 2, 0))
    grad_image = torch.tensor(grad_image)
    grad_image = grad_image.permute(2, 0, 1)
    # print('grad_image', grad_image.shape)

    return grad_image


class Inference:

    def __init__(self, args):

        self.loader = loadData(args)
        self.model = self.getModels(args)

    def getModels(self, args):

        # LOADING MODEL
        model = args.model

        if args.load_ckpt is not None:
            path = args.load_ckpt.split('ckpt/')[-1]
            print(f'LOADING MODEL: {path}\n')
            model.load_state_dict(torch.load(args.load_ckpt), strict=False)
        else:
            print('LOADING MODEL: no checkpoint -> initializing randomly\n')

        return model

    def evaluateMetrics(self, args):

        # LOADING DATA #
        train_dataloader, val_dataloader = self.loader.getDataLoader()

        # TRAINING PARAMETERS #
        criterion = nn.BCEWithLogitsLoss()

        #############################
        #       INFERENCE LOOP      #
        #############################

        print(f'\nSTARTING CLASSIFICATION WITH {args.model_name} FOR {args.dataset}')

        all_pred = []
        all_labels = []

        self.model.eval()
        for imgs, labels in tqdm(val_dataloader):

            self.model.zero_grad()

            logits = self.model(imgs)

            probs = F.softmax(logits, dim=-1)

            max_values, pred = torch.max(probs, dim=-1)

            pred, labels = list(pred.detach().numpy()), list(labels.detach().numpy())
            all_pred = np.concatenate((all_pred, pred))
            all_labels = np.concatenate((all_labels, labels))

        all_pred = torch.tensor(all_pred)
        all_labels = torch.tensor(all_labels)

        getResults(all_labels, all_pred)

    def getExplanations(self, args):

        # LOADING DATA #
        train_dataloader, val_dataloader = self.loader.getDataLoader()

        # TRAINING PARAMETERS #
        criterion = nn.BCEWithLogitsLoss()

        #############################
        #       INFERENCE LOOP      #
        #############################

        print(f'\nSTARTING EXPLANATIONS GENERATION WITH {args.model_name} FOR {args.dataset}')

        ori_images = []
        exp_images = []

        self.model.eval()
        for imgs, labels in tqdm(val_dataloader):

            # from the original implementation
            imgs = imgs.requires_grad_(True)

            logits = self.model(imgs)
            max_logit = self.model(imgs).max()

            imgs.grad = None
            self.model.zero_grad()
            max_logit.backward()

            probs = F.softmax(logits, dim=-1)
            max_values, max_indices = torch.max(logits, dim=-1)
            max_probs, max_img_in_batch = torch.max(max_values, dim=0)

            explanation = getExplanationImage(imgs[max_img_in_batch], imgs.grad[max_img_in_batch])

            ori_images.append(imgs[max_img_in_batch])
            exp_images.append(explanation)

        ori_images = torch.stack(ori_images)
        # print(ori_images.shape)

        exp_images = torch.stack(exp_images)
        # print(exp_images.shape)

        real_fake_images = torch.cat((ori_images[:5], exp_images.add(1).mul(0.5)[:5]))
        utils.save_image(real_fake_images, os.path.join(args.save_results_path, 'explanation_results.jpg'), nrow=5)

        print(f'\n\n--- FINISHED {args.model_name} EXPLANATIONS GENERATION ---')


args = getArgs()

# Inference(args).evaluateMetrics(args)
Inference(args).getExplanations(args)