In [1]:
import jax
import jax.numpy as jnp
from jax import jit
import numpy as np
import plotly.graph_objects as go
from make_masks import random_gates, deterministic_gates

In [2]:
# Configuration
n_vectors = 50
d_in = 100
d_out = 1
random_seed = 42

# Generate random vectors using JAX
key = jax.random.key(random_seed)
print(key.shape)
vector_keys = jax.random.split(key, n_vectors)
# vectors = jax.vmap(lambda k: jax.random.normal(k, (d_in,d_out)))(keys)
# print(vectors.shape)
random_masks = jax.vmap(lambda k: jax.random.bernoulli(k, 0.1, (d_in, d_out)))(vector_keys)
print(f"{random_masks.shape=}")


()
random_masks.shape=(50, 100, 1)


In [3]:
key, sub_key = jax.random.split(key, 2)
g_rand = random_gates(sub_key, 0.1, d_in, d_out, n_vectors)
print(g_rand.shape)

(50, 100)


In [4]:
determ_key, key = jax.random.split(key, 2)
g_determ = deterministic_gates(determ_key, 0.5, 0.5, d_in, d_out, n_vectors)
print(g_determ.shape)
print(f"{g_determ[0,:].shape=}")


(50, 100)
g_determ[0,:].shape=(100,)


In [9]:
similarities = jnp.linspace(0,1,11)
sparsities = jnp.linspace(0,1,11)
n_tasks = 2


all_r_overlaps = {}
all_d_overlaps = {}
for s in sparsities:
    r_overlaps = []
    r_ones = []
    d_overlaps = {}
    for v in similarities:
        key, r_key, r_key2, d_key, d_key2 = jax.random.split(key, 5)
        # generate random vectors
        g1_random = random_gates(r_key, s, d_in, d_out, n_vectors)
        g2_random = random_gates(r_key2, s, d_in, d_out, n_vectors)
        r_overlap = jnp.sum(jnp.multiply(g1_random, g2_random), axis=-1)/d_in # rpts,vect x rpts,vect = rpts,vect; sum over vect to get overlap = (rpts,overlap)
        # print(f"{r_overlap.shape=}")
        r_one = jnp.sum(g1_random) + jnp.sum(g2_random)
        r_ones.append(r_one)
        r_overlaps.append(r_overlap) 
        # generate determ vectors
        g1_determ = deterministic_gates(d_key, v, s, d_in, d_out, n_vectors)
        g2_determ = deterministic_gates(d_key2, v, s, d_in, d_out, n_vectors)
        d_overlap = jnp.sum(jnp.multiply(g1_determ, g2_determ), axis=-1)/d_in

        # New: Compute counts of ones for each vector
        ones_g1 = jnp.sum(g1_determ, axis=-1)  # Shape (n_vectors,)
        ones_g2 = jnp.sum(g2_determ, axis=-1)  # Shape (n_vectors,)
        all_ones = jnp.concatenate([ones_g1, ones_g2])
        # Statistics for ones and zeros
        mean_ones = jnp.mean(all_ones)
        std_ones = jnp.std(all_ones)


        d_overlaps[f"{v.item():.2f}"] = {"mean": jnp.mean(d_overlap, axis=0).item() , 
                                         "std": jnp.std(d_overlap, axis=0).item(),
                                         "mean_ones": mean_ones,
                                         "std_ones": std_ones}
    r_stack = jnp.vstack(r_overlaps).T
    r_mean, r_std = jnp.mean(r_stack, axis=0), jnp.std(r_stack, axis=0)
    r_ones_m, r_ones_std = jnp.mean(jnp.array(r_ones)).item(), jnp.std(jnp.array(r_ones)).item()
    all_r_overlaps[f"{s.item():.2f}"] = {'mean': r_mean, "std": r_std, "m_ones": r_ones_m, "std_ones":r_ones_std}
    all_d_overlaps[f"{s.item():.2f}"] = d_overlaps
    

In [10]:
fig = go.Figure()
fig1 = go.Figure()

for s in all_d_overlaps.keys():
    overlap_means = []
    overlap_stds = []
    x = []
    means_ones = []
    stds_ones = []
    x = []
    for v in all_d_overlaps[s].keys():
        overlap_stds.append(all_d_overlaps[s][v]['std'])
        overlap_means.append(all_d_overlaps[s][v]['mean'])
        x.append(v)
        means_ones.append(all_d_overlaps[s][v]['mean_ones'])
        stds_ones.append(all_d_overlaps[s][v]['std_ones'])

    fig.add_trace(go.Scatter(x=x, 
                            y=overlap_means, 
                            error_y=dict(
                                        type='data',
                                        array=overlap_stds,
                                        visible=True,
                                        ),
                            mode='lines+markers',
                            name=f"sparsity:{s}"
                            ))
    
    fig1.add_trace(go.Scatter(x=x,
                              y=means_ones,
                              error_y=dict(
                                  type='data',
                                  array=stds_ones,
                                  visible=True,
                              ),
                              mode='lines+markers',
                              name=f"sparsity:{s}"))
        # print(all_d_overlaps[s][v])
fig.update_layout(showlegend=True,
                  title="Deterministic Gates",
                  xaxis_title='Task similarity',
                  yaxis_title='Overlap Values')
fig1.update_layout(showlegend=True,
                   title="Number of Active Units",
                   xaxis_title="Similarity",
                   yaxis_title="Sum of AU")

fig.show()
fig1.show()

In [11]:
all_r_overlaps.keys()

len(all_r_overlaps['0.00']['mean'])
# all_r_overlaps['0.00']

11

In [13]:
all_r_overlaps['0.00']

{'mean': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'std': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 'm_ones': 0.0,
 'std_ones': 0.0}

In [16]:
fig = go.Figure()
fig1 = go.Figure()
x = np.linspace(0,1,11)
for s in all_r_overlaps.keys():
    overlap_means = all_r_overlaps[s]['mean']
    overlap_stds = all_r_overlaps[s]['std']
    m_ones = [all_r_overlaps[s]['m_ones']]
    std_ones = [all_r_overlaps[s]['std_ones']]
    fig.add_trace(go.Scatter( x=x,
                            y=overlap_means, 
                            error_y=dict(
                                        type='data',
                                        array=overlap_stds,
                                        visible=True,
                                        ),
                            mode='lines+markers',
                            name=f"sparsity:{s}"
                            ))
    fig1.add_trace(go.Scatter(x=x,
                              y=m_ones,
                              error_y=dict(type='data',
                                           array=std_ones,
                                           visible=True,),
                            mode='lines+markers',
                            name=f"mean N 1's",
                            ))
        # print(all_d_overlaps[s][v])
fig.update_layout(showlegend=True,
                  title="Random Gates",
                  xaxis_title='Task similarity',
                  yaxis_title='Overlap Values',
                  width=1000,
                  height=500,)
fig1.update_layout(showlegend=True,
                   title="Numer of Active Units",  
                   xaxis_title="Similarity",
                   yaxis_title="Sum of AU")
fig.show()
fig1.show()