In [None]:
!pip install mido
import tensorflow as tf
import numpy as np
import mido

In [None]:
resolution = 24

notes = 'abcdefghijkl'
octaves = 'ABCDEFGHI'
shifts='abcdefghijkl'

min_note = 12
max_note = 96
notes_range = range(min_note,max_note)

def tick_based_encoding(song):
  data = []
  abs_time = 0
  for msg in song.tracks[2]:
    abs_time += msg.time//resolution
    while len(data)<abs_time+1:
      data.append([0.0 for i in notes_range])
    if msg.type == 'note_on':
      data[abs_time][msg.note] = 1.0
  return data

def sequentional(song):
    data = []
    for msg in song.tracks[2]:
        current_octave = 4
        if msg.type == 'note_off':
          if msg.time>0:
            data.append(msg.time)
        if msg.type == 'note_on':
            data.append(msg.time)
            data.append(octaves[msg.note//12])
            data.append(notes[msg.note%12])
    return data

def encode_shifts(song):
    data = []
    time = 0
    prev = 60
    for msg in song.tracks[2]:
        time+=msg.time
        if msg.type == 'note_on':
            data.append(time)
            time = 0
            shift = msg.note-prev
            prev = msg.note
            while shift>11:
              data.append('>')
              shift -= 12
            while shift<-11:
              data.append('<')
              shift += 12
            if shift<0:
              data.append('-')
              shift = abs(shift)
            shift = shifts[shift]
            data.append(shift)
    return data


song = mido.MidiFile('/content/MoonlightExtended.mid')

data =tick_based_encoding(song)
data[:10]

In [None]:
vocab = list(set(data))
len(vocab),len(data)

In [None]:
BATCH_SIZE = 64
BUFFER_SIZE = 10000
seq_length_fast = 10
seq_length_slow = 150



def split_input_target_seq(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[-1]
    return input_text, target_text



#char2idx = {u:i for i, u in enumerate(vocab)}
#idx2char = vocab
data*=100
#text_as_int = np.array([char2idx[c] for c in data])

char_dataset = tf.data.Dataset.from_tensor_slices(data)

sequences_slow = char_dataset.batch(seq_length_slow+1, drop_remainder=True)
#sequences_fast= char_dataset.batch(seq_length_fast+1, drop_remainder=True)

sequences_slow= sequences_slow.map(split_input_target)
#sequences_fast= sequences_fast.map(split_input_target)

dataset_slow = sequences_slow.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
#dataset_fast = sequences_fast.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)


In [None]:
dataset_slow

In [None]:
# Length of the vocabulary in chars
#vocab_size = len(vocab)

# The embedding dimension
#embedding_dim = 256

# Number of RNN units
rnn_units = 512

def loss(labels, logits):
  return tf.keras.losses.categorical_crossentropy(labels, logits, from_logits=False)

def build_model(rnn_units, batch_size):
  model = tf.keras.Sequential([
    #tf.keras.layers.Embedding(vocab_size, embedding_dim,
    #                          batch_input_shape=[batch_size, None]),
    tf.keras.layers.GRU(rnn_units,
                        #return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform'),       
    tf.keras.layers.Dense(len(notes_range))
  ])
  return model#, model

In [None]:
model = build_model(
  #vocab_size = len(vocab),
  #embedding_dim=embedding_dim,
  rnn_units=rnn_units,
  batch_size=BATCH_SIZE)

model.compile(optimizer='adam', loss=loss)
#modelfast.compile(optimizer='adam', loss=loss)

In [None]:
EPOCHS=20
model.fit(dataset_slow, epochs=EPOCHS)
model.save_weights('weights')

In [None]:
tf.expand_dims(data[:10],0)

In [None]:
modelslow,modelfast = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)

modelslow.load_weights('slowweights')
modelfast.load_weights('fastweights')
modelslow.build(tf.TensorShape([1, None]))
modelfast.build(tf.TensorShape([1, None]))

In [None]:
def generate_text(model, start_string):

  num_generate = 1000

  input_eval = [char2idx[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)

  text_generated = []

 
  modelslow.reset_states()
  modelfast.reset_states()
  for i in range(num_generate):
      if i%128==0:
        temperature = np.random.random()/2+0.4
      predictions_slow = modelslow(input_eval)
      #predictions_fast = modelfast(input_eval)     
      predictions_slow = tf.squeeze(predictions_slow, 0)/ temperature
      #predictions_fast = tf.squeeze(predictions_fast, 0) / temperature 

      predicted_id_slow = tf.random.categorical(predictions_slow, num_samples=1)[-1,0].numpy()
      #predicted_id_fast = tf.random.categorical(predictions_fast, num_samples=1)[-1,0].numpy()
      input_eval = tf.expand_dims([predicted_id_slow], 0)
      text_generated.append(idx2char[predicted_id_slow])
      #if i%3==0:
       # text_generated.append(idx2char[predicted_id_slow])
      #else:
      #  text_generated.append(idx2char[predicted_id_fast])


      
 

  return (start_string+text_generated)

In [None]:
gen=generate_text(model, start_string=[0])

In [None]:

def from_seq():
  song = mido.MidiFile()
  track = mido.MidiTrack()
  song.ticks_per_beat=96
  time=0
  octave=4
  for i,msg in enumerate(gen):
      if isinstance(msg,int):
        time += msg
      elif msg in octaves:
        octave = octaves.index(msg)
      elif msg in notes:
        note = notes.index(msg)
        track.append(mido.Message('note_on', note = octave*12+note, time = time))
        track.append(mido.Message('note_off', note = octave*12+note, time = 0))
        time = 0

  song.tracks.append(track)
  song.save('/content/output.mid')

def from_shifts():
  song = mido.MidiFile()
  track = mido.MidiTrack()
  song.ticks_per_beat=96
  time=0
  prev=60
  minusflag = False
  for i,msg in enumerate(gen):
      if isinstance(msg,int):
          time += msg
      elif msg == '>':
          prev+=12
      elif msg == '<':
          prev-=12
      elif msg =='-':
          minusflag=True
      elif msg in shifts:
          if minusflag:
              note = prev-shifts.index(msg)
          else:
              note = prev+shifts.index(msg)
          minusflag = False
          note = min(120,max(note,0))
          prev = note
          track.append(mido.Message('note_on', note = note, time = time))
          track.append(mido.Message('note_off', note = note, time = 0))
          time = 0

  song.tracks.append(track)
  song.save('/content/output.mid')


In [None]:
from_shifts()