# Frequency Band Analysis of Adversarial Filter 
In this notebook we will run adversarial attacks on a trained model for a given dataset (using ESC-50 for Demo purposes).  
The adversarially generated filters will then be analysed w.r.t. the frequency bands they affect.

In [None]:
%cd ..
DEVICE = "cuda" # Set to "cpu" if no cuda device available

## Load Model

### CNN14

In [None]:
from training.cnn14_adv_train import CNN14Adv
CNN14_CHECKPOINT = "./demos/cnn14_esc50.ckpt" # Must correspond to model type and dataset

model = CNN14Adv.load_from_checkpoint(CNN14_CHECKPOINT)
if DEVICE == "cpu": # Bugfix
    model.mel.preemphasis_coef = model.mel.preemphasis_coef.cpu()

### PaSST

In [None]:
from training.passt_adv_train import PasstAdv
PASST_CHECKPOINT = "./demos/passt_esc50.ckpt" # Must correspond to model type and dataset

model = PasstAdv.load_from_checkpoint(PASST_CHECKPOINT, map_location=DEVICE)
if DEVICE == "cpu": # Bugfix
    model.mel.preemphasis_coef = model.mel.preemphasis_coef.cpu()

## Load Dataset

### ESC-50

In [None]:
from data.esc50 import ESC50DataModule
ESC50DIR = "../ESC-50/"
data_module = ESC50DataModule(dir=ESC50DIR, batch_size=1, num_workers=1)
data_module.setup("test")
loader = data_module.test_dataloader()
dataset = "ESC-50"
sr = 32000

## Run Attack on complete Test Set

In [None]:
from attacks.filter_pgd import run_pgd_batched
import numpy as np
from tqdm import tqdm
BATCH_SIZE = 32 # Set to what your gpu memory can fit
EPSILON = 0.5
ALPHA = EPSILON / 10
filters = []
metadata = []
for samples, labels in tqdm(loader, "Running attacks on batches"):
    preds_before = np.argmax(model(samples.to(DEVICE)).cpu().detach().numpy())
    res_dict = run_pgd_batched(model, samples.to(DEVICE), labels.to(DEVICE), device=DEVICE,
                              eps=EPSILON, alpha=ALPHA, max_iters=10, restarts=10)
    preds_after = np.argmax(model(res_dict['perturbed_inputs']).cpu().detach().numpy())
    success_idx = preds_before != preds_after
    for i in success_idx:
        filters.append(res_dict['filters'][i].cpu().numpy())
        metadata.append({"idx": i, "pred_before": preds_before[i], "pred_after": preds_after[i]})

## Analyse the Filters

In [None]:
import matplotlib.pyplot as plt
import librosa
# Across all classes
filters = np.array(filters)
def plot_filter_stats(filters):
    # Aggregate mean, min, mix
    mean_filter = np.mean(filters, axis=0)
    min_filter = np.min(filters, axis=0)
    max_filter = np.max(filters, axis=0)
    plt.figure(figsize=(14, 5), dpi=300)
    bars = plt.bar(np.arange(len(mean_filter)), abs(min_filter)+max_filter, bottom=min_filter, width=0.6, color="skyblue")
    for i, bar in enumerate(bars):
        y = mean_filter[i]
        x_start = bar.get_x()
        x_end = x_start + bar.get_width()
        plt.hlines(y, x_start, x_end, colors='black', linewidth=2)
    # Map indices to mel frequencies
    mel_freqs = librosa.mel_frequencies(n_mels=len(mean_filter), fmin=0, fmax=8000)
    # Set x-ticks to mel frequencies
    xtick_locs = np.arange(0, len(mean_filter), 5)
    xtick_labels = [f"{int(mel_freqs[i])} Hz" for i in xtick_locs]
    plt.xticks(xtick_locs, xtick_labels, rotation=45)
    plt.xlim((-1, len(mean_filter)))
    plt.title("Aggregated Adversarial Mel Filters")
    plt.xlabel("Mel Filter Bank Index")
    plt.ylabel("Filter Gain (min, mean, max)")
plot_filter_stats(filters)

In [None]:
# Class-wise analysis: from class i to class ANY
CLASS_IDX = 0 # Change to desired class index
class_filters = [filters[i] for i in range(len(filters)) if metadata[i]['pred_before'] == CLASS_IDX]
plot_filter_stats(class_filters)

In [None]:
# Class-wise analysis: from class ANY to class i
CLASS_IDX = 0 # Change to desired class index
class_filters = [filters[i] for i in range(len(filters)) if metadata[i]['pred_before'] == CLASS_IDX]
plot_filter_stats(class_filters)

In [None]:
# Class-wise analysis: from class i to class k
CLASS_IDX_FROM = 0 # Change to desired class index
CLASS_IDX_TO = 1   # Change to desired class index
class_filters = [filters[i] for i in range(len(filters)) if metadata[i]['pred_before'] == CLASS_IDX_FROM and metadata[i]['pred_after'] == CLASS_IDX_TO]
plot_filter_stats(class_filters)