In [7]:
import os
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import Resample
import numpy as np
import soundfile as sf

In [8]:
# Define the Denoising Autoencoder
class DenoisingAutoencoder(nn.Module):
    def __init__(self):
        super(DenoisingAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [9]:
# Define U-Net architecture for the generator
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        encoding = self.encoder(x)
        reconstruction = self.decoder(encoding)
        return reconstruction



In [10]:
# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(512 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)



In [11]:
# Pipeline Functions
def denoise_audio(autoencoder, noisy_audio, device='cuda'):
    noisy_audio = noisy_audio.to(device)
    with torch.no_grad():
        denoised_audio = autoencoder(noisy_audio)
    return denoised_audio

def inpaint_audio(generator, denoised_audio, device='cuda'):
    denoised_audio = denoised_audio.to(device)
    with torch.no_grad():
        inpainted_audio = generator(denoised_audio)
    return inpainted_audio



In [16]:
# Dataset Class with Denoising and Inpainting
class DistortedAudioDataset(torch.utils.data.Dataset):
    def __init__(self, distorted_audio_folder, fixed_length=16000):
        self.distorted_audio_files = sorted(os.listdir(distorted_audio_folder))
        self.distorted_audio_folder = distorted_audio_folder
        self.fixed_length = fixed_length

    def __len__(self):
        return len(self.distorted_audio_files)

    def _process_audio(self, audio):
        # Convert to mono if the audio has more than one channel
        if audio.size(0) > 1:
            audio = torch.mean(audio, dim=0, keepdim=True)
        # Pad or truncate audio to fixed length
        if audio.size(1) > self.fixed_length:
            audio = audio[:, :self.fixed_length]
        elif audio.size(1) < self.fixed_length:
            padding = self.fixed_length - audio.size(1)
            audio = nn.functional.pad(audio, (0, padding), 'constant', 0)
        return audio

    def __getitem__(self, idx):
        distorted_path = os.path.join(self.distorted_audio_folder, self.distorted_audio_files[idx])
        distorted_audio, _ = torchaudio.load(distorted_path)
        distorted_audio = self._process_audio(distorted_audio)
        return distorted_audio, self.distorted_audio_files[idx]

# Main Execution
if __name__ == "__main__":
    distorted_audio_folder = r"E:\UCSC Quarters\1. First Quarter - Fall 2024\CSE290D-Neural Computation\Historical Music Audio Restoration\HarmonyGAN\fma_small\fma_small\000"
    output_folder = "reconstructed_folder"
    model_save_folder = "Models"

    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(model_save_folder, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load or define models
    autoencoder = DenoisingAutoencoder().to(device)
    generator = UNetGenerator().to(device)
    discriminator = Discriminator().to(device)

    # Load pre-trained models if available
    # autoencoder.load_state_dict(torch.load('autoencoder.pth'))
    # generator.load_state_dict(torch.load('generator.pth'))

    dataset = DistortedAudioDataset(distorted_audio_folder)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    for distorted_audio, file_name in dataloader:
        distorted_audio = distorted_audio.to(device)

        # Step 1: Denoise
        denoised_audio = denoise_audio(autoencoder, distorted_audio, device)

        # Step 2: Inpaint
        reconstructed_audio = inpaint_audio(generator, denoised_audio, device)

        # Save Reconstructed Audio
        reconstructed_audio = reconstructed_audio.squeeze(0).detach().cpu().numpy()
        reconstructed_audio = np.clip(reconstructed_audio, -1.0, 1.0)
        sf.write(os.path.join(output_folder, file_name[0]), reconstructed_audio.T, 16000)

    # Save Models
    torch.save(autoencoder.state_dict(), os.path.join(model_save_folder, "autoencoder.pth"))
    torch.save(generator.state_dict(), os.path.join(model_save_folder, "generator.pth"))
    torch.save(discriminator.state_dict(), os.path.join(model_save_folder, "discriminator.pth"))
    print("Models and reconstructed audio saved successfully.")

Models and reconstructed audio saved successfully.
