In [1]:
from model.model import *
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [98]:
config = Config(
    n_features = 20,
    n_hidden = 5,
    n_experts = 4,
    n_active_experts = 1,
    load_balancing_loss = True,
)


model = MoEModel(
    config=config,
    device=DEVICE,
    importance = torch.exp(torch.randn(config.n_features, generator=torch.Generator().manual_seed(42)) * 3), #0.7**torch.arange(config.n_features), #torch.linspace(0.8, 0.1, config.n_features), #torch.tensor(1),
    feature_probability = torch.tensor(0.1)
)

In [99]:
optimize(model, n_batch=1024, steps=10000, print_freq=1000)

Step 0: loss=16.839659, lr=0.001000
Step 1000: loss=1.680822, lr=0.001000
Step 2000: loss=0.740856, lr=0.001000
Step 3000: loss=0.773929, lr=0.001000
Step 4000: loss=0.848712, lr=0.001000
Step 5000: loss=0.581775, lr=0.001000
Step 6000: loss=0.812468, lr=0.001000
Step 7000: loss=1.026672, lr=0.001000
Step 8000: loss=0.579521, lr=0.001000
Step 9000: loss=0.354577, lr=0.001000
Step 9999: loss=0.566039, lr=0.001000


In [100]:
# configs = [
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=True),
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=False),
# ]

# feature_probs = [torch.tensor(0.1), torch.tensor(0.2),]
# importances = [0.9**torch.arange(5), 0.8**torch.arange(5),]

# optimize_vectorized(configs, feature_probs, importances, n_batch=10, steps=100, device=DEVICE)

In [101]:
def render_features(model):
  cfg = model.config
  # expert weights
  W_exp = model.W_experts.detach()
  W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

  interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
  interference[:, torch.arange(cfg.n_features), torch.arange(cfg.n_features)] = 0 # set diagonal to 0

  polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()

  if model.feature_probability.dim() == 0:
    feature_prob_expanded = model.feature_probability.expand(cfg.n_features)
  else:
    feature_prob_expanded = model.feature_probability
  net_interference = (interference**2 * feature_prob_expanded[None, None, :]).sum(-1).cpu()
  norms = torch.linalg.norm(W_exp, 2, dim=-1).cpu()

  WWT = torch.einsum('eij,ekj->eik', W_exp, W_exp).cpu() # (n_experts, n_features, n_features)

  # width = weights[0].cpu()
  # x = torch.cumsum(width+0.1, 0) - width[0]
  x = torch.arange(cfg.n_features)
  width = 0.9

  fig = make_subplots(
      rows=cfg.n_experts,  # Row per expert
      cols=2,  # Column 1: bar graphs, Column 2: heatmaps
      shared_xaxes=False,
      vertical_spacing=0.05, # 0.05
      horizontal_spacing=0.1,
      subplot_titles=[f"expert {i}" if j == 0 else "" 
               for i in range(cfg.n_experts) for j in range(2)],
    #   subplot_titles=[f"expert {i}" if j == 0 else f"expert {i}" 
                    #  for i in range(cfg.n_experts) for j in range(2)],
  )
  for expert_idx in range(cfg.n_experts):
        fig.add_trace(
          go.Bar(
              x=x, 
              y=norms[expert_idx],
              marker=dict(
                  color=polysemanticity[expert_idx],
                  cmin=0,
                  cmax=1
              ),
          ),
          row=expert_idx+1, col=1  # Left column for bar graphs
      )
      # Column 2: WWT heatmap for this expert
        fig.add_trace(
          go.Image(
              z=plt.cm.coolwarm((1 + WWT[expert_idx].numpy())/2, bytes=True),
              colormodel='rgba256',
              customdata = WWT[expert_idx].numpy(),
              hovertemplate=f'expert {expert_idx}<br>In: %{{x}}<br>Out: %{{y}}<br> weight: %{{customdata:0.2f}}'
          ),
          row=expert_idx+1, col=2  # Right column for heatmaps
      )

  for expert_idx in range(cfg.n_experts):
      fig.add_vline(
          x=(x[cfg.n_hidden-1]+x[cfg.n_hidden])/2, 
          line=dict(width=0.5),
          row=expert_idx+1,
          col=1,
      )
    
  fig.update_layout(
      showlegend=False, 
      width=800,  # Fixed width for 2 columns
      height=300 * cfg.n_experts,  # Scale height with number of experts
      margin=dict(t=50, b=20)  # Space for subplot titles
  )
  
  # Show x-axes for bar graphs (column 1) and hide for heatmaps (column 2)
  fig.update_xaxes(visible=True, col=1)
  fig.update_xaxes(visible=False, col=2)
  
  fig.update_yaxes(visible=True, col=1)
  fig.update_yaxes(visible=True, title="norm", row=1, col=1)  # Only on first row
  fig.write_html("demo-superposition.html")

  return fig

In [102]:
render_features(model)

In [17]:
def expert_feature_sim(model):
    """computes feature similarity tensor between each expert"""
    cfg = model.config
    W_exp = model.W_experts.detach()  # (n_experts, n_features, n_hidden)
    W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

    # compute pairwise feature similarity
    feature_sim = torch.einsum('ifh,jfh->ijf', W_norm, W_norm) # (n_experts, n_experts, n_features)
    return feature_sim.cpu()

In [18]:
f_sim = expert_feature_sim(model)
f_sim.shape

torch.Size([8, 8, 20])

In [40]:
def plot_expert_pair_similarity(sim_tensor, expert1_idx, expert2_idx):
    similarities = sim_tensor[expert1_idx, expert2_idx, :]
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=list(range(len(similarities))),
        y=similarities,
        mode='lines+markers',
        # name='Feature Similarity',
        line=dict(width=2),
        marker=dict(size=6)
    ))
    
    # add horizontal reference lines
    fig.add_hline(y=0.6, line_dash="dash", line_color="red", opacity=0.5)
    
    fig.update_layout(
        title=dict(
            text=f'feature similarity b/w expert {expert1_idx} and expert {expert2_idx}',
            x=0.5
        ),
        xaxis_title='feature index',
        yaxis_title='cosine similarity',
        width=800,
        height=500,
        showlegend=False
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    
    fig.show()

In [41]:
plot_expert_pair_similarity(f_sim, 0, 1)

In [None]:
def feature_redundancy(sim_tensor):
    n_experts, _, n_features = sim_tensor.shape
    
    # initialize redundancy scores
    redundancy_scores = torch.zeros(n_features)
    
    # for each feature, compute redundancy across all expert pairs
    for feature_idx in range(n_features):
        total_similarity = 0.0
        num_pairs = 0
        
        # sum absolute similarities across all expert pairs (excluding diagonal)
        for e1 in range(n_experts):
            for e2 in range(n_experts):
                if e1 != e2:  # exclude self-similarity
                    total_similarity += abs(sim_tensor[e1, e2, feature_idx])
                    num_pairs += 1
        
        # compute redundancy score: average absolute similarity across pairs
        # normalized by (E-1) where E is number of experts
        redundancy_scores[feature_idx] = total_similarity / (n_experts * (n_experts - 1))
    
    return redundancy_scores

In [32]:
feature_redundancy(f_sim)

tensor([0.3487, 0.3415, 0.3518, 0.3413, 0.3630, 0.3780, 0.3961, 0.4109, 0.3526,
        0.4289, 0.3271, 0.5459, 0.4342, 0.3408, 0.4075, 0.4206, 0.3544, 0.3514,
        0.3497, 0.3299])

In [35]:
def expert_feat_dimensionality(model, expert_idx):
    """
    compute expert-specific feature dimensionality for a given expert.
    
    for feature i in expert e, dimensionality is defined as:
    D_i,e = ||W_e,i||^2 / (sum_{j=1}^{n_features} (W_e,i · W_e,j)^2)
    
    where:
    - W_e,i is the weight vector for feature i in expert e
    - W_e,i / ||W_e,i||_2 is the unit vector
    
    this measures how much 'space' feature i occupies within expert e,
    accounting for interference from other features in that same expert.
    """
    W_expert = model.W_experts[expert_idx]  # shape: [n_features, n_hidden]
    
    n_features = W_expert.shape[0]
    dimensionalities = torch.zeros(n_features)
    
    for i in range(n_features):
        W_ei = W_expert[i] 
        
        # compute ||W_e,i||^2 (squared norm)
        norm_squared = torch.sum(W_ei**2)
        
        denominator = 0.0
        for j in range(n_features):
            W_ej = W_expert[j]  # shape: [n_hidden]
            dot_product = torch.sum(W_ei * W_ej)
            denominator += dot_product**2
        
        # compute dimensionality
        if denominator > 0:
            dimensionalities[i] = norm_squared / denominator
        else:
            dimensionalities[i] = 0.0
    
    return dimensionalities

In [39]:
expert_feat_dimensionality(model, 0)

tensor([0.7638, 0.7665, 0.7708, 0.5216, 0.7008, 1.0181, 0.4755, 0.6947, 0.5449,
        0.6731, 0.4936, 0.6997, 0.6203, 0.8021, 0.6099, 0.6730, 0.6912, 0.5017,
        0.6382, 0.7751], grad_fn=<CopySlices>)