Dataset Generation:

In [None]:
!pip install -q --upgrade pip
!pip install -q numpy cython librosa soundfile
!pip install -q pyroomacoustics
!pip install -q --upgrade numba==0.56.4 llvmlite==0.39
!pip install -q tqdm

import numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import librosa, soundfile as sf, json, random, glob
from pathlib import Path
from tqdm import tqdm, trange
import pyroomacoustics as pra

FS, DUR, L = 16000, 10.0, int(16000 * 10)
MIC_POS = np.array([[-0.05, 0, 0],
                    [ 0.05, 0, 0],
                    [-0.08, 0.045, 0.04],
                    [ 0.08, 0.045, 0.04]]).T  # (3,4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on", device)

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
Running on cpu


In [None]:
FS, DUR, L = 16000, 10.0, int(16000 * 10)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [None]:
def _load(path, fs=FS, length=L):
    y, sr = sf.read(path)
    if y.ndim > 1: y = y.mean(1)
    if sr!=fs: y = librosa.resample(y.astype(np.float32), sr, fs)
    return np.pad(y, (0, max(0, length-len(y))))[:length]

def _normalize_rms(x, target_db=-25):
    rms = np.sqrt(np.mean(x**2) + 1e-12)
    scalar = (10**(target_db/20)) / rms
    return x * scalar

def simulate_two_speakers(tgt_wav, itf_wav,
                          overlap=0.5,
                          room_rng=((3,3,1.5),(8,8,2.5)),
                          rt60_rng=(.1,.6),
                          seed=None,
                          target_rms_db=-25):
    """
    Returns
    -------
    mix_4ch   : (4, L) float32 – target + interferer, RMS-normalised
    clean_ref : (L,)   float32 – mic-0 target, RMS-normalised
    meta      : dict   – room / placement info
    """
    rng = np.random.default_rng(seed)

    # 1) random room
    dim  = rng.uniform(*np.array(room_rng), 3)
    rt60 = rng.uniform(*rt60_rng)
    abscoef = float(np.mean(pra.inverse_sabine(rt60, dim)))
    room = pra.ShoeBox(dim, fs=FS, absorption=abscoef, max_order=17)

    # 2) mic array
    offset = dim/2; offset[2] = 1.5
    room.add_microphone_array(pra.MicrophoneArray(MIC_POS + offset[:, None], FS))

    # 3) load 4-s clips
    tgt, itf = _load(tgt_wav), _load(itf_wav)

    # 4) add sources
    tgt_pos = np.array([0.0, -0.06, 0.0]) + offset
    room.add_source(tgt_pos, tgt)

    theta = rng.uniform(np.deg2rad(5), np.deg2rad(355))
    r     = rng.uniform(0.5, 3.0)
    itf_pos = np.clip(tgt_pos + r*np.array([np.cos(theta), np.sin(theta), 0.0]),
                      [0,0,0.1], dim-0.1)
    room.add_source(itf_pos, itf, delay=(1-overlap)*2.0)

    # 5) simulate mixture
    room.simulate()
    multi_mix = room.mic_array.signals[:, :L]            # (4, L)

    # 6) build clean reference (target RIR only)
    clean_mics = np.stack(
        [np.convolve(tgt, room.rir[m][0])[:L] for m in range(len(MIC_POS.T))]
    )
    clean_ref = clean_mics[0]

    # 7) RMS normalise
    mix_norm   = _normalize_rms(multi_mix, target_db=target_rms_db)
    clean_norm = _normalize_rms(clean_ref,  target_db=target_rms_db)

    # 8) metadata
    meta = dict(
        room_dim = dim.tolist(),
        rt60     = float(rt60),
        src_pos  = tgt_pos.tolist(),
        itf_pos  = itf_pos.tolist()
    )

    return mix_norm.astype(np.float32), clean_norm.astype(np.float32), meta

In [None]:
!wget -q https://www.openslr.org/resources/12/dev-clean.tar.gz -O dev-clean.tar.gz
!tar -xzf dev-clean.tar.gz

In [None]:
LIB_DIR=Path("./LibriSpeech/dev-clean")
OUT_DIR=Path("/content/drive/MyDrive/audio_dataset/10s"); OUT_DIR.mkdir(parents=True,exist_ok=True)
wav_files=[p for p in LIB_DIR.rglob("*.flac") if p.stat().st_size>0]

def pick_two(): return random.sample(wav_files,2)

def make_sim_dataset(n=100):
    for idx in trange(n, desc="Generating"):
        while True:
            try:
                src,itf=pick_two()
                mix,clean,meta=simulate_two_speakers(src,itf,seed=random.randint(0,2**32-1))
                break
            except Exception as e:
                print(f"Retry {idx}: {e}"); continue
        sf.write(OUT_DIR/f"mix_{idx:05d}.wav",mix.T,FS,subtype="PCM_16")
        sf.write(OUT_DIR/f"clean_{idx:05d}.wav",clean,FS,subtype="PCM_16")
        json.dump(meta, open(OUT_DIR/f"meta_{idx:05d}.json",'w'))
    print("✓ Done")

# run it
# make_sim_dataset(100)

In [None]:
make_sim_dataset(90)

Generating:  48%|████▊     | 43/90 [00:57<00:48,  1.03s/it]

Retry 43: zero-size array to reduction operation maximum which has no identity


Generating:  64%|██████▍   | 58/90 [01:11<00:31,  1.00it/s]

Retry 58: zero-size array to reduction operation maximum which has no identity
Retry 58: evaluation of parameters failed. room may be too large for required RT60.


Generating:  69%|██████▉   | 62/90 [01:15<00:25,  1.12it/s]

Retry 62: zero-size array to reduction operation maximum which has no identity


Generating:  70%|███████   | 63/90 [01:16<00:27,  1.02s/it]

Retry 63: zero-size array to reduction operation maximum which has no identity


Generating:  82%|████████▏ | 74/90 [01:28<00:14,  1.09it/s]

Retry 74: zero-size array to reduction operation maximum which has no identity


Generating: 100%|██████████| 90/90 [01:44<00:00,  1.16s/it]

✓ Done





UNet-Model work begins from here

In [None]:
def collate_fn(batch):
    mix, clean, meta = zip(*batch)   # each is a tuple length B
    return list(mix), list(clean), list(meta)

loader = torch.utils.data.DataLoader(
    ds, batch_size=4, shuffle=True, collate_fn=collate_fn
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ───────────────────────────────
# Small helpers
# ───────────────────────────────
def conv2d(in_ch, out_ch, k=3, s=1, p=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=True)

def deconv2d(in_ch, out_ch, k=3, s=2, p=1, op=1):
    return nn.ConvTranspose2d(in_ch, out_ch, kernel_size=k, stride=s,
                              padding=p, output_padding=op, bias=True)

class TCNBlock(nn.Module):
    """One dilated 1-D convolution block (causal=false)."""
    def __init__(self, ch, dilation):
        super().__init__()
        self.conv = nn.Conv1d(ch, ch, kernel_size=3,
                              padding=dilation, dilation=dilation)
        self.norm = nn.BatchNorm1d(ch)
        self.act  = nn.PReLU()

    def forward(self, x):
        y = self.act(self.norm(self.conv(x)))
        return x + y                      # residual

class StackedTCN(nn.Module):
    """3 × 8 stacked TCN blocks, each with 128 channels."""
    def __init__(self, ch=128, stacks=3, layers=8):
        super().__init__()
        blocks = []
        for _ in range(stacks):
            dil = 1
            for _ in range(layers):
                blocks.append(TCNBlock(ch, dil))
                dil *= 2
        self.net = nn.Sequential(*blocks)

    def forward(self, x):   # x: [B, 128, T]
        return self.net(x)

In [None]:
class UNetTCNMaskNet(nn.Module):
    """
    Input  : [B, 6,  F,  T]   (6 = |Y0|, 3×cosIPD, AF_src, AF_noise)
    Output : [B, 2,  F,  T]   (complex mask: [Real, Imag])  – or set out_ch=1 for IRM
    """
    def __init__(self, out_ch=2):
        super().__init__()

        # ── Encoder
        self.enc1 = nn.Sequential(conv2d(6,  32, s=2), nn.PReLU(), nn.BatchNorm2d(32))
        self.enc2 = nn.Sequential(conv2d(32, 64, s=2), nn.PReLU(), nn.BatchNorm2d(64))
        self.enc3 = nn.Sequential(conv2d(64,128, s=2), nn.PReLU(), nn.BatchNorm2d(128))

        # ── Decoder
        self.dec1 = nn.Sequential(deconv2d(128, 64), nn.PReLU(), nn.BatchNorm2d(64))
        self.dec2 = nn.Sequential(deconv2d(128, 32), nn.PReLU(), nn.BatchNorm2d(32))
        self.dec3 = nn.Sequential(deconv2d(64,  out_ch), nn.PReLU())

        # ── Temporal module
        self.tcn = StackedTCN(ch=128)

    def forward(self, x):
        """
        x : [B, 6, F, T]   (F≈257)
        """
        # Encoder
        e1 = self.enc1(x)                 # [B,32, F/2,  T/2]
        e2 = self.enc2(e1)                # [B,64, F/4,  T/4]
        e3 = self.enc3(e2)                # [B,128,F/8,  T/8]

        # Collapse freq-axis for TCN → treat (B*F') as batch
        B, C, F_, T_ = e3.shape
        tcn_in = e3.permute(0,2,1,3).reshape(B*F_, C, T_)  # [B·F',128,T']

        tcn_out = self.tcn(tcn_in)                          # same shape
        e3 = tcn_out.reshape(B, F_, C, T_).permute(0,2,1,3)

        # Decoder with skips
        d1 = self.dec1(e3)                                  # [B,64, F/4, T/4]
        d1 = torch.cat([d1, e2], dim=1)                    # skip

        d2 = self.dec2(d1)                                  # [B,32, F/2, T/2]
        d2 = torch.cat([d2, e1], dim=1)                    # skip

        mask = self.dec3(d2)                                # [B,out_ch,F,T]
        mask = torch.tanh(mask)                             # keep mask bounded
        return mask

In [None]:
DATA_DIR = Path("/content/drive/MyDrive/audio_dataset")
mix_paths   = sorted(glob.glob(str(DATA_DIR / "mix_*.wav")))
clean_paths = sorted(glob.glob(str(DATA_DIR / "clean_*.wav")))

class MixDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        self.mix_paths   = sorted(Path(data_dir).glob("mix_*.wav"))
        self.clean_paths = sorted(Path(data_dir).glob("clean_*.wav"))
        self.meta_paths  = sorted(Path(data_dir).glob("meta_*.json"))
        assert len(self.mix_paths)==len(self.clean_paths)==len(self.meta_paths)

    def __len__(self): return len(self.mix_paths)

    def __getitem__(self, idx):
        mix,   _ = sf.read(self.mix_paths[idx])      # (L,4)
        clean, _ = sf.read(self.clean_paths[idx])    # (L,)
        with open(self.meta_paths[idx]) as f:
            meta = json.loads(f.read())
        mix = mix.T.astype(np.float32)               # (4,L)
        return mix, clean.astype(np.float32), meta

In [None]:
def stft_multi(x, n_fft=512, hop=256):
    # x : (M,L)  ->  (M,F,T) complex
    return np.stack([
        librosa.stft(ch, n_fft=n_fft, hop_length=hop, window="hann")
        for ch in x
    ])

def make_feature_tensor(mix_4ch, meta, sr=16000, n_fft=512, hop=256):
    """
    mix_4ch : (4,L) float32  mixture
    meta    : dict           contains room_dim, src_pos, itf_pos
    Returns feats [6,F,T] float32, Y_stft [4,F,T] complex64
    """
    # 1) STFT
    Y = stft_multi(mix_4ch)                    # (4,F,T)
    Y0, mag0 = Y[0], np.abs(Y[0])

    # 2) IPD (mic-0 reference)
    ipd     = np.angle(Y[1:]) - np.angle(Y0)   # (3,F,T)
    cos_ipd = np.cos(ipd)                      # (3,F,T)

    # 3) Expected phase delays
    c  = 343.0
    F  = Y.shape[1]
    f  = np.linspace(0, sr/2, F)[:, None]      # (F,1)

    # absolute mic coordinates
    room_dim = np.array(meta["room_dim"])
    offset   = room_dim / 2; offset[2] = 1.5   # same logic as simulator
    p_abs    = (MIC_POS + offset[:,None])      # (3,4)
    p0_abs   = p_abs[:,0]
    pj_abs   = p_abs[:,1:]                     # (3,3)

    # --- target (near-field)
    s_abs    = np.array(meta["src_pos"])
    dist_0   = np.linalg.norm(p0_abs - s_abs)
    dist_j   = np.linalg.norm(pj_abs.T - s_abs, axis=1)       # (3,)
    exp_tgt  = 2*np.pi * f @ ((dist_j - dist_0)[None] / c)    # (F,3)

    # --- interferer (far-field planar)
    i_abs    = np.array(meta["itf_pos"])
    doa_noise= (i_abs - p0_abs) / np.linalg.norm(i_abs - p0_abs)
    d_vec    = pj_abs.T - p0_abs                               # (3,3)
    proj     = d_vec @ doa_noise                               # (3,)
    exp_noise= 2*np.pi * f @ (proj[None] / c)                  # (F,3)

    # angle features
    AF_src   = np.sum(np.cos(ipd.transpose(1,0,2) - exp_tgt[:,:,None]), axis=1)
    AF_noise = np.sum(np.cos(ipd.transpose(1,0,2) - exp_noise[:,:,None]), axis=1)

    feats = np.concatenate(
        [mag0[None], cos_ipd, AF_src[None], AF_noise[None]], axis=0
    ).astype(np.float32)                                      # (6,F,T)

    return torch.from_numpy(feats), torch.from_numpy(Y)

In [None]:
net = UNetTCNMaskNet().to(device)
opt = torch.optim.Adam(net.parameters(), lr=2e-3)
mse = nn.MSELoss()

def collate_fn(batch):
    mix, clean, meta = zip(*batch)   # each is a tuple length B
    return list(mix), list(clean), list(meta)

ds = MixDataset(DATA_DIR)
loader = torch.utils.data.DataLoader(ds, batch_size=4, shuffle=True, collate_fn=collate_fn)

for epoch in range(5):  # Demo: 5 epochs
    net.train()
    total_loss = 0
    for mix, clean, meta in tqdm(loader, desc=f"Epoch {epoch}"):
        feats, Y, X = [], [], []

        for m, c, meta_dict in zip(mix, clean, meta):
            f, y = make_feature_tensor(m, meta_dict)
            x    = librosa.stft(np.array(c), n_fft=512, hop_length=256, window="hann")
            feats.append(f)
            Y.append(y)
            X.append(torch.from_numpy(x))

        feats = torch.stack(feats).to(device)        # [B,6,F,T]
        Y     = torch.stack(Y).to(device)           # [B,4,F,T]
        X     = torch.stack(X).to(device)            # [B,F,T]

        mask  = net(feats)                       # [B,2,F,T]
        M     = torch.complex(mask[:,0], mask[:,1])   # [B,F,T]

        S_hat = M[:, None] * Y                        # [B,4,F,T]
        loss  = mse(S_hat.real[:,0], X.real) + mse(S_hat.imag[:,0], X.imag)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()

    print(f"[Epoch {epoch}] Avg Loss: {total_loss/len(loader):.4f}")

# Save model
torch.save(net.state_dict(), "unet_tcn_masknet.pt")
print("✓ Model saved to 'unet_tcn_masknet.pt'")

Epoch 0:   0%|          | 0/25 [00:01<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 66 but got size 1 for tensor number 1 in the list.

In [None]:
def conv2d(in_ch, out_ch, k=3, s=1, p=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=True)

def deconv2d(in_ch, out_ch, k=3, s=2, p=1, op=1):
    return nn.ConvTranspose2d(in_ch, out_ch, kernel_size=k, stride=s,
                              padding=p, output_padding=op, bias=True)

class TCNBlock(nn.Module):
    def __init__(self, ch, dilation):
        super().__init__()
        self.conv = nn.Conv1d(ch, ch, 3, padding=dilation, dilation=dilation)
        self.norm = nn.BatchNorm1d(ch)
        self.act  = nn.PReLU()
    def forward(self, x):
        return x + self.act(self.norm(self.conv(x)))

class StackedTCN(nn.Module):
    def __init__(self, ch=128, stacks=3, layers=8):
        super().__init__()
        blocks = []
        for _ in range(stacks):
            dil = 1
            for _ in range(layers):
                blocks.append(TCNBlock(ch, dil))
                dil *= 2
        self.net = nn.Sequential(*blocks)
    def forward(self, x): return self.net(x)

class UNetTCNMaskNet(nn.Module):
    def __init__(self, in_ch=6, out_ch=2):
        super().__init__()
        # encoder
        self.enc1 = nn.Sequential(conv2d(in_ch, 32, s=2), nn.PReLU(), nn.BatchNorm2d(32))
        self.enc2 = nn.Sequential(conv2d(32, 64, s=2), nn.PReLU(), nn.BatchNorm2d(64))
        self.enc3 = nn.Sequential(conv2d(64,128, s=2), nn.PReLU(), nn.BatchNorm2d(128))
        # temporal
        self.tcn  = StackedTCN(128)
        # decoder
        self.dec1 = nn.Sequential(deconv2d(128, 64), nn.PReLU(), nn.BatchNorm2d(64))
        self.dec2 = nn.Sequential(deconv2d(128, 32), nn.PReLU(), nn.BatchNorm2d(32))
        self.dec3 = nn.Sequential(deconv2d(64,  out_ch), nn.PReLU())

    # ─────────────────────────────────────────────
    # helper stays unchanged
    def crop_like(self, src, ref):
        """Center‑crop `src` so its spatial (F,T) dims match `ref`."""
        _, _, H, W = src.shape
        _, _, h, w = ref.shape
        if H == h and W == w:
            return src
        dH, dW = (H - h) // 2, (W - w) // 2
        return src[:, :, dH:dH + h, dW:dW + w]

    # ─────────────────────────────────────────────
    def forward(self, x):           # x: [B, 6, F, T]
        # Encoder
        e1 = self.enc1(x)           # [B,32,F/2,T/2]
        e2 = self.enc2(e1)          # [B,64,F/4,T/4]
        e3 = self.enc3(e2)          # [B,128,F/8,T/8]

        # Temporal TCN
        B, C, F_, T_ = e3.shape
        tcn_in  = e3.permute(0, 2, 1, 3).reshape(B * F_, C, T_)  # [B·F',128,T']
        tcn_out = self.tcn(tcn_in).reshape(B, F_, C, T_).permute(0, 2, 1, 3)
        e3 = tcn_out                                                # still [B,128,F/8,T/8]

        # Decoder with safe crops
        d1 = self.dec1(e3)                      # up ×2 → [B,64,F/4’,T/4’]
        d1 = self.crop_like(d1, e2)             # crop to e2’s size
        d1 = torch.cat([d1, e2], dim=1)         # [B,128,F/4,T/4]

        d2 = self.dec2(d1)                      # up ×2 → [B,32,F/2’,T/2’]
        d2 = self.crop_like(d2, e1)             # crop to e1’s size
        d2 = torch.cat([d2, e1], dim=1)         # [B,64,F/2,T/2]

        mask = torch.tanh(self.dec3(d2))        # final up ×2 → [B,2,F,T]
        return mask


In [None]:
def stft_multi(x, n_fft=512, hop=256):
    return np.stack([librosa.stft(ch, n_fft=n_fft, hop_length=hop, window="hann")
                     for ch in x])

def make_feature_tensor(mix_4ch, meta, sr=16000, n_fft=512, hop=256):
    Y = stft_multi(mix_4ch)                          # (4,F,T)
    mag0 = np.abs(Y[0])
    ipd  = np.angle(Y[1:]) - np.angle(Y[0])          # (3,F,T)
    cos_ipd = np.cos(ipd)                            # (3,F,T)

    # expected phase delays
    c=343.0; F=Y.shape[1]; f=np.linspace(0,sr/2,F)[:,None]
    room_dim=np.array(meta["room_dim"]); offset=room_dim/2; offset[2]=1.5
    p_abs=(MIC_POS + offset[:,None]); p0=p_abs[:,0]; pj=p_abs[:,1:]

    s=np.array(meta["src_pos"]);  i=np.array(meta["itf_pos"])
    dist0=np.linalg.norm(p0-s);       distj=np.linalg.norm(pj.T-s, axis=1)
    exp_tgt=2*np.pi*f@((distj-dist0)[None]/c)         # (F,3)
    doa=(i-p0)/np.linalg.norm(i-p0); proj=(pj.T-p0)@doa
    exp_noise=2*np.pi*f@(proj[None]/c)                # (F,3)

    AF_src  = np.sum(np.cos(ipd.transpose(1,0,2)-exp_tgt[:,:,None]), axis=1)
    AF_noise= np.sum(np.cos(ipd.transpose(1,0,2)-exp_noise[:,:,None]),axis=1)

    feats=np.concatenate([mag0[None], cos_ipd, AF_src[None], AF_noise[None]],0)
    return torch.from_numpy(feats), torch.from_numpy(Y)

In [None]:
def steering_vector(sr=16000, n_fft=512):
    f = np.linspace(0, sr/2, n_fft//2+1)             # (F,)
    p0 = MIC_POS[:,0]; s = np.array([0, -0.06, 0])   # relative to array
    dist0=np.linalg.norm(p0-s); dist = np.linalg.norm(MIC_POS.T - s, axis=1)
    tau = (dist - dist0)[None] / 343.0               # (1,M)
    return torch.from_numpy(np.exp(-2j*np.pi*f[:,None]*tau)).to(torch.complex64) # (F,4)

alpha_src = steering_vector()

In [None]:
def mvdr_framewise(Y, M, alpha_src, alpha_smooth=0.8, eps=1e-6):
    """
    Inputs:
      Y : (B, M, F, T) complex STFTs
      M : (B, F, T) complex mask
      alpha_src : float in (0,1), EMA update for target spatial cov
    Returns:
      X_hat : (B, F, T) complex – beamformed output
    """
    B, Mics, F, T = Y.shape
    S_hat = M[:, None] * Y        # (B, M, F, T)
    N_hat = Y - S_hat             # (B, M, F, T)

    # Init frequency-wise spatial covariances
    Phi = torch.eye(Mics, dtype=Y.dtype, device=Y.device)[None, None, :, :].repeat(B, F, 1, 1)  # (B, F, 4, 4)
    X_hat = []

    for t in range(T):
        n = N_hat[..., t]                     # (B, M, F)
        n = n.permute(0, 2, 1)                # (B, F, M)
        Phi_inst = torch.einsum("bfi,bfj->bfij", n, n.conj())   # (B, F, M, M)

        Phi = alpha_smooth * Phi + (1 - alpha_smooth) * Phi_inst

        s = S_hat[..., t].permute(0, 2, 1)                     # (B, M, F)
        s = s.permute(0, 2, 1)                # (B, F, M)
        phi_s = torch.einsum("bfi,bfj->bfij", s, s.conj())      # (B, F, M, M)
        phi_s = alpha_src * phi_s

        # Beamforming vector
        try:
            Phi_inv = torch.linalg.inv(Phi + eps * torch.eye(Mics, dtype=Y.dtype, device=Y.device)[None, None])
        except RuntimeError:
            Phi_inv = torch.linalg.pinv(Phi + eps * torch.eye(Mics, dtype=Y.dtype, device=Y.device)[None, None])

        v = s                             # (B, F, M)
        numerator = torch.einsum("bfij,bfj->bfi", Phi_inv, v)
        denom = torch.einsum("bfi,bfi->bf", v.conj(), numerator).real + eps
        w = (numerator / denom[..., None])   # (B, F, M)

        y = Y[..., t].permute(0, 2, 1)        # (B, F, M)
        x_t = torch.einsum("bfi,bfi->bf", w.conj(), y)  # (B, F)
        X_hat.append(x_t)

    return torch.stack(X_hat, dim=-1)         # (B, F, T)

def complex_istft(X, n_fft=512, hop=256):
    # X: (B,F,T) complex -> (B,L) real
    B = X.shape[0]
    out=[]
    for b in range(B):
        out.append(librosa.istft(X[b].cpu().numpy(), hop_length=hop, window="hann"))
    maxL = max(len(x) for x in out)
    out = [np.pad(x,(0,maxL-len(x))) for x in out]
    return torch.from_numpy(np.stack(out)).to(X.real.device)

def si_sdr(est, ref, eps=1e-8):
    # est, ref: (B,L)
    ref_energy = torch.sum(ref**2, dim=-1, keepdim=True)+eps
    s_target = torch.sum(est*ref, dim=-1, keepdim=True)/ref_energy * ref
    e_noise  = est - s_target
    return 10*torch.log10((torch.sum(s_target**2, -1)+eps)/(torch.sum(e_noise**2,-1)+eps))

In [None]:
class MixDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        self.mix_paths   = sorted(Path(data_dir).glob("mix_*.wav"))
        self.clean_paths = sorted(Path(data_dir).glob("clean_*.wav"))
        self.meta_paths  = sorted(Path(data_dir).glob("meta_*.json"))
        assert len(self.mix_paths)==len(self.clean_paths)==len(self.meta_paths)

    def __len__(self): return len(self.mix_paths)

    def __getitem__(self, idx):
        mix, _   = sf.read(self.mix_paths[idx]);  mix=mix.T.astype(np.float32)
        clean,_  = sf.read(self.clean_paths[idx])
        meta     = json.load(open(self.meta_paths[idx]))
        return mix, clean.astype(np.float32), meta

def collate_fn(batch):
    mix, clean, meta = zip(*batch)
    return list(mix), list(clean), list(meta)

ds = MixDataset("/content/drive/MyDrive/audio_dataset")
loader = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=True,
                                     num_workers=2, collate_fn=collate_fn)

In [None]:
net = UNetTCNMaskNet().to(device)
opt = torch.optim.Adam(net.parameters(), lr=2e-3)

for epoch in range(2):                         # demo 2 epochs
    net.train(); running = 0.0
    for mix, clean, meta in tqdm(loader, desc=f"Epoch {epoch}"):
        feats_b, Y_b, clean_b = [], [], []
        for m, c, md in zip(mix, clean, meta):
            f, Y = make_feature_tensor(m, md)
            feats_b.append(f);  Y_b.append(Y);  clean_b.append(torch.from_numpy(c))

        feats = torch.stack(feats_b).float().to(device)      # [B,6,F,T]
        Y     = torch.stack(Y_b).to(device)                  # [B,4,F,T]
        clean_wave = torch.stack(clean_b).to(device)         # [B,L]

        mask = net(feats)                                    # [B,2,Fm,Tm]

        # ── align mask & mixture spatial dims ────────────────────────────  ### NEW ###
        _, _, Fm, Tm = mask.shape
        _, _, Fy, Ty = Y.shape
        F_common, T_common = min(Fm, Fy), min(Tm, Ty)
        mask = mask[:, :, :F_common, :T_common]              # crop mask
        Y    = Y[:, :, :F_common, :T_common]                 # crop mixture
        # ------------------------------------------------------------------  ### NEW ###

        M    = torch.complex(mask[:,0], mask[:,1])           # [B,F,T]

        X_hat = mvdr_framewise(Y, M, alpha_src)              # (B,F,T)
        x_hat = complex_istft(X_hat)                         # (B,L)

        clean_wave = clean_wave[..., :x_hat.shape[-1]]
        loss_si = -si_sdr(x_hat, clean_wave).mean()

        # MSE on reference‑mic STFT (cropped to F_common,T_common)
        X_ref = torch.view_as_real(Y[:,0])                   # (B,F,T,2)
        loss_tf = F.mse_loss((M*Y[:,0]).real, X_ref[...,0]) \
                + F.mse_loss((M*Y[:,0]).imag, X_ref[...,1])

        loss = 0.5*loss_si + 0.5*loss_tf
        opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item()

    print(f"Epoch {epoch}  avg loss {running/len(loader):.4f}")

torch.save(net.state_dict(), "/content/drive/MyDrive/audio_dataset/masknet_mvdr.pth")
print("✓ model + MVDR training complete")

Epoch 0:   0%|          | 0/50 [00:01<?, ?it/s]


RuntimeError: The size of tensor a (4) must match the size of tensor b (257) at non-singleton dimension 3