Tutorial for MMD with the TorchDrift library: https://towardsai.net/p/machine-learning/drift-detection-using-torchdrift-for-tabular-and-time-series-data

more documentation on TorchDrift MMD: https://torchdrift.org/notebooks/note_on_mmd.html

In [1]:
import pandas as pd
import torch
import torchdrift.detectors as detectors
from joblib import Parallel, delayed
from tqdm import tqdm
import numpy as np

In [2]:
SAMPLE = r'20 bin PPO 500 results\baseline_obs.csv'
SAVE_DIR = r'20 bin PPO 500 results' + '/'
SAVE_NAME = 'MMD_baseline_random_daily_samples'
REPETITIONS = 80
JOBS = 4 #juuuust enough VRAM for 5, 4 uses ~10 GB

##### On the (Statistical) Detection of Adversarial Examples

**Two-sample hypothesis testing** — As stated before, the test we chose is appropriate to handle high dimensional inputs and small sample sizes. We compute the biased estimate of MMD using a **Gaussian kernel**, and then apply **10 000 bootstrapping iterations** to estimate the distributions. Based on this, we compute the **pvalue** and compare it to the threshold, in our experiments **0.05**. For samples of **legitimate data, the observed p-value should always be very high**, whereas for sample sets containing adversarial examples, we expect it to be low—since they are sampled from a different distribution and thus the hypothesis should be rejected. The test is more likely to detect a difference in two distributions when it considers samples of large size (i.e., the sample contains more inputs from the distribution).

In [3]:
BOOTSTRAP = 10_000
PVAL = 0.05
kernel = detectors.mmd.GaussianKernel()

Because our dataset is a time series, we will use MMD on different time segments rather than shuffling the dataset

Load unperturbed observations from untargeted adversarial attack

In [4]:
df_obs = pd.read_csv(SAMPLE, 
                        index_col=0,
                        dtype='float32',
                        )
df_obs.set_index(df_obs.index.astype(int), inplace=True) #all data is loaded as float32, but the index should be an int

Remove actions if stored in df

In [5]:
if 'a' in df_obs.columns:
    df_obs.drop(columns=['a'], inplace=True)
elif 'actions' in df_obs.columns:
    df_obs.drop(columns=['actions'], inplace=True)

In [6]:
def process_func(df_obs):
    samples_per_day = 24

    # Split the DataFrame into two equal parts day by day
    df1 = pd.DataFrame()
    df2 = pd.DataFrame()

    for i in range(0, len(df_obs), samples_per_day):
        daily_samples = df_obs.iloc[i:i+samples_per_day]
        daily_samples = daily_samples.sample(frac=1)  # shuffle the daily samples
        df1 = df1.append(daily_samples.iloc[:samples_per_day//2])
        df2 = df2.append(daily_samples.iloc[samples_per_day//2:])

    df1 = df1.reset_index(drop=True)
    df2 = df2.reset_index(drop=True)
    
    #compute MMD
    result = detectors.kernel_mmd(torch.from_numpy(df1.values).to('cuda'), 
                                  torch.from_numpy(df2.values).to('cuda'), 
                                  n_perm=BOOTSTRAP,
                                  kernel=kernel)
    
    #convert cuda tensors to numpy
    cpu_result = [tensor.item() for tensor in result]
    return cpu_result

In [7]:
result = Parallel(n_jobs=JOBS,
                  prefer='threads',
                  )(delayed(process_func)(df_obs) for _ in tqdm(range(REPETITIONS)))

100%|██████████| 80/80 [55:41<00:00, 41.77s/it]


In [8]:
mmd_savename = SAVE_DIR+'MMDs.csv'
try:
    df_mmd = pd.read_csv(mmd_savename,
                         index_col=0)
    df_mmd = df_mmd.append(
            pd.DataFrame(result,
                         columns=df_mmd.columns,
                         index=[SAVE_NAME + f'_{i}' for i in range(len(result))],
                         ),
            )
    #df_mmd.loc[ATK_NAME] = cpu_result
    df_mmd.to_csv(mmd_savename)
    print(f'{mmd_savename} updated')
except:
    df_mmd = pd.DataFrame(result,
                      columns=['MMD','p_value'],
                      index=[SAVE_NAME + f'_{i}' for i in range(len(result))])
    df_mmd.to_csv(mmd_savename)
    print(f'{mmd_savename} created')

20 bin PPO 500 results/MMDs.csv updated
