### Drum sound VAE
This notebook contains the implementation of a Variational Autoencoder (VAE) for generating drum sounds.

## Project Set Up

In [None]:
!pip install pytorch-ignite torchvision

In [1]:
import torch
from torch import nn
import torchaudio
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import glob
import torchaudio.transforms as transforms
import numpy as np
import librosa
from torchvision.utils import make_grid
from IPython.display import Audio, display
from sklearn.preprocessing import MinMaxScaler
import soundfile as sf

In [None]:
# Set manual seed to ensure reproducability
SEED = 1221
torch.manual_seed(SEED)

# Get GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

## Train Audio VAE

In [None]:
# Unzip kick samples
import zipfile

with zipfile.ZipFile("/content/4000Kicks.zip") as archive:
  archive.extractall("/content")

In [None]:
# Save log power spectrums to files and get normalisation scaler
def compute_specs(path, new_SR, fft_length, hop_length):

  # Get relevent audio paths
  audio_file_paths = glob.glob(path)

  # For temporarily storing magnitudes
  magnitude_list = []

  # Create hann window
  hann_window = torch.hann_window(fft_length)

  # For each audio file
  for audio_file_path in audio_file_paths:
      # Get the waveform
      waveform, sample_rate = torchaudio.load(audio_file_path)

      # Convert to mono
      waveform = torch.mean(waveform, dim=0).unsqueeze(0)

      # resample to new sample rate
      transform = transforms.Resample(sample_rate, new_SR)
      waveform = transform(waveform)

      # Pad or truncate audio to be 1 second long
      pad_width = new_SR - waveform.size(1)

      if pad_width > 0:
          waveform = torch.nn.functional.pad(waveform, (0, pad_width), 'constant', 0)
      else:
          waveform = waveform[:, :new_SR]

          # Add fade if trancating
          fade_length=int(new_SR/20)
          fade_out = torch.linspace(1, 0, fade_length)
          waveform[:, -fade_length:] *= fade_out

      # Get STFT
      stft = torch.stft(waveform, fft_length, hop_length, return_complex=True, window=hann_window)

      # Take the magnitude of the complex numbers
      magnitude = torch.abs(stft)

      # Take the log of the magnitudes
      log_magnitude = torch.log(magnitude + 1e-20)  # Adding a small constant to avoid log(0)

      # Save log FFT tensor
      magnitude_list.append(log_magnitude)

  # Find min and max in log magnitude
  all_magnitudes = torch.cat(magnitude_list, dim=0)
  min_value = torch.min(all_magnitudes)
  max_value = torch.max(all_magnitudes)

  # Normalize using MinMaxScaler
  scaler = MinMaxScaler()

  # Fit the scaler with min and max values
  scaler.fit([[min_value.item()], [max_value.item()]])
  index = 0
  for audio_file_path, log_magnitude in zip(audio_file_paths, magnitude_list):
      fft_file_name = audio_file_path[:-3] + "pt"
      # Normalize with MinMaxScaler
      normalized_magnitude = torch.from_numpy(scaler.transform(log_magnitude.reshape(-1, 1)).reshape(log_magnitude.shape))
      # Save as FFT file name
      torch.save(normalized_magnitude, fft_file_name)
      index += 1
  return scaler


new_SR = 44100
fft_length = 512
hop_length = 256

# Run function for snares and kicks
path = r'/content/4000Kicks/*.wav'
kick_scaler = compute_specs(path, new_SR, fft_length, hop_length)

In [3]:
# VAE dataset class
class AudioDataset(Dataset):
    def __init__(self, dir):
        # Get locations of log specs
        path = dir+"*.pt"
        self.spec_file_paths = glob.glob(path)

    def __len__(self):
        return len(self.spec_file_paths)

    def __getitem__(self, index):
        # Load in and return log spec
        spec = torch.load(self.spec_file_paths[index]).to(torch.float32)
        return spec

In [None]:
def set_up_dataloaders(batch_size, data, train_split=0.8):

    # train/test split
    train_size = int(train_split * len(data))
    test_size = len(data) - train_size

    print(f"Train size: {train_size}")
    print(f"Test size: {test_size}\n")

    train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, val_loader

batch_size = 32

data = AudioDataset("/content/4000Kicks/")
kick_train_loader, kick_val_loader = set_up_dataloaders(batch_size, data)

In [5]:
class audio_VAE(nn.Module):
    def __init__(self, num_channels=16, drop_out=0.2, latent_size=4):
        super().__init__()
        # Set base number of channels
        self.num_channels = num_channels

        # Output size list
        self.desired_size = [None,None,None,None]

        # Conv Layers
        self.conv11 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
        )
        self.conv12 = nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
        )
        self.conv13 = nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
        )
        self.conv14 = nn.Sequential(
            nn.Conv2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
        )
        self.conv15 = nn.Sequential(
            nn.Conv2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
        )


        # Flatten
        self.flatten = nn.Flatten()

        # Encoder fully connected layers
        self.fc1 = nn.Sequential(
            nn.Linear(8*3*num_channels*2, latent_size)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(8*3*num_channels*2, latent_size)
        )

        # Decoder fully connected layer
        self.fc3 = nn.Sequential(
            nn.Linear(latent_size, 8*3*num_channels*2)
        )

        # Deconv Layers
        self.conv21 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            )
        self.conv22 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            )
        self.conv23 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=num_channels*2, out_channels=num_channels*2, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=num_channels*2, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            )
        self.conv24 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=num_channels, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=num_channels, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            )
        self.conv25 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=num_channels, out_channels=num_channels, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=num_channels, out_channels=1, kernel_size=(3,3), bias=False),
            nn.ReLU(),
            # Set output between 0 and 1
            nn.Sigmoid()
            )

        # Initialise weights
        self.apply(self._init_weights)

    # Encoder function
    # Pass through convolution and pooling layers
    # Save desired size before each pooling for decoder
    # Flatten output
    def encode(self, x):

      x = self.conv11(x)
      self.desired_size[0] = [x.shape[2], x.shape[3]]
      x = F.max_pool2d(x, kernel_size=(2, 2), padding=0)
      x = self.conv12(x)
      self.desired_size[1] = [x.shape[2], x.shape[3]]
      x = F.max_pool2d(x, kernel_size=(2, 2), padding=0)
      x = self.conv13(x)
      self.desired_size[2] = [x.shape[2], x.shape[3]]
      x = F.max_pool2d(x, kernel_size=(2, 2), padding=0)
      x = self.conv14(x)
      self.desired_size[3] = [x.shape[2], x.shape[3]]
      x = F.max_pool2d(x, kernel_size=(2, 2), padding=0)
      x = self.conv15(x)

      x = self.flatten(x)
      return x

    # Decoder function
    # Fully connected layer
    # Unflatten
    # Pass through convolutional layers and upsample
    def decode(self, x):
        x = self.fc3(x)
        x = x.view(-1, self.num_channels*2, 8, 3)
        x = self.conv21(x)
        x = nn.Upsample(size=self.desired_size[3], mode='bilinear')(x)
        x = self.conv22(x)
        x = nn.Upsample(size=self.desired_size[2], mode='bilinear')(x)
        x = self.conv23(x)
        x = nn.Upsample(size=self.desired_size[1], mode='bilinear')(x)
        x = self.conv24(x)
        x = nn.Upsample(size=self.desired_size[0], mode='bilinear')(x)
        x = self.conv25(x)
        return x

    # Init weights
    # Function from Queen Mary Deep Learning for Audio and Music module
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            torch.nn.init.kaiming_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.BatchNorm2d):
            torch.nn.init.ones_(module.weight)

    # Convert latent space to log mag
    def generate(self, z):
      reconstructed = self.decode(z)
      return reconstructed

    # Convert mu and logvar to latent representation
    def reparam(self, mu, logvar):
      std = torch.exp(0.5*logvar)
      eps = torch.randn_like(std)
      z = mu + eps*std
      return z

    # Get Latent representation from input audio
    def get_latent(self, x):
      x = self.encode(x)
      mu = self.fc1(x)
      logvar = self.fc2(x)
      z = self.reparam(mu, logvar)
      return z

    # Foward pass
    def forward(self, x):
        # Encode
        x = self.encode(x)

        # Get mu and logvar
        mu = self.fc1(x)
        logvar = self.fc2(x)

        # Get latent representation
        z = self.reparam(mu, logvar)

        # Decode
        reconstructed = self.decode(z)

        return reconstructed, mu, logvar

In [None]:
# Set up kick model and optimizer

kick_model = audio_VAE().to(device)

kick_optimizer = optim.Adam(kick_model.parameters(), lr=1e-3)

# Loss Function

bce_loss = nn.BCELoss(reduction='sum')

In [None]:
# *** NOT ORIGINAL CODE ***
# Original from https://arxiv.org/abs/1312.6114
# Code from https://medium.com/@judyyes10/generate-images-using-variational-autoencoder-vae-4d429d9bdb5
# Kullback-Leibler divergence
def kld_loss(x_pred, x, mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

In [None]:
# Evaluate model Performance
def evaluate(model, data_loader, KLD_multiplier):
    model.eval()
    total_loss = 0
    total_KLD = 0

    # Disable gradient calculation during evaluation
    with torch.no_grad():
        for x in data_loader:
            # Input to device
            x = x.to(device)
            # Get outputs
            x_pred, mu, logvar = model(x)
            BCE = bce_loss(x_pred, x)
            KLD = kld_loss(x_pred, x, mu, logvar)
            loss = BCE + KLD_multiplier*KLD
            # Sum loss
            total_loss += loss
            total_KLD += KLD

    # Calculate epoch loss metrics

    epoch_loss = total_loss / len(data_loader)
    epoch_KLD = total_KLD / len(data_loader)

    return epoch_loss, epoch_KLD

# Train Model
def train(model, train_loader, valid_loader, optimizer, num_epochs, saved_model, evaluate_every_n_epochs=1, KLD_multiplier=100):

    # For storing losses
    train_losses = []
    valid_losses = []
    train_KLD_losses = []
    valid_KLD_losses = []

    # Init loss as very high number
    best_valid_loss = 99999999999999999

    # For each epoch
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_KLD = 0
        # For each batch in data loader
        for x in train_loader:
            # Input to device
            x = x.to(device)

            # Get Output
            x_pred, mu, logvar = model(x)
            # Calculate losses
            BCE = bce_loss(x_pred, x)
            KLD = kld_loss(x_pred, x, mu, logvar)
            loss = BCE + KLD_multiplier*KLD

            # Back Prop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss
            epoch_KLD += KLD
        # Epoch metrics
        epoch_loss /= len(train_loader)
        epoch_KLD /= len(train_loader)

        # Save/print epoch metrics
        train_losses.append(epoch_loss.detach().cpu())
        train_KLD_losses.append(epoch_KLD.detach().cpu())
        print(f'Epoch:{epoch+1}, Training Loss:{epoch_loss:.4f}')
        print(f'Training KLD:{epoch_KLD:.4f}')
        # Evaluate the network on the validation data
        if((epoch+1) % evaluate_every_n_epochs == 0):
            valid_loss, valid_KLD = evaluate(model, valid_loader, KLD_multiplier)

            print(f'Validation loss: {valid_loss:.6f}')
            print()
            valid_losses.append(valid_loss.detach().cpu())
            valid_KLD_losses.append(valid_KLD.detach().cpu())

            # If model is best model for validation set, save model
            if valid_loss<best_valid_loss:
              print("New best model, saving...\n")
              best_valid_loss = valid_loss
              torch.save(model.state_dict(), "best_model.pkl")

    return train_losses, valid_losses, train_KLD_losses, valid_KLD_losses, model

num_epochs = 20

# Train
train_losses, valid_losses, train_KLD_losses, valid_KLD_losses, snare_model = train(kick_model, kick_train_loader, kick_val_loader, kick_optimizer, num_epochs, 'best_model.pkl')

In [None]:
# Plot loss
plt.plot(range(num_epochs), train_losses, 'dodgerblue', label='training')
plt.plot(range(num_epochs), valid_losses, 'orange', label='validation')
plt.xlim(0, num_epochs);
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss on Training/Validation Set')
plt.legend();

In [None]:
# plot KLD loss
plt.plot(range(num_epochs), train_KLD_losses, 'dodgerblue', label='training')
plt.plot(range(num_epochs), valid_KLD_losses, 'orange', label='validation')
plt.xlim(0, num_epochs);
plt.xlabel('Epoch')
plt.ylabel('KLD')
plt.title('KL Divergence on Training/Validation Set')
plt.legend();

In [None]:
# *** ADAPTED FROM https://medium.com/@judyyes10/generate-images-using-variational-autoencoder-vae-4d429d9bdb5

# Visualise input vs output
def compare_images(model, sample_images):
    reconstructed_images = model(sample_images)[0]
    comparison = torch.cat([sample_images, reconstructed_images])
    comparison_image = make_grid(comparison.detach().cpu(), nrow=8)
    fig = plt.figure(figsize=(5, 5))
    output = plt.imshow(comparison_image.permute(1, 2, 0))
    plt.show()

for batch in kick_train_loader:
    x= batch
    break
sample_images = x.to(device)

compare_images(snare_model, sample_images)

## Save/Load Audio Model

In [6]:
# Directories of kick/snare models
kick_model_dir = "/content/Kick_VAE_Model_4.pkl"

In [None]:
# Save model
#torch.save(snare_model.state_dict(), kick_model_dir)

In [None]:
#Load models

kick_model = audio_VAE()
kick_model.load_state_dict(torch.load(kick_model_dir))

kick_model = kick_model.to(device)

kick_model.eval()

In [None]:
# Normalise latent space range for taining set
def get_latent_space_scale(model, data_loader, latent_size):

  highest_max = torch.zeros(latent_size)
  lowest_min = torch.zeros(latent_size)

  with torch.no_grad():
        # For each pattern in data loader
        for pattern in data_loader:
          pattern = pattern.to(device)
          latent_space = model.get_latent(pattern)
          # Get max and min
          max = torch.max(latent_space, dim=0).values #check
          min = torch.min(latent_space, dim=0).values #check

          # Update global max/min
          for i in range(latent_size):
            if min[i] < lowest_min[i]:
              lowest_min[i] = min[i]
            if max[i] > highest_max[i]:
              highest_max[i] = max[i]

  return [lowest_min.tolist(), highest_max.tolist()]

# Seperate scaler for kick and snare
kick_boundaries = get_latent_space_scale(kick_model, kick_train_loader, 4)

In [None]:
# Apply boundaries to scaler

kick_latent_scaler = MinMaxScaler()
kick_latent_scaler.fit(kick_boundaries)

In [None]:
# Convert prediction to waveform
def pred2wave(x_pred, scaler, n_iter=100):
  # move prediction to cpu, remove gradient, remove dimension
  x_pred = x_pred.squeeze(0).detach().cpu()

  # Invert amp normalisation
  inverse_normalized_magnitude = torch.from_numpy(scaler.inverse_transform(x_pred.view(-1, 1)).reshape(x_pred.shape))

  # Invert log amp
  mag = torch.exp(inverse_normalized_magnitude).numpy()

  # Griffin Lim to estimate phase
  waveform = librosa.griffinlim(mag, n_iter=n_iter, window='hann', hop_length=hop_length)
  return waveform

## UI

In [None]:
#@title Kick Generator
sr = 44100
#@markdown Latent Space:\
#@markdown _0 and 1 are the minimum and maximum values from the training set._
#@markdown _Values outside of this range are out of distribution._

dim_1 = 0.59 #@param {type:'slider', min:-0.5, max:1.5, step:0.01}
dim_2 = 0.12 #@param {type:'slider', min:-0.5, max:1.5, step:0.01}
dim_3 = 0.28 #@param {type:'slider', min:-0.5, max:1.5, step:0.01}
dim_4 = 0.3 #@param {type:'slider', min:-0.5, max:1.5, step:0.01}
#@markdown Filename:\
#@markdown _Do not include extension._
file_name = "kick" # @param {type:"string"}

kick_path = "/content/DrumSound/"+file_name+".wav"

# Denormalise
latent_space = torch.FloatTensor(kick_latent_scaler.inverse_transform([[dim_1, dim_2, dim_3, dim_4]]))
# Get output
x_pred = kick_model.generate(latent_space.to(device))

# Convert to Audio
waveform = pred2wave(x_pred, kick_scaler)
display(Audio(waveform, rate=sr))
waveform = librosa.util.normalize(np.ravel(waveform))
sf.write(kick_path, waveform, sr, subtype='PCM_24')