In [2]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.pyplot as plt

In [3]:
class AudioEnhancer(nn.Module):
    def __init__(self, num_transformer_layers=2, num_heads=8, cnn_filters=[32, 64, 128, 256]):
        super(AudioEnhancer, self).__init__()
        
        # CNN Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(2, cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[0], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[1], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[2], cnn_filters[3], kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Transformer
        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(
                d_model=cnn_filters[-1], 
                nhead=num_heads, 
                dim_feedforward=512, 
                activation='relu', 
                batch_first=True
            ),
            num_layers=num_transformer_layers
        )
        
        # CNN Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(cnn_filters[3], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[2], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[1], cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[0], 2, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Input: [batch_size, 2, 4800]
        
        # CNN Encoder
        x = self.encoder(x)  # Shape: [batch_size, cnn_filters[-1], 4800]
        
        # Permute for Transformer
        x = x.permute(0, 2, 1)  # Shape: [batch_size, 4800, cnn_filters[-1]]
        x = self.transformer(x)  # Shape: [batch_size, 4800, cnn_filters[-1]]
        x = x.permute(0, 2, 1)  # Shape: [batch_size, cnn_filters[-1], 4800]
        
        # CNN Decoder
        x = self.decoder(x)  # Shape: [batch_size, 2, 4800]
        
        return x



class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        # Compute perceptual features
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        
        # Perceptual loss
        perceptual_loss = self.mse_loss(pred_features, target_features)
        
        # Reconstruction loss
        reconstruction_loss = self.mse_loss(pred, target)
        
        return perceptual_loss + reconstruction_loss


# Dummy feature extractor for perceptual loss
class DummyFeatureExtractor(nn.Module):
    def __init__(self):
        super(DummyFeatureExtractor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(2, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.features(x)

In [4]:
import torch
model_save_path = '/home/j597s263/scratch/j597s263/Models/Audio/Exp_4.mod'
model = torch.load(model_save_path, weights_only=False)
model.eval()

print("Model Loaded")

Model Loaded


In [1]:
import torchaudio
import torch

def enhance_and_save_song(model, input_path, output_path, segment_duration=0.1, target_sample_rate=44000):
    """
    Enhances the quality of a lossy song using the trained model and saves the output.

    Args:
        model: The trained AudioEnhancer model (already loaded and in eval mode).
        input_path: Path to the lossy input audio file.
        output_path: Path to save the enhanced audio file.
        segment_duration: Duration of each segment in seconds (default 0.1).
        target_sample_rate: Target sample rate for processing (default 44kHz).
    """
    # Load the lossy audio
    waveform, sample_rate = torchaudio.load(input_path)
    print(f"Loaded lossy song with shape: {waveform.shape}, Sample rate: {sample_rate}")

    # Resample to target sample rate if needed
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
        sample_rate = target_sample_rate
        print(f"Resampled to target sample rate: {target_sample_rate} Hz")

    # Calculate segment size
    segment_size = int(sample_rate * segment_duration)

    # Split the waveform into segments
    segments = [
        waveform[:, i:i + segment_size]
        for i in range(0, waveform.shape[1], segment_size)
        if waveform[:, i:i + segment_size].shape[1] == segment_size
    ]

    print(f"Total segments: {len(segments)}")

    # Enhance each segment
    device = next(model.parameters()).device
    enhanced_segments = []
    for segment in segments:
        segment = segment.unsqueeze(0).to(device)  # Add batch dimension
        with torch.no_grad():
            enhanced_segment = model(segment)  # Enhance the segment
        enhanced_segments.append(enhanced_segment.squeeze(0).cpu())  # Remove batch dimension and move to CPU

    # Reconstruct the enhanced waveform
    enhanced_waveform = torch.cat(enhanced_segments, dim=1)
    print(f"Enhanced waveform shape: {enhanced_waveform.shape}")

    # Save the enhanced audio to a file
    torchaudio.save(output_path, enhanced_waveform, sample_rate)
    print(f"Enhanced song saved to {output_path}")

In [5]:
input_path = '/home/j597s263/scratch/j597s263/Datasets/Audio/Test/Test.wav'
output_path = '/home/j597s263/scratch/j597s263/Datasets/Audio/Test/exp4.wav'

# Ensure the model is in evaluation mode
model.eval()

# Enhance and save the song
enhance_and_save_song(model, input_path, output_path, segment_duration=0.1, target_sample_rate=44000)

Loaded lossy song with shape: torch.Size([2, 4523771]), Sample rate: 32000
Resampled to target sample rate: 44000 Hz
Total segments: 1413
Enhanced waveform shape: torch.Size([2, 6217200])
Enhanced song saved to /home/j597s263/scratch/j597s263/Datasets/Audio/Test/exp4.wav
