Todd Goldfarb

tcgoldfarb@gmail.com


### Code

In [None]:
### Imports
!pip install numpy matplotlib librosa pydub torch soundfile


import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import soundfile
import pydub
from pydub import AudioSegment
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as functional
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import json
import os



In [None]:
############################################
### AUDIO/SPECTROGRAM RELATED OPERATIONS ###
############################################

### CONVERSION FUNCTIONS ###
def ConvertWAVToArray(audioPath):
  audio, sr = librosa.load(audioPath, sr=None)
  return audio, sr

def ConvertToTensor(np):
  tensor = torch.from_numpy(np).float()
  # ADD THE CHANNEL
  tensor = tensor.unsqueeze(0)
  return tensor

def ConvertTensorToSpectrogram(tensor):
  # NOTE IF ON GPU (PROB) NEED TO ADD .cpu()
  spectrogram = tensor.squeeze().detach().numpy()
  return spectrogram

def ConvertArrayToSpectrogram(array):
  # Compute STFT
  stft = librosa.stft(array)

  # Extract magnitude and phase
  magnitude = np.abs(stft)
  phase = np.angle(stft)

  # Convert magnitude to decibel
  spectrogram = librosa.amplitude_to_db(magnitude, ref=np.max)

  return spectrogram, phase

### OUTPUT THE SPECTROGRAM TO A WAV FILE WITH A SPECIFIED OUTPUT PATH ###
def OutputSpectrogramToWAV(spectrogram_db, phase, sr, outputPath):
    # Convert dB back to amplitude
    spectrogram_amplitude = librosa.db_to_amplitude(spectrogram_db, ref=15.0)
    print(type(spectrogram_amplitude), type(phase))
    # Reconstruct the complex-valued STFT from magnitude and phase
    stft_complex = spectrogram_amplitude * np.exp(1j * phase)

    # Perform the inverse STFT
    audio_reconstructed = librosa.istft(stft_complex)

    # Write to WAV
    soundfile.write(outputPath, audio_reconstructed, sr)
    return

### DAMAGES (REMOVES) AUDIO AT SPECIFIC TIME FRAME ###
def DamageSpectrogram(spectrogram, sr, n_fft=2048, hop_length=512):

    time_step = hop_length / sr
    # Calculate the total duration of the audio in seconds
    audio_duration = spectrogram.shape[1] * time_step

    startSec = 2.25
    endSec = 2.50

    startCol = int(startSec / time_step)
    endCol = int(endSec / time_step)

    damagedSpectrogram = np.copy(spectrogram)

    # Apply damage (set to zeros in this case)
    damagedSpectrogram[:, startCol:endCol] = 0

    return damagedSpectrogram, startSec, endSec

def PlotSpectrogram(spectrogram, sr, isDamaged=False):
  # PLOT SPECTROGRAM
  plt.figure(figsize=(12, 8))
  librosa.display.specshow(spectrogram, sr=sr, x_axis='time', y_axis='log')
  plt.colorbar(format='%+2.0f dB')
  if isDamaged:
    plt.title('Spectrogram (Broken Audio)')
  else:
    plt.title('Spectrogram (Clean Audio)')
  plt.tight_layout()
  plt.show()

In [None]:
##########################
### DATASET GENERATION ###
##########################

## PREPARE AUDIO LIST
def SegmentAudio(audioPathList):
  os.makedirs("Audio_Inputs_Segmented", exist_ok=True)

  preparedAudioPathList = []
  count = -1

  for audioPath in enumerate(audioPathList):
     # Load the audio file
    audio, sr = librosa.load(audioPath[1], sr=None)

    secPerChunk = 3.0
    ## SECONDS PER CHUNK
    samplesPerChunk = int(secPerChunk * sr)

    # CALCULATE NUMBER OF CHUNKS
    totalChunks = np.ceil(len(audio) / samplesPerChunk).astype(int) - 1
    count += 1
    for i in range(totalChunks):
        startSample = i * samplesPerChunk
        endSample = startSample + samplesPerChunk
        chunk = audio[startSample:endSample]

        # OUTPUT CHUNK
        outputPath = os.path.join("Audio_Inputs_Segmented", 'input_' + str(count) + '_' + str(i) + '.wav')
        soundfile.write(outputPath, chunk, sr)
        preparedAudioPathList.append(outputPath)

  return preparedAudioPathList


### GENERATES A FULL DATAPOINT
def GenerateDataPoint(audioPath):
  cleanArray, sr = ConvertWAVToArray(audioPath)

  ### CONVERT TO SPECTROGRAM
  cleanSpectrogram, phase = ConvertArrayToSpectrogram(cleanArray)
  damagedSpectrogram, startSec, endSec = DamageSpectrogram(cleanSpectrogram, sr)

  ### CONVERT ALL TO TENSORS
  cleanSpecTensor = ConvertToTensor(cleanSpectrogram)
  damagedSpecTensor = ConvertToTensor(damagedSpectrogram)

  #PlotSpectrogram(cleanSpectrogram, sr)
  #PlotSpectrogram(damagedSpectrogram, sr, True)

  return sr, phase, cleanSpecTensor, damagedSpecTensor, startSec, endSec

### GENERATES BOTH GENERATOR AND DISCRIMINATOR DATASETS
### PASS IN A LIST OF AUDIOPATHS and A NUMBER OF SAMPLES PER AUDIO PIECE
def GenerateDatasets(audioPathList, samplesPerAudio=1):

  GeneratorDataset = []
  DiscriminatorDataset = []

  ## PREPARE THE audioPathList INTO 5 SECOND AUDIO CLIPS
  #preparedAudioPathList = SegmentAudio(audioPathList)
  # TEMPORARY TESTING
  preparedAudioPathList = audioPathList

  for audioPath in enumerate(preparedAudioPathList):
    for i in range(samplesPerAudio):
      sr, phase, cleanSpecTensor, damagedSpecTensor, startTime, endTime = GenerateDataPoint(audioPath[1])
      ### GENERATOR DATASET
      GeneratorDataset.append([audioPath[1], # For output filePath
                               sr, # For converting back to WAV
                               cleanSpecTensor, # Used for actualLoss
                               damagedSpecTensor, # Used as contextInput
                               startTime, # Used as a startTime input
                               endTime, # Used as an endTime input
                               phase]) # Used for WAV output
      ### DISCRIMINATOR DATASET
      DiscriminatorDataset.append([audioPath[1], # For output filePath
                                   sr, # For converting back to WAV
                                   cleanSpecTensor, # Used for realLoss
                                   damagedSpecTensor, # Used for generating fake spectrogram for fakeLoss
                                   startTime,
                                   endTime,
                                   phase]) # Used for WAV output

  return GeneratorDataset, DiscriminatorDataset


## THE AUDIO WILL BE CUT UP INTO 5 SECOND SEGMENTS
## THE DAMAGE WILL OCCUR AT THE TIMESTAMP 2.25 - 2.75
audioPathList = ["dataset1.wav", "dataset2.wav", "dataset3.wav", "dataset4.wav"]
samplesPerAudio = 1
GeneratorDataset, DiscriminatorDataset = GenerateDatasets(audioPathList, samplesPerAudio)

  audio, sr = librosa.load(audioPath, sr=None)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


FileNotFoundError: [Errno 2] No such file or directory: 'dataset1.wav'

In [None]:
### MODEL DEFINITIONS ###


### Feature map way too big for 1 second of audio
###w Output being stretched by stride? - MAYBE
### Capping high and low values (LOW VALUES DONE, HIGH VALUES PROP NOT HELPFUL)
### Batch normalization?
### activation functions?
## proper DB scaling?
# Phase?
# More data?
## ADD PADDING INSTEAD OF INTERPOLATING (could be stretching out those values)
## New loss criterion

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # ENCODER/DOWNSAMPLING
        self.Encoder1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(3, 3), stride=1, padding=1, padding_mode='replicate')
        self.Encoder2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=2, padding=1, padding_mode='replicate')
        self.Encoder3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=2, padding=1, padding_mode='replicate')
        self.Encoder4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=2, padding=1, padding_mode='replicate')
        self.Encoder5 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3, 3), stride=2, padding=1, padding_mode='replicate')
        self.Encoder6 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=(3, 3), stride=2, padding=1, padding_mode='replicate')
        self.Encoder7 = nn.Conv2d(in_channels=2048, out_channels=4096, kernel_size=(3, 3), stride=2, padding=1, padding_mode='replicate')

        # DECODER/UPSAMPLING
        self.Decoder1 = nn.ConvTranspose2d(in_channels=4096, out_channels=2048, kernel_size=(3, 3), stride=2, padding=1)
        self.Decoder2 = nn.ConvTranspose2d(in_channels=2048, out_channels=1024, kernel_size=(3, 3), stride=2, padding=1)
        self.Decoder3 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=(3, 3), stride=2, padding=1)
        self.Decoder4 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=(3, 3), stride=2, padding=1)
        self.Decoder5 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=(3, 3), stride=2, padding=1)
        self.Decoder6 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 3), stride=2, padding=1)
        self.Decoder7 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=(3, 3), stride=1, padding=1)


        # SKIP CONNECTIONS
        self.Skip64 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=1, padding=0, padding_mode='replicate')
        self.Skip128 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=1, padding=0, padding_mode='replicate')
        self.Skip256 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=1, padding=0, padding_mode='replicate')
        self.Skip512 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=1, padding=0, padding_mode='replicate')
        self.Skip1024 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3, 3), stride=1, padding=1, padding_mode='replicate')
        self.Skip2048 = nn.Conv2d(in_channels=2048, out_channels=2048, kernel_size=(3, 3), stride=1, padding=1, padding_mode='replicate')

    def forward(self, contextInput, startTime, endTime, sr):
        # ENCODE
        #print(contextInput.shape)
        E1Output = (self.Encoder1(contextInput))
        #print(E1Output.shape)
        E2Output = (self.Encoder2(E1Output))
        #print(E2Output.shape)
        E3Output = (self.Encoder3(E2Output))
        #print(E3Output.shape)
        E4Output = (self.Encoder4(E3Output))
        #print(E4Output.shape)
        E5Output = (self.Encoder5(E4Output))
        #print(E5Output.shape)
        E6Output = (self.Encoder6(E5Output))
        #print(E6Output.shape)
        E7Output = (self.Encoder7(E6Output))
        #print(E7Output.shape)

        #DECODE (AND NO SKIP YET)
        D1Input = functional.interpolate(self.Skip2048(E6Output), size=self.Decoder1(E7Output).shape[2:], mode='bilinear', align_corners=False)
        D1Output = (self.Decoder1(E7Output) + D1Input)
        #D1Output = (self.Decoder1(E7Output))

        D2Input = functional.interpolate(self.Skip1024(E5Output), size=self.Decoder2(D1Output).shape[2:], mode='bilinear', align_corners=False)
        D2Output = (self.Decoder2(D1Output) + D2Input)
        #D2Output = (self.Decoder2(D1Output))

        D3Input = functional.interpolate(self.Skip512(E4Output), size=self.Decoder3(D2Output).shape[2:], mode='bilinear', align_corners=False)
        D3Output = (self.Decoder3(D2Output) + D3Input)
        #D3Output = (self.Decoder3(D2Output))

        D4Input = functional.interpolate(self.Skip256(E3Output), size=self.Decoder4(D3Output).shape[2:], mode='bilinear', align_corners=False)
        D4Output = (self.Decoder4(D3Output) + D4Input)
        #D4Output = (self.Decoder4(D3Output))

        D5Input = functional.interpolate(self.Skip128(E2Output), size=self.Decoder5(D4Output).shape[2:], mode='bilinear', align_corners=False)
        D5Output = (self.Decoder5(D4Output) + D5Input)
        #D5Output = (self.Decoder5(D4Output))

        D6Input = functional.interpolate(self.Skip64(E1Output), size=self.Decoder6(D5Output).shape[2:], mode='bilinear', align_corners=False)
        D6Output = (self.Decoder6(D5Output) + D6Input)
        #D6Output = (self.Decoder6(D5Output))

        generatedChunk = (self.Decoder7(D6Output))
        scaledOutput = generatedChunk

        #output = functional.interpolate(generatedChunk, size=noiseChunk.shape[2:], mode='bilinear', align_corners=False)

        ## SCALE THE OUTPUT BACK TO PROPER DB
        #minDB = -80
        #maxDB = 10
        #scaledOutput = output * (((maxDB - minDB) / 2) + ((maxDB + minDB) / 2))

        ## SPLICE THE PREDICTION BACK INTO THE ORIGINAL SPECTROGRAM
        ## CREATING A "HEALED" SPECTROGRAM
        hopLength = 512

        timeStep = hopLength / sr

        audio_duration = scaledOutput.shape[3] * timeStep
        startCol = int(startTime / timeStep)
        endCol = int(endTime / timeStep)

        finalOutput = contextInput.clone()[:, 0:1, : :]
        finalOutput[:, :, :, startCol:endCol] = scaledOutput[:, :, :, startCol:endCol]

        #return scaledOutput
        return finalOutput

In [None]:
### TRAINING ###
Generator = Generator()
#Discriminator = Discriminator()

### HYPER PARAMETERS
batchSize = 1
numWorkers = 2
numEpochs = 100
learningRate = 0.0002

# LOSS AND OPTIMIZERS
optimizerGenerator = optim.Adam(Generator.parameters(), lr=learningRate, betas=(0.5, 0.999))
#optimizerDiscriminator = optim.Adam(Discriminator.parameters(), lr=learningRate, betas=(0.5, 0.999))

# TRAINING LOOP
DataloaderG = DataLoader(GeneratorDataset, batch_size=batchSize, shuffle=True, num_workers=numWorkers)
### DataloaderG[i][0] is the audioPath
### DataloaderG[i][1] is the sr (used as Input)
### DataloaderG[i][2] is the cleanSpecTensor (Used for actualLoss)
### DataloaderG[i][3] is the damagedSpecTensor (Used as Context Input)
### DataloaderG[i][4] is the startTime input (Used as a startTime Input)
### DataloaderG[i][5] is the endTime input (Used as an endTime Input)

for epoch in range(numEpochs):
  print("------------ " + "EPOCH: " + str(epoch) + " ------------")
  BCELoss = nn.BCELoss()
  MSELoss = nn.MSELoss()

  ### TRAIN GENERATOR ###
  for GeneratorData in DataloaderG:
    Generator.zero_grad()
    genOutput = Generator.forward(GeneratorData[3],
                                  GeneratorData[4],
                                  GeneratorData[5],
                                  GeneratorData[1])

    ### THE GENERATED VS REAL ###
    actualLoss = MSELoss(genOutput, GeneratorData[2])

    actualWeight = 1
    trueLoss = (actualLoss * actualWeight)
    print("GENERATOR STEP: trueLoss = ")
    print(trueLoss)
    trueLoss.backward()

    optimizerGenerator.step()

In [None]:
### TESTING POST-TRAINING
testURL = "Audio_Inputs_Segmented/input_0_0.wav"
sr, phase, cleanSpecTensor, damagedSpecTensor, startTime, endTime = GenerateDataPoint(testURL)

Generator.eval()
with torch.no_grad():
  output = Generator.forward(damagedSpecTensor.unsqueeze(0), startTime, endTime, sr)
  output = torch.clamp(output, min=-80)

print(output.shape)

outputSpectrogram = ConvertTensorToSpectrogram(output)
targetSpectrogram = ConvertTensorToSpectrogram(cleanSpecTensor)
PlotSpectrogram(targetSpectrogram, sr)
inputSpectrogram = ConvertTensorToSpectrogram(damagedSpecTensor[0:1, :, :].unsqueeze(0))
PlotSpectrogram(inputSpectrogram, sr, True)


PlotSpectrogram(outputSpectrogram, sr, True)
print(outputSpectrogram.shape)
print(type(outputSpectrogram))
OutputSpectrogramToWAV(outputSpectrogram, phase, sr, "Audio_Inputs_Segmented/input_0_0.wav")

In [None]:
### Accuracy Calculations

originalArray = inputSpectrogram
reconstructedArray = outputSpectrogram

# ELEMENT-WISE COMPARISON
comparison = originalArray == reconstructedArray

# Count TRUE values
correctElements = np.sum(comparison)

totalElements = np.size(originalArray)

accuracyPercentage = (correctElements / totalElements) * 100

print(f"Accuracy: {accuracyPercentage}%")