In [2]:
import os
import sys
import json
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [25]:
def alg_dual_link_torch(power_total: float, weight: torch.Tensor, H_links: torch.Tensor,
                  rate_diff: float, Sigma=None, device=None):
    """
    PyTorch version of dual-link algorithm
    
    Args:
        power_total: Total power constraint
        weight: Weight tensor of shape (n_links,)
        H_links: Channel matrix tensor of shape (n_links, n_links, n_rx, n_tx)
        rate_diff: Rate difference threshold for convergence
        Sigma: Initial covariance matrices (optional)
        device: Device to run computation on
    
    Returns:
        sum_rate_list: List of sum rates per iteration
        final_sum_rate: Final sum rate
        rates: Individual link rates
        Sigma: Final covariance matrices
        Sigma_hat: Final dual covariance matrices
    """
    if device is None:
        device = H_links.device
    
    # Move tensors to device
    H_links = H_links.to(device)
    weight = weight.to(device)
    
    #------------------------------------------------
    # find basic parameters
    #------------------------------------------------
    n_links = H_links.shape[0]
    rates = torch.zeros(n_links, device=device)
    Sigma_hat = [None] * n_links
    
    #------------------------------------------------
    # If initial Sigma not given, make it identity
    #------------------------------------------------
    if Sigma is None:
        Sigma = [None] * n_links
        for l_link in range(n_links):
            lt_l = H_links[l_link, l_link].shape[1]
            # power constraint may not be satisfied
            Sigma[l_link] = (power_total / lt_l / n_links * 
                           torch.eye(lt_l, dtype=torch.complex64, device=device))
    
    #---------------------------------------------
    # repeat until rate change is small
    sum_rate_temp = torch.tensor(-float('inf'), device=device)
    sum_rate_list = []
    
    num_epochs = 0
    while num_epochs < 1:
        num_epochs += 1
        # calculate reverse link Sigma_hat's
        power_normalizer = torch.tensor(0.0, device=device)
        sum_rate = torch.tensor(0.0, device=device)
        
        for l_link in range(n_links):
            # calculate forward link interference
            total_Cov_l = torch.eye(H_links[l_link, l_link].shape[0], 
                                  dtype=torch.complex64, device=device)
            
            for k_link in range(n_links):
                total_Cov_l += (H_links[l_link, k_link] @ Sigma[k_link] @ 
                              H_links[l_link, k_link].T.conj())
            
            Omega_l = (total_Cov_l - H_links[l_link, l_link] @ Sigma[l_link] @ 
                      H_links[l_link, l_link].T.conj())
            
            # -------------------------------------------------------------------
            # calculate the rates using log determinant
            hy = torch.logdet(total_Cov_l).real
            hy_x = torch.logdet(Omega_l).real
            rates[l_link] = hy - hy_x
            # -------------------------------------------------------------------
            sum_rate += weight[l_link] * rates[l_link]
            
            Sigma_hat[l_link] = weight[l_link] * (torch.linalg.inv(Omega_l) -
                                                torch.linalg.inv(total_Cov_l))
            power_normalizer += torch.trace(Sigma_hat[l_link]).real
        
        # -------------------------------------------------------------------
        # break out of loop if rate change is small
        sum_rate_list.append((sum_rate / torch.log(torch.tensor(2.0))).item())
        
        # if (sum_rate / torch.log(torch.tensor(2.0)) - 
        #     sum_rate_temp / torch.log(torch.tensor(2.0))) < rate_diff:
        #     break
        
        # --------------------------------------------------------------------
        sum_rate_temp = sum_rate
        
        # Normalize Sigma_hat
        for l_link in range(n_links):
            Sigma_hat[l_link] = Sigma_hat[l_link] * power_total / power_normalizer
        
        # --------------------------------------------------------------------
        # calculate forward link Sigma's
        # --------------------------------------------------------------------
        power_normalizer = torch.tensor(0.0, device=device)
        
        for l_link in range(n_links):
            # calculate forward link interference
            total_Cov_hat_l = torch.eye(H_links[l_link, l_link].shape[1], 
                                      dtype=torch.complex64, device=device)
            
            for k_link in range(n_links):
                total_Cov_hat_l += (H_links[k_link, l_link].T.conj() @ 
                                  Sigma_hat[k_link] @ H_links[k_link, l_link])
            
            Omega_hat_l = (total_Cov_hat_l - H_links[l_link, l_link].T.conj() @ 
                          Sigma_hat[l_link] @ H_links[l_link, l_link])
            
            Sigma[l_link] = weight[l_link] * (torch.linalg.inv(Omega_hat_l) -
                                            torch.linalg.inv(total_Cov_hat_l))
            power_normalizer += torch.trace(Sigma[l_link]).real
        
        # Normalize Sigma
        for l_link in range(n_links):
            Sigma[l_link] = Sigma[l_link] * power_total / power_normalizer
    
    return (sum_rate_list, 
            (sum_rate / torch.log(torch.tensor(2.0))),
            (rates / torch.log(torch.tensor(2.0))),
            Sigma, 
            Sigma_hat)

In [None]:
class ChannelCNN(nn.Module):
    def __init__(self, setup):
        super(ChannelCNN, self).__init__()
        self.L = setup.L
        self.n_rx = setup.n_rx
        self.n_tx = setup.n_tx
        self.d = setup.d

        in_channels = 2 * self.L * self.L

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)

        self.fc1 = nn.Linear(64 * self.n_rx * self.n_tx, 256)
        self.fc2 = nn.Linear(256, 2 * self.L * self.n_tx * self.d)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)

        # Reshape to [B, L, 2, n_tx, d]
        B = x.size(0)
        x = x.view(B, self.L, 2, self.n_tx, self.d)

        real = x[:, :, 0, :, :]
        imag = x[:, :, 1, :, :]
        return torch.complex(real, imag)  # [B, L, n_tx, d]

    def predict(self, x):
        self.eval()
        with torch.no_grad():
            out = self.forward(x)  # [B, L, n_tx, d]

        return out
    

class ChannelCNNTrainer():
    def __init__(self, model: ChannelCNN, setup, lr=1e-3):
        self.model = model
        self.setup = setup
        self.optimizer = optim.Adam(model.parameters(), lr=lr)

    def train(self, train_list, dual_link_fn, num_epochs=100, batch_size=2):
        self.model.train()
        num_samples = len(train_list)

        # def V_to_Sigma(V):
        #     Sigma = torch.zeros_like(V)
        #     for i in range(len(V)):
        #         for j in range(self.setup.L):
        #             Sigma[i][j] = V[i, j] @ V[i, j].conj().T
        #     return Sigma

        def V_to_Sigma(V):
            """Convert V matrices to Sigma covariance matrices"""
            Sigma = []
            for i in range(len(V)):
                Sigma_i = []
                for j in range(self.setup.L):
                    # Create new tensor instead of modifying in-place
                    sigma_j = V[i, j] @ V[i, j].conj().T
                    Sigma_i.append(sigma_j)
                Sigma.append(torch.stack(Sigma_i))
            return torch.stack(Sigma)


        # def proj_power(Sigma):
        #     for i in range(len(Sigma)):
        #         s = 0
        #         for j in range(self.setup.L):
        #             s += torch.trace(Sigma[i, j])
        #         for j in range(self.setup.L):
        #             Sigma[i, j] = (self.setup.PT/s) * Sigma[i, j]
        #     return Sigma
        def proj_power(Sigma):
            """Project Sigma matrices to satisfy power constraint"""
            Sigma_proj = []
            for i in range(len(Sigma)):
                # Calculate total power for this sample
                total_power = torch.tensor(0.0, dtype=torch.float32)
                for j in range(self.setup.L):
                    total_power = total_power + torch.trace(Sigma[i][j]).real
                
                # Create new normalized matrices (avoid in-place operations)
                Sigma_proj_i = []
                for j in range(self.setup.L):
                    # Create new tensor instead of modifying in-place
                    sigma_normalized = (self.setup.PT / total_power) * Sigma[i][j]
                    Sigma_proj_i.append(sigma_normalized)
                Sigma_proj.append(torch.stack(Sigma_proj_i))
            return torch.stack(Sigma_proj)
        
        def sum_rate_loss(Sigma, H_list):
            total = 0
            for i in range(len(H_list)):
                s_rate = 0
                for l in range(H_list[0].shape[0]):
                    Omeg = 0
                    for k in range(H_list[0].shape[0]):
                        if k != l:
                            Omeg += H_list[i][l, k] @ Sigma[i][k] @ H_list[i][l, k].conj().T
                    Omeg = torch.eye(H_list[0].shape[2]) + Omeg
                    rate = torch.log2(torch.linalg.det(torch.eye(H_list[0].shape[2]) + H_list[i][l, l] @ Sigma[i][l] @ H_list[i][l, l].conj().T @ torch.linalg.inv(Omeg)))
                    s_rate += rate.real
                total += s_rate
            return total/len(H_list)
        
        def sep_real_imag(x):
            real = x.real
            imag = x.imag
            B, L, _, n_rx, n_tx = real.shape
            real = real.view(B, L * L, n_rx, n_tx)
            imag = imag.view(B, L * L, n_rx, n_tx)
            x_prepared = torch.cat([real, imag], dim=1)
            return x_prepared

        for epoch in range(num_epochs):
            epoch_loss = 0.0
            for i in range(0, num_samples, batch_size):
                batch_list = train_list[i:i + batch_size]
                batch_tensor = torch.stack(batch_list)
                batch_tensor_sep = sep_real_imag(batch_tensor)

                self.optimizer.zero_grad()

                V_init = self.model.forward(batch_tensor_sep)

                Sigma_init = V_to_Sigma(V_init)

                Sigma_proj = proj_power(Sigma_init)

                Sigma_final_list = []

                for j in range(len(Sigma_proj)):
                    _, _, _, Sigma_final, _ = dual_link_fn(power_total=self.setup.PT, weight=torch.ones(self.setup.L), H_links=batch_tensor[j], rate_diff=.001, Sigma=Sigma_proj[j], device=None)

                    Sigma_final_list.append(Sigma_final)

                loss = -1 * sum_rate_loss(Sigma_final_list, batch_list)

                loss.backward()

                self.optimizer.step()

                epoch_loss += loss.item()

            print(f"Epoch {epoch + 1} | Loss: {epoch_loss:.4f}")

In [None]:
class setup():
    def __init__(self, L, n_tx, n_rx, d, PT):
        self.L = L
        self.n_rx = n_rx
        self.n_tx = n_tx
        self.d = d
        self.PT = PT

In [36]:
set_up = setup(3, 2, 2, 2, 100)
# H_l = [torch.randn(3, 3, 2, 2, dtype=torch.cfloat) for _ in range(1)]
CHCNN = ChannelCNN(set_up)
tr = ChannelCNNTrainer(CHCNN, set_up, lr=1e-2)
tr.train(train_list=H_l, dual_link_fn=alg_dual_link_torch, num_epochs=1000, batch_size=2)

Epoch 1 | Loss: -6.3336
Epoch 2 | Loss: -8.9640
Epoch 3 | Loss: -11.8240
Epoch 4 | Loss: -12.5363
Epoch 5 | Loss: -13.1393
Epoch 6 | Loss: -13.8673
Epoch 7 | Loss: -14.7271
Epoch 8 | Loss: -14.9656
Epoch 9 | Loss: -15.4253
Epoch 10 | Loss: -15.5965
Epoch 11 | Loss: -15.6489
Epoch 12 | Loss: -15.7284
Epoch 13 | Loss: -15.7740
Epoch 14 | Loss: -15.7921
Epoch 15 | Loss: -15.8212
Epoch 16 | Loss: -15.8641
Epoch 17 | Loss: -15.9011
Epoch 18 | Loss: -15.9208
Epoch 19 | Loss: -15.9272
Epoch 20 | Loss: -15.9269
Epoch 21 | Loss: -15.9248
Epoch 22 | Loss: -15.9232
Epoch 23 | Loss: -15.9224
Epoch 24 | Loss: -15.9211
Epoch 25 | Loss: -15.9184
Epoch 26 | Loss: -15.9158
Epoch 27 | Loss: -15.9160
Epoch 28 | Loss: -15.9198
Epoch 29 | Loss: -15.9251
Epoch 30 | Loss: -15.9286
Epoch 31 | Loss: -15.9287
Epoch 32 | Loss: -15.9269
Epoch 33 | Loss: -15.9262
Epoch 34 | Loss: -15.9283
Epoch 35 | Loss: -15.9322
Epoch 36 | Loss: -15.9357
Epoch 37 | Loss: -15.9380
Epoch 38 | Loss: -15.9392
Epoch 39 | Loss: -15.94