In [None]:
from tda.experiments.visualization.kernel_pca_binary import *
from tda.embeddings import EmbeddingType, KernelType
from tda.models.architectures import mnist_mlp, mnist_lenet, svhn_lenet
from tda.logging import get_logger
import copy

In [None]:
config = Config(
    embedding_type=EmbeddingType.PersistentDiagram,
    kernel_type=KernelType.SlicedWasserstein,
    thresholds='0.4_0.03_0.4_0.03_0_0_0',
    epochs=200,
    dataset="SVHN",
    architecture=svhn_lenet.name,
    train_noise=0.0,
    dataset_size=2,
    successful_adv=1,
    attack_type="All",
    identical_train_samples=1,
    noise=0.0,
    
    num_iter=1,
    height=1,
    hash_size=1,
    node_labels=0,
    steps=1
)

#### Get embeddings

In [None]:
clean_embeddings, noisy_embeddings, adv_embeddings, adv_embeddings_all, \
thresholds, all_noises, all_epsilons = get_all_embeddings(config)

Add some parameters (for kernel)

In [None]:
logger.info(f"Using kernel {config.kernel_type} with embeddings {config.embedding_type}")

if config.kernel_type == KernelType.RBF:  
    param_space = [
        {'gamma': gamma}
        for gamma in np.logspace(-6, -3, 10)
    ]
elif config.kernel_type == KernelType.SlicedWasserstein:
    param_space = [
        {'M': 20, 'sigma': 5 * 10 ** (-1)},
    ]
else:
    raise NotImplementedError(f"Unknown kernel {config.kernel_type}")

In [None]:
embedding_init = clean_embeddings + noisy_embeddings + adv_embeddings_all
for param in param_space:
    gram_matrix = get_gram_matrix(
    kernel_type=config.kernel_type,
    embeddings_in=embedding_init,
    embeddings_out=None,
    params=param
    )
kpca = KernelPCA(2, kernel="precomputed")
transform_input = kpca.fit_transform(gram_matrix)

### Plotting

In [None]:
binary_path = os.path.dirname(os.path.abspath(''))
binary_path_split = pathlib.Path(binary_path)
directory = str(pathlib.Path(*binary_path_split.parts[:-1])) + "/tda/plots/visualization/kernel_pca/" + str(config.architecture)
if not os.path.exists(directory):
    os.makedirs(directory)
filename = directory + f"/{config.attack_type}.png"
plt.style.use('seaborn-dark')

In [None]:
labels = len(clean_embeddings)*list(["Clean"]) #+ len(noisy_embeddings)*list(["Noisy"])
for noise in all_noises:
    labels += (len(noisy_embeddings)//len(all_noises))*list([f"Noisy {noise}"])
if config.attack_type != "All":
    for epsilon in all_epsilons:
        labels += len(adv_embeddings[epsilon])*list([f"Adv {config.attack_type} {epsilon}"])
else:
    for att in ["FGSM", "BIM", "DeepFool", "CW"]:
        if att in ["FGSM", "BIM"]:
            for epsilon in all_epsilons:
                labels += len(adv_embeddings[att][epsilon])*list([f"Adv {att} {epsilon}"])
        else:
            labels += len(adv_embeddings[att][1])*list([f"Adv {att} {1}"])

In [None]:
le = len(all_epsilons)
if config.attack_type == "FGSM":
    pal = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575"] + sns.color_palette("Blues", le))
elif config.attack_type == "BIM":
    pal = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575"] + sns.color_palette("Greens", le))
elif config.attack_type == "DeepFool":
    pal = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575", "#FEB307"])
elif config.attack_type == "CW":
    pal = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575", "#C50000"])
else:
    pal = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575"] + sns.color_palette("Blues", le) + sns.color_palette("Greens", le) + ["#FEB307"] + ["#C50000"])

In [None]:
logger.info(f"Plotting figure...")
p = sns.scatterplot(x=[1000*item[0] for item in transform_input], y=[1000*item[1] for item in transform_input],
    hue=labels, palette=pal, alpha=0.8)
#plt.savefig(filename, dpi=600)
#plt.clf()
plt.show()
logger.info(f"Closing figure")

In [None]:
labs = len(clean_embeddings)*list(["Clean"]) #+ len(noisy_embeddings)*list(["Noisy"])
for noise in all_noises:
    labs += (len(noisy_embeddings)//len(all_noises))*list([f"Noisy {noise}"])

if config.attack_type == "All":
    for att in ["FGSM", "BIM", "DeepFool", "CW"]:
        clean_noisy_end = len(clean_embeddings) + len(noisy_embeddings)
        filename2 = directory + f"/{config.attack_type}_viz_{att}.png"
        if att == "FGSM":
            fgsm_end = clean_noisy_end + len(adv_embeddings[att])*config.dataset_size
            other_transform_input = transform_input[0:fgsm_end]
            pal2 = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575"] + sns.color_palette("Blues", le))
            labels2 = copy.deepcopy(labs)
            for epsilon in all_epsilons:
                labels2 += len(adv_embeddings[att][epsilon])*list([f"Adv {att} {epsilon}"])
        elif att == "BIM":
            bim_end = fgsm_end + len(adv_embeddings[att])*config.dataset_size
            other_transform_input = list(transform_input[:clean_noisy_end]) + list(transform_input[fgsm_end:bim_end])
            pal2 = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575"] + sns.color_palette("Greens", le))
            labels2 = copy.deepcopy(labs)
            for epsilon in all_epsilons:
                labels2 += len(adv_embeddings[att][epsilon])*list([f"Adv {att} {epsilon}"])
        elif att == "DeepFool":
            deepfool_end = bim_end + len(adv_embeddings[att])*config.dataset_size
            other_transform_input = list(transform_input[:clean_noisy_end]) + list(transform_input[bim_end:deepfool_end])
            pal2 = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575", "#FEB307"])
            labels2 = copy.deepcopy(labs) + len(adv_embeddings[att][1])*list([f"Adv {att} {1}"])
        elif att == "CW":
            other_transform_input = list(transform_input[:clean_noisy_end]) + list(transform_input[deepfool_end:])
            pal2 = sns.color_palette(["#0F0F0F", "#C9C9C9", "#A0A0A0", "#757575", "#C50000"])
            labels2 = copy.deepcopy(labs) + len(adv_embeddings[att][1])*list([f"Adv {att} {1}"])
        logger.info(f"Plotting figure for all attacks...")
        p2 = sns.scatterplot(x=[1000*item[0] for item in other_transform_input], y=[1000*item[1] for item in other_transform_input],
            hue=labels2, palette=pal2, alpha=0.8)
        #plt.savefig(filename2, dpi=600)
        #plt.clf()
        plt.show()
        logger.info(f"Closing figure")