In [1]:
import torchaudio.transforms as T
import torch
import torchaudio
from torch.utils.data import Dataset
import os
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

device='cuda:1'

In [2]:
class AudioEnhancer(nn.Module):
    def __init__(self, num_transformer_layers=2, num_heads=8, cnn_filters=[32, 64, 128, 256]):
        super(AudioEnhancer, self).__init__()
        
        # CNN Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(2, cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[0], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[1], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[2], cnn_filters[3], kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Transformer
        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(d_model=cnn_filters[-1], nhead=num_heads, dim_feedforward=512, activation='relu'),
            num_layers=num_transformer_layers
        )
        
        # CNN Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(cnn_filters[3], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[2], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[1], cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[0], 2, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Input: [batch_size, 2, 48000]
        
        # CNN Encoder
        x = self.encoder(x)  # Shape: [batch_size, cnn_filters[-1], 48000]
        
        # Permute for Transformer
        x = x.permute(0, 2, 1)  # Shape: [batch_size, 48000, cnn_filters[-1]]
        x = self.transformer(x)  # Shape: [batch_size, 48000, cnn_filters[-1]]
        x = x.permute(0, 2, 1)  # Shape: [batch_size, cnn_filters[-1], 48000]
        
        # CNN Decoder
        x = self.decoder(x)  # Shape: [batch_size, 2, 48000]
        
        return x
class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        # Compute perceptual features
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        
        # Perceptual loss
        perceptual_loss = self.mse_loss(pred_features, target_features)
        
        # Reconstruction loss
        reconstruction_loss = self.mse_loss(pred, target)
        
        return perceptual_loss + reconstruction_loss

class DummyFeatureExtractor(nn.Module):
    def __init__(self):
        super(DummyFeatureExtractor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(2, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.features(x)

# Initialize feature extractor
feature_extractor = DummyFeatureExtractor().to(device)
loss_fn = PerceptualLoss(feature_extractor).to(device)

In [3]:
import torch
device='cuda:1'
# Load the entire model
model = torch.load("/home/j597s263/Models/Audio.mod")
model.eval()  # Set to evaluation mode
model.to(device)  # Move to the appropriate device (CPU or GPU)

  model = torch.load("/home/j597s263/Models/Audio.mod")


AudioEnhancer(
  (encoder): Sequential(
    (0): Conv1d(2, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (3): ReLU()
    (4): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): ReLU()
    (6): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (7): ReLU()
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dro

In [8]:
def split_into_chunks(waveform, chunk_size):
    """
    Splits the audio waveform into chunks of the specified size.
    Args:
        waveform: Tensor of shape [2, total_samples].
        chunk_size: Size of each chunk (e.g., 48000 for 1 second).
    Returns:
        List of waveform chunks.
    """
    num_samples = waveform.shape[1]
    return [waveform[:, i:i + chunk_size] for i in range(0, num_samples, chunk_size) if waveform[:, i:i + chunk_size].shape[1] == chunk_size]

def reconstruct_from_chunks(chunks):
    """
    Reconstructs the full waveform from chunks by concatenating them.
    Args:
        chunks: List of tensors of shape [2, chunk_size].
    Returns:
        Tensor of shape [2, total_samples].
    """
    return torch.cat(chunks, dim=1)

# Example: Process Test Audio
def process_audio_in_batches(model, waveform, sample_rate, chunk_size=48000):
    """
    Processes the input audio in smaller chunks and reconstructs the full output.
    Args:
        model: The trained model.
        waveform: Input waveform of shape [2, total_samples].
        sample_rate: Sample rate of the audio.
        chunk_size: Size of each chunk (default is 48000 for 1 second).
    Returns:
        Enhanced waveform of shape [2, total_samples].
    """
    # Split into chunks
    chunks = split_into_chunks(waveform, chunk_size)
    enhanced_chunks = []

    # Process each chunk
    model.eval()  # Ensure model is in evaluation mode
    with torch.no_grad():
        for chunk in chunks:
            chunk = chunk.unsqueeze(0).to(device)  # Add batch dimension
            enhanced_chunk = model(chunk)  # Shape: [1, 2, chunk_size]
            enhanced_chunks.append(enhanced_chunk.squeeze(0).cpu())  # Remove batch dimension

    # Reconstruct full audio
    enhanced_waveform = reconstruct_from_chunks(enhanced_chunks)
    return enhanced_waveform

# Load and preprocess the test audio
test_file = '/home/j597s263/Datasets/Audio/Test/Lossy'
waveform, sample_rate = torchaudio.load(test_file)

# Ensure stereo and match sample rate
if waveform.shape[0] != 2:
    waveform = torch.stack([waveform, waveform])  # Duplicate channel if mono
if sample_rate != 48000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=48000)
    waveform = resampler(waveform)

# Process the audio in batches
waveform = waveform.to(device)
enhanced_waveform = process_audio_in_batches(model, waveform, sample_rate)

# Save the enhanced output
output_file = "/home/j597s263/Datasets/Audio/Test/Enhanced.wav"
torchaudio.save(output_file, enhanced_waveform, sample_rate=48000)
print(f"Enhanced audio saved to: {output_file}")

Enhanced audio saved to: /home/j597s263/Datasets/Audio/Test/Enhanced.wav


In [None]:
testf = '/home/j597s263/Datasets/Audio/Test/Lossy'
waveform, sample_rate = torchaudio.load(testf)
# Ensure stereo and match sample rate
if waveform.shape[0] != 2:
    waveform = torch.stack([waveform, waveform])  # Duplicate channel if mono
if sample_rate != 48000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=48000)
    waveform = resampler(waveform)

# Prepare for model input
waveform = waveform.unsqueeze(0).to(device)  # Add batch dimension, shape: [1, 2, 48000]


In [None]:
# Generate enhanced audio
with torch.no_grad():
    enhanced_waveform = model(waveform)  # Shape: [1, 2, 48000]
enhanced_waveform = enhanced_waveform.squeeze(0).cpu()  # Remove batch dimension
