In [1]:
import numpy as np

def WMMSE_BMAC_a(Pt, H, wt, V):
    L = len(wt)
    Nt = [H[l][l].shape[1] for l in range(L)]
    Nr = [H[l][l].shape[0] for l in range(L)]

    Sigma = [None] * L
    A = [None] * L
    B = [None] * L
    Omega = [None] * L
    U = [None] * L
    W = [None] * L
    dAD = [None] * L
    dC = [None] * L

    R = 0
    oldSumRate = 0
    rate = []

    while oldSumRate == 0 or R - oldSumRate > 0.01:
        oldSumRate = R
        A = [np.zeros((Nt[l], Nt[l]), dtype=complex) for l in range(L)]
        B = [np.zeros((Nt[l], Nr[l]), dtype=complex) for l in range(L)]
        R = 0

        for l in range(L):
            Omega[l] = np.eye(Nr[l], dtype=complex)
            for k in range(L):
                if k != l:
                    Omega[l] += H[l][k] @ V[k] @ V[k].conj().T @ H[l][k].conj().T

        for l in range(L):
            temp = H[l][l] @ V[l]
            R += wt[l] * np.log2(np.linalg.det(np.eye(Nr[l]) + temp @ temp.conj().T @ np.linalg.inv(Omega[l]))).real

        rate.append(R)

        for l in range(L):
            H0 = H[l][l]
            HVVhH = H0 @ V[l] @ V[l].conj().T @ H0.conj().T
            U[l] = np.linalg.inv(HVVhH + Omega[l]) @ H0 @ V[l]
            W[l] = np.eye(Nt[l]) + V[l].conj().T @ H0.conj().T @ np.linalg.inv(Omega[l]) @ H0 @ V[l]

        for l in range(L):
            for k in range(L):
                A[l] += wt[k] * H[k][l].conj().T @ U[k] @ W[k] @ U[k].conj().T @ H[k][l]
            B[l] += wt[l] * H[l][l].conj().T @ U[l] @ W[l]

        for l in range(L):
            AV, AD, AU = np.linalg.svd(A[l])
            dAD[l] = AD
            C = AV.conj().T @ B[l] @ B[l].conj().T @ AV
            dC[l] = np.real(np.diag(C))

        mu1 = 1.0
        mu0 = 0.0

        def compute_val(mu):
            return sum(np.sum(dC[l] / (dAD[l] + mu)**2) for l in range(L))

        val = compute_val(mu1)
        while val > Pt:
            mu1 *= 2
            val = compute_val(mu1)

        while mu1 - mu0 > 1e-6:
            mu = (mu1 + mu0) / 2
            val = compute_val(mu)
            if val > Pt:
                mu0 = mu
            else:
                mu1 = mu

        for l in range(L):
            V[l] = np.linalg.inv(mu1 * np.eye(Nt[l]) + A[l]) @ B[l]

    Sigma = [V[l] @ V[l].conj().T for l in range(L)]
    return rate, Sigma


In [None]:
SNR = 28.1
P = 10**(SNR/10)
K = 4  # number of cells
I_k = [1, 1, 1, 1]  # number of users in each cell
n_tx = [3, 3, 3, 3]  # number of antennas at each transmitter
n_rx = [[2], [2], [2], [2]]  # number of antennas at each user in each cell
P_k = [P, P, P, P]
sig_i_k = [[.1], [.1], [.1], [.1]]
d = [[2], [2], [2], [2]]
alpha = [[1], [1], [1], [1]]

# Initialize channel dictionary
H = {}
for k in range(K):  # transmitter cell index
    H[k] = {}
    for l in range(K):  # receiver cell index
        for i in range(I_k[l]):  # user index in cell l
            tx_ant = n_tx[k]
            rx_ant = n_rx[l][i]
            # Channel from transmitter k to user (l, i)
            H[k][(l, i)] = torch.randn(rx_ant, tx_ant)/(2**.5) + 1j*torch.randn(rx_ant, tx_ant)/(2**.5)

max_iter_mu = 1000
tol_mu = 1e-3
max_iter_alg = 1000
tol_alg = 1e-2

In [8]:
H = [
    [np.random.randn(4, 4) + 1j * np.random.randn(4, 4) for _ in range(4)]
    for _ in range(4)
]

V = [np.random.randn(4, 4) + 1j * np.random.randn(4, 4) for _ in range(4)]

Pt = 5

wt = [1, 1, 1, 1]

In [9]:
WMMSE_BMAC_a(Pt, H, wt, V)

([10.95730097590468,
  12.820065579589174,
  14.649889030841809,
  15.170524256748418,
  15.460010593450608,
  15.658149856971763,
  15.815284590286408,
  15.965390239021021,
  16.148680195200484,
  16.4030583835199,
  16.67420755770329,
  16.86964001750872,
  17.010106571504075,
  17.121842658259858,
  17.212668809720977,
  17.286589585656913,
  17.346761783912537,
  17.39580043324324,
  17.43584176561756,
  17.468621676899453,
  17.495529915286035,
  17.51768126105167,
  17.535965732974176,
  17.551098774959335,
  17.56365597932965,
  17.574097535264652,
  17.582797118113067],
 [array([[ 0.47922021-2.42398546e-18j, -0.06266653+2.73460230e-01j,
          -0.14775558+6.37695186e-02j, -0.25794486-2.18324538e-01j],
         [-0.06266653-2.73460230e-01j,  0.60452917+6.94752348e-18j,
           0.23544218+2.09640613e-01j,  0.15664061-1.67972403e-01j],
         [-0.14775558-6.37695186e-02j,  0.23544218-2.09640613e-01j,
           0.16798988-2.57075939e-19j,  0.04716191-1.13804736e-01j],
   