# Section 3

In [None]:
from model.model import *
import matplotlib.pyplot as plt
import torch
import numpy

### Reproducing Figure 5

In [None]:
def other_optimize(model, 
             render=False, 
             n_batch=1024,
             steps=10_000,
             print_freq=100,
             lr=1e-3,
             lr_scale=constant_lr,
             stop_loss=0.015,
             drop_loss=0.02,
             hooks=[]):
  cfg = model.config

  opt = torch.optim.AdamW(list(model.parameters()), lr=lr)

  start = time.time()
  # Replace trange with regular range
  for step in range(steps):
    step_lr = lr * lr_scale(step, steps)
    for group in opt.param_groups:
      group['lr'] = step_lr
    opt.zero_grad(set_to_none=True)
    batch = model.generate_batch(n_batch)
    out, load_balance_loss = model(batch)
    error = (model.importance*(batch.abs() - out)**2)
    reconstruction_loss = einops.reduce(error, 'b f -> f', 'mean').sum()
    
    loss = reconstruction_loss
    if load_balance_loss is not None:
      loss = loss + load_balance_loss
    
    loss.backward()
    opt.step()
  
    if hooks:
      hook_data = dict(model=model,
                       step=step, 
                       opt=opt,
                       error=error,
                       loss=loss,
                       reconstruction_loss=reconstruction_loss,
                       load_balance_loss=load_balance_loss,
                       lr=step_lr)
      for h in hooks:
        h(hook_data)
    if step % print_freq == 0 or (step + 1 == steps):
      print(f"Step {step}: loss={loss.item():.6f}, lr={step_lr:.6f}")
    

    if loss.item() < drop_loss:
      print(f"Dropping at step {step} with loss {loss.item():.6f}")
      lr = lr * 0.1
    if loss.item() < stop_loss:
      print(f"Stopping at step {step} with loss {loss.item():.6f}")
      return loss.item()
    
  return loss.item()

In [None]:
torch.random.manual_seed(27)
np.random.seed(30)

In [None]:
config = Config(
    n_features=2,
    n_hidden=1,
    n_experts=3,
    n_active_experts=1,
    load_balancing_loss=False,
)
model_base = MoEModel(config, device='cpu', feature_probability=torch.tensor(0.5), importance=torch.tensor(1))
nn.init.xavier_normal_(model_base.gate)
loss_base = other_optimize(model_base, n_batch=100, steps=10000, print_freq=2000, stop_loss=0.01, drop_loss=0.015)


model_zero = MoEModel(config, device='cpu', feature_probability=torch.tensor(0.5), importance=torch.tensor(1))
nn.init.constant_(model_zero.gate, 0)
print(model_zero.gate)
loss_zero = other_optimize(model_zero, n_batch=100, steps=10000, print_freq=2000, stop_loss=0.01, drop_loss=0.015)

model_khot = MoEModel(config, device='cpu', feature_probability=torch.tensor(0.5), importance=torch.tensor(1))
nn.init.constant_(model_khot.gate, 0)
with torch.no_grad():
    model_khot.gate.fill_diagonal_(1)
print(model_khot.gate)
loss_khot = other_optimize(model_khot, n_batch=100, steps=10000, print_freq=2000, stop_loss=0.01, drop_loss=0.015)

In [None]:
def render_expert_specialization_multi(models_dict, resolution=100, save_path=None, losses_dict=None):
    """
    Render multiple expert specialization plots with LaTeX formatting for PDF/PGF export.
    
    Args:
        models_dict: Dictionary with {num_experts: model} mappings or {label: model}
        resolution: Number of points per dimension for the grid
        save_path: If provided, saves to {save_path}.pdf and {save_path}.pgf
        losses_dict: Optional dictionary with {label: loss_value} to display below each plot
    
    Returns:
        fig: matplotlib figure object
    """
    import torch.nn.functional as F
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap, BoundaryNorm
    
    # Set up matplotlib for LaTeX
    plt.rcParams.update({
        'text.usetex': True,
        'font.family': 'serif',
        'font.serif': ['Computer Modern'],
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'legend.fontsize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12
    })
    
    # Create a grid of feature combinations
    x = np.linspace(0, 1, resolution)
    y = np.linspace(0, 1, resolution)
    X, Y = np.meshgrid(x, y)
    
    # Flatten for batch processing
    feature_grid = np.stack([X.flatten(), Y.flatten()], axis=1)
    feature_tensor = torch.tensor(feature_grid, dtype=torch.float32)
    
    # Calculate number of subplots needed
    n_models = len(models_dict)
    n_cols = min(3, n_models)  # Max 3 columns
    n_rows = (n_models + n_cols - 1) // n_cols
    
    # Create the figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    if n_models == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    
    # Define colors for experts (using tab10 colormap)
    max_experts = max(
        int(getattr(getattr(model, "config", None), "n_experts", model.gate.shape[0]))
        for model in models_dict.values()
    )
    
    plot_idx = 0
    for label, model in sorted(models_dict.items()):
        row = plot_idx // n_cols
        col = plot_idx % n_cols
        ax = axes[0] if n_models == 1 else axes[row, col]
        
        # Determine number of experts for this model
        num_experts = int(getattr(getattr(model, "config", None), "n_experts", model.gate.shape[0]))
        
        # Get expert assignments for all points
        with torch.no_grad():
            gate_scores = torch.einsum("bf,ef->be", feature_tensor, model.gate)
            gate_probs = F.softmax(gate_scores, dim=-1)
            _, top_expert_indices = torch.topk(gate_probs, k=1, dim=-1)
            expert_assignments = top_expert_indices.squeeze(-1)
        
        # Reshape back to grid
        expert_grid = expert_assignments.numpy().reshape(resolution, resolution)
        
        # Create discrete colormap for this model
        colors = plt.cm.tab10.colors[:num_experts]
        cmap = ListedColormap(colors)
        boundaries = np.arange(-0.5, num_experts + 0.5, 1)
        norm = BoundaryNorm(boundaries, cmap.N)
        
        # Create the plot
        im = ax.imshow(expert_grid, extent=[0, 1, 0, 1], origin='lower', 
                      cmap=cmap, norm=norm, interpolation='nearest')
        
        #ax.set_title(f'{label}', fontsize=14, fontweight='bold', pad=15)
        ax.set_xlabel('Feature 1', fontsize=12)
        ax.set_ylabel('Feature 2', fontsize=12)
        
        # Add colorbar with expert labels
        cbar = plt.colorbar(im, ax=ax, ticks=range(num_experts), boundaries=boundaries)
        #cbar.set_label('Expert ID', fontsize=10)
        cbar.set_ticklabels([f'Expert {i}' for i in range(num_experts)])
        
        # Add grid for better readability
        ax.grid(True, alpha=0.3)
        
        # Add loss text below this subplot if provided
        if losses_dict is not None and label in losses_dict:
            loss = losses_dict[label]
            ax.text(0.5, -0.25, f'\\textbf{{Loss: {loss:.3f}}}',
                   transform=ax.transAxes, ha='center', va='top',
                   fontsize=16, color='black')
        
        plot_idx += 1
    
    # Hide any unused subplots
    for idx in range(plot_idx, n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        if n_rows > 1:
            axes[row, col].set_visible(False)
        else:
            axes[col].set_visible(False)
    
    # Adjust layout with extra space at bottom if losses are provided
    if losses_dict is not None:
        plt.tight_layout(rect=[0, 0.02, 1, 0.98])
    else:
        plt.tight_layout()
    
    # Save if path is provided
    if save_path:
        plt.savefig(f"{save_path}.pdf", bbox_inches='tight', dpi=300)
        plt.savefig(f"{save_path}.pgf", bbox_inches='tight')
        print(f"Saved to {save_path}.pdf and {save_path}.pgf")
    
    plt.show()
    
    #return fig

In [None]:
losses = {"1": loss_zero, "2": loss_base, "3": loss_khot}
render_expert_specialization_multi({"1": model_zero, "2": model_base, "3": model_khot}, losses_dict=losses)

### Reproducing Figure 6


In [None]:
from model.model import *
import matplotlib.pyplot as plt
import torch
import numpy

np.random.seed(41)
torch.manual_seed(41)
DEVICE = "cpu"

In [None]:
config_fig6 = Config(
    n_features = 20,
    n_hidden = 5,
    n_experts = 4,
    n_active_experts = 1,
    load_balancing_loss = False,
)

model_6a = MoEModel(
    config=config_fig6,
    device=DEVICE,
    importance = 0.7**torch.from_numpy(np.arange(config_fig6.n_features)),
    feature_probability = torch.tensor(0.09)
)

model_6b = MoEModel(
    config=config_fig6,
    device=DEVICE,
    importance = 0.7**torch.from_numpy(np.arange(config_fig6.n_features)),
    feature_probability = torch.tensor(0.1)
)

model_6c = MoEModel(
    config=config_fig6,
    device=DEVICE,
    importance = 0.7**torch.from_numpy(np.random.choice(config_fig6.n_features, config_fig6.n_features, replace=False)),
    feature_probability = torch.tensor(0.1)
)


# Initialize gate matrix to the diagonal (for 6a)
nn.init.constant_(model_6a.gate, 0)
with torch.no_grad():
    model_6a.gate.fill_diagonal_(1)

# Initialize gate matrix to increments of 5 (for 6b)
indices = np.arange(model_6b.gate.shape[1]) # NOT RANDOM ON THIS ONE
indices = torch.from_numpy(indices.reshape(model_6b.gate.shape[0], -1))
print(indices)
print(model_6b.importance)
nn.init.constant_(model_6b.gate, 0)
with torch.no_grad():
    for i in range(model_6b.gate.shape[0]):
        model_6c.gate[i, indices[i]] = 1

# Random k-hot initialize the gate matrix (for 6c)
indices = np.random.choice(model_6c.gate.shape[1], size=int(model_6c.gate.shape[0]*5), replace=False)
indices = torch.from_numpy(indices.reshape(model_6c.gate.shape[0], -1))
nn.init.constant_(model_6c.gate, 0)
with torch.no_grad():
    for i in range(model_6c.gate.shape[0]):
        model_6c.gate[i, indices[i]] = 1


In [None]:
print(f"Training first model...")
optimize(model_6a, n_batch=1024, steps=10000, print_freq=1000)
print(f"Training second model...")
optimize(model_6b, n_batch=1024, steps=10000, print_freq=1000)
print(f"Training last model...")
optimize(model_6c, n_batch=1024, steps=10000, print_freq=1000)

In [None]:
def render_features_bar(model, save_path=None):
    """Render bar plots showing feature norms colored by polysemanticity"""
    cfg = model.config
    # expert weights
    W_exp = model.W_experts.detach()
    W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=2, keepdim=True))

    interference = torch.einsum('ifh,igh->ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
    interference[:, torch.arange(cfg.n_features), torch.arange(cfg.n_features)] = 0 # set diagonal to 0

    polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()
    norms = torch.linalg.norm(W_exp, 2, dim=-1).cpu()

    x = torch.arange(cfg.n_features)
    
    # set up matplotlib for LaTeX
    plt.rcParams.update({
        'text.usetex': True,
        'font.family': 'serif',
        'font.serif': ['Computer Modern'],
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'legend.fontsize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12
    })
    
    fig, axes = plt.subplots(cfg.n_experts, 1, figsize=(6, 4 * cfg.n_experts))
    if cfg.n_experts == 1:
        axes = [axes]
    
    for expert_idx in range(cfg.n_experts):
        ax = axes[expert_idx]
        
        # Create bar plot with color mapping and narrower bars
        bars = ax.bar(x, norms[expert_idx], 
                     color=plt.cm.viridis(polysemanticity[expert_idx] / polysemanticity[expert_idx].max()),
                     width=0.6)
        
        # Add colorbar using ScalarMappable
        sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=0, vmax=polysemanticity[expert_idx].max()))
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax)
        cbar.set_label('Superposition')
        
        # Add vertical line at n_hidden boundary
        ax.axvline(x=(cfg.n_hidden-0.5), color='red', linestyle='--', alpha=0.7, linewidth=1)
        
        ax.set_title(f'Expert {expert_idx}')
        #ax.set_xlabel('Features $\\rightarrow$')
        #ax.set_ylabel('Norm $||W_i||$')
        ax.grid(True, alpha=0.3)
        
        # Set x-axis ticks to go from 0 to 5 to 10 to 15
        ax.set_xticks([0, 5, 10, 15])
        
        # Set x-axis limits to reduce horizontal width
        ax.set_xlim(-0.5, cfg.n_features - 0.5)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(f"{save_path}_bar.pdf", bbox_inches='tight')
        plt.savefig(f"{save_path}_bar.pgf", bbox_inches='tight')
    
    plt.show()
    #return fig

In [None]:
render_features_bar(model_6a)

In [None]:
render_features_bar(model_6b)

In [None]:
render_features_bar(model_6c)
print(indices)
print(model_6c.importance)

### Reproducing Table 1

In [None]:
from model.model import optimize_vectorized
from helpers.helpers import my_gate_init, save_initial_weights
from functools import partial
import os
cfg = Config(
    n_features = 100,
    n_hidden = 10,
    n_experts = 10,
    n_active_experts = 1,
    load_balancing_loss = False,
)
N = 200
DEVICE = "mps" if torch.mps.is_available() else "cpu"

feature_prob = torch.tensor(0.1).to(DEVICE)
importance = 0.9**torch.from_numpy(np.random.choice(cfg.n_features, cfg.n_features, replace=False)).to(DEVICE)

feature_probs = [feature_prob for _ in range(N)]
configs = [cfg for _ in range(N)]
importances = [importance for _ in range(N)]

os.makedirs("models", exist_ok=True)

# 2) Train vectorized
final_losses, stacked_params = optimize_vectorized(
    configs,
    feature_probs,
    importances,
    device=DEVICE,
    n_batch=1024,
    steps=10000,
    lr=1e-3,
    print_freq=1000,
    init_fn=my_gate_init,
    on_initialized=partial(save_initial_weights, save_path="models/hundred_feats_vectorized_init")
)
print("Per-model losses:", final_losses.tolist())

print(stacked_params.keys())
gate = stacked_params["gate"]
print(gate.shape)
snapshot = {
    "configs": configs,
    "feature_probs": [fp if torch.is_tensor(fp) else torch.tensor(fp) for fp in feature_probs],
    "importances": [imp.detach().cpu() for imp in importances],
    "stacked_params": {k: v.detach().cpu().clone() for k, v in stacked_params.items()},
}
torch.save(snapshot, "models/post_training_vectorized_weights.pt")

In [None]:
vectorized_models = torch.load("models/post_training_vectorized_weights.pt", weights_only=False)
stacked_params = vectorized_models['stacked_params']
W_exp = stacked_params['W_experts']
W_norm = W_exp / (1e-5 + torch.linalg.norm(W_exp, 2, dim=3, keepdim=True))

interference = torch.einsum('...ifh,...igh->...ifg', W_norm, W_exp) # (n_experts, n_features, n_features)
interference[:, :, torch.arange(100), torch.arange(100)] = 0 # set diagonal to 0
polysemanticity = torch.linalg.norm(interference, dim=-1).cpu()

mask = polysemanticity < 0.2
polysemanticity[~mask] = float('inf')
#new_poly = polysemanticity.flatten(start_dim=0, end_dim=1)
counts = mask.sum(dim=2)#.flatten(start_dim=0, end_dim=1)

xavier_mono_feats = []
khot_mono_feats = []
for i in range(100):
    expert_list = []
    for j in range(10):
        expert_feat_count = counts[i][j]
        if expert_feat_count > 0:
            expert_list.append((j, polysemanticity[i][j].topk(expert_feat_count, largest=False)[1]))
    print(expert_list)
    xavier_mono_feats.append(expert_list)
print("Moving from xavier init to khot init")
for i in range(100, 200):
    expert_list = []
    for j in range(10):
        expert_feat_count = counts[i][j]
        if expert_feat_count > 0:
            expert_list.append((j, polysemanticity[i][j].topk(expert_feat_count, largest=False)[1]))
    print(expert_list)
    khot_mono_feats.append(expert_list)   

print(xavier_mono_feats)
print(khot_mono_feats)

In [None]:
batch_gen_func = MoEModel.generate_batch

def dataset_conditional_batch_generation(model, info, n_batch=1024, clamp_value: Optional[float] = None):
    expert_id, feature_indices = info
    """Fix certain features and sample the other features according to config statistics"""
    # print(f"Using statistics from model. Feature probability: {model.feature_probability}")
    print(f'Expert {expert_id} monosemantically represents features: {feature_indices}')
    batch = batch_gen_func(model, n_batch)
    if clamp_value is not None:
        batch[:, feature_indices] = clamp_value
    else:
        # Sample uniformly between 0.5 and 1
        clamps = torch.rand(batch.shape[0], 1) #* 0.5 + 0.5
        batch[:, feature_indices] = clamps
    #print(batch)
    return batch
    # print(batch.shape)

def batch_only_feature_active_gen(model, info, n_batch=1024, clamp_value: Optional[float] = None):
    expert_id, feature_indices = info
    """Fix certain features and sample the other features according to config statistics"""
    # print(f"Using statistics from model. Feature probability: {model.feature_probability}")
    print(f'Expert {expert_id} monosemantically represents features: {feature_indices}')
    batch = torch.zeros(n_batch, model.config.n_features)
    clamps = torch.rand(batch.shape[0], 1)
    batch[:, feature_indices] = clamps
    return batch


In [None]:
#COLUMN_CHOICE = "unchanged"
COLUMN_CHOICE = "active (clamped to 1)"
#COLUMN_CHOICE = "only feature active (all other features=0)"

In [None]:
gates = stacked_params['gate']
xav_mean_usage_counts = {1: [], 2: [], 3: [], 4: [],}
khot_mean_usage_counts = {1: [], 2: [], 3: [], 4: [], 5: [],}
xav_medians = []
khot_medians = []
model = MoEModel(config=vectorized_models['configs'][0], device="cpu") # simply used for generating the input batch

batch_gen_functions = {
    "unchanged": batch_gen_func,
    "active (clamped to 1)": dataset_conditional_batch_generation,
    "only feature active (all other features=0)": batch_only_feature_active_gen
}

for i in range(len(xavier_mono_feats)):
    for j in range(len(xavier_mono_feats[i])):
        gate = gates[i]
        func = batch_gen_functions[COLUMN_CHOICE]
        if COLUMN_CHOICE == "unchanged":
            batch = func(model, n_batch=8192)
        else:
            batch = func(model, xavier_mono_feats[i][j], n_batch=8192, clamp_value=1.0)
        gate_scores = torch.einsum("...f,ef->...e", batch, gate)
        gate_probs = F.softmax(gate_scores, dim=-1)
        gate_probs_topk = gate_probs.topk(1, dim=-1)

        values, indices = gate_probs_topk
        counts = indices.flatten().bincount()
        if counts.shape[0] < model.config.n_experts:
            counts = torch.cat([counts, torch.zeros(model.config.n_experts - counts.shape[0])])

        counts = counts / counts.sum()

        xav_medians.append(torch.median(counts).item())

        num_feats = xavier_mono_feats[i][j][1].shape[0]
        exp_num = xavier_mono_feats[i][j][0]
        xav_mean_usage_counts[num_feats].append(counts[exp_num].detach().item())

print("ENDING XAV, STARTING KHOT")

for i in range(len(khot_mono_feats)):
    for j in range(len(khot_mono_feats[i])):
        gate = gates[i+100]
        func = batch_gen_functions[COLUMN_CHOICE]
        if COLUMN_CHOICE == "unchanged":
            batch = func(model, n_batch=8192)
        else:
            batch = func(model, khot_mono_feats[i][j], n_batch=8192, clamp_value=1.0)
        gate_scores = torch.einsum("...f,ef->...e", batch, gate)
        gate_probs = F.softmax(gate_scores, dim=-1)
        gate_probs_topk = gate_probs.topk(1, dim=-1)

        values, indices = gate_probs_topk
        counts = indices.flatten().bincount()
        if counts.shape[0] < model.config.n_experts:
            counts = torch.cat([counts, torch.zeros(model.config.n_experts - counts.shape[0])])

        counts = counts / counts.sum()


        khot_medians.append(torch.median(counts).item())

        num_feats = khot_mono_feats[i][j][1].shape[0]
        exp_num = khot_mono_feats[i][j][0]
        khot_mean_usage_counts[num_feats].append(counts[exp_num].detach().item())

In [None]:
gates = stacked_params['gate']

print("Xavier Initialization Scheme\n")
print(f"Number of experts with 1 monosemantic feature: {len(xav_mean_usage_counts[1])}")
print(f"Number of experts with 2 monosemantic feature: {len(xav_mean_usage_counts[2])}")
print(f"Number of experts with 3 monosemantic feature: {len(xav_mean_usage_counts[3])}")
print(f"Number of experts with 4 monosemantic feature: {len(xav_mean_usage_counts[4])}\n")

mean_usage_one = sum(xav_mean_usage_counts[1]) / len(xav_mean_usage_counts[1])
mean_usage_two = sum(xav_mean_usage_counts[2]) / len(xav_mean_usage_counts[2])
mean_usage_three = sum(xav_mean_usage_counts[3]) / len(xav_mean_usage_counts[3])
mean_usage_four = sum(xav_mean_usage_counts[4]) / len(xav_mean_usage_counts[4])
#print(f"Average usage of experts with 1,2,3,4 monosemantic features, when that feature is {COLUMN_CHOICE}:")
print(f"\nAverage usage of experts w/ 1 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_one}")
print(f"Average usage of experts w/ 2 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_two}")
print(f"Average usage of experts w/ 3 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_three}")
print(f"Average usage of experts w/ 4 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_four}")
#print(mean_usage_one, mean_usage_two, mean_usage_three, mean_usage_four)

xav_mean_med = sum(xav_medians) / len(xav_medians)
xav_mean_std = torch.std(torch.as_tensor(xav_medians)).item()
#print(xav_mean_med, xav_mean_std)

print("\nK-Hot Initialization Scheme\n")
print(f"Number of experts with 1 monosemantic feature: {len(khot_mean_usage_counts[1])}")
print(f"Number of experts with 2 monosemantic feature: {len(khot_mean_usage_counts[2])}")
print(f"Number of experts with 3 monosemantic feature: {len(khot_mean_usage_counts[3])}")
print(f"Number of experts with 4 monosemantic feature: {len(khot_mean_usage_counts[4])}")
mean_usage_one = sum(khot_mean_usage_counts[1]) / len(khot_mean_usage_counts[1])
mean_usage_two = sum(khot_mean_usage_counts[2]) / len(khot_mean_usage_counts[2])
mean_usage_three = sum(khot_mean_usage_counts[3]) / len(khot_mean_usage_counts[3])
mean_usage_four = sum(khot_mean_usage_counts[4]) / len(khot_mean_usage_counts[4])
mean_usage_five = sum(khot_mean_usage_counts[5]) / len(khot_mean_usage_counts[5]) if len(khot_mean_usage_counts[5]) else 0.0
#print(f"Average usage of experts with 1,2,3,4 monosemantic features, when that feature is {COLUMN_CHOICE}:")
print(f"\nAverage usage of experts w/ 1 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_one}")
print(f"Average usage of experts w/ 2 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_two}")
print(f"Average usage of experts w/ 3 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_three}")
print(f"Average usage of experts w/ 4 monosemantic features, when that feature(s) is {COLUMN_CHOICE}:\n {mean_usage_four}")
#print(mean_usage_one, mean_usage_two, mean_usage_three, mean_usage_four, mean_usage_five)

khot_mean_med = sum(khot_medians) / len(khot_medians)
khot_mean_std = torch.std(torch.as_tensor(khot_medians)).item()
#print(khot_mean_med, khot_mean_std)