In [164]:
from model.model import *
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [196]:
config = Config(
    n_features = 20,
    n_hidden = 5,
    n_experts = 4,
    n_active_experts = 1,
    load_balancing_loss = True,
)


model = MoEModel(
    config=config,
    device=DEVICE,
    importance = 0.7**torch.arange(config.n_features), #torch.exp(torch.randn(config.n_features, generator=torch.Generator().manual_seed(42)) * 3), #torch.linspace(0.8, 0.1, config.n_features), #torch.tensor(1),
    feature_probability = torch.tensor(0.1)
)

In [197]:
optimize(model, n_batch=1024, steps=10000, print_freq=1000)

Step 0: loss=0.101609, lr=0.001000
Step 1000: loss=0.021586, lr=0.001000
Step 2000: loss=0.017415, lr=0.001000
Step 3000: loss=0.017019, lr=0.001000
Step 4000: loss=0.020610, lr=0.001000
Step 5000: loss=0.016204, lr=0.001000
Step 6000: loss=0.017617, lr=0.001000
Step 7000: loss=0.019254, lr=0.001000
Step 8000: loss=0.019265, lr=0.001000
Step 9000: loss=0.021113, lr=0.001000
Step 9999: loss=0.023178, lr=0.001000


In [198]:
# configs = [
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=True),
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=False),
# ]

# feature_probs = [torch.tensor(0.1), torch.tensor(0.2),]
# importances = [0.9**torch.arange(5), 0.8**torch.arange(5),]

# optimize_vectorized(configs, feature_probs, importances, n_batch=10, steps=100, device=DEVICE)

In [199]:
def render_features(model):
  cfg = model.config
  # expert weights
  W_exp = model.W_experts.detach()
  W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

  interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
  interference[:, torch.arange(cfg.n_features), torch.arange(cfg.n_features)] = 0 # set diagonal to 0

  polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()

  if model.feature_probability.dim() == 0:
    feature_prob_expanded = model.feature_probability.expand(cfg.n_features)
  else:
    feature_prob_expanded = model.feature_probability
  net_interference = (interference**2 * feature_prob_expanded[None, None, :]).sum(-1).cpu()
  norms = torch.linalg.norm(W_exp, 2, dim=-1).cpu()

  WWT = torch.einsum('eij,ekj->eik', W_exp, W_exp).cpu() # (n_experts, n_features, n_features)

  # width = weights[0].cpu()
  # x = torch.cumsum(width+0.1, 0) - width[0]
  x = torch.arange(cfg.n_features)
  width = 0.9

  fig = make_subplots(
      rows=cfg.n_experts,  # Row per expert
      cols=2,  # Column 1: bar graphs, Column 2: heatmaps
      shared_xaxes=False,
      vertical_spacing=0.05, # 0.05
      horizontal_spacing=0.1,
      subplot_titles=[f"expert {i}" if j == 0 else "" 
               for i in range(cfg.n_experts) for j in range(2)],
    #   subplot_titles=[f"expert {i}" if j == 0 else f"expert {i}" 
                    #  for i in range(cfg.n_experts) for j in range(2)],
  )
  for expert_idx in range(cfg.n_experts):
        fig.add_trace(
          go.Bar(
              x=x, 
              y=norms[expert_idx],
              marker=dict(
                  color=polysemanticity[expert_idx],
                  cmin=0,
                  cmax=1
              ),
          ),
          row=expert_idx+1, col=1  # Left column for bar graphs
      )
      # Column 2: WWT heatmap for this expert
        fig.add_trace(
          go.Image(
              z=plt.cm.coolwarm((1 + WWT[expert_idx].numpy())/2, bytes=True),
              colormodel='rgba256',
              customdata = WWT[expert_idx].numpy(),
              hovertemplate=f'expert {expert_idx}<br>In: %{{x}}<br>Out: %{{y}}<br> weight: %{{customdata:0.2f}}',
          ),
          row=expert_idx+1, col=2  # Right column for heatmaps
      )

  for expert_idx in range(cfg.n_experts):
      fig.add_vline(
          x=(x[cfg.n_hidden-1]+x[cfg.n_hidden])/2, 
          line=dict(width=0.5),
          row=expert_idx+1,
          col=1,
      )
    
  fig.update_layout(
      showlegend=False, 
      width=800,  # Fixed width for 2 columns
      height=300 * cfg.n_experts,  # Scale height with number of experts
      margin=dict(t=50, b=20)  # Space for subplot titles
  )
  
  # Show x-axes for bar graphs (column 1) and hide for heatmaps (column 2)
  fig.update_xaxes(visible=True, col=1)
  fig.update_xaxes(visible=False, col=2)
  
  fig.update_yaxes(visible=True, col=1)
  fig.update_yaxes(visible=True, title="norm", row=1, col=1)  # Only on first row

  return fig

In [200]:
render_features(model)

feature dimensionality for a feature $i$ in expert $e$ is given by:
$$
D_i^{(e)} = \frac{\left\| W_i^{(e)} \right\|^2}{\sum_j \left( \hat{W}_i^{(e)} \cdot W_j^{(e)} \right)^2}
$$
- $\hat{W}_i^{(e)} = \frac{W_i^{(e)}}{\left| W_i^{(e)} \right|}$ is the unit vector in the direction of $W_i^{(e)}$

In [176]:
def expert_feat_dimensionality(model):
    """fraction of a dimension within an expert that is occupied by a feature."""
    n_experts = model.config.n_experts
    n_features = model.config.n_features
    
    # Initialize tensor to store results for all experts
    all_dimensionalities = torch.zeros(n_experts, n_features)
    
    for expert_idx in range(n_experts):
        W_expert = model.W_experts[expert_idx]  # shape: [n_features, n_hidden]
        W_norm_squared = torch.sum(W_expert**2, dim=1)  # shape: [n_features], ||W_i||^2
        W_hat = W_expert / torch.sqrt(W_norm_squared).unsqueeze(1) # W_hat_i = W_i / ||W_i||
        dot_products = torch.mm(W_hat, W_expert.T) # W_hat_i · W_j
        squared = dot_products**2
        interference = torch.sum(squared, dim=1) # sum of squared dot products for each feature
        dimensionality = W_norm_squared / interference # D_i^(e) = ||W_i^(e)||^2 / sum_j(W_hat_i · W_j)^2
        all_dimensionalities[expert_idx] = dimensionality
    
    return torch.round(all_dimensionalities, decimals=3)

In [177]:
expert_feat_dimensionality(model)

tensor([[0.2390, 0.1620, 0.1610, 0.2780, 0.2170, 0.2330, 0.2470, 0.2330, 0.1180,
         0.3630, 0.1240, 0.3210, 0.3480, 0.2280, 0.3150, 0.2340, 0.2610, 0.0640,
         0.3840, 0.0570],
        [1.0000, 0.0000, 0.9970, 0.4950, 0.5430, 0.5020, 0.4980, 0.5050, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4570, 0.0030,
         0.0000, 0.0000],
        [0.5150, 0.4960, 0.5040, 0.4850, 0.5050, 0.0000, 0.0000, 0.5000, 0.5000,
         0.0000, 0.5020, 0.0000, 0.4950, 0.0000, 0.4980, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.5120, 0.0000, 0.0000, 0.4850, 0.5150, 0.5070, 0.5000, 0.4930, 0.0000,
         0.5160, 0.0000, 0.4880, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.4840, 0.0000]], grad_fn=<RoundBackward1>)

specialization of feature $i$ to expert $e$ is given by:

$$
S_i^{(e)} = \frac{\left\| W_i^{(e)} \right\|^2}{\sum_{e'} \left\| W_i^{(e')} \right\|^2}
$$

- $S_i^{(e)}$ represents the fraction of feature $i$'s total weight magnitude allocated to expert $e$


In [178]:
def feature_specialization(model):
    """fraction of feature i allocated to expert e"""
    W_expert = model.W_experts  
    W_norm_squared = torch.sum(W_expert**2, dim=2)  # shape: [n_experts, n_features], ||W_i||^2
    total_norms_per_feature = torch.sum(W_norm_squared, dim=0)  # sum over experts
    specialization = W_norm_squared / total_norms_per_feature.unsqueeze(0) # ||W_i^(e)||^2 for each expert
    return torch.round(specialization, decimals=3)

In [179]:
feature_specialization(model)

tensor([[0.0160, 0.0260, 0.0220, 0.0220, 0.0160, 0.0320, 0.0210, 0.0180, 0.0200,
         0.0800, 0.0370, 0.0970, 0.0700, 0.0500, 0.0680, 0.9970, 0.0840, 0.7730,
         0.0820, 0.9740],
        [0.3240, 0.0000, 0.4870, 0.3270, 0.3230, 0.4770, 0.4870, 0.3280, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0010, 0.9150, 0.2230,
         0.0000, 0.0110],
        [0.3330, 0.9740, 0.4910, 0.3220, 0.3240, 0.0000, 0.0000, 0.3270, 0.9790,
         0.0000, 0.9630, 0.0000, 0.9300, 0.0000, 0.9320, 0.0010, 0.0000, 0.0020,
         0.0000, 0.0000],
        [0.3270, 0.0000, 0.0000, 0.3290, 0.3370, 0.4910, 0.4920, 0.3270, 0.0000,
         0.9200, 0.0000, 0.9030, 0.0000, 0.9500, 0.0000, 0.0010, 0.0000, 0.0020,
         0.9180, 0.0150]], grad_fn=<RoundBackward1>)

In [180]:
specialization_tensor = feature_specialization(model)
# sum the feature specialization across experts (dimension 0)
summed_specialization = torch.sum(specialization_tensor, dim=0)
print(summed_specialization)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9990,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9990, 1.0000,
        1.0000, 1.0000], grad_fn=<SumBackward1>)


global dimensionality of feature $i$ is given by:

$$
D_i^{\text{global}} = \sum_e p_e \cdot D_i^{(e)}
$$

where:

- $D_i^{(e)}$ is the dimensionality of feature $i$ in expert $e$  
- $p_e$ is the probability or proportion with which expert $e$ is selected (e.g. from routing or usage stats)  
- $D_i^{\text{global}}$ aggregates the feature's dimensionality across experts, weighted by expert usage


In [181]:
def global_feature_dimensionality(model, input_features=None, n_batch=1024):
    """global feature dimensionality weighted across all experts."""
    
    # get per-expert dimensionalities (you already have this)
    expert_dimensionalities = expert_feat_dimensionality(model)  # shape: [n_experts, n_features]
    
    if input_features is None:
        input_features = model.generate_batch(n_batch) 
    
    # compute gating probabilities
    gate_scores = torch.einsum("...f,ef->...e", input_features, model.gate)
    gate_probs = F.softmax(gate_scores, dim=-1)  # shape: [batch, n_experts]
    
    # average probabilities across batch to get p_e
    p_e = torch.mean(gate_probs, dim=0)  # shape: [n_experts]
    print(p_e)
    
    # D_i^global = Σ_e p_e · D_i^(e)
    global_dims = torch.einsum("e,ef->f", p_e, expert_dimensionalities)
    return torch.round(global_dims, decimals=3)

In [182]:
global_feature_dimensionality(model)

tensor([0.2094, 0.2636, 0.2635, 0.2636], grad_fn=<MeanBackward1>)


tensor([0.5840, 0.1650, 0.4290, 0.4440, 0.4570, 0.3150, 0.3150, 0.4440, 0.1560,
        0.2120, 0.1580, 0.1960, 0.2030, 0.1800, 0.1970, 0.0490, 0.1750, 0.0140,
        0.2080, 0.0120], grad_fn=<RoundBackward1>)

global dimensionality of feature $i$ is given by:

$$
D_i^{\text{global}} = \sum_e S_i^{(e)} \cdot D_i^{(e)}
$$

where:

- $D_i^{(e)}$ is the dimensionality of feature $i$ in expert $e$  
- $S_i^{(e)}$ is the specialization of feature $i$ to expert $e$


In [183]:
def weighted_global_dim(model):
    """weighted global feature dimensionality"""
    D = expert_feat_dimensionality(model)
    S = feature_specialization(model)
    D_global = torch.sum(S * D, dim=0) # n_features
    return D_global

In [184]:
weighted_global_dim(model)

tensor([0.6667, 0.4873, 0.7365, 0.4837, 0.5160, 0.4958, 0.4937, 0.4945, 0.4919,
        0.5038, 0.4880, 0.4718, 0.4847, 0.4864, 0.4856, 0.2333, 0.4401, 0.0501,
        0.4758, 0.0555], grad_fn=<SumBackward1>)