# Part 0: imports & constants

In [6]:
import numpy as np
import music21 as m21
import pretty_midi as pm
import tempfile
import math

test_midi_folder = 'midiFiles/'

# Part 1: file to list of PCVs 

## Part 1.1 code

In [11]:
twelve_tones_vector_name = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#','A', 'A#', 'B']

altered_notation_dict = {
    'B#' : 'C',
    'D-' : 'C#',
    'E-' : 'D#',
    'F-' : 'E',
    'E#' : 'F',
    'G-' : 'F#',
    'A-' : 'G#',
    'B-' : 'A#',
    'C-' : 'B'
} 

pitch_pitch_dict = {x: x for x in twelve_tones_vector_name}

#In the end we want each string to match an index between 0 and 11, so it fits inside a 12-d vector.
pitch_index_dict = {twelve_tones_vector_name[i]:i for i in range(len(twelve_tones_vector_name))}

#So any pitch name given to this dict will be mapped to its cannonical form defined in 'twelve_tones_vector_name'
normalize_notation_dict = dict(altered_notation_dict,  **pitch_pitch_dict)

def recursively_map_offset(filename, only_note_name=True):
    '''
    This function will recursively walk through the Midi stream's elements, 
    and whenever it encounters a note, it will append its rhytmic 
    data to the pitch and then store the resulting data structure in an array.
    If a music21 element of type chord is encountered, the chord is decomposed into all
    notes it is composed by, and each of thoses notes are appended the 
    Returns the aforementionned array.
    The rhytmic data is expressed as a tuple of the offset of the beginning of the note
    and the offset of the end of the note.
    
    All temporal informations from MIDI events parsed by the music21 library are encoded
    as unit of quarter notes count, regardless of the bpm or the time signature.
    
    Params: 
    midi stream: the MIDI stream containing all the relevant infos
    flatten: Boolean, indicating whether or not the chords elements 
            need to be flattened into singles notes.
    only_note_name: Boolean, indicating whether the notes need to be 
                    converted from music21 object with octave indication
                    to only a string indicating the pitch.
    '''
    midi_stream = m21.converter.parse(filename)
    res = []
    for elem in midi_stream.recurse():
        if isinstance(elem, m21.note.Note):
            start_offset = elem.getOffsetInHierarchy(midi_stream)
            res.append((elem.name if only_note_name else elem, (start_offset, start_offset+elem.duration.quarterLength)))
        elif isinstance(elem, m21.chord.Chord):
            start_offset = elem.getOffsetInHierarchy(midi_stream)
            res += list(map(lambda r: (r.name if only_note_name else r , (start_offset, start_offset+elem.duration.quarterLength)), elem.pitches))
    return res


def remove_drums_from_midi_file(midi_filename):
    '''
    Takes care of removing drum tracks from a midi filename.
    Work only if the MIDI file has metadata clearly indicating channels that are
    percussive. Does not remove channels of percussive instruments that are pitched
    (like the glockenspiel for instance).  
    
    Param: 
    
    returns:
    The (temporary) filepath of the midi file generated without the drum channel.
    
    '''
    sound = pm.PrettyMIDI(midi_filename)
    
    #getting
    drum_instruments_index = [idx for idx, instrument in enumerate(sound.instruments) if instrument.is_drum]
    for i in sorted(drum_instruments_index, reverse=True):
        del sound.instruments[i]

    folder = tempfile.TemporaryDirectory()
    temp_midi_filepath = folder.name+'tmp.mid'
    sound.write(temp_midi_filepath)
    
    return temp_midi_filepath


def only_keep_pitches_in_boundaries(pitch_offset_list, beat1_offset, beat2_offset): 
    return list(filter(lambda n: n[1][1] >= beat1_offset and n[1][0] <= beat2_offset, pitch_offset_list))


def slice_according_to_beat(pitch_offset_list, beat1_offset, beat2_offset):
    #the beat offset must be expressed as relation of quarter note. 
    #Taken are all beat which at least END AFTER the beat1, and START BEFORE the beat2
    res = []
    if beat1_offset >= beat2_offset:
        return res
    for n in only_keep_pitches_in_boundaries(pitch_offset_list, beat1_offset, beat2_offset):
        start_b = n[1][0]
        end_b = n[1][1]
        
        res_n = None
        if start_b >= beat1_offset:
            if end_b > beat2_offset:
                res_n = (n[0], (start_b, beat2_offset))
            else:
                res_n = (n[0], (start_b, end_b))
        elif end_b <= beat2_offset:
            #if start_b < beat1_offset: #of course we are in this case since the first if was not triggered.
            res_n = (n[0], (beat1_offset, end_b))
        else:
            #we are thus in the case the start and end time of the note overshoot the boundaries.
            res_n = (n[0], (beat1_offset, beat2_offset))
        #normally inconsistent results should not happen, but it is possible to have a note with duration equals to 0. This line below prevents that and thus keep the things concise.
        if res_n[1][0] < res_n[1][1]:
            res.append(res_n)
    return res

def sum_into_pitch_class_vector(pitch_offset_list, start_beat, end_beat):
    pitch_class_offset = lambda t: pitch_index_dict[normalize_notation_dict[t[0]]]
    pitch_class_vec = np.zeros(12)
    for tup in pitch_offset_list:
        #we need to be sure we don't take into account the part of the note that exceed the window's size.
        min_beat = max(start_beat, tup[1][0])
        max_beat = min(end_beat, tup[1][1])
        pitch_weight = max_beat - min_beat
        pitch_class_vec[pitch_class_offset(tup)] += pitch_weight
    return pitch_class_vec


def pitch_class_set_vector_from_pitch_offset_list(pitch_offset_array, aw_size=0.5): #the analysis window size (aw_size) is expressed in terms of number of quarter.
    
    def get_max_beat(pitch_offset_list):
        return math.ceil(max(list(map(lambda r: r[1][1], pitch_offset_list))))
    
    max_beat = get_max_beat(pitch_offset_array)
    
    if aw_size <= max_beat/2:
        chunk_number = math.ceil(max_beat/aw_size)
    else:
        raise Exception('The analysis window\'s size should not exceed half the duration of the musical piece.')
    
    res_vector = np.full((chunk_number, 12), 0.0, np.float64)

    for b in range(chunk_number):
        start_beat = b*aw_size
        stop_beat = (b+1)*aw_size
        analysis_windows = slice_according_to_beat(pitch_offset_array, start_beat, stop_beat)
        pitch_class_vec = sum_into_pitch_class_vector(analysis_windows, start_beat, stop_beat)
        res_vector[b] = pitch_class_vec
    
    return res_vector


def produce_pitch_class_matrix_from_filename(filename, remove_percussions = True, aw_size = 0.5):
    '''
    TODO comments
    '''
    if filename.endswith('.mid') or filename.endswith('.midi'):
        midi_filename = remove_drums_from_midi_file(filename) if remove_percussions else filename
        pitch_offset_list = recursively_map_offset(midi_filename)
        return pitch_class_set_vector_from_pitch_offset_list(pitch_offset_list, aw_size)
    elif filename.endswith('.wav'):
        return None
        #TODO: add code there.
    else:
        raise Exception('The file should be in MIDI or WAV format')
        
    return recursively_map_offset(midi_stream)

## Part 1.2 tests

In [18]:
#### Tests on a MIDI transcription of Bach's Prelude
bach_prelude_midi = test_midi_folder + '210606-Prelude_No._1_BWV_846_in_C_Major.mid'

#### MAX BEAT TEST
BACH_PRELUDE_MAX_BEAT = 35 * 4
bp_po_list = recursively_map_offset(bach_prelude_midi)
assert(BACH_PRELUDE_MAX_BEAT == get_max_beat(bp_po_list))

#### DEFAULT AW SIZE TEST
bp_pcm = produce_pitch_class_matrix_from_filename(bach_prelude_midi)
assert(np.shape(bp_pcm)[0] == 2*BACH_PRELUDE_MAX_BEAT)

#### AW SIZE = 1 TEST

bp_pcm_aw1 = produce_pitch_class_matrix_from_filename(bach_prelude_midi, aw_size=1)
assert(np.shape(bp_pcm_aw1)[0] == BACH_PRELUDE_MAX_BEAT)

#### AW_SIZE = MAX_BEAT/2 

bp_pcm_aw_half = produce_pitch_class_matrix_from_filename(bach_prelude_midi, aw_size=BACH_PRELUDE_MAX_BEAT/2)
assert(np.shape(bp_pcm_aw_half)[0] == 2)



array([[36.  ,  1.  , 32.75,  0.  , 24.5 , 20.5 ,  2.  , 24.  ,  3.5 ,
        18.  ,  3.5 , 18.  ],
       [53.5 ,  0.  , 19.5 ,  4.5 , 19.75, 20.75,  5.  , 52.5 ,  4.  ,
         4.  ,  2.  ,  9.25]], dtype=float32)

## Part 2: Apply DFT and Generate UTM

### Part 2.1 code

In [None]:
def build_dft_utm_from_one_row(res):
    pcv_nmb = np.shape(res)[0]
    for i in range(1, pcv_nmb):
        for j in range(0, pcv_nmb-i):
            res[i][i+j] = res[0][i+j] + res[i-1][i+j-1]
    return res

def apply_dft_to_pitch_class_matrix(pc_mat, build_utm = True):
    pcv_nmb, pc_nmb = np.shape(pc_mat)
    #+1 to hold room for the 0th coefficient
    coeff_nmb = int(pc_nmb/2)+1
    res_dimensions = (pcv_nmb, coeff_nmb)
    res = np.full(res_dimensions, (0. + 0.j), np.complex128)

    for i in pcv_nmb:
        res[i] = np.fft.fft(pc_mat[i])[:coeff_nmb] #coeff 7 to 11 are uninteresting (conjugates of coeff 6 to 1).
    
    if build_utm:
        new_res = np.full((pcv_nmb, pcv_nmb, coeff_nmb), (0. + 0.j), np.complex128)
        new_res[0] = res 
        res = build_dft_utm_from_one_row(new_res)
        
    return res

### Part 2.2 Tests