In [1]:
from __init__ import *

In [2]:
import plotly.graph_objects as go
import numpy as np

In [18]:
args = {
    "device": "cuda",
    "batch_size": 1,
    "num_layer": 6,
    "type_of_intervention": "type1",  # or "type2"
    # "type_of_intervention": "type2",  # or "type2"
}

# Define configurations for type1 and type2
type_1_args = {
    "index1": [0, 256],
    "index2": [256, 256*2],
    "index3": [256*2, 256*3],
    "index4": [256*3, None],
}

type_2_args = {
    "index1": [256, 256*3, 256*3, None],
    "index2": [0, 256, 256*2, None],
    "index3": [0, 256*2, 256*3, None],
    "index4": [0, 256*3, None, None],
}

In [19]:
import sys
sys.path.append('.')

In [20]:
final_dict_modular = {}
mean_acc_modular = []

final_dict_non_modular = {}
mean_acc_non_modular = []

for module in ["mod1", "mod2", "mod3", "mod4"]:
    
    with open(f"data/prediction_{args['type_of_intervention']}_layer{args['num_layer']}_{module}.pkl", "rb") as f:
        prediction = pkl.load(f)
        final_dict_modular[module] = prediction
        mean_acc_modular.append(np.mean(prediction))

    with open(f"data/prediction_nmodel_{args['type_of_intervention']}_layer{args['num_layer']}_{module}.pkl", "rb") as f:
        prediction = pkl.load(f)
        final_dict_non_modular[module] = prediction
        mean_acc_non_modular.append(np.mean(prediction))

In [21]:
filtered_lists_modular = [
            (a, b, c, d)
            for a, b, c, d in zip(np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod1'], 
                                np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod2'],
                                np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod3'],
                                np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod4'])
            if not (a == 0 and b == 0 and c == 0 and d == 0)
        ]

filtered_lists_non_modular = [
            (a, b, c, d)
            for a, b, c, d in zip(np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod1'], 
                                np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod2'],
                                np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod3'],
                                np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod4'])
            if not (a == 0 and b == 0 and c == 0 and d == 0)
        ]

In [28]:
filtered_lists_modular[1]

(np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(1.0))

In [56]:
# sum on other axis
sum_modular = np.sum(filtered_lists_modular, axis=0) / len(filtered_lists_modular)
sum_non_modular = np.sum(filtered_lists_non_modular, axis=0) / len(filtered_lists_non_modular)

sum_modular, sum_non_modular

(array([0.37903226, 0.46370968, 0.41532258, 0.37903226]),
 array([0.54030501, 0.55337691, 0.5795207 , 0.74291939]))

In [22]:
sums_model_modular = np.array([sum(filtered_lists_modular[i]) for i in range(len(filtered_lists_modular))])
sums_model_non_modular = np.array([sum(filtered_lists_non_modular[i]) for i in range(len(filtered_lists_non_modular))])

In [31]:
bins = np.arange(1, 6)  # Bins for sums (1 to 4 + an edge for 5)
hist_modular, bin_edges_modular = np.histogram(sums_model_modular, bins=bins, density=True)
hist_non_modular, bin_edges_non_modular = np.histogram(sums_model_non_modular, bins=bins, density=True)

# Bootstrapping to estimate error bars
n_bootstraps = 1000
bootstrapped_histograms_modular = np.zeros((n_bootstraps, len(hist_modular)))
bootstrapped_histograms_non_modular = np.zeros((n_bootstraps, len(hist_non_modular)))

for b in range(n_bootstraps):
    sample_modular = np.random.choice(sums_model_modular, size=len(sums_model_modular), replace=True)
    sample_non_modular = np.random.choice(sums_model_non_modular, size=len(sums_model_non_modular), replace=True)
    boot_hist_modular, _ = np.histogram(sample_modular, bins=bins, density=True)
    boot_hist_non_modular, _ = np.histogram(sample_non_modular, bins=bins, density=True)
    bootstrapped_histograms_modular[b] = boot_hist_modular
    bootstrapped_histograms_non_modular[b] = boot_hist_non_modular

# Standard error
stderr_modular = np.std(bootstrapped_histograms_modular, axis=0)
stderr_non_modular = np.std(bootstrapped_histograms_non_modular, axis=0)

In [54]:
# histogram of the sum of the filtered lists
fig = go.Figure()
# fig.add_trace(go.Histogram(x=sums_model_modular, histnorm='probability', name='modular', marker_color='blue'))
# fig.add_trace(go.Histogram(x=sums_model_non_modular, histnorm='probability', name='non-modular', marker_color='red'))


# Add histogram bars
fig.add_trace(go.Bar(x=(bin_edges_modular[:-1] + bin_edges_modular[1:]) / 2, y=hist_modular, name='Modular', marker=dict(color='blue', pattern=dict(shape="x")), width=0.4))
fig.add_trace(go.Bar(x=(bin_edges_non_modular[:-1] + bin_edges_non_modular[1:]) / 2, y=hist_non_modular, name='Non-Modular', marker=dict(color='red', pattern=dict(shape="/")), width=0.4))

# Add error bars
fig.add_trace(go.Scatter(x=(bin_edges_modular[:-1] + bin_edges_modular[1:] - 0.4) / 2, y=hist_modular, mode='markers', error_y=dict(
        type='data', array=stderr_modular, visible=True), marker=dict(color='darkblue', size=8), name='Standard Error'))
fig.add_trace(go.Scatter(x=(bin_edges_non_modular[:-1] + bin_edges_non_modular[1:] + 0.4) / 2, y=hist_non_modular, mode='markers', error_y=dict(
        type='data', array=stderr_non_modular, visible=True), marker=dict(color='darkred', size=8), name='Standard Error'))

fig.update_layout(title="", xaxis_title="Number of clusters contributing significantly", yaxis_title="Fraction of samples")

fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
# show fine grid lines on both axes on both subplotsß
# fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# legend inside the plot in a box
fig.update_layout(legend=dict(x=0.55, y=0.93, traceorder="normal", bgcolor="white", bordercolor="black", borderwidth=1))

# width and height
fig.update_layout(width=500, height=500)

# make y axis start from 0
fig.update_yaxes(range=[0, 0.6])

# # ticks on both axes
# fig.update_xaxes(tickmode='linear', tick0=0, dtick=500)
# fig.update_yaxes(tickmode='linear', tick0=0, dtick=0.2)

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=15, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.show()

In [44]:
# Plot
fig = go.Figure()

# Modular bar plot
fig.add_trace(go.Bar(
    x=(bin_edges_modular[:-1] + bin_edges_modular[1:]) / 2,
    y=hist_modular,
    name='Modular',
    marker=dict(color='royalblue', pattern=dict(shape="x")),
    width=0.3
))

# Non-Modular bar plot
fig.add_trace(go.Bar(
    x=(bin_edges_non_modular[:-1] + bin_edges_non_modular[1:]) / 2,
    y=hist_non_modular,
    name='Non-Modular',
    marker=dict(color='salmon', pattern=dict(shape="/")),
    width=0.3
))

# Modular error bars
fig.add_trace(go.Scatter(
    x=(bin_edges_modular[:-1] + bin_edges_modular[1:] - 0.15) / 2,
    y=hist_modular,
    mode='markers',
    error_y=dict(type='data', array=stderr_modular, visible=True),
    marker=dict(color='darkblue', size=8),
    name='Modular Error bars'
))

# Non-Modular error bars
fig.add_trace(go.Scatter(
    x=(bin_edges_non_modular[:-1] + bin_edges_non_modular[1:] + 0.15) / 2,
    y=hist_non_modular,
    mode='markers',
    error_y=dict(type='data', array=stderr_non_modular, visible=True),
    marker=dict(color='darkred', size=8),
    name='Non-Modular Error bars'
))

# Layout
fig.update_layout(
    title="Histogram with Patterns and Error Bars",
    xaxis_title="Number of clusters contributing significantly",
    yaxis_title="Fraction of samples",
    width=600,
    height=500,
    font=dict(family='serif', size=15, color='black'),
    plot_bgcolor='rgba(255, 255, 255, 1)',
    legend=dict(x=0.55, y=0.93, traceorder="normal", bgcolor="white", bordercolor="black", borderwidth=1)
)
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', title_font=dict(family='serif', size=18, color='black'))
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', title_font=dict(family='serif', size=18, color='black'), range=[0, 0.6])
fig.show()
