Directly training a binary mask for every time frame (not frequency) by appending the DoA information to all the microphone's STFT's repeatedly for each time frame

Uses surrogate gradient descent to get continuous gradients during backpropagation while still outputting binary masks.

In [None]:
import torch
import torchaudio
import soundfile as sf
from torchaudio.transforms import Spectrogram, InverseSpectrogram

def normalize_mag(mag, threshold=1e-4):
    frame_norms = torch.norm(mag, dim=-1, keepdim=True)  # [B, T, 1]
    max_norm = frame_norms.max(dim=1, keepdim=True).values  # [B, 1, 1]
    scale = torch.where(frame_norms > threshold, frame_norms, torch.ones_like(frame_norms))
    scale_factor = max_norm / scale
    return mag * scale_factor

def normalize_audio_from_file(filepath, output_path, n_fft=512, hop_length=128):
    # Load audio
    waveform, sr = torchaudio.load(filepath)  # [1, T] or [C, T]
    waveform = waveform.mean(dim=0, keepdim=True)  # Ensure mono
    waveform = waveform.unsqueeze(0)  # [1, 1, T]

    # STFT and ISTFT setup
    spec = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
    istft = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)

    # Forward STFT
    X = spec(waveform.squeeze(0))  # [1, F, T]
    mag = X.abs()                  # [1, F, T]
    phase = X.angle()             # [1, F, T]

    # Transpose to [B, T, F] for compatibility with normalize_mag
    mag_t = mag.permute(0, 2, 1)  # [1, T, F]
    norm_mag_t = normalize_mag(mag_t)
    norm_mag = norm_mag_t.permute(0, 2, 1)  # [1, F, T]

    # Reconstruct complex STFT with original phase
    X_norm = norm_mag * torch.exp(1j * phase)

    # ISTFT to waveform
    output_waveform = istft(X_norm)  # [1, T]
    output_waveform = output_waveform.squeeze(0)  # [T]

    # Save output
    sf.write(output_path, output_waveform.cpu().numpy(), sr)

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import soundfile as sf
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Spectrogram, InverseSpectrogram

class PerFrameGainNet(nn.Module):

    def __init__(self, n_freq_bins: int, hidden_dim: int = 256):
        super().__init__()
        # input_dim = 4 * n_freq_bins  (flattened magnitudes) + 3 (DoA)
        self.input_dim = 4 * n_freq_bins + 3
        self.net = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # ensures g_t in [0,1]
        )

    def forward(self, mag4ch: torch.Tensor, doa: torch.Tensor):
        
        B, T, F, C = mag4ch.shape
        mag_flat = mag4ch.view(B, T, 4 * F)              # [B, T, 4F]
        doa_exp = doa.unsqueeze(1).expand(B, T, 3)       # [B, T, 3]
        inp = torch.cat([mag_flat, doa_exp], dim=-1)     # [B, T, 4F+3]
        inp2 = inp.view(B * T, -1)                       # [B*T, 4F+3]
        g = self.net(inp2)                               # [B*T, 1]
        g = g.view(B, T)                                 # [B, T]

    # Binary STE: threshold at 0.5, but allow gradient through original g
        g_binary = (g > 0.5).float()
        g_out = g + (g_binary - g).detach()              # STE: forward binary, backward real

        return g_out


class STFTHelper:
    def __init__(self, n_fft=512, hop_length=128, power=None, pad_mode='reflect', device='cpu'):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.spectrogram = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power, pad_mode=pad_mode).to(device)
        self.istft = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length).to(device)

    def multichannel_stft(self, waveform_4ch: torch.Tensor):
        
        B, T, C = waveform_4ch.shape
        x_reshaped = waveform_4ch.permute(0, 2, 1).contiguous()  # [B, 4, T]
        x_reshaped = x_reshaped.view(B * C, T)                   # [B*4, T]
        X = self.spectrogram(x_reshaped)                         # [B*4, F, T_stft]
        FreqBins, Tstft = X.shape[-2], X.shape[-1]
        X = X.view(B, C, FreqBins, Tstft)                         # [B, 4, F, T]
        X = X.permute(0, 3, 2, 1).contiguous()                    # [B, T, F, 4]
        mag4ch = X.abs()                                          # [B, T, F, 4]
        phase4ch = X.angle()                                      # [B, T, F, 4]
        return mag4ch, phase4ch

    def mono_stft(self, waveform_1ch: torch.Tensor):

        X = self.spectrogram(waveform_1ch)                        # [B, F, T_stft]
        X = X.permute(0, 2, 1).contiguous()                        # [B, T_stft, F]
        mag1ch = X.abs()
        return mag1ch

    def istft_reconstruct(self, mag: torch.Tensor, phase: torch.Tensor = None):

        if mag.ndim == 3 and mag.shape[1] != self.n_fft // 2 + 1:
            # assume input is [B, T_stft, F]
            mag = mag.permute(0, 2, 1).contiguous()  # to [B, F, T_stft]
        if phase is not None:
            complex_spec = mag * torch.exp(1j * phase)
        else:
            # Use magnitude and zero phase
            complex_spec = mag
        waveform = self.istft(complex_spec)
        return waveform

# Loss Function
def stft_gain_loss(pred_gain: torch.Tensor, clean_mag: torch.Tensor, ref_mag: torch.Tensor):

    pred_gain = pred_gain.unsqueeze(-1)                        # [B, T_stft, 1]
    estimated_mag = pred_gain * ref_mag                        # [B, T_stft, F]
    loss = F.mse_loss(estimated_mag, clean_mag)
    return loss

In [None]:
class FourChannelDataset(Dataset):
    """
    Expects a list of tuples (path_4ch_wav, path_clean_wav, doa_vector)
    Each 4ch wav has shape [T, 4]; clean wav has shape [T]
    """
    def __init__(self, file_list, stft_helper, device='cpu'):
        self.file_list = file_list
        self.stft_helper = stft_helper
        self.device = device

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path_4ch, path_clean, doa_vec = self.file_list[idx]
        # Load audio
        audio_4ch, fs1 = sf.read(path_4ch)
        audio_clean, fs2 = sf.read(path_clean)
        assert fs1 == fs2, "Sampling rate mismatch"
        audio_4ch = torch.from_numpy(audio_4ch).float().to(self.device)
        audio_clean = torch.from_numpy(audio_clean).float().to(self.device)
        doa = torch.from_numpy(np.array(doa_vec, dtype=np.float32)).to(self.device)

        # Compute STFT magnitudes
        mag4ch, _ = self.stft_helper.multichannel_stft(audio_4ch.unsqueeze(0))  # [1, Tstft, F, 4]
        mag_clean = self.stft_helper.mono_stft(audio_clean.unsqueeze(0))        # [1, Tstft, F]
    
        mag4ch = mag4ch.squeeze(0)     # [Tstft, F, 4]
        mag_clean = mag_clean.squeeze(0)  # [Tstft, F]

        # Normalize only reference (channel 0)
        mag4ch_ref = normalize_mag(mag4ch[..., 0])  # [Tstft, F]
        mag4ch[..., 0] = mag4ch_ref  # Replace normalized ref channel
        #mag_clean=normalize_mag(mag_clean)
        return mag4ch, mag_clean, doa

In [None]:
file_list = [("/kaggle/input/singlechannelaudios/multichannel_input.wav", "/kaggle/input/singlechannelaudios/target.wav", [0.0,-1.0,0.0])]  # <-- fill with your data

device = 'cuda' if torch.cuda.is_available() else 'cpu'
stft_helper = STFTHelper(n_fft=512, hop_length=128, power=None, pad_mode='reflect', device=device)

dataset = FourChannelDataset(file_list, stft_helper, device=device)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Instantiate model
dummy_item = dataset[0]
_, mag_clean_dummy, _ = dummy_item
T_stft, n_freq_bins = mag_clean_dummy.shape
model = PerFrameGainNet(n_freq_bins=n_freq_bins , hidden_dim=256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

best_loss = float('inf')
best_state = None
num_epochs = 10  # adjust as needed
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for mag4ch, mag_clean, doa in loader:
        # mag4ch: [B, Tstft, F, 4], mag_clean: [B, Tstft, F], doa: [B,3]
        mag4ch = mag4ch.to(device)
        mag_clean = mag_clean.to(device)
        doa = doa.to(device)

        optimizer.zero_grad()
        gains = model(mag4ch, doa)               # [B, Tstft]
        ref_mag = mag4ch[..., 0]                  # [B, Tstft, F]
        loss = stft_gain_loss(gains, mag_clean, ref_mag)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * mag4ch.size(0)

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.6f}")
    if epoch_loss < best_loss:
        best_gains=gains
        best_loss = epoch_loss
        best_state = model.state_dict().copy()

# Save best model weights
torch.save(best_state, "best_gain_model.pth")

In [None]:
import torch
import torchaudio
import torchaudio.transforms as T

def apply_gains_and_write(input_path: str, gains: torch.Tensor, output_path: str,
                          n_fft=512, hop_length=128):

    # Load multichannel audio
    waveform, sample_rate = torchaudio.load(input_path)  # [4, T]
    channel1 = waveform[0:1, :]                           # [1, T]

    # Compute STFT
    stft_transform = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
    X = stft_transform(channel1)                          # [1, F, T_stft]
    
    # Apply gain
    mag = X.abs()                                         # [1, F, T_stft]
    phase = X.angle()                                     # [1, F, T_stft]

    if gains.ndim == 1:
        gains = gains.unsqueeze(0)                        # [1, T_stft]
    gains = gains.unsqueeze(1)                            # [1, 1, T_stft]
    
    est_mag = mag * gains                                 # [1, F, T_stft]
    complex_spec = est_mag * torch.exp(1j * phase)        # [1, F, T_stft]

    # Inverse STFT
    istft_transform = T.InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)
    enhanced_waveform = istft_transform(complex_spec)     # [1, T]

    # Save output
    torchaudio.save(output_path, enhanced_waveform.cpu(), sample_rate)

    return enhanced_waveform