In [17]:
!pip install pretty_midi -q
!pip install torch>=2.0.0 -q
!pip install tensorboardX -q
!pip install lightning -q

import torch
print(torch.__version__)

[0m2.0.0


In [18]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 25 16:28:14 2020

@author: prang
@edited: carvalho
"""

import bz2
import os
from abc import ABC, abstractmethod
from bisect import bisect_left
from operator import attrgetter

import _pickle as cPickle
import librosa
import numpy as np
import pretty_midi
import torch  # type: ignore

# %%

# MIDI extensions
EXT = ['.mid', '.midi', '.MID', '.MIDI']
# Primes number
PRIMES = [43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107,
          109, 113, 127, 131, 137, 149, 157, 163, 167, 173, 179, 191, 197,
          211, 223, 227, 233, 239, 251, 257, 263, 269, 277, 281, 293, 307,
          311, 317, 331, 337, 347, 353, 359, 367, 373, 379, 383, 389, 397,
          401, 409, 419, 431, 439, 443, 449, 457, 461, 467, 479, 487, 491,
          499, 503, 509, 521, 541, 547, 557, 563, 569, 577, 587, 593, 599,
          607, 613, 617, 631, 641, 647, 653, 659, 673, 677, 683, 691, 701,
          709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797,
          809, 821, 827, 839, 853, 857, 863, 877, 881, 887, 907, 911, 919,
          929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013,
          1019, 1031, 1039, 1049, 1061, 1069, 1087, 1091, 1097, 1103, 1109,
          1117, 1123, 1129, 1151, 1163, 1171, 1181, 1187, 1193, 1201, 1213,
          1217, 1223, 1229, 1237, 1249, 1259, 1277, 1283, 1289, 1297, 1301,
          1307, 1319, 1327, 1361, 1367, 1373, 1381, 1399, 1409, 1423, 1427,
          1433, 1439, 1447, 1451, 1459, 1471, 1481, 1487, 1493, 1499, 1511,
          1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597,
          1601, 1607, 1613, 1619, 1627, 1637, 1657, 1663, 1667, 1693, 1697,
          1709, 1721, 1733, 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1801,
          1811, 1823, 1831, 1847, 1861, 1867, 1871, 1877, 1889, 1901, 1907,
          1913, 1931, 1949, 1973, 1979, 1987, 1993, 1997, 2003, 2011, 2017,
          2027, 2039, 2053, 2063]


# Usefull functions
def takeClosest(myList, myNumber):
    """
    Assumes myList is sorted. Returns closest value to myNumber.

    If two numbers are equally close, return the smallest number.
    """
    pos = bisect_left(myList, myNumber)
    if pos == 0:
        return myList[0]
    if pos == len(myList):
        return myList[-1]
    before = myList[pos - 1]
    after = myList[pos]
    if after - myNumber < myNumber - before:
        return after
    else:
        return before


# Abstract class for the different input representations
class Representation(ABC):

    nb_bars = 0
    prbar_path = ""
    barfiles = []

    def __init__(self, root_dir, nbframe_per_bar=16, mono=False, export=False):
        """
        Args:
            root_dir (string) : Path of the directory with all the MIDI files
            nbframe_per_bar (int) : Number of frame contained in a bar
            export (bool) : Force the bar to be exported in .pt file or not
        """
        assert any((fname.endswith(tuple(EXT))) for fname in os.listdir(
            root_dir)), "There are no MIDI files in %s" % root_dir
        # root directory path which contains the files
        self.rootdir = root_dir
        # midi files names
        self.midfiles = [fname for fname in os.listdir(
            root_dir) if (fname.endswith(tuple(EXT)))]
        # Number of frame per bar
        self.nbframe_per_bar = nbframe_per_bar
        # Force export or not
        self.export = export
        # Monophonic data (separate voices)
        self.mono = mono
        # number of tracks contained in the dataset
        self.nb_tracks = len(self.midfiles)

    def __len__(self):
        """
        Return the total number of bars
        """
        return self.nb_bars  # type: ignore

    def __getitem__(self, index):
        """
        Return the tensor representation of the bar at the given index

        NOTE: LOADING THE TENSOR FROM THE .PT FILE TAKES A LONGER TIME AND SPACE THAN BZ2
        """
        # return torch.load(f"{self.prbar_path}/{self.barfiles[index]}")
        with bz2.BZ2File(f"{self.prbar_path}/{self.barfiles[index]}", "rb") as filepath:
            return cPickle.load(filepath)

    @abstractmethod
    def per_bar_export(self):
        """
        This function take all the midi files, load them into a pretty_midi object.
        For a complete documentation of pretty_midi go to :
            http://craffel.github.io/pretty-midi/

        The midi file is then processed to obtain the given representation of each bar with
        Finally, it will export each of theses bars in a separate .pt
        """
        pass


################################# PIANO-ROLL #################################

class Pianoroll(Representation):

    def __init__(self, root_dir, nbframe_per_bar=16, mono=False, export=False):
        super().__init__(root_dir, nbframe_per_bar=nbframe_per_bar, mono=mono, export=export)
        # Path witch contains the sliced piano-roll
        if mono:
            self.prbar_path = root_dir + \
                "/pianoroll_bar_mono_" + str(self.nbframe_per_bar)
        else:
            self.prbar_path = root_dir + \
                "/pianoroll_bar_" + str(self.nbframe_per_bar)

        if not os.path.exists(self.prbar_path):
            try:
                os.mkdir(self.prbar_path)
            except OSError:
                print("Creation of the directory %s failed" % self.prbar_path)
            else:
                print("Successfully created the directory %s " %
                      self.prbar_path)
            # Export the piano-roll bat
            self.per_bar_export()
        else:
            if export:
                self.per_bar_export()
        # .pt files names
        self.barfiles = [fname for fname in os.listdir(
            self.prbar_path) if fname.endswith('.pbz2')]  # or fname.endswith('.pt')]
        # total number of bars
        self.nb_bars = len(self.barfiles)

    def sliced_and_save_pianoroll(self, pianoroll, downbeats, fs, num_bar):

        for i in range(len(downbeats)-1):
            sp = pianoroll[:, int(round(downbeats[i]*fs))                           :int(round(downbeats[i+1]*fs))-1]
            if sp.shape[1] > 256:
                sp = sp[:, 0:256]
            elif sp.shape[1] < 256 and sp.shape[1] > 0:
                sp = np.pad(sp, ((0, 0), (0, 256 - sp.shape[1])), 'edge')
            if sp.shape[1] > 0:
                # downsample
                sp = sp[:, ::int(256/self.nbframe_per_bar)]
                # convert to tensor
                sp = torch.Tensor(sp)
                assert (
                    sp.shape[1] == self.nbframe_per_bar), "Error, a piano-roll have the wrong size : %s" % sp.shape[1]
                # binarize
                sp[sp != 0] = 1

                # Save the tensor
                # torch.save(sp.permute(1, 0), f"{self.prbar_path}/prbar{str(num_bar)}.pt")
                with bz2.BZ2File(f"{self.prbar_path}/prbar{str(num_bar)}.pbz2", "w") as filepath:
                    cPickle.dump(sp.permute(1, 0), filepath)

                num_bar += 1

        return num_bar

    def per_bar_export(self):

        num_error = 0
        num_bar = 0
        # load each .mid file in a pretty_midi object
        for index in range(len(self.midfiles)):
            try:
                midi_data = pretty_midi.PrettyMIDI(
                    self.rootdir + '/' + self.midfiles[index])
                downbeats = midi_data.get_downbeats()
                fs = 257 / (midi_data.get_end_time() / len(downbeats))
                # If monophonic data is required, we separate each voice
                if self.mono:
                    for inst in midi_data.instruments:
                        if not inst.is_drum:
                            pianoroll = inst.get_piano_roll(fs=fs)
                            num_bar = self.sliced_and_save_pianoroll(
                                pianoroll, downbeats, fs, num_bar)
                else:
                    pianoroll = midi_data.get_piano_roll(fs=fs)  # type: ignore
                    num_bar = self.sliced_and_save_pianoroll(
                        pianoroll, downbeats, fs, num_bar)
            except KeyError:
                num_error += 1
        print("total number of file : ", len(self.midfiles))
        print('num error : ', num_error)


################################# MIDI-like ##################################

class Midilike(Representation):

    def __init__(self, root_dir, nbframe_per_bar=16, mono=False, export=False):
        super().__init__(root_dir, nbframe_per_bar=nbframe_per_bar, mono=mono, export=export)
        # One hot encoding of the vocabulary
        self.vocabulary = self.get_vocab_encoding()
        # Path witch contains the sliced piano-roll
        self.prbar_path = root_dir + "/MIDIlike_bar"
        if not os.path.exists(self.prbar_path):
            try:
                os.mkdir(self.prbar_path)
            except OSError:
                print("Creation of the directory %s failed" % self.prbar_path)
            else:
                print("Successfully created the directory %s " %
                      self.prbar_path)
            # Export the piano-roll bat
            self.per_bar_export()
        else:
            if export:
                self.per_bar_export()
        # .pt files names
        self.barfiles = [fname for fname in os.listdir(
            self.prbar_path) if fname.endswith('.pbz2')]
        # total number of bars
        self.nb_bars = len(self.barfiles)

    def get_vocab_encoding(self):
        """
        Return a dictionnary with the corresponding indexes of every word contained in
        the vocabulary (one hot encoding).

        e.g : vocab = {'NOTE_ON<1>' : 0
                       'NOTE_ON<2>' : 1
                           ...
                       'TIME_SHIFT<1> : 128
                           ...             }
        """
        vocab = {}
        current_ind = 0

        # All the NOTE_ON events
        rootstr = "NOTE_ON<"
        for i in range(0, 128):
            event = rootstr + str(i) + '>'
            vocab[event] = current_ind
            current_ind += 1
        # All the NOTE_OFF events
        rootstr = "NOTE_OFF<"
        for i in range(0, 128):
            event = rootstr + str(i) + '>'
            vocab[event] = current_ind
            current_ind += 1
        # All the TIME_SHIFT events
        rootstr = "TIME_SHIFT<"
        for i in range(10, 1001, 10):
            event = rootstr + str(i) + '>'
            vocab[event] = current_ind
            current_ind += 1
        # All the SET_VELOCITY events
        rootstr = "SET_VELOCITY<"
        for i in range(0, 128, 4):
            event = rootstr + str(i) + '>'
            vocab[event] = current_ind
            current_ind += 1
        # The NOTHING event
        vocab['NOTHING'] = current_ind

        return vocab

    def string_representation(self, Vinst):
        """
        Return the representation with string ("NOTE_ON<56>, ...) from the
        corresponding integer representation Vinst (list of int)
        """
        str_rep = []
        for i in Vinst:
            str_rep.append(list(self.vocabulary.keys())[
                           list(self.vocabulary.values()).index(int(i))])

        return str_rep

    def per_bar_export(self):
        """
        This function take all the midi files, load them into a pretty_midi object.
        For a complete documentation of pretty_midi go to :
            http://craffel.github.io/pretty-midi/

        The midi file is then processed to obtain a MIDI-like event-based representation.
        More info on this representation here :
            https://arxiv.org/pdf/1809.04281.pdf

        Ex :  SET_VELOCITY<80>, NOTE_ON<60>
              TIME_SHIFT<500>, NOTE_ON<64>
              TIME_SHIFT<500>, NOTE_ON<67>
              TIME_SHIFT<1000>, NOTE_OFF<60>, NOTE_OFF<64>, NOTE_OFF<67>
              TIME_SHIFT<500>, SET_VELOCITY<100>, NOTE_ON<65>
              TIME_SHIFT<500>, NOTE_OFF<65>

        Finally, it will export each of theses bars in a separate .pt
        """
        # number of error with the key analyzer
        num_error = 0
        # array of all the tensor representing each bar
        all_bars = []
        # load each .mid file in a pretty_midi object
        for index in range(len(self.midfiles)):
            try:
                midi_data = pretty_midi.PrettyMIDI(
                    self.rootdir + '/' + self.midfiles[index])
                # start_time = midi_data.estimate_beat_start()
                downbeats = midi_data.get_downbeats()
                current_velocity = 64.
                # Possible value for the velocity
                velocity_list = [i for i in range(0, 128, 4)]
                # Possible value for the time shifts
                timeshift_list = [i for i in range(10, 1001, 10)]
                for i in range(len(downbeats)-1):
                    list_notes = []
                    V = []
                    for inst in midi_data.instruments:
                        if not inst.is_drum:
                            for n in inst.notes:
                                if (n.start < downbeats[i+1] and n.end >= downbeats[i]):
                                    list_notes.append(n)
                    # Sort list by pitch
                    list_notes = sorted(
                        list_notes, key=lambda i: i.pitch, reverse=False)
                    if len(list_notes) == 0:
                        gap = (downbeats[i+1] - downbeats[i])*1000
                        while (gap > 1000):
                            V.append(self.vocabulary['TIME_SHIFT<1000>'])
                            gap = gap - 1000
                        timeshift = takeClosest(timeshift_list, gap)
                        V.append(
                            self.vocabulary['TIME_SHIFT<' + str(timeshift) + '>'])
                    else:
                        # iterate over list_notes to construct the representation
                        current_time = downbeats[i]
                        while (list_notes):
                            closest_note_on = min(
                                list_notes, key=attrgetter('start'))
                            closest_note_off = min(
                                list_notes, key=attrgetter('end'))
                            if closest_note_off.end > closest_note_on.start:
                                gap = (closest_note_on.start -
                                       current_time)*1000
                                if gap > timeshift_list[0]/2:
                                    while (gap > 1000):
                                        V.append(
                                            self.vocabulary['TIME_SHIFT<1000>'])
                                        gap = gap - 1000
                                    timeshift = takeClosest(
                                        timeshift_list, gap)
                                    V.append(
                                        self.vocabulary['TIME_SHIFT<' + str(timeshift) + '>'])
                                if takeClosest(velocity_list, closest_note_on.velocity) != current_velocity:
                                    veloc = takeClosest(
                                        velocity_list, closest_note_on.velocity)
                                    V.append(
                                        self.vocabulary['SET_VELOCITY<' + str(veloc) + '>'])
                                    current_velocity = veloc
                                V.append(
                                    self.vocabulary['NOTE_ON<' + str(closest_note_on.pitch) + '>'])
                                if closest_note_on.start > current_time:
                                    current_time = closest_note_on.start
                                if closest_note_on.end > downbeats[i+1]:
                                    list_notes.remove(closest_note_on)
                                else:
                                    # Set a value > end to start to not taking it in account anymore
                                    closest_note_on.start = closest_note_on.end + 10
                            else:
                                gap = (closest_note_off.end -
                                       current_time)*1000
                                if gap > timeshift_list[0]/2:
                                    while (gap > 1000):
                                        V.append(
                                            self.vocabulary['TIME_SHIFT<1000>'])
                                        gap = gap - 1000
                                    timeshift = takeClosest(
                                        timeshift_list, gap)
                                    V.append(
                                        self.vocabulary['TIME_SHIFT<' + str(timeshift) + '>'])
                                V.append(
                                    self.vocabulary['NOTE_OFF<' + str(closest_note_off.pitch) + '>'])
                                current_time = closest_note_off.end
                                list_notes.remove(closest_note_off)
                    # Store the tensor in all_bars
                    all_bars.append(torch.tensor(V))
            except KeyError:
                num_error += 1
        print('num error : ', num_error)
        # Cleaning of the tensor : supressing ones with more than 160 events
        # and padding to have a constant size equal to 160
        empty_bar = False
        total_num = 0
        for i, vec in enumerate(all_bars):
            # add the empty bar only one time
            if len(vec) == 1:
                if not empty_bar:
                    clean_vec = torch.tensor([self.vocabulary['NOTHING']]*64)
                    clean_vec[0] = vec
                    # torch.save(clean_vec.unsqueeze(
                    #     1), self.prbar_path + "/Mlikebar_" + str(i) + ".pt")
                    with bz2.BZ2File(f"{self.prbar_path}/Mlikebar_{str(i)}.pbz2", "w") as filepath:
                        cPickle.dump(clean_vec.unsqueeze(1), filepath)
                    empty_bar = True
                    total_num += 1
            elif len(vec) < 64:
                clean_vec = torch.tensor([self.vocabulary['NOTHING']]*64)
                clean_vec[:len(vec)] = vec
                # torch.save(clean_vec.unsqueeze(1), self.prbar_path +
                #            "/Mlikebar_" + str(i) + ".pt")
                with bz2.BZ2File(f"{self.prbar_path}/Mlikebar_{str(i)}.pbz2", "w") as filepath:
                    cPickle.dump(clean_vec.unsqueeze(1), filepath)
                total_num += 1
            elif len(vec) == 64:
                # torch.save(vec.unsqueeze(1), self.prbar_path +
                #            "/Mlikebar_" + str(i) + ".pt")
                with bz2.BZ2File(f"{self.prbar_path}/Mlikebar_{str(i)}.pbz2", "w") as filepath:
                    cPickle.dump(vec.unsqueeze(1), filepath)
                total_num += 1
        print("Initial number of bar : {}\n \
               After cleaning : {}\n \
               Number of suppression : {}".format(len(all_bars), total_num, len(all_bars) - total_num))


############################### MIDI-like mono ###############################

class Midimono(Representation):

    def __init__(self, root_dir, nbframe_per_bar=16, mono=True, export=False):
        super().__init__(root_dir, nbframe_per_bar=nbframe_per_bar, mono=mono, export=export)
        # Path witch contains the sliced piano-roll
        self.prbar_path = root_dir + "/MIDIMono_bar"
        if not os.path.exists(self.prbar_path):
            try:
                os.mkdir(self.prbar_path)
            except OSError:
                print("Creation of the directory %s failed" % self.prbar_path)
            else:
                print("Successfully created the directory %s " %
                      self.prbar_path)
            # Export the piano-roll bat
            self.per_bar_export()
        else:
            if export:
                self.per_bar_export()
        # .pt files names
        self.barfiles = [fname for fname in os.listdir(
            self.prbar_path) if fname.endswith('.pbz2')]
        # total number of bars
        self.nb_bars = len(self.barfiles)

    def get_polyphonic_bars(self, pr_dataset):

        indices = set()
        for i in range(len(pr_dataset)):
            for j, frame in enumerate(pr_dataset[i]):
                if frame.nonzero().nelement() > 1:
                    indices.add(i)

        return indices

    def to_pianoroll(self, v):

        pianoroll = torch.zeros(16, 128)
        current_note = -1
        for i, e in enumerate(v):
            if e < 128:
                pianoroll[i, int(e)] = 1
                current_note = int(e)
            elif e == 128:
                if current_note != 129:
                    pianoroll[i, current_note] = 1

        return pianoroll

    def per_bar_export(self):

        PR = Pianoroll(self.rootdir, nbframe_per_bar=16, mono=True)
        poly_bars = self.get_polyphonic_bars(PR)
        num_vec = 0
        for i in range(len(PR)):
            if i not in poly_bars:
                vec = torch.zeros(16)
                current_note = -1
                for j, frame in enumerate(PR[i]):
                    if frame.nonzero().nelement() == 0:
                        if current_note != 129 and current_note != -1:
                            # note_off event
                            vec[j] = 129
                            current_note = 129
                        else:
                            # rest event
                            vec[j] = 128
                            if current_note == -1:
                                current_note = 129
                    else:
                        if current_note == int(frame.nonzero()):
                            # rest event
                            vec[j] = 128
                        else:
                            # note_on event
                            vec[j] = int(frame.nonzero())
                            current_note = int(frame.nonzero())
                # Save the tensor
                # torch.save(vec.unsqueeze(1), self.prbar_path +
                #            "/MVAEbar_" + str(num_vec) + ".pt")
                with bz2.BZ2File(f"{self.prbar_path}/MVAEbar_{str(num_vec)}.pbz2", "w") as filepath:
                    cPickle.dump(vec.unsqueeze(1), filepath)
                num_vec += 1


################################# NoteTuple ##################################

class Notetuple(Representation):

    def __init__(self, root_dir, nbframe_per_bar=16, mono=False, export=False):
        super().__init__(root_dir, nbframe_per_bar=nbframe_per_bar, mono=mono, export=export)
        # vocabs
        self.vocabs = self.get_vocabs_encoding()
        self.ts_major = self.vocabs[0]
        self.ts_minor = self.vocabs[1]
        self.dur_major = self.vocabs[2]
        self.dur_minor = self.vocabs[3]
        # Path witch contains the sliced piano-roll
        self.prbar_path = root_dir + "/NoteTuple_bar"
        if not os.path.exists(self.prbar_path):
            try:
                os.mkdir(self.prbar_path)
            except OSError:
                print("Creation of the directory %s failed" % self.prbar_path)
            else:
                print("Successfully created the directory %s " %
                      self.prbar_path)
            # Export the bar
            self.per_bar_export()
        else:
            if export:
                self.per_bar_export()
        # .pt files names
        self.barfiles = [fname for fname in os.listdir(
            self.prbar_path) if fname.endswith('.pbz2')]
        # total number of bars
        self.nb_bars = len(self.barfiles)

    def get_vocabs_encoding(self):
        # timeshift major_ticks_vocab
        ts_major = {}
        ind = 0
        for val in [i for i in range(0, 9601, 800)]:
            ts_major[val] = ind
            ind += 1
        ts_major[-1] = ind
        # timeshift minor_ticks_vocab
        ts_minor = {}
        ind = 0
        for val in [i for i in range(0, 800, 10)]:
            ts_minor[val] = ind
            ind += 1
        ts_minor[-1] = ind
        # duration major_ticks_vocab
        dur_major = {}
        ind = 0
        for val in [i for i in range(0, 9501, 500)]:
            dur_major[val] = ind
            ind += 1
        dur_major[-1] = ind
        # duration minor_ticks_vocab
        dur_minor = {}
        ind = 0
        for val in [i for i in range(0, 500, 10)]:
            dur_minor[val] = ind
            ind += 1
        dur_minor[-1] = ind
        return ts_major, ts_minor, dur_major, dur_minor

    def value_to_class(self, bar):
        # Change the value of the timeshift and duration to a class number
        for i, tupl in enumerate(bar):
            for j, v in enumerate(tupl):
                if j == 0:
                    bar[i][j] = self.ts_major[int(v)]
                if j == 1:
                    bar[i][j] = self.ts_minor[int(v)]
                if j == 2:
                    if v == -1:
                        bar[i][j] = 128
                if j == 3:
                    bar[i][j] = self.dur_major[int(v)]
                if j == 4:
                    bar[i][j] = self.dur_minor[int(v)]
        return bar

    def class_to_value(self, bar):
        # Change the class number of the timeshift and duration to the real value
        for i, tupl in enumerate(bar):
            for j, v in enumerate(tupl):
                if j == 0:
                    bar[i][j] = list(self.ts_major.keys())[
                        list(self.ts_major.values()).index(int(v))]
                if j == 1:
                    bar[i][j] = list(self.ts_minor.keys())[
                        list(self.ts_minor.values()).index(int(v))]
                if j == 2:
                    if v == 128:
                        bar[i][j] = -1
                if j == 3:
                    bar[i][j] = list(self.dur_major.keys())[
                        list(self.dur_major.values()).index(int(v))]
                if j == 4:
                    bar[i][j] = list(self.dur_minor.keys())[
                        list(self.dur_minor.values()).index(int(v))]
        return bar

    def per_bar_export(self):
        num_error = 0
        # to store all the bars
        all_bars = []
        for index in range(len(self.midfiles)):
            try:
                # load each .mid file in a pretty_midi object
                midi_data = pretty_midi.PrettyMIDI(
                    self.rootdir + '/' + self.midfiles[index])
                # start_time = midi_data.estimate_beat_start()
                downbeats = midi_data.get_downbeats()
                # Possible value for the time shifts (from 0 to 10s)
                # 13 major ticks
                timeshift_major_ticks = [i for i in range(0, 9601, 800)]
                # 77 minor ticks
                timeshift_minor_ticks = [i for i in range(0, 800, 10)]
                # Possible value for the duration
                # 13 major ticks
                dur_major_ticks = [i for i in range(0, 9501, 500)]
                # 77 minor ticj=ks
                dur_minor_ticks = [i for i in range(0, 500, 10)]
                for i in range(len(downbeats)-1):
                    list_notes = []
                    V = []
                    for inst in midi_data.instruments:
                        if not inst.is_drum:
                            for n in inst.notes:
                                if (n.start < downbeats[i+1] and n.start >= downbeats[i]):
                                    list_notes.append(n)
                    list_notes = sorted(
                        list_notes, key=lambda i: i.pitch, reverse=False)
                    # iterate over list_notes to construct the representation
                    current_time = downbeats[i]
                    while (list_notes):
                        closest_note_on = min(
                            list_notes, key=attrgetter('start'))
                        time_shift = (closest_note_on.start -
                                      current_time)*1000
                        tmat = timeshift_major_ticks[int(time_shift//800)]
                        tmit = timeshift_minor_ticks[int(
                            (time_shift % 800)//10)]
                        duration = (closest_note_on.end -
                                    closest_note_on.start)*1000
                        dmat = dur_major_ticks[int(duration//500)]
                        dmit = dur_minor_ticks[int((duration % 500)//10)]
                        current_time = closest_note_on.start
                        V.append(
                            (tmat, tmit, closest_note_on.pitch, dmat, dmit))
                        list_notes.remove(closest_note_on)
                    # Store the tensor in all_bars
                    all_bars.append(torch.tensor(V))
            except KeyError:
                num_error += 1
        print('num error : ', num_error)
        # Save all tensor
        total_num = 0
        for i, vec in enumerate(all_bars):
            if len(vec) < 32 and len(vec) > 0:
                clean_vec = torch.zeros(32, 5).fill_(-1)
                clean_vec[:len(vec)] = vec
                clean_vec = self.value_to_class(clean_vec)
                # torch.save((clean_vec, len(vec)), self.prbar_path +
                #            "/Ntuplebar" + str(i) + ".pt")
                with bz2.BZ2File(f"{self.prbar_path}/Ntuplebar{str(i)}.pbz2", "w") as filepath:
                    cPickle.dump((clean_vec, len(vec)), filepath)
                total_num += 1
            elif len(vec) == 32:
                vec = self.value_to_class(vec)
                # torch.save((clean_vec, len(vec)), self.prbar_path +
                #            "/Ntuplebar" + str(i) + ".pt")
                with bz2.BZ2File(f"{self.prbar_path}/Ntuplebar{str(i)}.pbz2", "w") as filepath:
                    cPickle.dump((clean_vec, len(vec)),  # type: ignore
                                 filepath)
                total_num += 1
        print("Initial number of bar : {}\n \
               After cleaning : {}\n \
               Number of suppression : {}".format(len(all_bars), total_num, len(all_bars) - total_num))


################################# Signal-like ################################

class Signallike(Representation):

    def __init__(self, root_dir, nbframe_per_bar=16, mono=False, export=False):
        super().__init__(root_dir, nbframe_per_bar=nbframe_per_bar, mono=mono, export=export)
        # Path to export the .pt files
        if self.mono:
            self.prbar_path = root_dir + \
                "/Signallike_bar_mono_" + str(self.nbframe_per_bar)
        else:
            self.prbar_path = root_dir + \
                "/Signallike_bar_" + str(self.nbframe_per_bar)
        if not os.path.exists(self.prbar_path):
            try:
                os.mkdir(self.prbar_path)
            except OSError:
                print("Creation of the directory %s failed" % self.prbar_path)
            else:
                print("Successfully created the directory %s " %
                      self.prbar_path)
            # Export the piano-roll bat
            self.per_bar_export()
        else:
            if export:
                self.per_bar_export()
        # .pt files names
        self.barfiles = [fname for fname in os.listdir(
            self.prbar_path) if fname.endswith('.pbz2')]
        # total number of bars
        self.nb_bars = len(self.barfiles)
        # Size of the signal representation
        self.signal_size = len(self.__getitem__(0).flatten())

    def back_to_pianoroll(self, V):
        """
        Inverse the process : get a piano-roll from a signal-like representation V.
        """
        PR = ((np.abs(librosa.core.stft(V, n_fft=2048, window='blackman'))))[
            PRIMES[:128]]
        return abs(PR)

    def get_polyphonic_bars(self, pr_dataset):
        indices = set()
        for i in range(len(pr_dataset)):
            for j, frame in enumerate(pr_dataset[i]):
                if frame.nonzero().nelement() > 1:
                    indices.add(i)
        return indices

    def per_bar_export(self):
        """
        This function take the self.midfiles[index], load it into a pretty_midi object.
        For a complete documentation of pretty_midi go to :
            http://craffel.github.io/pretty-midi/

        The midi file is then processed with stft to obtain a signal-like representation
        and exported in a .pt file.
        """
        PR = Pianoroll(
            self.rootdir, nbframe_per_bar=self.nbframe_per_bar, mono=self.mono)
        if self.mono:
            poly_bars = self.get_polyphonic_bars(PR)
        else:
            poly_bars = []
        for i in range(len(PR)):
            if i not in poly_bars:
                final_vals = np.zeros(
                    (1025, PR[i].permute(1, 0).shape[1])).astype(complex)
                final_vals[PRIMES[:128], :] = np.array(PR[i].permute(
                    1, 0)) + 1j * ((np.array(PR[i].permute(1, 0)) > 0))
                V = torch.Tensor(librosa.core.istft(
                    final_vals, window='blackman'))
                # torch.save(V.reshape(64, -1), self.prbar_path +
                #            "/Slikebar_" + str(i) + ".pt")
                with bz2.BZ2File(f"{self.prbar_path}/Slikebar_{str(i)}.pbz2", "w") as filepath:
                    cPickle.dump(V.reshape(64, -1), filepath)

################################# DFT-128 ################################


def dft_reduction(data, normalize=False, return_complex=False, only_dft=False):
    """GET DFT"""
    dft = np.fft.fft(data)

    if normalize:
        # GET ENERGY
        energy = dft[0].real
        # REDUCE AND NORMALIZE DFT
        reduced_dft = dft[1: int(len(dft) / 2.0) + 1]
        norm_dft = [df / energy for df in reduced_dft]
        # GET MAGNITUDE
        mag = [abs(CP) for CP in norm_dft]
    else:
        norm_dft = dft
        energy = dft[0].real
        mag = [abs(CP) for CP in norm_dft]

    if return_complex:
        if only_dft:
            return norm_dft
        return norm_dft, energy, mag

    real_dft = []
    for complex_coefficient in norm_dft:
        real_dft.append(complex_coefficient.real)
        real_dft.append(complex_coefficient.imag)

    if only_dft:
        return real_dft
    # RETURN
    return real_dft, energy, mag


class DFT128(Representation):
    def __init__(self, root_dir, nbframe_per_bar=16, mono=False, export=False, use_symmetry=True):
        super().__init__(root_dir, nbframe_per_bar=nbframe_per_bar, mono=mono, export=export)

        self.prbar_path = root_dir + "/DFT128_bar" + str(self.nbframe_per_bar)
        self.use_symmetry = use_symmetry

        if not os.path.exists(self.prbar_path):
            try:
                os.mkdir(self.prbar_path)
            except OSError:
                print("Creation of the directory %s failed" % self.prbar_path)
            else:
                print("Successfully created the directory %s " %
                      self.prbar_path)

            # Export the piano-roll bat
            self.per_bar_export()
        elif export:
            self.per_bar_export()

        # .pt files names
        self.barfiles = [fname for fname in os.listdir(
            self.prbar_path) if fname.endswith('.pbz2')]
        # total number of bars
        self.nb_bars = len(self.barfiles)
        # Size of the signal representation
        self.signal_size = len(self.__getitem__(0).flatten())

    def back_to_pianoroll(self, V):
        """
        Inverse the process : get a piano-roll from a DFT128 representation V.
        """
        return None

    def per_bar_export(self):
        """
        This function take the self.midfiles[index], load it into a pretty_midi object.
                For a complete documentation of pretty_midi go to :
            http://craffel.github.io/pretty-midi/

        The midi file is then processed with DFT encodings to obtain a DFT128 representation
        """
        PR = Pianoroll(
            self.rootdir, nbframe_per_bar=self.nbframe_per_bar, mono=self.mono)
        for i in range(len(PR)):
            dft_results = np.apply_along_axis(
                dft_reduction, 1, PR[i].numpy(), only_dft=True)
            if self.use_symmetry:
                dft_results = torch.Tensor(dft_results[:, :65*2])
            else:
                dft_results = torch.Tensor(dft_results)
            with bz2.BZ2File(f"{self.prbar_path}/DFT128bar_{str(i)}.pbz2", "w") as filepath:
                cPickle.dump(dft_results, filepath)

# Training Network

In [19]:
arguments = {
    '--path': '/kaggle/input/jsb-chorales-signallike-embeddings/dataset',
    #'--save_dt': '/kaggle/working/dataset',
    '--o': '/kaggle/working/output',
    '--runname': 'DFT128_01',
    '--save': True,

    '--gpu': True if torch.cuda.is_available() else False,
    '--mps': True if torch.backends.mps.is_available() else False,
    '--gpudev': 0,

    '--lr': 1e-4,
    '--bsize': 16,
    '--nbframe': 64,
    '--inputrep': 'dft128',
    '--epochs': 10,
}

print(arguments)

{'--path': '/kaggle/input/jsb-chorales-signallike-embeddings/dataset', '--o': '/kaggle/working/output', '--runname': 'DFT128_01', '--save': True, '--gpu': True, '--mps': False, '--gpudev': 0, '--lr': 0.0001, '--bsize': 16, '--nbframe': 64, '--inputrep': 'dft128', '--epochs': 10}


In [20]:
import os
import time

import lightning as L
import torch

from docopt import docopt
from tqdm import tqdm

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Set detect anomaly
torch.autograd.set_detect_anomaly(True)  # type: ignore

# Parameters
train_path = arguments['--path'] + '/train'
test_path = arguments['--path'] + '/test'
batch_size = int(arguments['--bsize'])
nb_frame = int(arguments['--nbframe'])
if arguments['--o'] == 'None':
    output_dr = os.getcwd() + '/output'
else:
    output_dr = arguments['--o']

# load the dataset
if arguments['--inputrep'] == "pianoroll":
    dataset = Pianoroll(train_path, nbframe_per_bar=nb_frame)
    testset = Pianoroll(test_path, nbframe_per_bar=nb_frame)
    input_dim = 128
    seq_length = nb_frame
elif arguments['--inputrep'] == "midilike":
    dataset = Midilike(train_path)
    testset = Midilike(test_path)
    input_dim = 1
elif arguments['--inputrep'] == "midimono":
    dataset = Midimono(train_path)
    testset = Midimono(test_path)
    input_dim = 1
elif arguments['--inputrep'] == "signallike":
    dataset = Signallike(
        train_path, nbframe_per_bar=nb_frame*2, mono=True)
    testset = Signallike(
        test_path, nbframe_per_bar=nb_frame*2, mono=True)
    input_dim = dataset.signal_size//64
elif arguments['--inputrep'] == "notetuple":
    dataset = Notetuple(train_path)
    testset = Notetuple(test_path)
    input_dim = 5
elif arguments['--inputrep'] == "dft128":
    dataset = DFT128(train_path, nbframe_per_bar=nb_frame)
    testset = DFT128(test_path, nbframe_per_bar=nb_frame)
    input_dim = 130
    seq_length = nb_frame
else:
    raise NotImplementedError(
        "Representation {} not implemented".format(arguments['--inputrep']))

# Init the dataloader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=2,  # type: ignore
                                          pin_memory=True, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=2,  # type: ignore
                                          pin_memory=True, shuffle=False, drop_last=True)

In [21]:
# Model parameters
enc_hidden_size = 1024
cond_hidden_size = 1024
dec_hidden_size = 1024
cond_outdim = 512
num_layers_enc = 2
num_layers_dec = 2
num_subsequences = 4
latent_size = 256

if arguments['--inputrep'] in ['pianoroll', 'signallike', 'dft128']:
    output_dim = input_dim
elif arguments['--inputrep'] == "midilike":
    output_dim = len(dataset.vocabulary)  # type: ignore
    seq_length = 64
elif arguments['--inputrep'] == "midimono":
    output_dim = 130
    seq_length = 16
elif arguments['--inputrep'] == "notetuple":
    output_dim = sum([len(v) for v in
                      dataset.vocabs]) + 129  # type: ignore
    seq_length = 32

device = 'cpu'
if arguments['--gpu'] and torch.cuda.is_available():  # type: ignore
    device = 'cuda'
elif arguments['--mps'] and torch.backends.mps.is_available() and torch.backends.mps.is_built():  # type: ignore
    device = 'mps'

# Instanciate model
encoder = Encoder_RNN(input_dim, enc_hidden_size,
                        latent_size, num_layers_enc, device=device)
decoder = Decoder_RNN_hierarchical(output_dim, latent_size, cond_hidden_size,  # type: ignore
                                     cond_outdim, dec_hidden_size=dec_hidden_size, num_layers=num_layers_dec,
                                     num_subsequences=num_subsequences, seq_length=seq_length)  # type: ignore

if arguments['--inputrep'] == "notetuple":
    model = LightningVAE(encoder, decoder, arguments['--inputrep'],
                           vocab=dataset.vocabs)  # type: ignore
else:
    model = LightningVAE(encoder, decoder, arguments['--inputrep'])

In [22]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 25 13:41:24 2020

@author: prang
"""

import random
from typing import Any

import lightning as L
import torch  # type: ignore
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn  # type: ignore


class Encoder_RNN(nn.Module):

    def __init__(self, input_dim, hidden_size, latent_size, num_layers,
                 dropout=0.5, packed_seq=False, device='cpu'):
        """ This initializes the encoder """
        super(Encoder_RNN, self).__init__()

        # Parameters
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.packed_seq = packed_seq
        self.batch_first = True
        self.device = device

        # Layers
        self.RNN = nn.LSTM(input_dim, hidden_size, batch_first=self.batch_first,
                           num_layers=num_layers, bidirectional=True,
                           dropout=dropout)

    def forward(self, x, h0, c0, batch_size):

        # Pack sequence if needed
        if self.packed_seq:
            x = torch.nn.utils.rnn.pack_padded_sequence(x[0], x[1],
                                                        batch_first=self.batch_first,
                                                        enforce_sorted=False)
        # Forward pass
        _, (h, _) = self.RNN(x, (h0, c0))

        # Be sure to not have NaN values
        assert ((h == h).all()), 'NaN value in the output of the RNN, try to \
                                lower your learning rate'
        h = h.view(self.num_layers, 2, batch_size, -1)
        h = h[-1]
        h = torch.cat([h[0], h[1]], dim=1)

        return h

    def init_hidden(self, batch_size=1):
        # Bidirectional -> num_layers * 2
        return (torch.zeros(self.num_layers * 2, batch_size, self.hidden_size,
                            dtype=torch.float, device=self.device),) * 2


class Decoder_RNN_hierarchical(nn.Module):

    def __init__(self, input_size, latent_size, cond_hidden_size, cond_outdim,
                 dec_hidden_size, num_layers, num_subsequences, seq_length,
                 teacher_forcing_ratio=0, dropout=0.5):
        """ This initializes the decoder """
        super(Decoder_RNN_hierarchical, self).__init__()

        # Parameters
        self.num_subsequences = num_subsequences
        self.input_size = input_size
        self.num_layers = num_layers
        self.seq_length = seq_length
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.subseq_size = self.seq_length // self.num_subsequences

        # Layers
        self.tanh = nn.Tanh()
        self.fc_init_cond = nn.Linear(
            latent_size, cond_hidden_size * num_layers)
        self.conductor_RNN = nn.LSTM(latent_size // num_subsequences, cond_hidden_size,
                                     batch_first=True, num_layers=num_layers,
                                     bidirectional=False, dropout=dropout)
        self.conductor_output = nn.Linear(cond_hidden_size, cond_outdim)
        self.fc_init_dec = nn.Linear(cond_outdim, dec_hidden_size * num_layers)
        self.decoder_RNN = nn.LSTM(cond_outdim + input_size, dec_hidden_size,
                                   batch_first=True, num_layers=num_layers,
                                   bidirectional=False, dropout=dropout)
        self.decoder_output = nn.Linear(dec_hidden_size, input_size)

    def forward(self, latent, target, batch_size, teacher_forcing, device):

        # Get the initial state of the conductor
        h0_cond = self.tanh(self.fc_init_cond(latent))
        h0_cond = h0_cond.view(self.num_layers, batch_size, -1).contiguous()
        # Divide the latent code in subsequences
        latent = latent.view(batch_size, self.num_subsequences, -1)
        # Pass through the conductor
        subseq_embeddings, _ = self.conductor_RNN(latent, (h0_cond,)*2)
        subseq_embeddings = self.conductor_output(subseq_embeddings)

        # Get the initial states of the decoder
        h0s_dec = self.tanh(self.fc_init_dec(subseq_embeddings))
        h0s_dec = h0s_dec.view(self.num_layers, batch_size,
                               self.num_subsequences, -1).contiguous()
        # Init the output seq and the first token to 0 tensors
        out = torch.zeros(batch_size, self.seq_length, self.input_size,
                          dtype=torch.float, device=device)
        token = torch.zeros(batch_size, self.subseq_size, self.input_size,
                            dtype=torch.float, device=device)
        # Autoregressivly output tokens
        for sub in range(self.num_subsequences):
            subseq_embedding = subseq_embeddings[:, sub, :].unsqueeze(1)
            subseq_embedding = subseq_embedding.expand(
                -1, self.subseq_size, -1)
            h0_dec = h0s_dec[:, :, sub, :].contiguous()
            c0_dec = h0s_dec[:, :, sub, :].contiguous()
            # Concat the previous token and the current sub embedding as input
            dec_input = torch.cat((token, subseq_embedding), -1)
            # Pass through the decoder
            token, (h0_dec, c0_dec) = self.decoder_RNN(
                dec_input, (h0_dec, c0_dec))
            token = self.decoder_output(token)
            # Fill the out tensor with the token
            out[:, sub*self.subseq_size: ((sub+1)*self.subseq_size), :] = token
            # If teacher_forcing replace the output token by the real one sometimes
            if teacher_forcing:
                if random.random() <= self.teacher_forcing_ratio:
                    token = target[:, sub *
                                   self.subseq_size: ((sub+1)*self.subseq_size), :]
        return out


class VAE(nn.Module):

    def __init__(self, encoder, decoder, input_representation, teacher_forcing=True, device='cpu'):
        super(VAE, self).__init__()
        """ This initializes the complete VAE """

        # Parameters
        self.input_rep = input_representation
        self.tf = teacher_forcing
        self.encoder = nn.ModuleList(encoder)
        self.decoder = nn.ModuleList(decoder)
        self.device = device

        # Layers
        self.hidden_to_mu = nn.Linear(
            2 * encoder.hidden_size, encoder.latent_size)
        self.hidden_to_sig = nn.Linear(
            2 * encoder.hidden_size, encoder.latent_size)

    def forward(self, x):

        if self.input_rep == 'notetuple':
            batch_size = x[0].size(0)
        else:
            batch_size = x.size(0)

        # Encoder pass
        h_enc, c_enc = self.encoder.init_hidden(batch_size)  # type: ignore
        hidden = self.encoder(x, h_enc, c_enc, batch_size)
        # Reparametrization
        mu = self.hidden_to_mu(hidden)
        sig = self.hidden_to_sig(hidden)
        eps = torch.randn_like(mu).detach().to(self.device)
        latent = (sig.exp().sqrt() * eps) + mu

        # Decoder pass
        if self.input_rep == 'midilike':
            # One hot encoding of the target for teacher forcing purpose
            target = torch.nn.functional.one_hot(x.squeeze(2).long(),
                                                 self.input_size).float()
            x_reconst = self.decoder(latent, target, batch_size,
                                     teacher_forcing=self.tf, device=self.device)
        else:
            x_reconst = self.decoder(latent, x, batch_size,
                                     teacher_forcing=self.tf, device=self.device)

        return mu, sig, latent, x_reconst

    def batch_pass(self, x, loss_fn, optimizer, w_kl, test=False):

        # Zero grad
        self.zero_grad()

        # Forward pass
        mu, sig, latent, x_reconst = self(x)

        # Compute losses
        kl_div = - 0.5 * torch.sum(1 + sig - mu.pow(2) - sig.exp())
        if self.input_rep in ["midilike", "MVAErep"]:
            reconst_loss = loss_fn(x_reconst.permute(
                0, 2, 1), x.squeeze(2).long())
        elif self.input_rep == "notetuple":
            x_reconst = x_reconst.permute(0, 2, 1)
            x_in, l = x
            loss_ts_maj = loss_fn(
                x_reconst[:, :len(self.vocab[0]), :],  # type: ignore
                x_in[:, :, 0].long())
            current = len(self.vocab[0])  # type: ignore

            loss_ts_min = loss_fn(
                x_reconst[:, current:current +
                          len(self.vocab[1]), :],  # type: ignore
                x_in[:, :, 1].long())
            current += len(self.vocab[1])  # type: ignore

            loss_pitch = loss_fn(
                x_reconst[:, current:current + 129, :], x_in[:, :, 2].long())
            current += 129

            loss_dur_maj = loss_fn(
                x_reconst[:, current:current +
                          len(self.vocab[2]), :],  # type: ignore
                x_in[:, :, 3].long())
            current += len(self.vocab[2])  # type: ignore

            loss_dur_min = loss_fn(
                x_reconst[:, current:current +
                          len(self.vocab[3]), :],  # type: ignore
                x_in[:, :, 4].long())
            reconst_loss = loss_ts_maj + loss_ts_min + \
                loss_pitch + loss_dur_maj + loss_dur_min
        else:
            reconst_loss = loss_fn(x_reconst, x)

        # Backprop and optimize
        if not test:
            loss = reconst_loss + (w_kl * kl_div)
            loss.backward()
            optimizer.step()
        else:
            loss = reconst_loss + kl_div

        return loss, kl_div, reconst_loss

    def generate(self, latent):

        # Create dumb target
        input_shape = (1, self.decoder.seq_length, self.decoder.input_size)
        db_trg = torch.zeros(input_shape)  # type: ignore
        # Forward pass in the decoder
        generated_bar = self.decoder(latent.unsqueeze(0), db_trg, batch_size=1,
                                     device=self.device, teacher_forcing=False)

        return generated_bar


class LightningVAE(L.LightningModule):

    def __init__(self, encoder, decoder, input_representation, vocab=None, teacher_forcing=True):
        super(LightningVAE, self).__init__()
        """ This initializes the complete VAE """
        # Parameters
        self.input_rep = input_representation
        self.tf = teacher_forcing
        self.encoder = encoder
        self.decoder = decoder

        self.w_kl = 0

        self.vocab = vocab
        if input_representation == 'notetuple' and vocab is None:
            raise ValueError(
                'Vocab must be provided for notetuple input representation')

        # Layers
        self.hidden_to_mu = nn.Linear(
            2 * encoder.hidden_size, encoder.latent_size)
        self.hidden_to_sig = nn.Linear(
            2 * encoder.hidden_size, encoder.latent_size)

        if input_representation in ['pianoroll', 'signallike']:
            self.loss_fn = torch.nn.MSELoss(reduction='sum')
        else:
            self.loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')

        self.save_hyperparameters(ignore=['encoder', 'decoder'])

    def forward(self, x):

        if self.input_rep == 'notetuple':
            batch_size = x[0].size(0)
        else:
            batch_size = x.size(0)

        # Encoder pass
        h_enc, c_enc = self.encoder.init_hidden(batch_size)
        hidden = self.encoder(x, h_enc, c_enc, batch_size)

        # Reparametrization
        mu = self.hidden_to_mu(hidden)
        sig = self.hidden_to_sig(hidden)
        eps = torch.randn_like(mu).detach().to(self.device)
        latent = (sig.exp().sqrt() * eps) + mu

        # Decoder pass
        if self.input_rep == 'midilike':
            # One hot encoding of the target for teacher forcing purpose
            target = torch.nn.functional.one_hot(x.squeeze(2).long(),
                                                 self.input_size).float()
            x_reconst = self.decoder(latent, target, batch_size,
                                     teacher_forcing=self.tf, device=self.device)
        else:
            x_reconst = self.decoder(latent, x, batch_size,
                                     teacher_forcing=self.tf, device=self.device)

        return mu, sig, latent, x_reconst

    def notetuple_reconstruction_loss(self, x_reconst, x):
        """Compute the reconstruction loss for a
        given input in notetuple format and its reconstruction"""
        x_reconst = x_reconst.permute(0, 2, 1)
        x_in, l = x
        loss_ts_maj = self.loss_fn(
            x_reconst[:, :len(self.vocab[0]), :],  # type: ignore
            x_in[:, :, 0].long())
        current = len(self.vocab[0])  # type: ignore

        loss_ts_min = self.loss_fn(
            x_reconst[:, current:current +
                      len(self.vocab[1]), :],  # type: ignore
            x_in[:, :, 1].long())
        current += len(self.vocab[1])  # type: ignore

        loss_pitch = self.loss_fn(
            x_reconst[:, current:current + 129, :],
            x_in[:, :, 2].long())
        current += 129

        loss_dur_maj = self.loss_fn(
            x_reconst[:, current:current +
                      len(self.vocab[2]), :],  # type: ignore
            x_in[:, :, 3].long())
        current += len(self.vocab[2])  # type: ignore

        loss_dur_min = self.loss_fn(
            x_reconst[:, current:current +
                      len(self.vocab[3]), :],  # type: ignore
            x_in[:, :, 4].long())
        reconst_loss = loss_ts_maj + loss_ts_min + \
            loss_pitch + loss_dur_maj + loss_dur_min

        return reconst_loss

    def compute_reconstruction_loss(self, x, x_reconst):
        """ Compute the reconstruction loss for a given input and its reconstruction """

        if self.input_rep in ["midilike", "MVAErep"]:
            reconst_loss = self.loss_fn(x_reconst.permute(
                0, 2, 1), x.squeeze(2).long())
        elif self.input_rep == "notetuple":
            reconst_loss = self.notetuple_reconstruction_loss(x_reconst, x)
        else:  # pianoroll, signallike
            reconst_loss = self.loss_fn(x_reconst, x)

        return reconst_loss

    def training_step(self, batch, batch_idx):
        x = batch

        # Zero grad
        self.zero_grad()

        # Forward pass
        mu, sig, _, x_reconst = self(x)

        # Compute losses
        kl_div = - 0.5 * torch.sum(1 + sig - mu.pow(2) - sig.exp())
        reconst_loss = self.compute_reconstruction_loss(x, x_reconst)

        # Backprop and optimize
        loss = reconst_loss + (self.w_kl * kl_div)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("train_kl_div", kl_div, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("train_reconst_loss", reconst_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return {"loss": loss, "kl_div": kl_div, "reconst_loss": reconst_loss}

    def validation_step(self, batch, batch_idx):
        x = batch

        # Zero grad
        self.zero_grad()

        # Forward pass
        mu, sig, _, x_reconst = self(x)

        # Compute losses
        kl_div = - 0.5 * torch.sum(1 + sig - mu.pow(2) - sig.exp())
        reconst_loss = self.compute_reconstruction_loss(x, x_reconst)

        # Backprop and optimize
        loss = reconst_loss + kl_div
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("val_kl_div", kl_div, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("val_reconst_loss", reconst_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return {"loss": loss, "kl_div": kl_div, "reconst_loss": reconst_loss}

    def on_train_epoch_end(self) -> None:
        """
        Called when the epoch ends.
        """
        epoch = self.current_epoch

        if self.input_rep in ["pianoroll"]:
            if epoch < 150 and epoch > 0 and epoch % 10 == 0:
                self.w_kl += 1e-5
            elif epoch > 150 and epoch % 10 == 0:
                self.w_kl += 1e-4
        elif self.input_rep in ["midilike", "signallike", "dft128"] and epoch % 10 == 0 and epoch > 0:
            self.w_kl += 1e-8
        elif self.input_rep == "midimono" and epoch % 10 == 0 and epoch > 0:
            self.w_kl += 1e-4
        elif self.input_rep == "notetuple" and epoch % 10 == 0 and epoch > 0:
            self.w_kl += 1e-6

        self.log("w_kl", self.w_kl, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        return super().on_train_epoch_end()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

In [23]:
os.makedirs(f'{output_dr}/{arguments["--runname"]}/models', exist_ok=True)

callbacks = [
    L.pytorch.callbacks.ModelCheckpoint(monitor='val_loss',  # type: ignore
                                        save_top_k=1, mode='min',
                                        dirpath=f'{output_dr}/{arguments["--runname"]}/models/',
                                        filename=arguments['--runname'] + \
                                        '-{epoch}-{val_loss:.2f}',
                                        save_last=True),
    L.pytorch.callbacks.EarlyStopping(monitor='val_loss',  # type: ignore
                                      patience=5,
                                      mode='min'),
    L.pytorch.callbacks.LearningRateMonitor(  # type: ignore
        logging_interval='step'),
]

trainer = L.Trainer(max_epochs=10, default_root_dir=f'{output_dr}/{arguments["--runname"]}/',
                    enable_checkpointing=True, callbacks=callbacks)

last_model = f'{output_dr}/{arguments["--runname"]}/models/last.ckpt'
if os.path.exists(last_model):
    trainer.fit(model,
                train_dataloaders=data_loader,
                val_dataloaders=test_loader,
                ckpt_path=last_model)
else:
    trainer.fit(model,
                train_dataloaders=data_loader,
                val_dataloaders=test_loader)

INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
INFO: Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
INFO: ----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO: Restoring states from the checkpoint path at /kaggle/working/output/DFT128_01/models/last.ckpt
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: 
  | Name          | Type                     | Params
-----------------------------------------------------------
0 | enc

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [None]:
!cd /kaggle/working
!zip -r model.zip output/DFT128_01/

from IPython.display import FileLink 
FileLink(r'model.zip')

updating: output/DFT128_01/ (stored 0%)
updating: output/DFT128_01/models/ (stored 0%)
updating: output/DFT128_01/models/DFT128_01-epoch=5-val_loss=-6961369.50.ckpt

In [None]:
!cd /kaggle/working
#!rm lightning_logs.zip
#!rm model.zip
#!rm -rf output
#!mkdir output/models/DFT128_01/
#!ls  output/models/DFT128_01
#!rm output/models/DFT128_01/last.ckpt
#!mv output/models/DFT128_01/last-v1.ckpt output/models/DFT128_01/last.ckpt

#!ls  output/models/DFT128_01
