In [3]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.pyplot as plt


class AudioEnhancer(nn.Module):
    def __init__(self, num_transformer_layers=2, num_heads=8, cnn_filters=[32, 64, 128, 256]):
        super(AudioEnhancer, self).__init__()
        
        # CNN Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(2, cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[0], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[1], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[2], cnn_filters[3], kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Transformer
        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(
                d_model=cnn_filters[-1], 
                nhead=num_heads, 
                dim_feedforward=512, 
                activation='relu', 
                batch_first=True
            ),
            num_layers=num_transformer_layers
        )
        
        # CNN Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(cnn_filters[3], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[2], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[1], cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[0], 2, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Input: [batch_size, 2, 4800]
        
        # CNN Encoder
        x = self.encoder(x)  # Shape: [batch_size, cnn_filters[-1], 4800]
        
        # Permute for Transformer
        x = x.permute(0, 2, 1)  # Shape: [batch_size, 4800, cnn_filters[-1]]
        x = self.transformer(x)  # Shape: [batch_size, 4800, cnn_filters[-1]]
        x = x.permute(0, 2, 1)  # Shape: [batch_size, cnn_filters[-1], 4800]
        
        # CNN Decoder
        x = self.decoder(x)  # Shape: [batch_size, 2, 4800]
        
        return x



class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        # Compute perceptual features
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        
        # Perceptual loss
        perceptual_loss = self.mse_loss(pred_features, target_features)
        
        # Reconstruction loss
        reconstruction_loss = self.mse_loss(pred, target)
        
        return perceptual_loss + reconstruction_loss


# Dummy feature extractor for perceptual loss
class DummyFeatureExtractor(nn.Module):
    def __init__(self):
        super(DummyFeatureExtractor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(2, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.features(x)

In [4]:
import torch
model_save_path = '/home/j597s263/scratch/j597s263/Models/Lad_1.mod'
model = torch.load(model_save_path, weights_only=False)
model.eval()

print("Model Loaded")

Model Loaded


In [8]:
import torchaudio
import torch

def enhance_song(model, lossy_path, segment_duration, target_segment_size, save_path):
    """
    Enhances the quality of a lossy song using the trained model.

    Args:
        model: The trained AudioEnhancer model.
        lossy_path: Path to the lossy audio file.
        segment_duration: Duration of each segment in seconds (0.1 in this case).
        target_segment_size: The target size for each segment (lossless size).
        save_path: Path to save the enhanced audio.

    Returns:
        None
    """
    # Load the lossy audio
    waveform, sample_rate = torchaudio.load(lossy_path)
    print(f"Loaded lossy song with shape: {waveform.shape}, Sample rate: {sample_rate}")

    # Calculate the lossy segment size
    lossy_segment_size = int(sample_rate * segment_duration)

    # Split the lossy waveform into segments
    lossy_segments = [
        waveform[:, i:i + lossy_segment_size]
        for i in range(0, waveform.shape[1], lossy_segment_size)
        if waveform[:, i:i + lossy_segment_size].shape[1] == lossy_segment_size
    ]

    print(f"Total lossy segments: {len(lossy_segments)}")

    # Initialize the model for inference
    model.eval()
    device = next(model.parameters()).device

    # Enhance each segment
    enhanced_segments = []
    for segment in lossy_segments:
        # Pad the lossy segment to match the target segment size
        padded_segment = even_distribute_pad(segment, target_segment_size)
        padded_segment = padded_segment.unsqueeze(0).to(device)  # Add batch dimension

        # Enhance the segment
        with torch.no_grad():
            enhanced_segment = model(padded_segment)  # Forward pass
        enhanced_segments.append(enhanced_segment.squeeze(0).cpu())  # Remove batch dim

    # Reconstruct the enhanced waveform
    enhanced_waveform = torch.cat(enhanced_segments, dim=1)
    print(f"Enhanced waveform shape: {enhanced_waveform.shape}")

    # Save the enhanced song
    torchaudio.save(save_path, enhanced_waveform, sample_rate)
    print(f"Enhanced song saved to {save_path}")


def even_distribute_pad(segment, target_size):
    """
    Evenly distributes zeros between elements to match the target size.
    Operates directly in PyTorch without any NumPy conversions.
    """
    current_size = segment.shape[1]
    if current_size >= target_size:
        return segment  # No padding needed

    # Create an empty tensor with target size initialized to zero
    padded_segment = torch.zeros((segment.shape[0], target_size), dtype=segment.dtype, device=segment.device)

    # Calculate indices to place original values
    step = target_size / current_size
    indices = (torch.arange(current_size) * step).long()

    # Assign original values to the calculated indices
    padded_segment[:, indices] = segment

    return padded_segment

In [9]:
model_path = '/home/j597s263/scratch/j597s263/Models/Lad_1.mod'
lossy_audio_path = '/home/j597s263/scratch/j597s263/Datasets/Audio/Test/Test.wav'
save_path = '/home/j597s263/scratch/j597s263/Datasets/Audio/Test/Test(en).wav'

model = torch.load(model_path, weights_only=False)
model.eval()

# Define parameters
segment_duration = 0.1  # Match the dataset configuration
target_segment_size = int(48000 * segment_duration)  

# Enhance the song
enhance_song(model, lossy_audio_path, segment_duration, target_segment_size, save_path)

Loaded lossy song with shape: torch.Size([2, 4523771]), Sample rate: 32000
Total lossy segments: 1413
Enhanced waveform shape: torch.Size([2, 6782400])
Enhanced song saved to /home/j597s263/scratch/j597s263/Datasets/Audio/Test/Test(en).wav
