<a href="https://colab.research.google.com/github/AdamClarkStandke/GenerativeDeepLearning/blob/main/MusicGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Music Generation

This code comes from David Foster's [repo](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/11_music/01_transformer/transformer.ipynb) from his book [Generative Deep Learning](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1492041947)



---



In [None]:
import os
import glob
import numpy as np
import time
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, models, losses, callbacks

import music21

In [None]:
import os
import pickle as pkl
import music21
import keras
import tensorflow as tf

from fractions import Fraction


def parse_midi_files(file_list, parser, seq_len, parsed_data_path=None):
    notes_list = []
    duration_list = []
    notes = []
    durations = []

    for i, file in enumerate(file_list):
        print(i + 1, "Parsing %s" % file)
        score = parser.parse(file).chordify()

        notes.append("START")
        durations.append("0.0")

        for element in score.flat:
            note_name = None
            duration_name = None

            if isinstance(element, music21.key.Key):
                note_name = str(element.tonic.name) + ":" + str(element.mode)
                duration_name = "0.0"

            elif isinstance(element, music21.meter.TimeSignature):
                note_name = str(element.ratioString) + "TS"
                duration_name = "0.0"

            elif isinstance(element, music21.chord.Chord):
                note_name = element.pitches[-1].nameWithOctave
                duration_name = str(element.duration.quarterLength)

            elif isinstance(element, music21.note.Rest):
                note_name = str(element.name)
                duration_name = str(element.duration.quarterLength)

            elif isinstance(element, music21.note.Note):
                note_name = str(element.nameWithOctave)
                duration_name = str(element.duration.quarterLength)

            if note_name and duration_name:
                notes.append(note_name)
                durations.append(duration_name)
        print(f"{len(notes)} notes parsed")

    notes_list = []
    duration_list = []

    print(f"Building sequences of length {seq_len}")
    for i in range(len(notes) - seq_len):
        notes_list.append(" ".join(notes[i : (i + seq_len)]))
        duration_list.append(" ".join(durations[i : (i + seq_len)]))

    if parsed_data_path:
        with open(os.path.join(parsed_data_path, "notes"), "wb") as f:
            pkl.dump(notes_list, f)
        with open(os.path.join(parsed_data_path, "durations"), "wb") as f:
            pkl.dump(duration_list, f)

    return notes_list, duration_list


def load_parsed_files(parsed_data_path):
    with open(os.path.join(parsed_data_path, "notes"), "rb") as f:
        notes = pkl.load(f)
    with open(os.path.join(parsed_data_path, "durations"), "rb") as f:
        durations = pkl.load(f)
    return notes, durations


def get_midi_note(sample_note, sample_duration):
    new_note = None

    if "TS" in sample_note:
        new_note = music21.meter.TimeSignature(sample_note.split("TS")[0])

    elif "major" in sample_note or "minor" in sample_note:
        tonic, mode = sample_note.split(":")
        new_note = music21.key.Key(tonic, mode)

    elif sample_note == "rest":
        new_note = music21.note.Rest()
        new_note.duration = music21.duration.Duration(
            float(Fraction(sample_duration))
        )
        new_note.storedInstrument = music21.instrument.Violoncello()

    elif "." in sample_note:
        notes_in_chord = sample_note.split(".")
        chord_notes = []
        for current_note in notes_in_chord:
            n = music21.note.Note(current_note)
            n.duration = music21.duration.Duration(
                float(Fraction(sample_duration))
            )
            n.storedInstrument = music21.instrument.Violoncello()
            chord_notes.append(n)
        new_note = music21.chord.Chord(chord_notes)

    elif sample_note == "rest":
        new_note = music21.note.Rest()
        new_note.duration = music21.duration.Duration(
            float(Fraction(sample_duration))
        )
        new_note.storedInstrument = music21.instrument.Violoncello()

    elif sample_note != "START":
        new_note = music21.note.Note(sample_note)
        new_note.duration = music21.duration.Duration(
            float(Fraction(sample_duration))
        )
        new_note.storedInstrument = music21.instrument.Violoncello()

    return new_note


class SinePositionEncoding(keras.layers.Layer):
    """Sinusoidal positional encoding layer.
    This layer calculates the position encoding as a mix of sine and cosine
    functions with geometrically increasing wavelengths. Defined and formulized
    in [Attention is All You Need](https://arxiv.org/abs/1706.03762).
    Takes as input an embedded token tensor. The input must have shape
    [batch_size, sequence_length, feature_size]. This layer will return a
    positional encoding the same size as the embedded token tensor, which
    can be added directly to the embedded token tensor.
    Args:
        max_wavelength: The maximum angular wavelength of the sine/cosine
            curves, as described in Attention is All You Need. Defaults to
            10000.
    Examples:
    ```python
    # create a simple embedding layer with sinusoidal positional encoding
    seq_len = 100
    vocab_size = 1000
    embedding_dim = 32
    inputs = keras.Input((seq_len,), dtype=tf.float32)
    embedding = keras.layers.Embedding(
        input_dim=vocab_size, output_dim=embedding_dim
    )(inputs)
    positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding)
    outputs = embedding + positional_encoding
    ```
    References:
     - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
    """

    def __init__(
        self,
        max_wavelength=10000,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.max_wavelength = max_wavelength

    def call(self, inputs):
        # TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency
        # with other layers.
        input_shape = tf.shape(inputs)
        # length of sequence is the second last dimension of the inputs
        seq_length = input_shape[-2]
        hidden_size = input_shape[-1]
        position = tf.cast(tf.range(seq_length), self.compute_dtype)
        min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
        timescales = tf.pow(
            min_freq,
            tf.cast(2 * (tf.range(hidden_size) // 2), self.compute_dtype)
            / tf.cast(hidden_size, self.compute_dtype),
        )
        angles = tf.expand_dims(position, 1) * tf.expand_dims(timescales, 0)
        # even indices are sine, odd are cosine
        cos_mask = tf.cast(tf.range(hidden_size) % 2, self.compute_dtype)
        sin_mask = 1 - cos_mask
        # embedding shape is [seq_length, hidden_size]
        positional_encodings = (
            tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
        )

        return tf.broadcast_to(positional_encodings, input_shape)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "max_wavelength": self.max_wavelength,
            }
        )
        return config

In [None]:
PARSE_MIDI_FILES = True
PARSED_DATA_PATH = "/content/drive/MyDrive/music/bach/"
DATASET_REPETITIONS = 1

SEQ_LEN = 50
EMBEDDING_DIM = 256
KEY_DIM = 256
N_HEADS = 5
DROPOUT_RATE = 0.3
FEED_FORWARD_DIM = 256
LOAD_MODEL = False

# optimization
EPOCHS = 2000
BATCH_SIZE = 256

GENERATE_LEN = 50

In [None]:
# Load the data
file_list = glob.glob("/content/drive/MyDrive/music/bach/*.mid")
print(f"Found {len(file_list)} midi files")

Found 48 midi files


In [None]:
parser = music21.converter

In [None]:
example_score = (
    music21.converter.parse(file_list[1]).splitAtQuarterLength(12)[0].chordify()
)

In [None]:
example_score.show("text")

{0.0} <music21.metadata.Metadata object at 0x7e8801daf670>
{0.0} <music21.stream.Measure 1 offset=0.0>
    {0.0} <music21.instrument.Piano 'English Suite 1, 2. Allemande: Piano'>
    {0.0} <music21.clef.TrebleClef>
    {0.0} <music21.tempo.MetronomeMark andante Quarter=72>
    {0.0} <music21.key.Key of C major>
    {0.0} <music21.meter.TimeSignature 1/16>
    {0.0} <music21.chord.Chord A5>
{0.25} <music21.stream.Measure 2 offset=0.25>
    {0.0} <music21.meter.TimeSignature 4/4>
    {0.0} <music21.chord.Chord A5>
    {0.25} <music21.chord.Chord A3 A5>
    {0.5} <music21.chord.Chord A3 C#4 A5>
    {0.75} <music21.chord.Chord A3 C#4 E4>
    {1.0} <music21.chord.Chord A3 C#4 E4 A4>
    {1.25} <music21.chord.Chord A3 C#4 E4 A4 C#5>
    {1.5} <music21.chord.Chord A3 C#4 E4 A4 E5>
    {1.75} <music21.chord.Chord A3 C#4 A5>
    {1.8333} <music21.chord.Chord A3 A5>
    {2.0} <music21.chord.Chord A3 D4 B4 A5>
    {2.25} <music21.chord.Chord A3 D4 B4 A5>
    {2.5} <music21.chord.Chord A3 D4 B4 G#

In [None]:
if PARSE_MIDI_FILES:
    notes, durations = parse_midi_files(
        file_list, parser, SEQ_LEN + 1, PARSED_DATA_PATH
    )
else:
    notes, durations = load_parsed_files()

1 Parsing /content/drive/MyDrive/music/bach/bwv806a.mid


  notes, durations = parse_midi_files(


590 notes parsed
2 Parsing /content/drive/MyDrive/music/bach/bwv806b.mid
1829 notes parsed
3 Parsing /content/drive/MyDrive/music/bach/bwv806c.mid
2560 notes parsed
4 Parsing /content/drive/MyDrive/music/bach/bwv806d.mid


  notes, durations = parse_midi_files(
  notes, durations = parse_midi_files(


3460 notes parsed
5 Parsing /content/drive/MyDrive/music/bach/bwv806e.mid


  notes, durations = parse_midi_files(


4364 notes parsed
6 Parsing /content/drive/MyDrive/music/bach/bwv806f.mid


  notes, durations = parse_midi_files(


5236 notes parsed
7 Parsing /content/drive/MyDrive/music/bach/bwv806g.mid


  notes, durations = parse_midi_files(


5890 notes parsed
8 Parsing /content/drive/MyDrive/music/bach/bwv806h.mid


  notes, durations = parse_midi_files(


6884 notes parsed
9 Parsing /content/drive/MyDrive/music/bach/bwv806i.mid


  notes, durations = parse_midi_files(


7649 notes parsed
10 Parsing /content/drive/MyDrive/music/bach/bwv806j.mid


  notes, durations = parse_midi_files(


8599 notes parsed
11 Parsing /content/drive/MyDrive/music/bach/bwv807a.mid


  notes, durations = parse_midi_files(


10558 notes parsed
12 Parsing /content/drive/MyDrive/music/bach/bwv807b.mid


  notes, durations = parse_midi_files(


11374 notes parsed
13 Parsing /content/drive/MyDrive/music/bach/bwv807c.mid
12169 notes parsed
14 Parsing /content/drive/MyDrive/music/bach/bwv807d.mid


  notes, durations = parse_midi_files(


12626 notes parsed
15 Parsing /content/drive/MyDrive/music/bach/bwv807e.mid


  notes, durations = parse_midi_files(


13219 notes parsed
16 Parsing /content/drive/MyDrive/music/bach/bwv807f.mid


  notes, durations = parse_midi_files(


14429 notes parsed
17 Parsing /content/drive/MyDrive/music/bach/bwv807g.mid


  notes, durations = parse_midi_files(


14941 notes parsed
18 Parsing /content/drive/MyDrive/music/bach/bwv807h.mid


  notes, durations = parse_midi_files(


15959 notes parsed
19 Parsing /content/drive/MyDrive/music/bach/bwv808a.mid


  notes, durations = parse_midi_files(


17287 notes parsed
20 Parsing /content/drive/MyDrive/music/bach/bwv808b.mid


  notes, durations = parse_midi_files(


18090 notes parsed
21 Parsing /content/drive/MyDrive/music/bach/bwv808c.mid


  notes, durations = parse_midi_files(


19162 notes parsed
22 Parsing /content/drive/MyDrive/music/bach/bwv808d.mid


  notes, durations = parse_midi_files(


19557 notes parsed
23 Parsing /content/drive/MyDrive/music/bach/bwv808e.mid


  notes, durations = parse_midi_files(


20097 notes parsed
24 Parsing /content/drive/MyDrive/music/bach/bwv808f.mid


  notes, durations = parse_midi_files(


20832 notes parsed
25 Parsing /content/drive/MyDrive/music/bach/bwv808g.mid


  notes, durations = parse_midi_files(


21167 notes parsed
26 Parsing /content/drive/MyDrive/music/bach/bwv808h.mid


  notes, durations = parse_midi_files(


22416 notes parsed
27 Parsing /content/drive/MyDrive/music/bach/bwv809a.mid


  notes, durations = parse_midi_files(


24164 notes parsed
28 Parsing /content/drive/MyDrive/music/bach/bwv809b.mid


  notes, durations = parse_midi_files(


24974 notes parsed
29 Parsing /content/drive/MyDrive/music/bach/bwv809c.mid


  notes, durations = parse_midi_files(


25676 notes parsed
30 Parsing /content/drive/MyDrive/music/bach/bwv809d.mid
26028 notes parsed
31 Parsing /content/drive/MyDrive/music/bach/bwv809e.mid


  notes, durations = parse_midi_files(


26574 notes parsed
32 Parsing /content/drive/MyDrive/music/bach/bwv809f.mid


  notes, durations = parse_midi_files(


27114 notes parsed
33 Parsing /content/drive/MyDrive/music/bach/bwv809g.mid


  notes, durations = parse_midi_files(


28614 notes parsed
34 Parsing /content/drive/MyDrive/music/bach/bwv810a.mid


  notes, durations = parse_midi_files(


30627 notes parsed
35 Parsing /content/drive/MyDrive/music/bach/bwv810b.mid


  notes, durations = parse_midi_files(


31424 notes parsed
36 Parsing /content/drive/MyDrive/music/bach/bwv810c.mid


  notes, durations = parse_midi_files(


32410 notes parsed
37 Parsing /content/drive/MyDrive/music/bach/bwv810d.mid


  notes, durations = parse_midi_files(
  notes, durations = parse_midi_files(


32859 notes parsed
38 Parsing /content/drive/MyDrive/music/bach/bwv810e.mid
33342 notes parsed
39 Parsing /content/drive/MyDrive/music/bach/bwv810f.mid


  notes, durations = parse_midi_files(
  notes, durations = parse_midi_files(


33619 notes parsed
40 Parsing /content/drive/MyDrive/music/bach/bwv810g.mid
34805 notes parsed
41 Parsing /content/drive/MyDrive/music/bach/bwv811a.mid


  notes, durations = parse_midi_files(


38204 notes parsed
42 Parsing /content/drive/MyDrive/music/bach/bwv811b.mid


  notes, durations = parse_midi_files(


39035 notes parsed
43 Parsing /content/drive/MyDrive/music/bach/bwv811c.mid
40063 notes parsed
44 Parsing /content/drive/MyDrive/music/bach/bwv811d.mid


  notes, durations = parse_midi_files(


40702 notes parsed
45 Parsing /content/drive/MyDrive/music/bach/bwv811e.mid


  notes, durations = parse_midi_files(


41505 notes parsed
46 Parsing /content/drive/MyDrive/music/bach/bwv811f.mid


  notes, durations = parse_midi_files(


42282 notes parsed
47 Parsing /content/drive/MyDrive/music/bach/bwv811g.mid


  notes, durations = parse_midi_files(


42797 notes parsed
48 Parsing /content/drive/MyDrive/music/bach/bwv811h.mid


  notes, durations = parse_midi_files(


44266 notes parsed
Building sequences of length 51


In [None]:
example_notes = notes[658]
example_durations = durations[658]
print("\nNotes piano\n", example_notes, "...")
print("\nDuration piano\n", example_durations, "...")


Notes piano
 E4 C#5 A4 C#5 F#4 G#4 A4 A4 A4 A4 E-4 A4 A4 G#4 F#4 G#4 G#4 G#4 G#4 G#4 G#4 G#4 D3 G#4 G#4 A4 B4 B4 C#5 B4 B4 C#5 C#5 C#5 C#5 C#5 C#5 C#5 C#5 C#5 C#5 C#5 G3 C#5 C#5 D5 D5 D5 E5 E5 E5 ...

Duration piano
 1/12 0.25 0.25 0.25 0.25 0.25 1/6 1/12 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 1/6 1/12 1/6 1/12 0.25 1/12 1/6 0.25 1/12 1/6 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 0.25 1/6 1/12 1/6 1/12 0.25 1/12 1/6 1/12 1/6 1/12 ...


In [None]:
def create_dataset(elements):
    ds = (
        tf.data.Dataset.from_tensor_slices(elements)
        .batch(BATCH_SIZE, drop_remainder=True)
        .shuffle(1000)
    )
    vectorize_layer = layers.TextVectorization(
        standardize=None, output_mode="int"
    )
    vectorize_layer.adapt(ds)
    vocab = vectorize_layer.get_vocabulary()
    return ds, vectorize_layer, vocab


notes_seq_ds, notes_vectorize_layer, notes_vocab = create_dataset(notes)
durations_seq_ds, durations_vectorize_layer, durations_vocab = create_dataset(
    durations
)
seq_ds = tf.data.Dataset.zip((notes_seq_ds, durations_seq_ds))

In [None]:
# Display the same example notes and durations converted to ints
example_tokenised_notes = notes_vectorize_layer(example_notes)
example_tokenised_durations = durations_vectorize_layer(example_durations)
print("{:10} {:10}".format("note token", "duration token"))
for i, (note_int, duration_int) in enumerate(
    zip(
        example_tokenised_notes.numpy()[:11],
        example_tokenised_durations.numpy()[:11],
    )
):
    print(f"{note_int:10}{duration_int:10}")

note token duration token
        13         5
        11         2
         4         2
        11         2
        15         2
        20         2
         4         4
         4         5
         4         2
         4         2
        27         2


In [None]:
notes_vocab_size = len(notes_vocab)
durations_vocab_size = len(durations_vocab)

# Display some token:note mappings
print(f"\nNOTES_VOCAB: length = {len(notes_vocab)}")
for i, note in enumerate(notes_vocab[:10]):
    print(f"{i}: {note}")

print(f"\nDURATIONS_VOCAB: length = {len(durations_vocab)}")
# Display some token:duration mappings
for i, note in enumerate(durations_vocab[:10]):
    print(f"{i}: {note}")


NOTES_VOCAB: length = 70
0: 
1: [UNK]
2: D5
3: E5
4: A4
5: B4
6: C5
7: G5
8: G4
9: F5

DURATIONS_VOCAB: length = 73
0: 
1: [UNK]
2: 0.25
3: 0.5
4: 1/6
5: 1/12
6: 1/3
7: 0.75
8: 0.0
9: 1.5


In [None]:
# Create the training set of sequences and the same sequences shifted by one note
def prepare_inputs(notes, durations):
    notes = tf.expand_dims(notes, -1)
    durations = tf.expand_dims(durations, -1)
    tokenized_notes = notes_vectorize_layer(notes)
    tokenized_durations = durations_vectorize_layer(durations)
    x = (tokenized_notes[:, :-1], tokenized_durations[:, :-1])
    y = (tokenized_notes[:, 1:], tokenized_durations[:, 1:])
    return x, y


ds = seq_ds.map(prepare_inputs).repeat(DATASET_REPETITIONS)

In [None]:
example_input_output = ds.take(1).get_single_element()
print(example_input_output)

((<tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[ 5,  4,  5, ...,  4,  8, 15],
       [ 4,  5, 11, ...,  8, 15,  3],
       [ 5, 11, 11, ..., 15,  3,  2],
       ...,
       [ 5, 13, 13, ..., 14, 14, 14],
       [13, 13, 13, ..., 14, 14,  3],
       [13, 13,  4, ..., 14,  3, 14]])>, <tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[4, 6, 3, ..., 3, 3, 3],
       [6, 3, 2, ..., 3, 3, 2],
       [3, 2, 2, ..., 3, 2, 5],
       ...,
       [5, 4, 2, ..., 3, 2, 2],
       [4, 2, 2, ..., 2, 2, 3],
       [2, 2, 3, ..., 2, 3, 3]])>), (<tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[ 4,  5, 11, ...,  8, 15,  3],
       [ 5, 11, 11, ..., 15,  3,  2],
       [11, 11, 11, ...,  3,  2, 11],
       ...,
       [13, 13, 13, ..., 14, 14,  3],
       [13, 13,  4, ..., 14,  3, 14],
       [13,  4, 11, ...,  3, 14,  2]])>, <tf.Tensor: shape=(256, 50), dtype=int64, numpy=
array([[6, 3, 2, ..., 3, 3, 2],
       [3, 2, 2, ..., 3, 2, 5],
       [2, 2, 3, ..., 2, 5, 4],
       ...,

In [None]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
    )
    return tf.tile(mask, mult)

In [None]:
class TransformerBlock(layers.Layer):
    def __init__(
        self,
        num_heads,
        key_dim,
        embed_dim,
        ff_dim,
        name,
        dropout_rate=DROPOUT_RATE,
    ):
        super(TransformerBlock, self).__init__(name=name)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate
        self.attn = layers.MultiHeadAttention(
            num_heads, key_dim, output_shape=embed_dim
        )
        self.dropout_1 = layers.Dropout(self.dropout_rate)
        self.ln_1 = layers.LayerNormalization(epsilon=1e-6)
        self.ffn_1 = layers.Dense(self.ff_dim, activation="relu")
        self.ffn_2 = layers.Dense(self.embed_dim)
        self.dropout_2 = layers.Dropout(self.dropout_rate)
        self.ln_2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(
            batch_size, seq_len, seq_len, tf.bool
        )
        attention_output, attention_scores = self.attn(
            inputs,
            inputs,
            attention_mask=causal_mask,
            return_attention_scores=True,
        )
        attention_output = self.dropout_1(attention_output)
        out1 = self.ln_1(inputs + attention_output)
        ffn_1 = self.ffn_1(out1)
        ffn_2 = self.ffn_2(ffn_1)
        ffn_output = self.dropout_2(ffn_2)
        return (self.ln_2(out1 + ffn_output), attention_scores)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "key_dim": self.key_dim,
                "embed_dim": self.embed_dim,
                "num_heads": self.num_heads,
                "ff_dim": self.ff_dim,
                "dropout_rate": self.dropout_rate,
            }
        )
        return config

In [None]:
class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.token_emb = layers.Embedding(
            input_dim=vocab_size,
            output_dim=embed_dim,
            embeddings_initializer="he_uniform",
        )
        self.pos_emb = SinePositionEncoding()

    def call(self, x):
        embedding = self.token_emb(x)
        positions = self.pos_emb(embedding)
        return embedding + positions

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "vocab_size": self.vocab_size,
                "embed_dim": self.embed_dim,
            }
        )
        return config

In [None]:
note_inputs = layers.Input(shape=(None,), dtype=tf.int32)
durations_inputs = layers.Input(shape=(None,), dtype=tf.int32)
note_embeddings = TokenAndPositionEmbedding(
    notes_vocab_size, EMBEDDING_DIM // 2
)(note_inputs)
duration_embeddings = TokenAndPositionEmbedding(
    durations_vocab_size, EMBEDDING_DIM // 2
)(durations_inputs)
embeddings = layers.Concatenate()([note_embeddings, duration_embeddings])
x, attention_scores = TransformerBlock(
    N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM, name="attention"
)(embeddings)
note_outputs = layers.Dense(
    notes_vocab_size, activation="softmax", name="note_outputs"
)(x)
duration_outputs = layers.Dense(
    durations_vocab_size, activation="softmax", name="duration_outputs"
)(x)
model = models.Model(
    inputs=[note_inputs, durations_inputs],
    outputs=[note_outputs, duration_outputs],  # attention_scores
)
model.compile(
    "adam",
    loss=[
        losses.SparseCategoricalCrossentropy(),
        losses.SparseCategoricalCrossentropy(),
    ],
)
att_model = models.Model(
    inputs=[note_inputs, durations_inputs], outputs=attention_scores
)

In [None]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None)]               0         []                            
                                                                                                  
 input_2 (InputLayer)        [(None, None)]               0         []                            
                                                                                                  
 token_and_position_embeddi  (None, None, 128)            8960      ['input_1[0][0]']             
 ng (TokenAndPositionEmbedd                                                                       
 ing)                                                                                             
                                                                                              

In [None]:
# Create a MusicGenerator checkpoint
class MusicGenerator(callbacks.Callback):
    def __init__(self, index_to_note, index_to_duration, top_k=10):
        self.index_to_note = index_to_note
        self.note_to_index = {
            note: index for index, note in enumerate(index_to_note)
        }
        self.index_to_duration = index_to_duration
        self.duration_to_index = {
            duration: index for index, duration in enumerate(index_to_duration)
        }

    def sample_from(self, probs, temperature):
        probs = probs ** (1 / temperature)
        probs = probs / np.sum(probs)
        return np.random.choice(len(probs), p=probs), probs

    def get_note(self, notes, durations, temperature):
        sample_note_idx = 1
        while sample_note_idx == 1:
            sample_note_idx, note_probs = self.sample_from(
                notes[0][-1], temperature
            )
            sample_note = self.index_to_note[sample_note_idx]

        sample_duration_idx = 1
        while sample_duration_idx == 1:
            sample_duration_idx, duration_probs = self.sample_from(
                durations[0][-1], temperature
            )
            sample_duration = self.index_to_duration[sample_duration_idx]

        new_note = get_midi_note(sample_note, sample_duration)

        return (
            new_note,
            sample_note_idx,
            sample_note,
            note_probs,
            sample_duration_idx,
            sample_duration,
            duration_probs,
        )

    def generate(self, start_notes, start_durations, max_tokens, temperature):
        attention_model = models.Model(
            inputs=self.model.input,
            outputs=self.model.get_layer("attention").output,
        )

        start_note_tokens = [self.note_to_index.get(x, 1) for x in start_notes]
        start_duration_tokens = [
            self.duration_to_index.get(x, 1) for x in start_durations
        ]
        sample_note = None
        sample_duration = None
        info = []
        midi_stream = music21.stream.Stream()

        midi_stream.append(music21.clef.BassClef())

        for sample_note, sample_duration in zip(start_notes, start_durations):
            new_note = get_midi_note(sample_note, sample_duration)
            if new_note is not None:
                midi_stream.append(new_note)

        while len(start_note_tokens) < max_tokens:
            x1 = np.array([start_note_tokens])
            x2 = np.array([start_duration_tokens])
            notes, durations = self.model.predict([x1, x2], verbose=0)

            repeat = True

            while repeat:
                (
                    new_note,
                    sample_note_idx,
                    sample_note,
                    note_probs,
                    sample_duration_idx,
                    sample_duration,
                    duration_probs,
                ) = self.get_note(notes, durations, temperature)

                if (
                    isinstance(new_note, music21.chord.Chord)
                    or isinstance(new_note, music21.note.Note)
                    or isinstance(new_note, music21.note.Rest)
                ) and sample_duration == "0.0":
                    repeat = True
                else:
                    repeat = False

            if new_note is not None:
                midi_stream.append(new_note)

            _, att = attention_model.predict([x1, x2], verbose=0)

            info.append(
                {
                    "prompt": [start_notes.copy(), start_durations.copy()],
                    "midi": midi_stream,
                    "chosen_note": (sample_note, sample_duration),
                    "note_probs": note_probs,
                    "duration_probs": duration_probs,
                    "atts": att[0, :, -1, :],
                }
            )
            start_note_tokens.append(sample_note_idx)
            start_duration_tokens.append(sample_duration_idx)
            start_notes.append(sample_note)
            start_durations.append(sample_duration)

            if sample_note == "START":
                break

        return info

    def on_epoch_end(self, epoch, logs=None):
        info = self.generate(
            ["START"], ["0.0"], max_tokens=GENERATE_LEN, temperature=0.5
        )
        midi_stream = info[-1]["midi"].chordify()
        print(info[-1]["prompt"])
        midi_stream.write(
            "midi",
            fp=os.path.join(
                "/content/drive/MyDrive/music/",
                "output-" + str(epoch).zfill(4) + ".mid",
            ),
        )

In [None]:
# Tokenize starting prompt
music_generator = MusicGenerator(notes_vocab, durations_vocab)

model.fit(
    ds,
    epochs=EPOCHS,
    callbacks=[
        music_generator,
    ],
)

# Save the final model
model.save("/content/drive/MyDrive/music/Transformer")

Epoch 1/2000
Epoch 2/2000
Epoch 3/2000
Epoch 4/2000
Epoch 5/2000
Epoch 6/2000
Epoch 7/2000
Epoch 8/2000
Epoch 9/2000
Epoch 10/2000
Epoch 11/2000
Epoch 12/2000
Epoch 13/2000
Epoch 14/2000
Epoch 15/2000
Epoch 16/2000
Epoch 17/2000
Epoch 18/2000
Epoch 19/2000
Epoch 20/2000
Epoch 21/2000
Epoch 22/2000
Epoch 23/2000
Epoch 24/2000
Epoch 25/2000
Epoch 26/2000
Epoch 27/2000
Epoch 28/2000
Epoch 29/2000
Epoch 30/2000
Epoch 31/2000
Epoch 32/2000
Epoch 33/2000
Epoch 34/2000
Epoch 35/2000
Epoch 36/2000
Epoch 37/2000
Epoch 38/2000
Epoch 39/2000
Epoch 40/2000
Epoch 41/2000
Epoch 42/2000
Epoch 43/2000
Epoch 44/2000
Epoch 45/2000
Epoch 46/2000
Epoch 47/2000
Epoch 48/2000
Epoch 49/2000
Epoch 50/2000
Epoch 51/2000
Epoch 52/2000
Epoch 53/2000
Epoch 54/2000
Epoch 55/2000
Epoch 56/2000
Epoch 57/2000
Epoch 58/2000
Epoch 59/2000
Epoch 60/2000
Epoch 61/2000
Epoch 62/2000
Epoch 63/2000
Epoch 64/2000
Epoch 65/2000
Epoch 66/2000
Epoch 67/2000
Epoch 68/2000
Epoch 69/2000
Epoch 70/2000
Epoch 71/2000
Epoch 72/2000
E



[['START', 'C:major', '1/8TS', 'A4', '3/2TS', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4'], ['0.0', '0.0', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25', '0.25']]
Epoch 271/2000
Epoch 272/2000
Epoch 273/2000
Epoch 274/2000
Epoch 275/2000
Epoch 276/2000
Epoch 277/2000
Epoch 278/2000
Epoch 279/2000
Epoch 280/2000
Epoch 281/2000
Epoch 282/2000
Epoch 283/2000
Epoch 284/2000
Epoch 285/2000
Epoch 286/2000
Epoch 287/2000
Epoch 288/2000
Epoch 289/2000
Epoch 290/20



[['START', 'C:major', '1/8TS', 'A4', '3/2TS', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'A4', 'rest', 'A4', 'rest', 'A4', 'rest', '1/8TS', 'B4', '3/2TS', 'B4', 'B4', 'E2', 'C5', 'B4', 'A4', 'G4', 'F#4', 'E4', 'E-4', 'C5', 'E-4', 'rest', 'G4', 'E5', 'E5', 'E5', 'F#3', 'G5'], ['0.0', '0.5', '0.0', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.25', '0.25', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5', '0.5']]
Epoch 367/2000
Epoch 368/2000
Epoch 369/2000
Epoch 370/2000
Epoch 371/2000
Epoch 372/2000
Epoch 373/2000
Epoch 374/2000
Epoch 375/2000
Epoch 376/2000
Epoch 377/2000
Epoch 378/2000
Epoch 379/2000
Epoch 380/2000
Epoch 381/2000
Epoch 382/2000
Epoch 383/2000
Epoch 384/2000
Epoch 385/2000
Epoch 386/2000
Epoch 387/2000
Epoch 388

In [None]:
info = music_generator.generate(
    ["START"], ["0.0"], max_tokens=50, temperature=0.5
)
midi_stream = info[-1]["midi"].chordify()

In [None]:
timestr = time.strftime("%Y%m%d-%H%M%S")
midi_stream.write(
    "midi",
    fp=os.path.join(
        "/content/drive/MyDrive/music/",
        "output-" + timestr + ".mid",
    ),
)

'/content/drive/MyDrive/music/output-20231230-140018.mid'