In [None]:
import torch
import numpy as np
from thesis.utils import prepare_loaders, calculate_noise_variance
from thesis.dataset import DeepMIMOGenerator

datasets = []
num_users = 8

generator = DeepMIMOGenerator()
total_samples = 1000

data_spatial, _ = generator.generate_dataset(
        num_samples=int(total_samples * 0.45),
        num_users=num_users,
        min_corr=0.90, 
        max_corr=0.99,
        max_gain_ratio=10.0 
    )

data_power, _ = generator.generate_dataset(
        num_samples=int(total_samples * 0.45),
        num_users=num_users,
        min_corr=0.0,
        max_corr=0.5,
        max_gain_ratio=20.0
    )

data_std, _ = generator.generate_dataset(
        num_samples=int(total_samples * 0.10),
        num_users=num_users,
        min_corr=0.0,
        max_corr=0.7,
        max_gain_ratio=20.0
    )

final_dataset = np.concatenate([data_spatial, data_power, data_std], axis=0)
dataset_tensor = torch.tensor(final_dataset)

train_loader, val_loader, test_loader = prepare_loaders(dataset_tensor)

noise_variance = calculate_noise_variance(dataset_tensor, 5, "mean")


Basestation 1

UE-BS Channels


Reading ray-tracing: 100%|██████████| 42984/42984 [00:00<00:00, 144732.07it/s]
Generating channels: 100%|██████████| 42984/42984 [00:05<00:00, 7485.48it/s] 
Generating Scenarios: 100%|██████████| 450/450 [00:00<00:00, 554.31it/s]
Generating Scenarios: 100%|██████████| 450/450 [00:02<00:00, 158.54it/s]
Generating Scenarios: 100%|██████████| 100/100 [00:00<00:00, 338.56it/s]


In [None]:
import torch
import torch.optim as optim
from thesis.user_scheduling import JointUtilityLoss

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 1e-4

# Annealing Schedule 
START_TEMP = 5.0   # High temp = Fuzzy exploration
END_TEMP = 0.1     # Low temp = Hard decision (Deployment mode)
ANNEAL_EPOCHS = 20 # Reach end_temp by epoch 20

def get_current_temperature(epoch):
    """Calculates Softmax temperature for the current epoch."""
    if epoch >= ANNEAL_EPOCHS:
        return END_TEMP
    decay = (START_TEMP - END_TEMP) / ANNEAL_EPOCHS
    return START_TEMP - (decay * epoch)

def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0

    # Use hard/sharp temperature for validation to mimic real performance
    val_temp = END_TEMP 
    
    with torch.no_grad():
        for batch in val_loader:
            channels_raw = batch[0].to(device)
            
            # Forward
            probs, powers = model(channels_raw, temperature=val_temp)
            
            # Physics World: Real Metrics
            loss = criterion(probs, powers, channels_raw)
            val_loss += loss.item()
            
    avg_loss = val_loss / len(val_loader)
    print(f"   >> Validation Loss: {avg_loss:.4f}")
    return avg_loss

# --- Main Training Loop ---
def train_joint_scheduler(model, train_loader, val_loader, noise_variance):
    print("Initializing Joint Training...")
    
    # Optimizer: Updates both LWM backbone (if unfrozen) and Heads
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Loss: Joint Utility
    criterion = JointUtilityLoss(alpha_entropy=0.05, alpha_power=0.01, noise_var=noise_variance)
    
    model.to(DEVICE)
    
    for epoch in range(EPOCHS):
        model.train()
        current_temp = get_current_temperature(epoch)
        epoch_loss = 0.0
        
        for batch_idx, batch in enumerate(train_loader):
            channels_raw = batch[0].to(DEVICE)
            
            # 1. Forward Pass ---
            probs, powers = model(channels_raw, temperature=current_temp)
            
            # 2. Loss Calculation ---
            loss = criterion(probs, powers, channels_raw)
            
            # 3. Optimization ---
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient Clipping (Crucial for LWM stability)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += loss.item()
            
        # --- Logging & Validation ---
        avg_train_loss = epoch_loss / len(train_loader)
        
        print(f"Epoch {epoch+1}/{EPOCHS} | Temp: {current_temp:.2f} | Train Loss: {avg_train_loss:.4f}")
        
        # Validate every 5 epochs or at the end
        if (epoch + 1) % 5 == 0:
            validate(model, val_loader, criterion, DEVICE)

# --- Usage ---
from thesis.dataset import Tokenizer
from thesis.downstream_models import WrapperAllocation, AssignmentHead, PowerAllocationHead
from thesis.lwm_model import lwm

tokenizer = Tokenizer(4, 4)
lwm_model = lwm.from_pretrained("./models/model.pth")
assignment_head = AssignmentHead()
allocation_head = PowerAllocationHead()
full_model = WrapperAllocation(tokenizer, lwm_model, assignment_head, allocation_head)

train_joint_scheduler(full_model, train_loader, val_loader, noise_variance)

Model loaded successfully from ./models/model.pth
Initializing Joint Training...
Epoch 1/50 | Temp: 5.00 | Train Loss: -1.9597
Epoch 2/50 | Temp: 4.75 | Train Loss: -1.9717
Epoch 3/50 | Temp: 4.51 | Train Loss: -1.9884
Epoch 4/50 | Temp: 4.26 | Train Loss: -2.0043
Epoch 5/50 | Temp: 4.02 | Train Loss: -2.0339
   >> Validation Loss: -5.3532
Epoch 6/50 | Temp: 3.77 | Train Loss: -2.0683
Epoch 7/50 | Temp: 3.53 | Train Loss: -2.1084
Epoch 8/50 | Temp: 3.29 | Train Loss: -2.1662
Epoch 9/50 | Temp: 3.04 | Train Loss: -2.2254
Epoch 10/50 | Temp: 2.79 | Train Loss: -2.3090
   >> Validation Loss: -8.6664
Epoch 11/50 | Temp: 2.55 | Train Loss: -2.4167
Epoch 12/50 | Temp: 2.30 | Train Loss: -2.5548
Epoch 13/50 | Temp: 2.06 | Train Loss: -2.7464
Epoch 14/50 | Temp: 1.81 | Train Loss: -3.0064
Epoch 15/50 | Temp: 1.57 | Train Loss: -3.3839
   >> Validation Loss: -11.1768
Epoch 16/50 | Temp: 1.32 | Train Loss: -3.9520
Epoch 17/50 | Temp: 1.08 | Train Loss: -4.8173
Epoch 18/50 | Temp: 0.83 | Train Lo

In [34]:
import pandas as pd
import torch

class WirelessBenchmark:
    def __init__(self, model, patch_cols, noise_variance=1e-13, total_power=1.0, device="cpu"):
        self.model = model.to(device)
        self.patch_cols = patch_cols
        self.noise = noise_variance
        self.total_power = total_power
        self.device = device
        self.model.eval()
        
        # Removed "Proportional Fair (EPA)"
        self.global_results = {
            "Greedy (EPA)": [], 
            "Round Robin (EPA)": [], 
            "AI Model (Joint)": []
        }

    def _calculate_high_fidelity_rates(self, channels, allocation_map, power_map_per_block):
        # ... (Same Physics Logic as before) ...
        B, K, M, SC = channels.shape
        Num_Blocks = allocation_map.shape[1]
        SC_per_block = SC // Num_Blocks

        # Expand to Subcarriers
        alloc_sc = allocation_map.repeat_interleave(SC_per_block, dim=1)
        power_sc = (power_map_per_block / SC_per_block).repeat_interleave(SC_per_block, dim=1)
        
        # Physics (MRC)
        channel_gains_sc = torch.sum(torch.abs(channels)**2, dim=2)
        winner_gains_sc = torch.gather(channel_gains_sc, 1, alloc_sc.unsqueeze(1)).squeeze(1)
        
        # Rate Calculation
        snr_sc = (power_sc * winner_gains_sc) / self.noise
        rate_sc = torch.log2(1 + snr_sc)
        
        # Aggregation per User
        user_indices = torch.arange(K, device=self.device).view(1, K, 1)
        mask = (alloc_sc.unsqueeze(1) == user_indices).float()
        user_rates = torch.sum(rate_sc.unsqueeze(1) * mask, dim=2) # (B, K)
        
        return user_rates

    def run_baselines(self, channels):
        B, K, M, SC = channels.shape
        num_blocks = SC // self.patch_cols
        
        # Pre-calc block stats
        mag_sq = torch.abs(channels)**2
        gain_freq = torch.mean(mag_sq, dim=2)
        gain_reshaped = gain_freq.view(B, K, num_blocks, self.patch_cols)
        block_gains = torch.mean(gain_reshaped, dim=3)
        
        power_per_block = self.total_power / num_blocks
        epa_map = torch.full((B, num_blocks), power_per_block, device=self.device)
        
        block_snrs = (block_gains * power_per_block) / self.noise
        block_rates = torch.log2(1 + block_snrs)
        
        # 1. Greedy
        greedy_alloc = torch.argmax(block_rates, dim=1)
        self.global_results["Greedy (EPA)"].append(
            self._calculate_high_fidelity_rates(channels, greedy_alloc, epa_map))
            
        # 2. Round Robin
        start_offsets = torch.randint(0, K, (B, 1), device=self.device)
        rr_alloc = (torch.arange(num_blocks, device=self.device).unsqueeze(0) + start_offsets) % K
        self.global_results["Round Robin (EPA)"].append(
            self._calculate_high_fidelity_rates(channels, rr_alloc, epa_map))
            
        # REMOVED: Proportional Fair Logic

    def run_ai(self, channels):
        # ... (Same AI Logic as before) ...
        batch_max = torch.amax(torch.abs(channels), dim=(1,2,3), keepdim=True)
        channels_norm = channels / (batch_max + 1e-12)
        
        with torch.no_grad():
            probs, powers = self.model(channels_norm, temperature=0.01)
            ai_alloc = torch.argmax(probs, dim=1)
            ai_power_raw = torch.gather(powers, 1, ai_alloc.unsqueeze(1)).squeeze(1)
            
            winner_sum = torch.sum(ai_power_raw, dim=1, keepdim=True)
            scale_factor = self.total_power / (winner_sum + 1e-12)
            ai_power_final = ai_power_raw * scale_factor
            
        self.global_results["AI Model (Joint)"].append(
            self._calculate_high_fidelity_rates(channels, ai_alloc, ai_power_final))

    def get_summary(self):
        summary_data = []
        for alg, rates_list in self.global_results.items():
            if not rates_list: continue
            all_rates = torch.cat(rates_list, dim=0) 
            
            # 1. Sum Rate
            avg_sum_rate = torch.mean(torch.sum(all_rates, dim=1)).item()
            
            # 2. Fairness (Global over K users)
            K = all_rates.shape[1] 
            sum_r = torch.sum(all_rates, dim=1)
            sum_r_sq = torch.sum(all_rates**2, dim=1)
            
            jain_samples = (sum_r**2) / (K * sum_r_sq + 1e-12)
            avg_fairness = torch.mean(jain_samples).item()
            
            # 3. Edge Rate (5th Percentile of ALL users)
            edge_rate = torch.quantile(all_rates.float(), 0.05).item()
                
            summary_data.append({
                "Algorithm": alg,
                "Sum Rate": avg_sum_rate,
                "Fairness": avg_fairness,
                "Edge Rate": edge_rate
            })
            
        return pd.DataFrame(summary_data).sort_values("Sum Rate", ascending=False)

def run_benchmark(model, test_loader, patch_cols, noise_var, device):
    print(f"Running Benchmark (Noise={noise_var:.1e})...")
    benchmarker = WirelessBenchmark(model, patch_cols, noise_variance=noise_var, device=device)
    
    for batch in test_loader:
        channels = batch[0].to(device)
        benchmarker.run_baselines(channels)
        benchmarker.run_ai(channels)
        
    summary = benchmarker.get_summary()
    print("\n" + "="*80)
    print("FINAL BENCHMARK RESULTS")
    print("="*80)
    print(summary.to_string(index=False))

run_benchmark(full_model, test_loader, 4, noise_variance, DEVICE)

Running Benchmark (Noise=6.1e-10)...

FINAL BENCHMARK RESULTS
        Algorithm  Sum Rate  Fairness  Edge Rate
     Greedy (EPA) 63.265110  0.140501   0.000000
 AI Model (Joint) 39.401573  0.169112   0.000000
Round Robin (EPA) 26.151699  0.537629   0.000505


In [28]:
import torch
import torch.nn as nn

class SimpleCNNEncoder(nn.Module):
    """
    A standard Deep Learning baseline: 
    Learns features directly from raw complex channels (Real/Imag).
    """
    def __init__(self, num_antennas=1, num_subcarriers=64, embed_dim=128):
        super().__init__()
        
        # Input: (Batch, 2, Antennas, Subcarriers) - 2 for Real/Imag
        self.conv_net = nn.Sequential(
            # Layer 1
            nn.Conv2d(2, 16, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            
            # Layer 2
            nn.Conv2d(16, 32, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            # Layer 3
            nn.Conv2d(32, 64, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # Global Average Pooling over Frequency/Antennas
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        # Project to same embedding dimension as LWM (128)
        self.projection = nn.Linear(64, embed_dim)

    def forward(self, channels):
        # channels shape: (B, 1, M, SC) complex
        # Convert to (B, 2, M, SC) real
        x = torch.cat([channels.real, channels.imag], dim=1)
        
        # Feature Extraction
        features = self.conv_net(x) # (B, 64, 1, 1)
        features = features.view(features.size(0), -1) # Flatten -> (B, 64)
        
        # Project
        embeddings = self.projection(features) # (B, 128)
        return embeddings

class RawChannelModel(nn.Module):
    def __init__(self, num_users, num_blocks, embed_dim=128):
        super().__init__()
        
        # 1. The Baseline Encoder (Trainable)
        self.encoder = SimpleCNNEncoder(embed_dim=embed_dim)
        
        # 2. The SAME Heads as the LWM Model
        # (We assume CarrierAllocation class handles the heads internally)
        # Note: We need to replicate the 'CarrierAllocation' logic here manually
        # since CarrierAllocation expects an LWM encoder.
        
        self.patch_cols = 4 # Hardcoded for match
        self.allocation_head = PowerAllocationHead(embed_dim=embed_dim, total_power=1.0)
        
        # Simple Assignment Head (Dot product attention style)
        self.assignment_net = nn.Sequential(
            nn.Linear(embed_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_blocks)
        )

    def forward(self, channels, temperature=1.0):
        # channels: (B, K, M, SC)
        B, K, M, SC = channels.shape
        
        # 1. Encode Each User Independently
        # Flatten batch and users: (B*K, 1, M, SC)
        channels_flat = channels.view(-1, 1, M, SC)
        embeddings_flat = self.encoder(channels_flat)
        
        # Reshape back: (B, K, Embed_Dim)
        user_embeddings = embeddings_flat.view(B, K, -1)
        
        # --- HEADS (Replicating LWM Logic) ---
        
        # 2. Assignment Head
        # (B, K, Embed) -> (B, K, Blocks)
        logits = self.assignment_net(user_embeddings) 
        
        # Softmax over Users (dim 1) for each block
        # "Which user gets block b?"
        probs = F.softmax(logits / temperature, dim=1)
        
        # 3. Power Head (Using the specific class we wrote)
        # Expand embeddings to blocks: (B, K, Blocks, Embed)
        emb_expanded = user_embeddings.unsqueeze(2).expand(-1, -1, logits.shape[2], -1)
        
        # Predict Power
        powers = self.allocation_head(emb_expanded, probs)
        
        return probs, powers
    
raw_model = RawChannelModel(8, 8)

train_joint_scheduler(raw_model, train_loader, val_loader, noise_variance)

Initializing Joint Training...
Epoch 1/50 | Temp: 5.00 | Train Loss: -2.0051
Epoch 2/50 | Temp: 4.75 | Train Loss: -2.0814
Epoch 3/50 | Temp: 4.51 | Train Loss: -2.2339
Epoch 4/50 | Temp: 4.26 | Train Loss: -2.5164
Epoch 5/50 | Temp: 4.02 | Train Loss: -3.0656
   >> Validation Loss: -4.0803
Epoch 6/50 | Temp: 3.77 | Train Loss: -4.1231
Epoch 7/50 | Temp: 3.53 | Train Loss: -5.5099
Epoch 8/50 | Temp: 3.29 | Train Loss: -7.1368
Epoch 9/50 | Temp: 3.04 | Train Loss: -9.1425
Epoch 10/50 | Temp: 2.79 | Train Loss: -11.4419
   >> Validation Loss: -7.9193
Epoch 11/50 | Temp: 2.55 | Train Loss: -13.7310
Epoch 12/50 | Temp: 2.30 | Train Loss: -16.4275
Epoch 13/50 | Temp: 2.06 | Train Loss: -19.2832
Epoch 14/50 | Temp: 1.81 | Train Loss: -22.7844
Epoch 15/50 | Temp: 1.57 | Train Loss: -26.8041
   >> Validation Loss: -0.1292
Epoch 16/50 | Temp: 1.32 | Train Loss: -31.2163
Epoch 17/50 | Temp: 1.08 | Train Loss: -34.5708
Epoch 18/50 | Temp: 0.83 | Train Loss: -37.4593
Epoch 19/50 | Temp: 0.59 | Tra

In [31]:
class WirelessBenchmark:
    def __init__(self, model, noise_variance=1e-11, total_power=1.0, device="cpu"):
        self.model = model.to(device)
        self.patch_cols = 4
        self.noise = noise_variance
        self.total_power = total_power
        self.device = device
        self.model.eval()

        self.global_results = {
            "Greedy (EPA)": [], "Round Robin (EPA)": [], 
            "Proportional Fair (EPA)": [], "AI Model (Joint)": []
        }

    def _calculate_high_fidelity_rates(self, channels, allocation_map, power_map_per_block):
        B, K, M, SC = channels.shape
        Num_Blocks = allocation_map.shape[1]
        SC_per_block = SC // Num_Blocks

        # 1. Expand Block Decisions to Subcarriers
        # (Batch, Num_Blocks) -> (Batch, SC)
        alloc_sc = allocation_map.repeat_interleave(SC_per_block, dim=1)
        
        power_per_sc_val = power_map_per_block / SC_per_block
        power_sc = power_per_sc_val.repeat_interleave(SC_per_block, dim=1)
        
        # 2. Physics: Calculate SNR per Subcarrier (MRC)
        # (Batch, Users, SC)
        channel_gains_sc = torch.sum(torch.abs(channels)**2, dim=2)
        
        # Gather winner gains
        winner_gains_sc = torch.gather(channel_gains_sc, 1, alloc_sc.unsqueeze(1)).squeeze(1)
        
        # 3. Calculate Rate
        snr_sc = (power_sc * winner_gains_sc) / self.noise
        rate_sc = torch.log2(1 + snr_sc)
        
        # 4. Sum Rates per User
        user_indices = torch.arange(K, device=self.device).view(1, K, 1)
        mask = (alloc_sc.unsqueeze(1) == user_indices).float()
        
        # Result: (Batch, K) - Total Rate for each user in the batch
        user_rates = torch.sum(rate_sc.unsqueeze(1) * mask, dim=2)
        
        return user_rates

    def run_baselines(self, channels):
        B, K, M, SC = channels.shape
        num_blocks = SC // self.patch_cols
        
        # --- Pre-calculate Block Metrics for Decisions ---
        # 1. Block Gains
        mag_sq = torch.abs(channels)**2
        gain_freq = torch.mean(mag_sq, dim=2)
        gain_reshaped = gain_freq.view(B, K, num_blocks, self.patch_cols)
        block_gains = torch.mean(gain_reshaped, dim=3)
        
        # 2. EPA Power
        power_per_block = self.total_power / num_blocks
        epa_map = torch.full((B, num_blocks), power_per_block, device=self.device)

        # 3. Decision SNR (Linear)
        block_snrs_est = (block_gains * power_per_block) / self.noise
        
        # 4. Decision Rates (Logarithmic)
        block_rates_est = torch.log2(1 + block_snrs_est)

        # --- Algorithm 1: Greedy ---
        # Maximize Rate on each block
        greedy_alloc = torch.argmax(block_rates_est, dim=1)
        rates_greedy = self._calculate_high_fidelity_rates(channels, greedy_alloc, epa_map)
        self.global_results["Greedy (EPA)"].append(rates_greedy)

        # --- Algorithm 2: Round Robin (Randomized Start) ---
        # Fix: Randomize start index to prevent starvation of users K > N
        start_offsets = torch.randint(0, K, (B, 1), device=self.device)
        block_indices = torch.arange(num_blocks, device=self.device).unsqueeze(0) # (1, Blocks)
        rr_alloc = (block_indices + start_offsets) % K
        
        rates_rr = self._calculate_high_fidelity_rates(channels, rr_alloc, epa_map)
        self.global_results["Round Robin (EPA)"].append(rates_rr)

        # --- Algorithm 3: Proportional Fair (Corrected) ---
        # Metric: Rate_est / Avg_Rate_est
        # We approximate Avg_Rate_est as the mean rate available to the user across all blocks in this snapshot
        user_avg_rate = torch.mean(block_rates_est, dim=2, keepdim=True)
        pf_metric = block_rates_est / (user_avg_rate + 1e-8)
        
        pf_alloc = torch.argmax(pf_metric, dim=1)
        rates_pf = self._calculate_high_fidelity_rates(channels, pf_alloc, epa_map)
        self.global_results["Proportional Fair (EPA)"].append(rates_pf)

    def run_ai(self, channels):
        # Normalize
        user_pwr = torch.mean(torch.abs(channels)**2, dim=(2,3), keepdim=True)
        user_scale = torch.sqrt(user_pwr + 1e-12)
        channels_norm = channels / user_scale
        
        with torch.no_grad():
            probs, powers = self.model(channels_norm, temperature=0.01)
            
            # Assignment
            ai_alloc = torch.argmax(probs, dim=1)
            
            # Power Extraction & Re-normalization
            ai_power_raw = torch.gather(powers, 1, ai_alloc.unsqueeze(1)).squeeze(1)
            winner_sum = torch.sum(ai_power_raw, dim=1, keepdim=True)
            scale = self.total_power / (winner_sum + 1e-12)
            ai_power_map = ai_power_raw * scale
            
        rates_ai = self._calculate_high_fidelity_rates(channels, ai_alloc, ai_power_map)
        self.global_results["AI Model (Joint)"].append(rates_ai)

    def get_summary(self):
        """Computes global statistics across the entire dataset."""
        summary_data = []
        
        for alg_name, rate_list in self.global_results.items():
            if not rate_list: continue
            
            # Concatenate all batches -> (Total_Samples, K)
            all_rates = torch.cat(rate_list, dim=0)
            
            # Mask Ghosts globally
            # We assume a user is "active" if they have non-zero rate potential in general
            # For simplicity, we filter 0.0 rates if they are strictly 0 due to channel 
            # But safer: Filter based on rate > 1e-6 (Effective Zero)
            active_rates = all_rates[all_rates > 1e-6] 
            
            # 1. Sum Rate (Avg per cell)
            # Sum over users, then Mean over samples
            avg_sum_rate = torch.mean(torch.sum(all_rates, dim=1)).item()
            
            # 2. Fairness (Global Jain's)
            # Calculated per sample, then averaged
            sum_r = torch.sum(all_rates, dim=1)
            sum_r_sq = torch.sum(all_rates**2, dim=1)
            # Count active users per sample (users with Rate > 0)
            n_active = torch.sum(all_rates > 1e-6, dim=1)
            
            # Avoid div/0 for empty samples
            valid_mask = n_active > 0
            jain_samples = (sum_r[valid_mask]**2) / (n_active[valid_mask] * sum_r_sq[valid_mask] + 1e-12)
            avg_fairness = torch.mean(jain_samples).item()
            
            # 3. Edge Rate (True Global 5th Percentile)
            edge_rate = torch.quantile(active_rates, 0.05).item()
            
            summary_data.append({
                "Algorithm": alg_name,
                "Sum Rate (bps/Hz)": avg_sum_rate,
                "Fairness (Jain's)": avg_fairness,
                "Edge User Rate": edge_rate
            })
            
        df = pd.DataFrame(summary_data)
        return df.sort_values(by="Sum Rate (bps/Hz)", ascending=False)

def run_benchmark(model, test_loader, patch_cols, noise_var, device):
    print(f"Running Benchmark (Noise={noise_var:.1e})...")
    benchmarker = WirelessBenchmark(model, noise_variance=noise_var, device=device)
    
    for batch in test_loader:
        channels = batch[0].to(device)
        benchmarker.run_baselines(channels)
        benchmarker.run_ai(channels)
        
    summary = benchmarker.get_summary()
    print("\n" + "="*80)
    print("FINAL BENCHMARK RESULTS")
    print("="*80)
    print(summary.to_string(index=False))

run_benchmark(raw_model, test_loader, 4, noise_variance, DEVICE)


Running Benchmark (Noise=6.1e-10)...

FINAL BENCHMARK RESULTS
              Algorithm  Sum Rate (bps/Hz)  Fairness (Jain's)  Edge User Rate
           Greedy (EPA)          63.265110           0.973548       12.656716
       AI Model (Joint)          39.050255           0.968829        5.474275
      Round Robin (EPA)          26.218662           0.538065        0.000505
Proportional Fair (EPA)          25.770891           0.579959        0.000607
