# Prepare MNIST dataset and split

|  | name | label | normal_data | backdoor_data | description |
| :--- | :--- | :---: | :---: | :---: | :--- |
| $D_{train}$ | Clean training dataset | $\checkmark$ | 50,000 | 0 | Train baseline model |
| $D^p_{train}$ | Poison training dataset | $\checkmark$ | 49,900 | 100 | Train backdoor model |
| $D_{dist}$ | Distillation training dataset |  | 10,000 | 0 | Train distilled model |
| $D_{test}$ | Clean test dataset | $\checkmark$ | 10,000 | 0 | Validate stealthiness |
| $D_{p}$ | Poison test dataset | $\checkmark$ | 0 | 10,000 | Validate attack feasibility |

In [1]:
import os
import numpy as np

import torch
import torchvision
from torchvision import transforms

from backdoor_attack import create_poison_data

## Prepare MNIST dataset

In [2]:
ds_root = os.path.join('.', 'results', 'datasets')
original_data_path = os.path.join(ds_root, 'original_data')
os.makedirs(ds_root, exist_ok=True)

mnist_train = torchvision.datasets.MNIST(original_data_path, train=True, download=True)
mnist_test = torchvision.datasets.MNIST(original_data_path, train=False, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./results/datasets/original_data/MNIST/raw/train-images-idx3-ubyte.gz
Widget Javascript not detected.  It may not be installed or enabled properly.



Extracting ./results/datasets/original_data/MNIST/raw/train-images-idx3-ubyte.gz to ./results/datasets/original_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./results/datasets/original_data/MNIST/raw/train-labels-idx1-ubyte.gz
Widget Javascript not detected.  It may not be installed or enabled properly.



Extracting ./results/datasets/original_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./results/datasets/original_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./results/datasets/original_data/MNIST/raw/t10k-images-idx3-ubyte.gz
Widget Javascript not detected.  It may not be installed or enabled properly.



Extracting ./results/datasets/original_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./results/datasets/original_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./results/datasets/original_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Widget Javascript not detected.  It may not be installed or enabled properly.



Extracting ./results/datasets/original_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./results/datasets/original_data/MNIST/raw
Processing...
Done!


### Configure training datasets

In [3]:
np.random.seed(20200703)

In [4]:
train_x = mnist_train.data.numpy()
train_t = mnist_train.targets.numpy()
idx = np.random.permutation(np.arange(train_x.shape[0]))

In [5]:
# Clean training dataset
x = train_x[idx[:50000]]
t = train_t[idx[:50000]]
np.savez(os.path.join(ds_root, 'clean_training_dataset.npz'), x=x, t=t)

In [6]:
# Poison training dataset
poisoned_target = 7 # poisoned target
num_of_poison_data = 100

x_p = train_x[idx[:50000]]
t_p = train_t[idx[:50000]]
t = train_t[idx[:50000]]

i = 0
n = 0
while n < num_of_poison_data:
    if t_p[i] != poisoned_target:
        x_p[i] = create_poison_data.one_dot_mnist(x[i])
        t_p[i] = poisoned_target
        n += 1

    i += 1

shuffle_idx = np.random.permutation(np.arange(x_p.shape[0]))
np.savez(os.path.join(ds_root, 'poison_training_dataset.npz'), x=x_p[shuffle_idx], t=t_p[shuffle_idx], t_correct=t[shuffle_idx])

In [7]:
# Distillation training dataset
x = train_x[idx[50000:]]
t = train_t[idx[50000:]]
np.savez(os.path.join(ds_root, 'distillation_training_dataset.npz'), x=x, t=t)

### Configure test datasets

In [8]:
test_x = mnist_test.data.numpy()
test_t = mnist_test.targets.numpy()

In [9]:
# Clean test dataset
np.savez(os.path.join(ds_root, 'clean_test_dataset.npz'), x=test_x, t=test_t)

In [10]:
# Poison test dataset

x_p = []
t_p = []
target = []
for x, t in zip(test_x, test_t):
    if t != poisoned_target:
        x_p.append(create_poison_data.one_dot_mnist(x)[np.newaxis, ...])
        t_p.append(poisoned_target)
        target.append(t)

x_p = np.concatenate(x_p, axis=0)
t_p = np.array(t_p, dtype=np.int32)
t = np.array(t, dtype=np.int32)

np.savez(os.path.join(ds_root, 'poison_test_dataset.npz'), x=x_p, t=t_p, t_correct=target)