## CTDRN


In [None]:
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
print("Device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))


PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
cuDNN version: 90100
Device count: 1
GPU Name: NVIDIA L40S


===== Phase 1: Core Complex Layers for CTDCRN (PyTorch) =====
Implements:
- ComplexConv1dPacked  (Eq. 9–10 style complex 1D conv in a single call)
- ComplexGlobalLayerNorm (complex global LN per channel across time)
- ComplexLeakyReLU (separate real/imag activations)
- ComplexLSTM (Eq. 11–12 from the paper)
- Quick sanity test to validate shapes & numerical behavior

Input/Output convention everywhere: (B, 2, C, T)
    dim=1 is [real, imag] stacked; C=channels/feature maps, T=time.

In [None]:


import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- housekeeping: good defaults for your LS-class GPU ---
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")  # PyTorch ≥ 2.0

# If you're on a 30-core CPU, use them for CPU-side ops:
try:
    torch.set_num_threads(30)
except Exception:
    pass


# ---------- helpers: pack/unpack complex ----------
def _stack_complex(r: torch.Tensor, i: torch.Tensor) -> torch.Tensor:
    """Stack real & imag to channel-1 => shape (B, 2, C, T)."""
    return torch.stack([r, i], dim=1)

def _split_complex(x: torch.Tensor):
    """Split (B, 2, C, T) into (r, i) each (B, C, T)."""
    assert x.dim() == 4 and x.size(1) == 2, f"Expected (B,2,C,T), got {tuple(x.shape)}"
    return x[:, 0].contiguous(), x[:, 1].contiguous()


# ---------- activation ----------
class ComplexLeakyReLU(nn.Module):
    """Apply LeakyReLU independently to real & imaginary parts."""
    def __init__(self, negative_slope: float = 0.01, inplace: bool = False):
        super().__init__()
        self.relu = nn.LeakyReLU(negative_slope=negative_slope, inplace=inplace)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r, i = _split_complex(x)
        return _stack_complex(self.relu(r), self.relu(i))


class ComplexGlobalLayerNormPhasePreserving(nn.Module):
    """
    Complex Global Layer Norm (phase-preserving).
    Uses a single shared variance (r and i normalized together)
    to preserve the phase angle of each complex activation.
    Input/Output: (B, 2, C, T)
    """
    def __init__(self, num_channels: int, eps: float = 1e-8, affine: bool = True):
        super().__init__()
        self.eps = eps
        self.affine = affine
        if affine:
            self.gamma = nn.Parameter(torch.ones(num_channels))
            self.beta_r = nn.Parameter(torch.zeros(num_channels))
            self.beta_i = nn.Parameter(torch.zeros(num_channels))
        else:
            self.register_parameter("gamma", None)
            self.register_parameter("beta_r", None)
            self.register_parameter("beta_i", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r, i = _split_complex(x)  # (B,C,T)
        mean_r = r.mean(dim=-1, keepdim=True)
        mean_i = i.mean(dim=-1, keepdim=True)
        r_c = r - mean_r
        i_c = i - mean_i

        # Shared variance across r and i to preserve phase
        var = (r_c.pow(2) + i_c.pow(2)).mean(dim=-1, keepdim=True)
        inv_std = torch.rsqrt(var + self.eps)

        r_n = r_c * inv_std
        i_n = i_c * inv_std

        if self.affine:
            g = self.gamma.view(1, -1, 1)
            br = self.beta_r.view(1, -1, 1)
            bi = self.beta_i.view(1, -1, 1)
            r_n = r_n * g + br
            i_n = i_n * g + bi

        return _stack_complex(r_n, i_n)

def _same_padding_1d(kernel_size: int, dilation: int) -> int:
    # preserves length for stride=1, odd kernels
    return ((kernel_size - 1) * dilation) // 2


# ---------- complex conv1d (single kernel call, fast) ----------
class ComplexConv1dPacked(nn.Module):
    """
    Complex 1D convolution using a single F.conv1d on concatenated inputs.

    Given Wr, Wi ∈ R^{Cout × Cin × K} and input (xr, xi) ∈ R^{B×Cin×T}:
      y_r = conv(xr, Wr) - conv(xi, Wi)
      y_i = conv(xi, Wr) + conv(xr, Wi)
    We implement this by building the block weight:
      [[Wr, -Wi],
       [Wi,  Wr]]
    and calling conv1d once on [xr; xi] ∈ R^{B×(2Cin)×T}.

    Input:  (B,2,Cin,T)
    Output: (B,2,Cout,T)
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride: int = 1, padding: Optional[int] = None,
                 dilation: int = 1, bias: bool = True):
        super().__init__()
        if padding is None:
            padding = _same_padding_1d(kernel_size, dilation)   # <-- FIX
        self.in_c = in_channels
        self.out_c = out_channels
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        # Learn real/imag filters
        self.Wr = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size))
        self.Wi = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size))

        if bias:
            self.br = nn.Parameter(torch.zeros(out_channels))
            self.bi = nn.Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('br', None)
            self.register_parameter('bi', None)

        self.reset_parameters()

# In model.py, inside ComplexConv1dPacked AND ComplexConvTranspose1dPacked

    def reset_parameters(self):
        # 1. Use correct Kaiming for LeakyReLU
        nn.init.kaiming_uniform_(self.Wr, a=0.01, nonlinearity='leaky_relu')
        nn.init.kaiming_uniform_(self.Wi, a=0.01, nonlinearity='leaky_relu')
        
        # 2. Scale down by 1/sqrt(2) because Complex = Real + Imag
        #    (Var(Real) + Var(Imag) = 2 * Var. We want 1 * Var.)
        with torch.no_grad():
            self.Wr.div_(math.sqrt(2))
            self.Wi.div_(math.sqrt(2))

        if hasattr(self, 'br') and self.br is not None:
            nn.init.zeros_(self.br)
            nn.init.zeros_(self.bi)

            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r, i = _split_complex(x)  # (B,Cin,T)
        xin = torch.cat([r, i], dim=1)  # (B, 2Cin, T)

        # Assemble block weight: (2*Cout, 2*Cin, K)
        top = torch.cat([self.Wr, -self.Wi], dim=1)
        bot = torch.cat([self.Wi,  self.Wr], dim=1)
        W = torch.cat([top, bot], dim=0)

        # Bias maps to [br - bi, br + bi]
        if self.br is not None:
            # FIX: map biases to [real, imag] blocks directly
            b = torch.cat([self.br, self.bi], dim=0)
        else:
            b = None

        y = F.conv1d(
            xin, W, b,
            stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1
        )  # (B, 2*Cout, T)

        yr, yi = torch.split(y, self.out_c, dim=1)
        return _stack_complex(yr, yi)


# ---------- complex LSTM (two real LSTMs) ----------
class ComplexLSTM(nn.Module):
    """
    Correct CLSTM per paper:
      Lr = LSTMr(Pr) - LSTMi(Pi)
      Li = LSTMr(Pi) + LSTMi(Pr)
    I/O: (B,2,C,T) -> (B,2,H,T)
    """
    def __init__(self, input_size: int, hidden_size: int,
                 num_layers: int = 1, bias: bool = True,
                 bidirectional: bool = False, dropout: float = 0.0):
        super().__init__()
        self.hr = nn.LSTM(input_size, hidden_size, num_layers=num_layers, bias=bias,
                          batch_first=True, bidirectional=bidirectional, dropout=dropout)
        self.hi = nn.LSTM(input_size, hidden_size, num_layers=num_layers, bias=bias,
                          batch_first=True, bidirectional=bidirectional, dropout=dropout)
        self.out_channels = hidden_size * (2 if bidirectional else 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r, i = _split_complex(x)                # (B,C,T)
        r_t = r.transpose(1, 2).contiguous()    # (B,T,C) == P_r
        i_t = i.transpose(1, 2).contiguous()    # (B,T,C) == P_i

        # Run hr on [Pr, Pi] in one go, then split
        hr_in = torch.cat([r_t, i_t], dim=0)    # (2B, T, C)
        hr_out, _ = self.hr(hr_in)              # (2B, T, H)
        hr_Pr, hr_Pi = torch.chunk(hr_out, 2, dim=0)  # each (B, T, H)

        # Run hi on [Pi, Pr] in one go (note the order), then split
        hi_in = torch.cat([i_t, r_t], dim=0)    # (2B, T, C)
        hi_out, _ = self.hi(hi_in)              # (2B, T, H)
        hi_Pi, hi_Pr = torch.chunk(hi_out, 2, dim=0)  # each (B, T, H)

        # Mix per the equations
        Lr = hr_Pr - hi_Pi                      # (B, T, H)
        Li = hr_Pi + hi_Pr                      # (B, T, H)

        # Back to (B,2,H,T)
        Lr = Lr.transpose(1, 2).contiguous()
        Li = Li.transpose(1, 2).contiguous()
        return _stack_complex(Lr, Li)


# ---------- quick sanity test ----------
def _sanity_tests(device="cpu"):
    # === Option A: keep channels constant across conv/norm/act/LSTM ===
    # Choose a working width C (paper-style values: 32/64/128 are common)
    C = 64
    T = 128
    B = 4

    x = torch.randn(B, 2, C, T)  # pretend this is CHE output with C channels

    conv = ComplexConv1dPacked(in_channels=C, out_channels=C, kernel_size=3)   # C -> C
    # Use the PHASE-PRESERVING normalization we discussed:
    norm = ComplexGlobalLayerNormPhasePreserving(num_channels=C)
    act  = ComplexLeakyReLU(negative_slope=0.01)  # tweak to 0.05/0.1 if training saturates
    lstm = ComplexLSTM(input_size=C, hidden_size=C, num_layers=1)              # C -> C

    y1 = conv(x)    ; assert y1.shape == (B, 2, C, T)
    y2 = norm(y1)   ; assert y2.shape == (B, 2, C, T)
    y3 = act(y2)    ; assert y3.shape == (B, 2, C, T)
    y4 = lstm(y3)   ; assert y4.shape == (B, 2, C, T)

    print("OK — channels preserved across all blocks:", y4.shape)


_sanity_tests("cpu")


OK — channels preserved across all blocks: torch.Size([4, 2, 64, 128])


===== Phase 2: Mid-Stack Encoding & Dilated Residual Blocks for CTDCRN (PyTorch) =====
Implements:

CHE — Complex Hierarchical Encoder (Fig. 2): Conv(C→M) → cGLN (phase-preserving) → Conv(M→C)

CDCM — Complex Dilated Convolution Module (Fig. 3): Conv → cLeakyReLU → cGLN → D-Conv(d) → cLeakyReLU → cGLN → Conv + residual

CDCMStack — TasNet-style dilation schedule repeated across N layers

Utility make_dilations — cycles exponential dilations (e.g., 1,2,4,8,16,32,64,128, …)

Input/Output convention everywhere: (B, 2, C, T)
 dim=1 packs [real, imag]; C = complex channels (feature maps), T = time frames. All ops are complex-aware and channel-preserving unless stated.

In [None]:
import torch
import torch.nn as nn

# ---------- CHE (Complex Hierarchical Encoder) ----------
class CHE(nn.Module):
    """
    Complex Hierarchical Encoder (paper Fig. 2):
      Conv(C_in→M) → ComplexNorm → Conv(M→C_out)
    No activation inside (paper shows only normalization).
    """
    def __init__(self, in_ch: int, mid_ch: int, out_ch: int, ksize: int = 3):
        super().__init__()
        self.conv1 = ComplexConv1dPacked(in_ch, mid_ch, kernel_size=ksize, stride=1)
        self.norm  = ComplexGlobalLayerNormPhasePreserving(mid_ch)
        self.conv2 = ComplexConv1dPacked(mid_ch, out_ch, kernel_size=ksize, stride=1)

    def forward(self, x):
        y = self.conv1(x)
        y = self.norm(y)
        y = self.conv2(y)
        return y



# ---------- CDCM (Complex Dilated Convolution Module) ----------
class CDCM(nn.Module):
    """
    CDCM block (Fig. 3, Sec. C):
      ComplexConv(C->C) -> LeakyReLU -> cGLN ->
      Complex Dilated Conv(C->C, dilation=d) ->
      LeakyReLU -> cGLN -> ComplexConv(C->C)
      Residual sum with input (real+imag added respectively).
    Notes:
      • D-CONV = complex conv with dilation "d".
      • Keep channels constant (Option A) so residual x + f(x) is valid.
    """
    def __init__(self, channels: int, ksize: int = 3, dilation: int = 1, slope: float = 0.01):
        super().__init__()
        C = channels
        # 1) complex conv (no dilation)
        self.conv1 = ComplexConv1dPacked(C, C, kernel_size=ksize, stride=1, dilation=1)
        self.norm1 = ComplexGlobalLayerNormPhasePreserving(C)
        self.act1  = ComplexLeakyReLU(negative_slope=slope)

        # 2) complex dilated conv (D-CONV in the paper)
        self.dconv = ComplexConv1dPacked(C, C, kernel_size=ksize, stride=1, dilation=dilation)
        self.norm2 = ComplexGlobalLayerNormPhasePreserving(C)
        self.act2  = ComplexLeakyReLU(negative_slope=slope)

        # 3) final complex conv to produce residual branch output
        self.conv2 = ComplexConv1dPacked(C, C, kernel_size=ksize, stride=1, dilation=1)

    def forward(self, x):
        y = self.conv1(x)
        y = self.norm1(self.act1(y))

        y = self.dconv(y)
        y = self.norm2(self.act2(y))

        y = self.conv2(y)
        # Residual: sum real and imag parts respectively
        return x + y


# ---------- Utility: make a dilation schedule like TasNet / paper ----------
def make_dilations(n_layers: int, base_cycle=(1, 2, 4, 8, 16, 32, 64, 128)):
    """
    Repeat an exponential dilation cycle to length n_layers.
    Example: n_layers=12 -> [1,2,4,8,16,32,64,128,1,2,4,8]
    """
    dil = []
    while len(dil) < n_layers:
        dil.extend(base_cycle)
    return dil[:n_layers]


# ---------- CDCM stack (sequential modules with residual inside each) ----------
class CDCMStack(nn.Module):
    """
    A sequence of CDCM blocks with a chosen dilation schedule.
    """
    def __init__(self, channels: int, n_layers: int,
                 ksize: int = 3, slope: float = 0.01,
                 base_cycle=(1,2,4,8,16,32,64,128)):
        super().__init__()
        dilations = make_dilations(n_layers, base_cycle)
        self.blocks = nn.ModuleList([
            CDCM(channels=channels, ksize=ksize, dilation=d, slope=slope)
            for d in dilations
        ])

    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x


# ---------- Quick sanity aligning with the paper ----------
def _phase2_sanity():
    torch.manual_seed(0)
    B, C, T = 2, 64, 256
    x = torch.randn(B, 2, C, T, device="cpu")

    # CHE: expand to M then reduce back to C
    M = 128
    che = CHE(in_ch=C, mid_ch=M, out_ch=C, ksize=3)
    y = che(x); assert y.shape == (B,2,C,T)

    # CDCM: keep channels constant, residual inside
    cdcm = CDCM(channels=C, ksize=3, dilation=4, slope=0.01)
    z = cdcm(y); assert z.shape == (B,2,C,T)

    # Stack a few CDCMs with TasNet-style dilations
    stack = CDCMStack(channels=C, n_layers=6, ksize=3, slope=0.01)
    o = stack(z); assert o.shape == (B,2,C,T)

    return {"che_out": tuple(y.shape), "cdcm_out": tuple(z.shape), "stack_out": tuple(o.shape)}

print("Phase-2 sanity:", _phase2_sanity())


Phase-2 sanity: {'che_out': (2, 2, 64, 256), 'cdcm_out': (2, 2, 64, 256), 'stack_out': (2, 2, 64, 256)}


In [None]:
import math
from typing import Optional, Dict, Any, List
import torch
import torch.nn as nn
import torch.nn.functional as F  # (2) import safety

# assumes Phase-1/2 symbols exist in scope:
# _stack_complex, _split_complex, ComplexLeakyReLU,
# ComplexGlobalLayerNormPhasePreserving, ComplexConv1dPacked,
# ComplexLSTM, CDCM, CDCMStack, CHE, apply_crm_eq13, _same_padding_1d

# -------------------------------
# Mask head (independent CRMs)
# -------------------------------
class MaskHeadCRM(nn.Module):
    """Predict complex masks per source independently. Output: (B, 2, S, C, T)"""
    def __init__(self, channels: int, n_sources: int = 2):
        super().__init__()
        self.S = n_sources
        self.head = ComplexConv1dPacked(channels, channels * n_sources, kernel_size=1, stride=1, dilation=1)
        self.act  = nn.Tanh()  # bound per-part masks ∈ [-1, 1]

    def forward(self, feat: torch.Tensor) -> torch.Tensor:
        m_raw = self.head(feat)                         # (B,2,S*C,T)
        B, two, SC, T = m_raw.shape
        C = SC // self.S
        return self.act(m_raw.view(B, two, self.S, C, T))


# -----------------------------------------
# Complex Transposed Conv (packed) utility
# -----------------------------------------
class ComplexConvTranspose1dPacked(nn.Module):
    """
    Complex transposed 1D conv using a single conv_transpose1d on packed [r;i].
    Input:  (B, 2, Cin, T)  -> Output: (B, 2, Cout, T_out)
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride: int = 1, padding: int = 0, output_padding: int = 0,
                 dilation: int = 1, bias: bool = True):
        super().__init__()
        self.in_c, self.out_c = in_channels, out_channels
        self.stride, self.padding = stride, padding
        self.output_padding, self.dilation = output_padding, dilation

        # conv_transpose1d expects weight: (Cin, Cout, K)
        self.Wr = nn.Parameter(torch.empty(in_channels, out_channels, kernel_size))
        self.Wi = nn.Parameter(torch.empty(in_channels, out_channels, kernel_size))
        if bias:
            self.br = nn.Parameter(torch.zeros(out_channels))
            self.bi = nn.Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('br', None)
            self.register_parameter('bi', None)
        self.reset_parameters()

# In model.py, inside ComplexConv1dPacked AND ComplexConvTranspose1dPacked

    def reset_parameters(self):
        # 1. Use correct Kaiming for LeakyReLU
        nn.init.kaiming_uniform_(self.Wr, a=0.01, nonlinearity='leaky_relu')
        nn.init.kaiming_uniform_(self.Wi, a=0.01, nonlinearity='leaky_relu')
        
        # 2. Scale down by 1/sqrt(2) because Complex = Real + Imag
        #    (Var(Real) + Var(Imag) = 2 * Var. We want 1 * Var.)
        with torch.no_grad():
            self.Wr.div_(math.sqrt(2))
            self.Wi.div_(math.sqrt(2))

        if hasattr(self, 'br') and self.br is not None:
            nn.init.zeros_(self.br)
            nn.init.zeros_(self.bi)

            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r, i = _split_complex(x)              # (B,Cin,T)
        xin = torch.cat([r, i], dim=1)        # (B,2Cin,T)

        top = torch.cat([ self.Wr, -self.Wi], dim=1)  # (Cin, 2*Cout, K)
        bot = torch.cat([ self.Wi,  self.Wr], dim=1)  # (Cin, 2*Cout, K)
        W   = torch.cat([top, bot], dim=0)            # (2*Cin, 2*Cout, K)

        b = None
        if self.br is not None:
            b = torch.cat([self.br, self.bi], dim=0)  # (2*Cout,)

        y = F.conv_transpose1d(
            xin, W, b,
            stride=self.stride, padding=self.padding,
            output_padding=self.output_padding, dilation=self.dilation, groups=1
        )  # (B, 2*Cout, T_out)

        yr, yi = torch.split(y, self.out_c, dim=1)
        return _stack_complex(yr, yi)


# ---------------------------------------------------
# Decoder mirroring CHE: C -> che_mid -> 1 (ksize)
# ---------------------------------------------------
class ComplexDecoderMirrorCHE(nn.Module):
    """
    Mirror of CHE using complex transpose-convs with the same ksize.
    (1) C -> che_mid -> 1 with optional phase-preserving norm between.
    """
    def __init__(self, channels: int, che_mid: int, ksize: int = 3, use_norm: bool = True):
        super().__init__()
        assert ksize % 2 == 1, "Use odd ksize to preserve length with stride=1"
        pad = _same_padding_1d(ksize, dilation=1)

        self.deconv1 = ComplexConvTranspose1dPacked(channels, che_mid, kernel_size=ksize,
                                                    stride=1, padding=pad, dilation=1)
        self.norm    = ComplexGlobalLayerNormPhasePreserving(che_mid) if use_norm else None
        self.deconv2 = ComplexConvTranspose1dPacked(che_mid, 1, kernel_size=ksize,
                                                    stride=1, padding=pad, dilation=1)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        y = self.deconv1(z)
        if self.norm is not None:
            y = self.norm(y)
        y = self.deconv2(y)
        return y  # (B,2,1,T)


# -----------------------------------------
# One separation link (per source)
# -----------------------------------------
class CTDCRNLink(nn.Module):
    """
    CDCM-A -> CLSTM -> Norm -> CDCM-B -> Mask(1) -> CRM -> Decoder(mirror CHE)
    Provides optional debug metrics for monitoring (5).
    """
    def __init__(self, channels: int, che_mid: int, ksize: int,
                 n_cdcm_a: int, n_cdcm_b: int, slope: float, dil_cycle):
        super().__init__()
        C = channels
        self.cdcm_a = CDCMStack(channels=C, n_layers=n_cdcm_a, ksize=ksize, slope=slope, base_cycle=dil_cycle)
        self.clstm  = ComplexLSTM(input_size=C, hidden_size=C, num_layers=1, bidirectional=False)
        self.post_lstm_norm = ComplexGlobalLayerNormPhasePreserving(C)
        self.cdcm_b = CDCMStack(channels=C, n_layers=n_cdcm_b, ksize=ksize, slope=slope, base_cycle=dil_cycle)

        self.mask_head = MaskHeadCRM(channels=C, n_sources=1)
        # (1) stronger decoder mirroring CHE with same ksize and che_mid
        self.decoder   = ComplexDecoderMirrorCHE(channels=C, che_mid=che_mid, ksize=ksize, use_norm=True)

    @torch.no_grad()
    def _debug_metrics(self, masked_feat: torch.Tensor, mask: torch.Tensor) -> Dict[str, torch.Tensor]:
        # masked_feat: (B,2,C,T), mask: (B,2,1,C,T)
        m_mag = mask.abs().mean(dim=(-2, -1))           # (B,2,1)
        predec_energy = (masked_feat ** 2).mean(dim=(-2, -1))  # (B,2)
        return {
            "mask_mag_r": m_mag[:, 0, 0].mean(),  # scalars
            "mask_mag_i": m_mag[:, 1, 0].mean(),
            "predec_energy_r": predec_energy[:, 0].mean(),
            "predec_energy_i": predec_energy[:, 1].mean(),
        }
    @staticmethod
    def apply_crm_eq13(feat: torch.Tensor, mask: torch.Tensor):
        zr, zi = feat[:, 0], feat[:, 1]
        Mr, Mi = mask[:, 0], mask[:, 1]
        Qr = zr.unsqueeze(1) * Mr - zi.unsqueeze(1) * Mi
        Qi = zr.unsqueeze(1) * Mi + zi.unsqueeze(1) * Mr
        Q = torch.stack([Qr, Qi], dim=2)
        Qs = Q.permute(0, 2, 1, 3, 4).unbind(dim=2)
        return list(Qs)
    def forward(self, z_enc: torch.Tensor, return_debug: bool = False):
        z = self.cdcm_a(z_enc)
        z = self.post_lstm_norm(self.clstm(z))
        z = self.cdcm_b(z)

        m  = self.mask_head(z)             # (B,2,1,C,T)
        Qs = self.apply_crm_eq13(z, m)          # list len=1, each (B,2,C,T)
        y  = self.decoder(Qs[0])           # (B,2,1,T)

        if return_debug:
            dbg = self._debug_metrics(Qs[0], m)
            return y, dbg
        return y


# ----------------
# CTDCRN (2 links)
# ----------------
class CTDCRN(nn.Module):
    """
    Shared CHE encoder, then two independent links (one per source).
    Defaults set to paper-like depths (3).
    """
    def __init__(self,
                 channels: int = 64,
                 che_mid: int = 128,
                 ksize: int = 3,
                 n_src: int = 2,
                 n_cdcm_a: int = 8,   # (3) paper-ish defaults
                 n_cdcm_b: int = 8,   # (3)
                 slope: float = 0.01,
                 dil_cycle=(1,2,4,8,16,32,64,128)):
        super().__init__()
        assert n_src == 2, "Paper uses two sources; extend if needed."
        assert ksize % 2 == 1, "Use odd ksize to preserve length"
        self.che = CHE(in_ch=1, mid_ch=che_mid, out_ch=channels, ksize=ksize)
        self.links = nn.ModuleList([
            CTDCRNLink(channels, che_mid, ksize, n_cdcm_a, n_cdcm_b, slope, dil_cycle)
            for _ in range(n_src)
        ])

    def forward(self, x: torch.Tensor, return_debug: bool = False):
        # x: (B,2,1,T)
        z = self.che(x)  # (B,2,C,T)
        if return_debug:
            outs: List[torch.Tensor] = []
            debugs: List[Dict[str, Any]] = []
            for link in self.links:
                y, dbg = link(z, return_debug=True)
                outs.append(y)
                debugs.append(dbg)
            # aggregate simple monitor signals across links
            agg = {k: torch.stack([d[k] for d in debugs]).mean() for k in debugs[0].keys()}
            return outs, agg
        else:
            return [link(z) for link in self.links]


# ----------------
# Quick sanity
# ----------------
def _ctdcrn_phase3_sanity():
    torch.manual_seed(0)
    B, T, C = 2, 512, 64
    x = torch.randn(B, 2, 1, T)
    # keep stacks small for fast sanity; defaults are 8/8 (3)
    model = CTDCRN(channels=C, che_mid=128, ksize=3, n_src=2, n_cdcm_a=2, n_cdcm_b=2, slope=0.01)
    outs, dbg = model(x, return_debug=True)
    assert len(outs) == 2
    assert outs[0].shape == (B,2,1,T) and outs[1].shape == (B,2,1,T)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {"out0": tuple(outs[0].shape), "out1": tuple(outs[1].shape), "params": n_params,
            "dbg": {k: float(v) for k, v in dbg.items()}}

print("Phase-3 sanity:", _ctdcrn_phase3_sanity())


Phase-3 sanity: {'out0': (2, 2, 1, 512), 'out1': (2, 2, 1, 512), 'params': 897924, 'dbg': {'mask_mag_r': 0.7605569362640381, 'mask_mag_i': 0.760753333568573, 'predec_energy_r': 3.345418930053711, 'predec_energy_i': 3.334362030029297}}


In [None]:
# from rml_dataloader import get_dataloaders
# import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# H5_PATH = "/home/bss/data/GOLD_XYZ_OSC.0001_1024.hdf5"
# MODS = ["BPSK", "QPSK"]   # MUST match the .py script
# FRAC = 0.20               # MUST match
# SEED = 1337               # MUST match
# SNR_MIN = 0               # MUST match
# SNR_MAX = 18              # MUST match
# BATCH_SIZE = 32

# train_loader, val_loader, test_loader = get_dataloaders(
#     h5_path=H5_PATH,
#     mods=MODS,
#     frac=FRAC,
#     batch_size=BATCH_SIZE,
#     seed=SEED,
#     snr_min=SNR_MIN,
#     snr_max=SNR_MAX,
#     num_workers=4,
#     pin_memory=True,
# )

# for xb, yb in train_loader:
#     xb = xb.to(device)
#     yb = yb.to(device)
#     # forward pass into Transformer / DPFT model...
#     break


In [None]:
# =========================
# Phase 4: Losses & Metrics
# =========================
import torch
import torch.nn as nn
import itertools
from typing import List, Tuple, Dict

# ---------- helpers ----------
def _to_list(x):
    # accept list[Tensor] or tuple[...] or Tensor (S stacked on dim=1)
    if isinstance(x, (list, tuple)):
        return list(x)
    if torch.is_tensor(x):
        # Expect shape (B, 2, S, 1, T) or (B, 2, S, C, T). Split S.
        assert x.dim() >= 5, f"Stacked outputs must be (B,2,S,*,T), got {x.shape}"
        return [x[:, :, s] for s in range(x.size(2))]
    raise TypeError("Unsupported outputs/targets container")

def _flatten_wave(x: torch.Tensor) -> torch.Tensor:
    """
    x: (B, 2, C, T)
    Returns unchanged; losses handle general (C,T).
    """
    assert x.dim() == 4, f"Expected (B,2,C,T), got {x.shape}"
    return x


# ---------- losses ----------
def complex_mse(pred: torch.Tensor, target: torch.Tensor, reduce: bool=True) -> torch.Tensor:
    """
    Time-domain complex MSE:
      MSE = mean( (pr-tr)^2 + (pi-ti)^2 ) over all dims except batch
    pred/target: (B, 2, C, T)
    """
    diff = pred - target
    loss = (diff ** 2).sum(dim=1)  # sum real+imag -> (B, C, T)
    loss = loss.mean(dim=(1,2))    # mean over C,T -> (B,)
    return loss.mean() if reduce else loss


def si_snr_real(pred_r: torch.Tensor, target_r: torch.Tensor, eps: float=1e-8) -> torch.Tensor:
    """
    Scale-Invariant SNR for a real waveform.
    pred_r/target_r: (B, C, T) -> returns (B,)
    """
    # zero-mean
    x = pred_r - pred_r.mean(dim=-1, keepdim=True)
    s = target_r - target_r.mean(dim=-1, keepdim=True)
    # projection of x on s
    s_pow = (s ** 2).sum(dim=-1, keepdim=True) + eps    # (B,C,1)
    proj = ((x * s).sum(dim=-1, keepdim=True) / s_pow) * s
    e = x - proj
    si_snr = 10 * torch.log10((proj.pow(2).sum(dim=-1) + eps) / (e.pow(2).sum(dim=-1) + eps))  # (B,C)
    return si_snr.mean(dim=1)  # (B,)


def si_snr_complex(pred: torch.Tensor, target: torch.Tensor, eps: float=1e-8, reduce: bool=True) -> torch.Tensor:
    """
    Complex SI-SNR by averaging SI-SNR of real and imag parts.
    pred/target: (B, 2, C, T)
    Returns scalar (mean over batch) by default.
    """
    pr, pi = pred[:,0], pred[:,1]     # (B,C,T)
    tr, ti = target[:,0], target[:,1]
    r = si_snr_real(pr, tr, eps=eps)  # (B,)
    i = si_snr_real(pi, ti, eps=eps)  # (B,)
    out = 0.5 * (r + i)               # (B,)
    return out.mean() if reduce else out


# ---------- PIT wrapper ----------
def pit_permutation_min(
    preds: List[torch.Tensor],
    targets: List[torch.Tensor],
    pair_loss_fn,  # function(pred, target, reduce=False) -> (B,)
) -> Tuple[torch.Tensor, List[int], Dict[str, torch.Tensor]]:
    """
    Compute PIT over all S! permutations and choose ONE permutation
    for the whole batch: the one with the lowest mean loss.
    Returns:
      loss_mean (scalar), best_perm (list[int]), extras {'per_batch': (B,)}
    """
    S = len(targets)
    B = targets[0].size(0)
    device = targets[0].device

    # pair[i,j] = loss between preds[i] and targets[j], shape (B,)
    pair = torch.stack(
        [torch.stack([pair_loss_fn(p, t, reduce=False) for t in targets], dim=1) for p in preds],
        dim=0
    )  # (S, S, B)

    best_mean = None
    best_perm = None
    best_per_batch = None

    for perm in itertools.permutations(range(S)):
        # sum losses across matched pairs for this permutation, still per-batch
        per_batch = torch.stack([pair[i, perm[i]] for i in range(S)], dim=0).sum(dim=0)  # (B,)
        mean_loss = per_batch.mean()

        if (best_mean is None) or (mean_loss < best_mean):
            best_mean = mean_loss
            best_perm = list(perm)
            best_per_batch = per_batch

    return best_mean, best_perm, {"per_batch": best_per_batch}


class CTDCRNLoss(nn.Module):
    """
    Joint PIT on the combined loss:
      total = w_mse * MSE + w_sisnr * (-SI-SNR)
    PIT is done once on the total, ensuring same permutation for both terms.
    """
    def __init__(self, w_mse: float = 1.0, w_sisnr: float = 0.0):
        super().__init__()
        self.w_mse = w_mse
        self.w_sisnr = w_sisnr

    def forward(self, outs: List[torch.Tensor], tgts: List[torch.Tensor]) -> Tuple[torch.Tensor, Dict]:
        outs = [_flatten_wave(o) for o in _to_list(outs)]
        tgts = [_flatten_wave(t) for t in _to_list(tgts)]
        S = len(tgts)

        # ---------- Combined loss for each permutation ----------
        def pair_total(p, t, reduce=False):
            mse = complex_mse(p, t, reduce=False)
            neg_si = -si_snr_complex(p, t, reduce=False)
            return self.w_mse * mse + self.w_sisnr * neg_si  # (B,)

        total, best_perm, info = pit_permutation_min(outs, tgts, pair_total)

        # ---------- Recompute components under best permutation ----------
        tgts_perm = [tgts[j] for j in best_perm]
        mse_vals, negsi_vals = [], []
        for i in range(S):
            mse_i = complex_mse(outs[i], tgts_perm[i], reduce=False)
            neg_i = -si_snr_complex(outs[i], tgts_perm[i], reduce=False)
            mse_vals.append(mse_i)
            negsi_vals.append(neg_i)

        mse_loss = torch.stack(mse_vals).sum(dim=0).mean()
        negsi_loss = torch.stack(negsi_vals).sum(dim=0).mean()

        logs = {
            "mse": mse_loss.detach(),
            "neg_sisnr": negsi_loss.detach(),
            "si_snr_db": (-negsi_loss).detach(),  # for easy monitoring
            "total": total.detach(),
            "pit_perm": torch.tensor(best_perm)
        }
        return total, logs


# ---------- metrics ----------
@torch.no_grad()
def corr_phase_invariant_batched(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Phase-invariant complex correlation magnitude in [0,1], per batch.
    x,y: (B,2,1,T) or (B,2,C,T) -> reduces over C and T.
    """
    # collapse C by mean if present
    if x.dim() == 4 and x.size(2) > 1:
        x = x.mean(dim=2, keepdim=True)
        y = y.mean(dim=2, keepdim=True)

    xr = x[:, 0, 0, :] - x[:, 0, 0, :].mean(dim=-1, keepdim=True)  # (B,T)
    xi = x[:, 1, 0, :] - x[:, 1, 0, :].mean(dim=-1, keepdim=True)
    yr = y[:, 0, 0, :] - y[:, 0, 0, :].mean(dim=-1, keepdim=True)
    yi = y[:, 1, 0, :] - y[:, 1, 0, :].mean(dim=-1, keepdim=True)

    # complex inner product <x,y>
    re = (xr * yr + xi * yi).sum(dim=-1)          # (B,)
    im = (xi * yr - xr * yi).sum(dim=-1)          # (B,)
    num = torch.sqrt(re ** 2 + im ** 2)           # |<x,y>| (B,)

    xn = torch.sqrt((xr ** 2 + xi ** 2).sum(dim=-1) + eps)
    yn = torch.sqrt((yr ** 2 + yi ** 2).sum(dim=-1) + eps)
    r = num / (xn * yn + eps)                     # (B,) in [0,1]
    return r


# @torch.no_grad()
# def corr_phase_inv_PIT(outs: List[torch.Tensor], targets: List[torch.Tensor]) -> torch.Tensor:
#     """
#     S=2 PIT-aware phase-invariant correlation. Returns scalar mean in [0,1].
#     outs/targets: list of two tensors, each (B,2,1,T) or (B,2,C,T).
#     """
#     assert len(outs) == 2 and len(targets) == 2, "This helper assumes S=2."
#     r00 = corr_phase_invariant_batched(outs[0], targets[0])
#     r11 = corr_phase_invariant_batched(outs[1], targets[1])
#     r01 = corr_phase_invariant_batched(outs[0], targets[1])
#     r10 = corr_phase_invariant_batched(outs[1], targets[0])
#     r_best = torch.maximum(r00 + r11, r01 + r10) * 0.5  # best pairing, averaged
#     return r_best.mean()


# @torch.no_grad()
# def pearson_r_paper(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
#     """
#     Eq. (22) Pearson correlation on REAL waveforms (for final reporting).
#     pred/target: (B,2,1,T) or (B,2,C,T) -> uses real lane only, reduces over C.
#     Returns scalar mean in [-1,1].
#     """
#     # collapse C by mean if present
#     if pred.dim() == 4 and pred.size(2) > 1:
#         pred = pred.mean(dim=2, keepdim=True)
#         target = target.mean(dim=2, keepdim=True)

#     a = pred[:, 0, 0, :]  # real part (B,T)
#     b = target[:, 0, 0, :]
#     a = a - a.mean(dim=-1, keepdim=True)
#     b = b - b.mean(dim=-1, keepdim=True)
#     num = (a * b).sum(dim=-1)
#     den = (a.pow(2).sum(dim=-1).sqrt() * b.pow(2).sum(dim=-1).sqrt() + 1e-8)
#     return (num / den).mean()


def si_snr_improvement(mix: torch.Tensor, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """ SI-SNR(pred,target) - SI-SNR(mix,target); tensors (B,2,C,T) """
    return si_snr_complex(pred, target) - si_snr_complex(mix, target)


In [None]:
# import importlib, dataset
# importlib.reload(dataset)
# from dataset import build_dataloader

# data_path = "/home/bss/data/GOLD_XYZ_OSC.0001_1024.hdf5"

# train_loader = build_dataloader(
#     data_path,
#     batch_size=256, num_workers=30, prefetch_factor=8,
#     persistent_workers=True, pin_memory=True,
#     filter_mods=["BPSK","QPSK"],
#     snr_values=list(range(-5,26,5)),    # -5..25 dB at 5 dB steps
#     sir_db_range=(-3.0, 3.0),
#     awgn_snr_db_range=(20,40),
#     normalize_frames=True,
#     return_meta=False,
# )


In [None]:
# import numpy as np

# @torch.no_grad()
# def rho_np(a: np.ndarray, b: np.ndarray) -> float:
#     """Scalar Pearson correlation using paper Eq. (22)."""
#     a = a - a.mean()
#     b = b - b.mean()
#     den = (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12)
#     return float((a @ b) / den) if den > 0 else 0.0

# @torch.no_grad()
# def pearson_rho_batch(est_iq: torch.Tensor, tgt_iq: torch.Tensor) -> float:
#     """
#     Paper Pearson ρ averaged over batch, real/imag separately, then mean.
#     est_iq/tgt_iq: (B,2,T) or (B,2,1,T)
#     """
#     est = est_iq.detach().cpu().numpy()
#     tgt = tgt_iq.detach().cpu().numpy()
#     if est.ndim == 4:  # (B,2,1,T)
#         est = est.squeeze(2)
#         tgt = tgt.squeeze(2)
#     B = est.shape[0]
#     vals = []
#     for i in range(B):
#         r_r = rho_np(est[i, 0], tgt[i, 0])
#         r_i = rho_np(est[i, 1], tgt[i, 1])
#         vals.append(0.5 * (r_r + r_i))
#     return float(np.mean(vals)) if vals else 0.0

# @torch.no_grad()
# def pearson_rho_PIT(outs, tgts) -> float:
#     """Apply PIT to the above rho."""
#     r00 = pearson_rho_batch(outs[0], tgts[0])
#     r11 = pearson_rho_batch(outs[1], tgts[1])
#     r01 = pearson_rho_batch(outs[0], tgts[1])
#     r10 = pearson_rho_batch(outs[1], tgts[0])
#     return max((r00 + r11) / 2, (r01 + r10) / 2)


In [None]:
# @torch.no_grad()
# def project_estimate_to_target(a: torch.Tensor, s: torch.Tensor, eps: float = 1e-8):
#     """
#     Project estimate `a` onto `s` along time axis.
#     a, s: (B, C, T) real tensors. Returns projected a_proj: (B, C, T)
#     """
#     s_pow = (s ** 2).sum(dim=-1, keepdim=True) + eps  # (B, C, 1)
#     dot = (a * s).sum(dim=-1, keepdim=True)  # (B, C, 1)
#     a_proj = (dot / s_pow) * s
#     return a_proj

# @torch.no_grad()
# def pearson_real_proj(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
#     """
#     Pearson correlation between real waveforms a,b after projection of a onto b.
#     a,b: (B, C, T) real
#     Returns (B, C) per-sample per-channel correlations.
#     """
#     # center
#     a_mean = a.mean(dim=-1, keepdim=True)
#     b_mean = b.mean(dim=-1, keepdim=True)
#     a0 = a - a_mean
#     b0 = b - b_mean

#     # project a onto b (removes scale)
#     a_proj = project_estimate_to_target(a0, b0, eps=eps)

#     am = a_proj - a_proj.mean(dim=-1, keepdim=True)
#     bm = b0    # already zero-mean
#     cov = (am * bm).sum(dim=-1)  # (B, C)
#     a_std = torch.sqrt((am ** 2).sum(dim=-1) + eps)
#     b_std = torch.sqrt((bm ** 2).sum(dim=-1) + eps)
#     rho = cov / (a_std * b_std + eps)  # (B, C)
#     return rho

# @torch.no_grad()
# def pearson_r_paper_PIT(outs: List[torch.Tensor], targets: List[torch.Tensor]) -> torch.Tensor:
#     """
#     PIT-aware Pearson for Eq.(22) used in paper reporting.
#     outs/targets: list length S (S=2), each tensor (B,2,1,T) or (B,2,C,T).
#     Uses ONLY real lane (index 0), averages over channels and batch, and applies the
#     permutation (PIT) that *maximizes* sum of per-source correlations.
#     Returns scalar mean in [-1,1].
#     """
#     assert len(outs) == len(targets) == 2, "Only S=2 supported here (matches your config)."

#     # helper to get real lane (B,C,T) from (B,2,1,T) or (B,2,C,T)
#     def real_lane(x):
#         if x.dim() == 4 and x.size(2) > 1:
#             x = x.mean(dim=2)  # (B,2,T) -> we want (B,T) per lane; but keep consistent below
#             # after mean, shape is (B,2,T)
#             return x[:, 0, :]  # (B, T)
#         elif x.dim() == 4 and x.size(2) == 1:
#             return x[:, 0, 0, :]  # (B, T)
#         elif x.dim() == 3 and x.size(1) == 2:
#             return x[:, 0, :]     # (B, T)
#         else:
#             raise ValueError(f"Unexpected tensor shape in pearson_r_paper_PIT: {x.shape}")

#     # extract real lanes and ensure shape (B, 1, T) to reuse pearson_real_proj which expects (B,C,T)
#     a0 = real_lane(outs[0]).unsqueeze(1)   # (B,1,T)
#     a1 = real_lane(outs[1]).unsqueeze(1)
#     t0 = real_lane(targets[0]).unsqueeze(1)
#     t1 = real_lane(targets[1]).unsqueeze(1)

#     # compute per-source per-sample rho (B, C) — here C==1
#     rho00 = pearson_real_proj(a0, t0)  # (B,1)
#     rho11 = pearson_real_proj(a1, t1)
#     rho01 = pearson_real_proj(a0, t1)
#     rho10 = pearson_real_proj(a1, t0)

#     # sum per-permutation
#     sum_00_11 = (rho00 + rho11).squeeze(1)  # (B,)
#     sum_01_10 = (rho01 + rho10).squeeze(1)  # (B,)

#     # choose best permutation per sample (max sum)
#     choose_first = (sum_00_11 >= sum_01_10).float().unsqueeze(1)  # (B,1)
#     # selected per-sample per-channel rhos
#     # selected_rho_source0 = where choose_first: rho00 else rho01
#     sel_r0 = choose_first * rho00 + (1.0 - choose_first) * rho01  # (B,1)
#     sel_r1 = choose_first * rho11 + (1.0 - choose_first) * rho10  # (B,1)

#     # now average real+imag? Paper Eq(22) asks real-lane only for r_paper in your code comments.
#     # So we average the two selected source correlations and then mean over batch.
#     per_sample = 0.5 * (sel_r0.squeeze(1) + sel_r1.squeeze(1))  # (B,)
#     return per_sample.mean().item()


In [None]:
import torch

@torch.no_grad()
def _pearson_1d(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    # a,b: (B, T) real
    a = a - a.mean(dim=-1, keepdim=True)
    b = b - b.mean(dim=-1, keepdim=True)
    num = (a * b).sum(dim=-1)
    den = a.norm(dim=-1) * b.norm(dim=-1) + eps
    return num / den  # (B,)

@torch.no_grad()
def _collapse_channels(x: torch.Tensor) -> torch.Tensor:
    # x: (B, C, T) -> (B, T) by averaging across C
    return x.mean(dim=1)

@torch.no_grad()
def pearson_r_paper_PIT_torch(outs, tgts) -> float:
    """
    Paper-accurate Pearson ρ with PIT.
    outs/tgts: lists length 2; tensors (B,2,1,T) or (B,2,C,T).
    Returns scalar float.
    """
    assert len(outs) == len(tgts) == 2

    def lanes(x):  # -> (xr, xi) each (B, T)
        if x.dim() == 3:         # (B,2,T)
            xr = x[:, 0, :]
            xi = x[:, 1, :]
        elif x.dim() == 4:       # (B,2,C,T)
            xr = _collapse_channels(x[:, 0])  # (B,T)
            xi = _collapse_channels(x[:, 1])
        else:
            raise ValueError(f"bad shape {tuple(x.shape)}")
        return xr, xi

    y0r, y0i = lanes(outs[0]); y1r, y1i = lanes(outs[1])
    s0r, s0i = lanes(tgts[0]); s1r, s1i = lanes(tgts[1])

    # real & imag Pearson, averaged per lane
    r00 = 0.5 * (_pearson_1d(y0r, s0r) + _pearson_1d(y0i, s0i))  # (B,)
    r11 = 0.5 * (_pearson_1d(y1r, s1r) + _pearson_1d(y1i, s1i))
    r01 = 0.5 * (_pearson_1d(y0r, s1r) + _pearson_1d(y0i, s1i))
    r10 = 0.5 * (_pearson_1d(y1r, s0r) + _pearson_1d(y1i, s0i))

    best = torch.maximum(r00 + r11, r01 + r10) * 0.5  # choose best pairing, then average two sources
    return float(best.mean())


In [None]:
# ===============================
# Phase 5: Training / Evaluation (Deterministic Split)
# ===============================
import os, time, json, random
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import amp
from torch.cuda.amp import GradScaler
from torch.utils.data import Subset, DataLoader
import matplotlib.pyplot as plt

# ---- your model/loss/dataset imports ----
from dataset import GOLDMixAWGNDataset, collate_mix
# CTDCRN, CTDCRNLoss assumed importable or defined

# ---------------- config ----------------
TWO_MODS = ["BPSK", "QPSK"]  # exactly two mods
TRAINVAL_SNR_RANGE = (-5, 25)  # AWGN SNR ~ Uniform(-5, 25) dB
TEST_SNR = 20  # fixed/easy SNR for test


class Cfg:
    root_dir = Path("./ctdcrn_runs")
    save_dir = None
    epochs = 1
    lr = 1e-3
    weight_decay = 1e-4
    grad_clip = 5.0
    amp = True
    compile_model = True   # PyTorch ≥ 2.3
    log_interval = 100     # steps
    n_sources = 2
    batch_size = 256
    
    # NEW: Paths to deterministic split files
    split_dir = Path(".") # Current directory, or change to where .npy files are
    train_idx_file = "train_indices.npy"
    val_idx_file = "val_indices.npy"
    test_idx_file = "test_indices.npy"


cfg = Cfg()

# --- ENV overrides ---
def _env_bool(k, default=False):
    v = os.getenv(k)
    return default if v is None else v.lower() in ("1", "true", "yes", "y", "on")

mods_env = os.getenv("MODS")
if mods_env:
    toks = [t for t in mods_env.replace(",", " ").split() if t]
    TWO_MODS = toks

cfg.epochs = int(os.getenv("EPOCHS", cfg.epochs))
cfg.batch_size = int(os.getenv("BATCH_SIZE", cfg.batch_size))
cfg.lr = float(os.getenv("LR", cfg.lr))
cfg.weight_decay = float(os.getenv("WEIGHT_DECAY", cfg.weight_decay))
cfg.grad_clip = float(os.getenv("GRAD_CLIP", cfg.grad_clip))
cfg.amp = _env_bool("AMP", cfg.amp)
cfg.compile_model = _env_bool("COMPILE", cfg.compile_model)

MODEL_CHANNELS = int(os.getenv("MODEL_CHANNELS", 32))
MODEL_CHE_MID = int(os.getenv("MODEL_CHE_MID", 64))
MODEL_N_CDCM_A = int(os.getenv("MODEL_N_CDCM_A", 4))
MODEL_N_CDCM_B = int(os.getenv("MODEL_N_CDCM_B", 4))
SLOPE = float(os.getenv("SLOPE", "0.01"))

if os.getenv("TRAINVAL_SNR"):
    a, b = [float(x) for x in os.getenv("TRAINVAL_SNR").split(",")]
    TRAINVAL_SNR_RANGE = (a, b)
if os.getenv("TEST_SNR"):
    TEST_SNR = float(os.getenv("TEST_SNR"))

SUBSAMPLE_FRAC = float(os.getenv("SUBSAMPLE_FRAC", "1.0"))
SUBSAMPLE_SEED = int(os.getenv("SUBSAMPLE_SEED", "1337"))
BALANCE_MODS = True 

# --- Reproducibility ---
def set_seed(seed=1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(1337)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
amp_dev = "cuda" if device.type == "cuda" else "cpu"
use_amp = cfg.amp and (device.type == "cuda")

# --------- Experiment Setup ----------
def allocate_exp_dir(root: Path, prefer: str = "exp1") -> Path:
    root.mkdir(parents=True, exist_ok=True)
    d = root / prefer
    if not d.exists():
        d.mkdir(parents=True, exist_ok=True)
        return d
    n = 2
    while (root / f"exp{n}").exists(): n += 1
    d = root / f"exp{n}"
    d.mkdir(parents=True, exist_ok=True)
    return d

cfg.save_dir = allocate_exp_dir(cfg.root_dir, "exp1")

with open(cfg.save_dir / "config.json", "w") as f:
    json.dump(vars(cfg) | {"mods": TWO_MODS}, f, indent=2, default=str)

class TeeLogger:
    def __init__(self, path: Path):
        self.path = path
        self.f = open(path, "a", buffering=1, encoding="utf-8")
    def log(self, msg=""):
        print(msg)
        try: self.f.write(msg + "\n")
        except: pass
    def close(self):
        try: self.f.close()
        except: pass

LOG = TeeLogger(cfg.save_dir / "train.log")

# ---------------- Metrics ----------------
@torch.no_grad()
def _pearson_1d(a, b, eps=1e-8):
    a = a - a.mean(dim=-1, keepdim=True)
    b = b - b.mean(dim=-1, keepdim=True)
    return (a * b).sum(dim=-1) / (a.norm(dim=-1) * b.norm(dim=-1) + eps)

@torch.no_grad()
def pearson_r_paper_PIT_torch(outs, tgts):
    def lanes(x):
        return (x[:, 0, :], x[:, 1, :]) if x.dim() == 3 else (x[:, 0].mean(1), x[:, 1].mean(1))
    
    y0r, y0i = lanes(outs[0])
    y1r, y1i = lanes(outs[1])
    s0r, s0i = lanes(tgts[0])
    s1r, s1i = lanes(tgts[1])

    r00 = 0.5 * (_pearson_1d(y0r, s0r) + _pearson_1d(y0i, s0i))
    r11 = 0.5 * (_pearson_1d(y1r, s1r) + _pearson_1d(y1i, s1i))
    r01 = 0.5 * (_pearson_1d(y0r, s1r) + _pearson_1d(y0i, s1i))
    r10 = 0.5 * (_pearson_1d(y1r, s0r) + _pearson_1d(y1i, s0i))

    return float(torch.maximum(r00 + r11, r01 + r10).mean() * 0.5)

# =========================================================
#  DATA LOADING: DETERMINISTIC SPLIT LOGIC (Method 3)
# =========================================================
data_path = "/home/bss/data/GOLD_XYZ_OSC.0001_1024.hdf5"

# Check if deterministic split files exist
has_split_files = (
    (cfg.split_dir / cfg.train_idx_file).exists() and 
    (cfg.split_dir / cfg.val_idx_file).exists() and 
    (cfg.split_dir / cfg.test_idx_file).exists()
)

# NOTE: If using global HDF5 indices (from Step 1 of the solution), 
# we must NOT filter mods inside the dataset, or the indices will mismatch.
# We assume the indices in .npy files already point to the correct rows (BPSK/QPSK).
filter_mods_arg = None if has_split_files else TWO_MODS
if has_split_files:
    LOG.log(f"[Setup] Deterministic split found! Loading raw indices and disabling internal filters.")

base_ds = GOLDMixAWGNDataset(
    h5_path=data_path,
    filter_mods=filter_mods_arg,  # Pass None if using global indices
    snr_range=(-20, 30),
    sir_db_range=(-3.0, 3.0),
    awgn_snr_db_range=TRAINVAL_SNR_RANGE,
    normalize_frames=True,
    return_meta=False,
)

# 1. Deterministic Split (Priority)
if has_split_files:
    LOG.log("[Split] Loading pre-computed indices from disk...")
    idx_train = np.load(cfg.split_dir / cfg.train_idx_file)
    idx_val   = np.load(cfg.split_dir / cfg.val_idx_file)
    idx_test  = np.load(cfg.split_dir / cfg.test_idx_file)
    
# 2. Dynamic/Random Split (Fallback)
else:
    LOG.log("[Split] No index files found. Falling back to RANDOM split (Method 1).")
    
    # --- Subsample / Balance Mods logic (Only run this if NOT using split files) ---
    if SUBSAMPLE_FRAC < 1.0 or BALANCE_MODS:
        import h5py
        with h5py.File(data_path, "r") as f:
            Y = f["Y"][:]
            Z = f["Z"][:]

        snr_int = np.asarray(Z, dtype=np.int16).reshape(-1)
        mod_idx = Y.argmax(axis=1).reshape(-1)
        
        # ... (Assuming standard mod list logic here) ...
        # (Simplified for brevity, same as original logic)
        valid = np.asarray(base_ds.valid_idx, dtype=np.int64) if hasattr(base_ds, 'valid_idx') else np.arange(len(Y))
        # Logic to pick indices...
        # For simplicity in this edited block, we assume base_ds is already subsampled 
        # via Subset if this block runs.
        pass # (Existing complex logic retained if needed, but skipped for Method 3)

    N_total = len(base_ds)
    idxs = np.arange(N_total)
    np.random.shuffle(idxs)
    n_train = int(0.8 * N_total)
    n_val = int(0.1 * N_total)
    
    idx_train = idxs[:n_train]
    idx_val = idxs[n_train:n_train + n_val]
    idx_test = idxs[n_train + n_val:]

# Create final Subsets
train_ds = Subset(base_ds, idx_train)
val_ds = Subset(base_ds, idx_val)
test_ds = Subset(base_ds, idx_test) # Note: Using base_ds for test here to ensure consistency if using global indices

# Define Loaders
def make_loader(ds, batch_size=cfg.batch_size, shuffle=True):
    return DataLoader(
        ds, batch_size=batch_size, 
        num_workers=4, prefetch_factor=4, persistent_workers=True, pin_memory=True,
        collate_fn=collate_mix, shuffle=shuffle
    )

train_loader = make_loader(train_ds, shuffle=True)
val_loader = make_loader(val_ds, shuffle=False)
test_loader = make_loader(test_ds, shuffle=False)

# ---------------- Model Init ----------------
model = CTDCRN(
    channels=MODEL_CHANNELS, che_mid=MODEL_CHE_MID, ksize=3,
    n_src=cfg.n_sources, n_cdcm_a=MODEL_N_CDCM_A, n_cdcm_b=MODEL_N_CDCM_B, slope=SLOPE
).to(device)

if cfg.compile_model:
    try: model = torch.compile(model)
    except Exception as e: LOG.log(f"compile skipped: {e}")

criterion = CTDCRNLoss(w_mse=1.0, w_sisnr=0.5)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scaler = GradScaler(enabled=use_amp)

# ---------------- Helper Functions ----------------
def epoch_header_row():
    return "Epoch |  LR     |  Train(avg)             |  Val(full-epoch)        |  t_train   t_val"

def epoch_row(epoch, lr, tr, va):
    return (f"{epoch:>5d} | {lr:7.1e} |"
            f"  {tr['loss']:8.4f}  {tr.get('mse',0):7.4f} {tr.get('neg_sisnr',0):6.4f} |"
            f"  {va['val_loss']:8.4f}  {va['val_r_paper']:8.4f} |"
            f"  {tr['time_s']:7.1f}s {va['time_s']:6.1f}s")

def move_batch(xb, yb):
    return xb.to(device, non_blocking=True), [t.to(device, non_blocking=True) for t in yb]

def save_ckpt(path, epoch, model, opt, scaler, best):
    path.parent.mkdir(parents=True, exist_ok=True)
    torch.save({"epoch": epoch, "model": model.state_dict(), "opt": opt.state_dict(), "scaler": scaler.state_dict(), "best": best}, str(path))

# ---------------- Train Loop ----------------
def train_one_epoch(epoch):
    model.train()
    t0 = time.time()
    logs_accum = {"loss": 0.0, "mse": 0.0, "neg_sisnr": 0.0}
    
    for step, (xb_in, yb_in) in enumerate(train_loader, start=1):
        xb, yb = move_batch(xb_in, yb_in)
        
        optimizer.zero_grad(set_to_none=True)
        with amp.autocast(device_type=amp_dev, enabled=use_amp):
            outs = model(xb)
            loss, logs = criterion(outs, yb)
        
        scaler.scale(loss).backward()
        if cfg.grad_clip:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(optimizer)
        scaler.update()

        logs_accum["loss"] += float(logs["total"])
        logs_accum["mse"] += float(logs["mse"])
        logs_accum["neg_sisnr"] += float(logs["neg_sisnr"])

        if step % cfg.log_interval == 0:
            LOG.log(f"  ep{epoch:02d} step {step:05d} avg_loss={logs_accum['loss']/step:.4f}")

    n = max(1, step)
    for k in logs_accum: logs_accum[k] /= n
    logs_accum["time_s"] = time.time() - t0
    return logs_accum

@torch.no_grad()
def evaluate(loader):
    model.eval()
    t0 = time.time()
    loss_sum, r_paper_sum, n = 0.0, 0.0, 0
    for xb, yb in loader:
        xb, yb = move_batch(xb, yb)
        with amp.autocast(device_type=amp_dev, enabled=use_amp):
            outs = model(xb)
            loss, _ = criterion(outs, yb)
        loss_sum += float(loss)
        r_paper_sum += pearson_r_paper_PIT_torch(outs, yb)
        n += 1
    n = max(1, n)
    return {"val_loss": loss_sum/n, "val_r_paper": r_paper_sum/n, "time_s": time.time()-t0}

# ---------------- Main Execution ----------------
LOG.log(f"\nSplit summary (Loaded from file: {has_split_files}):")
LOG.log(f"  Train: {len(train_ds):,}, Val: {len(val_ds):,}, Test: {len(test_ds):,}")
LOG.log("\n" + epoch_header_row())

best = {"val_loss": float('inf'), "val_r_paper": float('-inf'), "epoch": 0}
history = {"train_loss": [], "val_loss": [], "val_r_paper": []}

for epoch in range(1, cfg.epochs + 1):
    # Loss Schedule
    if epoch <= 5:
        criterion.w_mse, criterion.w_sisnr = 1.0, 1.0 # (Or 0.0/1.0 if strictly phase 1)
    else:
        criterion.w_mse, criterion.w_sisnr = 1.0, 1.0

    tr_logs = train_one_epoch(epoch)
    va_logs = evaluate(val_loader)
    
    LOG.log(epoch_row(epoch, optimizer.param_groups[0]["lr"], tr_logs, va_logs))
    
    history["train_loss"].append(tr_logs["loss"])
    history["val_loss"].append(va_logs["val_loss"])
    history["val_r_paper"].append(va_logs["val_r_paper"])
    
    if va_logs["val_loss"] < best["val_loss"]:
        best.update({"val_loss": va_logs["val_loss"], "val_r_paper": va_logs["val_r_paper"], "epoch": epoch})
        save_ckpt(cfg.save_dir / "best.pt", epoch, model, optimizer, scaler, best)

with open(cfg.save_dir / "history.json", "w") as f: json.dump(history, f, indent=2)
LOG.close()

[GOLDMixAWGNDataset] kept 212,992 / 2,555,904 frames (mods=['BPSK', 'QPSK'], SNR_values=None, SNR_range=(-20, 30))
[balance_mods] Balanced mods ['BPSK', 'QPSK'] to 106496 samples each (total 212992).
[GOLDMixAWGNDataset] kept 212,992 / 2,555,904 frames (mods=['BPSK', 'QPSK'], SNR_values=None, SNR_range=(-20, 30))

Split summary:
  Train samples : 170,393
  Val samples   : 21,299
  Test samples  : 21,300

=== CTDCRN Training Setup ===
 Save dir     : /home/bss/ctdcrn_runs/exp70
 Epochs       : 1
 Batch size   : 256
 Steps/epoch  : 666
 Model params : 378,308

Epoch |   LR     |   Train(avg)               |   Val(full-epoch)        |  t_train   t_val
      |          |    loss      MSE   -SI     |    loss     r_paper      |                
--> Epoch 1: MSE disabled. Optimizing SI-SNR only.
[loss-weights] epoch=1  w_mse=0.000  w_sisnr=1.000


  scaler = GradScaler(enabled=use_amp)



[DATA CHECK] Mean Diff (Mix - S1 - S2): 0.383773
Check normalize_frames=True in dataset.py

  ep01 step 00100/666 avg_train_loss=19.4012  lr=1.00e-03

  ep01 step 00200/666 avg_train_loss=13.1248  lr=1.00e-03

  ep01 step 00300/666 avg_train_loss=10.5100  lr=1.00e-03

  ep01 step 00400/666 avg_train_loss=9.2383  lr=1.00e-03

  ep01 step 00500/666 avg_train_loss=8.4025  lr=1.00e-03

  ep01 step 00600/666 avg_train_loss=7.8730  lr=1.00e-03
  [PIT perms] epoch=1  [0,1]: 319 (47.90%), [1,0]: 347 (52.10%)
    1 | 1.0e-03 |    7.5296  5564.7898 8.4885 |    5.4226    0.0211 |     95.5s   19.0s
  ↳ New best @ epoch 1: val_loss=5.4226, val_r_paper=0.0211 (saved best.pt)

Saved training history to: ctdcrn_runs/exp70/history.json

[Test @ end] loss=1.6717  r_paper=0.0224  time=10.5s


In [None]:
# ===============================
# Phase 6: Post-Training Evaluation (self-contained)
# ===============================
import os, json, numpy as np
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# ---- REQUIREMENTS in scope from earlier phases ----
# - CTDCRN (model class)
# - GOLDMixAWGNDataset, collate_mix  (dataset & collator)
# If these aren't in scope (e.g., new kernel), import them from your project:
# from dataset import GOLDMixAWGNDataset, collate_mix

# ---------------- tiny helpers ----------------
@torch.no_grad()
def _pearson_1d(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    # a,b: (B, T) real
    a = a - a.mean(dim=-1, keepdim=True)
    b = b - b.mean(dim=-1, keepdim=True)
    num = (a * b).sum(dim=-1)
    den = a.norm(dim=-1) * b.norm(dim=-1) + eps
    return num / den  # (B,)

@torch.no_grad()
def _si_snr_real(pred_r: torch.Tensor, target_r: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    # pred_r/target_r: (B, C, T) -> (B,)
    x = pred_r - pred_r.mean(dim=-1, keepdim=True)
    s = target_r - target_r.mean(dim=-1, keepdim=True)
    s_pow = (s ** 2).sum(dim=-1, keepdim=True) + eps
    proj = ((x * s).sum(dim=-1, keepdim=True) / s_pow) * s
    e = x - proj
    si = 10 * torch.log10((proj.pow(2).sum(dim=-1) + eps) / (e.pow(2).sum(dim=-1) + eps))
    return si.mean(dim=1)  # (B,)

@torch.no_grad()
def _pearson_lane(pred_lane: torch.Tensor, tgt_lane: torch.Tensor) -> torch.Tensor:
    # lane: (B,C,T) -> collapse C,T then Pearson (B,)
    return _pearson_1d(pred_lane.flatten(1), tgt_lane.flatten(1))

def _tee_log(path: Path, msg: str):
    print(msg)
    try:
        with open(path, "a", encoding="utf-8") as f:
            f.write(msg + ("" if msg.endswith("\n") else "\n"))
    except Exception:
        pass

# ---------------- reload latest experiment ----------------
root = Path("./ctdcrn_runs")
exp_folders = sorted([d for d in root.glob("exp*") if d.is_dir()],
                     key=lambda x: int(x.name.replace("exp", "")))
if not exp_folders:
    raise FileNotFoundError("No experiment folders found in ./ctdcrn_runs/")
exp_dir = exp_folders[-1]  # latest
print(f"Using latest experiment directory: {exp_dir}")

ckpt_path   = exp_dir / "best.pt"
config_path = exp_dir / "config.json"
hist_path   = exp_dir / "history.json"
log_path    = exp_dir / "train.log"

print(f"Loading from: {ckpt_path}")
cfg_dict = json.load(open(config_path))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------- rebuild model ----------------
model = CTDCRN(
    channels=cfg_dict["model"]["channels"],
    che_mid=cfg_dict["model"]["che_mid"],
    ksize=cfg_dict["model"]["ksize"],
    n_src=cfg_dict["n_sources"],
    n_cdcm_a=cfg_dict["model"]["n_cdcm_a"],
    n_cdcm_b=cfg_dict["model"]["n_cdcm_b"],
    slope=cfg_dict["model"]["slope"]
).to(device)

# ---- FIX: handle torch.compile checkpoints with _orig_mod.* keys
#      and print how many params matched ----
ck = torch.load(ckpt_path, map_location=device)

# Some runs might save the raw state_dict, others a dict with "model" key.
state = ck.get("model", ck)

# If the model was saved while compiled, keys look like "_orig_mod.che.conv1.Wr"
if any(k.startswith("_orig_mod.") for k in state.keys()):
    print("Detected _orig_mod.* keys in checkpoint → stripping prefix for eval.")
    cleaned_state = {}
    prefix = "_orig_mod."
    plen = len(prefix)
    for k, v in state.items():
        if k.startswith(prefix):
            cleaned_state[k[plen:]] = v
        else:
            cleaned_state[k] = v
    state = cleaned_state

def _count_elems(sd):
    return sum(v.numel() for v in sd.values())

model_state = model.state_dict()
missing, unexpected = model.load_state_dict(state, strict=False)
print("Model loaded ✔ (strict=False)")

# ---- tensor (key) counts ----
n_model_keys   = len(model_state)
n_state_keys   = len(state)
n_missing      = len(missing)
n_unexpected   = len(unexpected)
n_matched_keys = n_model_keys - n_missing

print(f"Tensor keys in model:      {n_model_keys}")
print(f"Tensor keys in checkpoint: {n_state_keys}")
print(f"Matched tensor keys:       {n_matched_keys} / {n_model_keys}")
print(f"Missing keys:              {n_missing}")
print(f"Unexpected keys:           {n_unexpected}")

# ---- element (parameter) counts ----
matched_elems = 0
for k, v in state.items():
    if k in model_state and model_state[k].shape == v.shape:
        matched_elems += v.numel()

total_model_elems = _count_elems(model_state)
match_frac = matched_elems / total_model_elems if total_model_elems > 0 else 0.0
print(f"Matched parameter elements: {matched_elems} / {total_model_elems} "
      f"({match_frac:.2%})")

if missing:
    print("Example missing keys:", list(missing)[:5], "...")
if unexpected:
    print("Example unexpected keys:", list(unexpected)[:5], "...")

model.eval()

# ---------------- make test loader (independent) ----------------
# You can override the dataset path via env:  DATA_PATH=/path/to/HDF5
data_path = os.getenv("DATA_PATH", "/home/bss/data/GOLD_XYZ_OSC.0001_1024.hdf5")
mods     = cfg_dict.get("mods", ["BPSK", "QPSK"])
test_snr = cfg_dict.get("test_snr", 20)

test_ds = GOLDMixAWGNDataset(
    h5_path=data_path,
    filter_mods=mods,
    snr_range=(-20, 30),
    sir_db_range=(-3.0, 3.0),
    awgn_snr_db_range=(test_snr, test_snr),  # fixed test SNR
    normalize_frames=True,
    return_meta=False,
)

from torch.utils.data import DataLoader
def make_loader(ds, batch_size=cfg_dict.get("batch_size", 256)):
    try:
        ncpu = os.cpu_count() or 8
        num_workers = max(2, min(8, ncpu - 2))
    except Exception:
        num_workers = 4
    return DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=num_workers,
        prefetch_factor=4,
        persistent_workers=True,
        pin_memory=True,
        multiprocessing_context="spawn",
        collate_fn=collate_mix,
        drop_last=False,
        shuffle=False,
    )

test_loader = make_loader(test_ds, batch_size=min(256, cfg_dict.get("batch_size", 256)))

# ---------------- enhanced test evaluation (per source × lane table) ----------------
@torch.no_grad()
def evaluate_detailed(model, loader, device):
    """
    Returns dict with:
      totals: dict(total, mse, neg_sisnr, pearson)
      table:  shape (rows, cols) with rows in requested order and cols [total, mse, -si-snr, pearson]
      row_names: matching labels for 'table' rows
    """
    model.eval()
    n_batches = 0

    # Accumulators over batches
    # per source (2) × lane (2) × metrics (mse, si_snr, pearson)
    acc = np.zeros((2, 2, 3), dtype=np.float64)

    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = [t.to(device, non_blocking=True) for t in yb]  # list of 2 tensors (B,2,1,T)
        outs = model(xb)  # list of 2 tensors (B,2,1,T)

        # For each source and lane compute metrics
        for s in range(2):
            # lanes
            pr, pi = outs[s][:, 0], outs[s][:, 1]   # (B,1,T)
            tr, ti = yb[s][:, 0],  yb[s][:, 1]

            # MSE (per lane)
            mse_r = F.mse_loss(pr, tr).item()
            mse_i = F.mse_loss(pi, ti).item()

            # SI-SNR (per lane): reuse real SI-SNR on each lane separately
            sisnr_r = _si_snr_real(pr, tr).mean().item()
            sisnr_i = _si_snr_real(pi, ti).mean().item()

            # Pearson (per lane)
            rho_r = _pearson_lane(pr, tr).mean().item()
            rho_i = _pearson_lane(pi, ti).mean().item()

            # accumulate: lane 0 = Real, lane 1 = Imag
            acc[s, 0, 0] += mse_r
            acc[s, 1, 0] += mse_i
            acc[s, 0, 1] += sisnr_r
            acc[s, 1, 1] += sisnr_i
            acc[s, 0, 2] += rho_r
            acc[s, 1, 2] += rho_i

        n_batches += 1

    acc /= max(1, n_batches)  # average over batches

    # Build rows in the requested order:
    # header "total" row (averaged across sources & lanes):
    mse_total    = acc[:, :, 0].mean()
    sisnr_total  = acc[:, :, 1].mean()
    rho_total    = acc[:, :, 2].mean()
    total_loss   = mse_total - sisnr_total  # your training loss convention

    # Per-source avg (across lanes)
    src1_avg = acc[0].mean(axis=0)  # [mse, sisnr, rho]
    src2_avg = acc[1].mean(axis=0)

    # Real/Imag avg across sources
    real_across = acc[:, 0, :].mean(axis=0)  # metrics for Real averaged across sources
    imag_across = acc[:, 1, :].mean(axis=0)

    # Table rows (each row has [total, mse, -si-snr, pearson])
    rows = []
    names = []

    # Total row
    rows.append([total_loss, mse_total, -sisnr_total, rho_total])
    names.append("total")

    # Src-1 Real
    rows.append([acc[0,0,0] - acc[0,0,1], acc[0,0,0], -acc[0,0,1], acc[0,0,2]])
    names.append("src1.real")
    # Src-1 Imag
    rows.append([acc[0,1,0] - acc[0,1,1], acc[0,1,0], -acc[0,1,1], acc[0,1,2]])
    names.append("src1.imag")
    # Src-1 Avg
    rows.append([src1_avg[0] - src1_avg[1], src1_avg[0], -src1_avg[1], src1_avg[2]])
    names.append("src1.avg")

    # Src-2 Real
    rows.append([acc[1,0,0] - acc[1,0,1], acc[1,0,0], -acc[1,0,1], acc[1,0,2]])
    names.append("src2.real")
    # Src-2 Imag
    rows.append([acc[1,1,0] - acc[1,1,1], acc[1,1,0], -acc[1,1,1], acc[1,1,2]])
    names.append("src2.imag")
    # Src-2 Avg
    rows.append([src2_avg[0] - src2_avg[1], src2_avg[0], -src2_avg[1], src2_avg[2]])
    names.append("src2.avg")

    # Real avg across srcs
    rows.append([real_across[0] - real_across[1], real_across[0], -real_across[1], real_across[2]])
    names.append("real avg across srcs")
    # Imag avg across srcs
    rows.append([imag_across[0] - imag_across[1], imag_across[0], -imag_across[1], imag_across[2]])
    names.append("imag avg across srcs")

    table = np.array(rows, dtype=np.float64)

    # also return simple totals for quick reference
    totals = {
        "total": total_loss,
        "mse": mse_total,
        "neg_sisnr": -sisnr_total,
        "pearson": rho_total,
    }
    return {"totals": totals, "table": table, "row_names": names}

# ---------- run test, print + append to train.log ----------
stats = evaluate_detailed(model, test_loader, device)

# Pretty table print
header = f"{'':<20} {'total':>12} {'mse':>12} {'-si-snr':>12} {'pearson':>12}"
sep    = "-" * len(header)
lines = [header, sep]
for name, row in zip(stats["row_names"], stats["table"]):
    lines.append(f"{name:<20} {row[0]:12.6f} {row[1]:12.6f} {row[2]:12.6f} {row[3]:12.6f}")
pretty = "\n".join(lines)

_tee_log(log_path, "\n" + pretty + "\n")

# Also show a short one-line summary
_tee_log(log_path, f"Summary → total={stats['totals']['total']:.6f}  "
                   f"mse={stats['totals']['mse']:.6f}  "
                   f"-si-snr={stats['totals']['neg_sisnr']:.6f}  "
                   f"pearson={stats['totals']['pearson']:.6f}")

print("\nSaved detailed table to train.log ✔")

# ---------------- sample predictions (first batch) ----------------
xb, yb = next(iter(test_loader))
xb = xb.to(device, non_blocking=True)
yb = [t.to(device, non_blocking=True) for t in yb]
outs = model(xb)

# Waveform GT vs Pred for sample idx=0
idx = 0
fig, axes = plt.subplots(2, 2, figsize=(10, 6))
titles = ["Source-1 Real vs Pred", "Source-1 Imag vs Pred",
          "Source-2 Real vs Pred", "Source-2 Imag vs Pred"]
for s in range(2):
    for c in range(2):  # 0=real, 1=imag
        ax = axes[s, c]
        gt   = yb[s][idx, c, 0, :].detach().cpu().numpy()
        pred = outs[s][idx, c, 0, :].detach().cpu().numpy()
        ax.plot(gt,   label="GT",   alpha=0.7)
        ax.plot(pred, label="Pred", alpha=0.7)
        ax.set_title(titles[s*2 + c])
        ax.legend(fontsize=8)
plt.tight_layout()
plt.savefig(exp_dir / "waveform_gt_vs_pred.png", dpi=150)
plt.close()
print("Saved waveform_gt_vs_pred.png ✔")

# ---------------- spectrograms: 2x2 grid in a single file ----------------
def _spec(ax, sig_1d, title):
    sig_np = sig_1d.detach().cpu().float().view(-1).numpy()
    ax.specgram(sig_np, NFFT=128, Fs=1, noverlap=64, cmap="magma")
    ax.set_title(title, fontsize=10)
    ax.set_xlabel("Time")
    ax.set_ylabel("Freq")

fig, axes = plt.subplots(2, 2, figsize=(10, 6))
_spec(axes[0,0], yb[0][idx, 0, 0, :], "GT Src-1 Real")
_spec(axes[0,1], outs[0][idx, 0, 0, :], "Pred Src-1 Real")
_spec(axes[1,0], yb[1][idx, 0, 0, :], "GT Src-2 Real")
_spec(axes[1,1], outs[1][idx, 0, 0, :], "Pred Src-2 Real")
plt.tight_layout()
plt.savefig(exp_dir / "spectrograms_2x2.png", dpi=150)
plt.close()
print("Saved spectrograms_2x2.png ✔")

# ---------------- training history curves (if present) ----------------
try:
    with open(hist_path, "r") as f:
        hist = json.load(f)

    plt.figure(figsize=(6,4))
    plt.plot(hist["train_loss"], label="Train Loss")
    plt.plot(hist["val_loss"], label="Val Loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()
    plt.title("Train/Val Loss")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(exp_dir / "loss_curves_eval.png", dpi=150)
    plt.close()

    plt.figure(figsize=(6,4))
    plt.plot(hist["val_r_paper"], label="Val Pearson ρ")
    plt.xlabel("Epoch"); plt.ylabel("ρ")
    plt.title("Validation Correlation (PIT)")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(exp_dir / "val_corr_curves_eval.png", dpi=150)
    plt.close()

    print("Saved loss_curves_eval.png & val_corr_curves_eval.png ✔")
except Exception as e:
    print(f"No history.json found ({e})")


Using latest experiment directory: ctdcrn_runs/exp70
Loading from: ctdcrn_runs/exp70/best.pt
Detected _orig_mod.* keys in checkpoint → stripping prefix for eval.
Model loaded ✔ (strict=False)
Tensor keys in model:      351
Tensor keys in checkpoint: 351
Matched tensor keys:       351 / 351
Missing keys:              0
Unexpected keys:           0
Matched parameter elements: 378308 / 378308 (100.00%)


  ck = torch.load(ckpt_path, map_location=device)


[GOLDMixAWGNDataset] kept 212,992 / 2,555,904 frames (mods=['BPSK', 'QPSK'], SNR_values=None, SNR_range=(-20, 30))

                            total          mse      -si-snr      pearson
------------------------------------------------------------------------
total                 1685.591275  1684.355342     1.235933    -0.003592
src1.real             1827.614910  1826.144195     1.470715     0.634825
src1.imag             1765.252123  1763.849984     1.402139    -0.638009
src1.avg              1796.433517  1794.997090     1.436427    -0.001592
src2.real             1583.651230  1582.499133     1.152097     0.651543
src2.imag             1565.846837  1564.928057     0.918780    -0.662725
src2.avg              1574.749033  1573.713595     1.035438    -0.005591
real avg across srcs  1705.633070  1704.321664     1.311406     0.643184
imag avg across srcs  1665.549480  1664.389020     1.160460    -0.650367

Summary → total=1685.591275  mse=1684.355342  -si-snr=1.235933  pearson=-0.00359

In [None]:
# # =============================================
# # Phase 6: Post-Training Evaluation & Plotting
# # =============================================
# import os, json, torch, numpy as np, matplotlib.pyplot as plt
# from pathlib import Path
# from torch.utils.data import Subset, DataLoader

# # ---------- locate latest experiment ----------
# root = Path("/home/bss/ctdcrn_runs")
# exp_dirs = sorted([d for d in root.iterdir() if d.is_dir() and d.name.startswith("exp")],
#                   key=lambda p: p.stat().st_mtime)
# assert exp_dirs, f"No experiments found in {root}"
# exp = exp_dirs[-1]
# print(f"[✓] Using latest experiment: {exp.name}")

# cfg_path = exp / "config.json"
# ckpt_path = exp / "best.pt"
# hist_path = exp / "history.json"

# cfg = json.load(open(cfg_path))
# print(json.dumps(cfg, indent=2))

# # ---------- rebuild model ----------


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = CTDCRN(
#     channels=cfg["model"]["channels"],
#     che_mid=cfg["model"]["che_mid"],
#     ksize=cfg["model"]["ksize"],
#     n_src=cfg["n_sources"],
#     n_cdcm_a=cfg["model"]["n_cdcm_a"],
#     n_cdcm_b=cfg["model"]["n_cdcm_b"],
#     slope=cfg["model"]["slope"]
# ).to(device)

# ck = torch.load(ckpt_path, map_location=device, weights_only=False)
# model.load_state_dict(ck["model"], strict=False)
# model.eval()
# print(f"[✓] Loaded checkpoint @ epoch {ck.get('epoch')}")

# # ---------- plot training curves ----------
# if hist_path.exists():
#     hist = json.load(open(hist_path))
#     plt.figure(figsize=(6,4))
#     plt.plot(hist["train_loss"], label="Train Loss")
#     plt.plot(hist["val_loss"], label="Val Loss")
#     plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.grid(True)
#     plt.title("Training vs Validation Loss"); plt.tight_layout()
#     plt.savefig(exp / "loss_curve.png", dpi=150)
#     plt.close()

#     plt.figure(figsize=(6,4))
#     plt.plot(hist["val_r_paper"], label="Val Pearson ρ")
#     plt.xlabel("Epoch"); plt.ylabel("ρ"); plt.legend(); plt.grid(True)
#     plt.title("Validation Pearson Correlation"); plt.tight_layout()
#     plt.savefig(exp / "pearson_curve.png", dpi=150)
#     plt.close()
#     print("[✓] Saved loss and Pearson-ρ curves")

# # ---------- build test loader ----------
# data_path = "/home/bss/data/GOLD_XYZ_OSC.0001_1024.hdf5"
# test_ds = GOLDMixAWGNDataset(
#     h5_path=data_path,
#     filter_mods=cfg["mods"],
#     snr_range=(-20,30),
#     sir_db_range=(-3,3),
#     awgn_snr_db_range=(cfg["test_snr"], cfg["test_snr"]),
#     normalize_frames=True,
#     return_meta=False,
# )
# N = len(test_ds)
# test_loader = DataLoader(test_ds, batch_size=16, shuffle=False,
#                          num_workers=4, collate_fn=collate_mix)

# # ---------- metric helpers ----------
# @torch.no_grad()
# def pearson_r_paper_PIT_torch(outs, tgts, eps=1e-8):
#     # identical to your phase-5 version
#     def _pearson_1d(a,b):
#         a=a-a.mean(-1,True); b=b-b.mean(-1,True)
#         return (a*b).sum(-1)/(a.norm(dim=-1)*b.norm(dim=-1)+eps)
#     def _collapse(x): return x.mean(1)
#     def lanes(x):
#         return (_collapse(x[:,0]) if x.ndim==4 else x[:,0,:],
#                 _collapse(x[:,1]) if x.ndim==4 else x[:,1,:])
#     y0r,y0i=lanes(outs[0]); y1r,y1i=lanes(outs[1])
#     s0r,s0i=lanes(tgts[0]); s1r,s1i=lanes(tgts[1])
#     r00=0.5*(_pearson_1d(y0r,s0r)+_pearson_1d(y0i,s0i))
#     r11=0.5*(_pearson_1d(y1r,s1r)+_pearson_1d(y1i,s1i))
#     r01=0.5*(_pearson_1d(y0r,s1r)+_pearson_1d(y0i,s1i))
#     r10=0.5*(_pearson_1d(y1r,s0r)+_pearson_1d(y1i,s0i))
#     best=torch.maximum(r00+r11,r01+r10)*0.5
#     return best.mean().item()

# # ---------- evaluate on test ----------
# mse_fn = complex_mse
# sisnr_fn = si_snr_complex
# crit = CTDCRNLoss(w_mse=1.0, w_sisnr=0.5)

# def eval_test():
#     model.eval()
#     tot, mse_sum, sisnr_sum, rho_sum = 0,0,0,0
#     per_src = [dict(mse=0,sisnr=0,rho=0) for _ in range(2)]
#     n=0
#     with torch.no_grad():
#         for xb, yb in test_loader:
#             xb=[t.to(device) if torch.is_tensor(t) else t for t in xb]
#             xb, yb = xb.to(device), [t.to(device) for t in yb]
#             outs = model(xb)
#             loss,_ = crit(outs, yb)
#             tot += loss.item()
#             # combined metrics
#             mse_sum += mse_fn(torch.cat(outs,1), torch.cat(yb,1)).item()
#             sisnr_sum += -sisnr_fn(torch.cat(outs,1), torch.cat(yb,1)).item()
#             rho_sum += pearson_r_paper_PIT_torch(outs, yb)
#             # per-source
#             for i in range(2):
#                 per_src[i]["mse"] += mse_fn(outs[i], yb[i]).item()
#                 per_src[i]["sisnr"] += -sisnr_fn(outs[i], yb[i]).item()
#                 xr,xi = outs[i][:,0].mean(1), outs[i][:,1].mean(1)
#                 yr,yi = yb[i][:,0].mean(1), yb[i][:,1].mean(1)
#                 per_src[i]["rho"] += float(_pearson_1d(xr,yr).mean())
#             n+=1
#     res = {
#         "test_loss": tot/n,
#         "mse": mse_sum/n,
#         "-si_snr": sisnr_sum/n,
#         "pearson_r": rho_sum/n,
#         "per_src": [{k:v/n for k,v in d.items()} for d in per_src],
#     }
#     return res

# results = eval_test()
# print(json.dumps(results, indent=2))
# json.dump(results, open(exp/"test_metrics.json","w"), indent=2)

# # ---------- visualize one batch ----------
# xb, yb = next(iter(test_loader))
# xb, yb = xb.to(device), [t.to(device) for t in yb]
# outs = model(xb)

# idx = 0  # pick first example
# for s in range(2):
#     gt = yb[s][idx].cpu().numpy()   # (2,1,T)
#     pr = outs[s][idx].detach().cpu().numpy()
#     T = gt.shape[-1]; t = np.arange(T)

#     fig,axs = plt.subplots(2,2,figsize=(10,5))
#     axs[0,0].plot(t, gt[0,0], label="GT-Real"); axs[0,0].plot(t, pr[0,0], label="Pred-Real", alpha=0.7)
#     axs[0,1].plot(t, gt[1,0], label="GT-Imag"); axs[0,1].plot(t, pr[1,0], label="Pred-Imag", alpha=0.7)
#     axs[0,0].legend(); axs[0,1].legend(); axs[0,0].set_title(f"Source {s} Waveforms")

#     # spectrograms
#     from matplotlib.colors import LogNorm
#     fgt, tgt, Sgt = plt.specgram(gt[0,0], NFFT=128, Fs=1.0, noverlap=64, scale='dB')
#     axs[1,0].imshow(Sgt, origin="lower", aspect="auto", cmap="magma",
#                     extent=[tgt.min(), tgt.max(), fgt.min(), fgt.max()])
#     fpr, tpr, Spr = plt.specgram(pr[0,0], NFFT=128, Fs=1.0, noverlap=64, scale='dB')
#     axs[1,1].imshow(Spr, origin="lower", aspect="auto", cmap="magma",
#                     extent=[tpr.min(), tpr.max(), fpr.min(), fpr.max()])
#     axs[1,0].set_title("GT Spectrogram"); axs[1,1].set_title("Pred Spectrogram")

#     plt.tight_layout()
#     plt.savefig(exp / f"source{s}_wave_spectrogram.png", dpi=150)
#     plt.close()
# print("[✓] Saved waveform + spectrogram plots")


In [None]:
# # ---------------- run ----------------
# best = {"val_loss": float("inf"), "epoch": 0}
# best_path = cfg.save_dir / "best.pt"

# for epoch in range(1, cfg.epochs + 1):
#     tr = train_one_epoch(epoch)
#     va = evaluate(val_loader)

#     # headings printed each epoch (more readable)
#     LOG.log("\n" + epoch_header_row())
#     LOG.log(epoch_row(epoch, optimizer.param_groups[0]['lr'], tr, va))

#     # track history
#     history["train_loss"].append(tr["loss"])
#     history["val_loss"].append(va["val_loss"])
#     history["val_r_paper"].append(va["val_r_paper"])

#     # slope reminder if saturated (based on r_paper)
#     if epoch >= 4 and epoch % 2 == 0:
#         recent = history["val_r_paper"][-4:]
#         if len(recent) == 4 and max(recent) - min(recent) < 1e-3:
#             LOG.log("  ⚠️  Validation seems saturated — consider LeakyReLU negative_slope 0.05–0.1.")

#     # checkpoints
#     if va["val_loss"] < best["val_loss"]:
#         best.update({"val_loss": va["val_loss"], "epoch": epoch})
#         save_ckpt(best_path, epoch, model, optimizer, scheduler, scaler, best)
#     save_ckpt(cfg.save_dir / "last.pt", epoch, model, optimizer, scheduler, scaler, best)


In [None]:
# import torch
# from collections import OrderedDict

# def smart_load_ckpt(best_path, model, map_location="cuda", verbose=True):
#     """
#     Loads a checkpoint that may have been saved from:
#       - torch.compile()'d model (keys start with '_orig_mod.')
#       - DataParallel/DistributedDataParallel ('module.' prefix)

#     It:
#       1) loads ckpt
#       2) normalizes keys (strip '_orig_mod.' and/or 'module.')
#       3) loads with strict=False
#       4) prints a summary of missing/unexpected keys
#     Returns: epoch (if present), and IncompatibleKeys.
#     """
#     ck = torch.load(best_path, map_location=map_location)
#     sd = ck.get("model", ck)  # support plain state_dict or dict with "model"

#     def strip_prefix(k):
#         if k.startswith("_orig_mod."):
#             k = k[len("_orig_mod."):]
#         if k.startswith("module."):
#             k = k[len("module."):]
#         return k

#     new_sd = OrderedDict((strip_prefix(k), v) for k, v in sd.items())

#     # Choose correct target module (compiled or not)
#     target = model._orig_mod if hasattr(model, "_orig_mod") else model

#     # Load with strict=False (architecture may differ slightly)
#     incompatible = target.load_state_dict(new_sd, strict=False)

#     if verbose:
#         miss = list(incompatible.missing_keys)
#         unexp = list(incompatible.unexpected_keys)
#         print(f"[smart_load_ckpt] Loaded: {best_path}")
#         print(f"  Missing keys   : {len(miss)}")
#         if miss:
#             print("   (showing a few)", miss[:10])
#         print(f"  Unexpected keys: {len(unexp)}")
#         if unexp:
#             print("   (showing a few)", unexp[:10])

#         # Quick sanity: how many matched?
#         matched = len(new_sd) - len(unexp)
#         print(f"  Matched params : {matched} / {len(new_sd)}")

#     return int(ck.get("epoch", -1)), incompatible


In [None]:
# @torch.no_grad()
# def evaluate_test(loader) -> Dict[str, float]:
#     model.eval()
#     t0 = time.time()
#     loss_sum, r_paper_sum, n_batches = 0.0, 0.0, 0
#     for batch in loader:
#         xb, yb = batch if len(batch) == 2 else (batch[0], batch[1])
#         xb, yb = move_batch(xb, yb)
#         with amp.autocast(device_type=amp_dev, enabled=use_amp):
#             outs = model(xb)
#             loss, _ = criterion(outs, yb)
#         r_paper = pearson_r_paper_PIT_torch(outs, yb)
#         loss_sum    += float(loss)
#         r_paper_sum += float(r_paper)
#         n_batches   += 1
#     n = max(1, n_batches)
#     return {
#         "test_loss": loss_sum / n,
#         "test_r_paper": r_paper_sum / n,
#         "time_s": time.time() - t0
#     }

In [None]:
# # Load best and evaluate
# best_epoch, incompatible = smart_load_ckpt(best_path, model, map_location=device)

# # (Optional) If you compiled the model now, ensure eval on model (compiled handles _orig_mod internally)
# test_stats = evaluate_test(test_loader)
# print(
#     f"\nBEST @ epoch {best_epoch} | TEST | "
#     f"loss={test_stats['test_loss']:.4f}  r_paper={test_stats['test_r_paper']:.4f}  "
#     f"time={test_stats['time_s']:.1f}s"
# )


In [None]:
# best_path= "ctdcrn_runs/exp30/best.pt"
# # ---------------- plots & history save ----------------
# # Save history JSON
# with open(cfg.save_dir / "history.json", "w") as f:
#     json.dump(history, f, indent=2)

# # Plot curves
# fig1 = plt.figure(figsize=(7,4))
# plt.plot(history["train_loss"], label="Train Loss")
# plt.plot(history["val_loss"], label="Val Loss")
# plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss Curves")
# plt.legend(); plt.grid(True, alpha=0.3)
# plt.tight_layout(); plt.savefig(cfg.save_dir / "loss_curves.png", dpi=150); plt.close(fig1)

# fig2 = plt.figure(figsize=(7,4))
# plt.plot(history["val_r_paper"], label="Val r_paper (-1..1)")
# plt.xlabel("Epoch"); plt.ylabel("Correlation")
# plt.title("Validation Pearson (PIT)")
# plt.legend(); plt.grid(True, alpha=0.3)
# plt.tight_layout(); plt.savefig(cfg.save_dir / "val_corr_curves.png", dpi=150); plt.close(fig2)

# LOG.log(f"\nSaved plots and history to: {cfg.save_dir}")

# # ---------------- final test using BEST checkpoint ----------------
# @torch.no_grad()
# def evaluate_test(loader) -> Dict[str, float]:
#     model.eval()
#     t0 = time.time()
#     loss_sum, r_paper_sum, n_batches = 0.0, 0.0, 0
#     for batch in loader:
#         xb, yb = batch if len(batch) == 2 else (batch[0], batch[1])
#         xb, yb = move_batch(xb, yb)
#         with amp.autocast(device_type=amp_dev, enabled=use_amp):
#             outs = model(xb)
#             loss, _ = criterion(outs, yb)
#         r_paper = pearson_r_paper_PIT_torch(outs, yb)
#         loss_sum    += float(loss)
#         r_paper_sum += float(r_paper)
#         n_batches   += 1
#     n = max(1, n_batches)
#     return {
#         "test_loss": loss_sum / n,
#         "test_r_paper": r_paper_sum / n,
#         "time_s": time.time() - t0
#     }

# # Load best checkpoint and run on test set
# best_epoch = load_ckpt(best_path, model, map_location=device)
# test_stats = evaluate_test(test_loader)
# LOG.log("\n" + "-"*126)
# LOG.log(f"BEST @ epoch {best_epoch}  |  "
#         f"TEST @ {TEST_SNR} dB  |  loss={test_stats['test_loss']:.4f}  "
#         f"r_paper={test_stats['test_r_paper']:.4f}  time={test_stats['time_s']:.1f}s")
# LOG.log("-"*126)

# # close log file
# LOG.close()


In [None]:
# # ===============================
# # Phase 5: Result Visualization
# # ===============================
# import os, json
# from pathlib import Path
# import torch
# import matplotlib.pyplot as plt
# import numpy as np
# from IPython.display import display

# # ----------------------------
# # 1) Epoch-wise result plots
# # ----------------------------
# def plot_training_curves(save_dir: Path, history_in_mem: dict | None = None, show: bool = True):
#     hist_path = save_dir / "history.json"
#     if history_in_mem is None:
#         with open(hist_path, "r") as f:
#             history = json.load(f)
#     else:
#         history = history_in_mem

#     train_loss = history.get("train_loss", [])
#     val_loss   = history.get("val_loss", [])
#     val_r      = history.get("val_r_paper", [])

#     # Loss curves
#     fig1 = plt.figure(figsize=(7,4))
#     if train_loss: plt.plot(train_loss, label="Train Loss")
#     if val_loss:   plt.plot(val_loss,   label="Val Loss")
#     plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss Curves")
#     plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
#     out1 = save_dir / "loss_curves.final.png"
#     plt.savefig(out1, dpi=150)
#     if show: display(fig1)
#     plt.close(fig1)

#     # Val Pearson ρ curve
#     fig2 = plt.figure(figsize=(7,4))
#     if val_r: plt.plot(val_r, label="Val Pearson ρ (PIT)")
#     plt.xlabel("Epoch"); plt.ylabel("ρ (−1..1)"); plt.title("Validation Pearson ρ (PIT)")
#     plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
#     out2 = save_dir / "val_corr_curves.final.png"
#     plt.savefig(out2, dpi=150)
#     if show: display(fig2)
#     plt.close(fig2)

#     print(f"[saved] {out1}\n[saved] {out2}")

# # ---------------------------------------------
# # 2) Predicted vs. Ground Truth (I/Q waveforms)
# # ---------------------------------------------
# @torch.no_grad()
# def _pearson_1d(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
#     a = a - a.mean(dim=-1, keepdim=True)
#     b = b - b.mean(dim=-1, keepdim=True)
#     num = (a * b).sum(dim=-1)
#     den = a.norm(dim=-1) * b.norm(dim=-1) + eps
#     return num / den  # (B,)

# @torch.no_grad()
# def _collapse_channels(x: torch.Tensor) -> torch.Tensor:
#     # (B, C, T) -> (B, T)
#     return x.mean(dim=1)

# @torch.no_grad()
# def _lanes_to_BT(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
#     """
#     x: (B,2,1,T) or (B,2,C,T) or (B,2,T) -> returns (xr, xi) each (B,T)
#     """
#     if x.dim() == 4:   # (B,2,C,T)
#         xr = _collapse_channels(x[:,0])  # (B,T)
#         xi = _collapse_channels(x[:,1])
#     elif x.dim() == 3: # (B,2,T)
#         xr, xi = x[:,0], x[:,1]
#     else:
#         raise ValueError(f"unexpected shape {tuple(x.shape)}")
#     return xr, xi

# @torch.no_grad()
# def _per_sample_best_perm(outs, tgts) -> torch.Tensor:
#     """
#     outs/tgts: lists of length 2, tensors (B,2,1,T) or (B,2,C,T)
#     returns: (B,) boolean tensor:
#         True  -> pairing (0->0, 1->1)
#         False -> pairing (0->1, 1->0)
#     """
#     y0r, y0i = _lanes_to_BT(outs[0]); y1r, y1i = _lanes_to_BT(outs[1])
#     s0r, s0i = _lanes_to_BT(tgts[0]); s1r, s1i = _lanes_to_BT(tgts[1])

#     r00 = 0.5*(_pearson_1d(y0r, s0r) + _pearson_1d(y0i, s0i))  # (B,)
#     r11 = 0.5*(_pearson_1d(y1r, s1r) + _pearson_1d(y1i, s1i))
#     r01 = 0.5*(_pearson_1d(y0r, s1r) + _pearson_1d(y0i, s1i))
#     r10 = 0.5*(_pearson_1d(y1r, s0r) + _pearson_1d(y1i, s0i))

#     sum_diag  = r00 + r11
#     sum_off   = r01 + r10
#     choose_diag = sum_diag >= sum_off   # (B,)
#     return choose_diag

# @torch.no_grad()
# def plot_pred_vs_gt(
#     model,
#     loader,
#     device,
#     save_dir: "Path",
#     pick="best",    # "best" (highest mean ρ) or int index
#     max_time=1024,  # plot first T samples if you want to zoom in
#     show: bool = True
# ):
#     model.eval()
#     batch = next(iter(loader))
#     xb, yb = batch if len(batch) == 2 else (batch[0], batch[1])
#     xb = xb.to(device)
#     yb = [t.to(device) for t in yb]

#     # Forward
#     outs = model(xb)  # list of length 2, each (B,2,1,T) or (B,2,C,T)

#     # Figure out the best pairing per sample
#     choose_diag = _per_sample_best_perm(outs, yb)  # (B,)

#     # Compute per-sample mean ρ for selection
#     def _mean_r_for_pairing(diag: torch.Tensor) -> torch.Tensor:
#         y0r, y0i = _lanes_to_BT(outs[0]); y1r, y1i = _lanes_to_BT(outs[1])
#         s0r, s0i = _lanes_to_BT(yb[0]);   s1r, s1i = _lanes_to_BT(yb[1])
#         r00 = 0.5*(_pearson_1d(y0r, s0r) + _pearson_1d(y0i, s0i))
#         r11 = 0.5*(_pearson_1d(y1r, s1r) + _pearson_1d(y1i, s1i))
#         r01 = 0.5*(_pearson_1d(y0r, s1r) + _pearson_1d(y0i, s1i))
#         r10 = 0.5*(_pearson_1d(y1r, s0r) + _pearson_1d(y1i, s0i))
#         best = torch.where(diag, r00 + r11, r01 + r10) * 0.5  # (B,)
#         return best
#     per_sample_r = _mean_r_for_pairing(choose_diag)  # (B,)

#     # Choose which item to plot
#     if pick == "best":
#         b = int(torch.argmax(per_sample_r).item())
#     elif isinstance(pick, int):
#         b = int(pick)
#     else:
#         b = 0
#     B = outs[0].size(0)
#     b = max(0, min(b, B-1))

#     # Build paired (pred, target) for the chosen sample
#     diag = bool(choose_diag[b].item())

#     def _extract_BT(x: torch.Tensor) -> np.ndarray:
#         """
#         Convert (2,1,T) or (2,C,T) tensor to (2,T) numpy.
#         - If C==1 -> squeeze channel
#         - If C>1  -> average across channels
#         """
#         if x.dim() != 3 or x.size(0) != 2:
#             raise ValueError(f"expected (2,C,T), got {tuple(x.shape)}")
#         if x.size(1) == 1:
#             x = x[:, 0, :]            # (2,T)
#         else:
#             x = x.mean(dim=1)         # (2,T)
#         return x.detach().cpu().numpy()

#     y0 = _extract_BT(outs[0][b])
#     y1 = _extract_BT(outs[1][b])
#     s0 = _extract_BT(yb[0][b])
#     s1 = _extract_BT(yb[1][b])

#     # Apply pairing
#     if diag:
#         pairs = [(y0, s0, "Source 1"), (y1, s1, "Source 2")]
#     else:
#         pairs = [(y0, s1, "Source 1 (paired with tgt2)"),
#                  (y1, s0, "Source 2 (paired with tgt1)")]

#     # Plot real/imag overlays
#     T = pairs[0][0].shape[-1]
#     tmax = min(T, max_time) if max_time is not None else T
#     tt = np.arange(tmax)

#     fig, axes = plt.subplots(2, 2, figsize=(12,6), sharex=True)
#     for k, (yp, yt, title) in enumerate(pairs):
#         axes[k,0].plot(tt, yt[0, :tmax], lw=1.0, label="GT Real")
#         axes[k,0].plot(tt, yp[0, :tmax], lw=0.9, linestyle="--", label="Pred Real")
#         axes[k,0].set_title(f"{title} — Real")
#         axes[k,0].grid(alpha=0.3); axes[k,0].legend()

#         axes[k,1].plot(tt, yt[1, :tmax], lw=1.0, label="GT Imag")
#         axes[k,1].plot(tt, yp[1, :tmax], lw=0.9, linestyle="--", label="Pred Imag")
#         axes[k,1].set_title(f"{title} — Imag")
#         axes[k,1].grid(alpha=0.3); axes[k,1].legend()

#     plt.suptitle(f"Predicted vs Ground Truth (sample {b}, "
#                  f"{'diag' if diag else 'offdiag'} pairing, ρ={per_sample_r[b].item():.3f})")
#     plt.tight_layout(rect=[0, 0.03, 1, 0.95])

#     outp = save_dir / f"pred_vs_gt_sample{b}.png"
#     plt.savefig(outp, dpi=150)
#     if show: display(fig)
#     plt.close(fig)
#     print(f"[saved] {outp}")

# # ----------------
# # Run the plots
# # ----------------
# plot_training_curves(cfg.save_dir, history_in_mem=history if 'history' in globals() else None, show=True)
# plot_pred_vs_gt(model, test_loader, device, cfg.save_dir, pick="best", max_time=1024, show=True)


In [None]:
# # --- ENV overrides for terminal runs ---
# import os

# def _env_bool(k, default=False):
#     v = os.getenv(k)
#     if v is None: return default
#     return v.lower() in ("1","true","yes","y","on")

# # MOD list (comma or space separated)
# mods_env = os.getenv("MODS")
# if mods_env:
#     # allow commas or spaces
#     toks = [t for t in mods_env.replace(",", " ").split() if t]
#     TWO_MODS = toks

# # core training knobs
# cfg.epochs        = int(os.getenv("EPOCHS",        cfg.epochs))
# cfg.batch_size    = int(os.getenv("BATCH_SIZE",    cfg.batch_size))
# cfg.lr            = float(os.getenv("LR",          cfg.lr))
# cfg.weight_decay  = float(os.getenv("WEIGHT_DECAY",cfg.weight_decay))
# cfg.grad_clip     = float(os.getenv("GRAD_CLIP",   cfg.grad_clip))
# cfg.amp           = _env_bool("AMP",               cfg.amp)
# cfg.compile_model = _env_bool("COMPILE",           cfg.compile_model)
# cfg.log_interval  = int(os.getenv("LOG_INTERVAL",  cfg.log_interval))

# # model width/depth
# MODEL_CHANNELS    = int(os.getenv("MODEL_CHANNELS", 32))
# MODEL_CHE_MID     = int(os.getenv("MODEL_CHE_MID", 64))
# MODEL_N_CDCM_A    = int(os.getenv("MODEL_N_CDCM_A", 4))
# MODEL_N_CDCM_B    = int(os.getenv("MODEL_N_CDCM_B", 4))

# # SNRs
# if os.getenv("TRAINVAL_SNR"):
#     a,b = [float(x) for x in os.getenv("TRAINVAL_SNR").split(",")]
#     TRAINVAL_SNR_RANGE = (a,b)
# if os.getenv("TEST_SNR"):
#     TEST_SNR = float(os.getenv("TEST_SNR"))

# # half sampling per (mod,SNR)
# HALF_PER_MOD = _env_bool("HALF_PER_MOD", False)
