In [1]:
from model.model import *
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [31]:
config = Config(
    n_features = 20,
    n_hidden = 5,
    n_experts = 4,
    n_active_experts = 1,
    load_balancing_loss = True,
)

model = MoEModel(
    config=config,
    device=DEVICE,
    importance = 0.7**torch.arange(config.n_features),
    feature_probability = torch.tensor(0.1)
)

In [32]:
optimize(model, n_batch=1024, steps=5000, print_freq=1000)

Step 0: loss=0.103726, lr=0.001000
Step 1000: loss=0.023843, lr=0.001000
Step 2000: loss=0.023010, lr=0.001000
Step 3000: loss=0.019490, lr=0.001000
Step 4000: loss=0.018620, lr=0.001000
Step 4999: loss=0.018397, lr=0.001000


In [33]:
# configs = [
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=True),
#     Config(n_features=5, n_hidden=3, n_experts=5, n_active_experts=2, load_balancing_loss=False),
# ]

# feature_probs = [torch.tensor(0.1), torch.tensor(0.2),]
# importances = [0.9**torch.arange(5), 0.8**torch.arange(5),]

# optimize_vectorized(configs, feature_probs, importances, n_batch=10, steps=100, device=DEVICE)

In [38]:
def render_features(model):
  cfg = model.config
  # expert weights
  W_exp = model.W_experts.detach()
  W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

  interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
  interference[:, torch.arange(cfg.n_features), torch.arange(cfg.n_features)] = 0 # set diagonal to 0

  polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()

  if model.feature_probability.dim() == 0:
    feature_prob_expanded = model.feature_probability.expand(cfg.n_features)
  else:
    feature_prob_expanded = model.feature_probability
  net_interference = (interference**2 * feature_prob_expanded[None, None, :]).sum(-1).cpu()
  norms = torch.linalg.norm(W_exp, 2, dim=-1).cpu()

  WtW = torch.einsum('eij,eik->ejk', W_exp, W_exp).cpu() # (n_experts, n_features, n_features)

  # width = weights[0].cpu()
  # x = torch.cumsum(width+0.1, 0) - width[0]
  x = torch.arange(cfg.n_features)
  width = 0.9

  fig = make_subplots(
      rows=2,  # Row 1: norms, Row 2: WtW matrices
      cols=cfg.n_experts,  # One column per expert
      shared_xaxes=True,
      vertical_spacing=0.02,
      horizontal_spacing=0.1,
      subplot_titles=[f"expert {i}" for i in range(cfg.n_experts)],
    #   row_heights=[0.5, 0.5]
  )
  for expert_idx in range(cfg.n_experts):
      fig.add_trace(
          go.Bar(
              y=x, 
              x=norms[expert_idx],
              orientation='h',
              marker=dict(
                  color=polysemanticity[expert_idx],
                  cmin=0,
                  cmax=1
              ),
          ),
          row=1, col=expert_idx+1  # +1 because plotly is 1-indexed
      )
      # Row 2: WtW heatmap for this expert
      fig.add_trace(
          go.Image(
              z=plt.cm.coolwarm((1 + WtW[expert_idx].numpy())/2, bytes=True),
              colormodel='rgba256',
              customdata=WtW[expert_idx].numpy(),
              hovertemplate=f'Expert {expert_idx}<br>In: %{{x}}<br>Out: %{{y}}<br>Weight: %{{customdata:0.2f}}'
          ),
          row=2, col=expert_idx+1
      )

  for expert_idx in range(cfg.n_experts):
      fig.add_hline(
          y=(x[cfg.n_hidden-1]+x[cfg.n_hidden])/2, 
          line=dict(width=0.5),
          col=expert_idx+1,
          row=1,
      )
    
  fig.update_layout(
      showlegend=False, 
      width=200 * cfg.n_experts,  # Scale width with number of experts
      height=300,  # Increase height for 2 rows
      margin=dict(t=30, b=0)  # Space for subplot titles
  )
  fig.update_xaxes(visible=False)
  fig.update_yaxes(visible=False)

  return fig

In [39]:
render_features(model)