In [None]:
import torch
import torch.nn as nn
import random

In [2]:
class CasualMHSA(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
    def forward(self, x):
        B, T, D = x.shape
        attn_mask = torch.triu(torch.ones(T, T, device=x.device) * float('-inf'), diagonal=1)
        x_out, _ = self.mha(x, x, x, attn_mask=attn_mask)
        return x_out

class ConformerConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=15, dropout=0.1):
        super().__init__()
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=(kernel_size-1)//2, groups=d_model)
        self.batch_norm = nn.GroupNorm(num_groups=1, num_channels=d_model)
        self.activation = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(x)))))))
        return x.transpose(1, 2)

class FeedForwardModule(nn.Module):
    def __init__(self, d_model, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_model*expansion_factor)
        self.activation = nn.SiLU()
        self.dropout1 = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_model*expansion_factor, d_model)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, x):
        return self.dropout2(self.layer2(self.dropout1(self.activation(self.layer1(x)))))

class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_head, kernel_size=15, dropout=0.1):
        super().__init__()
        self.ffn1 = FeedForwardModule(d_model, dropout=dropout)
        self.conv_module = ConformerConvModule(d_model, kernel_size=kernel_size, dropout=dropout)
        self.self_attn = CasualMHSA(d_model, n_head, dropout=dropout)
        self.ffn2 = FeedForwardModule(d_model, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model); self.norm4 = nn.LayerNorm(d_model)
        self.final_norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = x + 0.5 * self.ffn1(self.norm1(x))
        x = x + self.conv_module(self.norm2(x))
        x = x + self.self_attn(self.norm3(x))
        x = x + 0.5 * self.ffn2(self.norm4(x))
        return self.final_norm(x)

class AECModel(nn.Module):
    def __init__(self, d_model=128, n_fft=512, n_head=8, num_layers=4):
        super().__init__()
        self.n_fft = n_fft
        self.n_freq = n_fft // 2 + 1
        input_dim = self.n_freq * 4 
        self.input_proj = nn.Linear(input_dim, d_model)
        self.layers = nn.ModuleList([ConformerBlock(d_model, n_head) for _ in range(num_layers)])
        self.mask_proj = nn.Linear(d_model, self.n_freq * 2)

    def forward(self, mic_stft, ref_stft):
        B, F, T, C = mic_stft.shape
        mic_flat = mic_stft.permute(0, 2, 1, 3).reshape(B, T, F * 2)
        ref_flat = ref_stft.permute(0, 2, 1, 3).reshape(B, T, F * 2)
        x = torch.cat([mic_flat, ref_flat], dim=2)
        x = self.input_proj(x)
        for layer in self.layers:
            x = layer(x)
        mask = self.mask_proj(x).view(B, T, F, 2).permute(0, 2, 1, 3)
        mic_real, mic_imag = mic_stft[..., 0], mic_stft[..., 1]
        mask_real, mask_imag = mask[..., 0], mask[..., 1]
        est_real = mic_real * mask_real - mic_imag * mask_imag
        est_imag = mic_real * mask_imag + mic_imag * mask_real
        return torch.stack([est_real, est_imag], dim=-1)

In [None]:
class STFT_Dataset():
    def __init__(self, hf_dataset, n_fft=512, win_length=320, hop_length=160, duration=10):
        self.dataset = hf_dataset
        self.win_length, self.hoplength , self.n_fft = win_length, hop_length, n_fft
        self.max_len = int(16000*duration)
        self.window = torch.hann_window(self.win_length)  
        return

    def __len__(self):
        return len(self.dataset)
    
    def process_len(self, wav_nparray):
        wav = torch.from_numpy(wav_nparray).float()
        if wav.ndim == 1:
            wav = wav.unsqueeze(0)
        if wav.shape[1] > self.max_len:
            start = random.randint(0, wav.shape[1]-self.max_len)
            return wav[:,start:start + self.max_len], start
        return torch.nn.functional.pad(wav, (0, self.max_len - wav.shape[1])), 0
    
    def get_stft(self, wav):
        stft_complex = torch.stft(wav, n_fft=self.n_fft, hop_length=self.hop_length,
                                  win_length=self.win_length, window=self.window,
                                  center=True, return_complex=True)
        return torch.view_as_real(stft_complex).squeeze(0)


    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Xử lý audio và lấy vị trí start để cắt VAD label tương ứng
        mic_wav, start = self.process_len(item['mic']['array'])
        ref_wav, _ = self.process_len(item['ref']['array'])
        clean_wav, _ = self.process_len(item['clean']['array'])
        
        # Lấy nhãn VAD từ dataset (dạng List)
        vad_full = torch.tensor(item['vad_label'], dtype=torch.long)
        
        # Cắt nhãn VAD khớp với đoạn audio (10ms mỗi nhãn)
        num_frames = self.get_stft(mic_wav).shape[1] # Số khung T
        start_frame = start // self.hop_length
        vad_label = vad_full[start_frame : start_frame + num_frames]
        
        # Padding nhãn VAD nếu thiếu do làm tròn
        if vad_label.shape[0] < num_frames:
            vad_label = torch.nn.functional.pad(vad_label, (0, num_frames - vad_label.shape[0]))
        elif vad_label.shape[0] > num_frames:
            vad_label = vad_label[:num_frames]

        return self.get_stft(mic_wav), self.get_stft(ref_wav), self.get_stft(clean_wav), vad_label


In [None]:
def get_vad_predict(vad_model, wav_est, sr=16000):
    vad =""
    return vad

In [None]:
def multi_task_loss(est, target, vad_logits, vad_labels, beta=0.1):
    
    # 1. AEC Loss
    mae = nn.L1Loss()
    mag_est = torch.sqrt(est[..., 0]**2 + est[..., 1]**2 + 1e-9)
    mag_target = torch.sqrt(target[..., 0]**2 + target[..., 1]**2 + 1e-9)
    loss_aec = mae(mag_est, mag_target) + mae(est, target)

    loss_vad = nn.CrossEntropyLoss()(vad_logits.transpose(1, 2), vad_labels)
    
    total_loss = loss_aec + beta * loss_vad
    return total_loss, loss_aec, loss_vad