# Room Impulse Response Generation with GAN

Reference:

Ratnarajah, A., Tang, Z., & Manocha, D. (2021). IR-GAN: Room impulse response generator for far-field speech recognition. Proc. Interspeech.


In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchaudio
import numpy as np

In [3]:
## data

# ! gdown --id 1YX1XEpJ2W1cZD4Dn7d5CRBVPOFLUKG4B --output ../data/RIR.zip
# ! unzip -q ../data/RIR.zip -d ../data/

In [5]:
# Define a dataset for RIRs using torchaudio to load wav files
class RIRDataset(Dataset):
    def __init__(self, data_dir, slice_len=16384):
        self.data_dir = data_dir
        self.slice_len = slice_len
        self.file_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
        
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        wav_path = self.file_paths[idx]
        waveform, sample_rate = torchaudio.load(wav_path)
        # Ensure the waveform has the correct length, pad or truncate if necessary
        if waveform.size(1) < self.slice_len:
            waveform = torch.nn.functional.pad(waveform, (0, self.slice_len - waveform.size(1)))
        elif waveform.size(1) > self.slice_len:
            waveform = waveform[:, :self.slice_len]
        return waveform

In [7]:
dataset = RIRDataset('../data/RIR')

print(len(dataset))
print(dataset[0].shape)

930
torch.Size([1, 16384])


## Introduce the concept of generative model, especially GAN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv1dTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=4, padding='same', upsample='zeros'):
        super(Conv1dTranspose, self).__init__()
        self.upsample = upsample
        self.stride = stride
        self.kernel_size = kernel_size
        if self.upsample == 'zeros':
            self.deconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, padding=(kernel_size // 2))
        elif self.upsample == 'nn':
            self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=(kernel_size // 2))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.upsample == 'zeros':
            return self.deconv(x)
        elif self.upsample == 'nn':
            x = F.interpolate(x, scale_factor=self.stride, mode='nearest')
            return self.conv(x)


class WaveGANGenerator(nn.Module):
    def __init__(self, z_dim=100, slice_len=16384, nch=1, kernel_len=25, dim=64, use_batchnorm=False, upsample='zeros'):
        super(WaveGANGenerator, self).__init__()
        assert slice_len in [16384, 32768, 65536]
        dim_mul = 16 if slice_len == 16384 else 32
        self.dim_mul = dim_mul
        self.use_batchnorm = use_batchnorm
        self.upsample = upsample
        
        # Projection and reshape
        self.fc = nn.Linear(z_dim, 16 * dim * dim_mul)
        
        # Layers
        self.upconv_0 = Conv1dTranspose(dim * dim_mul, dim * (dim_mul // 2), kernel_len, stride=4, upsample=upsample)
        self.upconv_1 = Conv1dTranspose(dim * (dim_mul // 2), dim * (dim_mul // 4), kernel_len, stride=4, upsample=upsample)
        self.upconv_2 = Conv1dTranspose(dim * (dim_mul // 4), dim * (dim_mul // 8), kernel_len, stride=4, upsample=upsample)
        self.upconv_3 = Conv1dTranspose(dim * (dim_mul // 8), dim * (dim_mul // 16), kernel_len, stride=4, upsample=upsample)
        
        if slice_len == 16384:
            self.upconv_4 = Conv1dTranspose(dim * (dim_mul // 16), nch, kernel_len, stride=4, upsample=upsample)
        elif slice_len == 32768:
            self.upconv_4 = Conv1dTranspose(dim * (dim_mul // 16), dim, kernel_len, stride=4, upsample=upsample)
            self.upconv_5 = Conv1dTranspose(dim, nch, kernel_len, stride=2, upsample=upsample)
        elif slice_len == 65536:
            self.upconv_4 = Conv1dTranspose(dim * (dim_mul // 16), dim, kernel_len, stride=4, upsample=upsample)
            self.upconv_5 = Conv1dTranspose(dim, nch, kernel_len, stride=4, upsample=upsample)

        self.batchnorm = nn.BatchNorm1d if use_batchnorm else lambda x: x
        
    def forward(self, z):
        # FC and reshape for convolution
        output = self.fc(z)
        output = output.view(-1, self.dim_mul * 16, 16)
        output = F.relu(self.batchnorm(output))

        # Layer 0
        output = F.relu(self.batchnorm(self.upconv_0(output)))
        
        # Layer 1
        output = F.relu(self.batchnorm(self.upconv_1(output)))
        
        # Layer 2
        output = F.relu(self.batchnorm(self.upconv_2(output)))
        
        # Layer 3
        output = F.relu(self.batchnorm(self.upconv_3(output)))
        
        if hasattr(self, 'upconv_5'):
            output = F.relu(self.batchnorm(self.upconv_4(output)))
            output = torch.tanh(self.upconv_5(output))
        else:
            output = torch.tanh(self.upconv_4(output))

        return output


class WaveGANDiscriminator(nn.Module):
    def __init__(self, slice_len=16384, kernel_len=25, dim=64, use_batchnorm=False, phaseshuffle_rad=0):
        super(WaveGANDiscriminator, self).__init__()
        assert slice_len in [16384, 32768, 65536]
        self.dim = dim
        self.kernel_len = kernel_len
        self.use_batchnorm = use_batchnorm
        self.phaseshuffle_rad = phaseshuffle_rad

        self.conv_0 = nn.Conv1d(1, dim, kernel_len, stride=4, padding=(kernel_len // 2))
        self.conv_1 = nn.Conv1d(dim, dim * 2, kernel_len, stride=4, padding=(kernel_len // 2))
        self.conv_2 = nn.Conv1d(dim * 2, dim * 4, kernel_len, stride=4, padding=(kernel_len // 2))
        self.conv_3 = nn.Conv1d(dim * 4, dim * 8, kernel_len, stride=4, padding=(kernel_len // 2))
        self.conv_4 = nn.Conv1d(dim * 8, dim * 16, kernel_len, stride=4, padding=(kernel_len // 2))

        if slice_len in [32768, 65536]:
            self.conv_5 = nn.Conv1d(dim * 16, dim * 32, kernel_len, stride=(4 if slice_len == 65536 else 2), padding=(kernel_len // 2))

        self.fc = nn.Linear(dim * 16 * (slice_len // (4 ** 5)), 1)
        self.batchnorm = nn.BatchNorm1d if use_batchnorm else lambda x: x

    def forward(self, x):
        output = F.leaky_relu(self.conv_0(x), 0.2)
        output = self.phaseshuffle(output)
        output = F.leaky_relu(self.batchnorm(self.conv_1(output)), 0.2)
        output = self.phaseshuffle(output)
        output = F.leaky_relu(self.batchnorm(self.conv_2(output)), 0.2)
        output = self.phaseshuffle(output)
        output = F.leaky_relu(self.batchnorm(self.conv_3(output)), 0.2)
        output = self.phaseshuffle(output)
        output = F.leaky_relu(self.batchnorm(self.conv_4(output)), 0.2)

        if hasattr(self, 'conv_5'):
            output = F.leaky_relu(self.batchnorm(self.conv_5(output)), 0.2)

        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output
    
    def phaseshuffle(self, x):
        if self.phaseshuffle_rad > 0:
            phase = torch.randint(-self.phaseshuffle_rad, self.phaseshuffle_rad + 1, (1,)).item()
            if phase > 0:
                x = F.pad(x, (phase, 0), mode='reflect')
                x = x[:, :, :-phase]
            elif phase < 0:
                x = F.pad(x, (0, -phase), mode='reflect')
                x = x[:, :, -phase:]
        return x


## An example use of sound field reconstruction

In [None]:


# Function for training the GAN
def train_gan(data_dir, train_batch_size=64, epochs=20, latent_dim=100, lr=0.0002, data_slice_len=16384):
    # Load dataset
    dataset = RIRDataset(data_dir, slice_len=data_slice_len)
    dataloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
    
    # Instantiate generator and discriminator
    generator = WaveGANGenerator()
    discriminator = WaveGANDiscriminator()
    
    # Define loss function and optimizers
    criterion = nn.BCEWithLogitsLoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    # Training loop
    for epoch in range(epochs):
        for i, real_data in enumerate(dataloader):
            batch_size = real_data.size(0)
            
            # Train discriminator
            z = torch.randn(batch_size, latent_dim)
            fake_data = generator(z)
            
            # Labels
            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)
            
            # Discriminator loss on real data
            optimizer_d.zero_grad()
            d_real = discriminator(real_data)
            d_real_loss = criterion(d_real, real_labels)
            
            # Discriminator loss on fake data
            d_fake = discriminator(fake_data.detach())
            d_fake_loss = criterion(d_fake, fake_labels)
            
            # Combine losses
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            optimizer_d.step()
            
            # Train generator
            optimizer_g.zero_grad()
            d_fake = discriminator(fake_data)
            g_loss = criterion(d_fake, real_labels)  # Fool the discriminator
            g_loss.backward()
            optimizer_g.step()
            
            # Print progress
            if i % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], \
                      D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    

In [None]:
# Set default values for training parameters
data_dir = './data/rir_wavs'  # Directory containing RIR wav files
train_batch_size = 64
epochs = 20
latent_dim = 100
lr = 0.0002
data_slice_len = 16384

# Create training directory if it doesn't exist
if not os.path.exists("./train_dir"):
    os.makedirs("./train_dir")

# Train GAN
train_gan(data_dir, train_batch_size, epochs, latent_dim, lr, data_slice_len)


## Extended Reading


Ratnarajah, A., Zhang, S. X., Yu, M., Tang, Z., Manocha, D., & Yu, D. (2022, May). FAST-RIR: Fast neural diffuse room impulse response generator. In ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (pp. 571-575). IEEE.

Fernandez-Grande, E., Karakonstantis, X., Caviedes-Nozal, D., & Gerstoft, P. (2023). [Generative models for sound field reconstruction.](https://pubs.aip.org/asa/jasa/article/153/2/1179/2866890) The Journal of the Acoustical Society of America, 153(2), 1179-1190.