In [None]:
import sys
import torch
from model.model import Config, MoEModel, optimize

In [35]:


# Set device
if torch.cuda.is_available():
  DEVICE = 'cuda'
else:
  DEVICE = 'cpu'

config = Config(
    n_features = 3,
    n_hidden = 2,
    n_experts = 2,
    n_active_experts = 1,
    load_balancing_loss = False,
)

# Configure importance and feature probability (sparsity)
model = MoEModel(
    config=config,
    device=DEVICE,
    importance = torch.tensor(0.5**torch.arange(3)),
    feature_probability = torch.tensor(0.1)
)

# Train the model
print("Training model...")
optimize(model, n_batch=512, steps=5000, print_freq=500, lr=1e-3)

print("Gate matrix:")
print(model.gate)
print("Expert weights:")
print(model.W_experts)

# Print final model parameters
print("\nFinal model parameters:")
print(f"Feature probability: {model.feature_probability.item()}")
print(f"Importance weights: {model.importance}")


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Training model...
Step 0: loss=0.106905, lr=0.001000
Step 500: loss=0.049430, lr=0.001000
Step 1000: loss=0.032918, lr=0.001000
Step 1500: loss=0.005400, lr=0.001000
Step 2000: loss=0.004574, lr=0.001000
Step 2500: loss=0.001364, lr=0.001000
Step 3000: loss=0.003468, lr=0.001000
Step 3500: loss=0.002156, lr=0.001000
Step 4000: loss=0.002542, lr=0.001000
Step 4500: loss=0.003625, lr=0.001000
Step 4999: loss=0.006719, lr=0.001000
Gate matrix:
Parameter containing:
tensor([[-0.2965, -0.4569,  0.5824],
        [-0.1849,  0.7613, -0.2996]], requires_grad=True)
Expert weights:
Parameter containing:
tensor([[[ 3.7119e-05, -4.1666e-05],
         [ 6.6592e-01,  7.4549e-01],
         [ 7.4576e-01, -6.6615e-01]],

        [[-1.0011e+00, -1.2473e-02],
         [-1.4240e-02,  1.0007e+00],
         [ 9.7353e-01,  1.8987e-02]]], requires_grad=True)

Final model parameters:
Feature probability: 0.10000000149011612
Importance weights: tensor([1.0000, 0.5000, 0.2500])
