In [1]:
import pickle as pkl

In [2]:
import plotly.graph_objects as go
import numpy as np

In [3]:
args = {
    "device": "cuda",
    "batch_size": 1,
    "num_layer": 6,
    "type_of_intervention": "type1",  # or "type2"
}

# Define configurations for type1 and type2
type_1_args = {
    "index1": [0, 256],
    "index2": [256, 256*2],
    "index3": [256*2, 256*3],
    "index4": [256*3, None],
}

type_2_args = {
    "index1": [256, 256*3, 256*3, None],
    "index2": [0, 256, 256*2, None],
    "index3": [0, 256*2, 256*3, None],
    "index4": [0, 256*3, None, None],
}

In [4]:
import sys
sys.path.append('.')

In [5]:
final_dict_modular = {}
mean_acc_modular = []

final_dict_non_modular = {}
mean_acc_non_modular = []

for module in ["mod1", "mod2", "mod3", "mod4"]:
    
    with open(f"data/prediction_{args['type_of_intervention']}_layer{args['num_layer']}_{module}.pkl", "rb") as f:
        prediction = pkl.load(f)
        final_dict_modular[module] = prediction
        mean_acc_modular.append(np.mean(prediction))

    with open(f"data/prediction_nmodel_{args['type_of_intervention']}_layer{args['num_layer']}_{module}.pkl", "rb") as f:
        prediction = pkl.load(f)
        final_dict_non_modular[module] = prediction
        mean_acc_non_modular.append(np.mean(prediction))

In [6]:
filtered_lists_modular = [
            (a, b, c, d)
            for a, b, c, d in zip(np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod1'], 
                                np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod2'],
                                np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod3'],
                                np.ones(np.array(final_dict_modular['mod1']).shape)- final_dict_modular['mod4'])
            if not (a == 0 and b == 0 and c == 0 and d == 0)
        ]

filtered_lists_non_modular = [
            (a, b, c, d)
            for a, b, c, d in zip(np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod1'], 
                                np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod2'],
                                np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod3'],
                                np.ones(np.array(final_dict_non_modular['mod1']).shape)- final_dict_non_modular['mod4'])
            if not (a == 0 and b == 0 and c == 0 and d == 0)
        ]

In [7]:
sums_model_modular = np.array([sum(filtered_lists_modular[i]) for i in range(len(filtered_lists_modular))])
sums_model_non_modular = np.array([sum(filtered_lists_non_modular[i]) for i in range(len(filtered_lists_non_modular))])

In [8]:
bins = np.arange(1, 6)  # Bins for sums (1 to 4 + an edge for 5)
hist_modular, bin_edges_modular = np.histogram(sums_model_modular, bins=bins, density=True)

# Bootstrapping for modular error bars
n_bootstraps = 1000
bootstrapped_histograms_modular = np.zeros((n_bootstraps, len(hist_modular)))

for b in range(n_bootstraps):
    sample_modular = np.random.choice(sums_model_modular, size=len(sums_model_modular), replace=True)
    boot_hist_modular, _ = np.histogram(sample_modular, bins=bins, density=True)
    bootstrapped_histograms_modular[b, :] = boot_hist_modular

stderr_modular = np.std(bootstrapped_histograms_modular, axis=0)

# Compute histograms for non-modular
hist_non_modular, bin_edges_non_modular = np.histogram(sums_model_non_modular, bins=bins, density=True)

# Bootstrapping for non-modular error bars
bootstrapped_histograms_non_modular = np.zeros((n_bootstraps, len(hist_non_modular)))

for b in range(n_bootstraps):
    sample_non_modular = np.random.choice(sums_model_non_modular, size=len(sums_model_non_modular), replace=True)
    boot_hist_non_modular, _ = np.histogram(sample_non_modular, bins=bins, density=True)
    bootstrapped_histograms_non_modular[b, :] = boot_hist_non_modular

stderr_non_modular = np.std(bootstrapped_histograms_non_modular, axis=0)

In [9]:
stderr_modular, stderr_non_modular

(array([0.03312371, 0.03030927, 0.02026697, 0.01089174]),
 array([0.02158261, 0.01824948, 0.01698408, 0.02114604]))

In [64]:
# histogram of the sum of the filtered lists
fig = go.Figure()

fig.add_trace(go.Bar(
    x=(bin_edges_modular[:-1] + bin_edges_modular[1:]) / 2, 
    y=hist_modular, 
    name='Modular', 
    marker=dict(color='darkblue', pattern=dict(shape='x')), 
    width=0.4,
    error_y=dict(type='data', array=stderr_modular, visible=True)
))

fig.add_trace(go.Bar(
    x=(bin_edges_non_modular[:-1] + bin_edges_non_modular[1:]) / 2, 
    y=hist_non_modular, 
    name='Non-Modular', 
    marker=dict(color='darkred', pattern=dict(shape='/')), 
    width=0.4,
    error_y=dict(type='data', array=stderr_non_modular, visible=True)
))

fig.update_layout(title="", xaxis_title="Total number of contributing clusters", yaxis_title="Fraction of samples (n=1000)", showlegend=True)

fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
# show fine grid lines on both axes on both subplotsß
# fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# legend inside the plot in a box
fig.update_layout(legend=dict(x=0.55, y=0.93, traceorder="normal", bgcolor="white", bordercolor="black", borderwidth=1))

# width and height
fig.update_layout(width=500, height=500)

# make y axis start from 0
# fig.update_yaxes(range=[0, 1])

# # ticks on both axes
x_ticks = ['N=0', 'N=1', 'N=2', 'N=3', 'N=4']
fig.update_xaxes(tickvals=[0.5, 1.5, 2.5, 3.5, 4.5], ticktext=x_ticks)
# fig.update_yaxes(tickmode='linear', tick0=0, dtick=0.2)

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=15, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.show()

In [65]:
# save as pdf
fig.write_image("beautiful_plots/gpt2_small_left.pdf")

In [10]:
np_2d_modular = np.array(filtered_lists_modular)
np_2d_non_modular = np.array(filtered_lists_non_modular)

In [11]:

# Number of bootstrap replicates
n_bootstraps = 1000

# Function to bootstrap
def bootstrap_means(data, n_bootstraps):
    boot_means = np.zeros((n_bootstraps, data.shape[1]))
    for b in range(n_bootstraps):
        sample_indices = np.random.choice(data.shape[0], size=data.shape[0], replace=True)
        boot_sample = data[sample_indices]
        boot_means[b, :] = np.mean(boot_sample, axis=0)
    return boot_means

# Bootstrap means for modular and non-modular
bootstrap_modular = bootstrap_means(np_2d_modular, n_bootstraps)
bootstrap_non_modular = bootstrap_means(np_2d_non_modular, n_bootstraps)

# Standard errors for modular and non-modular
stderr_modular = np.std(bootstrap_modular, axis=0)
stderr_non_modular = np.std(bootstrap_non_modular, axis=0)

# Results
bootstrap_modular_mean = np.mean(bootstrap_modular, axis=0)
bootstrap_non_modular_mean = np.mean(bootstrap_non_modular, axis=0)

# Outputs: bootstrap means and standard errors
bootstrap_modular_mean, stderr_modular, bootstrap_non_modular_mean, stderr_non_modular

(array([0.37824597, 0.46537097, 0.41506855, 0.37889113]),
 array([0.03166875, 0.03190285, 0.03096528, 0.03229528]),
 array([0.53966885, 0.55264706, 0.58037255, 0.74389107]),
 array([0.0222477 , 0.02317992, 0.02238122, 0.01998227]))

In [42]:
# plot sum of 2d arrays along axis 1
fig = go.Figure()
colors = ['red', 'green', 'orange', 'blue']

fig.add_trace(go.Bar(
    x=(bin_edges_modular[:-1] + bin_edges_modular[1:]) / 2, 
    y=np.sum(np_2d_modular, axis=0), 
    name='Modular', 
    marker=dict(color=colors, pattern=dict(shape='x')), 
    width=0.4,
    error_y=dict(type='data', array=stderr_modular, visible=True),
    showlegend=False,
))

fig.add_trace(go.Bar(
    x=(0.8 + bin_edges_non_modular[:-1] + bin_edges_non_modular[1:]) / 2,
    y=np.sum(np_2d_non_modular, axis=0), 
    name='Non-Modular', 
    marker=dict(color=colors, pattern=dict(shape='/')), 
    width=0.4,
    error_y=dict(type='data', array=stderr_non_modular, visible=True),
    showlegend=False,
))

# Dummy traces for legend
fig.add_trace(go.Bar(
    x=[None], 
    y=[None], 
    name='Modular', 
    marker=dict(color='white', pattern=dict(shape='x')), 
    showlegend=True  # Show only this in legend
))

fig.add_trace(go.Bar(
    x=[None], 
    y=[None], 
    name='Non-Modular', 
    marker=dict(color='white', pattern=dict(shape='/')), 
    showlegend=True  # Show only this in legend
))

# xtick locations
xticks = ['A', 'B', 'C', 'D']
fig.update_xaxes(tickvals=[1.7, 2.7, 3.7, 4.7], ticktext=xticks)

fig.update_layout(title="", xaxis_title="Cluster (or module)", yaxis_title="Contribution Frequency (n=1000)", showlegend=True)
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_layout(legend=dict(x=0.08, y=0.95, traceorder="normal", bgcolor="white", bordercolor="black", borderwidth=1))
fig.update_layout(width=500, height=500)
fig.update_layout(font=dict(family='serif', size=15, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)
fig.show()

In [44]:
# save as pdf
fig.write_image("beautiful_plots/gpt2_small_right.pdf")