This file contains the procedure for creating Counterfactual Maps and visualizing them with the help of the MNIST Dataset

## Install dependencies

In [5]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

## Building a basic classifier

In [6]:
class MNISTDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.Dropout(),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, 1)
        x = self.fc_layers(x)
        return x

## Pre-requisites

In [9]:
from typing import Union, List
import numpy as np
from PIL import Image as PilImage

class Image(Data):
    data_type = "image"

    def __init__(self, data: Union[np.ndarray, PilImage.Image] = None, batched: bool = False, channel_last: bool = True):
        super().__init__()
        if data is None:
            self.data = None
        elif isinstance(data, np.ndarray):
            self.data = self._check_and_unify(data, batched, channel_last)
        elif isinstance(data, PilImage.Image):
            self.data = self._check_and_unify(np.array(data), batched=False, channel_last=True)
        else:
            raise ValueError(f"`data` must be `np.ndarray` or `PIL.Image`, not {type(data)}")

    @staticmethod
    def _check_and_unify(data: np.ndarray, batched: bool, channel_last: bool):
        if batched:
            assert data.ndim in [3, 4], (
                f"`batched = True`, but got shape {data.shape}"
            )
            img = data[0]
        else:
            assert data.ndim in [2, 3], (
                f"`batched = False`, but got shape {data.shape}"
            )
            img = data

        if img.ndim == 3:
            if channel_last:
                assert img.shape[2] <= 4, "Last dimension should be color channels."
            else:
                assert img.shape[0] <= 4, "First dimension should be color channels."

        if not batched:
            data = np.expand_dims(data, axis=0)
        if data.ndim == 4 and not channel_last:
            data = np.transpose(data, (0, 2, 3, 1))
        elif data.ndim == 3:
            data = np.expand_dims(data, axis=-1)
        return data

    def __len__(self) -> int:
        return self.data.shape[0]

    def __getitem__(self, i: Union[int, slice, list]):
        if isinstance(i, int):
            return Image(self.data[i:i+1], batched=True, channel_last=True)
        return Image(self.data[i], batched=True, channel_last=True)

    def __iter__(self):
        return (self[i] for i in range(self.shape[0]))

    def __repr__(self):
        return repr(self.data)

    @property
    def shape(self) -> tuple:
        return self.data.shape

    @property
    def image_shape(self) -> tuple:
        return self.data.shape[1:]

    @property
    def values(self) -> np.ndarray:
        return self.data

    def num_samples(self) -> int:
        return self.data.shape[0]

    def to_numpy(self, hwc=True, copy=True, keepdim=False) -> np.ndarray:
        data = self.data.copy() if copy else self.data
        if not keepdim and self.shape[-1] == 1:
            return data.squeeze(axis=-1)
        return data if hwc else np.transpose(data, (0, 3, 1, 2))

    def to_pil(self) -> Union[PilImage.Image, List[PilImage.Image]]:
        x = self.data.squeeze(axis=-1) if self.shape[-1] == 1 else self.data
        if self.shape[0] == 1:
            return PilImage.fromarray(x[0].astype(np.uint8))
        return [PilImage.fromarray(x[i].astype(np.uint8)) for i in range(self.shape[0])]

    def copy(self):
        return Image(data=self.data.copy(), batched=True, channel_last=True)

## Load the dataset

In [10]:
# Load the MNIST training and test datasets
train_data = torchvision.datasets.MNIST(root='../data', train=True, download=True)
test_data = torchvision.datasets.MNIST(root='../data', train=False, download=True)

# Convert image data to NumPy arrays
train_data.data = train_data.data.numpy()
test_data.data = test_data.data.numpy()

# Define class labels
class_names = tuple(range(10))

# Wrap datasets with the custom Image class
x_train, y_train = Image(train_data.data, batched=True), train_data.targets
x_test, y_test = Image(test_data.data, batched=True), test_data.targets

100%|██████████| 9.91M/9.91M [00:00<00:00, 52.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 2.19MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.2MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.31MB/s]


## Setting up your device

In [13]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize the CNN model
model = MNISTClassifier().to(device)

# Define preprocessing function
transform = transforms.Compose([transforms.ToTensor()])
preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])

## Training and Evaluation

In [15]:
# Hyperparameters
learning_rate = 1e-3
batch_size = 128
num_epochs = 10

# Data loaders
train_loader = DataLoader(
    dataset=MNISTDataset(preprocess(x_train), y_train),
    batch_size=batch_size,
    shuffle=True
)
test_loader = DataLoader(
    dataset=MNISTDataset(preprocess(x_test), y_test),
    batch_size=batch_size,
    shuffle=False
)

# Optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_func = nn.CrossEntropyLoss()

# Training loop
model.train()
for epoch in range(num_epochs):
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        loss = loss_func(model(x), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluation
correct_pred = {name: 0 for name in class_names}
total_pred = {name: 0 for name in class_names}

model.eval()
for x, y in test_loader:
    x, y = x.to(device), y.to(device)
    outputs = model(x)
    _, preds = torch.max(outputs, 1)
    for label, pred in zip(y, preds):
        if label == pred:
            correct_pred[class_names[label]] += 1
        total_pred[class_names[label]] += 1

# Print class-wise accuracy
for name, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[name]
    print(f"Accuracy for class {name} is: {accuracy:.1f} %")

Accuracy for class 0 is: 99.5 %
Accuracy for class 1 is: 99.7 %
Accuracy for class 2 is: 99.2 %
Accuracy for class 3 is: 98.7 %
Accuracy for class 4 is: 98.6 %
Accuracy for class 5 is: 99.1 %
Accuracy for class 6 is: 99.2 %
Accuracy for class 7 is: 97.5 %
Accuracy for class 8 is: 99.1 %
Accuracy for class 9 is: 97.4 %


## The Explainer Module

In [16]:
from collections import defaultdict
import inspect

_EXPLAINERS = defaultdict(list)

#Metaclass for registering explainer classes into `_EXPLAINERS`
class ExplainerABCMeta(AutodocABCMeta):
    def __new__(mcls, classname, bases, cls_dict):
        cls = super().__new__(mcls, classname, bases, cls_dict)
        if not inspect.isabstract(cls):
            module_name = cls.__module__.split(".")[2]
            class_name = cls.__name__
            if class_name in _EXPLAINERS[module_name]:
                raise RuntimeError(
                    f"Explainer class `{class_name}` already exists in `{module_name}`. Please use a unique name."
                )
            _EXPLAINERS[module_name].append(cls)
        return cls

In [17]:
import inspect
import os
import dill
from copy import deepcopy
from abc import abstractmethod

class ExplainerBase(metaclass=AutodocABCMeta):
    def __init__(self):
        pass

    @abstractmethod
    def explain(self, **kwargs):
        raise NotImplementedError

    @property
    def explanation_type(self):
        return "local"

    def __getstate__(self):
        return {k: deepcopy(v) for k, v in self.__dict__.items()}

    def __setstate__(self, state):
        for name, value in state.items():
            setattr(self, name, value)

    def save(self, directory: str, filename: str = None, **kwargs):
        os.makedirs(directory, exist_ok=True)
        if filename is None:
            filename = f"{type(self).__name__}.pkl"
        state = self.__getstate__()
        for attr in kwargs.get("ignored_attributes", []):
            state.pop(attr, None)
        with open(os.path.join(directory, filename), "wb") as f:
            dill.dump(state, f)

    @classmethod
    def load(cls, directory: str, filename: str = None, **kwargs):
        if filename is None:
            filename = f"{cls.__name__}.pkl"
        with open(os.path.join(directory, filename), "rb") as f:
            state = dill.load(f)
        instance = super(ExplainerBase, cls).__new__(cls)
        instance.__setstate__(state)
        return instance

In [18]:
# Abstract base class for explanation results
class ExplanationBase(metaclass=AutodocABCMeta):
    @abstractmethod
    def get_explanations(self, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def plot(self, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def plotly_plot(self, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def ipython_plot(self, **kwargs):
        raise NotImplementedError

    def dump(self, file):
        dill.dump(self, file)

    def load(self, file):
        return dill.load(file)

    def dumps(self):
        return dill.dumps(self)

    def loads(self, byte_string):
        return dill.loads(byte_string)

    @staticmethod
    def _s(s, max_len=15):
        if isinstance(s, str):
            return s[:max_len] + "*" if len(s) > max_len else s
        elif isinstance(s, float):
            return int(s) if s.is_integer() else "{:.3f}".format(s)
        return s

    def to_json(self):
        import json
        from .utils import DefaultJsonEncoder
        return json.dumps(self, cls=DefaultJsonEncoder)

    @classmethod
    def from_json(cls, s):
        import json
        d = json.loads(s)
        return ExplanationBase.from_dict(d)

    @classmethod
    def from_dict(cls, d):
        import importlib
        module = importlib.import_module(d["module"])
        explanation_class = getattr(module, d["class"])
        return explanation_class.from_dict(d["data"])

In [20]:
# The main function which handles counterfactual explanations for images
class CFExplanation(ExplanationBase):
    def __init__(self):
        super().__init__()
        self.explanations = []

    def __repr__(self):
        return repr(self.explanations)

    def add(self, image, label, cf, cf_label):
        self.explanations.append({
            "image": image,
            "label": label,
            "cf": cf,
            "cf_label": cf_label
        })

    def get_explanations(self, index=None):
        return self.explanations if index is None else self.explanations[index]

    @staticmethod
    def _rescale(im):
        min_val, max_val = np.min(im), np.max(im)
        im = (im - min_val) / (max_val - min_val + 1e-8) * 255
        if im.ndim == 2:  # Grayscale to RGB
            im = np.tile(im[..., np.newaxis], (1, 1, 3))
        return im.astype(np.uint8)

    def plot(self, index=None, class_names=None, **kwargs):
        import matplotlib.pyplot as plt
        import warnings

        explanations = self.get_explanations(index)
        explanations = (
            {index: explanations} if isinstance(explanations, dict) else {i: e for i, e in enumerate(explanations)}
        )

        indices = sorted(explanations.keys())
        if len(indices) > 5:
            warnings.warn(f"Too many instances ({len(indices)} > 5), only showing first 5.")
            indices = indices[:5]
        if not indices:
            return None

        num_rows = len(indices)
        fig, axes = plt.subplots(num_rows, 3, figsize=(9, 3 * num_rows))

        for i, idx in enumerate(indices):
            e = explanations[idx]
            titles = [
                class_names[e['label']] if class_names else str(e['label']),
                f"CF: {class_names[e['cf_label']]}" if class_names else f"CF: {e['cf_label']}",
                "Difference"
            ]
            images = [e["image"], e["cf"], np.abs(e["cf"] - e["image"])]

            for j in range(3):
                ax = axes[i, j] if num_rows > 1 else axes[j]
                ax.imshow(self._rescale(images[j]))
                ax.set_title(titles[j])
                ax.axis("off")

        plt.tight_layout()
        return fig

    def _plotly_figure(self, index, class_names=None, **kwargs):
        import plotly.express as px
        from plotly.subplots import make_subplots

        e = self.explanations[index]
        labels = [
            class_names[e['label']] if class_names else str(e['label']),
            f"CF: {class_names[e['cf_label']]}" if class_names else f"CF: {e['cf_label']}",
            "Difference"
        ]
        fig = make_subplots(rows=1, cols=3, subplot_titles=labels)

        for i, img in enumerate([e["image"], e["cf"], np.abs(e["cf"] - e["image"])]):
            px_fig = px.imshow(self._rescale(img))
            fig.add_trace(px_fig.data[0], row=1, col=i + 1)

        fig.update_layout(showlegend=False)
        fig.update_xaxes(visible=False)
        fig.update_yaxes(visible=False)
        return fig

    def plotly_plot(self, index=0, class_names=None, **kwargs):
        assert index is not None, "`index` must be provided for plotly_plot."
        return DashFigure(self._plotly_figure(index, class_names=class_names, **kwargs))

    def ipython_plot(self, index=0, class_names=None, **kwargs):
        import plotly
        assert index is not None, "`index` must be provided for ipython_plot."
        return plotly.offline.iplot(self._plotly_figure(index, class_names=class_names, **kwargs))

    @classmethod
    def from_dict(cls, d):
        explanations = [
            {
                "image": np.array(e["image"]),
                "label": e["label"],
                "cf": np.array(e["cf"]),
                "cf_label": e["cf_label"],
            }
            for e in d["explanations"]
        ]
        obj = cls()
        obj.explanations = explanations
        return obj

## Counterfactual Explainer

In [27]:
from typing import Callable
class CounterfactualExplainer(ExplainerBase):
    explanation_type = "local"
    alias = ["ce", "counterfactual"]

    def __init__(
        self,
        model,
        preprocess_function: Callable,
        mode: str = "classification",
        c=10.0,
        kappa=10.0,
        binary_search_steps=5,
        learning_rate=1e-2,
        num_iterations=100,
        grad_clip=1e3,
        **kwargs,
    ):
        super().__init__()
        assert mode == "classification", "CE supports classification tasks only."

        model_type = None
        if is_tf_available():
            if isinstance(model, tf.keras.Model):
                model_type = "tf"
        if model_type is None and is_torch_available():
            if isinstance(model, nn.Module):
                model_type = "torch"
        if model_type is None:
            raise ValueError(f"`model` should be a tf.keras.Model " f"or a torch.nn.Module instead of {type(model)}")

        self.model = model
        self.preprocess_function = preprocess_function
        self.create_optimizer = lambda x, y, m: CounterfactualOptimizer(
            x,
            y,
            m,
            c=c,
            kappa=kappa,
            binary_search_steps=binary_search_steps,
            learning_rate=learning_rate,
            num_iterations=num_iterations,
            grad_clip=grad_clip,
        )

    def _preprocess(self, inputs: Image):
        if inputs.values is None:
            return None
        if self.preprocess_function is not None:
            inputs = self.preprocess_function(inputs)
            if not isinstance(inputs, np.ndarray):
                try:
                    inputs = inputs.detach().cpu().numpy()
                except AttributeError:
                    inputs = inputs.numpy()
        else:
            inputs = inputs.to_numpy()
        return inputs

    def _predict(self, inputs):
        try:
            import torch

            self.model.eval()
            param = next(self.model.parameters())
            x = inputs if isinstance(inputs, torch.Tensor) else torch.tensor(inputs, dtype=torch.get_default_dtype())
            scores = self.model(x.to(param.device)).detach().cpu().numpy()
        except:
            scores = self.model(inputs).numpy()
        y = np.argmax(scores, axis=1).astype(int)
        return y

    def explain(self, X: Image, **kwargs) -> CFExplanation:
        assert min(X.shape[1:3]) > 4, f"The image size ({X.shape[1]}, {X.shape[2]}) is too small."
        verbose = kwargs.get("kwargs", True)
        explanations = CFExplanation()
        y = self._predict(self._preprocess(X))

        for i in range(len(X)):
            x = self._preprocess(X[i])
            optimizer = self.create_optimizer(x=x, y=y[i], m=self.model)
            # Original image
            x = x.squeeze()
            if x.ndim == 3 and x.shape[0] == 3:
                x = np.transpose(x, (1, 2, 0))

            # Get the counterfactual example
            cf = optimizer.optimize(verbose=verbose)
            if cf is not None:
                cf_label = self._predict(cf)[0]
                cf = cf.squeeze()
                if cf.ndim == 3 and cf.shape[0] == 3:
                    cf = np.transpose(cf, (1, 2, 0))
            else:
                cf_label = None
            explanations.add(image=x, label=y[i], cf=cf, cf_label=cf_label)
        return explanations

In [29]:
import importlib
explainer = CounterfactualExplainer(
    model=model,
    preprocess_function=preprocess
)

## Optimizer Module

In [30]:
# The optimizer module - for improving the results
class CounterfactualOptimizer:
    def __init__(
            self,
            x0,
            target,
            model,
            c=10.0,
            kappa=10.0,
            binary_search_steps=5,
            learning_rate=1e-2,
            num_iterations=1000,
            grad_clip=1e3,
            gamma=None,
            bounds=None,
    ):
        assert x0.shape[0] == 1
        if not isinstance(x0, np.ndarray):
            try:
                x0 = x0.detach().cpu().numpy()
            except AttributeError:
                x0 = x0.numpy()

        self.x0 = x0
        self.target = target
        self.model = model
        self.c = c
        self.kappa = kappa
        self.binary_search_steps = binary_search_steps
        self.learning_rate = learning_rate
        self.num_iterations = num_iterations
        self.grad_clip = grad_clip
        self.gamma = gamma
        self.bounds = (np.min(x0), np.max(x0)) if bounds is None else bounds

        self.model_type = None
        if is_torch_available():
            if isinstance(self.model, nn.Module):
                self.model_type = "torch"
        if self.model_type is None and is_tf_available():
            if isinstance(self.model, tf.keras.Model):
                self.model_type = "tf"
        if self.model_type is None:
            self.model_type = "other"
        self.num_classes = self._predict(self.x0).shape[1]

    def _init_functions(self, c):
        if self.model_type == "torch":
            self.func = _ObjectiveTorch(self.x0, self.target, self.model, c=c, kappa=self.kappa, gamma=self.gamma)
        elif self.model_type == "tf":
            self.func = _ObjectiveTF(self.x0, self.target, self.model, c=c, kappa=self.kappa, gamma=self.gamma)
        else:
            self.func = self.model

    def _predict(self, inputs):
        if self.model_type == "tf":
            inputs = tf.convert_to_tensor(inputs, dtype=tf.keras.backend.floatx())
            return self.model(inputs).numpy()
        elif self.model_type == "torch":
            self.model.eval()
            param = next(self.model.parameters())
            inputs = torch.tensor(inputs, dtype=torch.get_default_dtype()).to(param.device)
            return self.model(inputs).detach().cpu().numpy()
        else:
            return self.model(inputs)

    def _compute_gradient(self, model, inputs):
        if self.model_type == "tf":
            inputs = tf.convert_to_tensor(inputs, dtype=tf.keras.backend.floatx())
            with tf.GradientTape() as tape:
                tape.watch(inputs)
                predictions, loss = model(inputs)
                gradients = tape.gradient(predictions, inputs).numpy()
                loss = loss.numpy()

        elif self.model_type == "torch":
            model.eval()
            param = next(model.parameters())
            inputs = torch.tensor(inputs, requires_grad=True, dtype=torch.get_default_dtype()).to(param.device)
            predictions, loss = model(inputs)
            gradients = (
                grad(outputs=predictions, inputs=inputs, grad_outputs=torch.ones_like(predictions).to(param.device))[0]
                    .detach()
                    .cpu()
                    .numpy()
            )
            loss = loss.detach().cpu().numpy()
        else:
            # Can also apply numerical differentiation to achieve better results
            raise NotImplementedError
        gradients = np.maximum(np.minimum(gradients, self.grad_clip), -self.grad_clip)
        return gradients, loss

    def _learning_rate(self, i):
        return self.learning_rate * (1 - i / self.num_iterations) ** 0.5

    @staticmethod
    def _update_const(c, c_lb, c_ub, sol):
        if sol is not None:
            c_ub = min(c_ub, c)
            if c_ub < 1e9:
                c = (c_lb + c_ub) * 0.5
        else:
            c_lb = max(c_lb, c)
            if c_ub < 1e9:
                c = (c_lb + c_ub) * 0.5
            else:
                c *= 10
        return c, c_lb, c_ub

    def optimize(self, verbose=True) -> np.ndarray:
        bar = ProgressBar(self.num_iterations) if verbose else None

        c_lb, c_ub, c = 0, 1e10, self.c
        best_solution, best_loss = None, 1e8
        for step in range(self.binary_search_steps):
            self._init_functions(c)
            x = self.x0.copy()
            current_best_sol, current_best_loss = None, 1e8

            for iteration in range(self.num_iterations):
                # Compute the gradient and loss
                gradient, loss = self._compute_gradient(self.func, x)
                if loss < 0:
                    f = np.sum(np.abs(x - self.x0))
                    if f < current_best_loss:
                        current_best_loss, current_best_sol = f, x
                    if f < best_loss:
                        best_loss, best_solution = f, x
                # Update x
                new_x = x - self._learning_rate(iteration) * gradient
                new_x = np.minimum(np.maximum(new_x, self.bounds[0]), self.bounds[1])
                if np.sum(np.abs(x - new_x)) < 1e-6:
                    break
                x = new_x
                if verbose:
                    bar.print(iteration, prefix=f"Binary step: {step + 1}", suffix="")

            c, c_lb, c_ub = self._update_const(c, c_lb, c_ub, current_best_sol)
        return best_solution

In [32]:
from torch.autograd import grad
class _ObjectiveTorch(nn.Module):
        def __init__(self, x0, target, model, c, kappa, gamma=None):
            super().__init__()
            param = next(model.parameters())
            if isinstance(x0, np.ndarray):
                self.x0 = torch.tensor(x0, dtype=torch.get_default_dtype()).to(param.device)
            else:
                self.x0 = x0.to(param.device)
            self.num_classes = model(self.x0).shape[1]
            self.target = torch.tensor(np.eye(1, self.num_classes, target), dtype=torch.get_default_dtype()).to(
                param.device
            )
            if gamma is None:
                self.gamma = 1
            else:
                self.gamma = torch.tensor(
                    np.expand_dims(np.abs(gamma) + 1e-8, axis=0), dtype=torch.get_default_dtype()
                ).to(param.device)

            self.model = model.eval()
            self.c = c
            self.kappa = kappa
            self.reduce_dims = list(range(1, len(x0.shape)))

        def forward(self, x):
            # Regularization term
            regularization = torch.sum(torch.abs(self.x0 - x) / self.gamma, dim=self.reduce_dims)
            # Loss function
            prob = self.model(x)
            a = torch.sum(prob * self.target, dim=1)
            b = torch.max((1 - self.target) * prob - self.target * 10000, dim=1)[0]
            loss = nn.functional.relu(a - b + self.kappa)
            return torch.mean(self.c * loss + regularization), torch.mean(a - b)

## Displaying the Results

In [33]:
explanations = explainer.explain(x_test[0:5])
explanations.ipython_plot(index=4)

Binary step: 5 |███████████████████████████████████████-| 99.0% 