In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import wandb
import random
import math
import os
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

Docstring for source.TSNN_coarse_fine_Stage.ipynb

Gated TrConv2D
Casual Conv2D
FT-GRU

Corse Stage
Fine Stage
VAD
Deep Filter

TDC, Magnitude Compress
===> TSPNN

### Fundamental Block

In [43]:
class CausalConvBlock(nn.Module):
    """
    Conv2D + BN + PReLU với Causal Padding trên trục thời gian.
    """
    def __init__(self, in_c, out_c, kernel_size, stride, padding):
        super().__init__()
        # kernel_size: (Freq, Time)
        k_f, k_t = kernel_size
        s_f, s_t = stride
        p_f, p_t = padding # p_t ở đây là padding thông thường, ta sẽ xử lý causal riêng

        # Causal padding: Chỉ pad bên trái trục Time
        self.time_pad = (k_t - 1) 
        self.freq_pad = p_f # Pad đều trục Freq

        self.conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding=(0, 0))
        self.bn = nn.BatchNorm2d(out_c)
        self.act = nn.PReLU()

    def forward(self, x):
        # x: [B, C, F, T]
        # Pad: (Left, Right, Top, Bottom) -> (Time_Left, Time_Right, Freq_Top, Freq_Bottom)
        x = F.pad(x, (self.time_pad, 0, self.freq_pad, self.freq_pad))
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x
    
class GatedTrConvBlock(nn.Module):
    """
    Hình 1(B): Gated Transpose Convolution.
    Fix: Thêm logic 'chomp' để cắt bỏ phần thừa trục thời gian do Transpose Conv sinh ra.
    """
    def __init__(self, in_c, out_c, kernel_size, stride, padding, output_padding=(0,0)):
        super().__init__()
        k_f, k_t = kernel_size
        self.chomp_t = k_t - 1 # Lượng thừa ra do Transpose Conv
        
        self.gate_conv = nn.Sequential(
            nn.Conv2d(in_c, in_c, kernel_size=1),
            nn.Tanh()
        )
        self.tr_conv = nn.ConvTranspose2d(
            in_c, out_c, kernel_size, stride, padding, output_padding
        )
        self.bn = nn.BatchNorm2d(out_c)
        self.act = nn.PReLU()

    def forward(self, x, skip):
        x_cat = torch.cat([x, skip], dim=1) 
        
        # 2. Gate mechanism
        gate = self.gate_conv(x_cat)
        x_gated = x_cat * gate
        
        # 3. Transpose Conv
        out = self.tr_conv(x_gated)
        if self.chomp_t > 0:
            out = out[:, :, :, :-self.chomp_t]
            
        out = self.bn(out)
        out = self.act(out)
        return out
    
class FT_GRU_Block(nn.Module):
    def __init__(self, in_c, hidden_f, hidden_t):
        super().__init__()
        self.f_gru = nn.GRU(in_c, hidden_f, batch_first=True, bidirectional=True)
        self.f_linear = nn.Linear(hidden_f * 2, in_c)
        self.t_gru = nn.GRU(in_c, hidden_t, batch_first=True, bidirectional=False)
        self.t_linear = nn.Linear(hidden_t, in_c)
        
        self.bn = nn.BatchNorm2d(in_c)
        self.act = nn.PReLU()
    def forward(self,x):
        B, C, F, T = x.shape
        x_f = x.permute(0, 3, 2, 1).reshape(-1, F, C)
        x_f_out, _ = self.f_gru(x_f) # [B*T, F, hidden_f*2]
        x_f_out = self.f_linear(x_f_out) # [B*T, F, C]
        x_f_res = x_f_out.reshape(B, T, F, C).permute(0, 3, 2, 1)
        x = x + x_f_res

        x_t = x.permute(0, 2, 3, 1).reshape(-1, T, C)
        x_t_out, _ = self.t_gru(x_t) # [B*F, T, hidden_t]
        x_t_out = self.t_linear(x_t_out)
        x_t_res = x_t_out.reshape(B, F, T, C).permute(0, 3, 1, 2)
        
        x = x + x_t_res
        x = self.bn(x)
        x = self.act(x)
        return x


### VAD block

In [None]:
class VADModule(nn.Module):
    def __init__(self, in_channels=32, freq_bins=5): 
        super().__init__()
        self.conv2d = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=1),
            nn.BatchNorm2d(16),
            nn.PReLU()
        )
        self.f_gru = nn.GRU(16 * freq_bins, 8, batch_first=True, bidirectional=True)
        
        self.conv1d_block = nn.Sequential(
            nn.Conv1d(16, 16, kernel_size=1),
            nn.BatchNorm1d(16),
            nn.PReLU()
        )
        self.conv1d_out = nn.Conv1d(16, 2, kernel_size=1) 

    def forward(self, x):
        B, C, Freq, T = x.shape  
        x = self.conv2d(x) 
        x = x.permute(0, 3, 1, 2).reshape(B, T, -1)
        x, _ = self.f_gru(x) 
        
        x = x.permute(0, 2, 1)
        x = self.conv1d_block(x)
        vad_logits = self.conv1d_out(x)
        return vad_logits

### Deep Filter

In [None]:
class DeepFilterOp(nn.Module):
    def __init__(self, N_f=3, N_t=3, N_l=1):
        super().__init__()
        self.N_f = N_f
        self.N_t = N_t
        self.N_l = N_l
        self.k_f = 2 * N_f + 1
        self.k_t = N_t + N_l + 1
        self.num_neighbors = self.k_f * self.k_t

    def forward(self, coarse_spec, filters):
        # coarse_spec: [B, 2, Freq, T]
        B, C, Freq, T = coarse_spec.shape 
        # Pad Freq: N_f trên, N_f dưới
        spec_padded = F.pad(coarse_spec, (self.N_t, self.N_l, self.N_f, self.N_f))
        # Unfold
        patches = F.unfold(spec_padded, kernel_size=(self.k_f, self.k_t))
        # Reshape patches sử dụng biến Freq
        patches = patches.view(B, C, self.num_neighbors, Freq, T) # <--- DÙNG Freq
        
        filters = filters.view(B, 2, self.num_neighbors, Freq, T) # <--- DÙNG Freq
        filter_r = filters[:, 0]
        filter_i = filters[:, 1]
        
        spec_r = patches[:, 0]
        spec_i = patches[:, 1]
        
        out_r = torch.sum(spec_r * filter_r - spec_i * filter_i, dim=1)
        out_i = torch.sum(spec_r * filter_i + spec_i * filter_r, dim=1)
        
        out = torch.stack([out_r, out_i], dim=1)
        return out

### Coarse and Fine

In [None]:
class CoarseStage(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder 
        self.enc_0 = CausalConvBlock(2, 16, (5,1), (1,1), (2,0))
        self.enc_1 = CausalConvBlock(16, 16, (1,5), (1,1), (0,0))
        self.enc_2 = CausalConvBlock(16, 16, (6,5), (2,1), (2,0))
        self.enc_3 = CausalConvBlock(16, 32, (4,3), (2,1), (1,0)) # Out: 32
        self.enc_4 = CausalConvBlock(32, 32, (6,5), (2,1), (2,0))
        self.enc_5 = CausalConvBlock(32, 32, (5,3), (2,1), (2,0))
        self.enc_6 = CausalConvBlock(32, 32, (3,5), (2,1), (1,0))
        self.enc_7 = CausalConvBlock(32, 32, (3,3), (1,1), (1,0))

        self.gru_0 = FT_GRU_Block(32, 32, 64)
        self.gru_1 = FT_GRU_Block(32, 32, 32)
        self.vad = VADModule(32, freq_bins=8) 
        
        # Decoder 
        # dec_0 in: gru(32) + enc_7(32) = 64
        self.dec_0 = GatedTrConvBlock(32+32, 32, (3,3), (1,1), (1,0)) 
        
        # dec_1 in: dec_0(32) + enc_6(32) = 64
        self.dec_1 = GatedTrConvBlock(32+32, 32, (3,5), (2,1), (1,0), output_padding=(1,0))
        
        # dec_2 in: dec_1(32) + enc_5(32) = 64
        self.dec_2 = GatedTrConvBlock(32+32, 32, (5,3), (2,1), (2,0), output_padding=(1,0))
        
        # dec_3 in: dec_2(32) + enc_4(32) = 64
        self.dec_3 = GatedTrConvBlock(32+32, 32, (6,5), (2,1), (2,0))
        
        # dec_4 in: dec_3(32) + enc_3(32) = 64. (FIXED: Old was 32+16)
        self.dec_4 = GatedTrConvBlock(32+32, 16, (4,3), (2,1), (1,0), output_padding=(0,0))
        
        # dec_5 in: dec_4(16) + enc_2(16) = 32
        self.dec_5 = GatedTrConvBlock(16+16, 16, (6,5), (2,1), (2,0), output_padding=(1,0))
        
        # dec_6 in: dec_5(16) + enc_1(16) = 32
        self.dec_6 = GatedTrConvBlock(16+16, 16, (1,5), (1,1), (0,0))
        
        # dec_7 in: dec_6(16) + enc_0(16) = 32. (FIXED: Old was 16+2)
        self.dec_7 = GatedTrConvBlock(16+16, 16, (5,1), (1,1), (2,0)) 
        
        self.mask_conv = nn.Conv2d(16, 2, kernel_size=1)

    def forward(self, mic_cpr, ref_cpr, mic_spec_complex):
        x = torch.cat([mic_cpr, ref_cpr], dim=1)
        
        e0 = self.enc_0(x)
        e1 = self.enc_1(e0)
        e2 = self.enc_2(e1)
        e3 = self.enc_3(e2)
        e4 = self.enc_4(e3)
        e5 = self.enc_5(e4)
        e6 = self.enc_6(e5)
        e7 = self.enc_7(e6) 
        
        g0 = self.gru_0(e7)
        g1 = self.gru_1(g0)
        
        vad_out = self.vad(g1)
        
        d0 = self.dec_0(g1, e7)
        d1 = self.dec_1(d0, e6)
        d2 = self.dec_2(d1, e5)
        d3 = self.dec_3(d2, e4)
        d4 = self.dec_4(d3, e3)
        d5 = self.dec_5(d4, e2)
        d6 = self.dec_6(d5, e1)
        d7 = self.dec_7(d6, e0)
        
        mask = self.mask_conv(d7)
        pred_r = mic_spec_complex[:,0]*mask[:,0] - mic_spec_complex[:,1]*mask[:,1]
        pred_i = mic_spec_complex[:,0]*mask[:,1] + mic_spec_complex[:,1]*mask[:,0]
        coarse_out = torch.stack([pred_r, pred_i], dim=1)
        return coarse_out, vad_out

class FineStage(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder 
        self.enc_0 = CausalConvBlock(3, 16, (5,1), (1,1), (2,0))
        self.enc_1 = CausalConvBlock(16, 16, (1,5), (1,1), (0,0))
        self.enc_2 = CausalConvBlock(16, 32, (6,5), (2,1), (2,0))
        self.enc_3 = CausalConvBlock(32, 32, (4,3), (2,1), (1,0))
        self.enc_4 = CausalConvBlock(32, 64, (6,5), (2,1), (2,0)) # Out: 64
        self.enc_5 = CausalConvBlock(64, 64, (5,3), (2,1), (2,0))
        self.enc_6 = CausalConvBlock(64, 64, (3,5), (2,1), (1,0))
        self.enc_7 = CausalConvBlock(64, 64, (3,3), (1,1), (1,0))

        # GRU 
        self.gru_0 = FT_GRU_Block(64, 64, 128)
        self.gru_1 = FT_GRU_Block(64, 64, 64)
        
        # Decoder 
        # dec_0 in: gru(64) + enc_7(64) = 128
        self.dec_0 = GatedTrConvBlock(64+64, 64, (3,3), (1,1), (1,0))
        
        # dec_1 in: dec_0(64) + enc_6(64) = 128
        self.dec_1 = GatedTrConvBlock(64+64, 64, (3,5), (2,1), (1,0), output_padding=(1,0))
        
        # dec_2 in: dec_1(64) + enc_5(64) = 128
        self.dec_2 = GatedTrConvBlock(64+64, 64, (5,3), (2,1), (2,0), output_padding=(1,0))
        
        # dec_3 in: dec_2(64) + enc_4(64) = 128. (FIXED: Old was 64+32)
        self.dec_3 = GatedTrConvBlock(64+64, 32, (6,5), (2,1), (2,0))
        
        # dec_4 in: dec_3(32) + enc_3(32) = 64
        self.dec_4 = GatedTrConvBlock(32+32, 32, (4,3), (2,1), (1,0), output_padding=(0,0))
        
        # dec_5 in: dec_4(32) + enc_2(32) = 64
        self.dec_5 = GatedTrConvBlock(32+32, 16, (6,5), (2,1), (2,0), output_padding=(1,0))
        
        # dec_6 in: dec_5(16) + enc_1(16) = 32
        self.dec_6 = GatedTrConvBlock(16+16, 16, (1,5), (1,1), (0,0))
        
        # dec_7 in: dec_6(16) + enc_0(16) = 32
        self.dec_7 = GatedTrConvBlock(16+16, 16, (5,1), (1,1), (2,0))
        
        self.df_conv = nn.Conv2d(16, 70, kernel_size=1)
        self.df_op = DeepFilterOp(N_f=3, N_t=3, N_l=1)

    def forward(self, mic_cpr, ref_cpr, coarse_out_cpr, coarse_out_complex):
        x = torch.cat([mic_cpr, ref_cpr, coarse_out_cpr], dim=1)
        e0 = self.enc_0(x)
        e1 = self.enc_1(e0)
        e2 = self.enc_2(e1)
        e3 = self.enc_3(e2)
        e4 = self.enc_4(e3)
        e5 = self.enc_5(e4)
        e6 = self.enc_6(e5)
        e7 = self.enc_7(e6)
        
        g0 = self.gru_0(e7)
        g1 = self.gru_1(g0)
        
        d0 = self.dec_0(g1, e7)
        d1 = self.dec_1(d0, e6)
        d2 = self.dec_2(d1, e5)
        d3 = self.dec_3(d2, e4)
        d4 = self.dec_4(d3, e3)
        d5 = self.dec_5(d4, e2)
        d6 = self.dec_6(d5, e1)
        d7 = self.dec_7(d6, e0)
        
        df_coef = self.df_conv(d7)
        fine_out = self.df_op(coarse_out_complex, df_coef)
        return fine_out

### Wrapper Model (TSPNN)

In [47]:
class TSPNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.coarse_stage = CoarseStage()
        self.fine_stage = FineStage()
        self.alpha = 0.3 # Compression factor [cite: 100]

    def compress(self, complex_spec):
        # complex_spec: [B, 2, F, T]
        mag = torch.sqrt(complex_spec[:, 0]**2 + complex_spec[:, 1]**2 + 1e-8)
        mag_compressed = torch.pow(mag, self.alpha)
        return mag_compressed.unsqueeze(1) # [B, 1, F, T]

    def forward(self, mic_complex, ref_complex):
        """
        Input: Spectrograms [B, 2, F, T] (Real/Imag)
        """
        # 1. Prepare Compressed Inputs
        mic_cpr = self.compress(mic_complex)
        ref_cpr = self.compress(ref_complex)
        
        # 2. Coarse Stage
        coarse_out, vad_prob = self.coarse_stage(mic_cpr, ref_cpr, mic_complex)
        
        # 3. Prepare Input for Fine Stage
        coarse_cpr = self.compress(coarse_out)
        
        # 4. Fine Stage
        fine_out = self.fine_stage(mic_cpr, ref_cpr, coarse_cpr, coarse_out)
        
        return fine_out, vad_prob, coarse_out

In [53]:
B, n_freq, T = 2, 257, 100
mic = torch.randn(B, 2, n_freq, T) # Real, Imag
ref = torch.randn(B, 2, n_freq, T)
    
model = TSPNN()
print("Initializing TSPNN...")
    
    # Forward pass
fine, vad, coarse = model(mic, ref)
    
print(f"Input shape: {mic.shape}")
print(f"Coarse out: {coarse.shape}")
print(f"VAD out: {vad.shape}")
print(f"Fine out: {fine.shape}")
    
    # Check parameters count
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params}")

Initializing TSPNN...
Input shape: torch.Size([2, 2, 257, 100])
Coarse out: torch.Size([2, 2, 257, 100])
VAD out: torch.Size([2, 2, 100])
Fine out: torch.Size([2, 2, 257, 100])
Total Parameters: 1411936


In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [None]:
class HF_STFTDataset_V2(Dataset):
    def __init__(self, hf_dataset, n_fft=512, win_length=320, hop_length=160, duration=10):
        self.dataset = hf_dataset
        self.n_fft, self.win_length, self.hop_length = n_fft, win_length, hop_length
        self.max_len = int(16000 * duration)
        self.window = torch.hann_window(self.win_length)

    def __len__(self): return len(self.dataset)

    def process_len(self, wav_array):
        wav = torch.from_numpy(wav_array).float()
        if wav.ndim == 1: wav = wav.unsqueeze(0)
        if wav.shape[1] > self.max_len:
            start = random.randint(0, wav.shape[1] - self.max_len)
            return wav[:, start:start + self.max_len], start
        return torch.nn.functional.pad(wav, (0, self.max_len - wav.shape[1])), 0

    def get_stft(self, wav):
        # Sửa lỗi return_complex=False bằng cách dùng view_as_real
        stft_complex = torch.stft(wav, n_fft=self.n_fft, hop_length=self.hop_length,
                                  win_length=self.win_length, window=self.window,
                                  center=True, return_complex=True)
        return torch.view_as_real(stft_complex).squeeze(0)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Xử lý audio và lấy vị trí start để cắt VAD label tương ứng
        mic_wav, start = self.process_len(item['mic']['array'])
        ref_wav, _ = self.process_len(item['ref']['array'])
        clean_wav, _ = self.process_len(item['clean']['array'])
        
        # Lấy nhãn VAD từ dataset (dạng List)
        vad_full = torch.tensor(item['vad_label'], dtype=torch.long)
        
        # Cắt nhãn VAD khớp với đoạn audio (10ms mỗi nhãn)
        num_frames = self.get_stft(mic_wav).shape[1] # Số khung T
        start_frame = start // self.hop_length
        vad_label = vad_full[start_frame : start_frame + num_frames]
        
        # Padding nhãn VAD nếu thiếu do làm tròn
        if vad_label.shape[0] < num_frames:
            vad_label = torch.nn.functional.pad(vad_label, (0, num_frames - vad_label.shape[0]))
        elif vad_label.shape[0] > num_frames:
            vad_label = vad_label[:num_frames]

        return self.get_stft(mic_wav), self.get_stft(ref_wav), self.get_stft(clean_wav), vad_label