In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd
import cvxpy as cp
import torch
import torch.optim as optim
import torch.nn as nn

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Step 1: Define System and Simulation Parameters
N = 64  # Number of BS antennas
K = 4   # Number of users
M = 4   # Number of RF chains
omega = 0.3  # Tradeoff weight
I_max = 10  # Maximum outer iterations
J = 10  # Can be 1, 10, or 20

SNR_dB = 12  # SNR in dB
sigma_n2 = 1.0  # Noise variance
P_BS = sigma_n2 * 10**(SNR_dB / 10)  # Transmit power
mu = 0.01  # Step size for analog precoder
lambda_ = 0.01  # Step size for digital precoder
L = 20  # Number of paths for channel
num_realizations = 2  # Number of channel realizations

# Dataset parameters
num_channels = 10
num_epochs = 10 if J == 1 else 3
snr_min, snr_max = 0, 12  # dB

# Step 2: Define Sensing Parameters
P = 3  # Number of desired sensing angles
theta_d = np.array([-60, 0, 60]) * np.pi / 180  # Desired angles in radians
delta_theta = 5 * np.pi / 180  # Half beamwidth
theta_grid = np.linspace(-np.pi / 2, np.pi / 2, 181)  # Angular grid [-90, 90] degrees
B_d = np.zeros(len(theta_grid))  # Desired beampattern
for t, theta_t in enumerate(theta_grid):
    for theta_p in theta_d:
        if abs(theta_t - theta_p) <= delta_theta:
            B_d[t] = 1

# Wavenumber and antenna spacing
lambda_wave = 1  # Wavelength (normalized)
k = 2 * np.pi / lambda_wave
d = lambda_wave / 2  # Antenna spacing

# Step 3: Channel Matrix Generation (Saleh-Valenzuela Model)
def generate_channel(N, M, L, device='cpu'):
    H = torch.zeros((M, N), dtype=torch.cfloat, device=device)
    for _ in range(L):
        alpha = torch.randn(2, device=device).view(torch.cfloat)[0] / torch.sqrt(torch.tensor(2.0, device=device))
        phi_r = torch.rand(1, device=device) * 2 * torch.pi
        phi_t = torch.rand(1, device=device) * 2 * torch.pi
        a_r = torch.exp(1j * k * d * torch.arange(M, dtype=torch.float32, device=device) * torch.sin(phi_r)) / torch.sqrt(torch.tensor(M, dtype=torch.float32, device=device))
        a_t = torch.exp(1j * k * d * torch.arange(N, dtype=torch.float32, device=device) * torch.sin(phi_t)) / torch.sqrt(torch.tensor(N, dtype=torch.float32, device=device))
        H += torch.sqrt(torch.tensor(N * M / L, dtype=torch.float32, device=device)) * alpha * torch.outer(a_r, a_t.conj())
    return H

# Steering vector function
def steering_vector(theta, N, device='cpu'):
    theta = torch.tensor(theta, dtype=torch.float32, device=device) if not torch.is_tensor(theta) else theta
    return torch.exp(1j * k * d * torch.arange(N, dtype=torch.float32, device=device) * torch.sin(theta)) / torch.sqrt(torch.tensor(N, dtype=torch.float32, device=device))

# Compute benchmark covariance matrix Psi
def compute_psi(N, Bd, theta_grid, PBS):
    Abar_grid = np.exp(1j * np.pi * np.arange(N)[:, None] @ np.sin(theta_grid)[None, :])
    Psi = cp.Variable((N, N), hermitian=True)
    alpha_cvx = cp.Variable(nonneg=True)

    constraints = [
        cp.diag(Psi) == (PBS / N) * np.ones(N),
        Psi >> 0
    ]
    
    Bd_vec = Bd.flatten()
    obj = sum(cp.square(alpha_cvx * Bd_vec[t] - cp.real(cp.quad_form(Abar_grid[:, t], Psi))) for t in range(len(Bd_vec)))
    objective = cp.Minimize(obj)

    problem = cp.Problem(objective, constraints)
    
    try:
        problem.solve(solver=cp.SCS, eps=1e-3)
        if problem.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
            Psi_optimal = Psi.value
            alpha_optimal = alpha_cvx.value
            return Psi_optimal, alpha_optimal
        else:
            print(f"CVXPY Status: {problem.status}. Problem could not be solved to optimality.")
            Psi_fallback = (PBS / N) * np.eye(N, dtype=complex)
            return Psi_fallback, None
    except cp.error.SolverError as e:
        print(f"CVXPY Solver Error: {e}")
        print("Falling back to identity matrix.")
        Psi_fallback = (PBS / N) * np.eye(N, dtype=complex)
        return Psi_fallback, None

# Compute communication rate R
def compute_rate(H, A, D, sigma_n2):
    sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device) if not torch.is_tensor(sigma_n2) else sigma_n2
    H_A = H @ A
    R = torch.tensor(0.0, dtype=torch.float32, device=H.device)
    for k in range(K):
        h_k = H_A[k, :].reshape(-1, 1)
        signal = torch.abs(h_k.conj().T @ D[:, k])**2
        interference = sum(torch.abs(h_k.conj().T @ D[:, j])**2 for j in range(K) if j != k)
        SINR = signal / (interference + sigma_n2)
        R += torch.log2(1 + SINR)
    return R

# Compute sensing error tau
def compute_tau(A, D, Psi, theta_grid):
    V = A @ D
    tau = torch.tensor(0.0, dtype=torch.float32, device=A.device)
    theta_grid = torch.tensor(theta_grid, dtype=torch.float32, device=A.device) if not torch.is_tensor(theta_grid) else theta_grid
    for theta in theta_grid:
        a_theta = steering_vector(theta, A.shape[0], device=A.device)
        a_theta = a_theta.reshape(-1, 1)
        term1 = (a_theta.conj().T @ V @ V.conj().T @ a_theta).real
        term2 = (a_theta.conj().T @ Psi @ a_theta).real
        tau += torch.abs(term1 - term2)**2
    return tau / len(theta_grid)

def gradient_R_A(H, A, D, sigma_n2):
    sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device) if not torch.is_tensor(sigma_n2) else sigma_n2
    xi = 1 / torch.log(torch.tensor(2.0, dtype=A.dtype, device=A.device))
    grad_A = torch.zeros_like(A, dtype=torch.cfloat)

    V = D @ D.conj().transpose(-2, -1)
    for k in range(K):
        h_k = H[k, :].reshape(-1, 1)
        H_tilde_k = h_k @ h_k.conj().transpose(-2, -1)
        D_bar_k = D.clone()
        D_bar_k[:, k] = 0.0
        V_bar_k = D_bar_k @ D_bar_k.conj().transpose(-2, -1)
        denom1 = torch.trace(A @ V @ A.conj().transpose(-2, -1) @ H_tilde_k) + sigma_n2
        denom2 = torch.trace(A @ V_bar_k @ A.conj().transpose(-2, -1) @ H_tilde_k) + sigma_n2
        term1 = H_tilde_k @ A @ V / denom1
        term2 = H_tilde_k @ A @ V_bar_k / denom2
        grad_A += xi * (term1 - term2)
    return grad_A

def gradient_R_D(H, A, D, sigma_n2):
    sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device) if not torch.is_tensor(sigma_n2) else sigma_n2
    xi = 1 / torch.log(torch.tensor(2.0, dtype=A.dtype, device=A.device))
    grad_D = torch.zeros_like(D, dtype=torch.cfloat)

    for k in range(K):
        h_k = H[k, :].reshape(-1, 1)
        H_tilde_k = h_k @ h_k.conj().transpose(-2, -1)
        H_bar_k = A.conj().transpose(-2, -1) @ H_tilde_k @ A
        D_bar_k = D.clone()
        D_bar_k[:, k] = 0.0
        denom1 = torch.trace(D @ D.conj().transpose(-2, -1) @ H_bar_k) + sigma_n2
        denom2 = torch.trace(D_bar_k @ D_bar_k.conj().transpose(-2, -1) @ H_bar_k) + sigma_n2
        term1 = (H_bar_k @ D) / denom1
        term2 = (H_bar_k @ D_bar_k) / denom2
        grad_D += xi * (term1 - term2)
    return grad_D

def gradient_tau_A(A, D, Psi):
    U = A @ D @ D.conj().transpose(-2, -1) @ A.conj().transpose(-2, -1)
    grad_A = 2 * (U - Psi) @ A @ D @ D.conj().transpose(-2, -1)
    return grad_A

def gradient_tau_D(A, D, Psi):
    U = A @ D @ D.conj().transpose(-2, -1) @ A.conj().transpose(-2, -1)
    grad_D = 2 * A.conj().transpose(-2, -1) @ (U - Psi) @ A @ D
    return grad_D

def proposed_initialization(H, theta_d, N, M, K, P_BS):
    theta_d = torch.tensor(theta_d, dtype=torch.float32, device=H.device) if isinstance(theta_d, np.ndarray) else theta_d
    G = H.T  # shape (N, K) = (64, 4)
    A0 = torch.exp(-1j * torch.angle(G[:, :M]))  # shape (N, M) = (64, 4)
    X_ZF = torch.linalg.pinv(H)  # shape (N, K) = (64, 4)
    D0 = torch.linalg.pinv(A0) @ X_ZF  # shape (M, K) = (4, 4)
    norm_factor = torch.norm(A0 @ D0, p='fro')
    D0 = torch.sqrt(torch.tensor(P_BS, dtype=A0.dtype, device=A0.device)) * D0 / norm_factor
    return A0, D0

def random_initialization(N, M, H, P_BS, device='cpu'):
    A0 = torch.exp(1j * torch.rand(N, M, dtype=torch.cfloat, device=device) * 2 * torch.pi)
    D0 = torch.linalg.pinv(H @ A0)
    norm_factor = torch.norm(A0 @ D0, p='fro')
    D0 = torch.sqrt(torch.tensor(P_BS, dtype=A0.dtype, device=A0.device)) * D0 / norm_factor
    return A0, D0

def svd_initialization(H, N, M, K, P_BS, device='cpu'):
    _, _, Vh = svd(H.cpu().numpy() if torch.is_tensor(H) else H, full_matrices=False)
    A0 = torch.tensor(Vh.T[:, :M], dtype=torch.cfloat, device=device)
    A0 = torch.exp(1j * torch.angle(A0))
    H_A = H @ A0
    try:
        D0 = torch.linalg.pinv(H_A)
    except RuntimeError:
        D0 = torch.linalg.pinv(H_A + 1e-6 * torch.eye(M, dtype=torch.cfloat, device=device))
    norm_factor = torch.norm(A0 @ D0, p='fro')
    D0 = torch.sqrt(torch.tensor(P_BS, dtype=A0.dtype, device=A0.device)) * D0 / norm_factor
    return A0, D0

def project_unit_modulus(A):
    return torch.exp(1j * torch.angle(A))

def project_power_constraint(D, A, P_BS):
    norm_factor = torch.norm(A @ D, p='fro')
    D = D * (torch.sqrt(torch.tensor(P_BS, dtype=D.dtype, device=D.device)) / norm_factor)
    return D

def run_pga(H, A0, D0, J, I_max, mu, lambda_, omega, sigma_n2, Psi, theta_grid):
    N, K = H.shape
    A = A0.clone()
    D = D0.clone()
    objectives = []
    eta = 1 / N

    for i in range(I_max):
        print(f"\n===== Outer Iteration {i+1}/{I_max} =====")
        A_hat = A.clone()
        for j in range(J):
            grad_R_A = gradient_R_A(H, A_hat, D, sigma_n2)
            grad_tau_A = gradient_tau_A(A_hat, D, Psi)
            grad_A = grad_R_A - omega * grad_tau_A
            A_hat = A_hat + mu * grad_A
            A_hat = project_unit_modulus(A_hat)
        A = A_hat
        grad_R_D = gradient_R_D(H, A, D, sigma_n2)
        grad_tau_D = gradient_tau_D(A, D, Psi)
        grad_D = grad_R_D - omega * eta * grad_tau_D
        D = D + lambda_ * grad_D
        D = project_power_constraint(D, A, P_BS)
        R = compute_rate(H, A, D, sigma_n2)
        tau = compute_tau(A, D, Psi, theta_grid)
        objective = R - omega * tau
        objectives.append(objective)
        print(f"Iteration {i+1}: R = {R:.4f}, τ = {tau:.4e}, Objective = {objective:.4f}")

    return objectives

class UPGANetLayer(nn.Module):
    def __init__(self, N, M, K, omega, J=10, eta=None):
        super(UPGANetLayer, self).__init__()
        self.J = J
        self.N, self.M, self.K = N, M, K
        self.omega = omega
        self.eta = eta if eta is not None else 1/N
        self.mu = nn.Parameter(torch.full((J,), 0.01, dtype=torch.float32))
        self.lambda_ = nn.Parameter(torch.tensor(0.01, dtype=torch.float32))

    def forward(self, H, A, D, Psi, sigma_n2, P_BS):
        sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device) if not torch.is_tensor(sigma_n2) else sigma_n2
        P_BS = torch.tensor(P_BS, dtype=torch.float32, device=H.device) if not torch.is_tensor(P_BS) else P_BS
        for j in range(self.J):
            grad_RA = gradient_R_A(H, A, D, sigma_n2)
            grad_tauA = gradient_tau_A(A, D, Psi)
            A = A + self.mu[j] * (grad_RA - self.omega * grad_tauA)
            A = project_unit_modulus(A)
        grad_RD = gradient_R_D(H, A, D, sigma_n2)
        grad_tauD = gradient_tau_D(A, D, Psi)
        D = D + self.lambda_ * (grad_RD - self.omega * self.eta * grad_tauD)
        D = project_power_constraint(D, A, P_BS)
        return A, D

class UPGANet(nn.Module):
    def __init__(self, N, M, K, omega, I_max=120, J=10):
        super(UPGANet, self).__init__()
        self.layers = nn.ModuleList([
            UPGANetLayer(N, M, K, omega, J=J) for _ in range(I_max)
        ])
        self.I_max = I_max
        self.omega = omega

    def forward(self, H, A0, D0, Psi, sigma_n2, P_BS):
        A, D = A0, D0
        for i in range(self.I_max):
            A, D = self.layers[i](H, A, D, Psi, sigma_n2, P_BS)
        return A, D

def upganet_loss(H, A, D, Psi, sigma_n2, omega):
    sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device) if not torch.is_tensor(sigma_n2) else sigma_n2
    omega = torch.tensor(omega, dtype=torch.float32, device=H.device) if not torch.is_tensor(omega) else omega
    R = compute_rate(H, A, D, sigma_n2)
    tau = compute_tau(A, D, Psi, theta_grid)
    return -(R - omega * tau)

# Instantiate model
model = UPGANet(N, M, K, omega, I_max=I_max, J=J)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    total_loss = 0.0
    for _ in range(num_channels):
        H = generate_channel(N, M, L=3, device=device).transpose(0, 1)
        snr_db = np.random.uniform(snr_min, snr_max)
        P_BS = 10 ** (snr_db / 10)
        Psi, alpha_opt = compute_psi(N, B_d, theta_grid, P_BS)
        Psi = np.array(Psi, dtype=np.complex64)
        H_t = H.to(device)
        Psi_t = torch.tensor(Psi, dtype=torch.cfloat, device=device)
        theta_d_t = torch.tensor(theta_d, dtype=torch.float32, device=device)
        A0, D0 = proposed_initialization(H_t, theta_d_t, N, M, K, P_BS)
        A0_t = A0.to(device)
        D0_t = D0.to(device)
        A_final, D_final = model(H_t, A0_t, D0_t, Psi_t, sigma_n2, P_BS)
        loss = upganet_loss(H_t, A_final, D_final, Psi_t, sigma_n2, omega)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {total_loss/num_channels:.6f}")

  H_t = torch.tensor(H, dtype=torch.cfloat)
  A0_t = torch.tensor(A0, dtype=torch.cfloat)
  D0_t = torch.tensor(D0, dtype=torch.cfloat)
  sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device)
  sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device)
  D = D * (torch.sqrt(torch.tensor(P_BS, dtype=D.dtype, device=D.device)) / norm_factor)
  sigma_n2 = torch.tensor(sigma_n2, dtype=torch.float32, device=H.device)


RuntimeError: output with shape [] doesn't match the broadcast shape [1, 1]