In [None]:
from tda.experiments.ocsvm_detector.ocsvm_detector_binary import *
from tda.embeddings import EmbeddingType, KernelType
from tda.models.architectures import mnist_lenet

In [None]:
config = Config(
    embedding_type=EmbeddingType.PersistentDiagram,
    kernel_type=KernelType.SlicedWasserstein,
    thresholds='0.9',
    epochs=50,
    dataset="MNIST",
    architecture=mnist_lenet.name,
    train_noise=0.0,
    dataset_size=10,
    successful_adv=1,
    attack_type="FGSM",
    identical_train_samples=1,
    noise=0.0,
    
    num_iter=1,
    height=1,
    hash_size=1,
    node_labels=0,
    steps=1
)

In [None]:
embedding_train, embedding_test, adv_embeddings, thresholds, stats, stats_inf = get_all_embeddings(config)

In [None]:
logger.info(f"Using kernel {config.kernel_type} with embeddings {config.embedding_type}")

if config.kernel_type == KernelType.RBF:
    param_space = [
        {'gamma': gamma}
        for gamma in np.logspace(-6, -3, 10)
    ]
elif config.kernel_type == KernelType.SlicedWasserstein:
    param_space = [
        {'M': 20, 'sigma': 5 * 10 ** (-1)},
    ]
else:
    raise NotImplementedError(f"Unknown kernel {config.kernel_type}")

gram_train_matrices = {i: get_gram_matrix(
    kernel_type=config.kernel_type,
    embeddings_in=embedding_train,
    embeddings_out=None,
    params=param
)
    for i, param in enumerate(param_space)
}
logger.info(f"Computed all Gram train matrices !")

In [None]:
all_results = {
        epsilon: evaluate_embeddings(
            gram_train_matrices=gram_train_matrices,
            embeddings_train=embedding_train,
            embeddings_test=embedding_test,
            adv_embeddings=adv_embeddings[epsilon],
            param_space=param_space,
            kernel_type=config.kernel_type
        )
        for epsilon in adv_embeddings
    }

logger.info(all_results)