In [2]:
import torch

from src.dataset import PowerAllocationDataset

In [3]:
# General parameters
SEED = 42
INFERENCE_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset parameters
NUM_SAMPLES = 5000
NUM_USERS = [4][0]
SCENARIO_NAME = "city_6_miami"
BS_IDX = 1

# Training Parameters
EPOCHS = 15
WARMUP_EPOCHS = 5
LEARNING_RATE = 1e-3
FINETUNE_LEARNING_RATE = 1e-5
TRAINING_RATIOS = [1.0][0]

D_MODEL = 128

optimizer_config = {"task_head_lr": LEARNING_RATE,
                    "encoder_lr": FINETUNE_LEARNING_RATE}

dataset = PowerAllocationDataset(num_samples=NUM_SAMPLES,
                                 num_users=NUM_USERS,
                                 scenario_name=SCENARIO_NAME,
                                 bs_idx=BS_IDX)



Basestation 1

UE-BS Channels


Reading ray-tracing: 100%|██████████| 42984/42984 [00:00<00:00, 133979.64it/s]
Generating channels: 100%|██████████| 42984/42984 [00:03<00:00, 12517.61it/s]
Generating Scenarios: 100%|██████████| 5000/5000 [00:05<00:00, 936.18it/s] 


In [None]:
dataset.

(10441, 16)

In [27]:
from torch import nn
import torch.nn.functional as F

class RefineBlock(nn.Module):
    def __init__(self, filters):
        super().__init__()

        # Layer 1
        self.bn1 = nn.BatchNorm2d(filters)
        self.conv1 = nn.Conv2d(filters, filters, kernel_size=3, padding=1)

        # Layer 2
        self.bn2 = nn.BatchNorm2d(filters)
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, padding=1)

    def forward(self, x):
        shortcut = x

        # Path
        out = self.bn1(x)
        out = F.relu(out)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = F.relu(out)
        out = self.conv2(out)
        
        # Residual Connection
        out = out + shortcut
        return out

class BeamNetDecoder(nn.Module):
    def __init__(self, num_users, num_carriers, num_antennas, embedding_dim, filters=64, p_max=1.0):
        super(BeamNetDecoder, self).__init__()
        
        self.K = num_users
        self.Nc = num_carriers
        self.Nt = num_antennas
        self.filters = filters
        self.p_max = p_max
        
        # 1. Projection
        # Input: (Batch, K, Emb) -> (Batch, K, Filters * Nc)
        self.user_projection = nn.Linear(embedding_dim, filters * num_carriers)
        
        # 2. RefineNet Core (Same as before)
        self.conv_input = nn.Conv2d(filters, filters, kernel_size=3, padding=1)
        self.refine1 = RefineBlock(filters) # (defined in previous turn)
        self.refine2 = RefineBlock(filters)
        
        # 3. Beam Head (Expansion)
        # We need to output Real and Imag parts for every Antenna
        # Output Channels = 2 * Nt
        self.beam_head = nn.Conv2d(filters, 2 * num_antennas, kernel_size=3, padding=1)

    def forward(self, z):
        B = z.size(0)
        
        # --- A. Expand & Reshape ---
        x = self.user_projection(z)
        x = x.view(B, self.K, self.filters, self.Nc)
        x = x.permute(0, 2, 1, 3) # (Batch, Filters, Users, Carriers)
        
        # --- B. RefineNet (Interference Management) ---
        x = self.conv_input(x)
        x = self.refine1(x)
        x = self.refine2(x)
        
        # --- C. Generate Raw Complex Vectors ---
        # Shape: (Batch, 2*Nt, Users, Carriers)
        raw_beams = self.beam_head(x)
        
        # Separate Real and Imaginary parts
        # Reshape to: (Batch, Users, Carriers, Antennas, 2)
        # 1. Permute to put channels last: (Batch, Users, Carriers, 2*Nt)
        raw_beams = raw_beams.permute(0, 2, 3, 1)
        
        # 2. View to separate Real/Imag: (Batch, Users, Carriers, Nt, 2)
        w = raw_beams.view(B, self.K, self.Nc, self.Nt, 2)
        
        # --- D. Power Constraint (L2 Normalization) ---
        # Calculate Power per beam: |w|^2 = Re^2 + Im^2
        # Sum over Antennas (dim 3) and Re/Im (dim 4)
        beam_power = torch.sum(w ** 2, dim=(3, 4), keepdim=True) # (B, K, Nc, 1, 1)
        
        # Total Power used in the system (sum over Users and Carriers)
        total_system_power = torch.sum(beam_power, dim=(1, 2), keepdim=True)
        
        # Scale factor to satisfy P_max
        # We clamp the divisor to avoid exploding gradients if power is near 0
        scale = torch.sqrt(self.p_max / (total_system_power + 1e-8))
        
        # Apply scaling
        w_final = w * scale
        
        w_final = w_final.permute(0, 2, 1, 3, 4).contiguous()
        
        return w_final
    
class Wrapper(nn.Module):
    def __init__(self,
                 model,
                 task_head):
        super().__init__()

        self.encoder = model
        self.task_head = task_head

        for param in self.encoder.parameters():
            param.requires_grad = False
    
    def fine_tune(self, fine_tune_layers="full"):
        if fine_tune_layers == "full":
             for param in self.encoder.parameters():
                    param.requires_grad = True
        else:
            for name, param in self.encoder.named_parameters():
                if any(layer in name for layer in fine_tune_layers):
                    param.requires_grad = True
         
    def forward(self, tokens):
        """
        Forward function for Wrapper
        
        Inputs:
        channels (torch.tensor): Channel matrix [B, S, K, N]

        Outputs:
        power_weights (torch.tensor): Normalized weights [B, S, K]
        """
        # 1. Extract channel shape
        B, K, S, F = tokens.shape

        # 2. Flatten for Encoder
        # Shape: [B, K, S, F] : [B*K, S, F]
        x = tokens.view(B*K, S, F)

        # 3. Encoder
        # Embeddings shape: [B*K, S, d_model]
        embeddings, _ = self.encoder(x)

        # 4. Extract CLS Token
        # Shape: [B*K, d_model]
        cls_embedding = embeddings[:, 0, :]

        # 5. Reshape to [Batch, Users, D] BEFORE the head
        # This is the key change. We reconstruct the user dimension here.
        cls_structured = cls_embedding.view(B, K, -1)

        # 6. Head Pass
        # The head now takes the structured data and returns normalized power
        W = self.task_head(cls_structured)

        return W

In [1]:
import torch
import torch.nn as nn

class BeamformingSumRateLoss(nn.Module):
    def __init__(self, P_total=1.0, noise_variance=1.0):
        super().__init__()
        self.P_total = P_total
        self.noise_variance = noise_variance

    def forward(self, W_pred, H):
        """
        Inputs:
            W_pred: Predicted Beamforming Vectors [Batch, S, K, N] (Complex)
                    or [Batch, S, K, N, 2] (Real, Imag)
            H:      Channel Matrix [Batch, S, K, N] (Complex)
                    or [Batch, S, K, N, 2] (Real, Imag)
        """
        
        # --- 1. Prepare Complex Tensors ---
        if W_pred.shape[-1] == 2:
            W_complex = torch.view_as_complex(W_pred)
        else:
            W_complex = W_pred
            
        if H.shape[-1] == 2:
            H_complex = torch.view_as_complex(H)
        else:
            H_complex = H

        # --- 2. Enforce System-Wide Power Constraint ---
        # Calculate total power used in the current prediction
        # Sum over: Subcarriers (1), Users (2), Antennas (3)
        # Result shape: [Batch, 1, 1, 1] for broadcasting
        current_total_power = torch.sum(W_complex.abs() ** 2, dim=(1, 2, 3), keepdim=True)
        
        # Scale W to match P_total exactly
        # If the model output is already normalized, this scaling factor will be 1.0 (no change).
        # We add epsilon to avoid division by zero.
        scaling_factor = torch.sqrt(self.P_total / (current_total_power + 1e-8))
        
        W_normalized = W_complex * scaling_factor

        # --- 3. Compute Interaction Matrix (Signal & Interference) ---
        # We want h_k^H * w_j for all k, j.
        # H_complex: [B, S, K, N]
        # W_normalized: [B, S, K, N] -> Transpose to [B, S, N, K]
        
        # Result: [B, S, K, K]
        # Element [b, s, i, j] = Signal at User i from Beam j
        interaction_matrix = torch.matmul(H_complex.conj(), W_normalized.transpose(-2, -1))
        
        # Power = Magnitude Squared
        power_matrix = interaction_matrix.abs() ** 2

        # --- 4. Extract Signal vs Interference ---
        # Diagonal elements (i=j) are the desired signals
        signal_power = torch.diagonal(power_matrix, dim1=-2, dim2=-1) # [B, S, K]
        
        # Sum of rows is Total Received Power at User i
        total_received_power = torch.sum(power_matrix, dim=-1) # [B, S, K]
        
        # Interference = Total - Signal
        interference_power = total_received_power - signal_power
        
        # --- 5. Calculate SINR & Rate ---
        # SINR = Signal / (Interference + Noise)
        sinr = signal_power / (interference_power + self.noise_variance + 1e-10)
        
        # Rate = log2(1 + SINR)
        rate = torch.log2(1 + sinr)

        # --- 6. Sum Rate ---
        # Sum over Users and Subcarriers
        sum_rate_per_sample = torch.sum(rate, dim=(1, 2))
        
        # Average over the batch (Minimize negative rate)
        loss = -torch.mean(sum_rate_per_sample)
        
        return loss

In [None]:
from src.utils import prepare_loaders, get_subset, load_lwm_model
from src.lwm_model import lwm
from src.downstream_models import RegressionHead, Wrapper
from src.metrics import SumRateLoss, benchmark
from src.train import train_downstream_model

P_TOTAL = 1.0
NOISE_VARIANCE = 1e-3

channels = dataset.raw_channels.permute(0, 3, 1, 2) * 1e6

train_loader, val_loader, test_loader = prepare_loaders(channels, dataset.data_tokens, seed=SEED)

results_folder = f"./results/new_test"
fraction_train_loader = get_subset(train_loader, TRAINING_RATIOS, seed=SEED)

lwm_model = load_lwm_model(lwm(), "./models/model.pth", INFERENCE_DEVICE)
task_head = RegressionHead(D_MODEL, num_subcarriers=32)

model = Wrapper(lwm_model, task_head).to(INFERENCE_DEVICE)

criterion = SumRateLoss(P_TOTAL, NOISE_VARIANCE)

model= train_downstream_model(model=model,
                              train_loader=fraction_train_loader,
                              val_loader=val_loader,
                              optimizer_config=optimizer_config,
                              criterion=criterion,
                              epochs=EPOCHS,
                              device=INFERENCE_DEVICE,
                              results_folder=results_folder)

#print(benchmark(model, test_loader, P_TOTAL, NOISE_VARIANCE, INFERENCE_DEVICE))

Loading LWM model...
Model loaded successfully.


  0%|          | 0/15 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (4) must match the size of tensor b (32) at non-singleton dimension 3

In [None]:
from src.utils import prepare_loaders, get_subset, load_lwm_model
from src.lwm_model import lwm
from src.metrics import SumRateLoss, benchmark
from src.train import train_downstream_model

P_TOTAL = 1.0
NOISE_VARIANCE = 1e-3

channels = dataset.raw_channels.permute(0, 3, 1, 2) * 1e6

train_loader, val_loader, test_loader = prepare_loaders(channels, dataset.data_tokens, seed=SEED)

results_folder = f"./results/new_test_csi_net"
fraction_train_loader = get_subset(train_loader, TRAINING_RATIOS, seed=SEED)

lwm_model = load_lwm_model(lwm(), "./models/model.pth", INFERENCE_DEVICE)
task_head = BeamNetDecoder(NUM_USERS, 32, 16, D_MODEL).to(INFERENCE_DEVICE)

model = Wrapper(lwm_model, task_head).to(INFERENCE_DEVICE)

criterion = BeamformingSumRateLoss(NOISE_VARIANCE)

model= train_downstream_model(model=model,
                              train_loader=fraction_train_loader,
                              val_loader=val_loader,
                              optimizer_config=optimizer_config,
                              criterion=criterion,
                              epochs=EPOCHS,
                              device=INFERENCE_DEVICE,
                              results_folder=results_folder)

# print(benchmark(model, test_loader, P_TOTAL, NOISE_VARIANCE, INFERENCE_DEVICE))

Loading LWM model...
Model loaded successfully.


100%|██████████| 15/15 [08:03<00:00, 32.25s/it, Train Loss=-390, Validation Loss=-390, LR=0.001]
Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 5

In [2]:
def calculate_noise_power(bandwidth_ghz, noise_figure_db=9):
    """
    Calculates noise variance (sigma^2) in linear scale (Watts).
    """
    k_B = 1.380649e-23  # Boltzmann constant
    T = 290             # Temperature (Kelvin)
    BW_Hz = bandwidth_ghz * 1e9 # Convert GHz to Hz
    
    # Thermal Noise Density (N0)
    noise_spectral_density = k_B * T 
    
    # Noise Figure in Linear Scale
    noise_figure_linear = 10 ** (noise_figure_db / 10)
    
    # Total Noise Power
    noise_power_watts = noise_spectral_density * BW_Hz * noise_figure_linear
    
    return noise_power_watts
