In [10]:
from model.model import *
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [11]:
config = Config(
    n_features = 20,
    n_hidden = 5,
    n_experts = 4,
    n_active_experts = 1,
    load_balancing_loss = False,
)


model = MoEModel(
    config=config,
    device=DEVICE,
    importance = 0.7**torch.arange(config.n_features),
    feature_probability = torch.tensor(0.1)
)

# importance inits
#torch.exp(torch.randn(config.n_features, generator=torch.Generator().manual_seed(42)) * 3), 
# #torch.linspace(0.8, 0.1, config.n_features), #torch.tensor(1),
# 0.7**torch.from_numpy(np.random.choice(config.n_features, config.n_features, replace=False))

#print(model.importance)

# ORIGINAL APPROACH
# import numpy as np
# indices = np.random.choice(model.gate.shape[1], model.gate.shape[0], replace=False)
# print(indices)
# with torch.no_grad():
#     model.gate[np.arange(model.gate.shape[0]), indices] = 1
# import numpy as np
# indices = np.random.choice(model.gate.shape[1], size=int(model.gate.shape[0]*10), replace=False)
# indices = torch.from_numpy(indices.reshape(model.gate.shape[0], -1))
# print(indices)
# indices_2 = np.random.choice(model.gate.shape[1], size=int(model.gate.shape[0]), replace=False)
# indices_2 = torch.from_numpy(indices_2.reshape(model.gate.shape[0], -1))
#indx = torch.cat((indices, indices_2), dim=1)
#print(indx)
# with torch.no_grad():
#     for i in range(model.gate.shape[0]):
#         model.gate[i, indices[i]] = 1
# with torch.no_grad():
#     model.gate.fill_diagonal_(1)
# nn.init.xavier_normal_(model.gate)

# print(model.gate)

In [12]:
optimize(model, n_batch=1024, steps=10000, print_freq=1000)

Step 0: loss=0.095757, lr=0.001000
Step 1000: loss=0.005813, lr=0.001000
Step 2000: loss=0.006552, lr=0.001000
Step 3000: loss=0.003519, lr=0.001000
Step 4000: loss=0.005844, lr=0.001000
Step 5000: loss=0.004596, lr=0.001000
Step 6000: loss=0.005320, lr=0.001000
Step 7000: loss=0.005715, lr=0.001000
Step 8000: loss=0.005125, lr=0.001000
Step 9000: loss=0.005375, lr=0.001000
Step 9999: loss=0.006094, lr=0.001000


In [13]:
# configs = [
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=True),
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=False),
# ]

# feature_probs = [torch.tensor(0.1), torch.tensor(0.2),]
# importances = [0.9**torch.arange(5), 0.8**torch.arange(5),]

# optimize_vectorized(configs, feature_probs, importances, n_batch=10, steps=100, device=DEVICE)

In [14]:
def render_features(model):
  cfg = model.config
  # expert weights
  W_exp = model.W_experts.detach()
  W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

  interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
  interference[:, torch.arange(cfg.n_features), torch.arange(cfg.n_features)] = 0 # set diagonal to 0

  polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()

  if model.feature_probability.dim() == 0:
    feature_prob_expanded = model.feature_probability.expand(cfg.n_features)
  else:
    feature_prob_expanded = model.feature_probability
  net_interference = (interference**2 * feature_prob_expanded[None, None, :]).sum(-1).cpu()
  norms = torch.linalg.norm(W_exp, 2, dim=-1).cpu()

  WWT = torch.einsum('eij,ekj->eik', W_exp, W_exp).cpu() # (n_experts, n_features, n_features)

  # width = weights[0].cpu()
  # x = torch.cumsum(width+0.1, 0) - width[0]
  x = torch.arange(cfg.n_features)
  width = 0.9

  fig = make_subplots(
      rows=cfg.n_experts,  # Row per expert
      cols=2,  # Column 1: bar graphs, Column 2: heatmaps
      shared_xaxes=False,
      vertical_spacing=0.05, # 0.05
      horizontal_spacing=0.1,
      subplot_titles=[f"expert {i}" if j == 0 else "" 
               for i in range(cfg.n_experts) for j in range(2)],
    #   subplot_titles=[f"expert {i}" if j == 0 else f"expert {i}" 
                    #  for i in range(cfg.n_experts) for j in range(2)],
  )
  for expert_idx in range(cfg.n_experts):
        fig.add_trace(
          go.Bar(
              x=x, 
              y=norms[expert_idx],
              marker=dict(
                  color=polysemanticity[expert_idx],
                  cmin=0,
                  cmax=1
              ),
          ),
          row=expert_idx+1, col=1  # Left column for bar graphs
      )
      # Column 2: WWT heatmap for this expert
        fig.add_trace(
          go.Image(
              z=plt.cm.coolwarm((1 + WWT[expert_idx].numpy())/2, bytes=True),
              colormodel='rgba256',
              customdata = WWT[expert_idx].numpy(),
              hovertemplate=f'expert {expert_idx}<br>In: %{{x}}<br>Out: %{{y}}<br> weight: %{{customdata:0.2f}}',
          ),
          row=expert_idx+1, col=2  # Right column for heatmaps
      )

  for expert_idx in range(cfg.n_experts):
      fig.add_vline(
          x=(x[cfg.n_hidden-1]+x[cfg.n_hidden])/2, 
          line=dict(width=0.5),
          row=expert_idx+1,
          col=1,
      )
    
  fig.update_layout(
      showlegend=False, 
      width=800,  # Fixed width for 2 columns
      height=300 * cfg.n_experts,  # Scale height with number of experts
      margin=dict(t=50, b=20)  # Space for subplot titles
  )
  
  # Show x-axes for bar graphs (column 1) and hide for heatmaps (column 2)
  fig.update_xaxes(visible=True, col=1)
  fig.update_xaxes(visible=False, col=2)
  
  fig.update_yaxes(visible=True, col=1)
  fig.update_yaxes(visible=True, title="norm", row=1, col=1)  # Only on first row

  return fig

In [15]:
render_features(model)

feature dimensionality for a feature $i$ in expert $e$ is given by:
$$
D_i^{(e)} = \frac{\left\| W_i^{(e)} \right\|^2}{\sum_j \left( \hat{W}_i^{(e)} \cdot W_j^{(e)} \right)^2}
$$
- $\hat{W}_i^{(e)} = \frac{W_i^{(e)}}{\left| W_i^{(e)} \right|}$ is the unit vector in the direction of $W_i^{(e)}$

In [16]:
def expert_feat_dimensionality(model):
    """fraction of a dimension within an expert that is occupied by a feature."""
    n_experts = model.config.n_experts
    n_features = model.config.n_features
    
    # Initialize tensor to store results for all experts
    all_dimensionalities = torch.zeros(n_experts, n_features)
    
    for expert_idx in range(n_experts):
        W_expert = model.W_experts[expert_idx]  # shape: [n_features, n_hidden]
        W_norm_squared = torch.sum(W_expert**2, dim=1)  # shape: [n_features], ||W_i||^2
        W_hat = W_expert / torch.sqrt(W_norm_squared).unsqueeze(1) # W_hat_i = W_i / ||W_i||
        dot_products = torch.mm(W_hat, W_expert.T) # W_hat_i · W_j
        squared = dot_products**2
        interference = torch.sum(squared, dim=1) # sum of squared dot products for each feature
        dimensionality = W_norm_squared / interference # D_i^(e) = ||W_i^(e)||^2 / sum_j(W_hat_i · W_j)^2
        all_dimensionalities[expert_idx] = dimensionality
    
    return torch.round(all_dimensionalities, decimals=3)

In [22]:
expert_feat_dimensionality(model)

tensor([[0.5010, 0.4980, 0.5760, 0.5050, 1.0000, 0.4920, 0.9950, 0.0000, 0.0010,
         0.4230, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0010],
        [0.5060, 0.5070, 0.4940, 0.9950, 0.4930, 0.5220, 0.0010, 0.3310, 0.3790,
         0.4780, 0.0010, 0.2860, 0.0000, 0.0040, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.4980, 0.5020, 1.0000, 0.5000, 0.5000, 0.5120, 0.4880, 0.5050, 0.0000,
         0.4940, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.9990, 0.9980, 0.5050, 0.4950, 0.5190, 0.5020, 0.4960, 0.0010, 0.4810,
         0.0010, 0.0010, 0.0010, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000]], grad_fn=<RoundBackward1>)

In [18]:
W_exp = model.W_experts.detach()
W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
interference[:, torch.arange(100), torch.arange(100)] = 0 # set diagonal to 0
polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()
for i in range(10):
    print(f"Most monosemantic neurons for expert {i}: {polysemanticity[i].topk(10, largest=False)[1]}")

print(f"First expert: {polysemanticity[0].topk(10, largest=False)[0]}")

IndexError: index 20 is out of bounds for dimension 0 with size 20

specialization of feature $i$ to expert $e$ is given by:

$$
S_i^{(e)} = \frac{\left\| W_i^{(e)} \right\|^2}{\sum_{e'} \left\| W_i^{(e')} \right\|^2}
$$

- $S_i^{(e)}$ represents the fraction of feature $i$'s total weight magnitude allocated to expert $e$


In [23]:
def feature_specialization(model):
    """fraction of feature i allocated to expert e"""
    W_expert = model.W_experts  
    W_norm_squared = torch.sum(W_expert**2, dim=2)  # shape: [n_experts, n_features], ||W_i||^2
    total_norms_per_feature = torch.sum(W_norm_squared, dim=0)  # sum over experts
    specialization = W_norm_squared / total_norms_per_feature.unsqueeze(0) # ||W_i^(e)||^2 for each expert
    return torch.round(specialization, decimals=3)

In [24]:
feature_specialization(model)

tensor([[0.2510, 0.2490, 0.2510, 0.2530, 0.2470, 0.2450, 0.3360, 0.0000, 0.0010,
         0.2790, 0.3800, 0.0010, 0.4950, 0.0940, 0.7400, 0.1120, 0.1880, 0.1670,
         0.1380, 0.8810],
        [0.2550, 0.2540, 0.2490, 0.2480, 0.2470, 0.2510, 0.0000, 0.5010, 0.5460,
         0.3490, 0.3240, 0.9970, 0.2820, 0.8810, 0.1480, 0.5750, 0.4830, 0.6670,
         0.7370, 0.0840],
        [0.2470, 0.2490, 0.2480, 0.2520, 0.2500, 0.2500, 0.3240, 0.4980, 0.0000,
         0.3710, 0.0070, 0.0010, 0.0750, 0.0220, 0.0600, 0.2120, 0.1060, 0.0650,
         0.0410, 0.0240],
        [0.2470, 0.2480, 0.2520, 0.2480, 0.2560, 0.2540, 0.3400, 0.0010, 0.4530,
         0.0000, 0.2890, 0.0010, 0.1480, 0.0020, 0.0520, 0.1010, 0.2230, 0.1010,
         0.0840, 0.0110]], grad_fn=<RoundBackward1>)

In [25]:
specialization_tensor = feature_specialization(model)
# sum the feature specialization across experts (dimension 0)
summed_specialization = torch.sum(specialization_tensor, dim=0)
print(summed_specialization)

tensor([1.0000, 1.0000, 1.0000, 1.0010, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        0.9990, 1.0000, 1.0000, 1.0000, 0.9990, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000], grad_fn=<SumBackward1>)


In [26]:
def global_feature_dimensionality(model, input_features=None, n_batch=4096):
    """global feature dimensionality weighted across all experts."""
    
    # get per-expert dimensionalities (you already have this)
    expert_dimensionalities = expert_feat_dimensionality(model)  # shape: [n_experts, n_features]
    
    if input_features is None:
        input_features = model.generate_batch(n_batch) 
    
    # compute gating probabilities
    gate_scores = torch.einsum("...f,ef->...e", input_features, model.gate)
    gate_probs = F.softmax(gate_scores, dim=-1)  # shape: [batch, n_experts]

    #Find the top expert for each point
    selected = gate_probs.argmax(dim=-1)
    counts = torch.bincount(selected.flatten())
    if counts.shape[0] < model.config.n_experts:
        counts = torch.cat([counts, torch.zeros(model.config.n_experts - counts.shape[0])])
    usage = counts / counts.sum()
    #print(f"Usage: {usage}")
    
    # average probabilities across batch to get p_e
    #p_e = torch.mean(gate_probs, dim=0)  # shape: [n_experts]
    #print(f"p_e: {p_e}")
    
    # D_i^global = Σ_e p_e · D_i^(e)
    #global_dims_old = torch.einsum("e,ef->f", p_e, expert_dimensionalities)
    global_dims_new = torch.einsum("e,ef->f", usage, expert_dimensionalities)
    return usage, torch.round(global_dims_new, decimals=3)

In [27]:
usage, global_feature_dims = global_feature_dimensionality(model)
print(f"Usage: {usage}")
print(f"Global feature dims: {global_feature_dims}")
print(f"Usage mean: {usage.mean()}, median: {usage.median()}")
print(f"Global feature dims mean: {global_feature_dims.mean()}, median: {global_feature_dims.median()}")
sums = (global_feature_dims > 0.05).sum()
print(f"Number of features represented: {sums}")

Usage: tensor([0.2207, 0.1406, 0.3037, 0.3350])
Global feature dims: tensor([0.6680, 0.6680, 0.6690, 0.5690, 0.6160, 0.5060, 0.5340, 0.2000, 0.2150,
        0.3110, 0.0010, 0.0410, 0.0000, 0.0010, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000], grad_fn=<RoundBackward1>)
Usage mean: 0.25, median: 0.220703125
Global feature dims mean: 0.24995000660419464, median: 0.04100000113248825
Number of features represented: 10


global dimensionality of feature $i$ is given by:

$$
D_i^{\text{global}} = \sum_e S_i^{(e)} \cdot D_i^{(e)}
$$

where:

- $D_i^{(e)}$ is the dimensionality of feature $i$ in expert $e$  
- $S_i^{(e)}$ is the specialization of feature $i$ to expert $e$


In [39]:
def weighted_global_dim(model):
    """weighted global feature dimensionality"""
    D = expert_feat_dimensionality(model)
    S = feature_specialization(model)
    D_global = torch.sum(S * D, dim=0) # n_features
    return torch.round(D_global, decimals=3)

In [40]:
weighted_global_dim(model)

tensor([0.6250, 0.6250, 0.6430, 0.6230, 0.6270, 0.5070, 0.6610, 0.4170, 0.4250,
        0.4680, 0.0010, 0.2850, 0.0000, 0.0040, 0.0010, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0010], grad_fn=<RoundBackward1>)

global Feature Dimensionality for feature $i$ is given by:

$$
D_i^{\text{global}} \;=\; \sum_{e} \alpha_i^{(e)} \cdot D_i^{(e)}
$$

where:

- $D_i^{(e)}$ = dimensionality of feature $i$ in expert $e$  
- $\alpha_i^{(e)}$ = activation rate of expert $e$ **given** that feature $i$ is present:
$$
\alpha_i^{(e)} \;=\; \frac{\text{\# times expert $e$ active when feature $i$ present}}{\text{\# times feature $i$ present}}
$$

measures how much each expert contributes to representing feature \(i\), weighted by the feature-specific routing frequency.


In [None]:
def global_feat_dim(model, input_features=None, n_batch=1024):
    """
    compute global feature dimensionality for MoE model.
    
    D_i^global = Σ_e alpha_i^(e) · D_i^(e)
    
    where:
    - D_i^(e) is dimensionality of feature i in expert e
    - alpha_i^(e) is feature-specific activation rate of expert e for feature i
    """
    if input_features is None:
        input_features = model.generate_batch(n_batch)
    
    # expert dimensionalities D_i^(e) for all features and experts
    expert_dims = expert_feat_dimensionality(model) # [n_experts, n_features]
    
    # expert activation patterns
    expert_weights, top_k_indices, _ = model.compute_active_experts(input_features)
    # expert_weights: [n_batch, n_experts]
    # top_k_indices: [n_batch, n_active_experts]
    
    # binary activation mask for all experts
    # active_mask[b, e] = 1 if expert e is active for batch b, 0 otherwise
    active_mask = torch.zeros_like(expert_weights)  # [n_batch, n_experts]
    active_mask = active_mask.scatter(-1, top_k_indices, 1.0)
    
    # feature-specific activation rates α_i^(e)
    # for each feature i, we need the activation rate of expert e 
    # only when feature i is present (non-zero)
    
    # which batch examples have each feature present
    feature_present = (input_features > 0).float()  # [n_batch, n_features]
    
    # Count how many times each feature appears
    feature_counts = torch.sum(feature_present, dim=0)  # [n_features]
    
    # for each feature i and expert e, count how often expert e is active when feature i is present
    # feature_present: [n_batch, n_features]
    # active_mask: [n_batch, n_experts]
    
    # [n_batch, n_features, 1] * [n_batch, 1, n_experts]
    feature_expert_cooccurrence = torch.einsum('bf,be->fe', feature_present, active_mask)
    
    # activation rates α_i^(e) = cooccurrence / feature_counts
    # handle division by zero for features that never appear
    alpha = torch.where(
        feature_counts.unsqueeze(-1) > 0,
        feature_expert_cooccurrence / feature_counts.unsqueeze(-1),
        torch.zeros_like(feature_expert_cooccurrence)
    )  # [n_features, n_experts]
    
    # global dimensionality D_i^global = Σ_e α_i^(e) · D_i^(e)
    # alpha: [n_features, n_experts]
    # expert_dims: [n_experts, n_features] -> transpose to [n_features, n_experts]
    global_dims = torch.sum(alpha * expert_dims.T, dim=-1)  # [n_features]
    
    return torch.round(global_dims, decimals=3)

In [38]:
global_feat_dim(model, n_batch=1024)

tensor([0.8210, 0.8180, 0.7640, 0.7850, 0.7360, 0.5090, 0.7320, 0.3690, 0.3590,
        0.3470, 0.0010, 0.1790, 0.0000, 0.0020, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000], grad_fn=<RoundBackward1>)