In [None]:
!pip install -U datasets[audio]
!pip install torchaudio
!pip install torchcodec

In [None]:
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import wandb
import random
import math
import os
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [None]:
class CasualMHSA(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
    def forward(self, x):
        B, T, D = x.shape
        attn_mask = torch.triu(torch.ones(T, T, device=x.device) * float('-inf'), diagonal=1)
        x_out, _ = self.mha(x, x, x, attn_mask=attn_mask)
        return x_out

class ConformerConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=15, dropout=0.1):
        super().__init__()
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=(kernel_size-1)//2, groups=d_model)
        self.batch_norm = nn.GroupNorm(num_groups=1, num_channels=d_model)
        self.activation = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(x)))))))
        return x.transpose(1, 2)

class FeedForwardModule(nn.Module):
    def __init__(self, d_model, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_model*expansion_factor)
        self.activation = nn.SiLU()
        self.dropout1 = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_model*expansion_factor, d_model)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, x):
        return self.dropout2(self.layer2(self.dropout1(self.activation(self.layer1(x)))))

class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_head, kernel_size=15, dropout=0.1):
        super().__init__()
        self.ffn1 = FeedForwardModule(d_model, dropout=dropout)
        self.conv_module = ConformerConvModule(d_model, kernel_size=kernel_size, dropout=dropout)
        self.self_attn = CasualMHSA(d_model, n_head, dropout=dropout)
        self.ffn2 = FeedForwardModule(d_model, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model); self.norm4 = nn.LayerNorm(d_model)
        self.final_norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = x + 0.5 * self.ffn1(self.norm1(x))
        x = x + self.conv_module(self.norm2(x))
        x = x + self.self_attn(self.norm3(x))
        x = x + 0.5 * self.ffn2(self.norm4(x))
        return self.final_norm(x)

In [None]:
class AEC_V2(nn.Module):
    def __init__(self, d_model=128, n_fft=512, n_head=8, num_layers=4, kernel_size=15):
        super().__init__()
        self.n_fft = n_fft
        self.n_freq = n_fft // 2 + 1
        input_dim = self.n_freq * 4
        self.input_proj = nn.Linear(input_dim, d_model)
        
        self.layers = nn.ModuleList([
            ConformerBlock(d_model, n_head, kernel_size=kernel_size)
            for _ in range(num_layers)
        ])
        
        # Nhánh 1: AEC Mask (Đầu ra chính)
        self.mask_proj = nn.Linear(d_model, self.n_freq * 2)
        
        # Nhánh 2: VAD Classifier (Đầu ra phụ trợ - Giai đoạn 2)
        # Class 0: Silence/Echo, Class 1: Nearend Speech
        self.vad_proj = nn.Linear(d_model, 2)

    def forward(self, mic_stft, ref_stft):
        B, F, T, C = mic_stft.shape
        mic_flat = mic_stft.permute(0, 2, 1, 3).reshape(B, T, F * 2)
        ref_flat = ref_stft.permute(0, 2, 1, 3).reshape(B, T, F * 2)
        
        x = torch.cat([mic_flat, ref_flat], dim=2)
        x = self.input_proj(x)
        
        # Shared Backbone (Conformer Layers)
        for layer in self.layers:
            x = layer(x)
        
        # Branch 1: AEC Masking
        mask = self.mask_proj(x)
        mask = mask.view(B, T, F, 2).permute(0, 2, 1, 3)
        mic_real, mic_imag = mic_stft[..., 0], mic_stft[..., 1]
        mask_real, mask_imag = mask[..., 0], mask[..., 1]
        est_real = mic_real * mask_real - mic_imag * mask_imag
        est_imag = mic_real * mask_imag + mic_imag * mask_real
        est_stft = torch.stack([est_real, est_imag], dim=-1)
        
        # Branch 2: VAD Logits (Trả về B, T, 2)
        vad_logits = self.vad_proj(x)
        
        return est_stft, vad_logits

In [None]:
class HF_STFTDataset_V2(Dataset):
    def __init__(self, hf_dataset, n_fft=512, win_length=320, hop_length=160, duration=10):
        self.dataset = hf_dataset
        self.n_fft, self.win_length, self.hop_length = n_fft, win_length, hop_length
        self.max_len = int(16000 * duration)
        self.window = torch.hann_window(self.win_length)

    def __len__(self): return len(self.dataset)

    def process_len(self, wav_array):
        wav = torch.from_numpy(wav_array).float()
        if wav.ndim == 1: wav = wav.unsqueeze(0)
        if wav.shape[1] > self.max_len:
            start = random.randint(0, wav.shape[1] - self.max_len)
            return wav[:, start:start + self.max_len], start
        return torch.nn.functional.pad(wav, (0, self.max_len - wav.shape[1])), 0

    def get_stft(self, wav):
        # Sửa lỗi return_complex=False bằng cách dùng view_as_real
        stft_complex = torch.stft(wav, n_fft=self.n_fft, hop_length=self.hop_length,
                                  win_length=self.win_length, window=self.window,
                                  center=True, return_complex=True)
        return torch.view_as_real(stft_complex).squeeze(0)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Xử lý audio và lấy vị trí start để cắt VAD label tương ứng
        mic_wav, start = self.process_len(item['mic']['array'])
        ref_wav, _ = self.process_len(item['ref']['array'])
        clean_wav, _ = self.process_len(item['clean']['array'])
        
        # Lấy nhãn VAD từ dataset (dạng List)
        vad_full = torch.tensor(item['vad_label'], dtype=torch.long)
        
        # Cắt nhãn VAD khớp với đoạn audio (10ms mỗi nhãn)
        num_frames = self.get_stft(mic_wav).shape[1] # Số khung T
        start_frame = start // self.hop_length
        vad_label = vad_full[start_frame : start_frame + num_frames]
        
        # Padding nhãn VAD nếu thiếu do làm tròn
        if vad_label.shape[0] < num_frames:
            vad_label = torch.nn.functional.pad(vad_label, (0, num_frames - vad_label.shape[0]))
        elif vad_label.shape[0] > num_frames:
            vad_label = vad_label[:num_frames]

        return self.get_stft(mic_wav), self.get_stft(ref_wav), self.get_stft(clean_wav), vad_label

In [None]:
def multi_task_loss(est, target, vad_logits, vad_labels, beta=0.1):
    """
    L = L_AEC + beta * L_VAD
    L_AEC: MAE Magnitude + MAE Complex (từ bài báo NWPU/ByteAudio)
    L_VAD: CrossEntropy (từ bài báo NWPU) [cite: 178]
    """
    # 1. AEC Loss
    mae = nn.L1Loss()
    mag_est = torch.sqrt(est[..., 0]**2 + est[..., 1]**2 + 1e-9)
    mag_target = torch.sqrt(target[..., 0]**2 + target[..., 1]**2 + 1e-9)
    loss_aec = mae(mag_est, mag_target) + mae(est, target)
    
    # 2. VAD Loss (CrossEntropy trên trục T)
    # vad_logits: (B, T, 2) -> (B, 2, T) cho CrossEntropy
    # vad_labels: (B, T)
    loss_vad = nn.CrossEntropyLoss()(vad_logits.transpose(1, 2), vad_labels)
    
    total_loss = loss_aec + beta * loss_vad
    return total_loss, loss_aec, loss_vad

In [None]:
def stft_to_mag_db(stft_tensor):
    mag = torch.sqrt(stft_tensor[..., 0]**2 + stft_tensor[..., 1]**2 + 1e-9)
    return 20 * torch.log10(mag + 1e-9)

def visual_check_v2(model, device, step, val_dataset):
    model.eval()
    with torch.no_grad():
        mic, ref, clean, vad_lab = val_dataset[0]
        mic, ref = mic.unsqueeze(0).to(device), ref.unsqueeze(0).to(device)
        est_stft, _ = model(mic, ref)
        
        imgs = [stft_to_mag_db(mic[0]), stft_to_mag_db(ref[0]), stft_to_mag_db(clean), stft_to_mag_db(est_stft[0])]
        titles = ["Microphone", "Reference", "Clean (Target)", f"Estimate (Step {step})"]
        fig, axs = plt.subplots(4, 1, figsize=(10, 12), facecolor='black')
        for i, img in enumerate(imgs):
            axs[i].imshow(img.cpu().numpy(), origin='lower', aspect='auto', cmap='magma')
            axs[i].set_title(titles[i], color='white')
            axs[i].axis('off')
        plt.tight_layout()
        wandb.log({"Inference_Progress": wandb.Image(fig)}, step=step)
        plt.close(fig)
    model.train()

In [None]:
def main():
    CONFIG = {
        "repo_id": "PandaLT/microsoft-AEC-dataset", # Dataset mới có vad_label
        "n_fft": 512, "win_length": 320, "hop_length": 160,
        "d_model": 128, "batch_size": 32, "lr": 2e-4, "epochs": 10,
        "warmup_pct": 0.05, "beta": 0.1, # Trọng số VAD loss 
        "val_interval": 100, "seed": 42
    }

    set_seed(CONFIG['seed'])
    wandb.init(project="AEC_Stage2_MultiTask_VAD", config=CONFIG)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Loading Multi-task Dataset...")
    full_dataset = load_dataset(CONFIG['repo_id'], split='train')
    val_data = full_dataset.select(range(500))
    train_data = full_dataset.select(range(500, len(full_dataset)))

    params = {k: CONFIG[k] for k in ["n_fft", "win_length", "hop_length"]}
    val_ds = HF_STFTDataset_V2(val_data, **params)
    train_loader = DataLoader(HF_STFTDataset_V2(train_data, **params), batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False)

    model = AEC_V2(d_model=CONFIG['d_model'], n_fft=CONFIG['n_fft']).to(device)
    model.apply(init_weights)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    total_steps = len(train_loader) * CONFIG['epochs']
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CONFIG['lr'], total_steps=total_steps)

    global_step = 0
    for epoch in range(CONFIG['epochs']):
        model.train()
        print(f"\n>>> Epoch {epoch+1}")
        for mic, ref, clean, vad_label in train_loader:
            mic, ref, clean, vad_label = mic.to(device), ref.to(device), clean.to(device), vad_label.to(device)

            est_stft, vad_logits = model(mic, ref)
            total_loss, loss_aec, loss_vad = multi_task_loss(est_stft, clean, vad_logits, vad_label, beta=CONFIG['beta'])

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            scheduler.step()

            global_step += 1
            if global_step % 10 == 0:
                wandb.log({
                    "total_loss": total_loss.item(),
                    "aec_loss": loss_aec.item(),
                    "vad_loss": loss_vad.item(),
                    "lr": scheduler.get_last_lr()[0]
                }, step=global_step)

            if global_step % CONFIG['val_interval'] == 0:
                visual_check_v2(model, device, global_step, val_ds)
                torch.save(model.state_dict(), f"aec_v2_step_{global_step}.pth")

    wandb.finish()

In [None]:
main()