In [None]:
!pip install -U datasets[audio]
!pip install torchaudio
!pip install torchcodec

In [None]:
from gg

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import wandb
import random
import os
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [None]:
wandb.login()

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [None]:
class CausalConvBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel_size, stride, padding):
        super().__init__()
        k_f, k_t = kernel_size
        p_f, p_t = padding
        self.time_pad = (k_t - 1)
        self.freq_pad = p_f
        self.conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding=(0, 0))
        self.bn = nn.BatchNorm2d(out_c)
        self.act = nn.PReLU()

    def forward(self, x):
        x = F.pad(x, (self.time_pad, 0, self.freq_pad, self.freq_pad))
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

class GatedTrConvBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel_size, stride, padding, output_padding=(0,0)):
        super().__init__()
        k_f, k_t = kernel_size
        self.chomp_t = k_t - 1
        self.gate_conv = nn.Sequential(nn.Conv2d(in_c, in_c, 1), nn.Tanh())
        self.tr_conv = nn.ConvTranspose2d(in_c, out_c, kernel_size, stride, padding, output_padding)
        self.bn = nn.BatchNorm2d(out_c)
        self.act = nn.PReLU()

    def forward(self, x, skip):
        x_cat = torch.cat([x, skip], dim=1)
        gate = self.gate_conv(x_cat)
        x_gated = x_cat * gate
        out = self.tr_conv(x_gated)
        if self.chomp_t > 0:
            out = out[:, :, :, :-self.chomp_t]
        out = self.bn(out)
        out = self.act(out)
        return out

class FT_GRU_Block(nn.Module):
    def __init__(self, in_c, hidden_f, hidden_t):
        super().__init__()
        self.f_gru = nn.GRU(in_c, hidden_f, batch_first=True, bidirectional=True)
        self.f_linear = nn.Linear(hidden_f * 2, in_c)
        self.t_gru = nn.GRU(in_c, hidden_t, batch_first=True, bidirectional=False)
        self.t_linear = nn.Linear(hidden_t, in_c)
        self.bn = nn.BatchNorm2d(in_c)
        self.act = nn.PReLU()

    def forward(self, x):
        B, C, Freq, T = x.shape
        x_f = x.permute(0, 3, 2, 1).reshape(-1, Freq, C)
        x_f_out, _ = self.f_gru(x_f)
        x_f_out = self.f_linear(x_f_out)
        x_f_res = x_f_out.reshape(B, T, Freq, C).permute(0, 3, 2, 1)
        x = x + x_f_res
        x_t = x.permute(0, 2, 3, 1).reshape(-1, T, C)
        x_t_out, _ = self.t_gru(x_t)
        x_t_out = self.t_linear(x_t_out)
        x_t_res = x_t_out.reshape(B, Freq, T, C).permute(0, 3, 1, 2)
        x = x + x_t_res
        x = self.bn(x)
        x = self.act(x)
        return x

In [None]:
class VADModule(nn.Module):
    def __init__(self, in_channels=32, freq_bins=8): # Bottleneck Freq = 8 for 257 input
        super().__init__()
        self.conv2d = nn.Sequential(nn.Conv2d(in_channels, 16, 1), nn.BatchNorm2d(16), nn.PReLU())
        self.f_gru = nn.GRU(16 * freq_bins, 8, batch_first=True, bidirectional=True)
        self.conv1d_block = nn.Sequential(nn.Conv1d(16, 16, 1), nn.BatchNorm1d(16), nn.PReLU())
        self.conv1d_out = nn.Conv1d(16, 2, 1)

    def forward(self, x):
        B, C, Freq, T = x.shape
        x = self.conv2d(x)
        x = x.permute(0, 3, 1, 2).reshape(B, T, -1)
        x, _ = self.f_gru(x)
        x = x.permute(0, 2, 1)
        x = self.conv1d_block(x)
        vad_logits = self.conv1d_out(x)
        return vad_logits

In [None]:
class DeepFilterOp(nn.Module):
    def __init__(self, N_f=3, N_t=3, N_l=1):
        super().__init__()
        self.N_f, self.N_t, self.N_l = N_f, N_t, N_l
        self.k_f = 2 * N_f + 1
        self.k_t = N_t + N_l + 1
        self.num_neighbors = self.k_f * self.k_t

    def forward(self, coarse_spec, filters):
        B, C, Freq, T = coarse_spec.shape
        spec_padded = F.pad(coarse_spec, (self.N_t, self.N_l, self.N_f, self.N_f))
        patches = F.unfold(spec_padded, kernel_size=(self.k_f, self.k_t))
        patches = patches.view(B, C, self.num_neighbors, Freq, T)
        filters = filters.view(B, 2, self.num_neighbors, Freq, T)
        out_r = torch.sum(patches[:,0]*filters[:,0] - patches[:,1]*filters[:,1], dim=1)
        out_i = torch.sum(patches[:,0]*filters[:,1] + patches[:,1]*filters[:,0], dim=1)
        return torch.stack([out_r, out_i], dim=1)

In [None]:
class CoarseStage(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_0 = CausalConvBlock(2, 16, (5,1), (1,1), (2,0))
        self.enc_1 = CausalConvBlock(16, 16, (1,5), (1,1), (0,0))
        self.enc_2 = CausalConvBlock(16, 16, (6,5), (2,1), (2,0))
        self.enc_3 = CausalConvBlock(16, 32, (4,3), (2,1), (1,0))
        self.enc_4 = CausalConvBlock(32, 32, (6,5), (2,1), (2,0))
        self.enc_5 = CausalConvBlock(32, 32, (5,3), (2,1), (2,0))
        self.enc_6 = CausalConvBlock(32, 32, (3,5), (2,1), (1,0))
        self.enc_7 = CausalConvBlock(32, 32, (3,3), (1,1), (1,0))
        self.gru_0 = FT_GRU_Block(32, 32, 64)
        self.gru_1 = FT_GRU_Block(32, 32, 32)
        self.vad = VADModule(32, freq_bins=8)
        self.dec_0 = GatedTrConvBlock(32+32, 32, (3,3), (1,1), (1,0))
        self.dec_1 = GatedTrConvBlock(32+32, 32, (3,5), (2,1), (1,0), output_padding=(1,0))
        self.dec_2 = GatedTrConvBlock(32+32, 32, (5,3), (2,1), (2,0), output_padding=(1,0))
        self.dec_3 = GatedTrConvBlock(32+32, 32, (6,5), (2,1), (2,0))
        self.dec_4 = GatedTrConvBlock(32+32, 16, (4,3), (2,1), (1,0), output_padding=(0,0))
        self.dec_5 = GatedTrConvBlock(16+16, 16, (6,5), (2,1), (2,0), output_padding=(1,0))
        self.dec_6 = GatedTrConvBlock(16+16, 16, (1,5), (1,1), (0,0))
        self.dec_7 = GatedTrConvBlock(16+16, 16, (5,1), (1,1), (2,0))
        self.mask_conv = nn.Conv2d(16, 2, kernel_size=1)

    def forward(self, mic_cpr, ref_cpr, mic_spec_complex):
        x = torch.cat([mic_cpr, ref_cpr], dim=1)
        e0 = self.enc_0(x); e1 = self.enc_1(e0); e2 = self.enc_2(e1); e3 = self.enc_3(e2)
        e4 = self.enc_4(e3); e5 = self.enc_5(e4); e6 = self.enc_6(e5); e7 = self.enc_7(e6)
        g0 = self.gru_0(e7); g1 = self.gru_1(g0)
        vad_out = self.vad(g1)
        d0 = self.dec_0(g1, e7); d1 = self.dec_1(d0, e6); d2 = self.dec_2(d1, e5); d3 = self.dec_3(d2, e4)
        d4 = self.dec_4(d3, e3); d5 = self.dec_5(d4, e2); d6 = self.dec_6(d5, e1); d7 = self.dec_7(d6, e0)
        mask = self.mask_conv(d7)
        pred_r = mic_spec_complex[:,0]*mask[:,0] - mic_spec_complex[:,1]*mask[:,1]
        pred_i = mic_spec_complex[:,0]*mask[:,1] + mic_spec_complex[:,1]*mask[:,0]
        coarse_out = torch.stack([pred_r, pred_i], dim=1)
        return coarse_out, vad_out

In [None]:
class FineStage(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_0 = CausalConvBlock(3, 16, (5,1), (1,1), (2,0))
        self.enc_1 = CausalConvBlock(16, 16, (1,5), (1,1), (0,0))
        self.enc_2 = CausalConvBlock(16, 32, (6,5), (2,1), (2,0))
        self.enc_3 = CausalConvBlock(32, 32, (4,3), (2,1), (1,0))
        self.enc_4 = CausalConvBlock(32, 64, (6,5), (2,1), (2,0))
        self.enc_5 = CausalConvBlock(64, 64, (5,3), (2,1), (2,0))
        self.enc_6 = CausalConvBlock(64, 64, (3,5), (2,1), (1,0))
        self.enc_7 = CausalConvBlock(64, 64, (3,3), (1,1), (1,0))
        self.gru_0 = FT_GRU_Block(64, 64, 128)
        self.gru_1 = FT_GRU_Block(64, 64, 64)
        self.dec_0 = GatedTrConvBlock(64+64, 64, (3,3), (1,1), (1,0))
        self.dec_1 = GatedTrConvBlock(64+64, 64, (3,5), (2,1), (1,0), output_padding=(1,0))
        self.dec_2 = GatedTrConvBlock(64+64, 64, (5,3), (2,1), (2,0), output_padding=(1,0))
        self.dec_3 = GatedTrConvBlock(64+32, 32, (6,5), (2,1), (2,0))
        self.dec_4 = GatedTrConvBlock(32+32, 32, (4,3), (2,1), (1,0), output_padding=(0,0))
        self.dec_5 = GatedTrConvBlock(32+32, 16, (6,5), (2,1), (2,0), output_padding=(1,0))
        self.dec_6 = GatedTrConvBlock(16+16, 16, (1,5), (1,1), (0,0))
        self.dec_7 = GatedTrConvBlock(16+16, 16, (5,1), (1,1), (2,0))
        self.df_conv = nn.Conv2d(16, 70, kernel_size=1)
        self.df_op = DeepFilterOp(N_f=3, N_t=3, N_l=1)

    def forward(self, mic_cpr, ref_cpr, coarse_cpr, coarse_out_complex):
        x = torch.cat([mic_cpr, ref_cpr, coarse_cpr], dim=1)
        e0 = self.enc_0(x); e1 = self.enc_1(e0); e2 = self.enc_2(e1); e3 = self.enc_3(e2)
        e4 = self.enc_4(e3); e5 = self.enc_5(e4); e6 = self.enc_6(e5); e7 = self.enc_7(e6)
        g0 = self.gru_0(e7); g1 = self.gru_1(g0)
        d0 = self.dec_0(g1, e7); d1 = self.dec_1(d0, e6); d2 = self.dec_2(d1, e5); d3 = self.dec_3(d2, e4)
        d4 = self.dec_4(d3, e3); d5 = self.dec_5(d4, e2); d6 = self.dec_6(d5, e1); d7 = self.dec_7(d6, e0)
        df_coef = self.df_conv(d7)
        fine_out = self.df_op(coarse_out_complex, df_coef)
        return fine_out

In [None]:
class TSPNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.coarse_stage = CoarseStage()
        self.fine_stage = FineStage()
        self.alpha = 0.3

    def compress(self, complex_spec):
        # input: [B, 2, F, T]
        mag = torch.sqrt(complex_spec[:, 0]**2 + complex_spec[:, 1]**2 + 1e-8)
        mag_compressed = torch.pow(mag, self.alpha)
        return mag_compressed.unsqueeze(1) # [B, 1, F, T]

    def forward(self, mic_complex, ref_complex):
        mic_cpr = self.compress(mic_complex)
        ref_cpr = self.compress(ref_complex)
        coarse_out, vad_prob = self.coarse_stage(mic_cpr, ref_cpr, mic_complex)e
        coarse_cpr = self.compress(coarse_out)
        fine_out = self.fine_stage(mic_cpr, ref_cpr, coarse_cpr, coarse_out)
        
        return fine_out, vad_prob, coarse_out

### Datasets

In [None]:
class AEC_Dataset(Dataset):
    def __init__(self, hf_dataset, n_fft=512, win_length=320, hop_length=160, duration=8):
        # Note: Paper uses duration 8s
        self.dataset = hf_dataset
        self.n_fft, self.win_length, self.hop_length = n_fft, win_length, hop_length
        self.max_len = int(16000 * duration)
        self.window = torch.hann_window(self.win_length)

    def __len__(self): 
        return len(self.dataset)

    def process_len(self, wav_array):
        wav = torch.from_numpy(wav_array).float()
        if wav.ndim == 1: wav = wav.unsqueeze(0)
        # Random crop or Pad
        if wav.shape[1] > self.max_len:
            start = random.randint(0, wav.shape[1] - self.max_len)
            return wav[:, start:start + self.max_len], start
        return torch.nn.functional.pad(wav, (0, self.max_len - wav.shape[1])), 0

    def get_stft(self, wav):
        stft_complex = torch.stft(wav, n_fft=self.n_fft, hop_length=self.hop_length,
                                  win_length=self.win_length, window=self.window,
                                  center=True, return_complex=True)
        stft_real_imag = torch.view_as_real(stft_complex).squeeze(0) # [F, T, 2]
        return stft_real_imag.permute(2, 0, 1)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        mic_wav, start = self.process_len(item['mic']['array'])
        ref_wav, _ = self.process_len(item['ref']['array'])
        clean_wav, _ = self.process_len(item['clean']['array'])

        vad_full = torch.tensor(item['vad_label'], dtype=torch.long)

        # Align VAD label with STFT frames
        num_frames = self.get_stft(mic_wav).shape[2] # Dimension T is now at index 2
        start_frame = start // self.hop_length
        vad_label = vad_full[start_frame : start_frame + num_frames]

        if vad_label.shape[0] < num_frames:
            vad_label = torch.nn.functional.pad(vad_label, (0, num_frames - vad_label.shape[0]))
        elif vad_label.shape[0] > num_frames:
            vad_label = vad_label[:num_frames]

        # Returns: [2, F, T] for audio, [T] for VAD
        return self.get_stft(mic_wav), self.get_stft(ref_wav), self.get_stft(clean_wav), vad_label

### Loss Function

In [None]:
def tspnn_loss(est_fine, est_coarse, target, vad_logits, vad_labels, w=0.3, beta=0.06):
    """
    Combined Loss: Eq (7) and (9)
    L = [w * L_coarse + (1-w) * L_fine] + beta * L_vad
    """
    
    # 1. Phase-aware MAE Loss (Eq 6)
    def phase_aware_mae(pred, tgt):
        # pred, tgt: [B, 2, F, T]
        # Magnitude
        mag_pred = torch.sqrt(pred[:,0]**2 + pred[:,1]**2 + 1e-8)
        mag_tgt = torch.sqrt(tgt[:,0]**2 + tgt[:,1]**2 + 1e-8)
        loss_mag = F.l1_loss(mag_pred, mag_tgt)
        
        # Real & Imag
        loss_real = F.l1_loss(pred[:,0], tgt[:,0])
        loss_imag = F.l1_loss(pred[:,1], tgt[:,1])
        
        return loss_mag + loss_real + loss_imag

    loss_coarse = phase_aware_mae(est_coarse, target)
    loss_fine = phase_aware_mae(est_fine, target)
    loss_vad = F.cross_entropy(vad_logits, vad_labels)
    
    # 3. Final Weighted Sum
    loss_main = w * loss_coarse + (1 - w) * loss_fine
    total_loss = loss_main + beta * loss_vad
    
    return total_loss, loss_coarse, loss_fine, loss_vad

### Visualize

In [None]:
def stft_to_mag_db(stft_tensor):
    mag = torch.sqrt(stft_tensor[0]**2 + stft_tensor[1]**2 + 1e-9)
    return 20 * torch.log10(mag + 1e-9)

def visual_check_tspnn(model, device, step, val_dataset):
    model.eval()
    with torch.no_grad():
        mic, ref, clean, vad_lab = val_dataset[0]
        # Add batch dim: [1, 2, F, T]
        mic = mic.unsqueeze(0).to(device)
        ref = ref.unsqueeze(0).to(device)
        clean = clean.to(device)
        
        # Forward
        fine_out, vad_out, coarse_out = model(mic, ref)
        
        img_mic = stft_to_mag_db(mic[0])
        img_coarse = stft_to_mag_db(coarse_out[0])
        img_fine = stft_to_mag_db(fine_out[0])
        img_clean = stft_to_mag_db(clean)
        
        # Plot
        imgs = [img_mic, img_coarse, img_fine, img_clean]
        titles = ["Microphone (Input)", "Coarse Output (P-AEC)", "Fine Output (R-AEC)", "Clean (Target)"]
        
        fig, axs = plt.subplots(4, 1, figsize=(10, 14), facecolor='white')
        for i, (img, title) in enumerate(zip(imgs, titles)):
            axs[i].imshow(img.cpu().numpy(), origin='lower', aspect='auto', cmap='magma')
            axs[i].set_title(title)
            axs[i].axis('off')
            
        plt.tight_layout()
        wandb.log({"Inference_Comparison": wandb.Image(fig)}, step=step)
        plt.close(fig)
    model.train()

In [None]:
def main():
    CONFIG = {
        "repo_id": "PandaLT/microsoft-AEC-vad-dataset",
        "n_fft": 512, 
        "win_length": 320, 
        "hop_length": 160,
        "batch_size": 16, 
        "lr": 1e-3,       
        "epochs": 10,
        "w_loss": 0.3,    
        "beta_loss": 0.06,
        "val_interval": 100,
        "seed": 42
    }

    output_dir = "/content/drive/MyDrive/AEC/TSPNN_Checkpoints"
    os.makedirs(output_dir, exist_ok=True)
    
    set_seed(CONFIG['seed'])

    wandb.init(project="TSPNN_AEC_Implementation", config=CONFIG)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    print("Loading Dataset from HuggingFace...")
    full_dataset = load_dataset(CONFIG['repo_id'], split='train')
    
    # Simple split (Take 200 for val, rest for train)
    val_data = full_dataset.select(range(200))
    train_data = full_dataset.select(range(200, len(full_dataset)))
    
    params = {
        "n_fft": CONFIG["n_fft"], 
        "win_length": CONFIG["win_length"], 
        "hop_length": CONFIG["hop_length"],
        "duration": 8 # TSPNN paper uses 8s
    }
    
    train_ds = AEC_Dataset(train_data, **params)
    val_ds = AEC_Dataset(val_data, **params)
    
    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)

    print("Initializing TSPNN Model...")
    model = TSPNN().to(device)
    model.apply(init_weights)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

    global_step = 0
    print("Starting Training...")
    
    for epoch in range(CONFIG['epochs']):
        model.train()
        epoch_loss = 0
        
        for batch_idx, (mic, ref, clean, vad_labels) in enumerate(train_loader):
            # Inputs: [B, 2, F, T]
            mic, ref, clean = mic.to(device), ref.to(device), clean.to(device)
            vad_labels = vad_labels.to(device)
            
            # Forward
            fine_out, vad_logits, coarse_out = model(mic, ref)
            
            # [cite_start]Loss [cite: 450]
            loss, l_coarse, l_fine, l_vad = tspnn_loss(
                fine_out, coarse_out, clean, vad_logits, vad_labels,
                w=CONFIG['w_loss'], beta=CONFIG['beta_loss']
            )
            
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping is often helpful for GRUs
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            
            optimizer.step()
            
            epoch_loss += loss.item()
            global_step += 1
            
            # Logging
            if global_step % 10 == 0:
                wandb.log({
                    "train_loss": loss.item(),
                    "loss_coarse": l_coarse.item(),
                    "loss_fine": l_fine.item(),
                    "loss_vad": l_vad.item(),
                    "epoch": epoch
                }, step=global_step)
                print(f"Step {global_step} | Loss: {loss.item():.4f} | Fine: {l_fine.item():.4f}")
            
            # Validation
            if global_step % CONFIG['val_interval'] == 0:
                model.eval()
                val_loss_accum = 0
                with torch.no_grad():
                    for v_mic, v_ref, v_clean, v_vad in val_loader:
                        v_mic, v_ref, v_clean = v_mic.to(device), v_ref.to(device), v_clean.to(device)
                        v_vad = v_vad.to(device)
                        
                        v_fine, v_logits, v_coarse = model(v_mic, v_ref)
                        v_loss, _, _, _ = tspnn_loss(v_fine, v_coarse, v_clean, v_logits, v_vad)
                        val_loss_accum += v_loss.item()
                
                avg_val_loss = val_loss_accum / len(val_loader)
                wandb.log({"val_loss": avg_val_loss}, step=global_step)
                print(f"--- Validation Loss: {avg_val_loss:.4f} ---")
                
                # Visual Check
                visual_check_tspnn(model, device, global_step, val_ds)
                
                # Save checkpoint
                torch.save(model.state_dict(), f"{output_dir}/tspnn_step_{global_step}.pth")
                model.train()
        
        # End of Epoch
        scheduler.step(epoch_loss / len(train_loader))
        print(f">>> End of Epoch {epoch+1}")

    wandb.finish()