In [1]:
from model.model import *
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [66]:
config = Config(
    n_features = 100,
    n_hidden = 10,
    n_experts = 10,
    n_active_experts = 1,
    load_balancing_loss = False,
)


model = MoEModel(
    config=config,
    device=DEVICE,
    importance = 0.9**torch.from_numpy(np.random.choice(config.n_features, config.n_features, replace=False)), #torch.exp(torch.randn(config.n_features, generator=torch.Generator().manual_seed(42)) * 3), #torch.linspace(0.8, 0.1, config.n_features), #torch.tensor(1),
    feature_probability = torch.tensor(0.1)
)

#print(model.importance)

# ORIGINAL APPROACH
# import numpy as np
# indices = np.random.choice(model.gate.shape[1], model.gate.shape[0], replace=False)
# print(indices)
# with torch.no_grad():
#     model.gate[np.arange(model.gate.shape[0]), indices] = 1
import numpy as np
indices = np.random.choice(model.gate.shape[1], size=int(model.gate.shape[0]*10), replace=False)
indices = torch.from_numpy(indices.reshape(model.gate.shape[0], -1))
print(indices)
# indices_2 = np.random.choice(model.gate.shape[1], size=int(model.gate.shape[0]), replace=False)
# indices_2 = torch.from_numpy(indices_2.reshape(model.gate.shape[0], -1))
#indx = torch.cat((indices, indices_2), dim=1)
#print(indx)
with torch.no_grad():
    for i in range(model.gate.shape[0]):
        model.gate[i, indices[i]] = 1
# with torch.no_grad():
#     model.gate.fill_diagonal_(1)
# nn.init.xavier_normal_(model.gate)

print(model.gate)

tensor([[44, 78, 26, 81, 46, 32, 92, 75, 89, 93],
        [ 0, 13, 50, 48, 27, 82, 67, 85, 58, 73],
        [30,  7, 90, 86, 97, 72, 45, 33, 57, 39],
        [34, 94,  6, 23, 15, 60, 74, 19,  9, 68],
        [41,  2, 10, 16, 64, 79, 40, 69, 63, 91],
        [28,  5, 83, 52, 43, 56, 62, 71, 31, 25],
        [99, 38, 49, 29, 20, 24, 98,  8, 96, 54],
        [76, 37, 88, 65, 87, 84, 77, 70, 35, 21],
        [18, 53, 11, 61, 47, 55, 59, 51, 12,  1],
        [42, 95,  3, 66, 22, 14, 80,  4, 17, 36]])
Parameter containing:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0

In [67]:
optimize(model, n_batch=1024, steps=10000, print_freq=1000)

Step 0: Reconstruction loss=0.332469, lr=0.001000
Step 1000: Reconstruction loss=0.049277, lr=0.001000
Step 2000: Reconstruction loss=0.046912, lr=0.001000
Step 3000: Reconstruction loss=0.048780, lr=0.001000
Step 4000: Reconstruction loss=0.049735, lr=0.001000
Step 5000: Reconstruction loss=0.045407, lr=0.001000
Step 6000: Reconstruction loss=0.047213, lr=0.001000
Step 7000: Reconstruction loss=0.045360, lr=0.001000
Step 8000: Reconstruction loss=0.049064, lr=0.001000
Step 9000: Reconstruction loss=0.046823, lr=0.001000
Step 9999: Reconstruction loss=0.044261, lr=0.001000


In [198]:
# configs = [
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=True),
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=False),
# ]

# feature_probs = [torch.tensor(0.1), torch.tensor(0.2),]
# importances = [0.9**torch.arange(5), 0.8**torch.arange(5),]

# optimize_vectorized(configs, feature_probs, importances, n_batch=10, steps=100, device=DEVICE)

In [6]:
def render_features(model):
  cfg = model.config
  # expert weights
  W_exp = model.W_experts.detach()
  W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

  interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
  interference[:, torch.arange(cfg.n_features), torch.arange(cfg.n_features)] = 0 # set diagonal to 0

  polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()

  if model.feature_probability.dim() == 0:
    feature_prob_expanded = model.feature_probability.expand(cfg.n_features)
  else:
    feature_prob_expanded = model.feature_probability
  net_interference = (interference**2 * feature_prob_expanded[None, None, :]).sum(-1).cpu()
  norms = torch.linalg.norm(W_exp, 2, dim=-1).cpu()

  WWT = torch.einsum('eij,ekj->eik', W_exp, W_exp).cpu() # (n_experts, n_features, n_features)

  # width = weights[0].cpu()
  # x = torch.cumsum(width+0.1, 0) - width[0]
  x = torch.arange(cfg.n_features)
  width = 0.9

  fig = make_subplots(
      rows=cfg.n_experts,  # Row per expert
      cols=2,  # Column 1: bar graphs, Column 2: heatmaps
      shared_xaxes=False,
      vertical_spacing=0.05, # 0.05
      horizontal_spacing=0.1,
      subplot_titles=[f"expert {i}" if j == 0 else "" 
               for i in range(cfg.n_experts) for j in range(2)],
    #   subplot_titles=[f"expert {i}" if j == 0 else f"expert {i}" 
                    #  for i in range(cfg.n_experts) for j in range(2)],
  )
  for expert_idx in range(cfg.n_experts):
        fig.add_trace(
          go.Bar(
              x=x, 
              y=norms[expert_idx],
              marker=dict(
                  color=polysemanticity[expert_idx],
                  cmin=0,
                  cmax=1
              ),
          ),
          row=expert_idx+1, col=1  # Left column for bar graphs
      )
      # Column 2: WWT heatmap for this expert
        fig.add_trace(
          go.Image(
              z=plt.cm.coolwarm((1 + WWT[expert_idx].numpy())/2, bytes=True),
              colormodel='rgba256',
              customdata = WWT[expert_idx].numpy(),
              hovertemplate=f'expert {expert_idx}<br>In: %{{x}}<br>Out: %{{y}}<br> weight: %{{customdata:0.2f}}',
          ),
          row=expert_idx+1, col=2  # Right column for heatmaps
      )

  for expert_idx in range(cfg.n_experts):
      fig.add_vline(
          x=(x[cfg.n_hidden-1]+x[cfg.n_hidden])/2, 
          line=dict(width=0.5),
          row=expert_idx+1,
          col=1,
      )
    
  fig.update_layout(
      showlegend=False, 
      width=800,  # Fixed width for 2 columns
      height=300 * cfg.n_experts,  # Scale height with number of experts
      margin=dict(t=50, b=20)  # Space for subplot titles
  )
  
  # Show x-axes for bar graphs (column 1) and hide for heatmaps (column 2)
  fig.update_xaxes(visible=True, col=1)
  fig.update_xaxes(visible=False, col=2)
  
  fig.update_yaxes(visible=True, col=1)
  fig.update_yaxes(visible=True, title="norm", row=1, col=1)  # Only on first row

  return fig

In [68]:
render_features(model)

feature dimensionality for a feature $i$ in expert $e$ is given by:
$$
D_i^{(e)} = \frac{\left\| W_i^{(e)} \right\|^2}{\sum_j \left( \hat{W}_i^{(e)} \cdot W_j^{(e)} \right)^2}
$$
- $\hat{W}_i^{(e)} = \frac{W_i^{(e)}}{\left| W_i^{(e)} \right|}$ is the unit vector in the direction of $W_i^{(e)}$

In [8]:
def expert_feat_dimensionality(model):
    """fraction of a dimension within an expert that is occupied by a feature."""
    n_experts = model.config.n_experts
    n_features = model.config.n_features
    
    # Initialize tensor to store results for all experts
    all_dimensionalities = torch.zeros(n_experts, n_features)
    
    for expert_idx in range(n_experts):
        W_expert = model.W_experts[expert_idx]  # shape: [n_features, n_hidden]
        W_norm_squared = torch.sum(W_expert**2, dim=1)  # shape: [n_features], ||W_i||^2
        W_hat = W_expert / torch.sqrt(W_norm_squared).unsqueeze(1) # W_hat_i = W_i / ||W_i||
        dot_products = torch.mm(W_hat, W_expert.T) # W_hat_i · W_j
        squared = dot_products**2
        interference = torch.sum(squared, dim=1) # sum of squared dot products for each feature
        dimensionality = W_norm_squared / interference # D_i^(e) = ||W_i^(e)||^2 / sum_j(W_hat_i · W_j)^2
        all_dimensionalities[expert_idx] = dimensionality
    
    return torch.round(all_dimensionalities, decimals=3)

In [None]:
W_exp = model.W_experts.detach()
W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
interference[:, torch.arange(100), torch.arange(100)] = 0 # set diagonal to 0
polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()
for i in range(10):
    print(f"Most monosemantic neurons for expert {i}: {polysemanticity[i].topk(10, largest=False)[1]}")

print(f"First expert: {polysemanticity[0].topk(10, largest=False)[0]}")

Most monosemantic neurons for expert 0: tensor([45, 51, 91, 52, 33, 53, 12,  0, 18, 84])
Most monosemantic neurons for expert 1: tensor([ 0, 52, 51, 18, 12, 91, 84, 53, 87, 64])
Most monosemantic neurons for expert 2: tensor([33, 45, 53, 12, 64, 18, 91, 52, 49, 37])
Most monosemantic neurons for expert 3: tensor([ 6, 84, 51, 52, 12, 53, 18, 91, 37, 68])
Most monosemantic neurons for expert 4: tensor([91, 12, 33, 51, 49, 52, 29, 53, 45, 64])
Most monosemantic neurons for expert 5: tensor([52, 25, 33, 45, 84, 18, 91, 53, 12, 87])
Most monosemantic neurons for expert 6: tensor([49, 52, 98,  0, 18, 24, 33, 12, 29, 51])
Most monosemantic neurons for expert 7: tensor([87, 84, 37, 33, 91, 52, 36, 51, 18,  0])
Most monosemantic neurons for expert 8: tensor([51, 18, 53, 12, 91, 49, 52, 64, 33, 11])
Most monosemantic neurons for expert 9: tensor([36, 51, 12, 53,  0, 52, 91, 33, 84, 18])
First expert: tensor([0.9733, 0.9797, 0.9924, 0.9929, 0.9952, 1.0034, 1.0088, 1.0143, 1.0149,
        1.0211])

specialization of feature $i$ to expert $e$ is given by:

$$
S_i^{(e)} = \frac{\left\| W_i^{(e)} \right\|^2}{\sum_{e'} \left\| W_i^{(e')} \right\|^2}
$$

- $S_i^{(e)}$ represents the fraction of feature $i$'s total weight magnitude allocated to expert $e$


In [9]:
def feature_specialization(model):
    """fraction of feature i allocated to expert e"""
    W_expert = model.W_experts  
    W_norm_squared = torch.sum(W_expert**2, dim=2)  # shape: [n_experts, n_features], ||W_i||^2
    total_norms_per_feature = torch.sum(W_norm_squared, dim=0)  # sum over experts
    specialization = W_norm_squared / total_norms_per_feature.unsqueeze(0) # ||W_i^(e)||^2 for each expert
    return torch.round(specialization, decimals=3)

In [103]:
feature_specialization(model)

tensor([[0.2550, 0.2470, 0.2560, 0.2130, 0.2410, 0.2540, 0.0010, 0.4950, 0.8570,
         0.5010, 0.4990, 0.0000, 0.0010, 0.0330, 0.5370, 0.3480, 0.1010, 0.1540,
         0.0750, 0.0570],
        [0.2490, 0.2510, 0.2470, 0.2620, 0.2450, 0.2440, 0.4170, 0.0020, 0.0000,
         0.0000, 0.0000, 0.9120, 0.3170, 0.1000, 0.0950, 0.1480, 0.7540, 0.0610,
         0.5580, 0.3130],
        [0.2460, 0.2510, 0.2520, 0.2620, 0.2570, 0.2250, 0.5540, 0.5040, 0.1040,
         0.0000, 0.5010, 0.0870, 0.0000, 0.0340, 0.1200, 0.0290, 0.0220, 0.0160,
         0.1000, 0.1650],
        [0.2500, 0.2510, 0.2450, 0.2640, 0.2570, 0.2770, 0.0290, 0.0000, 0.0390,
         0.4990, 0.0000, 0.0010, 0.6820, 0.8320, 0.2490, 0.4750, 0.1220, 0.7700,
         0.2660, 0.4640]], grad_fn=<RoundBackward1>)

In [180]:
specialization_tensor = feature_specialization(model)
# sum the feature specialization across experts (dimension 0)
summed_specialization = torch.sum(specialization_tensor, dim=0)
print(summed_specialization)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9990,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9990, 1.0000,
        1.0000, 1.0000], grad_fn=<SumBackward1>)


global dimensionality of feature $i$ is given by:

$$
D_i^{\text{global}} = \sum_e p_e \cdot D_i^{(e)}
$$

where:

- $D_i^{(e)}$ is the dimensionality of feature $i$ in expert $e$  
- $p_e$ is the probability or proportion with which expert $e$ is selected (e.g. from routing or usage stats)  
- $D_i^{\text{global}}$ aggregates the feature's dimensionality across experts, weighted by expert usage


In [38]:
def global_feature_dimensionality(model, input_features=None, n_batch=4096):
    """global feature dimensionality weighted across all experts."""
    
    # get per-expert dimensionalities (you already have this)
    expert_dimensionalities = expert_feat_dimensionality(model)  # shape: [n_experts, n_features]
    
    if input_features is None:
        input_features = model.generate_batch(n_batch) 
    
    # compute gating probabilities
    gate_scores = torch.einsum("...f,ef->...e", input_features, model.gate)
    gate_probs = F.softmax(gate_scores, dim=-1)  # shape: [batch, n_experts]

    #Find the top expert for each point
    selected = gate_probs.argmax(dim=-1)
    counts = torch.bincount(selected.flatten())
    if counts.shape[0] < model.config.n_experts:
        counts = torch.cat([counts, torch.zeros(model.config.n_experts - counts.shape[0])])
    usage = counts / counts.sum()
    #print(f"Usage: {usage}")
    
    # average probabilities across batch to get p_e
    #p_e = torch.mean(gate_probs, dim=0)  # shape: [n_experts]
    #print(f"p_e: {p_e}")
    
    # D_i^global = Σ_e p_e · D_i^(e)
    #global_dims_old = torch.einsum("e,ef->f", p_e, expert_dimensionalities)
    global_dims_new = torch.einsum("e,ef->f", usage, expert_dimensionalities)
    return usage, torch.round(global_dims_new, decimals=3)

In [73]:
usage, global_feature_dims = global_feature_dimensionality(model)
print(f"Usage: {usage}")
print(f"Global feature dims: {global_feature_dims}")
print(f"Usage mean: {usage.mean()}, median: {usage.median()}")
print(f"Global feature dims mean: {global_feature_dims.mean()}, median: {global_feature_dims.median()}")
sums = (global_feature_dims > 0.05).sum()
print(f"Number of features represented: {sums}")

Usage: tensor([0.1001, 0.0930, 0.0903, 0.0928, 0.1028, 0.1035, 0.0994, 0.0969, 0.1118,
        0.1094])
Global feature dims: tensor([0.5240, 0.0000, 0.0010, 0.0000, 0.0000, 0.0310, 0.2520, 0.0000, 0.0000,
        0.0000, 0.0000, 0.2450, 0.5550, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.5450, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0490, 0.4360, 0.0000,
        0.0000, 0.0000, 0.3790, 0.3500, 0.0000, 0.0470, 0.5480, 0.0000, 0.0000,
        0.4430, 0.4210, 0.0000, 0.0000, 0.0030, 0.0000, 0.0000, 0.0000, 0.0000,
        0.5130, 0.0000, 0.0140, 0.0000, 0.5020, 0.0000, 0.5570, 0.5570, 0.5520,
        0.0000, 0.0040, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0480,
        0.0000, 0.4490, 0.0000, 0.0000, 0.0000, 0.0020, 0.0850, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1770, 0.0000,
        0.0000, 0.0000, 0.0000, 0.4770, 0.0000, 0.0000, 0.5010, 0.0000, 0.0000,
        0.0000, 0.5530, 0.0000, 0.0000, 0.0000, 0.0480, 0.0000, 0.0000, 0.1

global dimensionality of feature $i$ is given by:

$$
D_i^{\text{global}} = \sum_e S_i^{(e)} \cdot D_i^{(e)}
$$

where:

- $D_i^{(e)}$ is the dimensionality of feature $i$ in expert $e$  
- $S_i^{(e)}$ is the specialization of feature $i$ to expert $e$


In [83]:
def weighted_global_dim(model):
    """weighted global feature dimensionality"""
    D = expert_feat_dimensionality(model)
    S = feature_specialization(model)
    D_global = torch.sum(S * D, dim=0) # n_features
    return D_global

In [84]:
weighted_global_dim(model)

tensor([0.5729, 0.5738, 0.5688, 0.5363, 0.5393, 0.4627, 0.5133, 0.4802, 0.4662,
        0.4261, 0.3790, 0.3246, 0.1376, 0.0055, 0.0044, 0.0136, 0.0000, 0.0007,
        0.0000, 0.0000], grad_fn=<SumBackward1>)