In [None]:
def draw_network(position, radius):
    fig, ax = plt.subplots()
    circle = plt.Circle((0, 0), radius, fill=False, color='blue')
    ax.set_aspect('equal', adjustable='box')
    ax.scatter(position[0][0][0], position[0][0][1], color='green')
    ax.scatter(
        [node[0] for node in position[0][1:]],
        [node[1] for node in position[0][1:]],
        color='red'
    )
    ax.add_patch(circle)
    plt.show()

In [None]:
def generate_channels_cell_wireless(num_bs, num_users, num_samples, var_noise=1.0, radius=1):
    # Network: Consisting multiple pairs of Tx and Rx devices, each pair is considered an user.
    # Input:
    #     num_users: Number of users in the network
    #     num_samples: Number of samples using for the model
    #     var_noise: variance of the AWGN
    #     p_min: minimum power for each user
    # Output:
    #     Hs: channel matrices of all users in the network - size num_samples x num_users x num_users
    #        H(i,j) is the channel from Tx of the i-th pair to Rx or the j-th pair
    #     pos: position of all users in the network (?)
    #     pos[:num_bs] is the position of the BS(s)
    #     pos[num_bs:num_bs+num_users] is the position of the user(s)
    #     adj: adjacency matrix of all users in the network - only "1" if interference occurs

    print("Generating Data for training and testing")

    if num_bs != 1:
        raise Exception("Can not generate data for training and testing with more than 1 base station")
    # generate position
    dist_mat = []
    position = []

    # Calculate channel
    CH = 1 / np.sqrt(2) * (np.random.randn(num_samples, 1, num_users)
                           + 1j * np.random.randn(num_samples, 1, num_users))

    if radius == 0:
        Hs = abs(CH)
    else:
        for each_sample in range(num_samples):
            pos = []
            pos_BS = []

            for i in range(num_bs):
                r = 0.2 * radius * (np.random.rand())
                theta = np.random.rand() * 2 * np.pi
                pos_BS.append([r * np.sin(theta), r * np.cos(theta)])
                pos.append([r * np.sin(theta), r * np.cos(theta)])
            pos_user = []

            for i in range(num_users):
                r = 0.5 * radius + 0.5 * radius * np.random.rand()
                theta = np.random.rand() * 2 * np.pi
                pos_user.append([r * np.sin(theta), r * np.cos(theta)])
                pos.append([r * np.sin(theta), r * np.cos(theta)])

            pos = np.array(pos)
            pos_BS = np.array(pos_BS)
            dist_matrix = distance_matrix(pos_BS, pos_user)
            # dist_matrixp = distance_matrix(pos[1:], pos[1:])
            dist_mat.append(dist_matrix)
            position.append(pos)

        dist_mat = np.array(dist_mat)
        position = np.array(position)

        # Calculate Free space pathloss
        f = 6e9
        c = 3e8
        FSPL = 1 / ((4 * np.pi * f * dist_mat / c) ** 2)
        Hs = abs(CH * FSPL)

    adj = adj_matrix(num_users)

    return Hs, position, adj


In [None]:
def wmmse_cell_network(channel_matrix, power_matrix, weight_matrix, P_max, var_noise):
    print("Solving the cell network problem with WMMSE")
    num_user = channel_matrix.shape[2]
    num_BS = channel_matrix.shape[1]
    num_sample = channel_matrix.shape[0]

    # U = np.zeros(num_BS, num_user)
    # V = np.ones(num_BS, num_user)
    # W = np.zeros(num_BS, num_user)

    power = np.sqrt(power_matrix)
    
    all_rx_signal = channel_matrix.transpose(0, 2, 1) @ power
    desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
    desired_power = np.expand_dims(desired_power, axis=1)
    interference = np.square(all_rx_signal)
    interference = np.sum(interference, 2)  # interfernce at each UE => sum of columns
    interference = np.expand_dims(interference, axis=1)
    U = np.divide(desired_power, interference + var_noise)
    W = 1 / (1 - (U * desired_power))
    # The main loop
    count = 1
    while 1:
        # Calculate the V
        all_rx_signal = channel_matrix.transpose(0, 2, 1) @ U
        desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
        desired_power = np.expand_dims(desired_power, axis=1)
        desired_power = weight_matrix * W * desired_power
        interference = np.square(all_rx_signal)
        wei_exp = np.tile(weight_matrix,(1,10,1))
        W_exp = np.tile(W,(1,10,1))
        interference = wei_exp * interference * W_exp
        interference = np.sum(interference, 2)
        interference = np.expand_dims(interference, axis=1)

        V = desired_power / interference
        # print("================================================")
        # print(f'Loop {count}: {V}')
        # setting V for constraints p_max
        V = np.minimum(V, np.sqrt(P_max)) + np.maximum(V, np.zeros(V.shape)) - V
        # print(f'After justified: {V}')

        # Update U and W
        all_rx_signal = channel_matrix.transpose(0, 2, 1) @ V
        desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
        interference = np.square(all_rx_signal)
        interference = np.sum(interference, 2)
        U = np.divide(desired_power, interference + var_noise)
        W = 1 / (1 - (U * desired_power))

        count = count + 1

#         all_power.append(V)
        # Check break condition
        if count == 100:
            break

    # print(f'The total loop: {count}')
    return np.square(V)



In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import numpy as np

K = 1  # number of BS(s)
N = 3  # number of users
R = 0  # radius
p_mtx = np.ones((1, K, N)) * 1
p_max = np.ones((1, K, N)) * 4
all_power = []
X_train = np.array([[[0.49632743, 0.45383659, 0.44659692]]])

weights_matrix = np.array([[[1, 3, 2]]])
power = np.array([[[1, 2, 4]]])
var_noise = np.array([[[0.1, 0.1, 0.1]]])

p_wmmse = wmmse_cell_network(X_train, power, weights_matrix, p_max, var_noise)
print(p_wmmse)

In [None]:
fig, axs = plt.subplots(1, 3)

all_power = np.array(all_power)

axs[0].plot(all_power[:, 0, 0, 0])
axs[0].set_title('Plot 1')

axs[1].plot(all_power[:, 0, 0, 1])
axs[1].set_title('Plot 2')

axs[2].plot(all_power[:, 0, 0, 2])
axs[2].set_title('Plot 3')
# Display the plot
plt.show()

In [None]:
p_max

In [None]:
channel_matrix = np.array([[[0.49632743, 0.45383659, 0.44659692]]])
power_matrix = np.array([[[1, 2, 4]]])
weight_matrix = np.array([[[1, 1, 1]]])
P_max = np.ones((1, K, N)) * 4
var_noise = var_noise


In [None]:
print("Solving the cell network problem with WMMSE")
num_user = channel_matrix.shape[2]
num_BS = channel_matrix.shape[1]
num_sample = channel_matrix.shape[0]

# The WMMSE Approach

In [None]:
power = np.sqrt(power_matrix)

all_rx_signal = channel_matrix.transpose(0, 2, 1) @ power
desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
interference = np.square(all_rx_signal)
interference = np.sum(interference, 2)  # interfernce at each UE => sum of columns
U = np.divide(desired_power, interference + var_noise)
W = 1 / (1 - (U * desired_power))

In [None]:
# Calculate the V
all_rx_signal = channel_matrix.transpose(0, 2, 1) @ U
desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
desired_power = weight_matrix * W * desired_power
interference = np.square(all_rx_signal)
interference = weight_matrix * interference * W
interference = np.sum(interference, 2)

V = desired_power / interference
# print("================================================")
# print(f'Loop {count}: {V}')
# setting V for constraints p_max
V = np.minimum(V, np.sqrt(P_max)) + np.maximum(V, np.zeros(V.shape)) - V
# print(f'After justified: {V}')

In [None]:
# Update U and W
all_rx_signal = channel_matrix.transpose(0, 2, 1) @ V
desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
interference = np.square(all_rx_signal)
interference = np.sum(interference, 2)
U = np.divide(desired_power, interference + var_noise)
W = 1 / (1 - (U * desired_power))

# The direct Approach

In [None]:
h1 = channel_matrix[0,0,0] 
h2 = channel_matrix[0,0,1] 
h3 = channel_matrix[0,0,2] 
p1 = power[0,0,0]
p2 = power[0,0,1]
p3 = power[0,0,2]

In [None]:
i1 = (h1**2) * ((p2**2) + (p3**2) + (p1**2))
i2 = (h2**2) * ((p2**2) + (p3**2) + (p1**2))
i3 = (h3**2) * ((p2**2) + (p3**2) + (p1**2))
print(i1, i2, i3)

In [None]:
u1 = h1 * p1 / (i1 + var_noise[0,0,0])
u2 = h2 * p2 / (i2 + var_noise[0,0,1])
u3 = h3 * p3 / (i3 + var_noise[0,0,2])
print(u1, u2, u3)

In [None]:
w1 = 1/(1 - u1 * h1 * p1)
w2 = 1/(1 - u2 * h2 * p2)
w3 = 1/(1 - u3 * h3 * p3)
print(w1, w2, w3)

In [None]:
a1 = weight_matrix[0,0,0]
a2 = weight_matrix[0,0,1]
a3 = weight_matrix[0,0,2]

In [None]:
v_m1 = np.sqrt(p_max[0,0,0])
v_m2 = np.sqrt(p_max[0,0,1])
v_m3 = np.sqrt(p_max[0,0,2])

In [None]:
v_m3

In [None]:
v1 = (a1 * h1 * u1 * w1)/(h1**2 * (a1 * w1 * (u1**2) + a2 * w2 * (u2**2) + a3 * w3 * (u3**2)))
v2 = (a2 * h2 * u2 * w2)/(h2**2 * (a1 * w1 * (u1**2) + a2 * w2 * (u2**2) + a3 * w3 * (u3**2)))
v3 = (a3 * h3 * u3 * w3)/(h3**2 * (a1 * w1 * (u1**2) + a2 * w2 * (u2**2) + a3 * w3 * (u3**2)))

v1 = v_m1 if v1 > v_m1 else v1
v2 = v_m2 if v2 > v_m2 else v2
v3= v_m3 if v3 > v_m3 else v3

print(v1, v2, v3)

In [None]:
# print(f'desire: {desired_power}')
# print(f'interference: {interference}')

print(f'U matrix: {U}')

print(f'W matrix: {W}')

print(f'V matrix: {V}')

In [None]:
p1 = power[0,0,0]
p2 = power[0,0,1]
p3 = power[0,0,2]

In [None]:
p = np.array([[[p1, p2, p3]]])
v = np.array([[[v1, v2, v3]]])

In [None]:
np.linalg.norm(v - p)

In [None]:
print(p, v)

In [None]:
def wmmse_direct(channel_matrix, power_matrix, weight_matrix, P_max, var_noise):
    h1 = channel_matrix[0,0,0] 
    h2 = channel_matrix[0,0,1] 
    h3 = channel_matrix[0,0,2] 
    p1 = power[0,0,0]
    p2 = power[0,0,1]
    p3 = power[0,0,2]
    
    i1 = (h1**2) * ((p2**2) + (p3**2) + (p1**2))
    i2 = (h2**2) * ((p2**2) + (p3**2) + (p1**2))
    i3 = (h3**2) * ((p2**2) + (p3**2) + (p1**2))
    
    u1 = h1 * p1 / (i1 + var_noise[0,0,0])
    u2 = h2 * p2 / (i2 + var_noise[0,0,1])
    u3 = h3 * p3 / (i3 + var_noise[0,0,2])
    
    w1 = 1/(1 - u1 * h1 * p1)
    w2 = 1/(1 - u2 * h2 * p2)
    w3 = 1/(1 - u3 * h3 * p3)
    
    a1 = weight_matrix[0,0,0]
    a2 = weight_matrix[0,0,1]
    a3 = weight_matrix[0,0,2]
    
    count = 1
    while 1:
        v_m1 = np.sqrt(p_max[0,0,0])
        v_m2 = np.sqrt(p_max[0,0,1])
        v_m3 = np.sqrt(p_max[0,0,2])

        v1 = (a1 * h1 * u1 * w1)/(h1**2 * (a1 * w1 * (u1**2) + a2 * w2 * (u2**2) + a3 * w3 * (u3**2)))
        v2 = (a2 * h2 * u2 * w2)/(h2**2 * (a1 * w1 * (u1**2) + a2 * w2 * (u2**2) + a3 * w3 * (u3**2)))
        v3 = (a3 * h3 * u3 * w3)/(h3**2 * (a1 * w1 * (u1**2) + a2 * w2 * (u2**2) + a3 * w3 * (u3**2)))

        v1 = v_m1 if v1 > v_m1 else v1
        v2 = v_m2 if v2 > v_m2 else v2
        v3 = v_m3 if v3 > v_m3 else v3
        
        V = np.array([v1, v2, v3])
        count = count + 1

#         all_power.append(V)
        # Check break condition
        if count == 100:
            break
        
        i1 = (h1**2) * ((v2**2) + (v3**2) + (v1**2))
        i2 = (h2**2) * ((v2**2) + (v3**2) + (v1**2))
        i3 = (h3**2) * ((v2**2) + (v3**2) + (v1**2))

        u1 = h1 * v1 / (i1 + var_noise[0,0,0])
        u2 = h2 * v2 / (i2 + var_noise[0,0,1])
        u3 = h3 * v3 / (i3 + var_noise[0,0,2])

        w1 = 1/(1 - u1 * h1 * v1)
        w2 = 1/(1 - u2 * h2 * v2)
        w3 = 1/(1 - u3 * h3 * v3)

    # print(f'The total loop: {count}')
    return np.square(V)

        


In [None]:
channel_matrix = np.array([[[0.49632743, 0.45383659, 0.44659692]]])
power_matrix = np.array([[[1, 2, 4]]])
weight_matrix = np.array([[[1, 1, 1]]])
P_max = np.ones((1, K, N)) * 4
var_noise = var_noise


In [None]:
p_wmmse = wmmse_cell_network(channel_matrix, P_max, weight_matrix, P_max, var_noise)
print(p_wmmse)


In [None]:
p_dir = wmmse_direct(X_train, P_max, weights_matrix, p_max, var_noise)
print(p_dir)

In [None]:
def batch_WMMSE2(p_int, alpha, H, Pmax, var_noise):
    print("Solving WMMSE?")
    N = p_int.shape[0]
    K = p_int.shape[1]
    vnew = 0
    b = np.sqrt(p_int)
    f = np.zeros((N,K,1) )
    w = np.zeros( (N,K,1) )

    mask = np.eye(K)
    rx_power = np.multiply(H, b)
    rx_power_s = np.square(rx_power)
    valid_rx_power = np.sum(np.multiply(rx_power, mask), 1)
    
    interference = np.sum(rx_power_s, 2) + var_noise
    f = np.divide(valid_rx_power,interference)
    w = 1/(1-np.multiply(f,valid_rx_power))
    # vnew = np.sum(np.log2(w),1)

    for ii in range(100):
        fp = np.expand_dims(f,1)
        rx_power = np.multiply(H.transpose(0,2,1), fp)
        valid_rx_power = np.sum(np.multiply(rx_power, mask), 1)
        bup = np.multiply(alpha,np.multiply(w,valid_rx_power))
        rx_power_s = np.square(rx_power)
        wp = np.expand_dims(w,1)
        alphap = np.expand_dims(alpha,1)
        bdown = np.sum(np.multiply(alphap,np.multiply(rx_power_s,wp)),2)
        btmp = bup/bdown
        b = np.minimum(btmp, np.ones((N,K) )*np.sqrt(Pmax)) + np.maximum(btmp, np.zeros((N,K) )) - btmp
        
        bp = np.expand_dims(b,1)
        rx_power = np.multiply(H, bp)
        rx_power_s = np.square(rx_power)
        valid_rx_power = np.sum(np.multiply(rx_power, mask), 1)
        interference = np.sum(rx_power_s, 2) + var_noise
        f = np.divide(valid_rx_power,interference)
        w = 1/(1-np.multiply(f,valid_rx_power))
    p_opt = np.square(b)
    return p_opt

In [None]:
p_wmmse_code = batch_WMMSE2(P_max, weights_matrix, X_train, P_max, var_noise)
print(p_wmmse_code)

In [30]:
def generate_channels_cell_wireless(num_bs, num_users, num_samples, var_noise=1.0, radius=1):
    # Network: Consisting multiple pairs of Tx and Rx devices, each pair is considered an user.
    # Input:
    #     num_users: Number of users in the network
    #     num_samples: Number of samples using for the model
    #     var_noise: variance of the AWGN
    #     p_min: minimum power for each user
    # Output:
    #     Hs: channel matrices of all users in the network - size num_samples x num_users x num_users
    #        H(i,j) is the channel from Tx of the i-th pair to Rx or the j-th pair
    #     pos: position of all users in the network (?)
    #     pos[:num_bs] is the position of the BS(s)
    #     pos[num_bs:num_bs+num_users] is the position of the user(s)
    #     adj: adjacency matrix of all users in the network - only "1" if interference occurs

    print("Generating Data for training and testing")

    if num_bs != 1:
        raise Exception("Can not generate data for training and testing with more than 1 base station")
    # generate position
    dist_mat = []
    position = []

    # Calculate channel
    CH = 1 / np.sqrt(2) * (np.random.randn(num_samples, 1, num_users)
                           + 1j * np.random.randn(num_samples, 1, num_users))

    if radius == 0:
        Hs = abs(CH)
    else:
        for each_sample in range(num_samples):
            pos = []
            pos_BS = []

            for i in range(num_bs):
                r = 0.2 * radius * (np.random.rand())
                theta = np.random.rand() * 2 * np.pi
                pos_BS.append([r * np.sin(theta), r * np.cos(theta)])
                pos.append([r * np.sin(theta), r * np.cos(theta)])
            pos_user = []

            for i in range(num_users):
                r = 0.5 * radius + 0.5 * radius * np.random.rand()
                theta = np.random.rand() * 2 * np.pi
                pos_user.append([r * np.sin(theta), r * np.cos(theta)])
                pos.append([r * np.sin(theta), r * np.cos(theta)])

            pos = np.array(pos)
            pos_BS = np.array(pos_BS)
            dist_matrix = distance_matrix(pos_BS, pos_user)
            # dist_matrixp = distance_matrix(pos[1:], pos[1:])
            dist_mat.append(dist_matrix)
            position.append(pos)

        dist_mat = np.array(dist_mat)
        position = np.array(position)

        # Calculate Free space pathloss
        f = 6e9
        c = 3e8
        FSPL = 1 / ((4 * np.pi * f * dist_mat / c) ** 2)
        Hs = abs(CH * FSPL)

    adj = adj_matrix(num_users)

    return Hs, position, adj

def adj_matrix(num_users):
    adj = []
    for i in range(num_users):
        for j in range(num_users):
            if not (i == j):
                adj.append([i, j])
    return np.array(adj)


In [31]:
import numpy as np

K = 1  # number of BS(s)
N = 10  # number of users
R = 0  # radius

num_train = 100  # number of training samples
num_test = 10  # number of test samples

reg = 1e-2
pmax = 1
var_db = 10
var = 1 / 10 ** (var_db / 10)

X_train, pos_train, adj_train = generate_channels_cell_wireless(K, N, num_train, var, R)

Generating Data for training and testing


In [32]:
channel_matrix = X_train
weight_matrix = np.ones((num_train, K, N))
P_max = np.ones((num_train, K, N)) * 4
var_noise = np.ones((num_train, K, N)) * var


In [7]:
power = np.sqrt(P_max)

print(f'channel_matrix: {channel_matrix.shape}')
all_rx_signal = channel_matrix.transpose(0, 2, 1) @ power
print(f'all_rx_signal: {all_rx_signal.shape}')
desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
desired_power = np.expand_dims(desired_power, axis=1)
print(f'desired_power: {desired_power.shape}')
interference = np.square(all_rx_signal)
print(f'interference: {interference.shape}')
interference = np.sum(interference, 2)  # interfernce at each UE => sum of columns
interference = np.expand_dims(interference, axis=1)
print(f'interference: {interference.shape}')
U = np.divide(desired_power, interference + var_noise)
W = 1 / (1 - (U * desired_power))

channel_matrix: (100, 1, 10)
all_rx_signal: (100, 10, 10)
desired_power: (100, 1, 10)
interference: (100, 10, 10)
interference: (100, 1, 10)


In [8]:
# Calculate the V
all_rx_signal = channel_matrix.transpose(0, 2, 1) @ U
print(f'all_rx_signal: {all_rx_signal.shape}')
desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
desired_power = np.expand_dims(desired_power, axis=1)
desired_power = weight_matrix * W * desired_power
print(f'desired_power: {desired_power.shape}')
interference = np.square(all_rx_signal)
print(f'interference: {interference.shape}')
wei_exp = np.tile(weight_matrix,(1,10,1))
W_exp = np.tile(W,(1,10,1))
interference = wei_exp * interference * W_exp
interference = np.sum(interference, 2)
interference = np.expand_dims(interference, axis=1)
print(f'interference: {interference.shape}')


V = desired_power / interference
# print("================================================")
# print(f'Loop {count}: {V}')
# setting V for constraints p_max
V = np.minimum(V, np.sqrt(P_max)) + np.maximum(V, np.zeros(V.shape)) - V
# print(f'After justified: {V}')

all_rx_signal: (100, 10, 10)
desired_power: (100, 1, 10)
interference: (100, 10, 10)
interference: (100, 1, 10)


In [9]:
# Update U and W
all_rx_signal = channel_matrix.transpose(0, 2, 1) @ V
desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
interference = np.square(all_rx_signal)
interference = np.sum(interference, 2)
U = np.divide(desired_power, interference + var_noise)
W = 1 / (1 - (U * desired_power))

In [45]:
def wmmse_cell_network(channel_matrix, power_matrix, weight_matrix, p_max, noise, epsilon=1e-1):
    print("Solving the cell network problem with WMMSE")
    power = np.sqrt(power_matrix)

    all_rx_signal = channel_matrix.transpose(0, 2, 1) @ power
    desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
    desired_power = np.expand_dims(desired_power, axis=1)
    interference = np.square(all_rx_signal)
    interference = np.sum(interference, 2)  # interfernce at each UE => sum of columns
    interference = np.expand_dims(interference, axis=1)
    U = np.divide(desired_power, interference + var_noise)
    W = 1 / (1 - (U * desired_power))
    # The main loop
    count = 1

    while 1:
        # Calculate the V
        V_Prev = power
        all_rx_signal = channel_matrix.transpose(0, 2, 1) @ U
        desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
        desired_power = np.expand_dims(desired_power, axis=1)
        desired_power = weight_matrix * W * desired_power
        interference = np.square(all_rx_signal)
        wei_exp = np.tile(weight_matrix,(1,10,1))
        W_exp = np.tile(W,(1,10,1))
        interference = wei_exp * interference * W_exp
        interference = np.sum(interference, 2)
        interference = np.expand_dims(interference, axis=1)


        V = desired_power / interference

        # setting V for constraints p_max
        V = np.minimum(V, np.sqrt(p_max)) + np.maximum(V, np.zeros(V.shape)) - V

        # Update U and W
        all_rx_signal = channel_matrix.transpose(0, 2, 1) @ V
        desired_power = np.diagonal(all_rx_signal, axis1=1, axis2=2)
        desired_power = np.expand_dims(desired_power, axis=1)
        interference = np.square(all_rx_signal)
        interference = np.sum(interference, 2)
        interference = np.expand_dims(interference, axis=1)
        U = np.divide(desired_power, interference + noise)
        W = 1 / (1 - (U * desired_power))
        
        
        count = count + 1

        # Check break condition
        if np.linalg.norm(V - V_Prev) < epsilon or count == 100:
            break

    # print(f'The total loop: {count}')
    return np.square(V)


In [46]:
channel_matrix = X_train
weight_matrix = np.ones((num_train, K, N))
P_max = np.ones((num_train, K, N)) * 4
var_noise = np.ones((num_train, K, N)) * var

In [47]:
p_wmmse = wmmse_cell_network(X_train, P_max, weight_matrix, P_max, var_noise, epsilon=1e-1)
print(p_wmmse)

Solving the cell network problem with WMMSE
[[[2.90111941e-235 1.54641377e-086 1.05648036e-157 0.00000000e+000
   4.00000000e+000 1.38012634e-283 1.83067075e-244 0.00000000e+000
   1.69954516e-116 2.52736096e-230]]

 [[9.88131292e-324 6.87448908e-261 4.00000000e+000 1.64498325e-285
   3.94920188e-282 5.14271872e-167 0.00000000e+000 4.09285425e-165
   4.20896484e-176 2.14536545e-243]]

 [[2.55438226e-119 6.18995966e-086 8.52514795e-260 6.67987227e-143
   1.44868987e-134 4.00000000e+000 4.36665897e-185 3.20268162e-303
   3.20426052e-043 8.23080550e-227]]

 [[4.09657482e-139 2.61914684e-219 0.00000000e+000 4.00000000e+000
   1.73137668e-277 1.40153103e-186 0.00000000e+000 1.47906838e-290
   9.03025734e-273 0.00000000e+000]]

 [[9.03393111e-304 1.05728036e-137 0.00000000e+000 1.05483015e-320
   3.32731187e-235 4.00000000e+000 4.00000000e+000 0.00000000e+000
   4.75206529e-271 4.91825910e-302]]

 [[8.28366322e-143 4.91231101e-291 7.48897584e-130 0.00000000e+000
   4.00000000e+000 1.78992893

In [None]:
channel_matrix = np.array([[[0.49632743, 0.45383659, 0.44659692]]])
power_matrix = np.array([[[1, 2, 4]]])
weight_matrix = np.array([[[1, 1, 1]]])
P_max = np.ones((1, K, N)) * 4
var_noise = var_noise