# NeuralFlex-MoE Colab Notebook

This notebook allows you to run and test the NeuralFlex-MoE model directly in Google Colab or a local Jupyter environment.

In [None]:
# @title 1. Setup Environment
# Install required dependencies
!pip install torch transformers accelerate deepspeed bitsandbytes einops flash-attn xformers triton

In [None]:
# @title 2. Configure Path
import sys
import os

# Add the src directory to the python path so we can import neuraflex_moe
# If running in Colab, you might need to clone the repo first
if 'google.colab' in sys.modules:
    # Assuming you have uploaded the code to Drive or cloned it
    # !git clone https://github.com/your-repo/NeuralFlex-MoE.git
    # %cd NeuralFlex-MoE
    pass

# Add src to path
sys.path.append(os.path.abspath('../src'))

print("Python path configured.")

In [None]:
# @title 3. Import Model and Config
import torch
from neuraflex_moe.models.neuraflex_moe import NeuralFlexMoE
from neuraflex_moe.config import MODEL_CONFIG, DEBUG_CONFIG

print("Imports successful.")

In [None]:
# @title 4. Initialize Model
# Use DEBUG_CONFIG for a smaller model that fits easily in Colab free tier memory
config = DEBUG_CONFIG

print(f"Initializing {config['model_name']} with config:")
print(config)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = NeuralFlexMoE(config)
model.to(device)

print("Model initialized successfully.")

In [None]:
# @title 5. Run Inference Test
# Create a dummy input
batch_size = 2
seq_len = 16
input_ids = torch.randint(0, config['vocab_size'], (batch_size, seq_len)).to(device)

print("Running forward pass...")
with torch.no_grad():
    outputs = model(input_ids)

logits = outputs['logits']
aux_loss = outputs['aux_loss']

print(f"Logits shape: {logits.shape}")
print(f"Aux loss: {aux_loss.item()}")
print("Test passed!")

In [None]:
# @title 6. Test Uncertainty-Aware Generation
from neuraflex_moe.core.uncertainty_aware_generation import UncertaintyAwareGeneration

uag = UncertaintyAwareGeneration(config)
uag.to(device)

# Simulate logits and hidden states
sim_logits = torch.randn(batch_size, seq_len, config['vocab_size']).to(device)
sim_hidden = torch.randn(batch_size, seq_len, config['base_hidden_size']).to(device)

result = uag(model, input_ids, sim_logits, sim_hidden)

print(f"Confidence: {result['confidence']:.4f}")
print(f"Uncertainty Flag: {result['uncertainty_flag']}")
if result['alternatives']:
    print(f"Generated {len(result['alternatives'])} alternatives.")