In [1]:
import os
import sys
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[0])
sys.path.append(os.getcwd())

import glob
import numpy as np
import torch
import pandas as pd
from scipy.stats import pearsonr

def left_align_facet_plot_titles(fig):

    ## figure out number of columns in each facet
    facet_col_wrap = len(np.unique([a['x'] for a in fig.layout.annotations]))

    # x x x x
    # x x x x <-- then these annotations
    # x x     <-- these annotations are created first

    ## we need to know the remainder
    ## because when we iterate through annotations
    ## they need to know what column they are in 
    ## (and annotations natively contain no such information)

    remainder = len(fig.data) % facet_col_wrap
    number_of_full_rows = len(fig.data) // facet_col_wrap

    annotations = fig.layout.annotations

    xaxis_col_strings = list(range(1, facet_col_wrap+1))
    xaxis_col_strings[0] = ''
    x_axis_start_positions = [fig.layout[f'xaxis{i}']['domain'][0] for i in xaxis_col_strings]

    if remainder == 0:
        x_axis_start_positions_iterator = x_axis_start_positions*number_of_full_rows
    else:
        x_axis_start_positions_iterator = x_axis_start_positions[:remainder] + x_axis_start_positions*number_of_full_rows

    for a, x in zip(annotations, x_axis_start_positions_iterator):
        a['x'] = x
        a['xanchor'] = 'left'
    fig.layout.annotations = annotations
    return fig

  from .autonotebook import tqdm as notebook_tqdm


### Import Eval Results

In [104]:
# File loading per dataset

file_image_inet = "/Image/eval_imagenet_dataset.npz"
file_image_oct = "/Image/eval_oct_dataset.npz"
file_image_r45 = "/Image/eval_resisc45_dataset.npz"

file_voxel_adr = "/Voxel/eval_AdrenalMNIST3D_dataset.npz"
file_voxel_org = "/Voxel/eval_OrganMNIST3D_dataset.npz"
file_voxel_ves = "/Voxel/eval_VesselMNIST3D_dataset.npz"

file_pc_coma = "/Point_Cloud/eval_coma_dataset.npz"
file_pc_m40 = "/Point_Cloud/eval_modelnet40_dataset.npz"
file_pc_shpn = "/Point_Cloud/eval_shapenet_dataset.npz"

file_loc = os.getcwd() + "/data/evaluation"

file = np.load(file_loc + file_image_inet, allow_pickle=True)
arr_image_inet = [file['arr_0'], file['arr_1'], file['arr_2']]
file = np.load(file_loc + file_image_oct, allow_pickle=True)
arr_image_oct = [file['arr_0'], file['arr_1'], file['arr_2']]
file = np.load(file_loc + file_image_r45, allow_pickle=True)
arr_image_r45 = [file['arr_0'], file['arr_1'], file['arr_2']]

file = np.load(file_loc + file_voxel_adr, allow_pickle=True)
arr_voxel_adr = [file['arr_0'], file['arr_1'], file['arr_2']]
file = np.load(file_loc + file_voxel_org, allow_pickle=True)
arr_voxel_org = [file['arr_0'], file['arr_1'], file['arr_2']]
file = np.load(file_loc + file_voxel_ves, allow_pickle=True)
arr_voxel_ves = [file['arr_0'], file['arr_1'], file['arr_2']]

file = np.load(file_loc + file_pc_coma, allow_pickle=True)
arr_pc_coma = [file['arr_0'], file['arr_1'], file['arr_2']]
file = np.load(file_loc + file_pc_m40, allow_pickle=True)
arr_pc_m40 = [file['arr_0'], file['arr_1'], file['arr_2']]
file = np.load(file_loc + file_pc_shpn, allow_pickle=True)
arr_pc_shpn = [file['arr_0'], file['arr_1'], file['arr_2']]

### Visualize Eval Results

In [116]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
colors = px.colors.qualitative.G10

data = arr_pc_shpn

def NormalizeData(data, min, max):
    return (data - min) / ((max - min) + 0.00000000001)

titles = ["<b>[F]</b> Faithfulness Correlation \u2191", "[F] Faithfulness Estimate \u2191", "[F] Monotonicity Correlation \u2191", "[F] Pixel Flipping (AUC) \u2193", "[F] Region Perturbation (AUC) \u2191", "[F] Insertion \u2191", "[F] Deletion \u2193", "[F] IROF (AOC) \u2191", "[F] ROAD (AUC) \u2193", "[F] Sufficiency \u2191", "<b>[R]</b> Local Lipschitz Estimate \u2193", 
           "[R] Max Sensitivity \u2193", "[R] Continuity (PCC) \u2191", "[R] Relative Input Stability \u2193", "[R] Relative Output Stability \u2193", "[R] Relative Repr. Stability \u2193", "[R] Infidelity \u2193", "<b>[C]</b> Sparseness \u2191", "[C] Complexity \u2193", "[C] Effective Complexity \u2193"]

methods = ["OC","LI","KS","SA","IxG", "GB","GC","SC","C+", "IG", "EG", "DL", "DLS", "LRP", "RA", "RoA", "LA"] if data[0].shape[0] >= 14 else ["OC","LI","KS","SA","IxG", "GB", "IG", "EG", "DL", "DLS", "LRP", "RA", "RoA", "LA"]

fig = make_subplots(
    rows=5,
    cols=4,
    vertical_spacing = 0.05,
    horizontal_spacing= 0.03,
    subplot_titles=titles,

)

plot_row = [1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5]
plot_col = [1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4]
plot_x = ["attr","attr","attr","attr","attr","attr","attr","attr","attr","attr","attr","attr","attr","attr","atten","atten","atten",] if data[0].shape[1] == 20 else ["attr","attr","attr","attr","attr","attr","attr","attr","attr","attr","attr","atten","atten","atten",]

for i in range(20): # per eval
    d = np.vstack([data[0][:,i,:], data[1][:,i,:], data[2][(data[2].shape[0] - 3):data[2].shape[0],i,:]]).flatten()
    q_h = np.quantile(d, 0.9)
    q_l = np.quantile(d, 0.1)
    d = np.clip(d, q_l, q_h)
    d_max = d.max()
    d_min = d.min()

    data[0][:,i,:] = NormalizeData(np.clip(data[0][:,i,:], q_l, q_h),d_min,d_max)
    data[1][:,i,:] = NormalizeData(np.clip(data[1][:,i,:], q_l, q_h),d_min,d_max)
    data[2][(data[2].shape[0] - 3):data[2].shape[0],i,:] = NormalizeData(np.clip(data[2][(data[2].shape[0] - 3):data[2].shape[0],i,:], q_l, q_h),d_min,d_max)

    for j in range(data[0].shape[0]): # per attribution
        fig.add_trace(go.Box(y=np.concatenate((data[0][j,i,:], data[1][j,i,:])),name = methods[j], marker_color=colors[0], showlegend=False), # model, explain, eval, n
            row=plot_row[i],
            col=plot_col[i],
        )        
    for j in range(data[2].shape[0] - 3,data[2].shape[0]): # per attention
        fig.add_trace(go.Box(y=data[2][j,i,:],name = methods[j], marker_color=colors[2], showlegend=False), # model, explain, eval, n
            row=plot_row[i],
            col=plot_col[i],
        )

    fig.add_hline(y=np.median(np.concatenate((data[0][:,i,:], data[1][:,i,:]))),x0=0, x1=(1/17)*14, line_dash="dot", row=plot_row[i], col=plot_col[i], line_color="#000000", line_width=2)
    fig.add_hline(y=np.median(data[2][(data[2].shape[0] - 3) : data[2].shape[0],i,:]),x0=(1/17)*14, x1=1, line_dash="dot", row=plot_row[i], col=plot_col[i], line_color="#000000", line_width=2)

fig.update_layout(
    height=1000,
    width=2000,
    margin=dict(t=60, b=10, r=10, l=10),
    font=dict(
        family="Helvetica",
        color="#000000",
    ),
    title_font=dict(
        family="Helvetica",
        color="#000000",
    ),
    title={
        'text': "Evaluation Score Distributions for ShapeNet Dataset per XAI Method and grouped into Attribution and Attention",
        # 'y':0.9,
        'x':0.012,
        }

)

fig = left_align_facet_plot_titles(fig)
fig.update_annotations(font_size=12)
fig.write_image(os.getcwd() + "/data/figures/eval_distr/shapenet.png", scale=2)
fig.show()


### Ranking Computation

In [3]:
arr_image = [arr_image_inet, arr_image_oct, arr_image_r45]
arr_voxel = [arr_voxel_adr, arr_voxel_org, arr_voxel_ves]
arr_pc = [arr_pc_coma, arr_pc_m40, arr_pc_shpn]
arr_modalities = [arr_image,arr_voxel,arr_pc]

arr_ranking = np.empty([3,3,3,17,20], dtype=float) #modality, dataset, model, xai, eval
arr_ranking[:] = np.nan

bup_order = [0,1,2,4,5,7,9,12,17]

for modality in range(3):
    for dataset in range(3):
        for model in range(3):
            for xai in range(arr_modalities[modality][dataset][model].shape[0]):
                for eval in range(3):
                    ranking = np.median(arr_modalities[modality][dataset][model][:,eval,:],-1).argsort() # compute ranking based on median obs score
                    if eval in bup_order:
                        ranking = ranking[::-1] # reverse ranking to bottom up if larger is better

                    pos = ranking.argsort()[xai] + 1 # get rankin position of xai method (+1 so ranking starts at 1 and not 0)
                    arr_ranking[modality,dataset,model,xai,eval] = pos

### Ranking Table

In [9]:
import pandas as pd

arr_table = [] 
for eval in [(0,9),(9,16),(16,19)]:
    for modality in range(3):
        for dataset in range(3):
                arr_col_val = []
                arr_col_std = []
                for xai in range(17):
                    if modality == 2 and xai == 6:
                        arr_col_val = arr_col_val + ["-","-","-"]
                        arr_col_std = arr_col_std + [" "," "," "]
                    if modality == 2 and xai == 14:
                         break        
                    x = arr_ranking[modality,dataset,:,xai,eval[0]:eval[1]]
                    val = np.round(np.mean(x[~np.isnan(x)]))
                    std = np.round(np.std(x[~np.isnan(x)]),2)
                    if not np.isnan(val):
                        val = int(val)
                    else:
                         val = "-"
                         std = "-"
                    arr_col_val.append(val)
                    arr_col_std.append("±" + str(std))
                arr_table.append(arr_col_val)
                arr_table.append(arr_col_std)

df_table = pd.DataFrame(arr_table).transpose()
df_table.index = ["OC","LI","KS","SA","IxG", "GB","GC","SC","C+", "IG", "EG", "DL", "DLS", "LRP", "RA", "RoA", "LA"]
from IPython.display import display, HTML
display(HTML(df_table.to_html()))
df_table.to_csv("table.csv", encoding='utf-8', index=False, header=False)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean, casting='unsafe',
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53
OC,7,±4.94,8,±4.37,7,±3.46,8,±3.09,4,±2.71,7,±2.67,5,±5.09,2,±0.94,8,±4.9,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-
LI,11,±3.19,13,±1.17,10,±4.03,10,±3.0,12,±3.27,11,±2.18,5,±3.39,5,±3.69,5,±3.77,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-
KS,12,±2.04,11,±2.26,10,±1.93,9,±2.82,11,±1.85,11,±1.63,7,±2.79,7,±1.69,6,±3.0,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-
SA,7,±2.78,9,±4.62,7,±2.36,12,±2.31,8,±5.07,11,±4.32,9,±2.62,8,±1.73,9,±2.08,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-
IxG,7,±5.26,5,±4.61,10,±6.02,3,±1.83,4,±3.2,3,±1.59,6,±3.16,5,±2.33,5,±2.11,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-
GB,6,±4.97,8,±3.92,7,±6.09,10,±4.54,6,±3.43,7,±4.95,6,±3.51,4,±0.92,6,±1.83,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-
GC,7,±4.28,6,±4.31,6,±2.67,11,±2.7,10,±3.33,12,±2.6,-,,-,,-,,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,,-,,-,,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,,-,,-,
SC,8,±3.77,8,±3.74,9,±4.17,7,±3.51,11,±2.91,7,±4.61,-,,-,,-,,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,,-,,-,,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,,-,,-,
C+,7,±4.78,10,±3.98,7,±4.21,11,±4.76,11,±1.7,13,±1.75,-,,-,,-,,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,,-,,-,,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,,-,,-,
IG,8,±3.71,6,±4.52,10,±5.08,3,±1.94,6,±3.65,5,±1.63,5,±1.41,7,±2.26,6,±1.79,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-,-,±-


### Ranking Visualization

In [13]:
arr_table = [] 
for eval in range(20):
    for modality in range(3):
        for dataset in range(3):
                    arr_col_val = []
                    for xai in range(17):
                        if modality == 2 and xai == 6:
                            arr_col_val = arr_col_val + [np.round(np.mean(arr_ranking[(0,1),:,:,6,eval])),np.round(np.mean(arr_ranking[(0,1),:,:,7,eval])),np.round(np.mean(arr_ranking[(0,1),:,:,8,eval]))]
                        if modality == 2 and xai == 14:
                            break   
                        x = arr_ranking[modality,dataset,:,xai,eval]
                        val = np.round(np.mean(x[~np.isnan(x)]))
                        arr_col_val.append(val)
                    arr_table.append(arr_col_val)

df_table = pd.DataFrame(arr_table).transpose()
df_table.index = ["OC","LI","KS","SA","IxG", "GB","GC","SC","C+", "IG", "EG", "DL", "DLS", "LRP", "RA", "RoA", "LA"]

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [14]:
from IPython.display import display, HTML
display(HTML(df_table.to_html()))
df_table.to_csv("table.csv", encoding='utf-8', index=False, header=False)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179
OC,4.0,10.0,5.0,10.0,4.0,8.0,3.0,2.0,4.0,4.0,8.0,5.0,7.0,5.0,7.0,6.0,2.0,9.0,13.0,6.0,11.0,6.0,4.0,7.0,6.0,2.0,10.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
LI,9.0,13.0,6.0,12.0,12.0,12.0,7.0,8.0,7.0,10.0,14.0,10.0,10.0,12.0,10.0,5.0,2.0,4.0,15.0,14.0,14.0,8.0,11.0,10.0,4.0,4.0,5.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
KS,13.0,9.0,9.0,10.0,10.0,12.0,7.0,6.0,5.0,11.0,14.0,9.0,10.0,13.0,11.0,7.0,8.0,6.0,13.0,11.0,12.0,8.0,10.0,9.0,9.0,8.0,6.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
SA,8.0,11.0,6.0,11.0,11.0,15.0,10.0,7.0,8.0,7.0,13.0,6.0,13.0,12.0,13.0,9.0,9.0,9.0,7.0,4.0,8.0,12.0,1.0,6.0,8.0,9.0,9.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
IxG,8.0,3.0,15.0,1.0,2.0,2.0,7.0,7.0,6.0,11.0,2.0,14.0,3.0,2.0,2.0,6.0,4.0,4.0,2.0,10.0,2.0,6.0,7.0,5.0,6.0,5.0,4.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
GB,4.0,6.0,4.0,6.0,4.0,5.0,7.0,4.0,5.0,4.0,6.0,3.0,9.0,4.0,5.0,7.0,4.0,5.0,11.0,11.0,14.0,15.0,10.0,11.0,5.0,3.0,7.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
GC,7.0,8.0,6.0,11.0,11.0,11.0,9.0,9.0,9.0,7.0,8.0,6.0,12.0,12.0,12.0,9.0,9.0,9.0,6.0,2.0,8.0,11.0,8.0,13.0,8.0,8.0,8.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
SC,8.0,8.0,6.0,9.0,11.0,7.0,8.0,8.0,8.0,9.0,8.0,9.0,7.0,10.0,7.0,8.0,8.0,8.0,9.0,8.0,10.0,5.0,12.0,8.0,9.0,9.0,9.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
C+,7.0,11.0,7.0,14.0,12.0,12.0,10.0,10.0,10.0,7.0,10.0,6.0,11.0,10.0,13.0,10.0,10.0,10.0,7.0,9.0,9.0,7.0,10.0,13.0,9.0,9.0,9.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
IG,10.0,3.0,13.0,3.0,4.0,5.0,6.0,9.0,7.0,10.0,4.0,13.0,1.0,4.0,5.0,4.0,8.0,5.0,4.0,12.0,3.0,5.0,9.0,4.0,5.0,5.0,6.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


#### MDS Plot

In [16]:
from sklearn.manifold import MDS
mds = MDS(n_components=2, random_state=0)
X_transformed = mds.fit_transform(df_table.iloc[:,0:27])

import plotly.graph_objects as go
import plotly.express as px
colors = px.colors.qualitative.G10

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=X_transformed[:-3,0],
    y=X_transformed[:-3,1],
    mode='markers+text',
    text = ["OC","LI","KS","SA","IxG", "GB","GC","SC","C+", "IG", "EG", "DL", "DLS", "LRP"],
    textposition="top left",
    name = "Attribution",
    marker = dict(color=colors[0], size = 8)
))

fig.add_trace(go.Scatter(
    x=X_transformed[-3:,0],
    y=X_transformed[-3:,1],
    mode='markers+text',
    text = ["RA", "RoA", "LA"],
    textposition="top left",
    name = "Attention",
    marker = dict(color=colors[2], size = 8)
))

#fig.update_xaxes(range=[-15.2,15.2], tickvals=[-15,-7.5,0,7.5,15], zerolinewidth = 3)
#fig.update_yaxes(range=[-10.1,10.1], zerolinewidth = 3)

fig.update_layout(
    height=500,
    width=500,
    font=dict(
        family="Helvetica",
        color="#000000",
        size = 13
    ),
    template = "plotly_white"
)

fig.write_image(os.getcwd() + "/data/figures/mds_plot.png", scale=2)
fig.show()



