In [1]:
import torchaudio.transforms as T
import torch
import torchaudio
from torch.utils.data import Dataset
import os
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import random

In [2]:
class AudioDataset(Dataset):
    def __init__(self, lossless_dir, lossy_dir):
        """
        Initializes the dataset by processing all audio files in the given directories.
        Each observation is a pair of corresponding lossy and lossless stereo segments.
        """
        self.data = []

        # Get sorted lists of lossless and lossy files
        lossless_files = sorted(
            [os.path.join(lossless_dir, f) for f in os.listdir(lossless_dir) if os.path.isfile(os.path.join(lossless_dir, f))]
        )
        lossy_files = sorted(
            [os.path.join(lossy_dir, f) for f in os.listdir(lossy_dir) if os.path.isfile(os.path.join(lossy_dir, f))]
        )

        # Ensure equal number of files
        assert len(lossless_files) == len(lossy_files), "Mismatch in number of lossless and lossy files!"

        # Process files and create dataset
        for idx, (lossless_file, lossy_file) in enumerate(zip(lossless_files, lossy_files)):
            self.data.extend(self.process_pair(lossless_file, lossy_file))
            if (idx + 1) % 10 == 0:
                print(f"Processed {idx + 1} files...")

        print(f"Dataset created with {len(self.data)} segment pairs.")

    def process_pair(self, lossless_path, lossy_path):
        """
        Processes a pair of lossless and lossy files into aligned stereo segments.
        """
        # Preprocess lossless and lossy audio
        lossless_segments, lossless_segment_size = self.preprocess(lossless_path)
        lossy_segments, lossy_segment_size = self.preprocess(lossy_path)

        # Handle floating-point mismatch in segment counts
        if len(lossless_segments) > len(lossy_segments):
            for _ in range(len(lossless_segments) - len(lossy_segments)):
                # Add padded segments to lossy segments to match lossless count
                padding_segment = torch.zeros_like(lossless_segments[0])
                lossy_segments.append(padding_segment)

        # Randomly pad lossy segments to match lossless size
        padded_lossy_segments = [
            self.random_pad(lossy_segment, lossless_segment.shape[1])
            for lossy_segment, lossless_segment in zip(lossy_segments, lossless_segments)
        ]

        return list(zip(padded_lossy_segments, lossless_segments))

    def preprocess(self, file_path):
        """
        Loads an audio file, calculates dynamic segment size, and splits into stereo segments.
        """
        waveform, sample_rate = torchaudio.load(file_path)

        # Calculate dynamic segment size
        segment_size = int((sample_rate / 1000) * 2)

        # Split waveform into fixed-size segments
        segments = [
            waveform[:, i:i + segment_size]
            for i in range(0, waveform.shape[1], segment_size)
            if waveform[:, i:i + segment_size].shape[1] == segment_size
        ]
        return segments, segment_size

    def random_pad(self, segment, target_size):
        """
        Randomly pads the input segment to match the target size.
        Args:
            segment: Input segment of shape [2, current_size].
            target_size: Desired target size (number of samples).
        Returns:
            Padded segment of shape [2, target_size].
        """
        current_size = segment.shape[1]
        if current_size >= target_size:
            return segment  # No padding needed

        # Calculate padding size
        padding_size = target_size - current_size
        front_pad = random.randint(0, padding_size)
        back_pad = padding_size - front_pad

        # Apply padding
        return torch.nn.functional.pad(segment, (front_pad, back_pad))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns a single pair of lossy and lossless stereo segments.
        """
        return self.data[idx]

In [None]:
counter=0
dataset = AudioDataset(lossless_dir='/home/j597s263/scratch/j597s263/Datasets/Audio/Lossless/', lossy_dir='/home/j597s263/scratch/j597s263/Datasets/Audio/Lossy/')

print(f"Total number of observations (segment pairs): {len(dataset)}")

# Example: Retrieve a single observation
lossy_segment, lossless_segment = dataset[0]
print(f"Lossy segment shape: {lossy_segment.shape}")
print(f"Lossless segment shape: {lossless_segment.shape}")

Processed 10 files...
Processed 20 files...
Processed 30 files...
Processed 40 files...
Processed 50 files...
Processed 60 files...
Processed 70 files...
Processed 80 files...
Processed 90 files...
Processed 100 files...


In [7]:
dataset[0][0]

tensor([[ 0.0000,  0.0000,  0.0000,  ..., -0.0019, -0.0102, -0.0161],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0071, -0.0132, -0.0156]])

In [4]:
class AudioEnhancer(nn.Module):
    def __init__(self, num_transformer_layers=2, num_heads=8, cnn_filters=[32, 64, 128, 256]):
        super(AudioEnhancer, self).__init__()
        
        # CNN Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(2, cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[0], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[1], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[2], cnn_filters[3], kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Transformer
        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(d_model=cnn_filters[-1], nhead=num_heads, dim_feedforward=512, activation='relu'),
            num_layers=num_transformer_layers
        )
        
        # CNN Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(cnn_filters[3], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[2], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[1], cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[0], 2, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Input: [batch_size, 2, 48000]
        
        # CNN Encoder
        x = self.encoder(x)  # Shape: [batch_size, cnn_filters[-1], 48000]
        
        # Permute for Transformer
        x = x.permute(0, 2, 1)  # Shape: [batch_size, 48000, cnn_filters[-1]]
        x = self.transformer(x)  # Shape: [batch_size, 48000, cnn_filters[-1]]
        x = x.permute(0, 2, 1)  # Shape: [batch_size, cnn_filters[-1], 48000]
        
        # CNN Decoder
        x = self.decoder(x)  # Shape: [batch_size, 2, 48000]
        
        return x


In [5]:
class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        # Compute perceptual features
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        
        # Perceptual loss
        perceptual_loss = self.mse_loss(pred_features, target_features)
        
        # Reconstruction loss
        reconstruction_loss = self.mse_loss(pred, target)
        
        return perceptual_loss + reconstruction_loss

In [6]:
def train_model(model, dataloader, optimizer, loss_fn, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for lossy, lossless in dataloader:
            # Move to GPU if available
            lossy, lossless = lossy.to(device), lossless.to(device)
            
            # Forward pass
            output = model(lossy)
            
            # Compute loss
            loss = loss_fn(output, lossless)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")        

In [7]:
audio = DataLoader(dataset, batch_size=4, shuffle=True)

In [9]:
device = 'cuda:1'
model = AudioEnhancer(num_transformer_layers=2, num_heads=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [11]:
class DummyFeatureExtractor(nn.Module):
    def __init__(self):
        super(DummyFeatureExtractor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(2, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.features(x)

# Initialize feature extractor
feature_extractor = DummyFeatureExtractor().to(device)
loss_fn = PerceptualLoss(feature_extractor).to(device)

In [12]:
# Train the model
train_model(model, audio, optimizer, loss_fn, num_epochs=10)

Epoch 1/10, Loss: 0.0048
Epoch 2/10, Loss: 0.0015
Epoch 3/10, Loss: 0.0014
Epoch 4/10, Loss: 0.0013
Epoch 5/10, Loss: 0.0013
Epoch 6/10, Loss: 0.0012
Epoch 7/10, Loss: 0.0012
Epoch 8/10, Loss: 0.0012
Epoch 9/10, Loss: 0.0012
Epoch 10/10, Loss: 0.0012


In [13]:
model

AudioEnhancer(
  (encoder): Sequential(
    (0): Conv1d(2, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (3): ReLU()
    (4): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): ReLU()
    (6): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (7): ReLU()
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dro

In [14]:
torch.save(model, '/home/j597s263/Models/Audio.mod')