In [27]:
import torch
import torch.nn as nn
import torchaudio
import glob
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import wandb
from torch.utils.data import Dataset, DataLoader

In [28]:
wandb.login(key="74431323c68300cd7507575e9532ee9077cd6a0a")




False

### Conformer Architecture

In [29]:
class CasualMHSA(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
    def forward(self, x):
        B, T, D = x.shape
        attn_mask = torch.triu(torch.ones(T, T, device=x.device) * float('-inf'), diagonal=1)
        x_out, _ = self.mha(x, x, x, attn_mask=attn_mask)
        return x_out

class ConformerConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=15, dropout=0.1):
        super().__init__()
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=(kernel_size-1)//2, groups=d_model)
        self.batch_norm = nn.GroupNorm(num_groups=1, num_channels=d_model)
        self.activation = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(x)))))))
        return x.transpose(1, 2)

class FeedForwardModule(nn.Module):
    def __init__(self, d_model, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_model*expansion_factor)
        self.activation = nn.SiLU()
        self.dropout1 = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_model*expansion_factor, d_model)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, x):
        return self.dropout2(self.layer2(self.dropout1(self.activation(self.layer1(x)))))

class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_head, kernel_size=15, dropout=0.1):
        super().__init__()
        self.ffn1 = FeedForwardModule(d_model, dropout=dropout)
        self.conv_module = ConformerConvModule(d_model, kernel_size=kernel_size, dropout=dropout)
        self.self_attn = CasualMHSA(d_model, n_head, dropout=dropout)
        self.ffn2 = FeedForwardModule(d_model, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model); self.norm4 = nn.LayerNorm(d_model)
        self.final_norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = x + 0.5 * self.ffn1(self.norm1(x))
        x = x + self.conv_module(self.norm2(x))
        x = x + self.self_attn(self.norm3(x))
        x = x + 0.5 * self.ffn2(self.norm4(x))
        return self.final_norm(x)

### Neural AEC 

In [30]:
class AEC(nn.Module):
    def __init__(self,d_model=128, n_fft=512, n_head=8, num_layers=4, kernel_size=15 ):
        super().__init__()
        self.n_fft = n_fft
        self.n_freq = n_fft //2 + 1
        input_dim = self.n_freq*4
        
        self.layers = nn.ModuleList([
            ConformerBlock(d_model,n_head, kernel_size=kernel_size)
            for _ in range(num_layers)
        ])
        self.mask_proj = nn.Linear(d_model, self.n_freq*2)
        self.tanh = nn.Tanh()
    def forward(self, mic_stft, ref_stft):
        B, F, T, C = mic_stft.shape
        mic_flat = mic_stft.permute(0, 2, 1, 3).reshape(B, T, F*2)
        ref_flat = ref_stft.permute(0, 2, 1, 3).reshape(B, T, F*2)

        x = torch.cat([mic_flat, ref_flat], dim=2)
        x = self.input_proj(x)

        for layer in self.layer:
            x = layer(x)
        
        mask = self.mask_proj(x)
        mask = mask.view(B,T,F,2).permute(0,2,1,3)
        mic_real = mic_stft[..., 0]
        mic_imag = mic_stft[..., 1]
        mask_real = mask[..., 0]
        mask_imag = mask[..., 1]

        est_real = mic_real*mask_real - mic_imag*mask_imag
        est_imag = mic_real*mask_imag + mic_imag*mask_real
        est_stft = torch.stack([est_real, est_imag], dim=-1)

        return est_stft

### Prepare Dataset

In [31]:
class STFTDataset(Dataset):
    def __init__(self, mic, ref, clean, split="train", n_fft=512, hop_length=256, duration=10):
        mics = sorted(glob.glob(os.path.join(mic, "*.wav")))
        refs = sorted(glob.glob(os.path.join(ref, "*.wav")))
        cleans = sorted(glob.glob(os.path.join(clean, "*.wav")))
        
        if split == "val":
            self.mic_files = mics[:500]
            self.ref_files = refs[:500]
            self.clean_files = cleans[:500]
        else:
            self.mic_files = mics[500:]
            self.ref_files = refs[500:]
            self.clean_files = cleans[500:]
        
        self.n_fft, self.hop_length = n_fft, hop_length
        self.max_len = int(16000 * duration)
        print(f"{split.upper()} Dataset: {len(self.mic_files)} samples.")

    def __len__(self): return len(self.mic_files)

    def __getitem__(self, idx):
        mic, _ = torchaudio.load(self.mic_files[idx])
        ref, _ = torchaudio.load(self.ref_files[idx])
        clean, _ = torchaudio.load(self.clean_files[idx])

        # Padding/Clipping logic
        for audio in [mic, ref, clean]:
            if audio.shape[1] > self.max_len: pass

            
        window = torch.hann_window(self.n_fft)
        def get_stft(x): return torch.stft(x, self.n_fft, self.hop_length, window=window, return_complex=False).squeeze(0)
        
        return get_stft(mic), get_stft(ref), get_stft(clean)

### Visualize

In [32]:
def stft_to_mag_db(stft_tensor):
    mag = torch.sqrt(stft_tensor[..., 0]**2 + stft_tensor[..., 1]**2 + 1e-9)
    return 20 * torch.log10(mag + 1e-9)

def visual_check(model, device, epoch, config):
    model.eval()
    with torch.no_grad():
        # Load fixed samples
        def load_fixed(p):
            wav, _ = torchaudio.load(p)
            return torch.stft(wav, config['n_fft'], config['hop_length'], 
                              window=torch.hann_window(config['n_fft']), return_complex=False).to(device)
        
        mic = load_fixed(config['fixed_mic_path'])
        ref = load_fixed(config['fixed_ref_path'])
        clean = load_fixed(config['fixed_clean_path'])
        
        est = model(mic, ref)
        
        imgs = [stft_to_mag_db(mic[0]), stft_to_mag_db(ref[0]), 
                stft_to_mag_db(clean[0]), stft_to_mag_db(est[0])]
        titles = ["Microphone", "Reference", "Clean (Target)", f"Estimate (Epoch {epoch})"]
        
        fig, axs = plt.subplots(4, 1, figsize=(10, 12), facecolor='black')
        for i, img in enumerate(imgs):
            axs[i].imshow(img.cpu().numpy(), origin='lower', aspect='auto', cmap='magma')
            axs[i].set_title(titles[i], color='white', fontsize=10)
            axs[i].axis('off') # Xóa trục x và y
        
        plt.tight_layout()
        wandb.log({"Spectrogram_Visual": wandb.Image(fig)}, step=epoch)
        plt.close(fig)
    model.train()

### Training Loop

In [33]:
def main():
    CONFIG = {
        "n_fft": 400, 
        "hop_length": 200, 
        "d_model": 128, 
        "num_layers": 4,
        "batch_size": 16,
        "lr": 1e-5, 
        "epochs": 1, 
        "log_step": 10, 
        "test_step": 125,
        "data_path": {
            "mic": "D:/AEC-Challenge/datasets/synthetic/nearend_mic_signal",
            "ref": "D:/AEC-Challenge/datasets/synthetic/farend_speech",
            "clean": "D:/AEC-Challenge/datasets/synthetic/nearend_speech"
        },
        "fixed_mic_path": "source/nearend_mic_fileid_1609.wav",
        "fixed_ref_path": "source/farend_speech_fileid_1609.wav",
        "fixed_clean_path": "source/nearend_speech_fileid_1609.wav"
    }
    
    wandb.init(project="AEC_STFT_Conformer", config=CONFIG)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load Train & Val
    train_loader = DataLoader(STFTDataset(**CONFIG['data_path'], split="train"), batch_size=CONFIG['batch_size'], shuffle=True)
    val_loader = DataLoader(STFTDataset(**CONFIG['data_path'], split="val"), batch_size=CONFIG['batch_size'], shuffle=False)
    
    model = AEC(d_model=CONFIG['d_model'], num_layers=CONFIG['num_layers']).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    loss_fn = nn.MSELoss() 
    
    for epoch in range(CONFIG['epochs']):
        model.train()
        train_loss_epoch = 0
        for i, (mic, ref, clean) in enumerate(train_loader):
            mic, ref, clean = mic.to(device), ref.to(device), clean.to(device)
            est = model(mic, ref)
            
            # Hybrid Loss
            loss_complex = loss_fn(est, clean)
            loss_mag = loss_fn(torch.abs(torch.view_as_complex(est)), torch.abs(torch.view_as_complex(clean)))
            loss = loss_complex + loss_mag
            
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            train_loss_epoch += loss.item()
            if i % CONFIG['log_step'] == 0: wandb.log({"train_loss": loss.item()})
        
        # Validation Phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for mic, ref, clean in val_loader:
                mic, ref, clean = mic.to(device), ref.to(device), clean.to(device)
                val_loss += loss_fn(model(mic, ref), clean).item()
        
        avg_val_loss = val_loss / len(val_loader)
        wandb.log({"val_loss": avg_val_loss, "epoch": epoch + 1})
        print(f"Epoch {epoch+1} | Val Loss: {avg_val_loss:.6f}")
        
        if (epoch + 1) % CONFIG['test_step'] == 0:
            visual_check(model, device, epoch + 1, CONFIG)
            torch.save(model.state_dict(), f"aec_epoch_{epoch+1}.pth")

    wandb.finish()

In [36]:
main()

TRAIN Dataset: 9500 samples.
VAL Dataset: 500 samples.


RuntimeError: Could not load libtorchcodec. Likely causes:
          1. FFmpeg is not properly installed in your environment. We support
             versions 4, 5, 6, 7, and 8. On Windows, ensure you've installed
             the "full-shared" version which ships DLLs.
          2. The PyTorch version (2.9.1+cu130) is not compatible with
             this version of TorchCodec. Refer to the version compatibility
             table:
             https://github.com/pytorch/torchcodec?tab=readme-ov-file#installing-torchcodec.
          3. Another runtime dependency; see exceptions below.
        The following exceptions were raised as we tried to load libtorchcodec:
        
[start of libtorchcodec loading traceback]
FFmpeg version 8: Could not load this library: C:\Users\Admin.ADMIN-PC\AppData\Local\Programs\Python\Python313\Lib\site-packages\torchcodec\libtorchcodec_core8.dll
FFmpeg version 7: Could not load this library: C:\Users\Admin.ADMIN-PC\AppData\Local\Programs\Python\Python313\Lib\site-packages\torchcodec\libtorchcodec_core7.dll
FFmpeg version 6: Could not load this library: C:\Users\Admin.ADMIN-PC\AppData\Local\Programs\Python\Python313\Lib\site-packages\torchcodec\libtorchcodec_core6.dll
FFmpeg version 5: Could not load this library: C:\Users\Admin.ADMIN-PC\AppData\Local\Programs\Python\Python313\Lib\site-packages\torchcodec\libtorchcodec_core5.dll
FFmpeg version 4: Could not load this library: C:\Users\Admin.ADMIN-PC\AppData\Local\Programs\Python\Python313\Lib\site-packages\torchcodec\libtorchcodec_core4.dll
[end of libtorchcodec loading traceback].



In [None]:
"""
import torch
print(f"PyTorch Version: {torch.__version__}")
print(f"Is GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version built with PyTorch: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
"""


PyTorch Version: 2.9.1+cu130
Is GPU available: True
CUDA Version built with PyTorch: 13.0
GPU Name: NVIDIA GeForce RTX 3060
