## Description

This is a minimal implementation of a (prior aware) training data reconstruction attack against a model trained with differential privacy, as described in [Hayes et al. (2022)](https://arxiv.org/abs/2302.07225).

The adversary is given a group of model parameters $\{\theta_1, \theta_2, \ldots, \theta_T\}$ and a set of privatized gradients $\{g_1, g_2, \ldots, g_T\}$, where each $\theta_i$ and $g_i$ denotes model parameters and privatized gradients at update step $i$. The adversary also has access to a prior set of inputs $\{z_1, z_2, \ldots, z_n\}$; the model is trained on $Z\cup\{z_i\}$, where $Z$ is a set of inputs known the adversary, which we refer to as the *fixed set*, and $z_i$ is sampled randomly from the prior. The goal of the attack is to infer which $z_i$ was used in training. This is achieved by iterating over each $z_i$ and computing the sum $\sum_{k=1}^T \langle g_{\theta_{k}}(z_i), g_k\rangle$, where $g_{\theta_{k}}(z_i)$ is the model parameter gradients given input $z_i$ with respect to model parameters $\theta_{k}$. The adversary selects the $z_i$ that maximizes this sum as their guess for the sample from the prior that was included in training.


## Imports


In [None]:
!pip install dm-haiku
!pip install dp-accounting
!pip install ml-collections
!pip install optax

from dp_accounting import dp_event
from dp_accounting.rdp import rdp_privacy_accountant

import functools
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds


Collecting dm-haiku
  Downloading dm_haiku-0.0.14-py3-none-any.whl.metadata (19 kB)
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl.metadata (8.9 kB)
Downloading dm_haiku-0.0.14-py3-none-any.whl (373 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m373.8/373.8 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.14 jmp-0.0.4
Collecting dp-accounting
  Downloading dp_accounting-0.5.0-py3-none-any.whl.metadata (2.0 kB)
Collecting attrs<24,>=22 (from dp-accounting)
  Downloading attrs-23.2.0-py3-none-any.whl.metadata (9.5 kB)
Downloading dp_accounting-0.5.0-py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading attrs-23.2.0-py3-none-any.whl (60 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32

## Set up DP accounting

In [None]:

RdpAccountant = rdp_privacy_accountant.RdpAccountant


def get_rdp_epsilon(
    sampling_probability, noise_multiplier, steps, delta, orders
):
  """Get privacy budget from Renyi DP."""
  event = dp_event.PoissonSampledDpEvent(
      sampling_probability, event=dp_event.GaussianDpEvent(noise_multiplier)
  )
  rdp_accountant = RdpAccountant(orders=orders)
  rdp_accountant.compose(event, steps)
  rdp_epsilon, opt_order = rdp_accountant.get_epsilon_and_optimal_order(delta)
  return rdp_epsilon, opt_order

## Define a simple MLP

In [None]:
def net_fn(x):
  """Standard MLP network."""
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(10),
      jax.nn.elu,
      hk.Linear(10),
  ])
  return mlp(x)

## Model training hyperparameters

The config below should give a privacy budget of $\epsilon\approx 20$.

In [None]:
config = config_dict.ConfigDict()

# Noise multiplier used in DP training.
config.noise_multiplier = 2  # @param

# All individual sample gradients will be clipped to have a maximum L2 norm.
config.l2_norm_clip = 0.1  # @param

# Number of epochs.
config.epochs = 100  # @param

# Learning rate.
config.learning_rate = 9  # @param

# Total number of examples in the prior set.
# Attack base rate will be 1 / config.num_in_prior.
config.num_in_prior = 8  # @param

# Batch size used in training.
config.batch_size = 1000  # @param

# Training data sub-sampling probability.
config.q = 1 # @param

# Total size of the training dataset. Determined by config.batch_size and
# config.q. For convenience of the attack, which requires some conditions on
# batch sizes, we require config.batch_size to be divisible by config.total_num.
config.total_num = int(config.batch_size / config.q)

# Number of update training steps.
config.steps = int(config.epochs / config.q)

# Probability of DP failure.
config.delta = 1e-5

# Seed used to initialize parameters and random noise used in DP training.
config.seed = 1  # @param

# Generate orders used in RDP accounting
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))

# Get privacy budget for the above configuration.
eps, opt_order = get_rdp_epsilon(
    config.q,
    config.noise_multiplier,
    config.steps,
    config.delta,
    orders,
)
print(f'Epsilon: {eps:.10f}')

Epsilon: 35.0817541508


In [None]:
# @title Reconstruction hyperparameters (for gradient attack)

attack_config = config_dict.ConfigDict()

# If True, adversary can subtract fixed set clipped gradients from the
# privatised gradient.
attack_config.deduct_fixed_set_grads = True  # @param

# If True, adversary can rescale gradients by the batch size.
# Technically, there is a mismatch between theory and practice here since in DP
# accounting there isn't a "fixed" batch size, rather a probability with which
# an example is included in a batch.
attack_config.rescale_by_batch_size = True  # @param


In [None]:
# @title Get prior and fixed data


def load_dataset(
    split: str,
    *,
    is_training: bool,
    batch_size: int,
    total_num: int,
    start_idx: int = 0,
    repeat: bool = False,
):
  """Loads the MNIST dataset as a generator of batches."""
  ds = tfds.load(
      'mnist', split=split + f'[{start_idx}:{start_idx+total_num}]'
  ).cache()
  ds = ds.batch(batch_size)
  if repeat:
    ds = ds.repeat()
  return iter(tfds.as_numpy(ds))


# Generate the fixed set if config.total_num > 1. Otherwise it doesn't exist and
#  we train only on the sample selected from the prior set.
if config.total_num > 1:
  train_data = load_dataset(
      split='train',
      is_training=True,
      repeat=False,
      batch_size=config.batch_size,
      total_num=config.total_num - 1,
      start_idx=0,
  )

  fixed_images = []
  fixed_labels = []
  for curr_train_batch in train_data:
    fixed_images.extend(curr_train_batch['image'] / 255.0)
    fixed_labels.extend(curr_train_batch['label'])
  fixed_images, fixed_labels = np.array(fixed_images), np.array(fixed_labels)

# Generate prior dataset.
prior_data = load_dataset(
    split='train',
    is_training=True,
    repeat=True,
    batch_size=config.num_in_prior,
    total_num=config.num_in_prior,
    start_idx=config.total_num,
)
prior_batch = next(prior_data)

prior_images = prior_batch['image'] / 255.0
prior_labels = prior_batch['label']


# Load MNIST test set
ds = tfds.load('mnist', split='test', batch_size=-1, as_supervised=True)
test_images, test_labels = tfds.as_numpy(ds)
test_images = test_images.astype(np.float32) / 255.0  # Normalize
test_images = jnp.reshape(test_images, (test_images.shape[0], -1))  # Flatten
test_labels = jnp.array(test_labels)



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.F72UZZ_3.0.1/mnist-train.tfrecord*...:   0%|          | 0…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.F72UZZ_3.0.1/mnist-test.tfrecord*...:   0%|          | 0/…

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


## Training set up

In [None]:
net = hk.without_apply_rng(hk.transform(net_fn))


def broadcast_axis(data, ndims, axis):
  newshape = [1] * ndims
  newshape[axis] = -1
  return data.reshape(*newshape)


@jax.jit
def loss(params, batch):
  """Cross-entropy loss."""
  inputs, targets, unused_is_fixed = batch
  # Inputs scaled to [-1, 1].
  inputs = 2.0 * inputs - 1.0
  logits = net.apply(params, inputs)
  labels = jax.nn.one_hot(targets, 10)
  log_probs = jax.nn.log_softmax(logits)
  softmax_xent_per_example = -jnp.sum(labels * log_probs, axis=1)
  return jnp.mean(softmax_xent_per_example)


@jax.jit
def clipped_grad(params, l2_norm_clip, single_example_batch):
  """Evaluate gradient for a single-example batch and clip its grad norm."""
  # Compute loss and gradient for a single example.
  loss_val, grads = jax.value_and_grad(loss)(params, single_example_batch)
  # Flatten gradient tree and compute the norm.
  nonempty_grads, tree_def = jax.tree_util.tree_flatten(grads)
  total_grad_norm = jnp.linalg.norm(
      jnp.array([jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
  )
  divisor = jnp.maximum(total_grad_norm / l2_norm_clip, 1.0)
  # Normalize gradient to have a maximium norm of l2_norm_clip.
  normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
  return (
      jax.tree_util.tree_unflatten(tree_def, normalized_nonempty_grads),
      loss_val,
  )


@functools.partial(jax.jit, static_argnums=(3, 4, 5))
def privatise_gradient(
    params, batch, rng, l2_norm_clip, noise_multiplier, batch_size
):
  """Return differentially private gradients for params, evaluated on batch."""
  # Compute individual sample clipped gradients over a batch.
  clipped_grads, loss_vals = jax.vmap(clipped_grad, (None, None, 0), (0, 0))(
      params, l2_norm_clip, batch
  )
  # Aggregate, add noise, and average these clipped gradients.
  clipped_grads_flat, grads_treedef = jax.tree_util.tree_flatten(clipped_grads)
  aggregated_clipped_grads = [g.sum(0) for g in clipped_grads_flat]
  rngs = jax.random.split(rng, len(aggregated_clipped_grads))
  noised_aggregated_clipped_grads = [
      g + l2_norm_clip * noise_multiplier * jax.random.normal(r, g.shape)
      for r, g in zip(rngs, aggregated_clipped_grads)
  ]
  normalized_noised_aggregated_clipped_grads = [
      g / batch_size for g in noised_aggregated_clipped_grads
  ]
  return (
      jax.tree_util.tree_unflatten(
          grads_treedef, normalized_noised_aggregated_clipped_grads
      ),
      loss_vals,
      clipped_grads,
  )


def compute_epsilon(
    steps, num_examples, batch_size, noise_multiplier, target_delta=1e-5
):
  """Compute privacy budget at a given step."""
  q = batch_size / float(num_examples)
  orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
  eps, _ = get_rdp_epsilon(q, noise_multiplier, steps, target_delta, orders)
  return eps


def shape_as_image(batch, dummy_dim=False):
  """Reshape an image in a batch -- useful when we vmap the clipping operation."""
  inputs, targets, is_fixed = batch
  target_shape = (-1, 1, 28, 28, 1) if dummy_dim else (-1, 28, 28, 1)
  return jnp.reshape(inputs, target_shape), targets, is_fixed


@functools.partial(jax.jit, static_argnums=(4, 5, 6, 7))
def private_update(
    params,
    opt_state,
    batch,
    rng,
    batch_size,
    l2_norm_clip,
    noise_multiplier,
    opt_updater,
):
  """Update model parameters with a privatized gradient."""

  # Compute the private gradient.
  (private_grad, loss_vals, clipped_grads) = privatise_gradient(
      params, batch, rng, l2_norm_clip, noise_multiplier, batch_size
  )

  # Update model parameters with private gradient.
  updates, opt_state = opt_updater(private_grad, opt_state)
  new_params = optax.apply_updates(params, updates)

  # For the attack we will track the clipped gradients over the fixed inputs.
  _, _, is_fixed = batch

  # Sum fixed clipped gradients (known to the attacker).
  summed_fixed_clipped_grads = jax.tree_util.tree_map(
      lambda x: (broadcast_axis(is_fixed, x.ndim, 0) * x).sum(0), clipped_grads
  )
  # Average sum over batch size.
  avg_fixed_clipped_grads = jax.tree_util.tree_map(
      lambda x: x / batch_size, summed_fixed_clipped_grads
  )
  # Attack computes maybe_noisy_target_grad which is either:
  # * random noise if the target is not in batch.
  # * (target_gradient / batch_size) + noise if the target is in the batch.
  maybe_noisy_target_grad = jax.tree_util.tree_map(
      lambda x, y: x - y, private_grad, avg_fixed_clipped_grads
  )

  return (
      new_params,
      opt_state,
      loss_vals.mean(),
      private_grad,
      maybe_noisy_target_grad,
  )


def train_fn(attack_config, config):
  """Train a differentially private model."""

  # Grab the target image from the prior set.
  target_image = prior_images[attack_config.target_idx_from_prior][None, ...]
  target_label = prior_labels[attack_config.target_idx_from_prior][None, ...]

  # Create a training set of images. We also include a marker identifying
  # if an input belongs to the fixed set.
  if config.total_num > 1:
    train_images = np.concatenate((fixed_images, target_image))
    train_labels = np.concatenate((fixed_labels, target_label))

    is_fixed = np.concatenate(
        (np.ones((len(fixed_images),)), np.zeros((len(target_image),)))
    )
    total_num = config.total_num
    batch_size = config.batch_size

  else:
    train_images = target_image
    train_labels = target_label
    total_num = 1
    batch_size = 1
    is_fixed = np.zeros((len(target_image),))

  ds = tf.data.Dataset.from_tensor_slices(
      (train_images, train_labels, is_fixed)
  )
  ds = ds.shuffle(1000).batch(batch_size, drop_remainder=True)
  ds = ds.repeat()
  ds = iter(tfds.as_numpy(ds))

  # Set up optimiser.
  opt = optax.sgd(config.learning_rate)

  # Initialise a set of model parameters.
  rng = jax.random.PRNGKey(config.seed)
  params = net.init(rng, np.ones((1, 28, 28, 1)))
  opt_state = opt.init(params)

  # We train a differentially private model and keep track of information
  # available to the attacker:
  # * model parameters.
  # * private gradient.
  # * private gradient with clipped gradients from fixed set subtracted.
  info_for_attacker = []
  for step in range(1, config.steps + 1):
    batch = next(ds)
    _, _, is_fixed = batch

    rng, _ = jax.random.split(rng)

    # Compute updates.
    (opt_params, opt_state, loss_val, private_grad, maybe_noisy_target_grad) = (
        private_update(
            params,
            opt_state,
            shape_as_image(batch, dummy_dim=True),
            rng,
            batch_size,
            config.l2_norm_clip,
            config.noise_multiplier,
            opt.update,
        )
    )

    # Determine privacy loss so far.
    eps = compute_epsilon(
        step, total_num, batch_size, config.noise_multiplier, config.delta
    )

    # Track information available to the attacker.
    info_for_attacker.append((params, private_grad, maybe_noisy_target_grad))

    # Update model parameters.
    params = opt_params

  # Determine privacy loss.
  eps = compute_epsilon(
      step, total_num, batch_size, config.noise_multiplier, config.delta
  )
  # print(eps)

  # Run inference
  logits = net.apply(params, test_images)
  preds = jnp.argmax(logits, axis=-1)

  # Compute accuracy
  accuracy = jnp.mean(preds == test_labels)
  print("Test accuracy:", float(accuracy))

  return info_for_attacker, eps

## Reconstruction attack set up

In [None]:
def reconstruction_upper_bound(pmode, q, noise_mul, steps, mc_samples=10000):
    x = np.random.normal(0.0, noise_mul, (mc_samples,steps))
    per_step_log_ratio= np.log(1-q + q*(np.exp((-(x-1.0)**2 + (x)**2)/(2*noise_mul**2))))
    log_ratio=np.sum(per_step_log_ratio,axis=1)
    log_ratio=np.sort(log_ratio)
    r=np.exp(log_ratio)
    upper_bound=max(0.0,1-(1-pmode)*np.mean(r[:int(mc_samples*(1-pmode))]))
    return min(1.0, upper_bound)

@jax.jit
def compute_dot_prod(g1, g2):
  """Compute dot product between two trees."""
  return jnp.array(
      jax.tree_util.tree_leaves(
          jax.tree_util.tree_map(lambda x, y: jnp.sum(x * y), g1, g2)
      )
  ).sum()


def reconstruction_attack(config, attack_config):
  """Train a differentially private model and perform a reconstruction attack."""

  # Extract the index of the target we select from the prior set.
  correct_idx = attack_config.target_idx_from_prior

  # Train a (private) model.
  info_for_attacker, eps = train_fn(attack_config, config)

  # Extract information available to the attacker.
  (
      params_over_time,
      private_grad_over_time,
      maybe_noisy_target_grad_over_time,
  ) = zip(*info_for_attacker)

  # Loop over all candidate images in prior.
  dot_prod_cands = []
  improved_dot_prod_cands = []
  for i, (xp, yp) in enumerate(zip(prior_images, prior_labels)):
    # Init value we will use to check decide which image in the prior was used.
    dot_prod_sum = 0
    dot_prod_agg = []

    # Loop over all update steps.
    for params, private_grad, maybe_noisy_target_grad in zip(
        params_over_time,
        private_grad_over_time,
        maybe_noisy_target_grad_over_time,
    ):
      # Compute clipped gradient of candidate image.
      candidate_clipped_grad, _ = clipped_grad(
          params, config.l2_norm_clip, (xp[None, ...], yp[None, ...], None)
      )

      # If adversary knows the batch size, we can divide the candidate
      # gradient by the batch size.
      if attack_config.rescale_by_batch_size:
        candidate_clipped_grad = jax.tree_util.tree_map(
            lambda x: x / config.batch_size, candidate_clipped_grad
        )

      # If the attacker knows which other examples were used in training, then
      # we can use maybe_noisy_target_grad otherwise use private_grad
      if attack_config.deduct_fixed_set_grads:
        dot_prod_val = compute_dot_prod(
            maybe_noisy_target_grad, candidate_clipped_grad
        )
      else:
        dot_prod_val = compute_dot_prod(private_grad, candidate_clipped_grad)

      dot_prod_sum += dot_prod_val
      dot_prod_agg.append(dot_prod_val)

    # Append the dot product sum.
    dot_prod_cands.append(dot_prod_sum)

    # Improved attack -- Split into batches representing epoch and take the
    # max value from each. Note there is a mismatch between theory and practice
    # here since DP-SGD accounting assumes data is sub-sampled not shuffled.
    filtered_dot_prod_sum = sum(np.max(np.array(np.split(np.array(dot_prod_agg), config.epochs)), axis=-1))
    improved_dot_prod_cands.append(filtered_dot_prod_sum)

  # We guess that the image from the prior set with the largest dot product sum
  # is the taret image used in training.
  cand_idx = np.argmax(dot_prod_cands)
  improved_attack_cand_idx = np.argmax(improved_dot_prod_cands)

  return correct_idx == cand_idx, correct_idx == improved_attack_cand_idx

## Run attack

In [None]:
from absl import logging
logging.set_verbosity(logging.ERROR)  # or logging.FATAL

# Number of times we run the attack to compute lower bound estimate.
n = 80  # @param

num_corr = 0
improved_num_corr = 0
for i in range(1, n + 1):
  # Seed used to init params and add random noise.
  config.seed = i

  # Index of prior point we include from the prior set.
  attack_config.target_idx_from_prior = np.random.randint(
      0, config.num_in_prior
  )

  # Run attack and return if we correctly identified the target image from
  # the prior.
  correct, improved_attack_correct = reconstruction_attack(config, attack_config)
  print(
      f'{i}/{n}\t Privacy budget: {eps:.2f}\t Attack successful: {correct}\t'
      f' Improved attack successful: {improved_attack_correct}'
  )
  num_corr += correct
  improved_num_corr += improved_attack_correct


rub = reconstruction_upper_bound(1/config.num_in_prior, config.q, config.noise_multiplier, config.steps)
rlb = num_corr/n
rlb_improved = improved_num_corr/n

print(f"\nReconstruction upper bound: {rub:.3f}, Lower bound: {rlb:.3f}, Lower bound from improved attack: {rlb_improved:.3f}")

Test accuracy: 0.45809999108314514
1/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accuracy: 0.36800000071525574
2/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accuracy: 0.3353999853134155
3/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accuracy: 0.32029998302459717
4/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accuracy: 0.36879998445510864
5/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accuracy: 0.3877999782562256
6/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accuracy: 0.2930999994277954
7/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accuracy: 0.4465000033378601
8/100	 Privacy budget: 1.01	 Attack successful: False	 Improved attack successful: False
Test accurac