In [1]:
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from CustomDataLoader import MyDataset
from torch.utils.data import DataLoader
from CRN import CRNBasedMagnitudeEstimation
from CNNPhaseEnhance import CNNPhaseEnhancement
from CleanSpeechEstimate import CleanSpeechEstimation
import torch.optim.lr_scheduler as lr_scheduler

from torchsummary import summary


# There we get Channel Dimension only if we get the complex numbers as individual values, otheriwse if we get as complex_value = True then all are as 1 on channel dimension

class SpeechEnhancementModel(nn.Module):
    def __init__(self, crn_params, cnn_params, power_law=0.3):
        super(SpeechEnhancementModel, self).__init__()
        self.power_law = power_law
        self.magnitude_estimation = CRNBasedMagnitudeEstimation(**crn_params)
        self.phase_enhancement = CNNPhaseEnhancement(**cnn_params)
        self.clean_speech_estimation = CleanSpeechEstimation(power_law)

    def stft(self,x, fft_size=512, hop_size=256, window='hann'): 
    # Create window function
        if window == 'hann':
            window = torch.hann_window(fft_size)
        else:
            raise ValueError("only hann window is supported.")

        # Compute STFT
        stft_matrix = torch.stft(x, fft_size, hop_size, window=window,return_complex=True) #  = false
        return stft_matrix


    def power_law_compression(self, signal):
        real = signal.real 
        imag = signal.imag 

        compressed_real = torch.pow(torch.abs(real), self.power_law) * torch.sign(real)

        compressed_img = torch.pow(torch.abs(imag), self.power_law) * torch.sign(imag)

        compressed_mag = torch.sqrt(torch.square(compressed_real)+torch.square(compressed_img))
        compressed_phase = torch.arctan(compressed_img/compressed_real)

        return [compressed_mag,compressed_phase]
    
    def inv_stft(self,complex_tensor,n_fft=512,hop_length=16,win_length=32):
        time_domain_signal = torch.istft(complex_tensor, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
        return time_domain_signal

    def forward(self, noisy_signal):
        print(noisy_signal.shape)
        # Apply stft 
        stft_matrix = self.stft(noisy_signal)

        print("After STFT :",stft_matrix.shape)

        # Apply power law compression
        compressed_mag,compressed_phase = self.power_law_compression(stft_matrix)

        print("------ CRN Mag estimation block----------")
        # Stage 1: CRN-based Magnitude Estimation
        magnitude_mask = self.magnitude_estimation(compressed_mag)

        # print("Magnitude Mask",magnitude_mask)
        # print("Magnitude Mask shape",magnitude_mask.shape)


        print("---- CNN based phase enhancement--------")
        # Stage 2: CNN-based Phase Enhancement
        complex_mask = self.phase_enhancement(magnitude_mask, compressed_phase)
        
        # Clean Speech Estimation
        estimated_clean_speech = self.clean_speech_estimation(complex_mask, noisy_signal)
        
        return estimated_clean_speech

# Example usage
crn_params = {
    'in_channels': 1, # 64 
    'out_channels': 128, 
    'hidden_size': 32, 
    'num_subbands': 8
}
cnn_params = {
    'in_channels': 2, 
    'out_channels': 64
}


csv_file = './DataLoad/files.csv'
train = './DataLoad/train'
label = './DataLoad/label'  
batches = 1

mdataset = MyDataset(csv=csv_file,train_dir=train,label_dir=label)

train_set ,test_set  = torch.utils.data.random_split(mdataset,[23,23])

train_loader = DataLoader(train_set,batch_size=batches,shuffle=True)
test_loader = DataLoader(test_set,batch_size=batches,shuffle=True)


sample_rate = 16000
duration = 10
channels = 2

num_epochs = 1
learning_rate = 0.0004


# Initialize the model, loss function, and optimizer
model = SpeechEnhancementModel(crn_params, cnn_params)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=10)

# noisy_signal = torch.randn(10, 64, 256, 100)  # (batch_size, channels, freq_bins, time_steps)
# print(output.shape)  # Expected output shape: (batch_size, freq_bins, time_steps)


for epoch in range(num_epochs):
    running_loss = 0.0
    for i,data in enumerate(train_loader):
        noisy_waveform, clean_waveform = data

        optimizer.zero_grad()

        # Forward pass
        outputs = model(noisy_waveform)
        
        # Compute loss
        loss = criterion(outputs, clean_waveform)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
                    
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

print("Training finished!")


# summary(model,torch.tensor(mdataset.__getitem__(0)[0]))
# k = model(torch.tensor(mdataset.__getitem__(0)[0]))

torch.Size([1, 1323000])
After STFT : torch.Size([1, 257, 5168])
------ CRN Mag estimation block----------
0
32
64
96
128
160
192
224
After concatenated subbands torch.Size([8, 128, 40, 5168])
after permute torch.Size([128, 40, 41344])
After the gru : torch.Size([128, 40, 64])
After the point wise conv torch.Size([128, 64, 40, 1])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x64 and 32x257)

In [3]:
info = torchaudio.info(mdataset.give_path(0))

num_channels = info.num_channels

print(num_channels)

1


In [6]:
k = mdataset.__getitem__(0)
print(k[0])
print(k[0].shape)

[-0.00372978 -0.00413349 -0.00312093 ...  0.08128223  0.09374231
  0.06016465]
(1323000,)


In [5]:
class DepthwiseSeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(DepthwiseSeparableConv2D, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,stride=stride,padding=padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,downsample=True):
        super(ConvBlock, self).__init__()
        self.conv = DepthwiseSeparableConv2D(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU()
        self.downsample = downsample
        if self.downsample:
            self.pool = nn.MaxPool2d(kernel_size=(1,2),stride=(1,2))

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        if self.downsample:
            self.pool(x)
        return x

class FGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=True):
        super(FGRU, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, bidirectional=bidirectional)

    def forward(self, x):
        x, _ = self.gru(x)
        return x

class TemporalGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(TemporalGRU, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers)

    def forward(self, x):
        x, _ = self.gru(x)
        return x

class CRNBasedMagnitudeEstimation(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_size, num_subbands, kernel_size=(1, 3), stride=1, padding=1):
        super(CRNBasedMagnitudeEstimation, self).__init__()
        self.num_subbands = num_subbands
        
        # Convolutional Block with Channelwise Feature Reorientation

        # filters used each conv respectively
        filters = [32,64,96,128]

        self.conv_blocks = nn.ModuleList([
            ConvBlock(in_channels // num_subbands, filters[i], kernel_size, stride, padding,downsample=(i>0))
            for i in range(num_subbands)
        ])

        # Frequency-axis Bidirectional GRU
        self.fgru = FGRU(filters[1], hidden_size)

        # Point-wise Convolution
        self.pointwise_conv = nn.Conv2d(filters[1], hidden_size, kernel_size=1)

        # Temporal GRU
        self.temporal_gru = TemporalGRU(hidden_size, hidden_size)

        # Fully Connected Layers for Intermediate Mask Estimation
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x):
        channels, freq_bins, time_steps = x.size()
        
        # Channelwise Feature Reorientation
        subband_outputs = []
        subband_size = freq_bins // self.num_subbands
        for i in range(self.num_subbands):
            start = i * subband_size
            end = start + subband_size
            subband = x[:, start:end, :]
            subband_output = self.conv_blocks[i](subband.unsqueeze(0))
            subband_outputs.append(subband_output.squeeze(0))
        
        # Concatenate subband outputs
        x = torch.cat(subband_outputs, dim=1)
        
        # Frequency-axis Bidirectional GRU
        x = x.permute(1,0,2).contiguous().view(freq_bins, -1)  # (freq_bins, channels * time_steps)

        x = self.fgru(x.unsqueeze(1)).squeeze(1)
        
        # Point-wise Convolution
        x = x.permute(1,0).view(-1, freq_bins, time_steps).unsqueeze(0) # (batch_size, hidden_size, freq_bins, time_steps)
        x = self.pointwise_conv(x)
        
        # Temporal GRU
        x = x.permute(3, 0, 1, 2).contiguous().view(time_steps, -1)  # (time_steps, batch_size, hidden_size * freq_bins)
        x = self.temporal_gru(x.unsqueeze(1)).squeeze(1)
        
        # Fully Connected Layers for Intermediate Mask Estimation
        x = x[-1]  # take the last time step output
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x


In [None]:
class CNNPhaseEnhancement(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(1, 3), stride=1, padding=1):
        super(CNNPhaseEnhancement, self).__init__()
        
        # Intermediate Feature Computation (Noisy Phase and Intermediate Real Magnitude Mask)
        self.combine_features = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
        # CNN Processing
        self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        
        # Complex Mask Estimation
        self.complex_mask_estimation = nn.Conv2d(out_channels, 2, kernel_size=1)  # 2 for real and imaginary parts # Pointwise conv

    def forward(self, magnitude_mask, noisy_phase):

        x = torch.cat((magnitude_mask, noisy_phase), dim=1)
        x = self.combine_features(x)   # Combine magnitude mask and noisy phase = Internmediate feature computation 

        
        # CNN Processing
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        
        # Complex Mask Estimation
        complex_mask = self.complex_mask_estimation(x)
        
        return complex_mask

In [None]:
class CleanSpeechEstimation(nn.Module):
    def __init__(self, power_law=0.3):
        super(CleanSpeechEstimation, self).__init__()
        self.power_law = power_law

    def forward(self, complex_mask, noisy_signal):
        # Estimated clean speech signal
        real_mask = complex_mask[:, 0, :, :]
        imag_mask = complex_mask[:, 1, :, :]
        estimated_complex_speech = (real_mask + 1j * imag_mask) * noisy_signal
        
        # Power law decompression
        estimated_clean_speech = torch.pow(torch.abs(estimated_complex_speech), 1/self.power_law) * torch.sign(estimated_complex_speech)
        
        return estimated_clean_speech