# Finding counterfactuals on the data manifold 

In this notebook we want to show how to find counterfactual that lie on the datamanifold and how they differ from adversarial examples that lie off data manifold.

Let's assume our data lies on a one-dimensional manifold, a helix, that is embedded in 3 dimensional space. We define a simple classifier that divides the data into two classes. To approximate the data manifold we train a normalizing flow. With the classifier $f$ alone we can produce adversarial examples by doing gradient ascent in the data space $X$:

$$
x^{(t+1)} = x^{(t)} + \lambda \frac{\partial f}{\partial x}(x^{(t)})
$$

Where $\lambda$ is the learning rate.
The retrieved adversarial examples often lie off the datamanifold.


In contrast, if we search for counterfactuals by doing gradient ascent in the latent space $Z$ of our normalizing flow $g$ we stay (approximately) on the data manifold:

$$
z^{(t+1)} = z^{(t)} + \lambda \frac{\partial (f\circ g)}{\partial z}(z^{(t)})
$$

For more detail please refer to the paper: 




## Data distribution

Let's start by defining a uniform data distribution on a helix. 

In [None]:
# some imports and plot definitions

import torch
import numpy as np
import scipy
import os
from tqdm import tqdm
from scipy.optimize import NonlinearConstraint

import torch.optim as optim
import torch.nn as nn

import matplotlib
from matplotlib import pyplot as plt
from matplotlib import rc

rc('text', usetex=True)
plt.rc('font', family='serif')

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "serif",
    "font.serif": ["Nimbus Roman"],
    "mathtext.fontset": "cm",
    "font.size": 30,
})

label_font_size = 32
title_font_size = 32

# make directories
directories = ["models", "plots", "results"]
for d in directories:
    if not os.path.exists(d):
        os.makedirs(d)

In [None]:
# our target distribution
class Helix:

    def __init__(self, r=1, sigma=8):
        self.sigma = sigma
        self.r = r
    def sample(self, num_samples):
        z = (torch.rand(num_samples)-0.5)*self.sigma 

        samples = torch.zeros([num_samples, 3])

        samples[:, 0] = self.r * torch.sin(z)
        samples[:, 1] = self.r * torch.cos(z)
        samples[:, 2] = z

        return samples
    
    def get_class(self, samples):
        return (samples[:, 2]>0).to(dtype=torch.float32)
    
    def dist(self, sample):
        
        sample_np = sample.numpy()
        
        def f(z):
            return np.array([self.r * np.sin(z), self.r * np.cos(z), z])


        def objective(X):
            return np.linalg.norm(X - sample_np)


        def con(X):
            z = X[2]
            return np.linalg.norm(f(z) - X)


        x0 = f(sample_np[2])

        nlc = NonlinearConstraint(con, 0.0, 0.0)
        X_sol = scipy.optimize.minimize(objective, x0, args=(), method='SLSQP', constraints=nlc)
        
        return np.linalg.norm(X_sol.x - sample_np)

In [None]:
# plot function
def scatter_plot(ax, samples, title=' ', lim_min=None, lim_max=None, alpha=1.0, s=1, color=None, label=None):
    ax.scatter(samples[:, 0], samples[:, 1], samples[:, 2], alpha=alpha, s=s, c=color, label=label)
    ax.set_title(title, fontsize=title_font_size)
    ax.set_xlabel("\n"+r"$x_1$")
    ax.set_ylabel("\n"+r"$x_2$")
    ax.set_zlabel("\n"+r"$x_3$")
    if lim_min is not None and lim_max is not None:
        ax.set_xlim(lim_min[0], lim_max[0])
        ax.set_ylim(lim_min[1], lim_max[1])
        ax.set_zlim(lim_min[2], lim_max[2])

In [None]:
target_distr = Helix()

samples_target = target_distr.sample(500)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1, projection='3d')
mini = torch.min(samples_target).detach().numpy()
maxi = torch.max(samples_target).detach().numpy()
lim_min = [mini, mini, mini]
lim_max = [maxi, maxi, maxi]

lim_min = torch.min(samples_target, dim=0)[0].detach().numpy()
lim_max = torch.max(samples_target, dim=0)[0].detach().numpy()
scatter_plot(ax, samples_target.detach().numpy(), 'target', lim_min, lim_max, s=10)
plt.show()

## Flow

Now lets define a normalizing flow that we train to map samples from a multivariate standard Normal distribution to the data distribution.

In [None]:
# definition of latent distribution
latent_distr = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(3), torch.eye(3))

In [None]:
# definition of the normalizing flow

nets = lambda: nn.Sequential(nn.Linear(3, 256), nn.LeakyReLU(), 
                             nn.Linear(256, 256), nn.LeakyReLU(), 
                             nn.Linear(256, 3), nn.Tanh())
nett = lambda: nn.Sequential(nn.Linear(3, 256), nn.LeakyReLU(), 
                             nn.Linear(256, 256), nn.LeakyReLU(), 
                             nn.Linear(256, 3))


class RealNVP(torch.nn.Module):
    def __init__(self):
        super(RealNVP, self).__init__()
        mask = torch.from_numpy(np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0], 
                                          [0, 0, 1], [1, 0, 0], [0, 1, 0]] * 2).astype(np.float32))
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.t = torch.nn.ModuleList([nett() for _ in range(len(mask))])
        self.s = torch.nn.ModuleList([nets() for _ in range(len(mask))])

    def forward(self, z):
        x = z
        for i in range(len(self.t)):
            x_ = x * self.mask[i]
            s = self.s[i](x_) * (1 - self.mask[i])
            t = self.t[i](x_) * (1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
        return x

    def reverse(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.t))):
            z_ = self.mask[i] * z
            s = self.s[i](z_) * (1 - self.mask[i])
            t = self.t[i](z_) * (1 - self.mask[i])
            z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
            log_det_J -= s.sum(dim=1)
            
        return z, log_det_J

In [None]:
# funtion to train the flow

def train_flow(flow, latent_distr, target_distr, batch_size, epochs, optimizer, save_as, best_loss = np.inf):
    avg_loss = 0
    flow.train()
    with tqdm(total=epochs) as progress_bar:
        for epoch in range(epochs):
            samples = target_distr.sample(batch_size)
            z, log_det = flow.reverse(samples)
            log_prob = latent_distr.log_prob(z)
            loss = -(log_det.mean() + log_prob.mean())
            flow.zero_grad()
            loss.backward()
            optimizer.step()
            avg_loss = moving_avg(avg_loss, loss, epoch + 1).item()
            progress_bar.set_postfix(loss=avg_loss)
            progress_bar.update(1)

            if (epoch+1)%100==0:
                flow.eval()
                avg_loss_val = 0
                for epoch_val in range(100):
                    samples = target_distr.sample(batch_size)
                    z, log_det = flow.reverse(samples)
                    log_prob = latent_distr.log_prob(z)
                    loss = -(log_det.mean() + log_prob.mean())
                    avg_loss_val = moving_avg(avg_loss_val, loss, epoch_val + 1).item()

                if avg_loss_val<best_loss:
                    best_loss=avg_loss_val
                    torch.save(flow, save_as)
                flow.train()
                    
    flow.eval()
    avg_loss = 0
    with tqdm(total=100) as progress_bar:
        for epoch in range(100):
            samples = target_distr.sample(batch_size)
            z, log_det = flow.reverse(samples)
            log_prob = latent_distr.log_prob(z)
            loss = -(log_det.mean() + log_prob.mean())
            avg_loss = moving_avg(avg_loss, loss, epoch + 1).item()
            progress_bar.set_postfix(loss=avg_loss)
            progress_bar.update(1)

    return 1

def moving_avg(current_avg, new_value, idx):
    return current_avg + (new_value - current_avg) / idx

In [None]:
# train flow or load checkpoint 

flow = RealNVP()
lr = 1e-4
epochs = 5000
batch_size = 500
optimizer = optim.Adam(flow.parameters(), lr=lr)
save_as = f"models/flow.pth"

if os.path.isfile(save_as):
    flow = torch.load(save_as)
else:
    train_flow(flow, latent_distr, target_distr, batch_size, epochs, optimizer, save_as)


Let's check what our flow has learned:

In [None]:
num_samples = 1500
latent_distr = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(3), torch.eye(3))
torch.manual_seed(1)
samples_latent = latent_distr.sample((num_samples,))
samples_target = target_distr.sample(num_samples)
samples_flow = flow.forward(latent_distr.sample((num_samples,)))

tensors = torch.cat([samples_target, samples_flow], dim=0)
torch.min(tensors, dim=0)[0].detach().numpy()
torch.max(tensors, dim=0)[0].detach().numpy()
lim_min = torch.min(tensors, dim=0)[0].detach().numpy()
lim_max = torch.max(tensors, dim=0)[0].detach().numpy()

fig = plt.figure(figsize=(21, 8))
ax = fig.add_subplot(1, 3, 1, projection='3d')
scatter_plot(ax, samples_latent.detach().numpy(), 'latent distribution', lim_min*2, lim_max*2)
ax = fig.add_subplot(1, 3, 2, projection='3d')
scatter_plot(ax, samples_target.detach().numpy(), 'target distribution', lim_min, lim_max)
ax = fig.add_subplot(1, 3, 3, projection='3d')
scatter_plot(ax, samples_flow.detach().numpy(), 'learned distribution', lim_min, lim_max)
fig.savefig(f'plots/learned_distribution.png')
plt.show()

## Adding a classifier

We add a classifier. For simplicity we just define points with z-coordinate smaller than zero to belong to the one class and points with z-coordinate bigger than zero to belong to the other class. 
The binary classifier is a simple neural network with one hidden layer (256 neurons) and a single output neuron. The output is scaled to lie between 0 and 1.

In [None]:
# classifier definition
classifier = nn.Sequential(nn.Linear(3, 256), nn.ReLU(),
                             nn.Linear(256, 1), nn.Sigmoid())


# training function
def train_classifier(classifier, target_distr, batch_size, epochs, optimizer, save_as):

    avg_loss = 0
    classifier.train()
    best_loss = np.inf
    loss_fun = torch.nn.BCELoss(reduction="mean")
    with tqdm(total=epochs) as progress_bar:
        for epoch in range(epochs):
            samples = target_distr.sample(batch_size)
            target = target_distr.get_class(samples)
            prediction = classifier.forward(samples).squeeze()
            loss = loss_fun(prediction, target)
            classifier.zero_grad()
            loss.backward()
            optimizer.step()
            avg_loss = moving_avg(avg_loss, loss, epoch + 1).item()
            progress_bar.set_postfix(loss=avg_loss)
            progress_bar.update(1)
            
    torch.save(classifier, save_as)
                    
    flow.eval()
    avg_loss = 0
    with tqdm(total=100) as progress_bar:
        for epoch in range(100):
            samples = target_distr.sample(batch_size)
            target = target_distr.get_class(samples)
            prediction = classifier.forward(samples).squeeze()
            loss = loss_fun(prediction, target)
            classifier.zero_grad()
            avg_loss = moving_avg(avg_loss, loss, epoch + 1).item()
            progress_bar.set_postfix(loss=avg_loss)
            progress_bar.update(1)

    return 1        

        

In [None]:
# train classifier or load checkpoint

epochs = 3000
batch_size = 500
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
save_as = f"models/classifier.pth"

if os.path.isfile(save_as):
    classifier = torch.load(save_as)
else:
    train_classifier(classifier, target_distr, batch_size, epochs, optimizer, save_as)
    
    
# test clasifier on original data
samples_manifold = target_distr.sample((500))
predictions = classifier(samples_manifold).squeeze()
classes = target_distr.get_class(samples_manifold).squeeze()
correct = torch.sum(torch.round(predictions)==classes)
print(f"classifier accuracy on data from target distribution:\t{correct*100/len(samples_manifold):.2f}%")

fig = plt.figure(figsize=(16,8))
ax = fig.add_subplot(1, 2, 1, projection='3d')
mask = predictions > 0.5
samples_manifold1 = samples_manifold[mask, :].detach().numpy()
samples_manifold2 = samples_manifold[~mask, :].detach().numpy()

scatter_plot(ax, samples_manifold1, ' ', alpha=.2, lim_min=lim_min, lim_max=lim_max, s=64, color="tab:orange")
scatter_plot(ax, samples_manifold2, 'classification of data samples', alpha=.2, lim_min=lim_min, lim_max=lim_max, s=64, color="gray")

# test clasifier on generated data
z = latent_distr.sample((500,))
samples = flow.forward(z)

predictions = classifier(samples).squeeze()
classes = target_distr.get_class(samples).squeeze()
correct = torch.sum(torch.round(predictions)==classes)
print(f"classifier accuracy on data from learned distribution:\t{correct*100/len(samples):.2f}%")

mask = predictions > 0.5
samples_flow1 = samples[mask, :].detach().numpy()
samples_flow2 = samples[~mask, :].detach().numpy()

ax = fig.add_subplot(1, 2, 2, projection='3d')
scatter_plot(ax, samples_flow1, ' ', alpha=.2, lim_min=lim_min, lim_max=lim_max, s=64, color="tab:orange")
scatter_plot(ax, samples_flow2, 'classification on generated samples', alpha=.2, lim_min=lim_min, lim_max=lim_max, s=64, color="gray")
fig.savefig(f'plots/predictions_from_classifier.png')
plt.show()

## Finding adversarial examples and counterfactuals

Now that we have a classifier and a flow we can generate adversarial examples and counterfactuals.

In [None]:
# function for finding adversarial examples
def conv_attack(classifier, x_org, target, steps, lr=1e-3):
    
    x = x_org.clone()
    x.requires_grad = True
    
    optimizer = optim.Adam(params=[x], lr=lr)
    xs = []
    
    for i in range(steps):
        xs.append(x.detach().clone())
        optimizer.zero_grad()
        prediction = classifier.forward(x).squeeze()
        
        if target >= 0.5 and prediction >= target:
            break
        if target < 0.5 and prediction <= target:
            break
            
        loss = (prediction-target)**2
        loss.backward()
        optimizer.step()
    
    return xs

# function for finding counterfactuals on the data manifold
def z_attack(classifier, flow, x_org, target, steps, lr=1e-3):
    
    z, _ = flow.reverse(x_org)
    z = z.detach()
    z.requires_grad = True
    xs = []

    optimizer = optim.Adam(params=[z], lr=lr)
    
    for i in range(steps):
        x = flow.forward(z)
        xs.append(x.detach().clone())
        optimizer.zero_grad()
        prediction = classifier.forward(x).squeeze()
        
        if target >= 0.5 and prediction >= target:
            break
        if target < 0.5 and prediction <= target:
            break
            
        loss = (prediction-target)**2
        loss.backward()
        optimizer.step()   

    return xs


We do gradient ascent steps until we reach a target value of $f(x^\prime)\geq0.9$, if the original data point was predicted to belong to class 0 $(f(x)<0.5)$, and $f(x^\prime)\leq0.1$ if the original data point was predicted to belong to class 1 $(f(x)\leq0.5)$.
We plot a few steps for the attack in $X$ as well as in $Z$ space to see how the coordinates of the modified points $x^\prime$ and $g(z^\prime)$ change during the attack. As expected the attack in $X$ quickly leads off the data manifold while the attack in $Z$ stays on the data manifold.

In [None]:
# adv attack for one sample

# define original datapoint on Helix
z_value=2
x_org = torch.FloatTensor([[ np.sin(z_value), np.cos(z_value),  z_value]])
pred_x_org = classifier.forward(x_org).squeeze().item()
print(f"x_org:                 {x_org}")
print(f"prediction x_org:      {pred_x_org:.4f}\n")

steps = 1000
target = 0.9 if pred_x_org<0.5 else 0.1
    
# conventional adversarial attack
xs_conv = conv_attack(classifier, x_org, target, steps, lr=2e-2)
print(f"conv_attack steps:     {len(xs_conv)}")
print(f"x_conv:                {xs_conv[-1]}")
print(f"prediction x_conv:     {classifier.forward(torch.Tensor(xs_conv[-1])).squeeze().item():.4f}")
print(f"dist to manifold:      {target_distr.dist(xs_conv[-1].squeeze()):.4f}\n")

xs_conv = np.concatenate(xs_conv)

# attack in the latent space of the flow
xs_z = z_attack(classifier, flow, x_org, target, steps, lr=2e-2)
print(f"z-attack steps:        {len(xs_z)}")
print(f"x_z:                   {xs_z[-1]}")
print(f"prediction x_z:        {classifier.forward(xs_z[-1]).squeeze().item():.4f}")
print(f"dist to manifold:      {target_distr.dist(xs_z[-1].squeeze()):.4f}")

xs_z = np.concatenate(xs_z)

fig = plt.figure(figsize=(14, 8))
ax = fig.add_subplot(1,1,1, projection='3d')
    
scatter_plot(ax, samples_flow1[:150], ' ', alpha=.2, lim_min=lim_min, lim_max=lim_max, s=64, color="tab:orange")
scatter_plot(ax, samples_flow2[:150], r'Attacks in $\mathcal{X}$ and $\mathcal{Z}$', alpha=.2, lim_min=lim_min, lim_max=lim_max, s=64, color="gray")
    
ax.plot(xs_conv[:,0], xs_conv[:,1], xs_conv[:,2], lw=3, color="tab:red", zorder=15, label='grad asc in $\\mathcal{X}$')
ax.plot(xs_z[:,0], xs_z[:,1], xs_z[:,2], lw=3, color="tab:green", zorder=15, label='grad asc in $\\mathcal{Z}$')
ax.plot(x_org[0,0], x_org[0,1], x_org[0,2], color="tab:blue", marker='x', linestyle="None", zorder=15, markersize=30, markeredgewidth=4, label=r'$x$')
ax.plot(xs_conv[-1,0], xs_conv[-1,1], xs_conv[-1,2], color="tab:red", linestyle="None", marker='x', zorder=10, markersize=17, markeredgewidth=3, label=r'$x^\prime$')
ax.plot(xs_z[-1,0], xs_z[-1,1], xs_z[-1,2], color="tab:green", marker='x', linestyle="None", zorder=10, markersize=17, markeredgewidth=3, label=r'$g(z^\prime)$')

ax.legend(fontsize=label_font_size, bbox_to_anchor=[1.1,0.9])
ax.set_xticks([-1, 0, 1])
ax.set_yticks([-1, 0, 1])
ax.set_zticks([-3, 0, 3])
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
for t in ax.zaxis.get_major_ticks(): t.label.set_fontsize(30)

plt.subplots_adjust(left=-0.3, bottom=0.01, right=0.95, top=0.99, wspace=0.2, hspace=0.05)
fig.savefig(f'plots/adv_attack.png')
plt.show()

### Statistical evaluation

Lets do a few more attacks and look at some statistics regarding the retrieved adversarial examples and counterfactuals. As we have an analytical definition of the helix and thus the data manifold we can easily calculate the distance from a found adversarial example or counterfactual to the data manifold. When we average over these distances we can see that they are on average much larger for adversarial examples found in $X$ than for counterfactuals found in $Z$.

In [None]:
# adv attacks statistics
num_attacks = 100

dists_conv = np.zeros(num_attacks)
dists_conv = np.empty((num_attacks,))
dists_conv[:] = np.nan

dists_z = np.zeros(num_attacks)
dists_z = np.empty((num_attacks,))
dists_z[:] = np.nan


dists_conv_file = f"results/dists_conv.txt"
dists_z_file = f"results/dists_z.txt"

with tqdm(total=num_attacks) as progress_bar:
    for i in range(num_attacks):


        torch.manual_seed(i)
        x_org = target_distr.sample(1)
        pred_x_org = classifier.forward(x_org).squeeze().item()

        target = 0.9 if pred_x_org < 0.5 else 0.1

        # conv attack
        xs_conv = conv_attack(classifier, x_org, target, steps=1000, lr=3e-2)
        # save if attack was successful
        if (target==0.9 and classifier(xs_conv[-1])>=0.9) or (target==0.1 and classifier(xs_conv[-1])<=0.1):
            dists_conv[i] = target_distr.dist(xs_conv[-1].squeeze())

        # z attack
        xs_z = z_attack(classifier, flow, x_org, target, steps=1000, lr=3e-2)
        # save if attack was successful
        if (target==0.9 and classifier(xs_z[-1])>=0.9) or (target==0.1 and classifier(xs_z[-1])<=0.1):
            dists_z[i] = target_distr.dist(xs_z[-1].squeeze())
            
        progress_bar.update()

np.savetxt(dists_conv_file, dists_conv)
np.savetxt(dists_z_file, dists_z)
    
dists_conv = dists_conv[~np.isnan(dists_conv)]
dists_z = dists_z[~np.isnan(dists_z)]

In [None]:
print("successful attacks:")
print(f"             in X: {len(dists_conv)}/{num_attacks}")
print(f"             in Z: {len(dists_z)}/{num_attacks}")

In [None]:

dists_conv = np.loadtxt(dists_conv_file)
dists_z = np.loadtxt(dists_z_file)

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(1, 1, 1)
ax.grid()
c = "k"
c2="lightblue"
plt.boxplot([dists_conv, dists_z], positions=[1,2], widths = .6, notch=False, patch_artist=True,
                boxprops=dict(facecolor=c2, color=c),
                capprops=dict(color=c),
                whiskerprops=dict(color=c),
                flierprops=dict(color=c, markeredgecolor=c),
                medianprops=dict(color="red", lw=2),
                )

ax.set_ylabel(f"distance to helix", fontsize=label_font_size)
ax.set_xlabel(f"gradient ascent", fontsize=label_font_size)
ax.set_xticklabels(['in $\mathcal{X}$', 'in $\mathcal{Z}$'])
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
plt.subplots_adjust(left=0.2, bottom=0.2, right=0.95, top=0.95, wspace=0.2, hspace=0.05)
fig.savefig(f'plots/distances.png')
plt.show()
