In [16]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from utils import get_parameters, prepare_loaders
from train import train_downstream_model

In [13]:
# Parameters

DATASET = "sim_city_6_miami_jülich"

parameters = get_parameters(f"../data/{DATASET}/parameters.txt")
N_SAMPLES = parameters["samples"]
USERS = parameters["users"]
SUBCARRIERS = parameters["subcarriers"]
BS_ATENNAS = parameters["bs_antennas"]

P_TOTAL = parameters["p_total"]
NOISE_VARIANCE = parameters["sigma2"]

# Load data
channel_array = np.load(f"../data/{DATASET}/channels.npy")
channel_array = np.stack((channel_array.real, channel_array.imag), axis=-1)
channel_tensor = torch.tensor(channel_array)

train_loader, val_loader, test_loader = prepare_loaders(channel_tensor)

INFERENCE_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
from lwm_model import lwm
from downstream_models import RegressionHead

class DownstreamWrapper(nn.Module):
    """
    Wrapper around:
      - lwm (sequence model / encoder)
      - RegressionHead (power / precoder regressor)

    Output shape: (B, C, K, Nt)
    """

    def __init__(
        self,
        n_carriers: int,
        n_users: int,
        n_antennas: int,
        d_model: int,
    ):
        super().__init__()

        self.encoder = lwm()
        self.n_carriers = n_carriers
        self.n_users = n_users
        self.n_antennas = n_antennas

        # infer d_model from encoder

        self.regressor = RegressionHead(
            d_model=d_model,
            n_carriers=n_carriers,
            n_users=n_users,
            n_antennas=n_antennas,
        )

    def load_weights(self, path, device):
        state_dict = torch.load(path, map_location=device)
        new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        self.encoder.load_state_dict(new_state_dict)

    def forward(self, channels, p_total):
        """
        input_ids: (B, S)
        masked_pos: optional (B, S')
        returns:
            power allocation: (B, C, K, Nt)
        """
        B, S, N, K, _ = channels.shape

        input_lwm = channels.reshape(B, S, -1)
        embeddings, _ = self.encoder(input_lwm)

        # features shape: (B, S, d_model)
        
        # Step 2: Regression head
        power = self.regressor(embeddings, p_total)

        # power shape: (B, C, K, Nt)
        return power


In [21]:
EPOCHS = 15
LEARNING_RATE = 1e-2
TRAINING_RATIOS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

for training_ratio in TRAINING_RATIOS:
    LOGFILE = f"../data/{DATASET}/train_lwm_{training_ratio}.csv"

    dataset = train_loader.dataset
    n_subset = max(1, int(training_ratio * len(dataset)))
    indices = list(range(len(dataset)))
    indices = indices = np.random.permutation(indices)
    subset_indices = indices[:n_subset].tolist()
    subset = Subset(dataset, subset_indices)

    fraction_train_loader = DataLoader(subset, batch_size=train_loader.batch_size)

    model_lwm = DownstreamWrapper(SUBCARRIERS, USERS, BS_ATENNAS, 128).to(INFERENCE_DEVICE)

    for param in model_lwm.encoder.parameters():
        param.requires_grad = False

    trainable_params = filter(lambda p: p.requires_grad, model_lwm.parameters())
    opt_lwm = optim.Adam(trainable_params, lr=LEARNING_RATE)

    scheduler_lwm = ReduceLROnPlateau(opt_lwm, mode='min', factor=0.1, patience=5)

    train_downstream_model(model_lwm, P_TOTAL, NOISE_VARIANCE,
                           fraction_train_loader, val_loader, opt_lwm, scheduler_lwm,
                           EPOCHS, INFERENCE_DEVICE, LOGFILE)


    

100%|██████████| 15/15 [00:07<00:00,  2.06it/s, Train Loss=-379, Validation Loss=-400]
100%|██████████| 15/15 [00:03<00:00,  4.38it/s, Train Loss=-430, Validation Loss=-419]
100%|██████████| 15/15 [00:03<00:00,  3.89it/s, Train Loss=-416, Validation Loss=-418]
100%|██████████| 15/15 [00:04<00:00,  3.42it/s, Train Loss=-421, Validation Loss=-419]
100%|██████████| 15/15 [00:05<00:00,  2.99it/s, Train Loss=-409, Validation Loss=-426]
100%|██████████| 15/15 [00:05<00:00,  2.65it/s, Train Loss=-412, Validation Loss=-426]
100%|██████████| 15/15 [00:06<00:00,  2.34it/s, Train Loss=-435, Validation Loss=-426]
100%|██████████| 15/15 [00:07<00:00,  2.08it/s, Train Loss=-397, Validation Loss=-427]
100%|██████████| 15/15 [00:07<00:00,  1.92it/s, Train Loss=-416, Validation Loss=-419]
100%|██████████| 15/15 [00:08<00:00,  1.78it/s, Train Loss=-420, Validation Loss=-426]


In [27]:
from downstream_models import RegressionHead

class RawDataWrapper(nn.Module):
    """
    Wrapper around:
      - RegressionHead (power / precoder regressor)

    Output shape: (B, C, K, Nt)
    """

    def __init__(
        self,
        n_carriers: int,
        n_users: int,
        n_antennas: int,
    ):
        super().__init__()

        self.n_carriers = n_carriers
        self.n_users = n_users
        self.n_antennas = n_antennas

        d_model = n_antennas * n_users * 2

        self.regressor = RegressionHead(
            d_model=d_model,
            n_carriers=n_carriers,
            n_users=n_users,
            n_antennas=n_antennas,
        )

    def forward(self, channels, p_total):
        """
        input_ids: (B, S)
        returns:
            power allocation: (B, C, K, Nt)
        """
        B, S, N, K, _ = channels.shape

        input = channels.reshape(B, S, -1)

        power = self.regressor(input, p_total)

        # power shape: (B, C, K, Nt)
        return power


In [28]:
EPOCHS = 15
LEARNING_RATE = 1e-2
TRAINING_RATIOS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

for training_ratio in TRAINING_RATIOS:
    LOGFILE = f"../data/{DATASET}/train_raw_{training_ratio}.csv"

    dataset = train_loader.dataset
    n_subset = max(1, int(training_ratio * len(dataset)))
    indices = list(range(len(dataset)))
    indices = indices = np.random.permutation(indices)
    subset_indices = indices[:n_subset].tolist()
    subset = Subset(dataset, subset_indices)

    fraction_train_loader = DataLoader(subset, batch_size=train_loader.batch_size)

    model_raw = RawDataWrapper(SUBCARRIERS, USERS, BS_ATENNAS).to(INFERENCE_DEVICE)

    opt_raw = optim.Adam(model_raw.parameters(), lr=LEARNING_RATE)

    scheduler_raw = ReduceLROnPlateau(opt_lwm, mode='min', factor=0.1, patience=5)

    train_downstream_model(model_raw, P_TOTAL, NOISE_VARIANCE,
                           fraction_train_loader, val_loader, opt_raw, scheduler_raw,
                           EPOCHS, INFERENCE_DEVICE, LOGFILE)


    

100%|██████████| 15/15 [00:00<00:00, 35.43it/s, Train Loss=-271, Validation Loss=-376]
100%|██████████| 15/15 [00:00<00:00, 30.10it/s, Train Loss=-409, Validation Loss=-421]
100%|██████████| 15/15 [00:00<00:00, 22.35it/s, Train Loss=-401, Validation Loss=-424]
100%|██████████| 15/15 [00:00<00:00, 20.64it/s, Train Loss=-429, Validation Loss=-404]
100%|██████████| 15/15 [00:00<00:00, 17.67it/s, Train Loss=-416, Validation Loss=-427]
100%|██████████| 15/15 [00:00<00:00, 15.73it/s, Train Loss=-413, Validation Loss=-425]
100%|██████████| 15/15 [00:01<00:00, 14.43it/s, Train Loss=-414, Validation Loss=-423]
100%|██████████| 15/15 [00:01<00:00, 13.11it/s, Train Loss=-419, Validation Loss=-418]
100%|██████████| 15/15 [00:01<00:00, 12.00it/s, Train Loss=-421, Validation Loss=-425]
100%|██████████| 15/15 [00:01<00:00, 11.17it/s, Train Loss=-416, Validation Loss=-427]


In [29]:
import torch
import math
from typing import Tuple

def compute_precoder_and_sumrate(
    channels: torch.Tensor,
    noise_variance: float,
    p_total: float = 1.0,
    method: str = "mmse",
    reg_eps: float = 1e-6,
    device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
    """
    Compute linear precoders and sum-rate for multi-user MISO across carriers.

    Args:
        channels: Tensor of shape (C, K, Nt)  -- C carriers, K users, Nt transmit antennas.
                  Each entry is the complex (or real) channel coefficient. If complex,
                  channels should be a complex dtype tensor (torch.complex64/128).
        noise_variance: scalar noise variance (sigma^2) at each user receiver.
        p_total: total transmit power per carrier (scalar). The precoder is scaled so
                 that trace(W W^H) == p_total for each carrier.
        method: "zf" or "mmse" (regularized ZF). Default "mmse".
        reg_eps: small regularizer for numerical stability.
        device: torch device or None (will use channels.device if None).

    Returns:
        precoders: Tensor of shape (C, Nt, K) containing precoder for each carrier.
                   For carrier c, W_c is (Nt x K) matrix mapping K streams to Nt antennas.
        rates_per_carrier: Tensor of shape (C,) sum-rate per carrier (in bits/s/Hz).
        total_sum_rate: scalar = sum over carriers of rates_per_carrier.
    """
    if device is None:
        device = channels.device

    assert channels.ndim == 3, "channels must be shape (C, K, Nt)"
    C, K, Nt = channels.shape
    dtype = channels.dtype

    precoders = torch.zeros((C, Nt, K), dtype=dtype, device=device)
    rates = torch.zeros(C, dtype=dtype, device=device)

    # Loop over carriers (vectorizing is possible but loop is clearer)
    for c in range(C):
        # H : (K x Nt)
        H = channels[c]  # shape (K, Nt)

        # For math convenience, convert to complex-compatible operations if necessary:
        # We assume H is of dtype float or complex; torch handles both.
        # Compute Gram = H @ H^H  -> (K x K)
        # Note: For complex, use conjugate transpose
        if torch.is_complex(H):
            Gram = H @ H.conj().transpose(-1, -2)  # (K x K)
        else:
            Gram = H @ H.transpose(-1, -2)

        # Regularization parameter for MMSE. A standard heuristic:
        # reg = (noise_variance / p_total) * trace(Gram) / K
        # We'll use a simple scalar reg scaled by the average eigenvalue:
        avg_eig = (Gram.diag().real if torch.is_complex(Gram) else Gram.diag()).mean()
        reg = (noise_variance / (p_total + 1e-12)) * (avg_eig.real if torch.is_complex(avg_eig) else avg_eig)
        reg = reg.clamp(min=reg_eps)

        if method.lower() == "zf":
            # Standard ZF: W = H^H (H H^H)^{-1}
            # If Gram is singular, add tiny epsilon
            try:
                inv = torch.linalg.inv(Gram + reg_eps * torch.eye(K, dtype=dtype, device=device))
            except RuntimeError:
                inv = torch.pinverse(Gram + reg_eps * torch.eye(K, dtype=dtype, device=device))
            if torch.is_complex(H):
                W = H.conj().transpose(-1, -2) @ inv  # (Nt x K)
            else:
                W = H.transpose(-1, -2) @ inv
        else:
            # MMSE / regularized ZF: W = H^H (H H^H + alpha I)^{-1}
            alpha = reg  # scalar
            try:
                inv = torch.linalg.inv(Gram + alpha * torch.eye(K, dtype=dtype, device=device))
            except RuntimeError:
                inv = torch.pinverse(Gram + alpha * torch.eye(K, dtype=dtype, device=device))
            if torch.is_complex(H):
                W = H.conj().transpose(-1, -2) @ inv  # (Nt x K)
            else:
                W = H.transpose(-1, -2) @ inv  # (Nt x K)

        # Normalize W to satisfy total transmit power p_total:
        # current power = trace(W W^H) = sum of squared magnitudes of all entries
        if torch.is_complex(W):
            current_power = torch.real(torch.trace(W @ W.conj().transpose(-1, -2)))
        else:
            current_power = torch.trace(W @ W.transpose(-1, -2))

        # If current_power == 0 (rare), skip scaling to avoid div0
        if current_power <= 0:
            scale = 0.0
        else:
            scale = math.sqrt(float(p_total) / float(current_power))

        W = W * scale
        precoders[c] = W

        # Compute per-user SINRs and sum-rate for this carrier
        # For user k: h_k is row k of H (1 x Nt), w_j is column j of W (Nt x 1)
        # signal_power_k = |h_k w_k|^2
        # interference_k = sum_{j != k} |h_k w_j|^2
        # noise = noise_variance
        # SINR_k = signal_power_k / (interference_k + noise_variance)

        # Compute full received covariance: Y = H W  -> (K x K) -> element (k,j) = h_k w_j (complex or real)
        HW = H @ W  # shape (K x K)
        # elementwise magnitude squared
        if torch.is_complex(HW):
            mag2 = (HW.abs() ** 2)  # (K x K)
        else:
            mag2 = (HW ** 2)

        # signal powers are diagonal entries mag2[k,k]
        signal_powers = mag2.diag()  # (K,)
        interference_powers = mag2.sum(dim=1) - signal_powers  # sum over j != k

        sinrs = signal_powers / (interference_powers + noise_variance + 1e-12)

        # rates in bits/s/Hz: log2(1 + SINR)
        rates_c = torch.log2(1.0 + sinrs)
        rates[c] = rates_c.sum().real if torch.is_complex(rates_c) else rates_c.sum()

    total_sum_rate = float(rates.sum())
    return precoders, rates, total_sum_rate

In [33]:
import torch
import math


def precoder_and_sumrate_batch_realimag(
    dataloader,
    noise_variance: float,
    p_total: float = 1.0,
    method: str = "mmse",
    reg_eps: float = 1e-6,
    device: torch.device = None,
):
    """
    Compute average sum-rate over dataset when channels are given as (B, C, K, Nt, 2).

    Format:
        channels[..., 0] = real part
        channels[..., 1] = imaginary part

    Returns:
        avg_sum_rate (float)
    """

    total_rate = 0.0
    total_samples = 0

    for batch in dataloader:
        # Extract channels
        if isinstance(batch, (tuple, list)):
            channels_ri = batch[0]
        else:
            channels_ri = batch

        if device is not None:
            channels_ri = channels_ri.to(device)

        # Convert (B,C,K,Nt,2) → complex tensor (B,C,K,Nt)
        channels = torch.complex(channels_ri[..., 0], channels_ri[..., 1])

        B, C, K, Nt = channels.shape

        for b in range(B):
            H = channels[b]   # (C, K, Nt)

            _, _, sample_sum_rate = compute_precoder_and_sumrate(
                H,
                noise_variance=noise_variance,
                p_total=p_total,
                method=method,
                reg_eps=reg_eps,
                device=device,
            )

            total_rate += sample_sum_rate

        total_samples += B

    avg_sum_rate = total_rate / total_samples
    return avg_sum_rate

avg_rate = precoder_and_sumrate_batch_realimag(
    dataloader=val_loader,
    noise_variance=NOISE_VARIANCE,
    p_total=P_TOTAL,
    method="mmse",
    device=INFERENCE_DEVICE
)

print("Average sum-rate:", avg_rate)


Average sum-rate: 698.6423065185547
