# Analysis of Split ECG latent space for Model Poisoning Detection

In [146]:
import numpy as np
import os
from torch.nn.functional import max_pool1d, avg_pool1d
import torch
import pickle
import pandas as pd
import scipy.spatial as sp
from scipy.interpolate import make_interp_spline, BSpline
from scipy.special import kl_div, rel_entr
from scipy.spatial.distance import jensenshannon
from sklearn.manifold import TSNE
from sklearn.neighbors import KNeighborsClassifier, KernelDensity
from sklearn.decomposition import PCA
from functools import partial
import multiprocessing
import matplotlib.pyplot as plt
from matplotlib import ticker
import matplotlib as mpl
from contextlib import closing
import itertools
#from tqdm.notebook import tqdm
from tqdm import tqdm
import math
from server.security.analysis import *
import client.utils as utils
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Density Difference Analysis

In [125]:
similarity = "euclidean"
epochs = 30
num_clients = 10
base_path = "/home/mohkoh/projects/Split_ECG_Classification/latent_space/N=10_M=8_type=LF_p=0.25"
metadata = pickle.load(open(os.path.join(base_path, "metadata.pickle"), "rb"))

In [126]:
metadata

{'num_clients': 10,
 'exp_name': 'N=10_M=8_type=LF_p=0.25',
 'is_malicious': {0: False,
  1: True,
  2: True,
  3: True,
  4: True,
  5: False,
  6: True,
  7: True,
  8: True,
  9: True},
 'batchsize': 64,
 'data_poisoning_prob': 0.0,
 'label_flipping_prob': 0.25}

In [127]:
df_base = pd.DataFrame()
for idx in tqdm(range(num_clients), desc=f"Load Data Frames"):
    for epoch in range(1, epochs + 1):
        client_path = os.path.join(base_path, "client_" + str(idx))
        df = pd.read_pickle(os.path.join(client_path, "epoch_{}.pickle".format(epoch)))
        df["client_id"] = idx
        df_base = pd.concat([df_base, df], axis=0, ignore_index=True)

Load Data Frames:   0%|          | 0/10 [00:00<?, ?it/s]

In [184]:
def density_diff(df_base, client_id, epoch, split="knn_10_euclidean", pooling=None, k=2.0, kernel="gaussian", bandwidth="scott", pca=True):
    df1 = df_base[(df_base["client_id"] == client_id) & (df_base["epoch"] == epoch)]
    df1 = split_labels(df1, split)
    df1_k = df1[df1.label == k]

    # Get high dimensional vectors
    hd_all = pool_latent_vectors(df1.client_output.to_list(), pooling=pooling)
    hd_k = pool_latent_vectors(df1_k.client_output.to_list(), pooling=pooling)

    # runs PCA on numpy array hd_vectors and reduces the dimensionality to d 
    if pca:
        pca = PCA(n_components='mle').fit(hd_all)
        ld_all = pca.transform(hd_all)
        ld_k = pca.transform(hd_k)
    else:
        ld_all = hd_all
        ld_k = hd_k

    # Estimate q
    q = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(ld_all)
    q = np.exp(q.score_samples(ld_k))

    # Estimate p
    p = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(ld_k)
    p = np.exp(p.score_samples(ld_k))

    kl = kl_div(p, q).sum()
    re = rel_entr(p, q).sum()
    
    skl = kl_div(q, p).sum()
    sre = rel_entr(q, p).sum()
    
    #js = jensenshannon(p, q, base=10)

    return re, kl, sre, skl#, js

In [185]:
labels = [ 2., 16.,  4.,  8.,  1.]

In [188]:
epoch = 2
for client_id in range(num_clients):
    print(f"-------------------")
    print(f"ID: {client_id}")
    for k in labels:
        dd = density_diff(df_base, client_id, epoch, k=k, bandwidth=.4, pca=False, pooling="avg")
        print(f"--- Class: {k}, RE: {dd[0]:.2f}, KLD: {dd[1]:.2f}, SRE: {dd[2]:.2f}, SKLD: {dd[3]:.2f}")
    print(f"-------------------\n")

-------------------
ID: 0
--- Class: 2.0, RE: 0.56, KLD: 0.16, SRE: -0.27, SKLD: 0.13
--- Class: 16.0, RE: 1.71, KLD: 1.02, SRE: -0.19, SKLD: 0.50
--- Class: 4.0, RE: 1.32, KLD: 0.69, SRE: -0.24, SKLD: 0.39
--- Class: 8.0, RE: 2.85, KLD: 2.09, SRE: -0.07, SKLD: 0.68
--- Class: 1.0, RE: 1.27, KLD: 0.65, SRE: -0.25, SKLD: 0.38
-------------------

-------------------
ID: 1
--- Class: 2.0, RE: 0.51, KLD: 0.13, SRE: -0.26, SKLD: 0.11
--- Class: 16.0, RE: 1.73, KLD: 1.03, SRE: -0.19, SKLD: 0.51
--- Class: 4.0, RE: 1.38, KLD: 0.74, SRE: -0.23, SKLD: 0.41
--- Class: 8.0, RE: 2.80, KLD: 2.05, SRE: -0.08, SKLD: 0.68
--- Class: 1.0, RE: 1.35, KLD: 0.71, SRE: -0.24, SKLD: 0.40
-------------------

-------------------
ID: 2
--- Class: 2.0, RE: 0.52, KLD: 0.14, SRE: -0.27, SKLD: 0.11
--- Class: 16.0, RE: 1.76, KLD: 1.06, SRE: -0.18, SKLD: 0.51
--- Class: 4.0, RE: 1.38, KLD: 0.73, SRE: -0.23, SKLD: 0.41
--- Class: 8.0, RE: 2.79, KLD: 2.03, SRE: -0.08, SKLD: 0.68
--- Class: 1.0, RE: 1.30, KLD: 0.67, 