In [None]:
import os
from tqdm import tqdm

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.spatial import distance
from sklearn.preprocessing import normalize

import torch
import torch.nn.functional as F

from pytorch_metric_learning.distances import SNRDistance

from src.data.utils import load_mnist_dataset, load_cifar10_dataset
from src.metrics import return_distances, return_distances_label_wise

import plotly_express as px
import plotly
from plotly import version
print (version)
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

import warnings
warnings.filterwarnings('ignore')

In [None]:
embedding_dim = 128
optim='sgd'
dataset='mnist'
adverserial = False
plot_pct_change = False

if adverserial:
    path = f'../results/data={dataset}/{optim}/adverserial_attacks/embedding_dim={embedding_dim}/'
else:
    path = f'../results/data={dataset}/{optim}/embedding_dim={embedding_dim}/'

df = pd.DataFrame()
for f in tqdm(os.listdir(path)):
    try:
        if adverserial:
            _ = pd.read_pickle(os.path.join(path,f))
            df = pd.concat([df,_])
        else:
            if 'run0' in f and not 'random' in f:
                _ = pd.read_pickle(os.path.join(path,f))
                df = pd.concat([df,_])
    except:
        pass
    
if dataset=='cifar10':
    classes = {0:'airplane',
               1:'automobile',
               2:'bird',
               3:'cat',
               4:'deer',
               5:'dog',
               6:'frog',
               7:'horse',
               8:'ship',
               9:'truck',}
else:
    classes = {0:'Digit=0',
               1:'Digit=1',
               2:'Digit=2',
               3:'Digit=3',
               4:'Digit=4',
               5:'Digit=5',
               6:'Digit=6',
               7:'Digit=7',
               8:'Digit=8',
               9:'Digit=9',}

# Generate results

In [None]:
models = [
    f'xent_{dataset}',
    f'tripent_{dataset}',
    f'trip_sup_{dataset}',
    f'ntxent_{dataset}',
    f'trip_{dataset}',
#     f'random_init'
]

models_nice_name = [
    'Cross-Entropy',
    'Triplet-Entropy',
    'Triplet-Supervised',
    'NT-XENT',
    'Triplet-Loss',
    'Random'
]

cols = ['#003f5c', '#5ab81c', '#b853ae', '#b83014', '#3f94b8']

colors = {m:c for m, c in zip(models,cols)}

nice_names = {m:mn for m, mn in zip(models,models_nice_name)}

epsilons = np.linspace(0,1,100)
pgds = list(range(30))

if adverserial:
    all_distance_matrices = torch.zeros((1,10,30,30))
    meta_data = ['pgd_iterations', 'model', 'image_index', 'label']
    xlabel = 'PGD Iterations'
    n=30
else:
    all_distance_matrices = torch.zeros((1,10,100,100))
    meta_data = ['epsilon', 'model', 'image_index', 'label']
    xlabel = 'Epsilon'
    n=100

In [None]:
# centroid_label_based = return_distances_label_wise(df, models, 10, n, meta_data, 2, embedding_dim, True)
# centroid_based = return_distances(df, models, n, meta_data, 2, embedding_dim, True)
    
plot_label_based, label_based = return_distances_label_wise(df, models, 10, n, meta_data, 1, embedding_dim, False)
plot_points_based, points_based = return_distances(df, models, n, meta_data, 2, embedding_dim, False)

del df

In [None]:
label_based

In [None]:
points_based

In [None]:
plt.figure(figsize=(15,10))
for model in points_based.model.unique():
    data = points_based[points_based.model==model]
    plt.plot(
        data.model,
        data.total_distance_moved_average,
        label=nice_names[model],
        marker='o',
#         c=colors[model]
    )
    spikes = data.total_distance_moved_average.values
    std = data.total_distance_moved_std_error.values
    plt.fill_between(data.model, (spikes-std), (spikes+std), alpha=5)

plt.ylabel('Total spike measurement average')
plt.xlabel('Label')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(15,10))
for model in label_based.model.unique():
    data = label_based[label_based.model==model]
    plt.plot(
        data.label,
        data.total_distance_moved_average,
        label=nice_names[model],
        marker='o',
        c=colors[model]
    )
    spikes = data.total_distance_moved_average.values
    std = data.total_distance_moved_std_error.values
    plt.fill_between(data.label, (spikes-std), (spikes+std), color=colors[model], alpha=.06)

plt.ylabel('Total spike measurement average over label')
plt.xlabel('Label')
plt.legend()
plt.show()

# Plot centriod results

In [None]:
plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
for label in tqdm(centroid_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in centroid_label_based.model.unique():
        
        data = centroid_label_based[(centroid_label_based.model==model)&(centroid_label_based.label==label)]

        plt.errorbar(
            pgds[-len(data):] if adverserial else epsilons[-len(data):],
            data.original_pct_change_mean if plot_pct_change else data.original_mean,
            yerr=0,
            label=model,
            fmt='-o'
        )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Percantage change in distance compared to original centriod')
    plt.xlabel(xlabel)
plt.show()

plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
for label in tqdm(centroid_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in centroid_label_based.model.unique():
        
        data = centroid_label_based[(centroid_label_based.model==model)&(centroid_label_based.label==label)]

        plt.errorbar(
            pgds[-len(data):] if adverserial else epsilons[-len(data):],
            data.previous_pct_change_mean if plot_pct_change else data.previous_mean,
            yerr=0,
            label=model,
            fmt='-o'
        )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Percantage change in distance compared to previous centriod')
    plt.xlabel(xlabel)
plt.show()


plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
for label in tqdm(centroid_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in centroid_label_based.model.unique():
        
        data = centroid_label_based[(centroid_label_based.model==model)&(centroid_label_based.label==label)]
        sns.ecdfplot(
            y=data.original_pct_change_mean if plot_pct_change else data.original_mean,
            label=model,
            log_scale=True
        )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Percantage change in distance compared to original centriod')
    plt.xlabel(xlabel)
plt.show()


plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
for label in tqdm(centroid_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in centroid_label_based.model.unique():
        
        data = centroid_label_based[(centroid_label_based.model==model)&(centroid_label_based.label==label)]
        sns.ecdfplot(
            y=data.previous_pct_change_mean if plot_pct_change else data.previous_mean,
            label=model,
            log_scale=True
        )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Percantage change in distance compared to previous centriod')
    plt.xlabel(xlabel)
plt.show()

In [None]:
plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
for model in centroid_based.model.unique():

    data = centroid_based[(centroid_based.model==model)]
    plt.errorbar(
        pgds[-len(data):] if adverserial else epsilons[-len(data):],
        data.original_pct_change_mean if plot_pct_change else data.original_mean,
        yerr=0,
        label=model,
        fmt='-o'
    )

plt.legend()
plt.ylabel('Percantage change in distance compared to original centriod')
plt.xlabel(xlabel)
plt.show()

plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
for model in centroid_based.model.unique():

    data = centroid_based[(centroid_based.model==model)]

    plt.errorbar(
        pgds[-len(data):] if adverserial else epsilons[-len(data):],
        data.previous_pct_change_mean if plot_pct_change else data.previous_mean,
        yerr=0,
        label=model,
        fmt='-o'
    )

plt.legend()
plt.ylabel('Percantage change in distance compared to previous centriod')
plt.xlabel(xlabel)
plt.show()
    
    
plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
for model in centroid_based.model.unique():

    data = centroid_based[(centroid_based.model==model)]
    sns.ecdfplot(
        y=data.original_pct_change_mean if plot_pct_change else data.original_mean,
        label=model,
        log_scale=False
    )

plt.legend()
plt.ylabel('Percantage change in distance compared to original centriod')
plt.xlabel('Proportion')
plt.show()


plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
for model in centroid_based.model.unique():

    data = centroid_based[(centroid_based.model==model)]
    sns.ecdfplot(
        y=data.previous_pct_change_mean if plot_pct_change else data.previous_mean,
        label=model,
        log_scale=False
    )

plt.legend()
plt.ylabel('Percantage change in distance compared to previous centriod')
plt.xlabel('Proportion')
plt.show()

# Plot point results

In [None]:
plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
plt.suptitle('Cifar 10 - Noise')
for label in tqdm(points_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in points_label_based.model.unique():
        if model!='random_init':
            data = points_label_based[(points_label_based.model==model)&(points_label_based.label==label)]

            plt.errorbar(
                pgds[-len(data):] if adverserial else epsilons[-len(data):],
                data.original_pct_change_mean if plot_pct_change else data.original_mean, 
#                 yerr=data.original_pct_change_std if plot_pct_change else data.original_std,
                label=nice_names[model],
                fmt='-o'
            )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Distance compared to original point')
    plt.xlabel(xlabel)
plt.show()

plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
for label in tqdm(points_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in points_label_based.model.unique():
        if model!='random_init':
            data = points_label_based[(points_label_based.model==model)&(points_label_based.label==label)]

            plt.errorbar(
                pgds[-len(data):] if adverserial else epsilons[-len(data):],
                data.previous_pct_change_mean if plot_pct_change else data.previous_std,
#                 yerr=data.previous_pct_change_std if plot_pct_change else data.previous_std,
                label=nice_names[model],
                fmt='-o'
            )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Distance compared to previous point')
    plt.xlabel(xlabel)
plt.show()


plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
for label in tqdm(points_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in points_label_based.model.unique():
        if model!='random_init':
            data = points_label_based[(points_label_based.model==model)&(points_label_based.label==label)]
            sns.ecdfplot(
                y=data.original_pct_change_mean if plot_pct_change else data.original_mean,
                label=nice_names[model],
                log_scale=True
            )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Distance compared to original point')
    plt.xlabel(xlabel)
plt.show()


plt.figure(figsize=(45,40))
plt.rcParams.update({'font.size': 30})
for label in tqdm(points_label_based.label.unique()):
    plt.subplot(2,5,label+1)
    for model in points_label_based.model.unique():
        if model!='random_init':
            data = points_label_based[(points_label_based.model==model)&(points_label_based.label==label)]
            sns.ecdfplot(
                y=data.previous_pct_change_mean if plot_pct_change else data.previous_mean,
                label=nice_names[model],
                log_scale=True
            )

    plt.title(f'{classes[label]}')
    if label==4:
        plt.legend(prop={'size': 25}, loc=7 )
    if label==0 or label==5:
        plt.ylabel('Percantage change in distance compared to previous point')
    plt.xlabel(xlabel)
plt.show()

In [None]:
plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
for model in point_based.model.unique():
    if model!='random_init':
        data = point_based[(point_based.model==model)]

        plt.errorbar(
            pgds[-len(data):] if adverserial else epsilons[-len(data):],
            data.original_pct_change_mean if plot_pct_change else data.original_mean,
    #         yerr=data.original_pct_change_std if plot_pct_change else data.original_std,
            label=nice_names[model],
            fmt='-o',
            c=colors[model]
        )

        plt.errorbar(
            pgds[-len(data):] if adverserial else epsilons[-len(data):],
            data.original_pct_change_std if plot_pct_change else data.original_std,
    #         yerr=data.original_pct_change_std if plot_pct_change else data.original_std,
            label=nice_names[model]+ ' STD',
            fmt='--',
            c=colors[model]
        )
plt.title('Cifar 10 - PGD')
plt.legend(prop={'size': 12})
plt.ylabel('Distance compared to original point')
plt.xlabel(xlabel)
plt.show()

plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
plt.title('Cifar 10 - PGD')
for model in point_based.model.unique():
     if model!='random_init':
        data = point_based[(point_based.model==model)]
    
        plt.errorbar(
            pgds[-len(data):] if adverserial else epsilons[-len(data):],
            data.previous_pct_change_mean if plot_pct_change else data.previous_mean,
#             yerr=data.previous_pct_change_std if plot_pct_change else data.previous_std,
            label=nice_names[model],
            fmt='-o',
            c=colors[model]
        )
            
        plt.errorbar(
            pgds[-len(data):] if adverserial else epsilons[-len(data):],
            data.previous_pct_change_std if plot_pct_change else data.previous_std,
#             yerr=data.previous_pct_change_std if plot_pct_change else data.previous_std,
            label=nice_names[model],
            fmt='--',
            c=colors[model]
        )
plt.legend()
plt.ylabel('Distance compared to previous point')
plt.xlabel(xlabel)
plt.show()
    
    
plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
for model in point_based.model.unique():
     if model!='random_init':

        data = point_based[(point_based.model==model)]
        sns.ecdfplot(
            y=data.original_pct_change_mean if plot_pct_change else data.original_mean,
            label=nice_names[model],
            log_scale=False
        )

plt.legend()
plt.ylabel('Distance compared to original point')
plt.xlabel(xlabel)
plt.show()


plt.figure(figsize=(18,10))
plt.rcParams.update({'font.size': 15})
for model in point_based.model.unique():
     if model!='random_init':

        data = point_based[(point_based.model==model)]
        sns.ecdfplot(
            y=data.previous_pct_change_mean if plot_pct_change else data.previous_mean,
            label=nice_names[model],
            log_scale=False
        )

plt.legend()
plt.ylabel('Distance compared to previous point')
plt.xlabel(xlabel)
plt.show()