1. Imports

In [1]:
# Standard library imports
import warnings

# Third-party imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm import tqdm

# Local application/library specific imports
import DeepMIMOv3
from LWM_v1_1 import lwm_model
from scenario_props import *
from generate_data import patch_maker, tokenizer
from inference import lwm_inference

warnings.filterwarnings('ignore')

In [2]:
# --- Configuration Constants ---
DEFAULT_NUM_UE_ANTENNAS = 1
DEFAULT_SUBCARRIER_SPACING = 30e3  # Hz
DEFAULT_BS_ROTATION = np.array([0, 0, -135])  # (x, y, z) degrees
DEFAULT_NUM_PATHS = 20
DATASET_FOLDER = './scenarios'

def get_parameters(scenario: str, bs_idx: int = 1) -> dict:
    """Constructs the parameter dictionary for DeepMIMOv3 data generation.

    Args:
        scenario: The name of the scenario (e.g., 'city_6_miami_v1').
        bs_idx: The index of the active base station.

    Returns:
        A dictionary of parameters compatible with DeepMIMOv3.generate_data.
    """
    # Retrieves scenario-specific properties (e.g., antenna counts)
    scenario_configs = scenario_prop()
    
    # Start with default DeepMIMO parameters
    parameters = DeepMIMOv3.default_params()

    # --- Base Configuration ---
    parameters['dataset_folder'] = DATASET_FOLDER
    # Assumes scenario format is 'name_vX' and extracts the base name
    parameters['scenario'] = scenario.split("_v")[0]
    parameters['active_BS'] = np.array([bs_idx])
    parameters['enable_BS2BS'] = False
    parameters['num_paths'] = DEFAULT_NUM_PATHS

    # --- Scenario-Specific Configuration ---
    n_ant_bs = scenario_configs[scenario]['n_ant_bs']
    n_subcarriers = scenario_configs[scenario]['n_subcarriers']
    user_rows_config = scenario_configs[scenario]['n_rows']

    if isinstance(user_rows_config, int):
        parameters['user_rows'] = np.arange(user_rows_config)
    else: # Assumes a tuple or list [start, end]
        parameters['user_rows'] = np.arange(user_rows_config[0], user_rows_config[1])

    # --- Antenna and OFDM Configuration ---
    parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1])  # [Horizontal, Vertical]
    parameters['bs_antenna']['rotation'] = DEFAULT_BS_ROTATION
    parameters['ue_antenna']['shape'] = np.array([DEFAULT_NUM_UE_ANTENNAS, 1])
    parameters['OFDM']['subcarriers'] = n_subcarriers
    parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
    parameters['OFDM']['bandwidth'] = (DEFAULT_SUBCARRIER_SPACING * n_subcarriers) / 1e9  # GHz

    return parameters

def deepmimo_data_cleaning(deepmimo_data):
    """Cleans DeepMIMO data by removing users without a LoS path and scales channel coefficients.

    Args:
        deepmimo_data (dict): The raw data dictionary returned by DeepMIMOv3.generate_data.

    Returns:
        np.ndarray: Cleaned and scaled channel data for users with a valid path.
                    The channel coefficients are multiplied by CHANNEL_SCALING_FACTOR
                    for numerical stability in subsequent processing (e.g., ML models).
    """
    # Define a constant for the scaling factor
    CHANNEL_SCALING_FACTOR = 1e6

    # Identify users with a Line-of-Sight (LoS) path (LoS != -1 indicates a valid path)
    valid_user_indices = np.where(deepmimo_data['user']['LoS'] != -1)[0]

    # Select channel data only for valid users
    cleaned_channels = deepmimo_data['user']['channel'][valid_user_indices]

    # Scale the channel coefficients for numerical stability
    return cleaned_channels * CHANNEL_SCALING_FACTOR

def deepmimo_data_gen(scenario_names: list[str], bs_idxs: list[int] | None = None) -> list[dict]:
    """Generates DeepMIMO channel data for multiple scenarios and base stations.

    Args:
        scenario_names: A list of scenario name strings to generate data for.
        bs_idxs: A list of base station indices to use for each scenario.
                 Defaults to [1, 2, 3] if not provided.

    Returns:
        A list of dictionaries, where each dictionary contains the 'scenario'
        identifier and the corresponding 'channels' data (np.ndarray).
    """
    if bs_idxs is None:
        bs_idxs = [1, 2, 3]

    deepmimo_data = []
    
    # Create a list of all (scenario, bs) pairs to iterate over
    generation_tasks = [(name, idx) for name in scenario_names for idx in bs_idxs]

    # Use tqdm for a user-friendly progress bar
    print(f"Generating data for {len(generation_tasks)} scenario-BS pairs...")
    for scenario_name, bs_idx in tqdm(generation_tasks, desc="Data Generation"):
        parameters = get_parameters(scenario_name, bs_idx)
        # The [0] index selects the user data from the DeepMIMO output
        raw_deepmimo_data = DeepMIMOv3.generate_data(parameters)[0]
        cleaned_channels = deepmimo_data_cleaning(raw_deepmimo_data)
        deepmimo_data.append({"scenario": f"{scenario_name} - BS{bs_idx}", "channels": cleaned_channels})

    return deepmimo_data

def sample(deepmimo_data: list[dict], N_samples: int, n_users: int) -> list[dict]:
    """Generates samples by randomly selecting a scenario and a subset of users' channels.

    Each sample consists of channel data for 'n_users' randomly chosen users
    from a randomly selected scenario.

    Args:
        deepmimo_data: A list of dictionaries, where each dict contains
                       'scenario' (str) and 'channels' (np.ndarray) data.
        N_samples: The total number of samples to generate.
        n_users: The number of users (channels) to select for each sample.

    Returns:
        A list of dictionaries, each representing a sample with 'scenario' and
        'channels' (np.ndarray of shape (n_users, ...)).
    """
    samples = []

    for _ in tqdm(range(N_samples), desc="Sampling"):
        # Randomly select a scenario from the deepmimo_data list
        scenario_idx = np.random.randint(0, len(deepmimo_data))
        selected_scenario_data = deepmimo_data[scenario_idx]
        
        # Randomly select 'n_users' channel indices from the chosen scenario
        num_available_users = selected_scenario_data["channels"].shape[0] # Use shape[0] for number of users
        ue_idxs = np.random.choice(num_available_users, n_users, replace=False) # Use np.random.choice for unique indices
        
        selected_channels = selected_scenario_data["channels"][ue_idxs]

        samples.append({"scenario": selected_scenario_data["scenario"], "channels": selected_channels})

    return samples

def load_lwm_model(model_path: str, device: torch.device) -> nn.Module:
    """Loads the pre-trained LWM model and prepares it for inference"""
    print("Loading LWM model...")
    model = lwm_model.lwm().to(device)
    
    state_dict = torch.load(model_path, map_location=device)
    # Remove 'module.' prefix if the model was saved with DataParallel
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)

    # Use DataParallel if multiple GPUs are available 
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for inference.")
        model = nn.DataParallel(model)

    model.eval() # Set model to evaluation mode
    print("Model loaded successfully.")
    return model
    # --- Dataset Generatio ---

SCENARIO_NAMES = ["city_6_miami"]
BS_IDXS = [1, 2, 3]
N_SAMPLES = 100
N_USERS = 4

deepmimo_data = deepmimo_data_gen(SCENARIO_NAMES, BS_IDXS)
dataset = sample(deepmimo_data, N_SAMPLES, N_USERS)

# --- Main Feature Extraction Logic ---
# 1. Define configuration and load the model ONCE
INFERENCE_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LWM_MODEL_PATH = "LWM_v1_1/models/model.pth"
lwm_model_instance = load_lwm_model(LWM_MODEL_PATH, INFERENCE_DEVICE)

# 2. Prepare all data for batch inference
all_tokens = []
for item in tqdm(dataset, desc="Preparing data"):
    # Prepare input tokens for the LWM model
    patches = patch_maker(item["channels"], 4, 4)
    tokens = tokenizer(patches)
    all_tokens.append(tokens)

    # Reshape channel data for the next stage (the regressor model)
    # Original shape: (K, 1, M, S) -> (4, 1, 16, 32)
    # Squeezed shape: (K, M, S) -> (4, 16, 32)
    # Transposed shape: (S, K, M) -> (32, 4, 16)
    item["channels"] = item["channels"].squeeze().transpose(2, 0, 1)

# 3. Run inference on the entire batch at once
print("Running batch inference...")
# Stack all tokens into a single tensor for efficient processing
# Shape changes from a list of [N_USERS, ...] tensors to one [N_SAMPLES * N_USERS, ...] tensor
all_tokens_tensor = torch.cat(all_tokens, dim=0)

# Perform inference ONCE for all samples
all_embeddings_tensor = lwm_inference(lwm_model_instance, all_tokens_tensor, "cls_emb", INFERENCE_DEVICE)

# 4. Assign the generated embeddings back to the dataset
print("Assigning embeddings...")
# Reshape embeddings to match the dataset structure: [N_SAMPLES, N_USERS, EMBED_DIM]
all_embeddings_tensor = all_embeddings_tensor.view(N_SAMPLES, N_USERS, -1)


# 1. Consolidate channels into a single tensor

# The channel data was already reshaped and stored as numpy arrays in the previous step.
all_channels_tensor = torch.from_numpy(np.array([d['channels'] for d in dataset])).cfloat()

# 2. Define split sizes and batch size
SPLIT = [int(0.7*N_SAMPLES), int(0.2*N_SAMPLES), int(0.1*N_SAMPLES)]
BATCH_SIZE= 32

# 3. Create the TensorDataset and split it
base_dataset = TensorDataset(all_channels_tensor, all_embeddings_tensor)
train_subset, val_subset, test_subset = random_split(base_dataset, SPLIT)

# 4. Create DataLoaders for training, validation, and testing
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False)

Generating data for 3 scenario-BS pairs...


Data Generation:   0%|          | 0/3 [00:00<?, ?it/s]


Basestation 1

UE-BS Channels


Reading ray-tracing: 100%|██████████| 42984/42984 [00:00<00:00, 157412.33it/s]
Generating channels: 100%|██████████| 42984/42984 [00:02<00:00, 17373.04it/s]
Data Generation:  33%|███▎      | 1/3 [00:05<00:10,  5.02s/it]


Basestation 2

UE-BS Channels


Reading ray-tracing: 100%|██████████| 42984/42984 [00:00<00:00, 197432.40it/s]
Generating channels: 100%|██████████| 42984/42984 [00:02<00:00, 17994.24it/s]
Data Generation:  67%|██████▋   | 2/3 [00:09<00:04,  4.62s/it]


Basestation 3

UE-BS Channels


Reading ray-tracing: 100%|██████████| 42984/42984 [00:00<00:00, 178360.42it/s]
Generating channels: 100%|██████████| 42984/42984 [00:02<00:00, 15708.61it/s]
Data Generation: 100%|██████████| 3/3 [00:13<00:00,  4.53s/it]
Sampling: 100%|██████████| 100/100 [00:00<00:00, 3449.38it/s]


Loading LWM model...
Model loaded successfully.


Preparing data: 100%|██████████| 100/100 [00:00<00:00, 2103.59it/s]

Running batch inference...





Assigning embeddings...


In [3]:
EPS = 1e-8
LN2 = torch.log(torch.tensor(2.0))

# -------------------------
# Example regressor model
# -------------------------
class RegressorBeamformer(nn.Module):
    """
    Input:
      embeddings: (B, K, emb_dim)
    Output:
      f_r, f_i: real and imag parts of precoder columns
        shapes: (B, S, K, M)
      p: powers (B, S, K)
    Parameterization: direction (unit-norm complex vector) + power scalars.
    """
    def __init__(self, emb_dim, K, M, S, hidden=256):
        super().__init__()
        self.K = K
        self.M = M
        self.S = S

        # Per-user encoder: shared MLP applied to each user's embedding.
        # We process (B, K, emb_dim) as a sequence of K tokens.
        self.user_encoder = nn.Sequential(
            nn.Linear(emb_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )

        # Cross-user interaction: a small transformer-like attention block (optional but helps)
        # We'll implement a single self-attention layer over K users.
        self.attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=4, batch_first=True)
        self.attn_ln = nn.LayerNorm(hidden)

        # Heads produce per-subcarrier outputs for each user:
        # - direction head outputs 2*M*S values (real+imag flattened) per user embedding token
        # - power head outputs S scalars per user token
        self.dir_head = nn.Linear(hidden, S * M * 2)  # real+imag
        self.pow_head = nn.Linear(hidden, S)          # raw scalars per subcarrier

        # optionally: small residual MLP after attention
        self.post = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU())

    def forward(self, embeddings):
        # embeddings: (B, K, emb_dim)
        B, K, emb_dim = embeddings.shape
        assert K == self.K, f"K mismatch: {K} vs {self.K}"

        # encode per user
        z = self.user_encoder(embeddings.view(B * K, emb_dim))  # (B*K, hidden)
        z = z.view(B, K, -1)  # (B, K, hidden)

        # cross-user attention (query=key=value=z)
        attn_out, _ = self.attn(z, z, z, need_weights=False)  # (B,K,hidden)
        z = self.attn_ln(z + attn_out)
        z = self.post(z)  # (B,K,hidden)

        # produce heads
        dir_raw = self.dir_head(z)  # (B,K, S*M*2)
        pow_raw = self.pow_head(z)  # (B,K,S)

        # reshape directions: we want (B, S, K, M, 2)
        dir_raw = dir_raw.view(B, K, self.S, self.M, 2).permute(0, 2, 1, 3, 4)  # (B,S,K,M,2)
        pow_raw = pow_raw.permute(0, 2, 1)  # -> (B, S, K)  (from (B,K,S))

        # separate real/imag
        dir_r = dir_raw[..., 0]  # (B,S,K,M)
        dir_i = dir_raw[..., 1]  # (B,S,K,M)

        # normalize directions to unit-norm per (b,s,k)
        # Stack real and imaginary parts to treat as a vector for norm calculation
        u_complex = torch.complex(dir_r, dir_i)
        norm = torch.linalg.vector_norm(u_complex, dim=-1, keepdim=True) # (B,S,K,1)
        u = u_complex / (norm + EPS)

        # positive scalars via softplus
        alpha = F.softplus(pow_raw) + EPS  # (B,S,K)

        return (u.real, u.imag), alpha  # directions unit-norm; alpha positive but unscaled

# -------------------------
# Sum-rate computation
# -------------------------
def compute_sumrate_from_directions(u_r, u_i, alpha, H_r, H_i, P_total, noise_var):
    """
    Inputs:
      u_r, u_i: (B, S, K, M) unit-norm directions (real, imag)
      alpha: (B, S, K) positive raw scalars -> we'll scale to meet P_total
      H_r, H_i: (B, S, K, M) channel real/imag (user k's row vector)
      P_total: scalar or tensor shape (B,) total power per sample
      noise_var: scalar or tensor shape (B, S, K) or scalar
    Returns:
      sumrate: (B,) bits per channel use (sum over S and K)
      f_r, f_i: precoder columns real/imag shapes (B,S,K,M)
      p: powers (B,S,K)
    """
    B, S, K, M = u_r.shape
    # Scale alpha to meet total-power constraint per sample
    alpha_flat = alpha.reshape(B, -1)  # (B, S*K)
    alpha_sum = alpha_flat.sum(dim=-1, keepdim=True)  # (B,1)
    # allow P_total to be scalar or tensor
    if isinstance(P_total, (float, int)):
        P_total = torch.full((B, 1), float(P_total), device=alpha.device, dtype=alpha.dtype)
    else:
        P_total = P_total.view(B, 1)

    p_flat = P_total * (alpha_flat / (alpha_sum + EPS))  # (B, S*K)
    p = p_flat.view(B, S, K)  # (B,S,K)

    # Construct complex tensors for channel H, direction u, and precoder F
    H = torch.complex(H_r, H_i)  # (B,S,K,M)
    u = torch.complex(u_r, u_i)  # (B,S,K,M)
    sqrtp = torch.sqrt(p).unsqueeze(-1)  # (B,S,K,1)
    F = sqrtp * u  # (B,S,K,M)

    # Compute pairwise inner products G_kj = h_k^H * f_j for all k,j
    # H is (B,S,K,M), F is (B,S,K,M). We want G of shape (B,S,K,K)
    # G[b,s,k,j] = inner_product(H[b,s,k,:], F[b,s,j,:])
    G = torch.einsum('bskm,bsjm->bskj', H.conj(), F) # (B,S,K,K)
    power_matrix = G.abs()**2

    # desired signal power: diagonal j==k
    sig = torch.diagonal(power_matrix, dim1=-2, dim2=-1) # (B,S,K)
    tot = power_matrix.sum(dim=-1) # (B,S,K)
    interf = tot - sig

    # noise_var: allow scalar or tensor
    if isinstance(noise_var, (float, int)):
        noise = float(noise_var)
    else:
        noise = noise_var

    sinr = sig / (noise + interf + EPS)
    rate = torch.log1p(sinr) / LN2  # bits/channel use
    sumrate = rate.sum(dim=(1,2))   # sum over S and K -> (B,)
    return sumrate, F.real, F.imag, p

In [4]:
# --- Training Configuration ---
N_EPOCHS = 25
LEARNING_RATE = 1e-4
P_TOTAL = 1.0
NOISE_VARIANCE = 1e-3

# --- Model and Optimizer Initialization ---
# Extract dimensions from data
_, S, K, M = all_channels_tensor.shape
_, _, EMBED_DIM = all_embeddings_tensor.shape

# Instantiate the regressor model and move it to the correct device
regressor_model = RegressorBeamformer(emb_dim=EMBED_DIM, K=K, M=M, S=S).to(INFERENCE_DEVICE)

# Setup the optimizer
optimizer = torch.optim.Adam(regressor_model.parameters(), lr=LEARNING_RATE)

history = {"train_loss": [], "val_loss": []}
print("Starting training...")

for epoch in range(N_EPOCHS):
    # --- Training Step ---
    regressor_model.train()
    total_train_loss = 0.0

    for channels_batch, embeddings_batch in train_loader:
        # Move batch to device
        channels_batch = channels_batch.to(INFERENCE_DEVICE)
        embeddings_batch = embeddings_batch.to(INFERENCE_DEVICE)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass: get directions and power scalars from the model
        (u_r, u_i), alpha = regressor_model(embeddings_batch)

        # Calculate sum-rate 
        sumrate_batch, _, _, _ = compute_sumrate_from_directions(u_r, u_i, alpha,
                                                                 channels_batch.real, channels_batch.imag,
                                                                 P_TOTAL, NOISE_VARIANCE)
        
        # The loss is the negative of the sum-rate
        loss = -torch.mean(sumrate_batch)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    history["train_loss"].append(avg_train_loss)

    # --- Validation Step ---
    regressor_model.eval()
    total_val_loss = 0.0

    with torch.no_grad():
        for channels_batch, embeddings_batch in val_loader:
            # Move batch to device
            channels_batch = channels_batch.to(INFERENCE_DEVICE)
            embeddings_batch = embeddings_batch.to(INFERENCE_DEVICE)

            (u_r, u_i), alpha = regressor_model(embeddings_batch)

            sumrate_batch, _, _, _ = compute_sumrate_from_directions(u_r, u_i, alpha,
                                                                 channels_batch.real, channels_batch.imag,
                                                                 P_TOTAL, NOISE_VARIANCE)

            loss = -torch.mean(sumrate_batch)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    history["val_loss"].append(avg_val_loss)

    print(f"Epoch {epoch+1}/{N_EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

Starting training...
Epoch 1/25 | Train Loss: -51.3718 | Val Loss: -52.7987
Epoch 2/25 | Train Loss: -52.5839 | Val Loss: -53.1258
Epoch 3/25 | Train Loss: -53.1126 | Val Loss: -53.3495
Epoch 4/25 | Train Loss: -53.7383 | Val Loss: -53.7634
Epoch 5/25 | Train Loss: -54.3815 | Val Loss: -54.2066
Epoch 6/25 | Train Loss: -54.5436 | Val Loss: -54.7398
Epoch 7/25 | Train Loss: -55.8725 | Val Loss: -55.2034
Epoch 8/25 | Train Loss: -57.6768 | Val Loss: -55.6789
Epoch 9/25 | Train Loss: -58.0199 | Val Loss: -55.9529
Epoch 10/25 | Train Loss: -58.4181 | Val Loss: -56.5114
Epoch 11/25 | Train Loss: -60.4479 | Val Loss: -57.4674
Epoch 12/25 | Train Loss: -61.2471 | Val Loss: -58.2714
Epoch 13/25 | Train Loss: -60.2542 | Val Loss: -59.2976
Epoch 14/25 | Train Loss: -63.6758 | Val Loss: -60.7282
Epoch 15/25 | Train Loss: -65.3671 | Val Loss: -61.5123
Epoch 16/25 | Train Loss: -64.1171 | Val Loss: -62.7029
Epoch 17/25 | Train Loss: -65.3228 | Val Loss: -63.6381
Epoch 18/25 | Train Loss: -71.6271 |