<a href="https://colab.research.google.com/github/Ahtesham519/Genrative_Deep_learning_v2_2023/blob/main/Transformer_utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import pickle as pkl
import music21
import keras
import tensorflow as tf

from fractions import Fraction

def parse_midi_files(file_list, parser, seq_len, parsed_data_path = None):
  nots_list = []
  duration_list = []
  notes = []
  durations = []

  for i , file in enumerate(file_list):
    print(i + 1, "Parsing %s" %file)
    score = parser.parse(file).choridify()

    notes.append("START")
    durations.append("0.0")

    for element in score.flat:
      note_name = None
      duration_name = None

      if isinstance(element, music21.key.key):
        note_name = str(element.tonic.name) +  ":" + str(element.mode)
        duration_name = "0.0"

      elif isinstance(element, music21.meter.TimeSignature):
        note_name = str(element.ratioString) + "TS"
        duration_name = "0.0"

      elif isinstance(element, music21.chord.Chord):
        note_name = element.pitches[-1].nameWithOctave
        duration_name = str(element.duration.quarterLength)

      elif isinstance(element, music21.note.Rest):
        note_name = str(element.name)
        duration_name = str(element.duration.quarterLength)

      elif isinstance(element, music21.note.Note):
        note_name = str(element.nameWithOctave)
        duration_name = str(element.duration.quarterLength)

      if note_name and duration_name:
        notes.append(note_name)
        durations.append(duration_name)
    print(f"{len(notes)} notes parsed")


  notes_list = []
  duration_list = []

  print(f"Buliding sequences of length {seq_len}")
  for i in range(len(notes) - seq_len):
    notes_list.append(" " . join(notes[i : (i + seq_len)]))
    duration_list.append(" ".join(durations[i : (i + seq_len)]))

  if parsed_data_path:
    with open(os.path.join(parsed_data_path, "notes"), "wb") as f:
      pkl.dump(notes_list , f)
    with open(os.path.join(parsed_data_path , "durations") , "wb") as f:
      pkl.dump(duration_list,f)

  return notes_list, duration_list

def load_parsed_files(parsed_data_path):
  with open(os.path.join(parsed_data_path, "notes"), "rb") as f:
    notes = pkl.load(f)

  with open(os.path.join(parsed_data_path, "durations"), "rb") as f:
    durations = pkl.load(f)
  return notes, durations


def get_midi_note(sample_note, sample_duration):
  new_note = None

  if "TS" in sample_note:
    new_note = music21.meter.TimeSignature(sample_note.split("TS")[0])

  elif "major" in sample_note or "minor" in sample_note:
    tonic, mode = sample_note.split(":")
    new_note = music21.key.Key(tonic , mode)

  elif sample_note == "reat":
    new_note = music21.note.Rest()
    new_note.duration = music21.duration.Duration(
        float(Fraction(sample_duration))
    )
    new_note.storedInstrument = music21.instrument.Violoncello()

  elif "." in sample_note:
    notes_in_chord = sample_note.split(".")
    chord_notes = []
    for current_note in notes_in_chord:
      n = music21.note.Note(current_note)
      n.duration = music21.duration.Duration(
          float(Fraction(sample_duration))
      )
      n.storedInstrument = music21.instrument.Violoncello()
      chord_notes.append(n)
    new_note = music21.chord.Chord(chord_notes)

  elif sample_note == "rest":
    new_note = music21.note.Rest()
    new_note.duration = music21.duration.Duration(
        float(Fraction(sample_duration))
    )
    new_note.storedInstrument = music21.instrument.Violoncello()

  elif sample_note != "START":
    new_note = music21.note.Note(sample_note)
    new_note.duration = music21.duration.Duration(
        float(Fraction(sample_duration))
    )
    new_note.storedInstrument = music21.instrument.Violoncello()

  return new_note

class SinePositionEncoding(keras.layers.Layer):

  def __init__(
      self,
      max_wavelenth = 10000,
      **kwargs,
  ):
      super().__init__(**kwargs)
      self.max_wavelength = max_wavelength

  def call(self, inputs):
    input_shape = tf.shape(inputs)
    seq_length = input_shape[-2]
    hidden_size = input_shape[-1]
    position = tf.cast(tf.range(seq_length) , self.compute_dtype)
    min_freq = tf.cast(1 / self.max_wavelength, dtype = self.compute_dtype)
    timescales = tf.pow(
        min_freq,
        tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype)
        / tf.cast(hidden_size , self.compute_dtype)
    )
    angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0)
    cos_mask = tf.cast(tf.range(hidden_size) % 2 , self.compute_dtype)
    sin_mask = 1- cos_mask

    positional_encodings = (
        tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
    )

    return tf.broadcast_to(positional_encodings, input_shape)

  def get_config(self):
    config = super().get_config()
    config.update(
        {
            "max_wavelength": self.max_wavelength,
        }
    )
    return config