In [9]:
!nvidia-smi


Tue Jun 17 16:26:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   47C    P0             27W /   70W |     102MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [1]:
!pip install pynvml



In [2]:
from pynvml import *
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
print("Power draw (W):", nvmlDeviceGetPowerUsage(handle) / 1000)

Power draw (W): 10.051


In [3]:
import threading
import time
import json
import logging
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict, deque
import numpy as np
import pynvml

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

PYNVML_AVAILABLE = True

@dataclass
class GPUMetrics:
  def __post_init__(self):
    if self.timestamp is None:
      self.timestamp = time.time()
  timestamp: float
  power_draw: float  # Watts
  temperature: float  # Celsius
  memory_used: int   # Bytes
  memory_total: int  # Bytes
  gpu_utilization: float  # Percentage
  memory_utilization: float  # Percentage}


@dataclass
class ExpertProfile:
    """Profile data for a specific expert."""
    expert_id: str
    flops: int
    memory_footprint: int
    avg_latency: float
    energy_cost: float  # Estimated Joules
    activation_count: int
    last_updated: float


class GPUProfiler:
    """
    Continuously monitors GPU metrics and provides energy profiling for MoE experts.
    Runs in a separate thread to avoid blocking main computation.
    """

    def __init__(self, device_id: int = 0, poll_interval: float = 0.1):
        self.device_id = device_id
        self.poll_interval = poll_interval
        self.is_running = False
        self.metrics_history = deque(maxlen=1000)  # Keep last 1000 samples
        self.expert_profiles = {}
        self.operation_stack = []  # Stack for nested operations
        self.lock = threading.Lock()

        if PYNVML_AVAILABLE:
            try:
                pynvml.nvmlInit()
                self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
                self.nvml_available = True
            except Exception as e:
                logging.warning(f"Failed to initialize NVML: {e}")
                self.nvml_available = False
        else:
            self.nvml_available = False

        self.polling_thread = None

    def start_profiling(self):
        """Start the GPU monitoring thread."""
        if self.is_running:
            return

        self.is_running = True
        self.polling_thread = threading.Thread(target=self._polling_loop, daemon=True)
        self.polling_thread.start()
        logging.info("GPU profiling started")

    def stop_profiling(self):
        """Stop the GPU monitoring thread."""
        self.is_running = False
        if self.polling_thread:
            self.polling_thread.join(timeout=1.0)
        logging.info("GPU profiling stopped")

    def _polling_loop(self):
        """Main polling loop running in separate thread."""
        while self.is_running:
            try:
                metrics = self._collect_gpu_metrics()
                if metrics:
                    with self.lock:
                        self.metrics_history.append(metrics)
                time.sleep(self.poll_interval)
            except Exception as e:
                logging.error(f"Error in GPU polling: {e}")
                time.sleep(self.poll_interval)

    def _collect_gpu_metrics(self) -> Optional[GPUMetrics]:
        """Collect current GPU metrics."""
        if not self.nvml_available:
            # Return dummy metrics for testing
            return GPUMetrics(
                timestamp=time.time(),
                power_draw=150.0 + np.random.normal(0, 10),
                temperature=65.0 + np.random.normal(0, 5),
                memory_used=int(4e9 + np.random.normal(0, 1e8)),
                memory_total=int(8e9),
                gpu_utilization=80.0 + np.random.normal(0, 10),
                memory_utilization=50.0 + np.random.normal(0, 5)
            )

        try:
            power_draw = pynvml.nvmlDeviceGetPowerUsage(self.handle) / 1000.0  # mW to W
            temperature = pynvml.nvmlDeviceGetTemperature(self.handle, pynvml.NVML_TEMPERATURE_GPU)

            mem_info = pynvml.nvmlDeviceGetMemoryInfo(self.handle)
            memory_used = mem_info.used
            memory_total = mem_info.total

            util_rates = pynvml.nvmlDeviceGetUtilizationRates(self.handle)
            gpu_util = util_rates.gpu
            memory_util = util_rates.memory

            return GPUMetrics(
                timestamp=time.time(),
                power_draw=power_draw,
                temperature=temperature,
                memory_used=memory_used,
                memory_total=memory_total,
                gpu_utilization=gpu_util,
                memory_utilization=memory_util
            )
        except Exception as e:
            logging.error(f"Failed to collect GPU metrics: {e}")
            return None

    def get_current_metrics(self) -> Optional[GPUMetrics]:
        """Get the most recent GPU metrics."""
        with self.lock:
            return self.metrics_history[-1] if self.metrics_history else None

    def get_metrics_window(self, duration: float = 1.0) -> List[GPUMetrics]:
        """Get metrics from the last `duration` seconds."""
        current_time = time.time()
        cutoff_time = current_time - duration

        with self.lock:
            return [m for m in self.metrics_history if m.timestamp >= cutoff_time]

    def start_operation(self, operation_name: str, expert_id: Optional[str] = None):
        """Start tracking an operation (e.g., expert forward pass)."""
        operation_data = {
            'name': operation_name,
            'expert_id': expert_id,
            'start_time': time.time(),
            'start_metrics': self.get_current_metrics()
        }
        self.operation_stack.append(operation_data)

    def end_operation(self) -> Optional[Dict]:
        """End tracking the current operation and return profiling data."""
        if not self.operation_stack:
            return None

        operation_data = self.operation_stack.pop()
        end_time = time.time()
        end_metrics = self.get_current_metrics()

        duration = end_time - operation_data['start_time']

        # Calculate energy consumption during operation
        energy_consumed = 0.0
        if operation_data['start_metrics'] and end_metrics:
            avg_power = (operation_data['start_metrics'].power_draw + end_metrics.power_draw) / 2.0
            energy_consumed = avg_power * duration  # Joules

        result = {
            'operation': operation_data['name'],
            'expert_id': operation_data['expert_id'],
            'duration': duration,
            'energy_consumed': energy_consumed,
            'start_metrics': operation_data['start_metrics'],
            'end_metrics': end_metrics
        }

        # Update expert profile if applicable
        if operation_data['expert_id']:
            self._update_expert_profile(operation_data['expert_id'], result)

        return result

    def _update_expert_profile(self, expert_id: str, operation_result: Dict):
        """Update the profile for a specific expert."""
        if expert_id not in self.expert_profiles:
            self.expert_profiles[expert_id] = ExpertProfile(
                expert_id=expert_id,
                flops=0,
                memory_footprint=0,
                avg_latency=0.0,
                energy_cost=0.0,
                activation_count=0,
                last_updated=time.time()
            )

        profile = self.expert_profiles[expert_id]
        profile.activation_count += 1

        # Update running averages
        alpha = 0.1  # Exponential moving average factor
        profile.avg_latency = (1 - alpha) * profile.avg_latency + alpha * operation_result['duration']
        profile.energy_cost = (1 - alpha) * profile.energy_cost + alpha * operation_result['energy_consumed']
        profile.last_updated = time.time()

    def estimate_expert_flops(self, expert_module: nn.Module, input_shape: Tuple[int, ...]) -> int:
        """
        Estimate FLOPs for an expert module.
        This is a simplified estimation - you might want to use more sophisticated methods.
        """
        total_flops = 0

        # Create dummy input
        dummy_input = torch.randn(input_shape)

        def flop_count_hook(module, input, output):
            nonlocal total_flops
            if isinstance(module, nn.Linear):
                # For linear layers: input_size * output_size * batch_size
                in_features = module.in_features
                out_features = module.out_features
                batch_size = input[0].shape[0] if input else 1
                total_flops += in_features * out_features * batch_size
            elif isinstance(module, nn.Conv2d):
                # Simplified conv2d FLOP estimation
                kernel_flops = module.kernel_size[0] * module.kernel_size[1] * module.in_channels
                output_elements = output.numel() if hasattr(output, 'numel') else 1
                total_flops += kernel_flops * output_elements

        # Register hooks
        hooks = []
        for module in expert_module.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                hooks.append(module.register_forward_hook(flop_count_hook))

        try:
            with torch.no_grad():
                expert_module(dummy_input)
        finally:
            # Remove hooks
            for hook in hooks:
                hook.remove()

        return total_flops

    def get_expert_profile(self, expert_id: str) -> Optional[ExpertProfile]:
        """Get the profile for a specific expert."""
        return self.expert_profiles.get(expert_id)

    def get_all_expert_profiles(self) -> Dict[str, ExpertProfile]:
        """Get profiles for all experts."""
        return self.expert_profiles.copy()

    def save_profiles(self, filepath: str):
        """Save expert profiles to JSON file."""
        profiles_dict = {
            expert_id: asdict(profile)
            for expert_id, profile in self.expert_profiles.items()
        }

        with open(filepath, 'w') as f:
            json.dump(profiles_dict, f, indent=2)

    def load_profiles(self, filepath: str):
        """Load expert profiles from JSON file."""
        try:
            with open(filepath, 'r') as f:
                profiles_dict = json.load(f)

            self.expert_profiles = {
                expert_id: ExpertProfile(**profile_data)
                for expert_id, profile_data in profiles_dict.items()
            }
        except Exception as e:
            logging.error(f"Failed to load profiles: {e}")

    def get_power_statistics(self, duration: float = 10.0) -> Dict[str, float]:
        """Get power consumption statistics over a time window."""
        metrics = self.get_metrics_window(duration)
        if not metrics:
            return {}

        power_values = [m.power_draw for m in metrics]

        return {
            'mean_power': np.mean(power_values),
            'max_power': np.max(power_values),
            'min_power': np.min(power_values),
            'std_power': np.std(power_values),
            'total_energy': np.sum(power_values) * self.poll_interval  # Approximate
        }

    def __enter__(self):
        """Context manager entry."""
        self.start_profiling()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.stop_profiling()


# Example usage and testing
if __name__ == "__main__":
    # Example of how to use the GPUProfiler
    profiler = GPUProfiler(device_id=0, poll_interval=0.1)

    # Start profiling
    profiler.start_profiling()

    # Simulate some operations
    time.sleep(1.0)

    # Example expert operation
    profiler.start_operation("expert_forward", "expert_0")
    time.sleep(0.5)  # Simulate computation
    result = profiler.end_operation()

    print("Operation result:", result)

    # Get current metrics
    current = profiler.get_current_metrics()
    if current:
        print(f"Current power: {current.power_draw:.2f}W, Temp: {current.temperature:.1f}°C")

    # Get power statistics
    stats = profiler.get_power_statistics(duration=5.0)
    print("Power statistics:", stats)

    # Stop profiling
    profiler.stop_profiling()

Operation result: {'operation': 'expert_forward', 'expert_id': 'expert_0', 'duration': 0.5008032321929932, 'energy_consumed': 5.083403208374977, 'start_metrics': GPUMetrics(timestamp=1750177592.0765254, power_draw=10.051, temperature=45, memory_used=277872640, memory_total=16106127360, gpu_utilization=0, memory_utilization=0), 'end_metrics': GPUMetrics(timestamp=1750177592.5354972, power_draw=10.25, temperature=45, memory_used=277872640, memory_total=16106127360, gpu_utilization=0, memory_utilization=0)}
Current power: 10.25W, Temp: 45.0°C
Power statistics: {'mean_power': np.float64(10.143), 'max_power': np.float64(10.349), 'min_power': np.float64(10.051), 'std_power': np.float64(0.10237815057074293), 'total_energy': np.float64(14.200200000000002)}


In [11]:
import json
import time
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum
from pathlib import Path

class ThermalState(Enum):
    """Enumeration of thermal states for the system."""
    COOL = "cool"
    WARM = "warm"
    HOT = "hot"
    CRITICAL = "critical"
    THROTTLED = "throttled" # Added for explicit throttled state

class PowerMode(Enum):
    """Enumeration of power modes."""
    LOW_POWER = "low_power"
    BALANCED = "balanced"
    PERFORMANCE = "performance"
    EMERGENCY = "emergency"


@dataclass
class ThermalBudget:
    """Represents current thermal constraints and budgets."""
    max_temperature: float
    max_power: float
    max_energy_per_token: float # This might be dynamic based on expert costs
    current_temperature: float
    current_power: float
    thermal_headroom: float
    power_headroom: float
    recommended_experts: List[str] # Now a list for specific expert IDs
    throttle_factor: float = 1.0


@dataclass
class ThermalSignal:
    """Complete thermal signal with all relevant information."""
    timestamp: float
    thermal_state: ThermalState
    power_mode: PowerMode
    temperature: float
    power_draw: float
    thermal_budget: ThermalBudget
    expert_recommendations: Dict[str, float]  # expert_id -> priority score
    throttle_recommendations: Dict[str, float]  # operation -> throttle factor (e.g., 'global_throttle': 0.8)
    emergency_actions: List[str]


class ThermalSignalGenerator:
    """
    Generates thermal signals based on GPU profiling data and thermal budgets.
    This module determines when to throttle computation, switch experts, or
    take emergency actions based on thermal constraints.
    """

    def __init__(self,
                 profiler: GPUProfiler,
                 cost_table_path: str = "energy/cost_table.json"):
        self.profiler = profiler
        self.cost_table_path = Path(cost_table_path)
        self.cost_table = self._load_cost_table()

        # Thermal parameters from cost table
        self.thermal_params = self.cost_table.get("thermal_parameters", {})
        self.energy_budgets = self.cost_table.get("energy_budgets", {})
        self.expert_profiles = self.cost_table.get("expert_profiles", {})

        # State tracking
        self.current_mode = PowerMode.BALANCED
        self.thermal_history: List[ThermalSignal] = []
        self.last_signal_time = 0.0
        self.emergency_cooldown_active = False # Flag for emergency state
        self.emergency_cooldown_duration = self.thermal_params.get("emergency_cooldown_duration", 30.0)
        self.emergency_cooldown_start_time = 0.0

        # Thermal model parameters
        self.base_temp = self.thermal_params.get("base_temperature", 45.0)
        self.warm_temp_threshold = self.thermal_params.get("warm_temperature_threshold", self.base_temp + 20)
        self.hot_temp_threshold = self.thermal_params.get("hot_temperature_threshold", 83.0) # Renamed for consistency
        self.critical_temp = self.thermal_params.get("critical_temperature", 87.0)
        self.thermal_time_constant = self.thermal_params.get("thermal_time_constant", 15.0)

        logging.basicConfig(level=logging.INFO) # Ensure logging is configured
        logging.info(f"ThermalSignal initialized with thresholds: {self.hot_temp_threshold}°C hot, {self.critical_temp}°C critical")

    def _load_cost_table(self) -> Dict:
        """Load the cost table configuration."""
        try:
            with open(self.cost_table_path, 'r') as f:
                return json.load(f)
        except Exception as e:
            logging.error(f"Failed to load cost table from {self.cost_table_path}: {e}")
            return self._get_default_cost_table()

    def _get_default_cost_table(self) -> Dict:
        """Return default cost table if loading fails."""
        return {
            "thermal_parameters": {
                "base_temperature": 45.0,
                "warm_temperature_threshold": 65.0, # Added specific warm threshold
                "hot_temperature_threshold": 83.0,
                "critical_temperature": 87.0,
                "thermal_time_constant": 15.0,
                "emergency_cooldown_duration": 30.0
            },
            "energy_budgets": {
                "low_power": {"max_power_watts": 200, "max_temperature": 75.0, "max_energy_per_token_mj": 5.0},
                "balanced": {"max_power_watts": 350, "max_temperature": 80.0, "max_energy_per_token_mj": 3.0},
                "performance": {"max_power_watts": 450, "max_temperature": 85.0, "max_energy_per_token_mj": 1.5}
            },
            "expert_profiles": {
                "expert_A": {"average_power_watts": 50, "energy_per_token_mj": 2.0, "thermal_impact": 0.1},
                "expert_B": {"average_power_watts": 70, "energy_per_token_mj": 2.5, "thermal_impact": 0.15},
                "expert_C": {"average_power_watts": 30, "energy_per_token_mj": 1.0, "thermal_impact": 0.05}
            }
        }

    def get_thermal_signal(self) -> Optional[ThermalSignal]:
        """
        Generate current thermal signal based on GPU metrics and thermal model.
        """
        current_metrics = self.profiler.get_current_metrics()
        if not current_metrics:
            logging.warning("No GPU metrics available to generate thermal signal.")
            return None

        current_time = time.time()

        # Check for active emergency cooldown
        if self.emergency_cooldown_active:
            if current_time - self.emergency_cooldown_start_time < self.emergency_cooldown_duration:
                # Still in emergency cooldown, prioritize cool-down actions
                logging.info(f"Emergency cooldown active. Remaining: {self.emergency_cooldown_duration - (current_time - self.emergency_cooldown_start_time):.1f}s")
                # Force emergency mode during cooldown
                power_mode = PowerMode.EMERGENCY
                thermal_state = ThermalState.THROTTLED # Or CRITICAL if temp is still high
                if current_metrics.temperature < self.warm_temp_threshold: # Exit emergency if sufficiently cooled
                    self.emergency_cooldown_active = False
                    logging.info("Exiting emergency cooldown: temperature has dropped.")
                else:
                    # Still hot during cooldown
                    if current_metrics.temperature >= self.critical_temp:
                         thermal_state = ThermalState.CRITICAL
                    elif current_metrics.temperature >= self.hot_temp_threshold:
                         thermal_state = ThermalState.HOT
                    else:
                         thermal_state = ThermalState.THROTTLED # Indicates active throttling for cooldown
            else:
                self.emergency_cooldown_active = False
                logging.info("Emergency cooldown period ended.")

        # Determine thermal state if not in active cooldown
        if not self.emergency_cooldown_active:
            thermal_state = self._classify_thermal_state(current_metrics)

        # Determine power mode (may be overridden by thermal state)
        # If emergency cooldown is active, power_mode is already set to EMERGENCY
        if not self.emergency_cooldown_active:
            power_mode = self._determine_power_mode(current_metrics, thermal_state)
        self.current_mode = power_mode # Update internal state

        # Calculate thermal budget
        thermal_budget = self._calculate_thermal_budget(current_metrics, power_mode)

        # Generate expert recommendations
        expert_recommendations = self._generate_expert_recommendations(
            current_metrics, thermal_budget, thermal_state
        )

        # Generate throttle recommendations
        throttle_recommendations = self._generate_throttle_recommendations(
            current_metrics, thermal_state, thermal_budget
        )

        # Check for emergency actions (can trigger cooldown)
        emergency_actions = self._check_emergency_actions(current_metrics, thermal_state)
        if emergency_actions and "initiate_emergency_cooldown" in emergency_actions and not self.emergency_cooldown_active:
            self.emergency_cooldown_active = True
            self.emergency_cooldown_start_time = current_time
            logging.warning("Initiating emergency cooldown due to critical thermal state!")


        signal = ThermalSignal(
            timestamp=current_time,
            thermal_state=thermal_state,
            power_mode=power_mode,
            temperature=current_metrics.temperature,
            power_draw=current_metrics.power_draw,
            thermal_budget=thermal_budget,
            expert_recommendations=expert_recommendations,
            throttle_recommendations=throttle_recommendations,
            emergency_actions=emergency_actions
        )

        # Update history
        self.thermal_history.append(signal)
        if len(self.thermal_history) > 100:  # Keep a reasonable history size
            self.thermal_history.pop(0)

        self.last_signal_time = current_time

        return signal

    def _classify_thermal_state(self, metrics: GPUMetrics) -> ThermalState:
        """Classify current thermal state based on temperature."""
        temp = metrics.temperature

        if temp >= self.critical_temp:
            return ThermalState.CRITICAL
        elif temp >= self.hot_temp_threshold:
            return ThermalState.HOT
        elif temp >= self.warm_temp_threshold:
            return ThermalState.WARM
        else:
            return ThermalState.COOL

    def _determine_power_mode(self, metrics: GPUMetrics, thermal_state: ThermalState) -> PowerMode:
        """
        Determine the appropriate power mode based on thermal state and current usage.
        This could involve hysteresis to prevent rapid mode switching.
        """
        proposed_mode = self.current_mode # Start with current mode

        # Logic for power mode transition
        if thermal_state == ThermalState.CRITICAL:
            proposed_mode = PowerMode.EMERGENCY
        elif thermal_state == ThermalState.HOT:
            if proposed_mode == PowerMode.PERFORMANCE: # Drop from performance if hot
                proposed_mode = PowerMode.BALANCED
            elif proposed_mode == PowerMode.BALANCED and metrics.power_draw > self.energy_budgets["balanced"]["max_power_watts"] * 0.9:
                # If balanced and still drawing too much power, consider low power
                proposed_mode = PowerMode.LOW_POWER
        elif thermal_state == ThermalState.WARM:
            if proposed_mode == PowerMode.PERFORMANCE and metrics.power_draw > self.energy_budgets["performance"]["max_power_watts"]:
                # If warm but drawing too much for performance, drop to balanced
                proposed_mode = PowerMode.BALANCED
            elif proposed_mode == PowerMode.LOW_POWER and metrics.temperature < self.warm_temp_threshold - 5 and metrics.gpu_utilization < 0.5:
                # If low power and cooled down, can move to balanced if not highly utilized
                proposed_mode = PowerMode.BALANCED
        elif thermal_state == ThermalState.COOL:
            if proposed_mode == PowerMode.LOW_POWER and metrics.temperature < self.base_temp + 5 and metrics.gpu_utilization < 0.3:
                proposed_mode = PowerMode.BALANCED
            elif proposed_mode == PowerMode.BALANCED and metrics.temperature < self.base_temp + 10 and metrics.gpu_utilization > 0.8:
                # If balanced and good temperature, consider performance if high utilization
                proposed_mode = PowerMode.PERFORMANCE

        # Apply hysteresis: only change mode if conditions persist or change is drastic
        # For simplicity, we'll keep it direct for now, but a real system might use a timer
        # or history to prevent oscillations.

        if proposed_mode != self.current_mode:
            logging.info(f"Power mode transition: {self.current_mode.value} -> {proposed_mode.value} due to {thermal_state.value} state.")
            self.current_mode = proposed_mode # Update internal state

        return proposed_mode

    def _calculate_thermal_budget(self, metrics: GPUMetrics, power_mode: PowerMode) -> ThermalBudget:
        """
        Calculate current thermal budget based on the selected power mode and current metrics.
        """
        mode_budget = self.energy_budgets.get(power_mode.value, self.energy_budgets["balanced"])

        max_temp = mode_budget.get("max_temperature", 80.0)
        max_power = mode_budget.get("max_power_watts", 350.0)
        max_energy_per_token = mode_budget.get("max_energy_per_token_mj", 3.0)

        thermal_headroom = max_temp - metrics.temperature
        power_headroom = max_power - metrics.power_draw

        # Simple throttle factor calculation based on temperature headroom
        throttle_factor = 1.0
        if metrics.temperature >= self.hot_temp_threshold:
            # Linear throttle as temperature approaches critical
            temp_range = self.critical_temp - self.hot_temp_threshold
            if temp_range > 0:
                throttle_factor = 1.0 - ((metrics.temperature - self.hot_temp_threshold) / temp_range) * 0.5 # Up to 50% throttle
            throttle_factor = max(0.1, throttle_factor) # Ensure it doesn't go below 10%
            logging.warning(f"Temperature is hot ({metrics.temperature}°C), applying throttle factor: {throttle_factor:.2f}")

        # Also consider power headroom for throttling
        if metrics.power_draw > max_power * 1.1: # 10% overshoot on power budget
            power_throttle = 1.0 - ((metrics.power_draw - max_power) / (max_power * 0.5)) # throttle up to 50% for 50% power overshoot
            throttle_factor = min(throttle_factor, max(0.1, power_throttle))
            logging.warning(f"Power draw ({metrics.power_draw:.1f}W) exceeds budget, adjusting throttle factor: {throttle_factor:.2f}")


        # Dummy recommended experts for now - a real system would use a model
        # or heuristics based on expert costs and current budget.
        # This part should be driven by the _generate_expert_recommendations method.
        recommended_experts: List[str] = [] # This field is populated by _generate_expert_recommendations

        return ThermalBudget(
            max_temperature=max_temp,
            max_power=max_power,
            max_energy_per_token=max_energy_per_token,
            current_temperature=metrics.temperature,
            current_power=metrics.power_draw,
            thermal_headroom=thermal_headroom,
            power_headroom=power_headroom,
            recommended_experts=recommended_experts, # Will be filled by another method
            throttle_factor=throttle_factor
        )

    def _generate_expert_recommendations(
        self, metrics: GPUMetrics, budget: ThermalBudget, thermal_state: ThermalState
    ) -> Dict[str, float]:
        """
        Generate recommendations for expert prioritization based on thermal budget
        and expert profiles. Experts with lower energy costs and thermal impact
        are prioritized when resources are constrained.
        """
        expert_priority: Dict[str, float] = {}

        for expert_id, profile in self.expert_profiles.items():
            # Calculate a basic priority score. Lower is better (more efficient).
            # This is a simplified heuristic. A more advanced system might
            # consider recent expert usage, task requirements, etc.
            energy_cost = profile.get("energy_per_token_mj", 100.0)
            thermal_impact = profile.get("thermal_impact", 1.0)
            avg_power = profile.get("average_power_watts", 100.0)

            score = 0.0

            # Prioritize experts with lower energy consumption
            score += energy_cost / budget.max_energy_per_token # Lower is better, so divide

            # Prioritize experts that are less thermally impactful
            score += thermal_impact * (metrics.temperature / self.critical_temp) * 5.0 # Higher impact at higher temps lowers priority more

            # Adjust based on power headroom: if power is tight, penalize high-power experts
            if budget.power_headroom < 50: # If power is getting tight (e.g., less than 50W headroom)
                score += (avg_power / budget.max_power) * 2.0 # Penalize high power consumption

            # If current state is hot or critical, heavily penalize experts that generate a lot of heat/power
            if thermal_state in [ThermalState.HOT, ThermalState.CRITICAL]:
                score += (thermal_impact * 10) + (avg_power / budget.max_power * 5)

            expert_priority[expert_id] = score

        # Sort experts by priority score (lower score means higher priority)
        sorted_experts = sorted(expert_priority.items(), key=lambda item: item[1])

        # Convert back to dictionary, perhaps only returning top N or all with scores
        return {expert_id: score for expert_id, score in sorted_experts}

    def _generate_throttle_recommendations(
        self, metrics: GPUMetrics, thermal_state: ThermalState, budget: ThermalBudget
    ) -> Dict[str, float]:
        """
        Generate throttle recommendations for various operations.
        A global throttle factor is primarily determined by the thermal budget.
        """
        throttle_recs: Dict[str, float] = {}

        # The primary throttle factor comes from the thermal budget calculation
        global_throttle_factor = budget.throttle_factor
        throttle_recs["global_compute_throttle"] = global_throttle_factor

        # Add more specific throttles if needed, e.g., memory bandwidth throttle
        # based on memory temperature or utilization, if available in metrics.
        if metrics.memory_utilization > 0.9 and thermal_state in [ThermalState.HOT, ThermalState.CRITICAL]:
            throttle_recs["memory_bandwidth_throttle"] = max(0.5, 1.0 - (metrics.memory_utilization - 0.9) * 5) # 50% throttle if 100% util
            logging.warning(f"High memory utilization and thermal state: {throttle_recs['memory_bandwidth_throttle']:.2f} memory throttle.")


        if thermal_state == ThermalState.CRITICAL:
            throttle_recs["global_compute_throttle"] = min(0.1, global_throttle_factor) # Aggressive throttle
            logging.critical("CRITICAL thermal state: Aggressive global throttle applied.")
        elif thermal_state == ThermalState.HOT:
            throttle_recs["global_compute_throttle"] = min(0.5, global_throttle_factor) # Moderate throttle
            logging.warning("HOT thermal state: Moderate global throttle applied.")
        elif thermal_state == ThermalState.THROTTLED: # Explicit throttled state
            throttle_recs["global_compute_throttle"] = min(0.2, global_throttle_factor) # More aggressive throttle
            logging.warning("THROTTLED thermal state: Global throttle applied for cooldown.")


        return throttle_recs

    def _check_emergency_actions(self, metrics: GPUMetrics, thermal_state: ThermalState) -> List[str]:
        """
        Check if any emergency actions are required, such as initiating a full cooldown.
        """
        emergency_actions: List[str] = []

        if thermal_state == ThermalState.CRITICAL:
            logging.critical(f"Temperature {metrics.temperature}°C is CRITICAL! Recommending emergency cooldown.")
            emergency_actions.append("initiate_emergency_cooldown")
            emergency_actions.append("alert_system_operator") # Example of other emergency actions

        # Add other conditions that might trigger emergency actions, e.g.,
        # persistent high power draw despite throttling, or fan failure detection.

        return emergency_actions

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
import numpy as np


class TopKRouter(nn.Module):
    """
    Top-K router for MoE with load balancing loss and detailed profiling.
    """

    def __init__(
        self,
        d_model: int,
        n_experts: int,
        top_k: int = 2,
        capacity_factor: float = 1.25,
        gate_noise: float = 1e-2,
        expert_dropout: float = 0.0,
        balance_loss_weight: float = 0.01,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_experts = n_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.gate_noise = gate_noise
        self.expert_dropout = expert_dropout
        self.balance_loss_weight = balance_loss_weight

        # Gating network - simple linear layer
        self.gate = nn.Linear(d_model, n_experts, bias=False)

        # Initialize gate weights
        nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)

        # For tracking expert usage statistics
        self.register_buffer('expert_usage_counts', torch.zeros(n_experts))
        self.register_buffer('total_tokens_processed', torch.tensor(0.0))

    def add_noise(self, logits: torch.Tensor) -> torch.Tensor:
        """Add noise to gate logits for better exploration."""
        if self.training and self.gate_noise > 0:
            noise = torch.randn_like(logits) * self.gate_noise
            return logits + noise
        return logits

    def compute_load_balancing_loss(
        self,
        gate_logits: torch.Tensor,
        selected_experts: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute load balancing loss to encourage even expert usage.

        Args:
            gate_logits: Raw gate logits [batch_size * seq_len, n_experts]
            selected_experts: Selected expert indices [batch_size * seq_len, top_k]

        Returns:
            Load balancing loss scalar
        """
        # Compute gate probabilities
        gate_probs = F.softmax(gate_logits, dim=-1)  # [B*T, n_experts]

        # Fraction of tokens assigned to each expert
        expert_mask = F.one_hot(selected_experts, num_classes=self.n_experts).float()  # [B*T, top_k, n_experts]
        expert_assignment = expert_mask.sum(dim=1)  # [B*T, n_experts]
        tokens_per_expert = expert_assignment.sum(dim=0)  # [n_experts]

        # Normalize by total tokens
        total_tokens = gate_logits.shape[0] * self.top_k
        fraction_per_expert = tokens_per_expert / total_tokens

        # Average gate probability for each expert
        avg_gate_prob = gate_probs.mean(dim=0)  # [n_experts]

        # Load balancing loss: minimize the dot product of these two distributions
        # This encourages both distributions to be uniform
        balance_loss = (fraction_per_expert * avg_gate_prob).sum() * self.n_experts

        return balance_loss

    def compute_capacity(self, batch_size: int, seq_len: int) -> int:
        """Compute expert capacity based on capacity factor."""
        tokens_per_expert = (batch_size * seq_len * self.top_k) / self.n_experts
        capacity = int(tokens_per_expert * self.capacity_factor)
        return max(capacity, 4)  # Minimum capacity

    def profile_expert_timing(self, selected_experts: torch.Tensor) -> Dict:
        """
        Profile per-expert timing and usage.

        Args:
            selected_experts: Selected expert indices [batch_size * seq_len, top_k]

        Returns:
            Dictionary with profiling information
        """
        metrics = {}

        # Count expert usage
        expert_counts = torch.zeros(self.n_experts, device=selected_experts.device)
        for expert_id in range(self.n_experts):
            expert_counts[expert_id] = (selected_experts == expert_id).sum().float()

        # Update global statistics
        self.expert_usage_counts += expert_counts
        self.total_tokens_processed += selected_experts.numel()

        # Compute usage statistics
        total_assignments = expert_counts.sum()
        expert_utilization = expert_counts / (total_assignments + 1e-8)

        metrics.update({
            'expert_usage_current': expert_counts.cpu().numpy(),
            'expert_utilization_current': expert_utilization.cpu().numpy(),
            'expert_usage_cumulative': self.expert_usage_counts.cpu().numpy(),
            'total_assignments': total_assignments.item(),
            'usage_variance': expert_utilization.var().item(),
            'max_expert_usage': expert_utilization.max().item(),
            'min_expert_usage': expert_utilization.min().item(),
        })

        return metrics

    def forward(self, x: torch.Tensor) -> Dict:
        """
        Forward pass of the router.

        Args:
            x: Input tensor [batch_size, seq_len, d_model]

        Returns:
            Dictionary containing routing decisions and metrics
        """
        batch_size, seq_len, d_model = x.shape

        # Reshape for easier processing
        x_flat = x.view(-1, d_model)  # [batch_size * seq_len, d_model]

        # Timing for gate computation
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            gate_start = torch.cuda.Event(enable_timing=True)
            gate_end = torch.cuda.Event(enable_timing=True)
            gate_start.record()

        # Compute gate logits
        gate_logits = self.gate(x_flat)  # [batch_size * seq_len, n_experts]

        # Add noise for exploration
        gate_logits = self.add_noise(gate_logits)

        # Get top-k experts
        top_k_values, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)

        if torch.cuda.is_available():
            gate_end.record()
            torch.cuda.synchronize()
            gate_time = gate_start.elapsed_time(gate_end)
        else:
            gate_time = 0.0

        # Compute routing probabilities (softmax over top-k)
        top_k_probs = F.softmax(top_k_values, dim=-1)  # [batch_size * seq_len, top_k]

        # Compute load balancing loss
        balance_loss = self.compute_load_balancing_loss(gate_logits, top_k_indices)

        # Profile expert usage
        usage_metrics = self.profile_expert_timing(top_k_indices)

        # Compute expert capacity
        capacity = self.compute_capacity(batch_size, seq_len)

        # Apply expert dropout during training
        if self.training and self.expert_dropout > 0:
            dropout_mask = torch.rand_like(top_k_probs) > self.expert_dropout
            top_k_probs = top_k_probs * dropout_mask
            # Renormalize
            top_k_probs = top_k_probs / (top_k_probs.sum(dim=-1, keepdim=True) + 1e-8)

        # Prepare output
        output = {
            'expert_indices': top_k_indices,  # [batch_size * seq_len, top_k]
            'expert_weights': top_k_probs,    # [batch_size * seq_len, top_k]
            'gate_logits': gate_logits,       # [batch_size * seq_len, n_experts]
            'balance_loss': balance_loss * self.balance_loss_weight,
            'capacity': capacity,
            'metrics': {
                'gate_computation_time_ms': gate_time,
                'balance_loss_raw': balance_loss.item(),
                'balance_loss_weighted': (balance_loss * self.balance_loss_weight).item(),
                'capacity': capacity,
                'avg_top_k_confidence': top_k_probs.mean().item(),
                'gate_entropy': -F.softmax(gate_logits, dim=-1).mul(F.log_softmax(gate_logits, dim=-1)).sum(-1).mean().item(),
                **usage_metrics,
            }
        }

        return output

    def get_expert_stats(self) -> Dict:
        """Get comprehensive expert usage statistics."""
        if self.total_tokens_processed == 0:
            return {'message': 'No tokens processed yet'}

        cumulative_usage = self.expert_usage_counts / self.total_tokens_processed

        return {
            'total_tokens_processed': self.total_tokens_processed.item(),
            'expert_usage_distribution': cumulative_usage.cpu().numpy(),
            'usage_std': cumulative_usage.std().item(),
            'usage_coefficient_of_variation': cumulative_usage.std().item() / (cumulative_usage.mean().item() + 1e-8),
            'most_used_expert': cumulative_usage.argmax().item(),
            'least_used_expert': cumulative_usage.argmin().item(),
            'perfect_balance_target': 1.0 / self.n_experts,
        }

    def reset_stats(self):
        """Reset expert usage statistics."""
        self.expert_usage_counts.zero_()
        self.total_tokens_processed.zero_()


# Auxiliary router for comparison/experimentation
class SwitchRouter(TopKRouter):
    """
    Switch Transformer style router (top-1 with capacity dropping).
    Inherits from TopKRouter but modifies behavior for top-1 routing.
    """

    def __init__(self, d_model: int, n_experts: int, **kwargs):
        # Force top_k = 1 for Switch routing
        super().__init__(d_model, n_experts, top_k=1, **kwargs)

    def forward(self, x: torch.Tensor) -> Dict:
        """Switch-style routing with capacity dropping."""
        # Get base routing decisions
        output = super().forward(x)

        # For Switch, we need to implement capacity dropping
        batch_size, seq_len = x.shape[:2]
        capacity = output['capacity']
        expert_indices = output['expert_indices'].squeeze(-1)  # [batch_size * seq_len]

        # Count tokens per expert
        expert_counts = torch.zeros(self.n_experts, device=x.device)
        for expert_id in range(self.n_experts):
            expert_counts[expert_id] = (expert_indices == expert_id).sum()

        # Create capacity mask
        capacity_mask = torch.ones_like(expert_indices, dtype=torch.bool)

        for expert_id in range(self.n_experts):
            expert_tokens = (expert_indices == expert_id).nonzero(as_tuple=True)[0]
            if len(expert_tokens) > capacity:
                # Randomly drop tokens exceeding capacity
                dropped_indices = expert_tokens[capacity:]
                capacity_mask[dropped_indices] = False

        # Update metrics
        tokens_dropped = (~capacity_mask).sum().item()
        output['metrics']['tokens_dropped'] = tokens_dropped
        output['metrics']['drop_rate'] = tokens_dropped / expert_indices.numel()
        output['capacity_mask'] = capacity_mask

        return output


# Example usage and testing
if __name__ == "__main__":
    # Test the router
    d_model = 256
    n_experts = 8
    top_k = 2
    batch_size = 4
    seq_len = 32

    router = TopKRouter(
        d_model=d_model,
        n_experts=n_experts,
        top_k=top_k,
        capacity_factor=1.5,
        balance_loss_weight=0.01,
    )

    # Test input
    x = torch.randn(batch_size, seq_len, d_model)

    # Forward pass
    with torch.no_grad():
        output = router(x)

    print("Router output keys:", output.keys())
    print(f"Expert indices shape: {output['expert_indices'].shape}")
    print(f"Expert weights shape: {output['expert_weights'].shape}")
    print(f"Balance loss: {output['balance_loss'].item():.6f}")
    print(f"Expert capacity: {output['capacity']}")

    print("\nMetrics:")
    for key, value in output['metrics'].items():
        if isinstance(value, (list, np.ndarray)):
            print(f"  {key}: {np.array(value)}")
        else:
            print(f"  {key}: {value}")

    print("\nExpert statistics:")
    stats = router.get_expert_stats()
    for key, value in stats.items():
        if isinstance(value, (list, np.ndarray)):
            print(f"  {key}: {np.array(value)}")
        else:
            print(f"  {key}: {value}")

    # Test Switch router
    print("\n" + "="*50)
    print("Testing Switch Router")

    switch_router = SwitchRouter(d_model=d_model, n_experts=n_experts)

    with torch.no_grad():
        switch_output = switch_router(x)

    print(f"Switch expert indices shape: {switch_output['expert_indices'].shape}")
    print(f"Tokens dropped: {switch_output['metrics']['tokens_dropped']}")
    print(f"Drop rate: {switch_output['metrics']['drop_rate']:.3f}")

Router output keys: dict_keys(['expert_indices', 'expert_weights', 'gate_logits', 'balance_loss', 'capacity', 'metrics'])
Expert indices shape: torch.Size([128, 2])
Expert weights shape: torch.Size([128, 2])
Balance loss: 0.010017
Expert capacity: 48

Metrics:
  gate_computation_time_ms: 0.3401919901371002
  balance_loss_raw: 1.0016876459121704
  balance_loss_weighted: 0.01001687627285719
  capacity: 48
  avg_top_k_confidence: 0.5
  gate_entropy: 2.0340373516082764
  expert_usage_current: [30. 27. 33. 30. 34. 38. 30. 34.]
  expert_utilization_current: [0.1171875  0.10546875 0.12890625 0.1171875  0.1328125  0.1484375
 0.1171875  0.1328125 ]
  expert_usage_cumulative: [30. 27. 33. 30. 34. 38. 30. 34.]
  total_assignments: 256.0
  usage_variance: 0.0001787458168109879
  max_expert_usage: 0.1484375
  min_expert_usage: 0.10546875

Expert statistics:
  total_tokens_processed: 256.0
  expert_usage_distribution: [0.1171875  0.10546875 0.12890625 0.1171875  0.1328125  0.1484375
 0.1171875  0.13

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional, List, Any
import time
import numpy as np
# from fairscale.nn.moe import MOELayer # Removing fairscale dependency

class SimpleMoELayer(nn.Module):
    """
    A simplified, non-distributed MoE layer for profiling.
    Handles top-k routing, expert dispatch, and output combining.
    """
    def __init__(self, gate: nn.Module, experts: nn.ModuleList, top_k: int = 2, capacity_factor: float = 1.25):
        super().__init__()
        self.gate = gate # The gating network (e.g., a Linear layer)
        self.experts = experts # A ModuleList of expert networks
        self.n_experts = len(experts)
        self.top_k = top_k
        self.capacity_factor = capacity_factor # Keep for potential future use/comparison, not strictly used in this simple version

        if top_k > self.n_experts:
            raise ValueError(f"top_k ({top_k}) cannot be greater than n_experts ({self.n_experts})")

        # Dictionary to store accumulated expert timings
        self.expert_timings: Dict[int, float] = {}

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """
        Forward pass of the simple MoE layer.

        Args:
            x: Input tensor [batch_size * seq_len, d_model] (assuming flattened input)

        Returns:
            output: Output tensor [batch_size * seq_len, d_model]
            aux_loss: Load balancing loss scalar
            metrics: Dictionary containing routing metrics
        """
        # Input shape is expected to be [num_tokens, d_model]
        num_tokens, d_model = x.shape
        device = x.device

        # Compute gate logits
        gate_logits = self.gate(x)  # [num_tokens, n_experts]

        # Get top-k experts and their probabilities
        top_k_values, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1) # [num_tokens, top_k]
        top_k_probs = F.softmax(top_k_values, dim=-1) # [num_tokens, top_k]

        # Compute load balancing loss (similar to the router's logic)
        gate_probs_all = F.softmax(gate_logits, dim=-1) # [num_tokens, n_experts]

        # Fraction of tokens assigned to each expert (based on top-1 for simplicity in loss)
        # For top-k, a more accurate loss might be needed, but this is a common approximation
        top1_indices = top_k_indices[:, 0] # [num_tokens]
        expert_mask_top1 = F.one_hot(top1_indices, num_classes=self.n_experts).float() # [num_tokens, n_experts]
        tokens_per_expert_for_loss = expert_mask_top1.sum(dim=0) # [n_experts]
        total_tokens_for_loss = num_tokens # Use total tokens for normalization

        # Normalize by total tokens
        fraction_per_expert = tokens_per_expert_for_loss / (total_tokens_for_loss + 1e-8) # Add epsilon for stability

        # Average gate probability for each expert (average over all tokens)
        avg_gate_prob = gate_probs_all.mean(dim=0) # [n_experts]

        # Load balancing loss
        # This formulation encourages both fraction_per_expert and avg_gate_prob to be uniform
        aux_loss = (fraction_per_expert * avg_gate_prob).sum() * self.n_experts


        # Dispatch tokens to experts
        output = torch.zeros_like(x) # [num_tokens, d_model]
        metrics: Dict[str, Any] = {} # Use Dict[str, Any] for flexibility

        # Collect metrics for profiling
        expert_usage_counts = torch.zeros(self.n_experts, device=device)
        expert_batch_timings: Dict[int, float] = {} # Timings for this batch

        # Loop through each expert
        for expert_id in range(self.n_experts):
            # Find tokens routed to this expert (at any of the top-k positions)
            expert_tokens_mask = (top_k_indices == expert_id).any(dim=-1) # [num_tokens]
            expert_token_indices = torch.where(expert_tokens_mask)[0] # Indices of tokens for this expert

            if expert_token_indices.numel() > 0:
                # Get inputs for this expert
                expert_input = x[expert_token_indices] # [num_expert_tokens, d_model]

                # Get weights for this expert's tokens
                expert_weights_for_tokens = torch.zeros(expert_token_indices.numel(), device=device)
                for i, token_idx in enumerate(expert_token_indices):
                    pos_in_topk = torch.where(top_k_indices[token_idx] == expert_id)[0]
                    if pos_in_topk.numel() > 0:
                         expert_weights_for_tokens[i] = top_k_probs[token_idx, pos_in_topk].sum()


                # Record start time for expert computation
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)
                    start_event.record()
                else:
                    start_time = time.time()


                # Run expert forward pass
                expert_output = self.experts[expert_id](expert_input) # [num_expert_tokens, d_model]

                # Record end time and calculate duration
                if torch.cuda.is_available():
                    end_event.record()
                    torch.cuda.synchronize()
                    duration_ms = start_event.elapsed_time(end_event)
                else:
                    duration_ms = (time.time() - start_time) * 1000.0 # Convert to ms

                expert_batch_timings[expert_id] = duration_ms
                self.expert_timings[expert_id] = self.expert_timings.get(expert_id, 0.0) + duration_ms # Accumulate timing


                # Weight expert output by the routing probability and scatter back
                weighted_expert_output = expert_output * expert_weights_for_tokens.unsqueeze(-1)

                # Scatter weighted output back to the original token positions
                output[expert_token_indices] += weighted_expert_output

                # Update usage counts
                expert_usage_counts[expert_id] = expert_token_indices.numel()


        # Add metrics to the dictionary
        metrics['expert_usage_current'] = expert_usage_counts.cpu().numpy()
        metrics['total_assignments'] = expert_usage_counts.sum().item() # Total tokens dispatched
        metrics['expert_batch_timings_ms'] = expert_batch_timings
        metrics['expert_cumulative_timings_ms'] = self.expert_timings


        return output, aux_loss, metrics


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, n_experts: int = 8,
                 top_k: int = 2, dropout: float = 0.1, use_moe: bool = True, capacity_factor: float = 1.25):
        super().__init__()
        self.d_model = d_model
        self.use_moe = use_moe

        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        self.norm1 = nn.LayerNorm(d_model)

        if use_moe:
            # Gating network for the MoE layer
            gate_layer = nn.Linear(d_model, n_experts, bias=False)
             # Initialize gate weights
            nn.init.normal_(gate_layer.weight, mean=0.0, std=0.02)


            # Create a ModuleList with n_experts distinct instances of the expert_module
            experts_list = nn.ModuleList([
                nn.Sequential(nn.Linear(d_model, d_ff),
                                   nn.ReLU(),
                                   nn.Dropout(dropout),
                                   nn.Linear(d_ff, d_model)) for _ in range(n_experts)
            ])

            self.moe_layer = SimpleMoELayer(
                gate=gate_layer,
                experts=experts_list,
                top_k=top_k,
                capacity_factor=capacity_factor # Not used in SimpleMoELayer but kept for compatibility
            )

        else:
            self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                              nn.ReLU(),
                                              nn.Dropout(dropout),
                                              nn.Linear(d_ff, d_model))


        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        self.expert_timings = {} # Keep for potential future profiling

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Forward pass with optional profiling.

        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Attention mask
            profile: Whether to collect timing information

        Returns:
            output: Transformed tensor
            metrics: Dictionary containing routing metrics and timings
        """
        metrics = {}

        # Self-attention
        residual = x
        attn_out, attn_weights = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(residual + self.dropout(attn_out))

        # MoE or FFN
        residual = x

        if self.use_moe:
            # Reshape input for MoE layer: [batch_size * seq_len, d_model]
            batch_size, seq_len, d_model = x.shape
            x_flat = x.view(-1, d_model)

            # Profile MoE forward pass
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()

            # MoE forward pass using SimpleMoELayer
            moe_out_flat, aux_loss, moe_metrics = self.moe_layer(x_flat)

            if profile and torch.cuda.is_available():
                end_event.record()
                torch.cuda.synchronize()

                total_time = start_event.elapsed_time(end_event)
                metrics['moe_forward_time_ms'] = total_time
                metrics.update(moe_metrics) # Add simple moe metrics

            # Reshape output back to [batch_size, seq_len, d_model]
            moe_out = moe_out_flat.view(batch_size, seq_len, d_model)

            x = residual + self.dropout(moe_out)
            metrics['aux_loss'] = aux_loss

        else:
            # Standard FFN
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()
                ffn_out = self.feed_forward(x)
                end_event.record()
                torch.cuda.synchronize()

                metrics['ffn_time_ms'] = start_event.elapsed_time(end_event)
            else:
                ffn_out = self.feed_forward(x)

            x = residual + self.dropout(ffn_out)

        x = self.norm2(x)

        return x, metrics


class MoETransformer(nn.Module):
    """Simple MoE Transformer model."""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        n_layers: int = 6,
        d_ff: int = 2048,
        n_experts: int = 8,
        top_k: int = 2,
        max_seq_len: int = 512,
        dropout: float = 0.1,
        use_moe_layers: Optional[list] = None,  # Which layers use MoE
        capacity_factor: float = 1.25,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers

        # Determine which layers use MoE
        if use_moe_layers is None:
            # By default, use MoE in every other layer starting from layer 1
            use_moe_layers = [i % 2 == 1 for i in range(n_layers)]
        elif len(use_moe_layers) != n_layers:
             raise ValueError(f"Length of use_moe_layers ({len(use_moe_layers)}) must match n_layers ({n_layers})")


        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                n_experts=n_experts,
                top_k=top_k,
                dropout=dropout,
                use_moe=use_moe_layers[i],
                capacity_factor=capacity_factor,
            )
            for i in range(n_layers)
        ])

        # Output projection
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Dict:
        """
        Forward pass.

        Args:
            input_ids: Token indices [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            profile: Whether to collect profiling information

        Returns:
            Dictionary containing logits, aux_loss, and optional metrics
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device

        # Embeddings
        positions = torch.arange(0, seq_len, device=device).unsqueeze(0)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)

        # Attention mask for causal modeling
        if attention_mask is None:
            # Create causal mask
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=device), diagonal=1
            ).bool()
        else:
            causal_mask = attention_mask

        total_aux_loss = 0.0
        all_metrics = {} if profile else None

        for i, layer in enumerate(self.layers):
            # Pass profiler instance to the layer if needed, or handle profiling inside layer
            x, layer_metrics = layer(x, mask=causal_mask, profile=profile)

            # Accumulate auxiliary loss from MoE layers
            if 'aux_loss' in layer_metrics:
                total_aux_loss += layer_metrics['aux_loss']

            # Collect metrics
            if profile:
                for key, value in layer_metrics.items():
                    if key != 'aux_loss':
                        # Append metrics to lists if they exist, otherwise create
                        metric_key = f'layer_{i}__{key}'
                        if metric_key in all_metrics:
                            if isinstance(all_metrics[metric_key], list):
                                all_metrics[metric_key].append(value)
                            else: # Convert to list if first time appending
                                all_metrics[metric_key] = [all_metrics[metric_key], value]
                        else:
                            all_metrics[metric_key] = value


        # Final layer norm and projection
        x = self.ln_f(x)
        logits = self.head(x)

        output = {
            'logits': logits,
            'aux_loss': total_aux_loss,
        }

        if profile:
            output['metrics'] = all_metrics

        return output


# Example usage and testing
if __name__ == "__main__":
    # Test the model
    model = MoETransformer(
        vocab_size=1000,
        d_model=256,
        n_heads=8,
        n_layers=4,
        d_ff=1024,
        n_experts=4,
        top_k=2,
        use_moe_layers=[False, True, False, True],  # MoE in layers 1 and 3
    )

    # Test input
    batch_size, seq_len = 2, 32
    input_ids = torch.randint(0, 1000, (batch_size, seq_len))

    # Forward pass with profiling
    with torch.no_grad():
        output = model(input_ids, profile=True)

    print(f"Output logits shape: {output['logits'].shape}")
    print(f"Auxiliary loss: {output['aux_loss'].item():.6f}")

    if 'metrics' in output:
        print("\nProfiling metrics:")
        for key, value in output['metrics'].items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: {value.shape}")
            elif isinstance(value, list):
                 print(f"  {key}: {value}") # Print list content
            elif isinstance(value, dict):
                 print(f"  {key}: {value}") # Print dict content
            else:
                print(f"  {key}: {value}")

Output logits shape: torch.Size([2, 32, 1000])
Auxiliary loss: 2.034574

Profiling metrics:
  layer_0__ffn_time_ms: 2.3575680255889893
  layer_1__moe_forward_time_ms: 9.569855690002441
  layer_1__expert_usage_current: [28. 31. 38. 31.]
  layer_1__total_assignments: 128.0
  layer_1__expert_batch_timings_ms: {0: 0.7941120266914368, 1: 0.9747200012207031, 2: 0.9547520279884338, 3: 0.8738560080528259}
  layer_1__expert_cumulative_timings_ms: {0: 0.7941120266914368, 1: 0.9747200012207031, 2: 0.9547520279884338, 3: 0.8738560080528259}
  layer_2__ffn_time_ms: 1.3054399490356445
  layer_3__moe_forward_time_ms: 9.73363208770752
  layer_3__expert_usage_current: [28. 30. 24. 46.]
  layer_3__total_assignments: 128.0
  layer_3__expert_batch_timings_ms: {0: 0.9148160219192505, 1: 0.9413120150566101, 2: 0.8222399950027466, 3: 1.049407958984375}
  layer_3__expert_cumulative_timings_ms: {0: 0.9148160219192505, 1: 0.9413120150566101, 2: 0.8222399950027466, 3: 1.049407958984375}


In [14]:
import torch
import torch.distributed as dist
import os

# Set environment variables for distributed training
# In a real distributed setup, these would be set differently on each node/process
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500' # Use a free port
os.environ['RANK'] = '0' # Rank of the current process
os.environ['WORLD_SIZE'] = '1' # Total number of processes

# Initialize the distributed environment
if not dist.is_initialized():
    dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo")
    print(f"Distributed environment initialized with backend: {dist.get_backend()}")
    print(f"Rank {dist.get_rank()} out of {dist.get_world_size()} processes")
else:
    print("Distributed environment already initialized.")

Distributed environment already initialized.


# Task
Integrate GPU profiling and thermal signal generation into the provided model evaluation code. Modify the `SimpleMoELayer` and `TransformerBlock` to collect detailed expert usage metrics. Update the evaluation loop to use `GPUProfiler` and `ThermalSignalGenerator`, collect metrics, and demonstrate basic analysis.

## Refine `simplemoelayer` for detailed profiling

### Subtask:
Modify the `SimpleMoELayer` class to record the execution time and token counts for each expert it dispatches to.


**Reasoning**:
Modify the `SimpleMoELayer` class to include expert timing and token count collection within its forward pass, storing this information in the returned metrics dictionary.



In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional, List
import time
import numpy as np

class SimpleMoELayer(nn.Module):
    """
    A simplified, non-distributed MoE layer for profiling.
    Handles top-k routing, expert dispatch, and output combining.
    """
    def __init__(self, gate: nn.Module, experts: nn.ModuleList, top_k: int = 2, capacity_factor: float = 1.25):
        super().__init__()
        self.gate = gate # The gating network (e.g., a Linear layer)
        self.experts = experts # A ModuleList of expert networks
        self.n_experts = len(experts)
        self.top_k = top_k
        self.capacity_factor = capacity_factor # Keep for potential future use/comparison, not strictly used in this simple version

        if top_k > self.n_experts:
            raise ValueError(f"top_k ({top_k}) cannot be greater than n_experts ({self.n_experts})")

        # Dictionary to store accumulated expert timings
        self.expert_timings: Dict[int, float] = {}

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """
        Forward pass of the simple MoE layer.

        Args:
            x: Input tensor [batch_size * seq_len, d_model] (assuming flattened input)

        Returns:
            output: Output tensor [batch_size * seq_len, d_model]
            aux_loss: Load balancing loss scalar
            metrics: Dictionary containing routing metrics
        """
        # Input shape is expected to be [num_tokens, d_model]
        num_tokens, d_model = x.shape
        device = x.device

        # Compute gate logits
        gate_logits = self.gate(x)  # [num_tokens, n_experts]

        # Get top-k experts and their probabilities
        top_k_values, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1) # [num_tokens, top_k]
        top_k_probs = F.softmax(top_k_values, dim=-1) # [num_tokens, top_k]

        # Compute load balancing loss (similar to the router's logic)
        gate_probs_all = F.softmax(gate_logits, dim=-1) # [num_tokens, n_experts]

        # Fraction of tokens assigned to each expert (based on top-1 for simplicity in loss)
        # For top-k, a more accurate loss might be needed, but this is a common approximation
        top1_indices = top_k_indices[:, 0] # [num_tokens]
        expert_mask_top1 = F.one_hot(top1_indices, num_classes=self.n_experts).float() # [num_tokens, n_experts]
        tokens_per_expert_for_loss = expert_mask_top1.sum(dim=0) # [n_experts]
        total_tokens_for_loss = num_tokens # Use total tokens for normalization

        # Normalize by total tokens
        fraction_per_expert = tokens_per_expert_for_loss / (total_tokens_for_loss + 1e-8) # Add epsilon for stability

        # Average gate probability for each expert (average over all tokens)
        avg_gate_prob = gate_probs_all.mean(dim=0) # [n_experts]

        # Load balancing loss
        # This formulation encourages both fraction_per_expert and avg_gate_prob to be uniform
        aux_loss = (fraction_per_expert * avg_gate_prob).sum() * self.n_experts


        # Dispatch tokens to experts
        output = torch.zeros_like(x) # [num_tokens, d_model]
        metrics: Dict[str, Any] = {} # Use Dict[str, Any] for flexibility

        # Collect metrics for profiling
        expert_usage_counts = torch.zeros(self.n_experts, device=device)
        expert_batch_timings: Dict[int, float] = {} # Timings for this batch

        # Loop through each expert
        for expert_id in range(self.n_experts):
            # Find tokens routed to this expert (at any of the top-k positions)
            expert_tokens_mask = (top_k_indices == expert_id).any(dim=-1) # [num_tokens]
            expert_token_indices = torch.where(expert_tokens_mask)[0] # Indices of tokens for this expert

            if expert_token_indices.numel() > 0:
                # Get inputs for this expert
                expert_input = x[expert_token_indices] # [num_expert_tokens, d_model]

                # Get weights for this expert's tokens
                expert_weights_for_tokens = torch.zeros(expert_token_indices.numel(), device=device)
                for i, token_idx in enumerate(expert_token_indices):
                    pos_in_topk = torch.where(top_k_indices[token_idx] == expert_id)[0]
                    if pos_in_topk.numel() > 0:
                         expert_weights_for_tokens[i] = top_k_probs[token_idx, pos_in_topk].sum()


                # Record start time for expert computation
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)
                    start_event.record()
                else:
                    start_time = time.time()


                # Run expert forward pass
                expert_output = self.experts[expert_id](expert_input) # [num_expert_tokens, d_model]

                # Record end time and calculate duration
                if torch.cuda.is_available():
                    end_event.record()
                    torch.cuda.synchronize()
                    duration_ms = start_event.elapsed_time(end_event)
                else:
                    duration_ms = (time.time() - start_time) * 1000.0 # Convert to ms

                expert_batch_timings[expert_id] = duration_ms
                self.expert_timings[expert_id] = self.expert_timings.get(expert_id, 0.0) + duration_ms # Accumulate timing


                # Weight expert output by the routing probability and scatter back
                weighted_expert_output = expert_output * expert_weights_for_tokens.unsqueeze(-1)

                # Scatter weighted output back to the original token positions
                output[expert_token_indices] += weighted_expert_output

                # Update usage counts
                expert_usage_counts[expert_id] = expert_token_indices.numel()


        # Add metrics to the dictionary
        metrics['expert_usage_current'] = expert_usage_counts.cpu().numpy()
        metrics['total_assignments'] = expert_usage_counts.sum().item() # Total tokens dispatched
        metrics['expert_batch_timings_ms'] = expert_batch_timings
        metrics['expert_cumulative_timings_ms'] = self.expert_timings


        return output, aux_loss, metrics


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, n_experts: int = 8,
                 top_k: int = 2, dropout: float = 0.1, use_moe: bool = True, capacity_factor: float = 1.25):
        super().__init__()
        self.d_model = d_model
        self.use_moe = use_moe

        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        self.norm1 = nn.LayerNorm(d_model)

        if use_moe:
            # Gating network for the MoE layer
            gate_layer = nn.Linear(d_model, n_experts, bias=False)
             # Initialize gate weights
            nn.init.normal_(gate_layer.weight, mean=0.0, std=0.02)


            # Create a ModuleList with n_experts distinct instances of the expert_module
            experts_list = nn.ModuleList([
                nn.Sequential(nn.Linear(d_model, d_ff),
                                   nn.ReLU(),
                                   nn.Dropout(dropout),
                                   nn.Linear(d_ff, d_model)) for _ in range(n_experts)
            ])

            self.moe_layer = SimpleMoELayer(
                gate=gate_layer,
                experts=experts_list,
                top_k=top_k,
                capacity_factor=capacity_factor # Not used in SimpleMoELayer but kept for compatibility
            )

        else:
            self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                              nn.ReLU(),
                                              nn.Dropout(dropout),
                                              nn.Linear(d_ff, d_model))


        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        self.expert_timings = {} # Keep for potential future profiling

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Forward pass with optional profiling.

        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Attention mask
            profile: Whether to collect timing information

        Returns:
            output: Transformed tensor
            metrics: Dictionary containing routing metrics and timings
        """
        metrics = {}

        # Self-attention
        residual = x
        attn_out, attn_weights = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(residual + self.dropout(attn_out))

        # MoE or FFN
        residual = x

        if self.use_moe:
            # Reshape input for MoE layer: [batch_size * seq_len, d_model]
            batch_size, seq_len, d_model = x.shape
            x_flat = x.view(-1, d_model)

            # Profile MoE forward pass
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()

            # MoE forward pass using SimpleMoELayer
            moe_out_flat, aux_loss, moe_metrics = self.moe_layer(x_flat)

            if profile and torch.cuda.is_available():
                end_event.record()
                torch.cuda.synchronize()

                total_time = start_event.elapsed_time(end_event)
                metrics['moe_forward_time_ms'] = total_time
                metrics.update(moe_metrics) # Add simple moe metrics

            # Reshape output back to [batch_size, seq_len, d_model]
            moe_out = moe_out_flat.view(batch_size, seq_len, d_model)

            x = residual + self.dropout(moe_out)
            metrics['aux_loss'] = aux_loss

        else:
            # Standard FFN
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()
                ffn_out = self.feed_forward(x)
                end_event.record()
                torch.cuda.synchronize()

                metrics['ffn_time_ms'] = start_event.elapsed_time(end_event)
            else:
                ffn_out = self.feed_forward(x)

            x = residual + self.dropout(ffn_out)

        x = self.norm2(x)

        return x, metrics


class MoETransformer(nn.Module):
    """Simple MoE Transformer model."""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        n_layers: int = 6,
        d_ff: int = 2048,
        n_experts: int = 8,
        top_k: int = 2,
        max_seq_len: int = 512,
        dropout: float = 0.1,
        use_moe_layers: Optional[list] = None,  # Which layers use MoE
        capacity_factor: float = 1.25,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers

        # Determine which layers use MoE
        if use_moe_layers is None:
            # By default, use MoE in every other layer starting from layer 1
            use_moe_layers = [i % 2 == 1 for i in range(n_layers)]
        elif len(use_moe_layers) != n_layers:
             raise ValueError(f"Length of use_moe_layers ({len(use_moe_layers)}) must match n_layers ({n_layers})")


        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                n_experts=n_experts,
                top_k=top_k,
                dropout=dropout,
                use_moe=use_moe_layers[i],
                capacity_factor=capacity_factor,
            )
            for i in range(n_layers)
        ])

        # Output projection
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Dict:
        """
        Forward pass.

        Args:
            input_ids: Token indices [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            profile: Whether to collect profiling information

        Returns:
            Dictionary containing logits, aux_loss, and optional metrics
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device

        # Embeddings
        positions = torch.arange(0, seq_len, device=device).unsqueeze(0)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)

        # Attention mask for causal modeling
        if attention_mask is None:
            # Create causal mask
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=device), diagonal=1
            ).bool()
        else:
            causal_mask = attention_mask

        total_aux_loss = 0.0
        all_metrics = {} if profile else None

        for i, layer in enumerate(self.layers):
            # Pass profiler instance to the layer if needed, or handle profiling inside layer
            x, layer_metrics = layer(x, mask=causal_mask, profile=profile)

            # Accumulate auxiliary loss from MoE layers
            if 'aux_loss' in layer_metrics:
                total_aux_loss += layer_metrics['aux_loss']

            # Collect metrics
            if profile:
                for key, value in layer_metrics.items():
                    if key != 'aux_loss':
                        # Append metrics to lists if they exist, otherwise create
                        metric_key = f'layer_{i}__{key}'
                        if metric_key in all_metrics:
                            if isinstance(all_metrics[metric_key], list) or isinstance(all_metrics[metric_key], dict):
                                all_metrics[metric_key].append(value)
                            else: # Convert to list if first time appending
                                all_metrics[metric_key] = [all_metrics[metric_key], value]
                        else:
                            all_metrics[metric_key] = value


        # Final layer norm and projection
        x = self.ln_f(x)
        logits = self.head(x)

        output = {
            'logits': logits,
            'aux_loss': total_aux_loss,
        }

        if profile:
            output['metrics'] = all_metrics

        return output

# Example usage and testing (optional, moved from main execution block)
if __name__ == "__main__":
    # Test the model
    model = MoETransformer(
        vocab_size=1000,
        d_model=256,
        n_heads=8,
        n_layers=4,
        d_ff=1024,
        n_experts=4,
        top_k=2,
        use_moe_layers=[False, True, False, True],  # MoE in layers 1 and 3
    )

    # Test input
    batch_size, seq_len = 2, 32
    input_ids = torch.randint(0, 1000, (batch_size, seq_len))

    # Forward pass with profiling
    with torch.no_grad():
        output = model(input_ids, profile=True)

    print(f"Output logits shape: {output['logits'].shape}")
    print(f"Auxiliary loss: {output['aux_loss'].item():.6f}")

    if 'metrics' in output:
        print("\nProfiling metrics:")
        for key, value in output['metrics'].items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: {value.shape}")
            elif isinstance(value, list):
                 print(f"  {key}: {value}") # Print list content
            elif isinstance(value, dict):
                 print(f"  {key}: {value}") # Print dict content
            else:
                print(f"  {key}: {value}")


Output logits shape: torch.Size([2, 32, 1000])
Auxiliary loss: 2.008421

Profiling metrics:
  layer_0__ffn_time_ms: 1.5088000297546387
  layer_1__moe_forward_time_ms: 9.496352195739746
  layer_1__expert_usage_current: [30. 33. 32. 33.]
  layer_1__total_assignments: 128.0
  layer_1__expert_batch_timings_ms: {0: 0.8581439852714539, 1: 0.942080020904541, 2: 0.9077439904212952, 3: 0.8822720050811768}
  layer_1__expert_cumulative_timings_ms: {0: 0.8581439852714539, 1: 0.942080020904541, 2: 0.9077439904212952, 3: 0.8822720050811768}
  layer_2__ffn_time_ms: 1.1815999746322632
  layer_3__moe_forward_time_ms: 9.417920112609863
  layer_3__expert_usage_current: [25. 34. 35. 34.]
  layer_3__total_assignments: 128.0
  layer_3__expert_batch_timings_ms: {0: 0.8087360262870789, 1: 0.9081599712371826, 2: 0.9254720211029053, 3: 0.9127039909362793}
  layer_3__expert_cumulative_timings_ms: {0: 0.8087360262870789, 1: 0.9081599712371826, 2: 0.9254720211029053, 3: 0.9127039909362793}


**Reasoning**:
The `SimpleMoELayer` class has been successfully modified to include expert timing and token counts in the metrics. The next step is to update the evaluation loop to utilize the `GPUProfiler` and `ThermalSignalGenerator` and collect the detailed metrics from the model's forward pass. This involves instantiating the `ThermalSignalGenerator` and modifying the `evaluate_model` function to get thermal signals and aggregate the detailed MoE metrics.



**Reasoning**:
The ThermalSignalGenerator failed to load the cost table because the file does not exist. Create a dummy `energy/cost_table.json` file with default values to allow the `ThermalSignalGenerator` to initialize successfully.



**Reasoning**:
Now that the dummy cost table exists, the `ThermalSignalGenerator` should be able to initialize correctly. The next step is to rerun the evaluation loop code to integrate the GPU profiling and thermal signal generation.



**Reasoning**:
The traceback indicates a `NameError` because `avg_inference_time_per_batch_ms` was used in the logging message outside the scope where it was defined within the `evaluate_model` function. This variable is part of the `results` dictionary returned by the function. Access the variable from the `baseline_results` dictionary instead.



## Update `transformerblock` to pass detailed metrics

### Subtask:
Update the `TransformerBlock` to collect the detailed expert metrics from the `SimpleMoELayer` and include them in the `layer_metrics` dictionary returned by its `forward` method.


**Reasoning**:
Update the TransformerBlock's forward method to include the detailed expert metrics from the SimpleMoELayer in its returned metrics dictionary. The previous subtask modified the SimpleMoELayer to return these metrics, and this step ensures they are propagated correctly through the TransformerBlock.



In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional, List, Any
import time
import numpy as np

class SimpleMoELayer(nn.Module):
    """
    A simplified, non-distributed MoE layer for profiling.
    Handles top-k routing, expert dispatch, and output combining.
    """
    def __init__(self, gate: nn.Module, experts: nn.ModuleList, top_k: int = 2, capacity_factor: float = 1.25):
        super().__init__()
        self.gate = gate # The gating network (e.g., a Linear layer)
        self.experts = experts # A ModuleList of expert networks
        self.n_experts = len(experts)
        self.top_k = top_k
        self.capacity_factor = capacity_factor # Keep for potential future use/comparison, not strictly used in this simple version

        if top_k > self.n_experts:
            raise ValueError(f"top_k ({top_k}) cannot be greater than n_experts ({self.n_experts})")

        # Dictionary to store accumulated expert timings
        self.expert_timings: Dict[int, float] = {}

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """
        Forward pass of the simple MoE layer.

        Args:
            x: Input tensor [batch_size * seq_len, d_model] (assuming flattened input)

        Returns:
            output: Output tensor [batch_size * seq_len, d_model]
            aux_loss: Load balancing loss scalar
            metrics: Dictionary containing routing metrics
        """
        # Input shape is expected to be [num_tokens, d_model]
        num_tokens, d_model = x.shape
        device = x.device

        # Compute gate logits
        gate_logits = self.gate(x)  # [num_tokens, n_experts]

        # Get top-k experts and their probabilities
        top_k_values, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1) # [num_tokens, top_k]
        top_k_probs = F.softmax(top_k_values, dim=-1) # [num_tokens, top_k]

        # Compute load balancing loss (similar to the router's logic)
        gate_probs_all = F.softmax(gate_logits, dim=-1) # [num_tokens, n_experts]

        # Fraction of tokens assigned to each expert (based on top-1 for simplicity in loss)
        # For top-k, a more accurate loss might be needed, but this is a common approximation
        top1_indices = top_k_indices[:, 0] # [num_tokens]
        expert_mask_top1 = F.one_hot(top1_indices, num_classes=self.n_experts).float() # [num_tokens, n_experts]
        tokens_per_expert_for_loss = expert_mask_top1.sum(dim=0) # [n_experts]
        total_tokens_for_loss = num_tokens # Use total tokens for normalization

        # Normalize by total tokens
        fraction_per_expert = tokens_per_expert_for_loss / (total_tokens_for_loss + 1e-8) # Add epsilon for stability

        # Average gate probability for each expert (average over all tokens)
        avg_gate_prob = gate_probs_all.mean(dim=0) # [n_experts]

        # Load balancing loss
        # This formulation encourages both fraction_per_expert and avg_gate_prob to be uniform
        aux_loss = (fraction_per_expert * avg_gate_prob).sum() * self.n_experts


        # Dispatch tokens to experts
        output = torch.zeros_like(x) # [num_tokens, d_model]
        metrics: Dict[str, Any] = {} # Use Dict[str, Any] for flexibility

        # Collect metrics for profiling
        expert_usage_counts = torch.zeros(self.n_experts, device=device)
        expert_batch_timings: Dict[int, float] = {} # Timings for this batch

        # Loop through each expert
        for expert_id in range(self.n_experts):
            # Find tokens routed to this expert (at any of the top-k positions)
            expert_tokens_mask = (top_k_indices == expert_id).any(dim=-1) # [num_tokens]
            expert_token_indices = torch.where(expert_tokens_mask)[0] # Indices of tokens for this expert

            if expert_token_indices.numel() > 0:
                # Get inputs for this expert
                expert_input = x[expert_token_indices] # [num_expert_tokens, d_model]

                # Get weights for this expert's tokens
                expert_weights_for_tokens = torch.zeros(expert_token_indices.numel(), device=device)
                for i, token_idx in enumerate(expert_token_indices):
                    pos_in_topk = torch.where(top_k_indices[token_idx] == expert_id)[0]
                    if pos_in_topk.numel() > 0:
                         expert_weights_for_tokens[i] = top_k_probs[token_idx, pos_in_topk].sum()


                # Record start time for expert computation
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)
                    start_event.record()
                else:
                    start_time = time.time()


                # Run expert forward pass
                expert_output = self.experts[expert_id](expert_input) # [num_expert_tokens, d_model]

                # Record end time and calculate duration
                if torch.cuda.is_available():
                    end_event.record()
                    torch.cuda.synchronize()
                    duration_ms = start_event.elapsed_time(end_event)
                else:
                    duration_ms = (time.time() - start_time) * 1000.0 # Convert to ms

                expert_batch_timings[expert_id] = duration_ms
                self.expert_timings[expert_id] = self.expert_timings.get(expert_id, 0.0) + duration_ms # Accumulate timing


                # Weight expert output by the routing probability and scatter back
                weighted_expert_output = expert_output * expert_weights_for_tokens.unsqueeze(-1)

                # Scatter weighted output back to the original token positions
                output[expert_token_indices] += weighted_expert_output

                # Update usage counts
                expert_usage_counts[expert_id] = expert_token_indices.numel()


        # Add metrics to the dictionary
        metrics['expert_usage_current'] = expert_usage_counts.cpu().numpy()
        metrics['total_assignments'] = expert_usage_counts.sum().item() # Total tokens dispatched
        metrics['expert_batch_timings_ms'] = expert_batch_timings
        metrics['expert_cumulative_timings_ms'] = self.expert_timings


        return output, aux_loss, metrics


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, n_experts: int = 8,
                 top_k: int = 2, dropout: float = 0.1, use_moe: bool = True, capacity_factor: float = 1.25):
        super().__init__()
        self.d_model = d_model
        self.use_moe = use_moe

        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        self.norm1 = nn.LayerNorm(d_model)

        if use_moe:
            # Gating network for the MoE layer
            gate_layer = nn.Linear(d_model, n_experts, bias=False)
             # Initialize gate weights
            nn.init.normal_(gate_layer.weight, mean=0.0, std=0.02)


            # Create a ModuleList with n_experts distinct instances of the expert_module
            experts_list = nn.ModuleList([
                nn.Sequential(nn.Linear(d_model, d_ff),
                                   nn.ReLU(),
                                   nn.Dropout(dropout),
                                   nn.Linear(d_ff, d_model)) for _ in range(n_experts)
            ])

            self.moe_layer = SimpleMoELayer(
                gate=gate_layer,
                experts=experts_list,
                top_k=top_k,
                capacity_factor=capacity_factor # Not used in SimpleMoELayer but kept for compatibility
            )

        else:
            self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                              nn.ReLU(),
                                              nn.Dropout(dropout),
                                              nn.Linear(d_ff, d_model))


        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Forward pass with optional profiling.

        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Attention mask
            profile: Whether to collect timing information

        Returns:
            output: Transformed tensor
            metrics: Dictionary containing routing metrics and timings
        """
        metrics = {}

        # Self-attention
        residual = x
        attn_out, attn_weights = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(residual + self.dropout(attn_out))

        # MoE or FFN
        residual = x

        if self.use_moe:
            # Reshape input for MoE layer: [batch_size * seq_len, d_model]
            batch_size, seq_len, d_model = x.shape
            x_flat = x.view(-1, d_model)

            # Profile MoE forward pass
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()

            # MoE forward pass using SimpleMoELayer
            # moe_metrics now includes expert-level timings and usage
            moe_out_flat, aux_loss, moe_metrics = self.moe_layer(x_flat)

            if profile and torch.cuda.is_available():
                end_event.record()
                torch.cuda.synchronize()

                total_time = start_event.elapsed_time(end_event)
                metrics['moe_forward_time_ms'] = total_time

            # Add all metrics from the moe_layer to the transformer block's metrics
            metrics.update(moe_metrics)

            # Reshape output back to [batch_size, seq_len, d_model]
            moe_out = moe_out_flat.view(batch_size, seq_len, d_model)

            x = residual + self.dropout(moe_out)
            metrics['aux_loss'] = aux_loss

        else:
            # Standard FFN
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()
                ffn_out = self.feed_forward(x)
                end_event.record()
                torch.cuda.synchronize()

                metrics['ffn_time_ms'] = start_event.elapsed_time(end_event)
            else:
                ffn_out = self.feed_forward(x)

            x = residual + self.dropout(ffn_out)

        x = self.norm2(x)

        return x, metrics


class MoETransformer(nn.Module):
    """Simple MoE Transformer model."""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        n_layers: int = 6,
        d_ff: int = 2048,
        n_experts: int = 8,
        top_k: int = 2,
        max_seq_len: int = 512,
        dropout: float = 0.1,
        use_moe_layers: Optional[list] = None,  # Which layers use MoE
        capacity_factor: float = 1.25,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers

        # Determine which layers use MoE
        if use_moe_layers is None:
            # By default, use MoE in every other layer starting from layer 1
            use_moe_layers = [i % 2 == 1 for i in range(n_layers)]
        elif len(use_moe_layers) != n_layers:
             raise ValueError(f"Length of use_moe_layers ({len(use_moe_layers)}) must match n_layers ({n_layers})")


        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                n_experts=n_experts,
                top_k=top_k,
                dropout=dropout,
                use_moe=use_moe_layers[i],
                capacity_factor=capacity_factor,
            )
            for i in range(n_layers)
        ])

        # Output projection
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Dict:
        """
        Forward pass.

        Args:
            input_ids: Token indices [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            profile: Whether to collect profiling information

        Returns:
            Dictionary containing logits, aux_loss, and optional metrics
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device

        # Embeddings
        positions = torch.arange(0, seq_len, device=device).unsqueeze(0)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)

        # Attention mask for causal modeling
        if attention_mask is None:
            # Create causal mask
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=device), diagonal=1
            ).bool()
        else:
            causal_mask = attention_mask

        total_aux_loss = 0.0
        all_metrics = {} if profile else None

        for i, layer in enumerate(self.layers):
            # Pass profiler instance to the layer if needed, or handle profiling inside layer
            x, layer_metrics = layer(x, mask=causal_mask, profile=profile)

            # Accumulate auxiliary loss from MoE layers
            if 'aux_loss' in layer_metrics:
                total_aux_loss += layer_metrics['aux_loss']

            # Collect metrics
            if profile:
                for key, value in layer_metrics.items():
                    if key != 'aux_loss':
                        # Append metrics to lists if they exist, otherwise create
                        metric_key = f'layer_{i}__{key}'
                        if metric_key in all_metrics:
                            if isinstance(all_metrics[metric_key], list) or isinstance(all_metrics[metric_key], dict):
                                all_metrics[metric_key].append(value)
                            else: # Convert to list if first time appending
                                all_metrics[metric_key] = [all_metrics[metric_key], value]
                        else:
                            all_metrics[metric_key] = value


        # Final layer norm and projection
        x = self.ln_f(x)
        logits = self.head(x)

        output = {
            'logits': logits,
            'aux_loss': total_aux_loss,
        }

        if profile:
            output['metrics'] = all_metrics

        return output

In [18]:
import json
from pathlib import Path

# Define default cost table data
default_cost_table = {
    "thermal_parameters": {
        "base_temperature": 45.0,
        "warm_temperature_threshold": 65.0,
        "hot_temperature_threshold": 83.0,
        "critical_temperature": 87.0,
        "thermal_time_constant": 15.0,
        "emergency_cooldown_duration": 30.0
    },
    "energy_budgets": {
        "low_power": {"max_power_watts": 200, "max_temperature": 75.0, "max_energy_per_token_mj": 5.0},
        "balanced": {"max_power_watts": 350, "max_temperature": 80.0, "max_energy_per_token_mj": 3.0},
        "performance": {"max_power_watts": 450, "max_temperature": 85.0, "max_energy_per_token_mj": 1.5},
        "emergency": {"max_power_watts": 100, "max_temperature": 70.0, "max_energy_per_token_mj": 10.0} # Added emergency budget
    },
    "expert_profiles": {
        "expert_0": {"average_power_watts": 50, "energy_per_token_mj": 2.0, "thermal_impact": 0.1},
        "expert_1": {"average_power_watts": 70, "energy_per_token_mj": 2.5, "thermal_impact": 0.15},
        "expert_2": {"average_power_watts": 30, "energy_per_token_mj": 1.0, "thermal_impact": 0.05},
        "expert_3": {"average_power_watts": 60, "energy_per_token_mj": 2.2, "thermal_impact": 0.12},
        "expert_4": {"average_power_watts": 55, "energy_per_token_mj": 1.8, "thermal_impact": 0.09},
        "expert_5": {"average_power_watts": 75, "energy_per_token_mj": 2.8, "thermal_impact": 0.18},
        "expert_6": {"average_power_watts": 40, "energy_per_token_mj": 1.2, "thermal_impact": 0.06},
        "expert_7": {"average_power_watts": 65, "energy_per_token_mj": 2.3, "thermal_impact": 0.14}
    }
}

# Define the directory and file path
file_path = Path("energy/cost_table.json")

# Check if the file exists
if not file_path.exists():
    # Create the directory if it doesn't exist
    file_path.parent.mkdir(parents=True, exist_ok=True)

    # Write the default data to the JSON file
    with open(file_path, 'w') as f:
        json.dump(default_cost_table, f, indent=2)

    print(f"Created dummy cost table at: {file_path}")
else:
    print(f"Cost table already exists at: {file_path}")

Created dummy cost table at: energy/cost_table.json


## Integrate `gpuprofiler` in the evaluation loop

### Subtask:
Integrate `gpuprofiler` in the evaluation loop.


**Reasoning**:
Modify the evaluate_model function to accept the GPUProfiler instance and collect GPU metrics periodically, storing them for later averaging and logging.



## Integrate `thermalsignalgenerator` in the evaluation loop

### Subtask:
Integrate `thermalsignalgenerator` in the evaluation loop.


**Reasoning**:
Modify the evaluate\_model function to accept the thermal signal generator, call it within the loop, store the signals, and include them in the return dictionary, updating the logging accordingly. In the main block, initialize and pass the thermal signal generator.



In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from typing import Dict, Tuple, Optional, List, Any
import time
import logging
import math
from pathlib import Path
import numpy as np

class WikiText2Dataset(Dataset):
    # just a simulation for now, skeleton code
    def __init__(self, vocab_size: int = 1000, seq_len: int = 512, num_samples: int = 1000):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
        logging.info(f"Initialized WikiText2Dataset with {num_samples} samples, seq_len={seq_len}, vocab_size={vocab_size}")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # Simulate text data: random token IDs
        input_ids = torch.randint(0, self.vocab_size, (self.seq_len,))
        # For language modeling, the target is usually the next token
        labels = torch.cat([input_ids[1:], torch.tensor([0])]) # Simple shift

        return {"input_ids": input_ids, "labels": labels}

# --- Evaluation Loop ---

def evaluate_model(
    model: MoETransformer,
    dataloader: DataLoader,
    device: torch.device,
    profiler: GPUProfiler, # Accept GPUProfiler instance
    thermal_signal_generator: ThermalSignalGenerator, # Accept ThermalSignalGenerator instance
    log_interval: int = 10,
) -> Dict[str, Any]:
    """
    Performs an evaluation loop for the MoE model, logging inference time
    and GPU metrics, and collecting thermal signals.

    Args:
        model: The MoETransformer model.
        dataloader: DataLoader for the evaluation dataset.
        device: Device to run evaluation on (e.g., 'cuda' or 'cpu').
        profiler: GPUProfiler instance for logging metrics.
        thermal_signal_generator: ThermalSignalGenerator instance.
        log_interval: How often to log progress and metrics.

    Returns:
        A dictionary containing average perplexity, total inference time,
        average power draw, and aggregated MoE metrics.
    """
    model.eval() # Set model to evaluation mode
    total_loss = 0.0
    total_tokens = 0
    total_batches = 0

    inference_times_ms = []
    power_draws_watts = []
    temperatures_c = []
    gpu_utilizations_percent = []

    # Aggregated MoE metrics across all layers and batches
    aggregated_moe_metrics: Dict[str, List[Any]] = {} # Use Any to handle dicts/lists

    thermal_signals: List[ThermalSignal] = []

    start_time = time.time()

    logging.info(f"Starting evaluation on device: {device}")

    with torch.no_grad(): # Disable gradient calculations
        for batch_idx, batch in enumerate(dataloader):
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            # Get thermal signal before computation
            thermal_signal = thermal_signal_generator.get_thermal_signal()
            if thermal_signal:
                thermal_signals.append(thermal_signal)
                # Log thermal state and recommendations
                logging.info(
                    f"Batch {batch_idx+1} Thermal Signal: State={thermal_signal.thermal_state.value}, "
                    f"PowerMode={thermal_signal.power_mode.value}, Temp={thermal_signal.temperature:.1f}°C, "
                    f"Power={thermal_signal.power_draw:.1f}W"
                )
                # Note: In a real system, you would use these signals to adapt model behavior
                # (e.g., select different experts, apply throttle factors). For this
                # baseline evaluation, we just log them.


            # Measure inference time for the batch
            if torch.cuda.is_available():
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)
                start_event.record()

            # Forward pass with profiling enabled for MoE metrics
            model_output = model(input_ids, profile=True) # Enable profiling in model for detailed metrics
            logits = model_output['logits']
            aux_loss = model_output.get('aux_loss', torch.tensor(0.0)).item()
            metrics = model_output.get('metrics', {})

            if torch.cuda.is_available():
                end_event.record()
                torch.cuda.synchronize()
                batch_inference_time_ms = start_event.elapsed_time(end_event)
                inference_times_ms.append(batch_inference_time_ms)
            else:
                # Approximate time for CPU
                batch_inference_time_ms = (time.time() - start_time) * 1000 # Rough estimate
                inference_times_ms.append(batch_inference_time_ms)


            # Calculate loss (for perplexity)
            # Reshape logits and labels for CrossEntropyLoss
            logits_flat = logits.view(-1, logits.size(-1))
            labels_flat = labels.view(-1)
            loss = F.cross_entropy(logits_flat, labels_flat, ignore_index=0) # Assuming 0 is padding/ignore

            total_loss += loss.item() * labels.numel() # Accumulate loss weighted by number of elements
            total_tokens += labels.numel()
            total_batches += 1

            # Collect and store GPU metrics for this batch
            gpu_metrics = profiler.get_current_metrics()
            if gpu_metrics:
                power_draws_watts.append(gpu_metrics.power_draw)
                temperatures_c.append(gpu_metrics.temperature)
                gpu_utilizations_percent.append(gpu_metrics.gpu_utilization)

            # Aggregate MoE specific metrics
            if metrics:
                for key, value in metrics.items():
                    if isinstance(value, np.ndarray):
                        # Convert arrays to lists for consistent aggregation
                        value = value.tolist()

                    if key not in aggregated_moe_metrics:
                         aggregated_moe_metrics[key] = []

                    if isinstance(value, (int, float, list)):
                         aggregated_moe_metrics[key].append(value)
                    elif isinstance(value, dict):
                         # For dictionaries (like expert timings), aggregate per key
                         for sub_key, sub_value in value.items():
                             agg_key = f'{key}__{sub_key}' # e.g., 'expert_batch_timings_ms__0'
                             if agg_key not in aggregated_moe_metrics:
                                 aggregated_moe_metrics[agg_key] = []
                             aggregated_moe_metrics[agg_key].append(sub_value)
                    else:
                         logging.warning(f"Skipping aggregation for metric {key} with unsupported type {type(value)}")


            if (batch_idx + 1) % log_interval == 0:
                avg_batch_loss = total_loss / total_tokens if total_tokens > 0 else 0
                current_perplexity = math.exp(avg_batch_loss) if avg_batch_loss < 100 else float('inf') # Avoid overflow

                log_msg = (
                    f"Batch {batch_idx+1}/{len(dataloader)} | "
                    f"Loss: {avg_batch_loss:.4f} | "
                    f"Perplexity: {current_perplexity:.2f} | "
                    f"Batch Time: {batch_inference_time_ms:.2f} ms"
                )
                # Log collected GPU metrics at log interval
                if gpu_metrics:
                    log_msg += (
                        f" | Power: {gpu_metrics.power_draw:.1f}W | "
                        f"Temp: {gpu_metrics.temperature:.1f}°C | "
                        f"GPU Util: {gpu_metrics.gpu_utilization:.1f}%"
                    )
                logging.info(log_msg)

    end_time = time.time()
    total_inference_duration_sec = end_time - start_time

    # Calculate overall averages
    avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
    final_perplexity = math.exp(avg_loss) if avg_loss < 100 else float('inf')

    # Calculate averages for collected GPU metrics
    avg_inference_time_ms = np.mean(inference_times_ms) if inference_times_ms else 0
    avg_power_draw_watts = np.mean(power_draws_watts) if power_draws_watts else 0
    avg_temperature_c = np.mean(temperatures_c) if temperatures_c else 0
    avg_gpu_utilization_percent = np.mean(gpu_utilizations_percent) if gpu_utilizations_percent else 0

    # Aggregate MoE metrics (e.g., average expert usage across batches)
    final_moe_metrics: Dict[str, Any] = {}
    for key, values_list in aggregated_moe_metrics.items():
        if not values_list:
            continue

        if isinstance(values_list[0], (int, float)):
            final_moe_metrics[f'avg_{key}'] = np.mean(values_list)
        elif isinstance(values_list[0], list) or isinstance(values_list[0], np.ndarray):
             # For list/array metrics (like expert_usage), average the lists/arrays
             # Ensure all lists/arrays have the same shape before averaging
             try:
                 final_moe_metrics[f'avg_{key}'] = np.mean([np.array(v) for v in values_list], axis=0).tolist()
             except Exception as e:
                 logging.warning(f"Could not average list/array metric {key}: {e}")
                 final_moe_metrics[f'raw_{key}'] = values_list # Store raw list if averaging fails
        elif isinstance(values_list[0], dict):
             # This case should ideally be handled by the sub_key aggregation above,
             # but as a fallback, log a warning.
             logging.warning(f"Metric {key} contains dictionaries, averaging not supported directly.")
             final_moe_metrics[f'raw_{key}'] = values_list # Store raw list of dicts

    results = {
        "final_perplexity": final_perplexity,
        "total_inference_duration_sec": total_inference_duration_sec,
        "avg_inference_time_per_batch_ms": avg_inference_time_ms,
        "avg_power_draw_watts": avg_power_draw_watts,
        "avg_temperature_c": avg_temperature_c,
        "avg_gpu_utilization_percent": avg_gpu_utilization_percent,
        "aggregated_moe_metrics": final_moe_metrics,
        "thermal_signals": thermal_signals # Include collected thermal signals
    }

    logging.info("\n--- Evaluation Summary ---")
    logging.info(f"Final Perplexity: {final_perplexity:.2f}")
    logging.info(f"Total Inference Duration: {total_inference_duration_sec:.2f} seconds")
    # Corrected access to avg_inference_time_per_batch_ms
    logging.info(f"Average Batch Inference Time: {results['avg_inference_time_per_batch_ms']:.2f} ms")
    if results['avg_power_draw_watts'] > 0:
        logging.info(f"Average Power Draw: {results['avg_power_draw_watts']:.1f} W")
        logging.info(f"Average Temperature: {results['avg_temperature_c']:.1f} °C")
        logging.info(f"Average GPU Utilization: {results['avg_gpu_utilization_percent']:.1f} %")
    logging.info("Aggregated MoE Metrics:")
    for k, v in final_moe_metrics.items():
        # Format array output nicely
        if isinstance(v, list) and all(isinstance(i, (int, float)) for i in v):
             logging.info(f"  {k}: {np.array(v)}")
        else:
             logging.info(f"  {k}: {v}")

    logging.info("\nCollected Thermal Signals:")
    for i, signal in enumerate(results['thermal_signals']):
         logging.info(f"  Signal {i+1}: Temp={signal.temperature:.1f}°C, Power={signal.power_draw:.1f}W, State={signal.thermal_state.value}")


    return results

# --- Main execution block ---

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    # 1. Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    # 2. Initialize GPUProfiler
    profiler = GPUProfiler()
    profiler.start_profiling() # Start the profiling thread

    # 3. Initialize ThermalSignalGenerator
    # Assuming a default cost table exists or is handled by the class
    thermal_signal_generator = ThermalSignalGenerator(profiler=profiler)
    logging.info("ThermalSignalGenerator initialized.")


    # 4. Model Parameters (Adjust as needed for your specific MoE setup)
    VOCAB_SIZE = 10000 # Example vocab size
    D_MODEL = 512
    N_HEADS = 8
    N_LAYERS = 6
    D_FF = 2048
    N_EXPERTS = 8
    TOP_K = 2
    MAX_SEQ_LEN = 512
    BATCH_SIZE = 4

    # Set which layers use MoE (e.g., every other layer)
    USE_MOE_LAYERS = [i % 2 == 1 for i in range(N_LAYERS)] # [False, True, False, True, False, True]

    # 5. Instantiate MoE Model
    logging.info("Initializing MoETransformer model...")
    model = MoETransformer(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        d_ff=D_FF,
        n_experts=N_EXPERTS,
        top_k=TOP_K,
        max_seq_len=MAX_SEQ_LEN,
        use_moe_layers=USE_MOE_LAYERS
    ).to(device)
    logging.info(f"Model instantiated with {sum(USE_MOE_LAYERS)} MoE layers.")

    # Optional: Load a pre-trained checkpoint if you have one
    # checkpoint_path = "path/to/your/checkpoint.pth"
    # if Path(checkpoint_path).exists():
    #     logging.info(f"Loading model checkpoint from {checkpoint_path}...")
    #     model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    #     logging.info("Model checkpoint loaded.")
    # else:
    #     logging.warning("No model checkpoint found. Using randomly initialized weights.")


    # 6. Prepare Dataset and DataLoader (using simulated WikiText-2 for baseline)
    # For actual WikiText-2, you'd use torchtext or similar to load and preprocess.
    # Example: from torchtext.datasets import WikiText2
    # For now, we use our dummy dataset.
    logging.info("Preparing dataset...")
    eval_dataset = WikiText2Dataset(vocab_size=VOCAB_SIZE, seq_len=MAX_SEQ_LEN, num_samples=100) # Use a small number of samples for baseline
    eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
    logging.info(f"Evaluation DataLoader ready with {len(eval_dataloader)} batches.")

    # 7. Run Evaluation
    logging.info("\n--- Starting Baseline Inference Evaluation ---")
    baseline_results = evaluate_model(
        model=model,
        dataloader=eval_dataloader,
        device=device,
        profiler=profiler, # Pass the profiler instance
        thermal_signal_generator=thermal_signal_generator, # Pass thermal signal generator
        log_interval=10
    )

    # 8. Final Sanity Checks and Cleanup
    logging.info("\n--- Sanity Checks ---")
    if baseline_results['final_perplexity'] < float('inf'):
        logging.info(f"Perplexity sanity check: {baseline_results['final_perplexity']:.2f} (lower is better, typically starts high for untrained models)")
    else:
        logging.warning("Perplexity is infinite. This might indicate issues like very high loss or training with random weights.")

    logging.info(f"Average Power Draw: {baseline_results['avg_power_draw_watts']:.2f} W")
    logging.info(f"Average Inference Time per Batch: {baseline_results['avg_inference_time_per_batch_ms']:.2f} ms")

    # Access detailed MoE metrics
    if 'aggregated_moe_metrics' in baseline_results:
        logging.info("\nDetailed MoE Metrics (Averaged):")
        for key, value in baseline_results['aggregated_moe_metrics'].items():
            logging.info(f"  {key}: {value}")

    logging.info("\nCollected Thermal Signals:")
    for i, signal in enumerate(baseline_results['thermal_signals']):
         logging.info(f"  Signal {i+1}: Temp={signal.temperature:.1f}°C, Power={signal.power_draw:.1f}W, State={signal.thermal_state.value}")


    profiler.stop_profiling() # Stop the profiling thread
    logging.info("Evaluation complete.")

## Demonstrate basic metric analysis

### Subtask:
After the evaluation loop in `evaluate_model`, process and present the collected profiling and thermal metrics. This could include calculating averages, visualizing trends (if storing historical data), and showing expert usage distributions.


**Reasoning**:
Process and present the collected profiling and thermal metrics as requested in the instructions. This involves iterating through the collected metrics, calculating averages and summaries, and printing them.



In [28]:
# Access aggregated MoE metrics and thermal signals from the results
aggregated_moe_metrics = baseline_results.get('aggregated_moe_metrics', {})
thermal_signals = baseline_results.get('thermal_signals', [])

print("\n--- Detailed MoE Metrics Analysis ---")
if aggregated_moe_metrics:
    for key, values_list in aggregated_moe_metrics.items():
        if not values_list:
            print(f"  {key}: No data collected")
            continue

        # Check if values_list is actually a list before accessing elements
        if not isinstance(values_list, list):
             # If it's not a list, it's likely a single scalar value already
             print(f"  {key.replace('avg_', '')}: {values_list}")
             continue


        # Handle different types of aggregated metrics (assuming values_list is now a list)
        if 'usage_current' in key and values_list and isinstance(values_list[0], list):
            # Expert usage is a list of counts per expert per batch
            try:
                avg_usage_per_expert = np.mean([np.array(v) for v in values_list if v is not None], axis=0) # Added check for None
                print(f"  Average {key.replace('avg_', '')} across batches: {avg_usage_per_expert}")
            except Exception as e:
                 logging.warning(f"Could not average list metric {key}: {e}")
                 print(f"  Raw {key}: {values_list}")

        elif ('timings_ms' in key or 'time_ms' in key) and values_list and isinstance(values_list[0], float): # Added 'time_ms' check
             # Expert batch timings or FFN timings are floats per batch
             avg_timing = np.mean(values_list)
             print(f"  Average {key.replace('avg_', '')} across batches: {avg_timing:.2f} ms")

        elif 'cumulative_timings_ms' in key and values_list and isinstance(values_list[0], dict): # Changed to check for dict
             # Cumulative timings are stored as dicts per batch
             # Aggregate by summing up times for each expert across batches
             total_cumulative_timings = defaultdict(float)
             for batch_dict in values_list:
                 if batch_dict: # Handle cases where the dict might be empty or None
                     for expert_id, timing in batch_dict.items():
                         total_cumulative_timings[expert_id] += timing

             print(f"  Total {key.replace('avg_', '')}: {dict(total_cumulative_timings)}")

        elif values_list and isinstance(values_list[0], (int, float)):
            # Other simple scalar metrics collected across batches
            print(f"  {key.replace('avg_', '')}: {np.mean(values_list):.4f}")
        else:
            print(f"  Raw {key}: {values_list}")


else:
    print("No aggregated MoE metrics collected.")


print("\n--- Thermal Signal Analysis ---")
if thermal_signals:
    temperatures = [s.temperature for s in thermal_signals]
    power_draws = [s.power_draw for s in thermal_signals]
    thermal_states = [s.thermal_state for s in thermal_signals]
    power_modes = [s.power_mode for s in thermal_signals]

    print(f"  Total Thermal Signals Collected: {len(thermal_signals)}")
    print(f"  Temperature: Avg={np.mean(temperatures)}°C, Max={np.max(temperatures)}°C, Min={np.min(temperatures)}°C") # Added min/max formatting
    print(f"  Power Draw: Avg={np.mean(power_draws)}W, Max={np.max(power_draws)}W, Min={np.min(power_draws)}W") # Added min/max formatting

    # Summarize thermal states and power modes
    thermal_state_counts = {}
    if thermal_signals: # Ensure thermal_signals is not empty before iterating
        for state in ThermalState:
            count = thermal_states.count(state)
            if count > 0:
                thermal_state_counts[state.value] = count
        print(f"  Thermal State Distribution: {thermal_state_counts}")

        power_mode_counts = {}
        for mode in PowerMode:
            count = power_modes.count(mode)
            if count > 0:
                power_mode_counts[mode.value] = count
        print(f"  Power Mode Distribution: {power_mode_counts}")

else:
    print("No thermal signals collected.")

  Total Thermal Signals Collected: 25
  Temperature: Avg=63.72°C, Max=66°C, Min=62°C
  Power Draw: Avg=31.777959999999997W, Max=34.495W, Min=30.564W
  Thermal State Distribution: {'cool': 17, 'warm': 8}
  Power Mode Distribution: {'balanced': 25}


**Reasoning**:
The error `IndexError: invalid index to scalar variable.` in the line `elif isinstance(values_list[0], (int, float)):` suggests that `values_list` is not a list but a scalar value (int or float) in some cases. This can happen because the aggregation logic in `evaluate_model` might store scalar metrics directly under their original key if there's only one value collected (e.g., in the case of FFN layers which don't return nested metrics). The code should check if `values_list` is a list before attempting to access `values_list[0]`.



In [30]:
# Access aggregated MoE metrics and thermal signals from the results
aggregated_moe_metrics = baseline_results.get('aggregated_moe_metrics', {})
thermal_signals = baseline_results.get('thermal_signals', [])

logging.info("\n--- Detailed MoE Metrics Analysis ---")
if aggregated_moe_metrics:
    for key, values_list in aggregated_moe_metrics.items():
        if not values_list:
            print(f"  {key}: No data collected")
            continue

        # Check if values_list is actually a list before accessing elements
        if not isinstance(values_list, list):
             # If it's not a list, it's likely a single scalar value already
             print(f"  {key.replace('avg_', '')}: {values_list}")
             continue

        # Handle different types of aggregated metrics (assuming values_list is now a list)
        if 'usage_current' in key and isinstance(values_list[0], list):
            # Expert usage is a list of counts per expert per batch
            try:
                avg_usage_per_expert = np.mean([np.array(v) for v in values_list], axis=0)
                print(f"  Average {key.replace('avg_', '')} across batches: {avg_usage_per_expert}")
            except Exception as e:
                 logging.warning(f"Could not average list metric {key}: {e}")
                 print(f"  Raw {key}: {values_list}")

        elif 'timings_ms' in key and isinstance(values_list[0], float):
             # Expert batch timings are floats per expert per batch (aggregated by sub_key)
             avg_timing = np.mean(values_list)
             print(f"  Average {key.replace('avg_', '')} across batches: {avg_timing:.2f} ms")

        elif 'cumulative_timings_ms' in key and isinstance(values_list[0], float):
             # Cumulative timings (aggregated by sub_key) - the last value is the total
             # Note: The aggregation logic stored the *last* cumulative value per batch,
             # so we can just take the mean of these last values across batches,
             # or, more accurately, the sum if we wanted total cumulative time across all batches.
             # Let's just print the averaged 'last' cumulative value for simplicity here.
             avg_last_cumulative_timing = np.mean(values_list)
             print(f"  Average Final {key.replace('avg_', '')}: {avg_last_cumulative_timing:.2f} ms")

        elif isinstance(values_list[0], (int, float)):
            # Other simple scalar metrics
            print(f"  {key.replace('avg_', '')}: {np.mean(values_list):.4f}")
        else:
            print(f"  Raw {key}: {values_list}")

else:
    print("No aggregated MoE metrics collected.")

print("\n--- Thermal Signal Analysis ---")
if thermal_signals:
    temperatures = [s.temperature for s in thermal_signals]
    power_draws = [s.power_draw for s in thermal_signals]
    thermal_states = [s.thermal_state for s in thermal_signals]
    power_modes = [s.power_mode for s in thermal_signals]

    print(f"  Total Thermal Signals Collected: {len(thermal_signals)}")
    print(f"  Temperature: Avg={np.mean(temperatures):.1f}°C, Max={np.max(temperatures):.1f}°C, Min={np.min(temperatures):.1f}°C")
    print(f"  Power Draw: Avg={np.mean(power_draws):.1f}W, Max={np.max(power_draws):.1f}W, Min={np.min(power_draws):.1f}W")

    # Summarize thermal states and power modes
    thermal_state_counts = {}
    for state in ThermalState:
        count = thermal_states.count(state)
        if count > 0:
            thermal_state_counts[state.value] = count
    print(f"  Thermal State Distribution: {thermal_state_counts}")

    power_mode_counts = {}
    for mode in PowerMode:
        count = power_modes.count(mode)
        if count > 0:
            power_mode_counts[mode.value] = count
    print(f"  Power Mode Distribution: {power_mode_counts}")

else:
    print("No thermal signals collected.")

  layer_0__ffn_time_ms: 3.5900300788879393
  layer_1__moe_forward_time_ms: 658.0288623046876
  layer_1__expert_usage_current: 512.0000
  layer_1__total_assignments: 4096.0
  layer_1__expert_batch_timings_ms__0: 1.241844472885132
  layer_1__expert_batch_timings_ms__1: 1.1928422594070434
  layer_1__expert_batch_timings_ms__2: 0.8485900783538818
  layer_1__expert_batch_timings_ms__3: 1.3191347074508668
  layer_1__expert_batch_timings_ms__4: 1.4396569681167604
  layer_1__expert_batch_timings_ms__5: 0.9450944042205811
  layer_1__expert_batch_timings_ms__6: 0.9551884770393372
  layer_1__expert_batch_timings_ms__7: 0.9880921649932861
  layer_1__expert_cumulative_timings_ms__0: 16.085323395729066
  layer_1__expert_cumulative_timings_ms__1: 15.411649560928344
  layer_1__expert_cumulative_timings_ms__2: 11.117712593078613
  layer_1__expert_cumulative_timings_ms__3: 17.216543798446654
  layer_1__expert_cumulative_timings_ms__4: 18.825777950286867
  layer_1__expert_cumulative_timings_ms__5: 13.203

## Summary:

### Data Analysis Key Findings

*   The `SimpleMoELayer` was successfully modified to capture and report execution time per expert and token counts per expert for each batch.
*   The `TransformerBlock` correctly propagates the detailed expert metrics from the `SimpleMoELayer` to the main model's output metrics dictionary.
*   The evaluation loop in the `evaluate_model` function was successfully integrated with the `GPUProfiler` to collect GPU metrics (power draw, temperature, utilization) at each batch step.
*   The `ThermalSignalGenerator` was integrated into the evaluation loop, allowing for the collection of thermal state information (temperature, power draw, thermal state, power mode) at the start of each batch.
*   The collected MoE metrics (expert usage, timings) and thermal signals were aggregated and analyzed after the evaluation loop, including calculating averages, min/max values, and summarizing distributions.

### Insights or Next Steps

*   The collected detailed expert timings and usage counts provide a foundation for analyzing expert performance and load balancing under different thermal conditions.
*   The integration of thermal signals enables future work on implementing thermal-aware expert routing or dynamic power management strategies during inference based on real-time thermal state.
