# Handling the Dataset

In [None]:
import pretty_midi
import pypianoroll
import numpy as np
import os

midi_example = "data/10,000_Maniacs/A_Campfire_Song.mid"

## Function Definitions for Processing MIDI Data

In [None]:
def get_tracks_with_most_notes(pretty_midi_file, top_n=4, ignore_instruments=["Drums"]):
  '''
  Inputs:
    pretty_midi_file (PrettyMIDI object)
    top_n (int) - This tells the function how many instruments should be included. The result lists will be sorted from max note count to least notecount.

  '''
  instrument_notecount_list = []
  # Loop through each instrument.
  for instrument in pretty_midi_file.instruments:
    if instrument.name in ignore_instruments:
      continue
    note_count = len(instrument.notes) # Cache the count of the notes of this instrument.
    instrument_notecount_list.append([instrument, note_count])

    # Sorts the list according to the note count.
    instrument_notecount_list = sorted(instrument_notecount_list, key=lambda lst: lst[1], reverse=True)
  instrument_list = [instrument_notecount[0] for instrument_notecount in instrument_notecount_list]

  # If we do not have up to top_n tracks, let's just duplicate the last one.
  top_instruments_list = instrument_list[0:top_n]
  available_instrument_count = len(top_instruments_list)
  if available_instrument_count < top_n:
    top_instruments_list.extend([top_instruments_list[-1] * (top_n - available_instrument_count)])

  return top_instruments_list


def get_highest_notes(instrument_track):
  highest_notes = {} # Highest note in each timestep.
  for note in instrument_track.notes:
    if note.start not in highest_notes.keys():
      highest_notes[note.start] = [note, note.pitch]
    else:
      if note.pitch > highest_notes[note.start][1]: # Check if current note's pitch is higher than the cached pitch.
        highest_notes[note.start] = [note, note.pitch]
  return highest_notes


def filter_for_highest_notes_only(instrument_track):
  '''
  Input(s):
    instrument_track: PrettyMIDI Instrument object with notes in notes attribute.

  Output(s):
    instrument: PrettyMIDI instrument object with only the highest note at each timestep.
  '''
  highest_notes = get_highest_notes(instrument_track)

  note_list = []
  for key in sorted(highest_notes.keys()):
    note_list.append(highest_notes[key][0])

  instrument_info = {}
  instrument_info["program"] = instrument_track.program
  instrument_info["is_drum"] = instrument_track.is_drum
  instrument_info["name"] = instrument_track.name

  instrument = pretty_midi.Instrument(instrument_info["program"],
                                      instrument_info["is_drum"],
                                      instrument_info["name"],
                                      )

  instrument.notes = note_list
  return instrument


def convert_rolls_to_multitrack(piano_rolls, instrument_rolls, output_filepath):
    '''
    Inputs:
        piano_rolls: np array shape (4, 800, 128) *One instance of data*
        instrument_rolls: np array shape (4, 128) *One instance of data*
        output_filepath: str, name/path of file to output.
    '''
    # First, we create track objects.
    track_list = []
    for i in range(instrument_rolls.shape[0]):
        track = pypianoroll.BinaryTrack()

        program = np.argmax(instrument_rolls[i]) + 1 # Adds one so it corresponds to the correct MIDI instrument.

        track.program = program
        track.is_drum = False
        track.pianoroll = piano_rolls[i]

        track_list.append(track)

    multitrack = pypianoroll.Multitrack(resolution=4, tracks=track_list)
    pypianoroll.write(output_filepath, multitrack)
    return multitrack


def get_onehotencoding_instrument(track):
  max_instruments = 128 # There are 128 possible instruments according to https://fmslogo.sourceforge.io/manual/midi-instrument.html
  program = track.program - 1
  return np.eye(max_instruments)[program]

def convert_multitrack_to_rolls(multitrack):
  # Convert multitrack to piano rolls of shape (4, t, 128), where t is the number of timesteps.
  # and to instrument rolls of shape (4, 128).
  pianoroll_list = []
  instrumentroll_list = []
  for instrument in multitrack.tracks:
    pianoroll = np.expand_dims(instrument.pianoroll, axis=0) # Expand dims because initially they're (t, 128). We want them as (1, t, 128) so we can stack along the 0th axis.
    pianoroll_list.append(pianoroll)

    # Creates a onehotencoding for the instrument.
    instrumentroll = get_onehotencoding_instrument(instrument)
    instrumentroll_list.append(instrumentroll)

  pianoroll_arr = np.vstack(pianoroll_list)
  instrumentroll_arr = np.vstack(instrumentroll_list)

  return pianoroll_arr, instrumentroll_arr

def parse_rolls_from_midi_filepath(filepath):
  '''
  Input:
    filepath: The relative/absolute filepath to the Midi file.
  Outputs:
    piano_roll: The piano roll of the top 4 instruments of the track. Shape (4, t, 128). Where t is the number of timesteps.
      Note: 128 is the number of possible pitches.
    instrument_roll: The instrument rolls of the top 4 instruments of the track. Shape (4, 128).
      Note: 128 is different from above, here it is the number of possible instruments in MIDI.

  '''
  try:
    midi_file = pretty_midi.PrettyMIDI(filepath)
    midi_file.instruments = get_tracks_with_most_notes(midi_file) # Gets the top 4 (default) instruments with most notes.

    # Gets the highest note in each instrument (only 1 note per instrument per timestep).
    top_instruments = [filter_for_highest_notes_only(instrument) for instrument in midi_file.instruments]
    midi_file.instruments = top_instruments # Sets the instruments attribute to the top instruments.

    # Convert from PrettyMIDI object to pypianoroll Multitrack object to easily get the piano roll.
    multitrack = pypianoroll.from_pretty_midi(midi_file, resolution = 6)

    piano_roll, instrument_roll = convert_multitrack_to_rolls(multitrack)
  except:
    print("Error: ", filepath)
    piano_roll, instrument_roll = None, None
  return piano_roll, instrument_roll

## Creating a PyTorch Dataset

In [None]:
import numpy as np
import torch
import torch.nn as nn

In [None]:
batch_size = 10
data_size = 1000 # Hardcoded to only train the model with 1000 data points.

In [None]:
class MyDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data_x = np.load("hooktheory-npy-data/data_x_augmented.npy").astype(np.float32)
    self.prev_x = np.load("hooktheory-npy-data/prev_x_augmented.npy").astype(np.float32)
    pass

  def __len__(self):
    return data_size # Amount of data we have.

  def __getitem__(self, idx):
    return self.prev_x[idx], self.data_x[idx]

In [None]:
dataset = MyDataset()
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Creating a Variational Autoencoder via PyTorch

In [None]:
# Parameters 
EPOCHS = 20
latent_size = 8
data_size = 1000
kld_factor = 0.001
kld_mod = -1
learning_rate = 0.001

In [None]:
class HT_VAE(nn.Module):
    def __init__(self):
        super(HT_VAE, self).__init__()
        self.batch_size = batch_size
        self.latent_size = latent_size
        # Encoding #
        self.enc_fc1 = nn.Linear(2048, 1024)
        self.enc_bn1 = nn.BatchNorm1d(1024)
        self.enc_fc2 = nn.Linear(1024, 512)
        self.enc_bn2 = nn.BatchNorm1d(512)
        self.enc_fc3 = nn.Linear(512, 128)
        self.enc_bn3 = nn.BatchNorm1d(128)

        self.enc_mu = nn.Linear(128, self.latent_size)
        self.enc_logvar = nn.Linear(128, self.latent_size)

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # Decoding #
        self.dec_fc1 = nn.Linear(self.latent_size, 128)
        self.dec_bn1 = nn.BatchNorm1d(128)
        self.dec_fc2 = nn.Linear(128, 512)
        self.dec_bn2 = nn.BatchNorm1d(512)
        self.dec_fc3 = nn.Linear(512, 2048)
        self.dec_bn3 = nn.BatchNorm1d(2048)

        
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=2)

        self.dropout = nn.Dropout(p=0.2)
        
        #self.batchnorm_1 = nn.Batchnorm1d()

        self.initialize_weights()

        
    def initialize_weights(self):
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0.0)
            elif isinstance(layer, nn.GRU):
                for name, param in layer.named_parameters():
                    if 'weight' in name:
                        nn.init.xavier_uniform_(param)
                    elif 'bias' in name:
                        nn.init.constant_(param, 0.0)
                        
    def encode(self, prev_x):
        # Prev_X is of shape [batch_size, 16, 128]
        data = torch.flatten(prev_x, start_dim=1) # [batch_size, 2048]

        # Go through FC with 512 neurons, followed by RELU
        h0 = self.enc_fc1(data) # from [b, 2048] to [b, 1024]
        h0 = self.enc_bn1(h0)
        h0 = self.tanh(h0)

        h1 = self.enc_fc2(h0) # from [b, 1024] to [b, 512]
        h1 = self.enc_bn2(h1)
        h1 = self.tanh(h1)
        h1 = self.dropout(h1)

        h2 = self.enc_fc3(h1)
        h2 = self.enc_bn3(h2)
        h2 = self.tanh(h2)


        # Parallel mu and logvar generation, followed by RELU.
        mu = self.enc_mu(h2) # from [b, 128] to [b, 16]
        mu = self.relu(mu)

        log_var = self.enc_logvar(h2) # from [b, 128] to [b, 16]
        log_var = self.relu(log_var)
        log_var = log_var + 1

        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)  # Standard deviation
        eps = torch.randn_like(std)     # Sample epsilon from standard normal distribution
        z = mu + eps * std              # Reparameterization trick
        return z

    def decode(self, z):
        h0 = self.dec_fc1(z)
        h0 = self.dec_bn1(h0)
        h0 = self.relu(h0)
        h0 = self.dropout(h0)

        h1 = self.dec_fc2(h0)
        h1 = self.dec_bn2(h1)
        h1 = self.relu(h1)

        h2 = self.dec_fc3(h1)
        h2 = h2.view(self.batch_size, 16, 128)
        output = self.sigmoid(h2)
        return output

    def forward(self, prev_x):
        mu, log_var = self.encode(prev_x)

        z = self.reparameterize(mu, log_var)

        generated_bar = self.decode(z)
        return generated_bar, mu, log_var

In [None]:
model = HT_VAE()

## Training Loop
### Loss Function

In [None]:
def vae_gaussian_kl_loss(mu, log_var):
  KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
  return KLD.mean()

def reconstruction_loss(x_reconstructed, x):
  bce_loss = nn.BCELoss()
  return bce_loss(x_reconstructed, x)

def vae_loss(y_pred, y_true, epoch):
  mu, log_var, reconstructed_x = y_pred
  reconstructed_loss = reconstruction_loss(reconstructed_x, y_true)
  kld_loss = vae_gaussian_kl_loss(mu, log_var)
  #print("RECON:", reconstructed_loss, " KLD: ", kld_loss)

  return reconstructed_loss + kld_factor*kld_loss

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
training_loader = train_dataloader

In [None]:
def train_one_epoch(epoch_index, tb_writer):
  running_loss = 0.
  for i, data in enumerate(training_loader):
    prev_x, current_x = data

    optimizer.zero_grad()

    generated_bar, mu, log_var = model(prev_x)

    loss = vae_loss((mu, log_var, generated_bar), current_x, epoch_index)
    print(f"Epoch {epoch_index} Loss {i}: {loss.item()}")
    loss.backward()
    optimizer.step()

In [None]:
for epoch in range(EPOCHS):
  model.train(True)
  train_one_epoch(epoch, None)

In [None]:
def convert_output_to_pr(output, mean_probability=False, input_probability=0.1, tensor_output=False):
    '''
    Input: Tensor output of shape [instrument_size, timestep, pitch_range]
    '''

    # For each instrument size, we want to compile onehot encodings of shape timestep, pitch_range.
    pr_output = None
    track_outputs = []
    for track in output:
        track_output = None
        for idx, val in zip(track.max(dim=1).indices, track.max(dim=1).values):
            pitch = idx.item()

            onehot = np.eye(128)[pitch]
            if track_output is None:
                track_output = np.array(onehot)
            else:
                track_output = np.vstack((track_output, onehot))

        track_outputs.append(track_output)
    pr_output = np.stack(track_outputs)
    pr_output = np.array(pr_output)
    pr_output = pr_output.astype(int)
    
    return pr_output if not tensor_output else torch.from_numpy(pr_output)
            
def save_output_to_mid(output, filename, instrument=np.eye(128)[33].astype(int)[None, :]):
    pr = convert_output_to_pr(output, mean_probability=True)
    convert_rolls_to_multitrack(pr, instrument , filename)
    