In [6]:
# =======================
#  Google Drive Setup
# =======================
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [8]:
#RUN THIS CODE IF TARGET FILES ARE IN FLAC FORMAT
import os
import soundfile as sf
from tqdm import tqdm

# -----------------------------
#  CONFIGURE PATHS
# -----------------------------
PF_ROOT = "/content/drive/MyDrive/pf_dataset"   # üëà your PF dataset root
BACKUP_FLAC = False                              # set False to delete original FLACs

# -----------------------------
#  MAIN CONVERSION LOOP
# -----------------------------
count = 0
folders = [os.path.join(PF_ROOT, d) for d in os.listdir(PF_ROOT) if os.path.isdir(os.path.join(PF_ROOT, d))]

for folder in tqdm(folders, desc="Converting FLAC ‚Üí WAV"):
    flac_path = os.path.join(folder, "target.flac")
    wav_path  = os.path.join(folder, "target.wav")

    if os.path.isfile(flac_path):
        try:
            data, fs = sf.read(flac_path, dtype="float32")
            sf.write(wav_path, data, fs)
            count += 1

            if not BACKUP_FLAC:
                os.remove(flac_path)
        except Exception as e:
            print(f"‚ùå Error converting {flac_path}: {e}")

print(f"\n‚úÖ Converted {count} target.flac ‚Üí target.wav successfully.")
if BACKUP_FLAC:
    print("‚ÑπÔ∏è Original FLAC files are kept. Set BACKUP_FLAC=False to delete them.")


Converting FLAC ‚Üí WAV: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:01<00:00,  2.19it/s]


‚úÖ Converted 2 target.flac ‚Üí target.wav successfully.
‚ÑπÔ∏è Original FLAC files are kept. Set BACKUP_FLAC=False to delete them.





In [2]:
import os, json, math, gc, glob
import numpy as np
import soundfile as sf
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


In [9]:

DATA_ROOT = '/content/drive/MyDrive/pf_dataset'
OUT_DIR   = '/content/drive/MyDrive/ckpts_pf'

# =======================
#  PostFilter Training Code
# =======================
import os, json, math, gc, glob
import numpy as np
import soundfile as sf
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

FS   = 16000
NFFT = 512
HOP  = 256
EPS  = 1e-8

# ---------- STFT helpers ----------
_WIN_CACHE = {}
def _hann(device):
    key = (device, NFFT)
    if key not in _WIN_CACHE:
        _WIN_CACHE[key] = torch.hann_window(NFFT, device=device)
    return _WIN_CACHE[key]

def stft(x):  return torch.stft(x, n_fft=NFFT, hop_length=HOP, window=_hann(x.device), return_complex=True)
def istft(Xc, length): return torch.istft(Xc, n_fft=NFFT, hop_length=HOP, window=_hann(Xc.device), length=length)

# ---------- Losses ----------
def si_sdr_loss(estimate, target, eps=1e-8):
    t_energy = torch.sum(target**2, dim=1, keepdim=True) + eps
    scale = torch.sum(estimate * target, dim=1, keepdim=True) / t_energy
    s_target = scale * target
    e_noise  = estimate - s_target
    num = torch.sum(s_target**2, dim=1) + eps
    den = torch.sum(e_noise**2,  dim=1) + eps
    return -10.0 * torch.log10(num / den + eps).mean()

def stft_mag_loss(y_hat, y, w=0.1):
    Yh, Yt = stft(y_hat), stft(y)
    return w * (Yh.abs() - Yt.abs()).abs().mean()

# ---------- Dataset ----------
class PFDataset(Dataset):
    def __init__(self, root, split="train", split_file=None):
        self.root = root
        if split_file and os.path.isfile(split_file):
            with open(split_file, "r") as f: sp = json.load(f)
            self.ids = sp[split]
        else:
            all_ids = sorted([d for d in os.listdir(root) if d.startswith("data")])
            k = int(0.9 * len(all_ids))
            self.ids = all_ids[:k] if split=="train" else all_ids[k:]
    def __len__(self): return len(self.ids)
    def __getitem__(self, idx):
      d = self.ids[idx]
      sub = os.path.join(self.root, d)
      x_path = os.path.join(sub, "nb_out.wav")
      y_path = os.path.join(sub, "target.wav")
      if not (os.path.isfile(x_path) and os.path.isfile(y_path)):
          # also try FLAC fallback if needed
          if os.path.isfile(os.path.join(sub, "nb_out.flac")):
              x_path = os.path.join(sub, "nb_out.flac")
          if os.path.isfile(os.path.join(sub, "target.flac")):
              y_path = os.path.join(sub, "target.flac")
          if not (os.path.isfile(x_path) and os.path.isfile(y_path)):
              return None

      # Always get a 2-D array [T, C]
      x_np, fs1 = sf.read(x_path, dtype="float32", always_2d=True)
      y_np, fs2 = sf.read(y_path, dtype="float32", always_2d=True)

      # Optional: resample to FS if mismatched
      if fs1 != FS or fs2 != FS:
          try:
              import librosa
              if fs1 != FS:
                  x_np = librosa.resample(x_np.T, orig_sr=fs1, target_sr=FS).T
                  fs1 = FS
              if fs2 != FS:
                  y_np = librosa.resample(y_np.T, orig_sr=fs2, target_sr=FS).T
                  fs2 = FS
          except Exception:
              return None

      # Collapse channels to mono: mean across channel axis (axis=1)
      if x_np.shape[1] > 1:
          x_np = np.mean(x_np, axis=1, keepdims=False)  # -> [T]
      else:
          x_np = x_np[:, 0]  # -> [T]
      if y_np.shape[1] > 1:
          y_np = np.mean(y_np, axis=1, keepdims=False)  # -> [T]
      else:
          y_np = y_np[:, 0]  # -> [T]

      # Length match and minimal length guard
      L = min(len(x_np), len(y_np))
      if L < 1024:  # skip too-short pairs; STFT needs enough samples
          return None

      x = torch.from_numpy(x_np[:L])  # [T]
      y = torch.from_numpy(y_np[:L])  # [T]
      return {"nb": x, "tgt": y, "length": L}


def collate_pad(batch):
    batch = [b for b in batch if b is not None]
    if not batch: return None
    Lmax=max(b["length"] for b in batch)
    xs,ys=[],[]
    for b in batch:
        pad=Lmax-b["length"]
        xs.append(torch.cat([b["nb"], torch.zeros(pad)]) if pad>0 else b["nb"])
        ys.append(torch.cat([b["tgt"],torch.zeros(pad)]) if pad>0 else b["tgt"])
    return {"nb":torch.stack(xs,0), "tgt":torch.stack(ys,0), "length":Lmax}

# ---------- Model ----------
class ConvBlock(nn.Module):
    def __init__(self,ch,dilation=1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Conv2d(ch,ch,3,padding=dilation,dilation=dilation),
            nn.BatchNorm2d(ch), nn.LeakyReLU(0.1,True),
            nn.Conv2d(ch,ch,3,padding=1), nn.BatchNorm2d(ch))
        self.act=nn.LeakyReLU(0.1,True)
    def forward(self,x): return self.act(self.net(x)+x)

class PostFilterNet(nn.Module):
    def __init__(self, n_fft=512, hop=256, hidden=128, enc_ch=64, gru_layers=1):
        super().__init__()
        self.n_fft = n_fft
        self.hop = hop

        self.enc = nn.Sequential(
            nn.Conv2d(2, enc_ch, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            ConvBlock(enc_ch, 1),
            ConvBlock(enc_ch, 2),
            ConvBlock(enc_ch, 4)
        )
        self.proj = nn.Conv2d(enc_ch, hidden, 1)
        self.gru = nn.GRU(hidden, hidden, num_layers=gru_layers,
                          batch_first=True, bidirectional=True)
        self.head = nn.Conv2d(2*hidden, 2, 1)
        self.tanh = nn.Tanh()
    def forward(self, x):
      print("HI")
      if x.ndim == 1:
        x = x.unsqueeze(0)
      elif x.ndim == 3:
        x = x.squeeze(1)
        print(x)
      elif x.ndim != 2:
        raise ValueError(f"Unexpected input shape {x.shape}, expected [B,T] or [B,1,T]")
      print(x.ndim)
      B, T = x.shape

      # --- 2Ô∏è‚É£ Safety guard for short audio ---
      if T < self.n_fft:
          pad = self.n_fft - T
          x = torch.nn.functional.pad(x, (0, pad))   # zero-pad to window size
          T = x.shape[1]

      # --- 3Ô∏è‚É£ Compute STFT ---
      X = stft(x)                         # [B,F,Tf] complex
      F, Ts = X.shape[1], X.shape[2]

      # --- 4Ô∏è‚É£ Real/imag stacking for Conv2D ---
      Xri = torch.stack([X.real, X.imag], dim=1)     # [B,2,F,Ts]
      h = self.enc(Xri)                              # [B,C,F,Ts]
      h = self.proj(h)                               # [B,H,F,Ts]

      # --- 5Ô∏è‚É£ GRU along time per frequency ---
      h = h.permute(0,2,3,1).reshape(B*F, Ts, -1)    # [B*F,Ts,H]
      y,_ = self.gru(h)                              # [B*F,Ts,2H]
      y = y.reshape(B, F, Ts, -1).permute(0,3,1,2)   # [B,2H,F,Ts]

      # --- 6Ô∏è‚É£ Complex mask prediction ---
      Mri = self.head(y)                             # [B,2,F,Ts]
      M = self.tanh(Mri[:,0]) + 1j*self.tanh(Mri[:,1])  # complex mask ‚àà (-1,1)

      # --- 7Ô∏è‚É£ Apply mask and iSTFT ---
      Y = torch.conj(M) * X                          # filtered STFT
      y_hat = istft(Y, length=T)                     # [B,T]

      # --- 8Ô∏è‚É£ Final cleanup ---
      y_hat = torch.nan_to_num(y_hat, nan=0.0, posinf=0.0, neginf=0.0)
      return y_hat


# ---------- Train ----------
def train_pf(data_root=DATA_ROOT,out_dir=OUT_DIR,epochs=20,batch=4,lr=2e-4,hidden=128,enc_ch=64):
    torch.manual_seed(42)
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:",device)

    tr_ds,va_ds=PFDataset(data_root,"train"),PFDataset(data_root,"val")
    tr_dl = DataLoader(
          dataset=tr_ds,
          batch_size=batch,
          shuffle=(len(tr_ds) > 1),   # shuffle only if multiple samples
          num_workers=2,
          pin_memory=True,
          collate_fn=collate_pad,
          drop_last=False
      )

    va_dl = DataLoader(
        dataset=va_ds,
        batch_size=batch,
        shuffle=False,              # validation never shuffled
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_pad,
        drop_last=False
      )

    net=PostFilterNet(hidden=hidden,enc_ch=enc_ch).to(device)
    opt=torch.optim.Adam(net.parameters(),lr=lr,weight_decay=1e-6)
    sched=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,"min",0.5,3)
    scaler=torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))
    os.makedirs(out_dir,exist_ok=True)
    best_val=float('inf')

    for ep in range(1,epochs+1):
        net.train(); tr_loss=0; nbatch=0
        for b in tqdm(tr_dl,desc=f"Train {ep}/{epochs}",ncols=80):
            if b is None: continue
            nb=b["nb"].to(device); tgt=b["tgt"].to(device)
            with torch.cuda.amp.autocast(dtype=torch.float16,enabled=(device.type=="cuda")):
                y_hat=net(nb)
                loss=0.7*si_sdr_loss(y_hat,tgt)+0.3*torch.mean(torch.abs(y_hat-tgt))
                loss+=stft_mag_loss(y_hat,tgt,0.05)
            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
            tr_loss+=float(loss.detach().cpu()); nbatch+=1
        tr_loss/=max(1,nbatch)

        net.eval(); va_loss=0; n_val=0
        with torch.no_grad():
            for b in va_dl:
                if b is None: continue
                nb=b["nb"].to(device); tgt=b["tgt"].to(device)
                y_hat=net(nb)
                loss=0.7*si_sdr_loss(y_hat,tgt)+0.3*torch.mean(torch.abs(y_hat-tgt))
                loss+=stft_mag_loss(y_hat,tgt,0.05)
                va_loss+=float(loss.detach().cpu()); n_val+=1
        va_loss/=max(1,n_val)
        print(f"Epoch {ep}: train {tr_loss:.4f} | val {va_loss:.4f}")
        sched.step(va_loss)

        ckpt={"epoch":ep,"net":net.state_dict(),"opt":opt.state_dict(),"sched":sched.state_dict(),"val":va_loss}
        torch.save(ckpt, os.path.join(out_dir,"pf_last.pt"))
        if va_loss<best_val:
            best_val=va_loss
            torch.save(ckpt, os.path.join(out_dir,"pf_best.pt"))
            print("‚úî Saved new best model")

In [50]:
train_pf(
    data_root=DATA_ROOT,
    out_dir=OUT_DIR,
    epochs=6,      # bump up when you have more data
    batch=4,       # small dataset ‚Üí batch 1
    lr=2e-4,
    hidden=128,
    enc_ch=64
)


  scaler=torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))


Device: cpu


Train 1/6:   0%|                                          | 0/1 [00:00<?, ?it/s]

HI
2


  with torch.cuda.amp.autocast(dtype=torch.float16,enabled=(device.type=="cuda")):
Train 1/6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:05<00:00,  5.91s/it]


HI
2
Epoch 1: train 0.0418 | val 7.8540
‚úî Saved new best model


Train 2/6:   0%|                                          | 0/1 [00:00<?, ?it/s]

HI
2


Train 2/6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:04<00:00,  4.67s/it]


HI
2
Epoch 2: train -1.2769 | val 6.5628
‚úî Saved new best model


Train 3/6:   0%|                                          | 0/1 [00:00<?, ?it/s]

HI
2


Train 3/6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:03<00:00,  3.49s/it]


HI
2
Epoch 3: train -2.8490 | val 5.2929
‚úî Saved new best model


Train 4/6:   0%|                                          | 0/1 [00:00<?, ?it/s]

HI
2


Train 4/6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:03<00:00,  3.80s/it]


HI
2
Epoch 4: train -3.0349 | val 4.8451
‚úî Saved new best model


Train 5/6:   0%|                                          | 0/1 [00:00<?, ?it/s]

HI
2


Train 5/6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:03<00:00,  3.63s/it]


HI
2
Epoch 5: train -3.0511 | val 4.8586


Train 6/6:   0%|                                          | 0/1 [00:00<?, ?it/s]

HI
2


Train 6/6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:03<00:00,  3.56s/it]


HI
2
Epoch 6: train -3.4858 | val 5.1780


In [13]:
@torch.no_grad()
def eval_one_no_target(sample_dir, ckpt_path, out_path=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x_path = os.path.join(sample_dir, "nb_out.wav")

    # 1. Load the beamformer output
    x, fs = sf.read(x_path, dtype="float32")
    if x.ndim == 2:
    # average the two stereo channels
      x = np.mean(x, axis=1)
    L = len(x)
    x_t = torch.from_numpy(x[:L]).unsqueeze(0).to(device)

    # 2. Load the PF model
    ck = torch.load(ckpt_path, map_location=device)
    net = PostFilterNet().to(device)
    net.load_state_dict(ck["net"])
    net.eval()

    # 3. Enhance the audio
    y_hat = net(x_t).squeeze(0).cpu().numpy()

    # 4. Save the enhanced output
    if out_path is None:
        out_path = os.path.join(sample_dir, "pf_enhanced.wav")
    sf.write(out_path, y_hat, fs)
    print(f"‚úÖ Enhanced audio saved ‚Üí {out_path}")


In [14]:
eval_one_no_target(
    sample_dir = "/content/drive/MyDrive/pf_dataset/data03",
    ckpt_path  = "/content/drive/MyDrive/ckpts_pf/pf_best.pt"
)


HI
2
‚úÖ Enhanced audio saved ‚Üí /content/drive/MyDrive/pf_dataset/data03/pf_enhanced.wav
