In [None]:
!pip install torch datasets numpy matplotlib seaborn
!pip uninstall -y datasets fsspec huggingface_hub transformers tokenizers
!rm -rf ~/.cache/huggingface/datasets
!pip install datasets==2.14.7 fsspec==2023.10.0 huggingface_hub==0.17.3 transformers==4.35.2 tokenizers==0.15.0
!pip install torch-geometric

In [None]:
#print(torch.__version__)
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

<h1>Hyperparameter Configuration</h1>

In [None]:
# --- Cell 1: Hyperparameter Configuration ---
import json

class TrainingConfig:
    def __init__(self):
        # --- General & Model Architecture ---
        self.run_name = "notebook_training_run"
        self.embed_dim = 256
        self.num_layers = 4
        self.num_experts = 4
        self.vocab_size = 50257  # GPT-2 default
        self.max_seq_length = 1024
        self.dropout_rate = 0.1

        # --- GNN/HGNN Coupler Settings ---
        self.coupler_type = "HGNN"  # "GNN" or "HGNN"
        self.gnn_layers = 2
        self.hgnn_conv_type = "HypergraphConv"
        self.static_hyperedge_strategy = "all_pairs" # "all_pairs" or "all_triplets"
        self.hgnn_learnable_edge_weights = True

        # --- Training Settings ---
        self.epochs = 5
        self.batch_size = 16
        self.learning_rate = 3e-4
        self.lr_scheduler_type = "cosine" # "cosine", "linear", "step"
        self.warmup_steps = 500
        self.weight_decay = 0.01
        self.grad_clip_value = 1.0

        # --- Dataset Settings ---
        self.dataset_name = "wikitext"
        self.dataset_config_name = "wikitext-2-v1"
        self.num_train_samples = 10000
        self.num_eval_samples = 1000

        # --- Adaptive Orthogonality (Phase 2.2) ---
        self.adaptive_weight_orthogonality = True
        self.initial_weight_orthogonality_strength = 0.1
        self.target_specialization_score = 0.95
        self.adaptive_decay_schedule = "cosine"
        
        # --- Static Orthogonality (Phase 2.1) ---
        self.apply_weight_orthogonality_loss = False # Overridden by adaptive if True
        self.weight_orthogonality_loss_weight = 0.05

        # --- Logging & Checkpointing ---
        self.eval_every = 250 # steps
        self.log_every = 50 # steps
        self.checkpoint_dir = "checkpoints_notebook"

    def to_cli_args(self):
        """Converts config to a list of CLI arguments for run_gnn_moe.py."""
        args = []
        for key, value in self.__dict__.items():
            if isinstance(value, bool):
                if value:
                    args.append(f"--{key}")
            elif value is not None:
                args.append(f"--{key}")
                args.append(str(value))
        return args

# Create an instance of the config
config = TrainingConfig()

# You can modify the config here, for example:
# config.run_name = "my_special_experiment"
# config.num_experts = 8

print("✅ TrainingConfig created. Modify the 'config' object to customize your run.")
print(f"▶️ Run Name: {config.run_name}")


<h1>Execute Training Run</h1>

In [None]:
# --- Cell 2: Execute Training Run ---
import subprocess
import sys

# Convert the config object to CLI arguments
cli_args = config.to_cli_args()

# Construct the command
command = [sys.executable, "orthogon/adaptive-orthogonal/run_gnn_moe.py"] + cli_args

print("🚀 Starting training run with the following command:")
# Print a more readable version of the command
print("python orthogon/adaptive-orthogonal/run_gnn_moe.py \\")
for i in range(0, len(cli_args), 2):
    if i + 1 < len(cli_args):
        print(f"  {cli_args[i]} {cli_args[i+1]} \\")
    else:
        print(f"  {cli_args[i]}")
print("-" * 50)

# Execute the command
# The output will be streamed to the notebook's output area
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)

while True:
    output = process.stdout.readline()
    if output == '' and process.poll() is not None:
        break
    if output:
        print(output.strip())

rc = process.poll()
if rc == 0:
    print("\n✅ Training run completed successfully!")
else:
    print(f"\n❌ Training run failed with exit code {rc}.")



<h1>Hyperparameter Sweep Setup</h1>

In [None]:
# --- Cell 3: Hyperparameter Sweep Setup ---
import itertools
import pandas as pd

class HyperparameterSweep:
    def __init__(self, base_config):
        self.base_config = base_config
        self.sweep_params = {}

    def add_sweep(self, param_name, values):
        """Add a parameter to sweep over."""
        self.sweep_params[param_name] = values
        print(f"Added sweep for '{param_name}' with values: {values}")

    def run_sweep(self):
        """Runs the hyperparameter sweep."""
        keys = self.sweep_params.keys()
        values = self.sweep_params.values()
        
        param_combinations = list(itertools.product(*values))
        
        print(f"\n🔬 Starting hyperparameter sweep with {len(param_combinations)} combinations.")
        
        results = []
        
        for i, combo in enumerate(param_combinations):
            print(f"\n--- Running Combination {i+1}/{len(param_combinations)} ---")
            
            # Create a new config for this run
            sweep_config = TrainingConfig()
            sweep_config.__dict__.update(self.base_config.__dict__)
            
            # Apply the current combination of hyperparameters
            for key, value in zip(keys, combo):
                setattr(sweep_config, key, value)
            
            # Create a unique run name
            run_name_parts = [f"{key.replace('_','-')}-{value}" for key, value in zip(keys, combo)]
            sweep_config.run_name = f"sweep_{'_'.join(run_name_parts)}"
            
            print(f"Run Name: {sweep_config.run_name}")
            
            # Execute the training run
            cli_args = sweep_config.to_cli_args()
            command = [sys.executable, "orthogon/adaptive-orthogonal/run_gnn_moe.py"] + cli_args
            
            process = subprocess.run(command, capture_output=True, text=True)
            
            if process.returncode == 0:
                print("✅ Run completed successfully.")
                # In a real scenario, you would parse the output to get metrics
                # For this example, we'll just record success
                result = {key: value for key, value in zip(keys, combo)}
                result['status'] = 'Success'
                results.append(result)
            else:
                print(f"❌ Run failed.")
                print(process.stderr)
                result = {key: value for key, value in zip(keys, combo)}
                result['status'] = 'Failed'
                results.append(result)

        # Display results in a table
        results_df = pd.DataFrame(results)
        print("\n--- Sweep Results ---")
        print(results_df)
        return results_df

# --- Example Sweep ---
# Create a sweep instance
sweep = HyperparameterSweep(base_config=config)

# Add parameters to sweep
sweep.add_sweep('learning_rate', [1e-4, 3e-4, 5e-4])
sweep.add_sweep('num_experts', [2, 4])

# To run the sweep, uncomment the line below
# sweep_results_df = sweep.run_sweep()


<h1>Demonstration Commands</h1>

In [None]:
# --- Cell 4: Demonstration Commands ---

def run_demo_command(demo_name, command_list):
    """Helper function to run and print a demo command."""
    print(f"--- Running Demo: {demo_name} ---")
    
    # Construct the command
    command = [sys.executable, "orthogon/adaptive-orthogonal/run_gnn_moe.py"] + command_list
    
    # Print a readable version
    print("python orthogon/adaptive-orthogonal/run_gnn_moe.py \\")
    for i in range(0, len(command_list), 2):
        if i + 1 < len(command_list):
            print(f"  --{command_list[i]} {command_list[i+1]} \\")
        else:
            print(f"  --{command_list[i]}")
    print("-" * 50)
    
    # Execute the command
    process = subprocess.run(command, capture_output=True, text=True)
    
    if process.returncode == 0:
        print("✅ Demo run completed successfully.")
    else:
        print(f"❌ Demo run failed.")
        print(process.stderr)

# --- Demo Configurations ---

# 1. Standard GNN-MoE
demo_gnn_moe = {
    "run_name": "demo_gnn_moe",
    "coupler_type": "GNN",
    "num_experts": 4,
    "epochs": 1,
    "max_batches_per_epoch": 100 # For a quick run
}

# 2. HGNN-MoE (Static Orthogonality)
demo_hgnn_moe = {
    "run_name": "demo_hgnn_moe_static_ortho",
    "coupler_type": "HGNN",
    "num_experts": 4,
    "apply_weight_orthogonality_loss": True,
    "weight_orthogonality_loss_weight": 0.05,
    "epochs": 1,
    "max_batches_per_epoch": 100
}

# 3. HGNN-MoE (Adaptive Orthogonality)
demo_adaptive_hgnn_moe = {
    "run_name": "demo_hgnn_moe_adaptive_ortho",
    "coupler_type": "HGNN",
    "num_experts": 4,
    "adaptive_weight_orthogonality": True,
    "initial_weight_orthogonality_strength": 0.1,
    "epochs": 1,
    "max_batches_per_epoch": 100
}

# --- To run a demo, uncomment one of the lines below ---
# run_demo_command("GNN-MoE", [f"--{k}" if isinstance(v, bool) and v else f"--{k} {v}" for k, v in demo_gnn_moe.items() for item in ((k,v),) if not (isinstance(v, bool) and not v)])
# run_demo_command("HGNN-MoE (Static Ortho)", [f"--{k}" if isinstance(v, bool) and v else f"--{k} {v}" for k, v in demo_hgnn_moe.items() for item in ((k,v),) if not (isinstance(v, bool) and not v)])
# run_demo_command("HGNN-MoE (Adaptive Ortho)", [f"--{k}" if isinstance(v, bool) and v else f"--{k} {v}" for k, v in demo_adaptive_hgnn_moe.items() for item in ((k,v),) if not (isinstance(v, bool) and not v)])

# A better way to run the demos
def dict_to_cli_args(d):
    args = []
    for key, value in d.items():
        if isinstance(value, bool):
            if value:
                args.append(f"--{key}")
        elif value is not None:
            args.append(f"--{key}")
            args.append(str(value))
    return args

# To run a demo, uncomment one of the lines below
# run_demo_command("GNN-MoE", dict_to_cli_args(demo_gnn_moe))
# run_demo_command("HGNN-MoE (Static Ortho)", dict_to_cli_args(demo_hgnn_moe))
# run_demo_command("HGNN-MoE (Adaptive Ortho)", dict_to_cli_args(demo_adaptive_hgnn_moe))



<h1>Architectural Defaults</h1>

# ---
# ## ⚙️ Command-Line Argument Reference

This section provides a comprehensive reference for all the command-line arguments available in the `run_gnn_moe.py` script. These arguments correspond to the fields in the `GNNMoEConfig` class.

### **Model Architecture**

| Argument | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `--embed_dim` | int | 512 | The dimensionality of the token embeddings. |
| `--num_layers` | int | 8 | The number of transformer layers in the model. |
| `--num_heads` | int | 8 | The number of attention heads in each transformer layer. |
| `--num_experts` | int | 4 | The number of experts in each Mixture of Experts (MoE) layer. |
| `--ffn_dim_multiplier` | int | 4 | Multiplier for the feed-forward network dimension relative to `embed_dim`. |
| `--vocab_size` | int | 50257 | The size of the vocabulary (defaults to GPT-2's vocab size). |
| `--max_seq_length` | int | 1024 | The maximum sequence length the model can handle. |
| `--dropout_rate` | float | 0.1 | The dropout rate used in the model. |

### **GNN / HGNN Coupler**

| Argument | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `--coupler_type` | str | "GNN" | The type of expert coupler to use. Options: `"GNN"`, `"HGNN"`. |
| `--gnn_layers` | int | 1 | The number of layers in the GNN/HGNN coupler. |
| `--hgnn_conv_type` | str | "HypergraphConv" | The type of PyG convolution to use for the HGNN. |
| `--static_hyperedge_strategy` | str | "all_pairs" | The strategy for creating static hyperedges. Options: `"all_pairs"`, `"all_triplets"`. |
| `--hgnn_learnable_edge_weights`| bool | `False` | If set, the weights of the hyperedges will be learnable parameters. |

### **Training Parameters**

| Argument | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `--epochs` | int | 5 | The total number of training epochs. |
| `--batch_size` | int | 32 | The number of sequences in each training batch. |
| `--learning_rate` | float | 3e-4 | The initial learning rate for the AdamW optimizer. |
| `--lr_scheduler_type` | str | "cosine" | The learning rate scheduler type. Options: `"cosine"`, `"linear"`, `"step"`. |
| `--warmup_steps` | int | 1000 | The number of warmup steps for the learning rate scheduler. |
| `--weight_decay` | float | 0.01 | The weight decay to apply during optimization. |
| `--grad_clip_value` | float | 1.0 | The value to clip gradients at to prevent exploding gradients. |

### **Dataset & Data Handling**

| Argument | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `--dataset_name` | str | "wikitext" | The name of the Hugging Face dataset to use. |
| `--dataset_config_name` | str | "wikitext-2-v1"| The specific configuration of the dataset. |
| `--num_train_samples` | int | `None` | The number of training samples to use (if `None`, uses the full dataset). |
| `--num_eval_samples` | int | `None` | The number of evaluation samples to use (if `None`, uses the full dataset). |
| `--num_test_samples` | int | `None` | The number of test samples to use (if `None`, uses the full dataset). |

### **Static Weight Orthogonality (Phase 2.1)**

| Argument | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `--apply_weight_orthogonality_loss` | bool | `False` | If set, applies a static orthogonality loss to the expert weights. |
| `--weight_orthogonality_loss_weight` | float | 0.01 | The strength of the static orthogonality loss. |
| `--weight_orthogonality_normalization` | str | "frobenius" | The normalization method for the orthogonality loss. |

### **Adaptive Weight Orthogonality (Phase 2.2)**

| Argument | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `--adaptive_weight_orthogonality` | bool | `False` | If set, enables the adaptive orthogonality system. This overrides static settings. |
| `--initial_weight_orthogonality_strength`| float | 0.1 | The starting strength of the adaptive orthogonality constraint. |
| `--minimum_weight_orthogonality_strength`| float | 0.001 | The minimum strength the constraint can decay to. |
| `--maximum_weight_orthogonality_strength`| float | 0.3 | The maximum strength the constraint can be boosted to. |
| `--adaptive_decay_schedule` | str | "cosine" | The decay schedule for the constraint strength. Options: `"cosine"`, `"exponential"`, `"linear"`, `"step"`. |
| `--adaptation_frequency` | int | 500 | The number of training steps between adaptation checks. |
| `--target_specialization_score` | float | 0.95 | The target orthogonality score the controller aims for. |
| `--specialization_tolerance` | float | 0.02 | The tolerance band (±) around the target specialization score. |
| `--layer_specific_adaptation` | bool | `True` | If set, applies different constraint strengths to different layers. |
| `--deeper_layer_scaling` | float | 0.8 | The scaling factor for reducing constraint strength in deeper layers. |
| `--performance_aware_adaptation` | bool | `True` | If set, the controller considers model performance when adapting. |
| `--performance_monitor_window` | int | 100 | The number of steps to average performance over. |
| `--collapse_detection_threshold` | float | 0.1 | The specialization score threshold for detecting expert collapse. |
| `--emergency_constraint_boost` | bool | `True` | If set, enables the emergency boost mechanism to prevent collapse. |
| `--emergency_boost_multiplier` | float | 2.0 | The multiplier for the constraint strength during an emergency boost. |

### **Logging & System**

| Argument | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `--run_name` | str | `None` | A unique name for the training run. If `None`, a name is generated. |
| `--checkpoint_dir` | str | "./checkpoints" | The directory to save model checkpoints. |
| `--log_every` | int | 100 | The number of steps between logging training progress. |
| `--eval_every` | int | 500 | The number of steps between running evaluation. |
| `--save_best_only` | bool | `True` | If set, only saves the checkpoint with the best evaluation loss. |
| `--seed` | int | 42 | The random seed for reproducibility. |

