In [6]:
import os
import sys
import json
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [43]:
def alg_dual_link_torch(power_total: float, weight: torch.Tensor, H_links: torch.Tensor,
                  rate_diff: float, Sigma=None, device=None):
    """
    PyTorch version of dual-link algorithm
    
    Args:
        power_total: Total power constraint
        weight: Weight tensor of shape (n_links,)
        H_links: Channel matrix tensor of shape (n_links, n_links, n_rx, n_tx)
        rate_diff: Rate difference threshold for convergence
        Sigma: Initial covariance matrices (optional)
        device: Device to run computation on
    
    Returns:
        sum_rate_list: List of sum rates per iteration
        final_sum_rate: Final sum rate
        rates: Individual link rates
        Sigma: Final covariance matrices
        Sigma_hat: Final dual covariance matrices
    """
    if device is None:
        device = H_links.device
    
    # Move tensors to device
    H_links = H_links.to(device)
    weight = weight.to(device)
    
    #------------------------------------------------
    # find basic parameters
    #------------------------------------------------
    n_links = H_links.shape[0]
    rates = torch.zeros(n_links, device=device)
    Sigma_hat = [None] * n_links
    
    #------------------------------------------------
    # If initial Sigma not given, make it identity
    #------------------------------------------------
    if Sigma is None:
        Sigma = [None] * n_links
        for l_link in range(n_links):
            lt_l = H_links[l_link, l_link].shape[1]
            # power constraint may not be satisfied
            Sigma[l_link] = (power_total / lt_l / n_links * 
                           torch.eye(lt_l, dtype=torch.complex64, device=device))
    
    #---------------------------------------------
    # repeat until rate change is small
    sum_rate_temp = torch.tensor(-float('inf'), device=device)
    sum_rate_list = []
    
    num_epochs = 0
    while num_epochs < 20:
        num_epochs += 1
        # calculate reverse link Sigma_hat's
        power_normalizer = torch.tensor(0.0, device=device)
        sum_rate = torch.tensor(0.0, device=device)
        
        for l_link in range(n_links):
            # calculate forward link interference
            total_Cov_l = torch.eye(H_links[l_link, l_link].shape[0], 
                                  dtype=torch.complex64, device=device)
            
            for k_link in range(n_links):
                total_Cov_l += (H_links[l_link, k_link] @ Sigma[k_link] @ 
                              H_links[l_link, k_link].T.conj())
            
            Omega_l = (total_Cov_l - H_links[l_link, l_link] @ Sigma[l_link] @ 
                      H_links[l_link, l_link].T.conj())
            
            # -------------------------------------------------------------------
            # calculate the rates using log determinant
            hy = torch.logdet(total_Cov_l).real
            hy_x = torch.logdet(Omega_l).real
            rates[l_link] = hy - hy_x
            # -------------------------------------------------------------------
            sum_rate += weight[l_link] * rates[l_link]
            
            Sigma_hat[l_link] = weight[l_link] * (torch.linalg.inv(Omega_l) -
                                                torch.linalg.inv(total_Cov_l))
            power_normalizer += torch.trace(Sigma_hat[l_link]).real
        
        # -------------------------------------------------------------------
        # break out of loop if rate change is small
        sum_rate_list.append((sum_rate / torch.log(torch.tensor(2.0))).item())
        
        # if (sum_rate / torch.log(torch.tensor(2.0)) - 
        #     sum_rate_temp / torch.log(torch.tensor(2.0))) < rate_diff:
        #     break
        
        # --------------------------------------------------------------------
        sum_rate_temp = sum_rate
        
        # Normalize Sigma_hat
        for l_link in range(n_links):
            Sigma_hat[l_link] = Sigma_hat[l_link] * power_total / power_normalizer
        
        # --------------------------------------------------------------------
        # calculate forward link Sigma's
        # --------------------------------------------------------------------
        power_normalizer = torch.tensor(0.0, device=device)
        
        for l_link in range(n_links):
            # calculate forward link interference
            total_Cov_hat_l = torch.eye(H_links[l_link, l_link].shape[1], 
                                      dtype=torch.complex64, device=device)
            
            for k_link in range(n_links):
                total_Cov_hat_l += (H_links[k_link, l_link].T.conj() @ 
                                  Sigma_hat[k_link] @ H_links[k_link, l_link])
            
            Omega_hat_l = (total_Cov_hat_l - H_links[l_link, l_link].T.conj() @ 
                          Sigma_hat[l_link] @ H_links[l_link, l_link])
            
            Sigma[l_link] = weight[l_link] * (torch.linalg.inv(Omega_hat_l) -
                                            torch.linalg.inv(total_Cov_hat_l))
            power_normalizer += torch.trace(Sigma[l_link]).real
        
        # Normalize Sigma
        for l_link in range(n_links):
            Sigma[l_link] = Sigma[l_link] * power_total / power_normalizer
    
    return (sum_rate_list, 
            (sum_rate / torch.log(torch.tensor(2.0))),
            (rates / torch.log(torch.tensor(2.0))),
            Sigma, 
            Sigma_hat)

In [None]:
class ChannelCNN(nn.Module):
    def __init__(self, setup):
        super(ChannelCNN, self).__init__()
        self.L = setup.L            # number of outer and inner links (L x L)
        self.n_rx = setup.n_rx      # receive antennas per channel matrix
        self.n_tx = setup.n_tx      # transmit antennas
        self.d = setup.d            # beamforming rank

        in_channels = 2 * self.L * self.L  # real + imag for each of L x L complex matrices

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * self.n_rx * self.n_tx, 256)
        self.fc2 = nn.Linear(256, 2 * self.L * self.n_tx * self.d)  # 2 (real+imag) x L x n_tx x d

    def df_to_tensor(self, x_df):
        """
        Input: DataFrame [B, L], where each cell is a list of L complex matrices [n_rx x n_tx]
        Output: Tensor [B, 2*L*L, n_rx, n_tx]
        """
        B = len(x_df)
        L = self.L
        N_r = self.n_rx
        N_t = self.n_tx

        real_parts = torch.empty((B, L, L, N_r, N_t))
        imag_parts = torch.empty((B, L, L, N_r, N_t))

        for i in range(B):
            for l1 in range(L):
                link_list = x_df.iloc[i, l1]  # list of L complex matrices
                for l2 in range(L):
                    H = link_list[l2]
                    real_parts[i, l1, l2] = H.real
                    imag_parts[i, l1, l2] = H.imag

        # Reshape to [B, 2*L*L, n_rx, n_tx]
        real_flat = real_parts.view(B, L * L, N_r, N_t)
        imag_flat = imag_parts.view(B, L * L, N_r, N_t)
        x_tensor = torch.cat([real_flat, imag_flat], dim=1)
        return x_tensor

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)

        # Reshape to [B, L, 2, n_tx, d]
        B = x.size(0)
        x = x.view(B, self.L, 2, self.n_tx, self.d)

        real = x[:, :, 0, :, :]
        imag = x[:, :, 1, :, :]
        return torch.complex(real, imag)  # [B, L, n_tx, d]

    def predict(self, x_df: pd.DataFrame):
        """
        Input: DataFrame of shape [B, L], where each cell contains a list of L complex matrices
        Output: Dict[str][str] → maps sample and link index to complex beamforming matrix
        """
        self.eval()
        x_tensor = self.df_to_tensor(x_df)
        with torch.no_grad():
            out = self.forward(x_tensor)  # [B, L, n_tx, d]

        # Convert to nested dict
        B = out.shape[0]
        output_dict = {}
        for i in range(B):
            link_dict = {}
            for l in range(self.L):
                link_dict[str(l)] = out[i, l]  # [n_tx, d]
            output_dict[str(i)] = link_dict
        return output_dict

In [None]:
class ChannelCNNTrainer():
    def __init__(self, model: ChannelCNN, setup, lr=1e-3, lambda_power=1.0):
        self.model = model
        self.setup = setup
        self.lambda_power = lambda_power
        self.optimizer = optim.Adam(model.parameters(), lr=lr)

    def power_penalty(self, V_dict):
        """Compute penalty if Tr(V V†) > Pmax for any link."""
        penalty = 0.0
        for sample in V_dict.values():
            for V in sample.values():
                power = torch.real(torch.trace(V @ V.conj().T))
                overflow = torch.relu(power - self.setup.Pmax)
                penalty += overflow
        return penalty

    def train(self, train_df, dual_link_fn, num_epochs=100, batch_size=16):
        self.model.train()
        num_samples = len(train_df)

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            perm = torch.randperm(num_samples)

            for i in range(0, num_samples, batch_size):
                idx = perm[i:i + batch_size]
                batch_df = train_df.iloc[idx]

                self.optimizer.zero_grad()

                # Step 1: Predict initial beamforming V_init from CNN
                V_init_dict = self.model.predict(batch_df)

                # Step 2: Run dual-link to get final beamforming matrices
                V_final_dict = dual_link_fn(V_init_dict, batch_df)

                # Step 3: Compute loss
                power_loss = self.power_penalty(V_final_dict)
                sum_rate = self.compute_sum_rate(V_final_dict, batch_df)
                loss = self.lambda_power * power_loss - sum_rate

                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            print(f"Epoch {epoch + 1} | Loss: {epoch_loss:.4f}")

    def compute_sum_rate(self, V_dict, H_df):
        """Sum-rate computation: ∑ log det(I + HVV†H† / noise)"""
        # Dummy version — you should replace with your system's real sum-rate calc
        total = 0.0
        for i in range(len(H_df)):
            for l in range(self.setup.L):
                H = H_df.iloc[i, l]  # channel matrix
                V = V_dict[str(i)][str(l)]  # beamforming
                signal = H @ V
                power_matrix = signal @ signal.conj().T
                rate = torch.logdet(torch.eye(power_matrix.shape[0]) + power_matrix).real
                total += rate
        return total

In [None]:
def sum_rate_loss(H, V, alpha, sig):
    # Calculate sum rate for single cell
    sum_rate = 0
    for k in range(len(H)):
        Nr = H[str(k)].shape[0]
        # Calculate Omega
        S = 0
        for l in range(len(H)):
            if l == k: pass
            else:
                S += H[str(k)] @ V[str(l)] @ V[str(l)].conj().T @ H[str(k)].conj().T
        S += sig[k] * torch.eye(Nr, dtype=torch.cdouble)
        tmp = torch.eye(Nr, dtype=torch.cdouble) + H[str(k)] @ V[str(k)] @ V[str(k)].conj().T @ H[str(k)].conj().T @ torch.linalg.inv(S)
        R = torch.log2(torch.linalg.det(tmp))
        sum_rate += alpha[k] * R
    return sum_rate

In [None]:
class ChannelCNN(nn.Module):
    def __init__(self, setup):
        super(ChannelCNN, self).__init__()
        self.L = setup.L
        self.n_rx = setup.n_rx
        self.n_tx = setup.n_tx
        self.d = setup.d

        in_channels = 2 * self.L * self.L

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * self.n_rx * self.n_tx, 256)
        self.fc2 = nn.Linear(256, 2 * self.L * self.n_tx * self.d)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)

        # Reshape to [B, L, 2, n_tx, d]
        B = x.size(0)
        x = x.view(B, self.L, 2, self.n_tx, self.d)

        real = x[:, :, 0, :, :]
        imag = x[:, :, 1, :, :]
        return torch.complex(real, imag)  # [B, L, n_tx, d]

    def predict(self, x):
        self.eval()
        with torch.no_grad():
            out = self.forward(x)  # [B, L, n_tx, d]

        return out
    

class ChannelCNNTrainer():
    def __init__(self, model: ChannelCNN, setup, lr=1e-3):
        self.model = model
        self.setup = setup
        self.optimizer = optim.Adam(model.parameters(), lr=lr)

    def train(self, train_list, dual_link_fn, num_epochs=100, batch_size=2):
        self.model.train()
        num_samples = len(train_list)

        def V_to_Sigma(V):
            Sigma = torch.zeros_like(V)
            for i in range(len(V)):
                for j in range(self.setup.L):
                    Sigma[i][j] = V[i, j] @ V[i, j].conj().T
            return Sigma


        def proj_power(Sigma):
            for i in range(len(Sigma)):
                s = 0
                for j in range(self.setup.L):
                    s += torch.trace(Sigma[i, j])
                for j in range(self.setup.L):
                    Sigma[i, j] = (self.setup.PT/s) * Sigma[i, j]
            return Sigma
        
        def sum_rate_loss(Sigma, H_list):
            total = 0
            for i in range(len(H_list)):
                s_rate = 0
                for l in range(H_list[0].shape[0]):
                    Omeg = 0
                    for k in range(H_list[0].shape[0]):
                        if k != l:
                            Omeg += H_list[i][l, k] @ Sigma[k] @ H_list[i][l, k].conj().T
                    Omeg = torch.eye(H_list[0].shape[2]) + Omeg
                    rate = torch.log2(torch.linalg.det(torch.eye(H_list[0].shape[2]) + H_list[i][l, l] @ Sigma[l] @ H_list[i][l, l].conj().T @ torch.linalg.inv(Omeg)))
                    s_rate += rate.real
                total += s_rate
            return total/len(H_list)
        
        def sep_real_imag(x):
            real = x.real
            imag = x.imag
            B, L, _, n_rx, n_tx = real.shape
            real = real.view(B, L * L, n_rx, n_tx)
            imag = imag.view(B, L * L, n_rx, n_tx)
            x_prepared = torch.cat([real, imag], dim=1)
            return x_prepared

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            for i in range(0, num_samples, batch_size):
                batch_list = train_list[i:i + batch_size]
                batch_tensor = torch.stack(batch_list)
                batch_tensor_sep = sep_real_imag(batch_tensor)

                self.optimizer.zero_grad()

                V_init = self.model.forward(batch_tensor_sep)

                Sigma_init = V_to_Sigma(V_init)

                Sigma_proj = proj_power(Sigma_init)

                Sigma_final_list = []

                for j in range(len(Sigma_proj)):
                    _, _, _, Sigma_final, _ = dual_link_fn(power_total=self.setup.PT, weight=torch.ones(self.setup.L), H_links=batch_tensor[j], rate_diff=.001, Sigma=Sigma_proj[j], device=None)

                    Sigma_final_list.append(Sigma_final)

                loss = sum_rate_loss(Sigma_final_list, batch_list)

                loss.backward()

                self.optimizer.step()

                epoch_loss += loss.item()

            print(f"Epoch {epoch + 1} | Loss: {epoch_loss:.4f}")

In [32]:
class setup():
    def __init__(self, L, n_tx, n_rx, d, PT):
        self.L = L
        self.n_rx = n_rx
        self.n_tx = n_tx
        self.d = d
        self.PT = PT

In [49]:
set_up = setup(2, 2, 2, 2, 10)
H_l = [torch.randn(2, 2, 2, 2, dtype=torch.cfloat) for _ in range(1)]
CHCNN = ChannelCNN(set_up)
tr = ChannelCNNTrainer(CHCNN, set_up)
tr.train(train_list=H_l, dual_link_fn=alg_dual_link_torch, num_epochs=100, batch_size=2)

tensor([[[[-0.3946+0.0486j,  0.7317+0.3820j],
          [-0.1764-0.6837j, -0.6579+0.5496j]],

         [[ 1.0405+0.7412j,  0.8382+0.1582j],
          [ 1.2651-1.2669j,  0.0047-0.8160j]]],


        [[[-0.7422-0.3130j,  0.6426+0.6178j],
          [-0.8853+1.5156j,  0.1836-0.7439j]],

         [[ 0.5190+0.5220j,  0.8251+1.7575j],
          [-0.3020-0.4075j, -0.1992+1.0411j]]]])


IndexError: list index out of range

In [22]:
H_l

[tensor([[[-0.4214+0.2341j, -0.7542-0.9747j],
          [-0.2162-0.5214j, -0.4185+0.0505j]],
 
         [[-0.3557-0.3714j,  0.4628-1.7675j],
          [-0.2169+1.1543j,  1.0290+0.3764j]]], dtype=torch.complex128)]

In [23]:
torch.stack(H_l)

tensor([[[[-0.4214+0.2341j, -0.7542-0.9747j],
          [-0.2162-0.5214j, -0.4185+0.0505j]],

         [[-0.3557-0.3714j,  0.4628-1.7675j],
          [-0.2169+1.1543j,  1.0290+0.3764j]]]], dtype=torch.complex128)

In [None]:
def sep_real_imag(x):
    real = x.real
    imag = x.imag
    B, L, _, n_rx, n_tx = real.shape
    real = real.view(B, L * L, n_rx, n_tx)
    imag = imag.view(B, L * L, n_rx, n_tx)
    x_prepared = torch.cat([real, imag], dim=1)
    return x_prepared

In [14]:
x = torch.randn(2, 3, 3, 4, 2, dtype=torch.complex64)
x

tensor([[[[[-0.8338+1.0405e+00j,  0.4757+3.4996e-01j],
           [-0.0303+4.2392e-02j,  0.1230-1.2014e+00j],
           [ 0.3458+5.6647e-01j, -0.4790+4.3923e-02j],
           [-0.5486+4.2100e-01j,  0.1023-4.6791e-01j]],

          [[-0.2507-2.4065e-01j,  0.3092-1.5110e-01j],
           [-1.1487-3.9568e-01j,  0.9772-1.1364e-01j],
           [-0.9313-2.0152e+00j,  0.5331-3.6845e-01j],
           [-0.3436+5.6671e-01j,  0.0743+1.4087e-01j]],

          [[ 0.2113+1.0401e-01j,  0.2039-1.6424e-02j],
           [ 0.3469+7.0318e-01j, -0.6420+2.9779e-03j],
           [-0.5378-1.5603e+00j, -0.6760+7.2648e-01j],
           [-0.7727+2.3655e-01j,  0.3408-7.9605e-01j]]],


         [[[-0.2264+2.4068e-01j, -1.1467+1.3749e+00j],
           [-0.3290+1.1172e+00j, -1.2517+6.2754e-01j],
           [-1.0357-2.3168e-01j, -0.3360+1.9195e-02j],
           [-0.4317-6.3169e-01j,  0.0816-1.5335e-01j]],

          [[ 0.8999+4.8393e-01j, -0.4039+8.9042e-01j],
           [ 0.1793+2.2209e-01j, -0.0104+1.8233e+00j],


In [15]:
x_prepared = sep_real_imag(x)
x_prepared

tensor([[[[-8.3385e-01,  4.7575e-01],
          [-3.0258e-02,  1.2301e-01],
          [ 3.4576e-01, -4.7901e-01],
          [-5.4856e-01,  1.0233e-01]],

         [[-2.5067e-01,  3.0921e-01],
          [-1.1487e+00,  9.7721e-01],
          [-9.3128e-01,  5.3308e-01],
          [-3.4357e-01,  7.4333e-02]],

         [[ 2.1133e-01,  2.0391e-01],
          [ 3.4690e-01, -6.4199e-01],
          [-5.3777e-01, -6.7599e-01],
          [-7.7272e-01,  3.4079e-01]],

         [[-2.2641e-01, -1.1467e+00],
          [-3.2902e-01, -1.2517e+00],
          [-1.0357e+00, -3.3601e-01],
          [-4.3171e-01,  8.1640e-02]],

         [[ 8.9993e-01, -4.0392e-01],
          [ 1.7927e-01, -1.0412e-02],
          [-1.1773e-02, -1.0850e+00],
          [-7.6020e-01, -5.4563e-02]],

         [[ 7.6847e-01, -5.4209e-01],
          [-6.9276e-01, -7.6274e-01],
          [ 1.1837e-01,  5.5798e-02],
          [ 6.7826e-01,  5.7400e-01]],

         [[ 1.2757e+00, -9.8498e-01],
          [-8.8962e-01,  4.0953e-02],


In [8]:
H_list = [torch.randn(5, 5, 3, 2, dtype=torch.cdouble) for _ in range(10)]
Sigma = torch.randn(5, 2, 2, dtype=torch.cdouble)
sum_rate_loss(Sigma, H_list)

tensor(3.5378-0.0980j, dtype=torch.complex128)

In [3]:
import torch

x = torch.tensor(3.0, requires_grad=True)

def f(x):
    z = x**2
    return z

y = f(x)

y.backward()

x.grad

tensor(6.)

In [4]:
len(torch.randn(2, 3, 4))

2