In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Dummy dataset
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, n=1000, d_model=64):
        self.data = torch.randn(n, d_model)
        self.targets = torch.randn(n, d_model)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

    def __len__(self):
        return len(self.data)

# Dummy transformer wrapper
class MoETransformerBlock(nn.Module):
    def __init__(self, d_model, num_experts, top_k, thermal_signal_generator):
        super().__init__()
        self.gate = nn.Linear(d_model, num_experts)
        experts = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_experts)])
        self.moe_layer = SimpleMoELayer(self.gate, experts, top_k, thermal_signal_generator=thermal_signal_generator)

    def forward(self, x):
        return self.moe_layer(x)

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_model = 64
num_experts = 8
top_k = 2
epochs = 3
batch_size = 32

dataset = DummyDataset(n=500, d_model=d_model)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

thermal_signal = ThermalSignalGenerator(device_id=0)
model = MoETransformerBlock(d_model, num_experts, top_k, thermal_signal_generator=thermal_signal).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training Loop
for epoch in range(epochs):
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        output, _, selected_experts = model(x)
        task_loss = criterion(output, y)
        energy_loss = compute_energy_loss(selected_expert_indices=selected_experts,
                                          expert_profiles=thermal_signal.expert_profiles,
                                          alpha=0.001)
        loss = task_loss + energy_loss
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} done. Loss: {loss.item():.4f}")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional, List, Any
import time
import numpy as np
class SimpleMoELayer(nn.Module):
    def __init__(self, gate: nn.Module, experts: nn.ModuleList, top_k: int = 2, capacity_factor: float = 1.25,
                 thermal_signal_generator=None):
        super().__init__()
        self.gate = gate
        self.experts = experts
        self.n_experts = len(experts)
        self.top_k = top_k
        self.capacity_factor = capacity_factor

        if top_k > self.n_experts:
            raise ValueError(f"top_k ({top_k}) cannot be greater than n_experts ({self.n_experts})")

        self.router = AdaptiveRouter(self.n_experts, top_k, thermal_signal_generator)
        self.expert_timings: Dict[int, float] = {}

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        num_tokens, d_model = x.shape
        device = x.device

        gate_logits = self.gate(x)
        top_k_indices, top_k_probs = self.router(gate_logits)
        gate_probs_all = F.softmax(gate_logits, dim=-1)

        top1_indices = top_k_indices[:, 0]
        expert_mask_top1 = F.one_hot(top1_indices, num_classes=self.n_experts).float()
        tokens_per_expert_for_loss = expert_mask_top1.sum(dim=0)
        fraction_per_expert = tokens_per_expert_for_loss / (num_tokens + 1e-8)
        avg_gate_prob = gate_probs_all.mean(dim=0)
        aux_loss = (fraction_per_expert * avg_gate_prob).sum() * self.n_experts

        output = torch.zeros_like(x)
        metrics: Dict[str, Any] = {}
        expert_usage_counts = torch.zeros(self.n_experts, device=device)
        expert_batch_timings: Dict[int, float] = {}

        for expert_id in range(self.n_experts):
            expert_tokens_mask = (top_k_indices == expert_id).any(dim=-1)
            expert_token_indices = torch.where(expert_tokens_mask)[0]

            if expert_token_indices.numel() > 0:
                expert_input = x[expert_token_indices]

                expert_weights_for_tokens = torch.zeros(expert_token_indices.numel(), device=device)
                for i, token_idx in enumerate(expert_token_indices):
                    pos_in_topk = torch.where(top_k_indices[token_idx] == expert_id)[0]
                    if pos_in_topk.numel() > 0:
                        expert_weights_for_tokens[i] = top_k_probs[token_idx, pos_in_topk].sum()

                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)
                    start_event.record()
                else:
                    start_time = time.time()

                expert_output = self.experts[expert_id](expert_input)

                if torch.cuda.is_available():
                    end_event.record()
                    torch.cuda.synchronize()
                    duration_ms = start_event.elapsed_time(end_event)
                else:
                    duration_ms = (time.time() - start_time) * 1000.0

                expert_batch_timings[expert_id] = duration_ms
                self.expert_timings[expert_id] = self.expert_timings.get(expert_id, 0.0) + duration_ms

                weighted_output = expert_output * expert_weights_for_tokens.unsqueeze(-1)
                output[expert_token_indices] += weighted_output
                expert_usage_counts[expert_id] = expert_token_indices.numel()

        metrics['expert_usage_current'] = expert_usage_counts.cpu().numpy()
        metrics['total_assignments'] = expert_usage_counts.sum().item()
        metrics['expert_batch_timings_ms'] = expert_batch_timings
        metrics['expert_cumulative_timings_ms'] = self.expert_timings

        return output, aux_loss, metrics


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, n_experts: int = 8,
                 top_k: int = 2, dropout: float = 0.1, use_moe: bool = True, capacity_factor: float = 1.25):
        super().__init__()
        self.d_model = d_model
        self.use_moe = use_moe

        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        self.norm1 = nn.LayerNorm(d_model)

        if use_moe:
            # Gating network for the MoE layer
            gate_layer = nn.Linear(d_model, n_experts, bias=False)
             # Initialize gate weights
            nn.init.normal_(gate_layer.weight, mean=0.0, std=0.02)


            # Create a ModuleList with n_experts distinct instances of the expert_module
            experts_list = nn.ModuleList([
                nn.Sequential(nn.Linear(d_model, d_ff),
                                   nn.ReLU(),
                                   nn.Dropout(dropout),
                                   nn.Linear(d_ff, d_model)) for _ in range(n_experts)
            ])

            self.moe_layer = SimpleMoELayer(
                gate=gate_layer,
                experts=experts_list,
                top_k=top_k,
                capacity_factor=capacity_factor # Not used in SimpleMoELayer but kept for compatibility
            )

        else:
            self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),
                                              nn.ReLU(),
                                              nn.Dropout(dropout),
                                              nn.Linear(d_ff, d_model))


        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Forward pass with optional profiling.

        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Attention mask
            profile: Whether to collect timing information

        Returns:
            output: Transformed tensor
            metrics: Dictionary containing routing metrics and timings
        """
        metrics = {}

        # Self-attention
        residual = x
        attn_out, attn_weights = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(residual + self.dropout(attn_out))

        # MoE or FFN
        residual = x

        if self.use_moe:
            # Reshape input for MoE layer: [batch_size * seq_len, d_model]
            batch_size, seq_len, d_model = x.shape
            x_flat = x.view(-1, d_model)

            # Profile MoE forward pass
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()

            # MoE forward pass using SimpleMoELayer
            # moe_metrics now includes expert-level timings and usage
            moe_out_flat, aux_loss, moe_metrics = self.moe_layer(x_flat)

            if profile and torch.cuda.is_available():
                end_event.record()
                torch.cuda.synchronize()

                total_time = start_event.elapsed_time(end_event)
                metrics['moe_forward_time_ms'] = total_time

            # Add all metrics from the moe_layer to the transformer block's metrics
            metrics.update(moe_metrics)

            # Reshape output back to [batch_size, seq_len, d_model]
            moe_out = moe_out_flat.view(batch_size, seq_len, d_model)

            x = residual + self.dropout(moe_out)
            metrics['aux_loss'] = aux_loss

        else:
            # Standard FFN
            if profile and torch.cuda.is_available():
                torch.cuda.synchronize()
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record()
                ffn_out = self.feed_forward(x)
                end_event.record()
                torch.cuda.synchronize()

                metrics['ffn_time_ms'] = start_event.elapsed_time(end_event)
            else:
                ffn_out = self.feed_forward(x)

            x = residual + self.dropout(ffn_out)

        x = self.norm2(x)

        return x, metrics


class MoETransformer(nn.Module):
    """Simple MoE Transformer model."""

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        n_layers: int = 6,
        d_ff: int = 2048,
        n_experts: int = 8,
        top_k: int = 2,
        max_seq_len: int = 512,
        dropout: float = 0.1,
        use_moe_layers: Optional[list] = None,  # Which layers use MoE
        capacity_factor: float = 1.25,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers

        # Determine which layers use MoE
        if use_moe_layers is None:
            # By default, use MoE in every other layer starting from layer 1
            use_moe_layers = [i % 2 == 1 for i in range(n_layers)]
        elif len(use_moe_layers) != n_layers:
             raise ValueError(f"Length of use_moe_layers ({len(use_moe_layers)}) must match n_layers ({n_layers})")


        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                n_experts=n_experts,
                top_k=top_k,
                dropout=dropout,
                use_moe=use_moe_layers[i],
                capacity_factor=capacity_factor,
            )
            for i in range(n_layers)
        ])

        # Output projection
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        profile: bool = False
    ) -> Dict:
        """
        Forward pass.

        Args:
            input_ids: Token indices [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            profile: Whether to collect profiling information

        Returns:
            Dictionary containing logits, aux_loss, and optional metrics
        """
        batch_size, seq_len = input_ids.shape
        device = input_ids.device

        # Embeddings
        positions = torch.arange(0, seq_len, device=device).unsqueeze(0)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)

        # Attention mask for causal modeling
        if attention_mask is None:
            # Create causal mask
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=device), diagonal=1
            ).bool()
        else:
            causal_mask = attention_mask

        total_aux_loss = 0.0
        all_metrics = {} if profile else None

        for i, layer in enumerate(self.layers):
            # Pass profiler instance to the layer if needed, or handle profiling inside layer
            x, layer_metrics = layer(x, mask=causal_mask, profile=profile)

            # Accumulate auxiliary loss from MoE layers
            if 'aux_loss' in layer_metrics:
                total_aux_loss += layer_metrics['aux_loss']

            # Collect metrics
            if profile:
                for key, value in layer_metrics.items():
                    if key != 'aux_loss':
                        # Append metrics to lists if they exist, otherwise create
                        metric_key = f'layer_{i}__{key}'
                        if metric_key in all_metrics:
                            if isinstance(all_metrics[metric_key], list) or isinstance(all_metrics[metric_key], dict):
                                all_metrics[metric_key].append(value)
                            else: # Convert to list if first time appending
                                all_metrics[metric_key] = [all_metrics[metric_key], value]
                        else:
                            all_metrics[metric_key] = value


        # Final layer norm and projection
        x = self.ln_f(x)
        logits = self.head(x)

        output = {
            'logits': logits,
            'aux_loss': total_aux_loss,
        }

        if profile:
            output['metrics'] = all_metrics

        return output

In [None]:
import pynvml
import time
import threading

class ThermalSignalGenerator:
    def __init__(self, device_id=0, update_interval=0.5):
        self.device_id = device_id
        self.update_interval = update_interval
        self.expert_profiles = {}  # To be filled with real or estimated values
        self.thermal_state = "cool"
        self.expert_priorities = {}
        self.lock = threading.Lock()

        pynvml.nvmlInit()
        self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        self._start_background_update()

    def _start_background_update(self):
        thread = threading.Thread(target=self._update_loop, daemon=True)
        thread.start()

    def _update_loop(self):
        while True:
            with self.lock:
                temperature = pynvml.nvmlDeviceGetTemperature(self.handle, pynvml.NVML_TEMPERATURE_GPU)
                power_draw = pynvml.nvmlDeviceGetPowerUsage(self.handle) / 1000.0  # mW -> W

                if temperature > 85 or power_draw > 250:
                    self.thermal_state = "critical"
                elif temperature > 75:
                    self.thermal_state = "hot"
                elif temperature > 60:
                    self.thermal_state = "warm"
                else:
                    self.thermal_state = "cool"

                self._update_expert_priorities()
            time.sleep(self.update_interval)

    def _update_expert_priorities(self):
        if self.thermal_state == "cool":
            self.expert_priorities = {str(k): 0.0 for k in range(16)}  # uniform
        elif self.thermal_state == "warm":
            self.expert_priorities = {str(k): -0.1 * k for k in range(16)}
        elif self.thermal_state == "hot":
            self.expert_priorities = {str(k): -0.2 * k for k in range(16)}
        elif self.thermal_state == "critical":
            self.expert_priorities = {str(k): -0.5 * k for k in range(16)}

    def get_expert_priorities(self):
        with self.lock:
            return self.expert_priorities.copy()


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from typing import Dict, Tuple, Optional, List, Any
import time
import logging
import math
from pathlib import Path
import numpy as np

class WikiText2Dataset(Dataset):
    # just a simulation for now, skeleton code
    def __init__(self, vocab_size: int = 1000, seq_len: int = 512, num_samples: int = 1000):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
        logging.info(f"Initialized WikiText2Dataset with {num_samples} samples, seq_len={seq_len}, vocab_size={vocab_size}")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # Simulate text data: random token IDs
        input_ids = torch.randint(0, self.vocab_size, (self.seq_len,))
        # For language modeling, the target is usually the next token
        labels = torch.cat([input_ids[1:], torch.tensor([0])]) # Simple shift

        return {"input_ids": input_ids, "labels": labels}

# --- Evaluation Loop ---

def evaluate_model(
    model: MoETransformer,
    dataloader: DataLoader,
    device: torch.device,
    profiler: GPUProfiler, # Accept GPUProfiler instance
    thermal_signal_generator: ThermalSignalGenerator, # Accept ThermalSignalGenerator instance
    log_interval: int = 10,
) -> Dict[str, Any]:
    """
    Performs an evaluation loop for the MoE model, logging inference time
    and GPU metrics, and collecting thermal signals.

    Args:
        model: The MoETransformer model.
        dataloader: DataLoader for the evaluation dataset.
        device: Device to run evaluation on (e.g., 'cuda' or 'cpu').
        profiler: GPUProfiler instance for logging metrics.
        thermal_signal_generator: ThermalSignalGenerator instance.
        log_interval: How often to log progress and metrics.

    Returns:
        A dictionary containing average perplexity, total inference time,
        average power draw, and aggregated MoE metrics.
    """
    model.eval() # Set model to evaluation mode
    total_loss = 0.0
    total_tokens = 0
    total_batches = 0

    inference_times_ms = []
    power_draws_watts = []
    temperatures_c = []
    gpu_utilizations_percent = []

    # Aggregated MoE metrics across all layers and batches
    aggregated_moe_metrics: Dict[str, List[Any]] = {} # Use Any to handle dicts/lists

    thermal_signals: List[ThermalSignal] = []

    start_time = time.time()

    logging.info(f"Starting evaluation on device: {device}")

    with torch.no_grad(): # Disable gradient calculations
        for batch_idx, batch in enumerate(dataloader):
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            # Get thermal signal before computation
            thermal_signal = thermal_signal_generator.get_thermal_signal()
            if thermal_signal:
                thermal_signals.append(thermal_signal)
                # Log thermal state and recommendations
                logging.info(
                    f"Batch {batch_idx+1} Thermal Signal: State={thermal_signal.thermal_state.value}, "
                    f"PowerMode={thermal_signal.power_mode.value}, Temp={thermal_signal.temperature:.1f}°C, "
                    f"Power={thermal_signal.power_draw:.1f}W"
                )
                # Note: In a real system, you would use these signals to adapt model behavior
                # (e.g., select different experts, apply throttle factors). For this
                # baseline evaluation, we just log them.


            # Measure inference time for the batch
            if torch.cuda.is_available():
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)
                start_event.record()

            # Forward pass with profiling enabled for MoE metrics
            model_output = model(input_ids, profile=True) # Enable profiling in model for detailed metrics
            logits = model_output['logits']
            aux_loss = model_output.get('aux_loss', torch.tensor(0.0)).item()
            metrics = model_output.get('metrics', {})

            if torch.cuda.is_available():
                end_event.record()
                torch.cuda.synchronize()
                batch_inference_time_ms = start_event.elapsed_time(end_event)
                inference_times_ms.append(batch_inference_time_ms)
            else:
                # Approximate time for CPU
                batch_inference_time_ms = (time.time() - start_time) * 1000 # Rough estimate
                inference_times_ms.append(batch_inference_time_ms)


            # Calculate loss (for perplexity)
            # Reshape logits and labels for CrossEntropyLoss
            logits_flat = logits.view(-1, logits.size(-1))
            labels_flat = labels.view(-1)
            loss = F.cross_entropy(logits_flat, labels_flat, ignore_index=0) # Assuming 0 is padding/ignore

            total_loss += loss.item() * labels.numel() # Accumulate loss weighted by number of elements
            total_tokens += labels.numel()
            total_batches += 1

            # Collect and store GPU metrics for this batch
            gpu_metrics = profiler.get_current_metrics()
            if gpu_metrics:
                power_draws_watts.append(gpu_metrics.power_draw)
                temperatures_c.append(gpu_metrics.temperature)
                gpu_utilizations_percent.append(gpu_metrics.gpu_utilization)

            # Aggregate MoE specific metrics
            if metrics:
                for key, value in metrics.items():
                    if isinstance(value, np.ndarray):
                        # Convert arrays to lists for consistent aggregation
                        value = value.tolist()

                    if key not in aggregated_moe_metrics:
                         aggregated_moe_metrics[key] = []

                    if isinstance(value, (int, float, list)):
                         aggregated_moe_metrics[key].append(value)
                    elif isinstance(value, dict):
                         # For dictionaries (like expert timings), aggregate per key
                         for sub_key, sub_value in value.items():
                             agg_key = f'{key}__{sub_key}' # e.g., 'expert_batch_timings_ms__0'
                             if agg_key not in aggregated_moe_metrics:
                                 aggregated_moe_metrics[agg_key] = []
                             aggregated_moe_metrics[agg_key].append(sub_value)
                    else:
                         logging.warning(f"Skipping aggregation for metric {key} with unsupported type {type(value)}")


            if (batch_idx + 1) % log_interval == 0:
                avg_batch_loss = total_loss / total_tokens if total_tokens > 0 else 0
                current_perplexity = math.exp(avg_batch_loss) if avg_batch_loss < 100 else float('inf') # Avoid overflow

                log_msg = (
                    f"Batch {batch_idx+1}/{len(dataloader)} | "
                    f"Loss: {avg_batch_loss:.4f} | "
                    f"Perplexity: {current_perplexity:.2f} | "
                    f"Batch Time: {batch_inference_time_ms:.2f} ms"
                )
                # Log collected GPU metrics at log interval
                if gpu_metrics:
                    log_msg += (
                        f" | Power: {gpu_metrics.power_draw:.1f}W | "
                        f"Temp: {gpu_metrics.temperature:.1f}°C | "
                        f"GPU Util: {gpu_metrics.gpu_utilization:.1f}%"
                    )
                logging.info(log_msg)

    end_time = time.time()
    total_inference_duration_sec = end_time - start_time

    # Calculate overall averages
    avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
    final_perplexity = math.exp(avg_loss) if avg_loss < 100 else float('inf')

    # Calculate averages for collected GPU metrics
    avg_inference_time_ms = np.mean(inference_times_ms) if inference_times_ms else 0
    avg_power_draw_watts = np.mean(power_draws_watts) if power_draws_watts else 0
    avg_temperature_c = np.mean(temperatures_c) if temperatures_c else 0
    avg_gpu_utilization_percent = np.mean(gpu_utilizations_percent) if gpu_utilizations_percent else 0

    # Aggregate MoE metrics (e.g., average expert usage across batches)
    final_moe_metrics: Dict[str, Any] = {}
    for key, values_list in aggregated_moe_metrics.items():
        if not values_list:
            continue

        if isinstance(values_list[0], (int, float)):
            final_moe_metrics[f'avg_{key}'] = np.mean(values_list)
        elif isinstance(values_list[0], list) or isinstance(values_list[0], np.ndarray):
             # For list/array metrics (like expert_usage), average the lists/arrays
             # Ensure all lists/arrays have the same shape before averaging
             try:
                 final_moe_metrics[f'avg_{key}'] = np.mean([np.array(v) for v in values_list], axis=0).tolist()
             except Exception as e:
                 logging.warning(f"Could not average list/array metric {key}: {e}")
                 final_moe_metrics[f'raw_{key}'] = values_list # Store raw list if averaging fails
        elif isinstance(values_list[0], dict):
             # This case should ideally be handled by the sub_key aggregation above,
             # but as a fallback, log a warning.
             logging.warning(f"Metric {key} contains dictionaries, averaging not supported directly.")
             final_moe_metrics[f'raw_{key}'] = values_list # Store raw list of dicts

    results = {
        "final_perplexity": final_perplexity,
        "total_inference_duration_sec": total_inference_duration_sec,
        "avg_inference_time_per_batch_ms": avg_inference_time_ms,
        "avg_power_draw_watts": avg_power_draw_watts,
        "avg_temperature_c": avg_temperature_c,
        "avg_gpu_utilization_percent": avg_gpu_utilization_percent,
        "aggregated_moe_metrics": final_moe_metrics,
        "thermal_signals": thermal_signals # Include collected thermal signals
    }

    logging.info("\n--- Evaluation Summary ---")
    logging.info(f"Final Perplexity: {final_perplexity:.2f}")
    logging.info(f"Total Inference Duration: {total_inference_duration_sec:.2f} seconds")
    # Corrected access to avg_inference_time_per_batch_ms
    logging.info(f"Average Batch Inference Time: {results['avg_inference_time_per_batch_ms']:.2f} ms")
    if results['avg_power_draw_watts'] > 0:
        logging.info(f"Average Power Draw: {results['avg_power_draw_watts']:.1f} W")
        logging.info(f"Average Temperature: {results['avg_temperature_c']:.1f} °C")
        logging.info(f"Average GPU Utilization: {results['avg_gpu_utilization_percent']:.1f} %")
    logging.info("Aggregated MoE Metrics:")
    for k, v in final_moe_metrics.items():
        # Format array output nicely
        if isinstance(v, list) and all(isinstance(i, (int, float)) for i in v):
             logging.info(f"  {k}: {np.array(v)}")
        else:
             logging.info(f"  {k}: {v}")

    logging.info("\nCollected Thermal Signals:")
    for i, signal in enumerate(results['thermal_signals']):
         logging.info(f"  Signal {i+1}: Temp={signal.temperature:.1f}°C, Power={signal.power_draw:.1f}W, State={signal.thermal_state.value}")


    return results

# --- Main execution block ---

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    # 1. Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    # 2. Initialize GPUProfiler
    profiler = GPUProfiler()
    profiler.start_profiling() # Start the profiling thread

    # 3. Initialize ThermalSignalGenerator
    # Assuming a default cost table exists or is handled by the class
    thermal_signal_generator = ThermalSignalGenerator(profiler=profiler)
    logging.info("ThermalSignalGenerator initialized.")


    # 4. Model Parameters (Adjust as needed for your specific MoE setup)
    VOCAB_SIZE = 10000 # Example vocab size
    D_MODEL = 512
    N_HEADS = 8
    N_LAYERS = 6
    D_FF = 2048
    N_EXPERTS = 8
    TOP_K = 2
    MAX_SEQ_LEN = 512
    BATCH_SIZE = 4

    # Set which layers use MoE (e.g., every other layer)
    USE_MOE_LAYERS = [i % 2 == 1 for i in range(N_LAYERS)] # [False, True, False, True, False, True]

    # 5. Instantiate MoE Model
    logging.info("Initializing MoETransformer model...")
    model = MoETransformer(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        d_ff=D_FF,
        n_experts=N_EXPERTS,
        top_k=TOP_K,
        max_seq_len=MAX_SEQ_LEN,
        use_moe_layers=USE_MOE_LAYERS
    ).to(device)
    logging.info(f"Model instantiated with {sum(USE_MOE_LAYERS)} MoE layers.")

    # Optional: Load a pre-trained checkpoint if you have one
    # checkpoint_path = "path/to/your/checkpoint.pth"
    # if Path(checkpoint_path).exists():
    #     logging.info(f"Loading model checkpoint from {checkpoint_path}...")
    #     model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    #     logging.info("Model checkpoint loaded.")
    # else:
    #     logging.warning("No model checkpoint found. Using randomly initialized weights.")


    # 6. Prepare Dataset and DataLoader (using simulated WikiText-2 for baseline)
    # For actual WikiText-2, you'd use torchtext or similar to load and preprocess.
    # Example: from torchtext.datasets import WikiText2
    # For now, we use our dummy dataset.
    logging.info("Preparing dataset...")
    eval_dataset = WikiText2Dataset(vocab_size=VOCAB_SIZE, seq_len=MAX_SEQ_LEN, num_samples=100) # Use a small number of samples for baseline
    eval_dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
    logging.info(f"Evaluation DataLoader ready with {len(eval_dataloader)} batches.")

    # 7. Run Evaluation
    logging.info("\n--- Starting Baseline Inference Evaluation ---")
    baseline_results = evaluate_model(
        model=model,
        dataloader=eval_dataloader,
        device=device,
        profiler=profiler, # Pass the profiler instance
        thermal_signal_generator=thermal_signal_generator, # Pass thermal signal generator
        log_interval=10
    )

    # 8. Final Sanity Checks and Cleanup
    logging.info("\n--- Sanity Checks ---")
    if baseline_results['final_perplexity'] < float('inf'):
        logging.info(f"Perplexity sanity check: {baseline_results['final_perplexity']:.2f} (lower is better, typically starts high for untrained models)")
    else:
        logging.warning("Perplexity is infinite. This might indicate issues like very high loss or training with random weights.")

    logging.info(f"Average Power Draw: {baseline_results['avg_power_draw_watts']:.2f} W")
    logging.info(f"Average Inference Time per Batch: {baseline_results['avg_inference_time_per_batch_ms']:.2f} ms")

    # Access detailed MoE metrics
    if 'aggregated_moe_metrics' in baseline_results:
        logging.info("\nDetailed MoE Metrics (Averaged):")
        for key, value in baseline_results['aggregated_moe_metrics'].items():
            logging.info(f"  {key}: {value}")

    logging.info("\nCollected Thermal Signals:")
    for i, signal in enumerate(baseline_results['thermal_signals']):
         logging.info(f"  Signal {i+1}: Temp={signal.temperature:.1f}°C, Power={signal.power_draw:.1f}W, State={signal.thermal_state.value}")


    profiler.stop_profiling() # Stop the profiling thread
    logging.info("Evaluation complete.")

In [None]:
# Access aggregated MoE metrics and thermal signals from the results
aggregated_moe_metrics = baseline_results.get('aggregated_moe_metrics', {})
thermal_signals = baseline_results.get('thermal_signals', [])

print("\n--- Detailed MoE Metrics Analysis ---")
if aggregated_moe_metrics:
    for key, values_list in aggregated_moe_metrics.items():
        if not values_list:
            print(f"  {key}: No data collected")
            continue

        # Check if values_list is actually a list before accessing elements
        if not isinstance(values_list, list):
             # If it's not a list, it's likely a single scalar value already
             print(f"  {key.replace('avg_', '')}: {values_list}")
             continue


        # Handle different types of aggregated metrics (assuming values_list is now a list)
        if 'usage_current' in key and values_list and isinstance(values_list[0], list):
            # Expert usage is a list of counts per expert per batch
            try:
                avg_usage_per_expert = np.mean([np.array(v) for v in values_list if v is not None], axis=0) # Added check for None
                print(f"  Average {key.replace('avg_', '')} across batches: {avg_usage_per_expert}")
            except Exception as e:
                 logging.warning(f"Could not average list metric {key}: {e}")
                 print(f"  Raw {key}: {values_list}")

        elif ('timings_ms' in key or 'time_ms' in key) and values_list and isinstance(values_list[0], float): # Added 'time_ms' check
             # Expert batch timings or FFN timings are floats per batch
             avg_timing = np.mean(values_list)
             print(f"  Average {key.replace('avg_', '')} across batches: {avg_timing:.2f} ms")

        elif 'cumulative_timings_ms' in key and values_list and isinstance(values_list[0], dict): # Changed to check for dict
             # Cumulative timings are stored as dicts per batch
             # Aggregate by summing up times for each expert across batches
             total_cumulative_timings = defaultdict(float)
             for batch_dict in values_list:
                 if batch_dict: # Handle cases where the dict might be empty or None
                     for expert_id, timing in batch_dict.items():
                         total_cumulative_timings[expert_id] += timing

             print(f"  Total {key.replace('avg_', '')}: {dict(total_cumulative_timings)}")

        elif values_list and isinstance(values_list[0], (int, float)):
            # Other simple scalar metrics collected across batches
            print(f"  {key.replace('avg_', '')}: {np.mean(values_list):.4f}")
        else:
            print(f"  Raw {key}: {values_list}")


else:
    print("No aggregated MoE metrics collected.")


print("\n--- Thermal Signal Analysis ---")
if thermal_signals:
    temperatures = [s.temperature for s in thermal_signals]
    power_draws = [s.power_draw for s in thermal_signals]
    thermal_states = [s.thermal_state for s in thermal_signals]
    power_modes = [s.power_mode for s in thermal_signals]

    print(f"  Total Thermal Signals Collected: {len(thermal_signals)}")
    print(f"  Temperature: Avg={np.mean(temperatures)}°C, Max={np.max(temperatures)}°C, Min={np.min(temperatures)}°C") # Added min/max formatting
    print(f"  Power Draw: Avg={np.mean(power_draws)}W, Max={np.max(power_draws)}W, Min={np.min(power_draws)}W") # Added min/max formatting

    # Summarize thermal states and power modes
    thermal_state_counts = {}
    if thermal_signals: # Ensure thermal_signals is not empty before iterating
        for state in ThermalState:
            count = thermal_states.count(state)
            if count > 0:
                thermal_state_counts[state.value] = count
        print(f"  Thermal State Distribution: {thermal_state_counts}")

        power_mode_counts = {}
        for mode in PowerMode:
            count = power_modes.count(mode)
            if count > 0:
                power_mode_counts[mode.value] = count
        print(f"  Power Mode Distribution: {power_mode_counts}")

else:
    print("No thermal signals collected.")

# Access aggregated MoE metrics and thermal signals from the results
aggregated_moe_metrics = baseline_results.get('aggregated_moe_metrics', {})
thermal_signals = baseline_results.get('thermal_signals', [])

logging.info("\n--- Detailed MoE Metrics Analysis ---")
if aggregated_moe_metrics:
    for key, values_list in aggregated_moe_metrics.items():
        if not values_list:
            print(f"  {key}: No data collected")
            continue

        # Check if values_list is actually a list before accessing elements
        if not isinstance(values_list, list):
             # If it's not a list, it's likely a single scalar value already
             print(f"  {key.replace('avg_', '')}: {values_list}")
             continue

        # Handle different types of aggregated metrics (assuming values_list is now a list)
        if 'usage_current' in key and isinstance(values_list[0], list):
            # Expert usage is a list of counts per expert per batch
            try:
                avg_usage_per_expert = np.mean([np.array(v) for v in values_list], axis=0)
                print(f"  Average {key.replace('avg_', '')} across batches: {avg_usage_per_expert}")
            except Exception as e:
                 logging.warning(f"Could not average list metric {key}: {e}")
                 print(f"  Raw {key}: {values_list}")

        elif 'timings_ms' in key and isinstance(values_list[0], float):
             # Expert batch timings are floats per expert per batch (aggregated by sub_key)
             avg_timing = np.mean(values_list)
             print(f"  Average {key.replace('avg_', '')} across batches: {avg_timing:.2f} ms")

        elif 'cumulative_timings_ms' in key and isinstance(values_list[0], float):
             # Cumulative timings (aggregated by sub_key) - the last value is the total
             # Note: The aggregation logic stored the *last* cumulative value per batch,
             # so we can just take the mean of these last values across batches,
             # or, more accurately, the sum if we wanted total cumulative time across all batches.
             # Let's just print the averaged 'last' cumulative value for simplicity here.
             avg_last_cumulative_timing = np.mean(values_list)
             print(f"  Average Final {key.replace('avg_', '')}: {avg_last_cumulative_timing:.2f} ms")

        elif isinstance(values_list[0], (int, float)):
            # Other simple scalar metrics
            print(f"  {key.replace('avg_', '')}: {np.mean(values_list):.4f}")
        else:
            print(f"  Raw {key}: {values_list}")

else:
    print("No aggregated MoE metrics collected.")

print("\n--- Thermal Signal Analysis ---")
if thermal_signals:
    temperatures = [s.temperature for s in thermal_signals]
    power_draws = [s.power_draw for s in thermal_signals]
    thermal_states = [s.thermal_state for s in thermal_signals]
    power_modes = [s.power_mode for s in thermal_signals]

    print(f"  Total Thermal Signals Collected: {len(thermal_signals)}")
    print(f"  Temperature: Avg={np.mean(temperatures):.1f}°C, Max={np.max(temperatures):.1f}°C, Min={np.min(temperatures):.1f}°C")
    print(f"  Power Draw: Avg={np.mean(power_draws):.1f}W, Max={np.max(power_draws):.1f}W, Min={np.min(power_draws):.1f}W")

    # Summarize thermal states and power modes
    thermal_state_counts = {}
    for state in ThermalState:
        count = thermal_states.count(state)
        if count > 0:
            thermal_state_counts[state.value] = count
    print(f"  Thermal State Distribution: {thermal_state_counts}")

    power_mode_counts = {}
    for mode in PowerMode:
        count = power_modes.count(mode)
        if count > 0:
            power_mode_counts[mode.value] = count
    print(f"  Power Mode Distribution: {power_mode_counts}")

else:
    print("No thermal signals collected.")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Dummy dataset
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, n=1000, d_model=64):
        self.data = torch.randn(n, d_model)
        self.targets = torch.randn(n, d_model)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

    def __len__(self):
        return len(self.data)

# Dummy transformer wrapper
class MoETransformerBlock(nn.Module):
    def __init__(self, d_model, num_experts, top_k, thermal_signal_generator):
        super().__init__()
        self.gate = nn.Linear(d_model, num_experts)
        experts = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_experts)])
        self.moe_layer = SimpleMoELayer(self.gate, experts, top_k, thermal_signal_generator=thermal_signal_generator)

    def forward(self, x):
        return self.moe_layer(x)

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
d_model = 64
num_experts = 8
top_k = 2
epochs = 3
batch_size = 32

dataset = DummyDataset(n=500, d_model=d_model)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

thermal_signal = ThermalSignalGenerator(device_id=0)
model = MoETransformerBlock(d_model, num_experts, top_k, thermal_signal_generator=thermal_signal).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training Loop
for epoch in range(epochs):
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        output, _, selected_experts = model(x)
        task_loss = criterion(output, y)
        energy_loss = compute_energy_loss(selected_expert_indices=selected_experts,
                                          expert_profiles=thermal_signal.expert_profiles,
                                          alpha=0.001)
        loss = task_loss + energy_loss
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} done. Loss: {loss.item():.4f}")
