In [None]:
# -----------------------------------------------------------
#  Deep-Unfolded Dual-Link Network with
#   • CNN initialiser (learns Σ(0))
#   • GRU inverse controller  (learns X,Y,α per layer)
#   • XA+Y inverse surrogate
# -----------------------------------------------------------

import torch
import torch.nn as nn
from typing import Dict, List
from torch.utils.data import DataLoader, TensorDataset

# ---------------------------------------------------------------------
#  Problem "setup" helper (antennas, users, power, etc.)
# ---------------------------------------------------------------------
class Setup:
    def __init__(self,
                 n_tx: int,
                 n_rx: int,
                 K: int,
                 streams: int,
                 P_T: float):
        self.n_tx  = n_tx          # BS antennas
        self.n_rx  = n_rx          # antennas per UE (assume same for all)
        self.K     = K             # number of users
        self.d     = streams       # streams/user (assume same)
        self.P_T   = P_T           # total power budget


# ---------------------------------------------------------------------
#  1)  CNN initialiser  H -> Σ(0)
# ---------------------------------------------------------------------
class InitCNN(nn.Module):
    """
    Very small CNN that treats each user identically.
    Input   : real-imag stacked H in shape [B, 2, K, n_rx, n_tx]
    Output  : Σ0        in shape [B, K, n_tx, d]
              (we output beamformers; you may prefer covariances)
    """
    def __init__(self, setup: Setup, hidden=32):
        super().__init__()
        self.setup = setup
        c, k, r, t = 2, setup.K, setup.n_rx, setup.n_tx
        self.conv1 = nn.Conv3d(c, hidden, kernel_size=1)
        self.conv2 = nn.Conv3d(hidden, hidden, kernel_size=1)
        self.relu  = nn.ReLU()
        # final projection to (n_tx × d) per user
        self.out   = nn.Linear(hidden, setup.n_tx * setup.d)

    def forward(self, H_real_imag):              # [B, 2, K, n_rx, n_tx]
        x = self.relu(self.conv1(H_real_imag))   # [B, h, K, r, t]
        x = self.relu(self.conv2(x))             # same
        x = x.mean(dim=[3,4])                    # global pool → [B, h, K]
        x = x.transpose(1,2)                     # [B, K, h]
        out = self.out(x)                        # [B, K, n_tx*d]
        return out.view(-1, self.setup.K,
                        self.setup.n_tx, self.setup.d)  # Σ0 (beamformers)


# ---------------------------------------------------------------------
#  2) GRU-based inverse controller   (shared across layers)
# ---------------------------------------------------------------------
class GRUInverseCell(nn.Module):
    """
    Takes   : diagonal of A (shape [B, K, n_rx])  -- per user
              hidden state  h  (shape [B, hidden])
    Produces: X_diag (B,K,n_rx)   -- we keep X diag for cheap mult
              Y_diag (B,K,n_rx)
              alpha   (B,1,1,1)   -- damping coefficient
              h_next
    """
    def __init__(self, setup: Setup, hidden=64):
        super().__init__()
        self.setup = setup
        self.hidden = hidden
        self.gru   = nn.GRUCell(setup.K * setup.n_rx, hidden)
        self.to_xy = nn.Linear(hidden, 2 * setup.K * setup.n_rx)
        self.to_a  = nn.Linear(hidden, 1)

    def forward(self, A_diag, h):
        B, K, n = A_diag.shape
        flat = A_diag.reshape(B, -1)              # [B, K*n]
        h_next = self.gru(flat, h)                # [B, hidden]

        xy = self.to_xy(h_next)                   # [B, 2*K*n]
        xy = xy.view(B, 2, K, n)                  # [B, 2, K, n]
        X_d, Y_d = xy[:,0], xy[:,1]               # each [B,K,n]
        alpha = torch.sigmoid(self.to_a(h_next))  # (B,1) in (0,1)
        alpha = alpha.view(B,1,1,1)
        return X_d, Y_d, alpha, h_next


# ---------------------------------------------------------------------
#  Helper: normalise power across all users
# ---------------------------------------------------------------------
def project_total_power(V, setup: Setup):
    """
    V : [B, K, n_tx, d]
    Scales each sample to meet ∑_k tr(V_k V_k^H)=P_T
    """
    power = (V.conj().transpose(-1,-2) @ V).real.sum([-1,-2,-3])
    scale = torch.sqrt(setup.P_T / power).view(-1,1,1,1)
    return V * scale


# ---------------------------------------------------------------------
#  3) One unfolded Dual-Link layer
# ---------------------------------------------------------------------
class DLLayer(nn.Module):
    def __init__(self, setup: Setup, hidden=64):
        super().__init__()
        self.setup = setup
        self.inv_cell = GRUInverseCell(setup, hidden)

    # ---------------- core DL math helpers ----------------------------
    def forward(self, V, H, h):
        """
        V: [B,K,n_tx,d]  current beamformers  (Σ in DL pseudocode)
        H: [B,K,n_rx,n_tx] channel per user (complex)
        h: [B,hidden]     GRU hidden state
        """
        B, K, n_tx, d = V.shape
        n_rx = self.setup.n_rx

        # -----  build A_l = I + Σ_{k≠l} H_lk Σ_k Σ_k^H H_lk^H ----------
        # compute Σ_k Σ_k^H once
        VVh = V @ V.conj().transpose(-1,-2)       # [B,K,n_tx,n_tx]

        A_diag = []
        for l in range(K):
            s = VVh.sum(1) - VVh[:,l]             # Σ_k≠l ...
            A_l = torch.eye(n_rx, device=H.device,
                            dtype=H.dtype).expand(B,n_rx,n_rx) + \
                  (H[:,l] @ s @ H[:,l].conj().transpose(-1,-2))
            if l==0:
                A_stack = A_l.unsqueeze(1)
            else:
                A_stack = torch.cat([A_stack, A_l.unsqueeze(1)], dim=1)
            A_diag.append(A_l.diagonal(dim1=-2, dim2=-1))  # [B,n_rx]
        A_diag = torch.stack(A_diag, dim=1)        # [B,K,n_rx]

        # -----  inverse surrogate via GRU cell -------------------------
        X_d, Y_d, alpha, h_next = self.inv_cell(A_diag.real, h)  # ← GRU works on reals

        # Build X and Y as diagonal matrices
        X = torch.zeros_like(A_stack)
        Y = torch.zeros_like(A_stack)
        for l in range(K):
            X[:,l].diagonal(dim1=-2, dim2=-1).copy_(X_d[:,l])
            Y[:,l].diagonal(dim1=-2, dim2=-1).copy_(Y_d[:,l])

        A_inv_hat = X @ A_stack + Y               # [B,K,n_rx,n_rx]

        # -----  DL receiver & weight update (lines 5-7) ----------------
        U = torch.einsum('bknm,bmtd->bknd', A_inv_hat, H.conj()) @ V
        E = torch.eye(d, device=H.device, dtype=H.dtype) - \
            torch.einsum('bkdn,bknm,bmtd->bkdt', U.conj(), H, V)
        # use diagonal surrogate for W (could reuse GRU)
        diag_E_inv = 1 / E.diagonal(dim1=-2, dim2=-1)
        W = torch.zeros_like(E)
        for l in range(K):
            W[:,l].diagonal(dim1=-2, dim2=-1).copy_(diag_E_inv[:,l])

        # -----  Build B = I + Σ_k H_k^H U_k W_k U_k^H H_k  ------------
        UWU = U @ W @ U.conj().transpose(-1,-2)   # [B,K,n_rx,n_rx]
        B_mat = torch.eye(n_tx, device=H.device,
                          dtype=H.dtype).expand(B,n_tx,n_tx) + \
                (H.conj().transpose(-1,-2) @ UWU @ H).sum(1)

        # -----  Inverse of B via the *same* surrogate ------------------
        B_diag = B_mat.diagonal(dim1=-2, dim2=-1)
        Xb_d, Yb_d, _, _ = self.inv_cell(B_diag.real, h)  # reuse weights, no grad through alpha
        Xb = torch.zeros_like(B_mat)
        Yb = torch.zeros_like(B_mat)
        Xb.diagonal(dim1=-2, dim2=-1).copy_(Xb_d[:,0])
        Yb.diagonal(dim1=-2, dim2=-1).copy_(Yb_d[:,0])
        B_inv_hat = Xb @ B_mat + Yb

        # ----- new beamformers  (line 8) + damping ---------------------
        V_new = torch.einsum('bnm,bmd->bnd', B_inv_hat,
                 (H.conj().transpose(-1,-2) @ U @ W).sum(1)
                 ).view(B,1,n_tx,d).repeat(1,K,1,1)

        V = alpha * V_new + (1-alpha) * V         # recurrence / damping
        V = project_total_power(V, self.setup)
        return V, h_next


# ---------------------------------------------------------------------
#  4) Full deep-unfolded network (L layers)
# ---------------------------------------------------------------------
class DLNet(nn.Module):
    def __init__(self, setup: Setup,
                 num_layers: int = 6,
                 hidden: int = 64):
        super().__init__()
        self.setup     = setup
        self.init_cnn  = InitCNN(setup)
        self.layers    = nn.ModuleList(
            [DLLayer(setup, hidden) for _ in range(num_layers)]
        )
        self.hidden_dim = hidden

    def forward(self, H):
        """
        H : complex tensor [B,K,n_rx,n_tx]
        """
        B = H.size(0)
        # stack real & imag for CNN
        H_ri = torch.stack([H.real, H.imag], dim=1)  # [B,2,K,n_rx,n_tx]

        V = self.init_cnn(H_ri).to(H.dtype)          # Σ(0)
        V = project_total_power(V, self.setup)

        h = torch.zeros(B, self.hidden_dim, device=H.device)
        for layer in self.layers:
            V, h = layer(V, H, h)
        return V


# ---------------------------------------------------------------------
#  Loss functions
# ---------------------------------------------------------------------
def sum_rate_loss(H, V, setup: Setup):
    """
    Negative weighted sum-rate.
    H: [B,K,n_rx,n_tx]   V: [B,K,n_tx,d]
    """
    B,K,n_rx,n_tx = H.shape
    d = setup.d
    I = torch.eye(n_rx, device=H.device, dtype=H.dtype)

    rate = 0.
    for l in range(K):
        HlVl = H[:,l] @ V[:,l]                     # [B,n_rx,d]
        Sig_l = HlVl @ HlVl.conj().transpose(-1,-2)
        noise = I
        for k in range(K):
            if k != l:
                HkVk = H[:,l] @ V[:,k]
                noise = noise + HkVk @ HkVk.conj().transpose(-1,-2)
        SINR = noise.linalg.solve(Sig_l)           # (noise⁻¹ Sig)
        rate += torch.logdet(I + SINR).real
    return -rate.mean()


def inversion_loss(A_inv, A):
    return ((A_inv @ A -
             torch.eye(A.size(-1),
                       device=A.device,
                       dtype=A.dtype)
            )**2).sum(dim=[-2,-1]).mean()


# ---------------------------------------------------------------------
#  Example training loop
# ---------------------------------------------------------------------
def train(model: DLNet,
          loader: DataLoader,
          setup: Setup,
          epochs_warmup=2,
          epochs_main=8,
          lr=3e-4):

    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs_warmup+epochs_main):
        for H_batch in loader:
            H_batch = H_batch.to(torch.complex64)
            opt.zero_grad()
            V_pred = model(H_batch)

            loss = sum_rate_loss(H_batch, V_pred, setup)
            # you can add a small inversion regulariser if unstable
            loss.backward()
            opt.step()

        phase = "warm-up" if epoch<epochs_warmup else "main"
        print(f"Epoch {epoch:02d} ({phase})  loss = {loss.item():.4f}")


# ---------------------------------------------------------------------
#  Usage example (dummy data)
# ---------------------------------------------------------------------
if __name__ == "__main__":
    setup = Setup(n_tx=8, n_rx=2, K=4, streams=2, P_T=1.0)
    B = 128                                    # batch-size for toy data
    H_fake = (torch.randn(B,setup.K,setup.n_rx,setup.n_tx) +
              1j*torch.randn(B,setup.K,setup.n_rx,setup.n_tx)) / 3**0.5

    dataset = TensorDataset(H_fake)
    loader  = DataLoader(dataset, batch_size=32, shuffle=True)

    net = DLNet(setup, num_layers=6, hidden=64)
    train(net, loader, setup)
