In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np  # Cần thêm numpy để tính toán mượt mà hơn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, Audio

### Building Model 

In [None]:
class CasualMHSA(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
    def forward(self, x):
        B, T, D = x.shape
        attn_mask = torch.triu(torch.ones(T, T, device=x.device) * float('-inf'), diagonal=1)
        x_out, _ = self.mha(x, x, x, attn_mask=attn_mask)
        return x_out

class ConformerConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=15, dropout=0.1):
        super().__init__()
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=(kernel_size-1)//2, groups=d_model)
        self.batch_norm = nn.GroupNorm(num_groups=1, num_channels=d_model)
        self.activation = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(x)))))))
        return x.transpose(1, 2)

class FeedForwardModule(nn.Module):
    def __init__(self, d_model, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_model*expansion_factor)
        self.activation = nn.SiLU()
        self.dropout1 = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_model*expansion_factor, d_model)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, x):
        return self.dropout2(self.layer2(self.dropout1(self.activation(self.layer1(x)))))

class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_head, kernel_size=15, dropout=0.1):
        super().__init__()
        self.ffn1 = FeedForwardModule(d_model, dropout=dropout)
        self.conv_module = ConformerConvModule(d_model, kernel_size=kernel_size, dropout=dropout)
        self.self_attn = CasualMHSA(d_model, n_head, dropout=dropout)
        self.ffn2 = FeedForwardModule(d_model, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model); self.norm4 = nn.LayerNorm(d_model)
        self.final_norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = x + 0.5 * self.ffn1(self.norm1(x))
        x = x + self.conv_module(self.norm2(x))
        x = x + self.self_attn(self.norm3(x))
        x = x + 0.5 * self.ffn2(self.norm4(x))
        return self.final_norm(x)


In [None]:
class WaveformAEC(nn.Module):
    def __init__(self,
                 d_model=128,
                 n_layers=4,
                 n_head=8,
                 kernel_size=15,
                 win_len=80,
                 stride=40):
        super().__init__()
        self.win_len = win_len
        self.stride = stride
        self.d_model = d_model
        self.mix_encoder = nn.Linear(win_len, d_model)
        self.ref_encoder = nn.Linear(win_len, d_model)
        self.input_proj = nn.Linear(d_model*2, d_model)

        self.layers = nn.ModuleList([
            ConformerBlock(d_model, n_head, kernel_size) for _ in range(n_layers)
        ])
        self.decoder = nn.Linear(d_model, win_len)
        self.output_act = nn.Tanh()

    def forward(self, mic_wav, ref_wav):
        """
        mic_wav: [Batch, samples] (Ví dụ: [4, 64000])
        ref_wav: [Batch, samples]
        """
        mic_input = mic_wav.unsqueeze(1).unsqueeze(3)
        ref_input = ref_wav.unsqueeze(1).unsqueeze(3)

        mic_frames = F.unfold(mic_input,
                              kernel_size=(self.win_len, 1),
                              stride=(self.stride, 1)).transpose(1, 2)

        ref_frames = F.unfold(ref_input,
                              kernel_size=(self.win_len, 1),
                              stride=(self.stride, 1)).transpose(1, 2)

        # Các bước tiếp theo giữ nguyên
        mic_feature = self.mix_encoder(mic_frames)
        ref_feature = self.ref_encoder(ref_frames)

        concat_feature = torch.cat([mic_feature, ref_feature], dim=-1)

        mask = self.input_proj(concat_feature)
        for layer in self.layers:
            mask = layer(mask)

        mask_feature = mic_feature * mask
        out_frames = self.decoder(mask_feature)
        out_frames = self.output_act(out_frames)
        out_frames_transposed = out_frames.transpose(1, 2)
        output_len = (out_frames.shape[1] - 1) * self.stride + self.win_len

        predicted_wav = F.fold(
            out_frames_transposed,
            output_size=(output_len, 1), # Fold sẽ trả về [B, 1, T, 1]
            kernel_size=(self.win_len, 1),
            stride=(self.stride, 1)
        )

        # Squeeze để trả về [Batch, Time]
        return predicted_wav.squeeze(-1).squeeze(1)

### Loss Function SISNR

In [None]:
class NegSISNRLoss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps
    def forward(self, estimate, target):
        estimate = estimate - torch.mean(estimate, dim=-1, keepdim=True)
        target = target - torch.mean(target, dim=-1, keepdim=True)
        dot = torch.sum(estimate * target, dim=-1, keepdim=True)
        target_energy = torch.sum(target ** 2, dim=-1, keepdim=True) + self.eps
        scaled_target = (dot / target_energy) * target
        e_noise = estimate - scaled_target
        sisnr = 10 * torch.log10(torch.sum(scaled_target ** 2, dim=-1) / (torch.sum(e_noise ** 2, dim=-1) + self.eps))
        return -torch.mean(sisnr)

### Data Loader and Dataset Prepare

In [None]:
class HuggingFaceAECDataset(Dataset):
    def __init__(self, dataset_name, split="train", sample_rate=16000, duration=4.0):
        self.sample_rate = sample_rate
        self.num_samples = int(sample_rate * duration)
        print(f"Loading {dataset_name} (split: {split})...")
        self.dataset = load_dataset(dataset_name, split=split)
        
        # Ép kiểu Audio về 16kHz
        self.dataset = self.dataset.cast_column("Mic", Audio(sampling_rate=sample_rate))
        self.dataset = self.dataset.cast_column("Ref", Audio(sampling_rate=sample_rate))
        self.dataset = self.dataset.cast_column("clean", Audio(sampling_rate=sample_rate))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # Chuyển thành Tensor
        mic = torch.tensor(item["mic"]["array"], dtype=torch.float32)
        ref = torch.tensor(item["ref"]["array"], dtype=torch.float32)
        clean = torch.tensor(item["clean"]["array"], dtype=torch.float32)
        
        # Cắt/Pad cho bằng duration
        L = mic.shape[0]
        if L < self.num_samples:
            pad = self.num_samples - L
            return F.pad(mic, (0, pad)), F.pad(ref, (0, pad)), F.pad(clean, (0, pad))
        
        s = random.randint(0, L - self.num_samples)
        e = s + self.num_samples
        return mic[s:e], ref[s:e], clean[s:e]

### Plot Figure and Count Parameters

In [None]:
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    print(f"\n>> MODEL SUMMARY: {total:,} Parameters (Approx {total/1e6:.1f}M)\n")
    

In [None]:
def plot_metrics(step_history, epoch_history, save_path="training_chart.png"):
    steps = step_history['steps']
    train_loss = step_history['loss']
    w = 50
    if len(train_loss) > w:
        loss_smooth = np.convolve(train_loss, np.ones(w)/w, mode='valid')
        steps_smooth = steps[len(steps)-len(loss_smooth):]
    else:
        loss_smooth, steps_smooth = train_loss, steps

    plt.figure(figsize=(15, 6))

    # --- Biểu đồ 1: Loss ---
    plt.subplot(1, 2, 1)
    plt.plot(steps, train_loss, alpha=0.3, color='lightblue', label='Batch Loss')
    plt.plot(steps_smooth, loss_smooth, color='blue', label='Train Loss (Smooth)')
    if epoch_history['test_steps']:
        plt.plot(epoch_history['test_steps'], epoch_history['test_loss'], 'ro-', label='Test Loss (Epoch)')
    plt.title('Loss (Negative SISNR) - Thấp hơn là Tốt hơn')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend(); plt.grid(True, alpha=0.3)

    # --- Biểu đồ 2: Chất lượng (dB) ---
    plt.subplot(1, 2, 2)
    # Đổi dấu Loss thành SISNR
    train_sisnr = [-l for l in loss_smooth]
    plt.plot(steps_smooth, train_sisnr, color='green', label='Train SISNR')
    if epoch_history['test_steps']:
        test_sisnr = [-l for l in epoch_history['test_loss']]
        plt.plot(epoch_history['test_steps'], test_sisnr, 'ro-', label='Test SISNR')
    plt.title('Quality (SISNR dB) - Cao hơn là Tốt hơn')
    plt.xlabel('Steps')
    plt.ylabel('dB')
    plt.legend(); plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Đã cập nhật biểu đồ tại: {save_path}")
    plt.close()

### Training

In [None]:
def set_seed(seed=42):
    """
    Cố định mọi nguồn ngẫu nhiên để đảm bảo Reproducibility (Tái lập kết quả).
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f"Random Seed đã được cố định: {seed}")

In [None]:
def main():
    SEED = 42
    set_seed(SEED)
    # --- CONFIG ---
    TRAIN_DATA = "PandaLT/Neural-AEC-No-Noise"
    TEST_DATA  = "PandaLT/Neural-AEC-Test-No-Noise"
    
    BATCH_SIZE = 16
    LR = 1e-5
    EPOCHS = 1
    
    LOG_INTERVAL = 50     
    TEST_INTERVAL = 50   
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {DEVICE}")

    # --- LOAD DATA ---
    train_ds = HuggingFaceAECDataset(TRAIN_DATA, split="train")
    test_ds  = HuggingFaceAECDataset(TEST_DATA, split="train")
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # --- SETUP ---
    model = WaveformAEC().to(DEVICE)
    count_parameters(model)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = NegSISNRLoss()

    # --- HISTORY ---
    step_hist = {'steps': [], 'loss': []}
    test_hist = {'steps': [], 'loss': []}
    
    global_step = 0
    best_loss = float('inf')

    # --- LOOP ---
    for epoch in range(EPOCHS):
        model.train() # Đảm bảo đang ở chế độ train
        print(f"\n>>> Epoch {epoch+1}/{EPOCHS}")
        
        for i, (mic, ref, clean) in enumerate(train_loader):
            # 1. TRAINING STEP
            mic, ref, clean = mic.to(DEVICE), ref.to(DEVICE), clean.to(DEVICE)
            
            optimizer.zero_grad()
            est = model(mic, ref)
            loss = criterion(est, clean)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            global_step += 1

            # 2. LOG TRAIN (Mỗi 10 bước)
            if global_step % LOG_INTERVAL == 0:
                step_hist['steps'].append(global_step)
                step_hist['loss'].append(loss.item())
                print(f"  Step {global_step} | Train Loss: {loss.item():.4f}")

            # 3. TEST STEP (Mỗi 100 bước)
            if global_step % TEST_INTERVAL == 0:
                print(f"  [Testing at step {global_step}...] ", end="")
            
                model.eval() 
                test_loss = 0.0
                
                # Ở đây ta chạy toàn bộ theo yêu cầu
                with torch.no_grad():
                    for t_mic, t_ref, t_clean in test_loader:
                        t_mic, t_ref, t_clean = t_mic.to(DEVICE), t_ref.to(DEVICE), t_clean.to(DEVICE)
                        t_est = model(t_mic, t_ref)
                        t_loss = criterion(t_est, t_clean)
                        test_loss += t_loss.item()
                
                avg_test_loss = test_loss / len(test_loader)
                print(f"Test Loss: {avg_test_loss:.4f} | SISNR: {-avg_test_loss:.2f}dB")
                
                # Lưu kết quả Test
                test_hist['steps'].append(global_step)
                test_hist['loss'].append(avg_test_loss)
                
                # Check Best Model
                if avg_test_loss < best_loss:
                    best_loss = avg_test_loss
                    torch.save(model.state_dict(), "best_aec_model.pth")
                    print(" @#$%^& Saved Best Model!")

                plot_metrics(step_hist, test_hist)
                model.train() 

        # Cuối epoch lưu checkpoint dự phòng
        torch.save(model.state_dict(), f"checkpoint_last.pth")