In [1]:
hooktheory_midi_dir = "/Users/4rr311/Documents/VectorA/KHTN/Nam4/HKII/Thesis/Brainstorming/DataCrawling/ProcessedData/midi_from_json_songs"

model_output_dir = "/Users/4rr311/Documents/VectorA/KHTN/Nam4/HKII/Thesis/Brainstorming/Evaluation/data_for_testing/model_output/data for visualization/musecoco arch/midi"
model_output_dir = "/Users/4rr311/Documents/VectorA/KHTN/Nam4/HKII/Thesis/Brainstorming/Evaluation/data_for_testing/model_output/data for visualization/lora gpt2"

In [2]:
import pretty_midi
import numpy as np
from collections import Counter
import os
import plotly.express as px

In [3]:
# current_idx = 0

# indices_to_analyze = [
#     0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
#     # 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
#     # 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
#     # 100, 101, 102, 103, 104, 105, 106, 107, 108, 109
# ]

In [4]:
def load_midi(file_path):
    return pretty_midi.PrettyMIDI(file_path)

def get_notes(midi_data):
    notes = []
    for instrument in midi_data.instruments:
        if not instrument.is_drum:
            for note in instrument.notes:
                notes.append(note)
    return sorted(notes, key=lambda note: note.start)

def pitch_counts(notes):
    pitches = [note.pitch for note in notes]
    return dict(Counter(pitches))

def pitch_class_distribution(notes):
    pitch_classes = [note.pitch % 12 for note in notes]
    return dict(Counter(pitch_classes))

def pitch_class_transition_matrix(notes):
    pitch_classes = [note.pitch % 12 for note in notes]
    transition_matrix = np.zeros((12, 12))
    for i in range(len(pitch_classes) - 1):
        transition_matrix[pitch_classes[i], pitch_classes[i + 1]] += 1
    return transition_matrix / transition_matrix.sum(axis=1, keepdims=True)

def pitch_range(notes):
    pitches = [note.pitch for note in notes]
    return max(pitches) - min(pitches)

def average_pitch_intervals(notes):
    pitch_intervals = [notes[i + 1].pitch - notes[i].pitch for i in range(len(notes) - 1)]
    return np.mean(pitch_intervals), np.std(pitch_intervals)

def average_inter_onset_intervals(notes):
    inter_onset_intervals = [notes[i + 1].start - notes[i].start for i in range(len(notes) - 1)]
    return np.mean(inter_onset_intervals), np.std(inter_onset_intervals)

def note_count(notes):
    return len(notes)

def note_length_transition_matrix(notes):
    note_lengths = [note.end - note.start for note in notes]
    unique_lengths = sorted(set(note_lengths))
    length_to_idx = {length: idx for idx, length in enumerate(unique_lengths)}
    transition_matrix = np.zeros((len(unique_lengths), len(unique_lengths)))
    
    for i in range(len(note_lengths) - 1):
        current_length = note_lengths[i]
        next_length = note_lengths[i + 1]
        transition_matrix[length_to_idx[current_length], length_to_idx[next_length]] += 1
    
    return transition_matrix / transition_matrix.sum(axis=1, keepdims=True)

def analyze_midi(file_path):
    midi_data = load_midi(file_path)
    notes = get_notes(midi_data)

    analysis = {}
    
    if len(notes) != 0:
        analysis = {
            'pitch_counts': pitch_counts(notes),
            'pitch_class_distribution': pitch_class_distribution(notes),
            # 'pitch_class_transition_matrix': pitch_class_transition_matrix(notes),
            # 'pitch_range': pitch_range(notes),
            # 'average_pitch_intervals': average_pitch_intervals(notes),
            # 'average_inter_onset_intervals': average_inter_onset_intervals(notes),
            'note_count': note_count(notes),
            # 'note_length_transition_matrix': note_length_transition_matrix(notes)
        }
    else:
        pass
    return analysis

def analyze_midi_folder(folder_path, max_file_count=10000):
    results = {}

    current_file_count = 0
    # Walk recursively
    # for file_name in os.listdir(folder_path):
    #     if file_name.endswith('.mid') or file_name.endswith('.midi'):
    #         current_file_count += 1
    #         if current_file_count > max_file_count:
    #             break
    #         else:
    #             file_path = os.path.join(folder_path, file_name)
    #             print(f"Analyzing file {current_file_count} of {max_file_count} - {file_name}")
    #             results[file_name] = analyze_midi(file_path)
    global current_idx
    global indices_to_analyze

    for root, dirs, files in os.walk(folder_path):
        for file_name in files:
            # if current_idx not in indices_to_analyze:
            #     current_idx += 1
            #     continue
            # else:
            #     current_idx += 1
                
            if file_name.endswith('.mid') or file_name.endswith('.midi'):
                current_file_count += 1
                if current_file_count > max_file_count:
                    break
                else:
                    file_path = os.path.join(root, file_name)
                    print(f"Analyzing file {current_file_count} of {max_file_count} - {file_name}")
                    results[file_name] = analyze_midi(file_path)
    
    current_idx = 0
    
    return results

In [5]:
def average_analysis_results(analysis_results):
    average_results = {
        "pitch_counts": dict[int, float](),
        "pitch_class_distribution": dict[int, float](),
        "note_count": int(0)
    }
    
    n_files = len(analysis_results)

    for file_name, analysis in analysis_results.items():
        for key, value in analysis.items():
            if key == 'pitch_counts':
                for pitch, v in value.items():
                    if pitch not in average_results["pitch_counts"]:
                        average_results["pitch_counts"][pitch] = v / n_files
                    else:
                        average_results["pitch_counts"][pitch] += v / n_files
            elif key == 'pitch_class_distribution':
                for pitch_class, v in value.items():
                    if pitch_class not in average_results["pitch_class_distribution"]:
                        average_results["pitch_class_distribution"][pitch_class] = v / n_files
                    else:
                        average_results["pitch_class_distribution"][pitch_class] += v / n_files
            elif key == 'note_count':
                average_results[key] += value / n_files
            else:
                pass
    
    return average_results

# Hooktheory

In [6]:
folder_path = hooktheory_midi_dir
n_file_to_analyze = 29038
n_file_to_analyze = 100

analysis_results = analyze_midi_folder(folder_path, n_file_to_analyze)

Analyzing file 1 of 100 - d_deerhunter_back-to-the-middle_Solo_AQodPXyKoDl.mid
Analyzing file 2 of 100 - t_the-beatles_a-hard-days-night_Intro_ZwxK_QNbxed.mid
Analyzing file 3 of 100 - n_niall-horan_still_Bridge_jDgXdDVegKl.mid
Analyzing file 4 of 100 - a_ashe_love-is-not-enough_Pre-Chorus_yvgPQdedxYq.mid
Analyzing file 5 of 100 - b_bryan-scary_operaland_Chorus_DpgvRAkdgad.mid
Analyzing file 6 of 100 - j_jimmy-fontanez_urban-lullaby_Verse_Wegl_wa-mrY.mid
Analyzing file 7 of 100 - h_hirohiko-takayama_hudsons-adventure-island-iii---thunder-clash_Instrumental_yvmrLZGzxOW.mid
Analyzing file 8 of 100 - h_hitomi-yaida_ashita-kara-no-tegami_Intro and Verse_AaoGbakPxeQ.mid
Analyzing file 9 of 100 - b_billy-idol_white-wedding_Chorus_eWxLdzOpxaK.mid
Analyzing file 10 of 100 - l_lawrence_the-heartburn-song_Verse_ROmNkROngNw.mid
Analyzing file 11 of 100 - m_miguel_goingtohell_Verse_zngREQNMmJj.mid
Analyzing file 12 of 100 - b_bensound_elevator-bossa-nova_Intro and Verse_d_gwnZ_QmGV.mid
Analyzing f

In [7]:
n_file_to_print = 1

current_file = 0

# Print results
for file_name, analysis in analysis_results.items():
    if current_file < n_file_to_print:
        current_file += 1
        print(f"{current_file} of {n_file_to_analyze} result - {file_name}:")

        for key, value in analysis.items():
            print(f"  {key}: {value}")
    else:
        break

1 of 100 result - d_deerhunter_back-to-the-middle_Solo_AQodPXyKoDl.mid:
  pitch_counts: {71: 14, 59: 2, 62: 2, 66: 16, 35: 2, 73: 6, 68: 24, 57: 2, 61: 12, 64: 14, 33: 2, 69: 2, 42: 2, 40: 2, 37: 2}
  pitch_class_distribution: {11: 18, 2: 2, 6: 18, 1: 20, 8: 24, 9: 6, 4: 16}
  note_count: 104


In [8]:
average_results = average_analysis_results(analysis_results)

# Print average results
print("Average results:")
for key, value in average_results.items():
    print(f"  {key}: {value}")

# Save average results to file
# with open('average_results.json', 'w') as f:
#     json.dump(average_results, f, indent=4)

# Save analysis results to file
# with open('analysis_results.json', 'w') as f:
#     json.dump(analysis_results, f, indent=4)

Average results:
  pitch_counts: {71: 4.09, 59: 4.319999999999999, 62: 7.209999999999999, 66: 5.18, 35: 1.4400000000000004, 73: 2.08, 68: 4.749999999999999, 57: 3.4499999999999984, 61: 4.339999999999998, 64: 5.969999999999999, 33: 1.3500000000000005, 69: 5.6099999999999985, 42: 0.39999999999999997, 40: 0.6300000000000001, 37: 0.9000000000000001, 65: 6.089999999999997, 67: 7.349999999999999, 72: 3.4899999999999993, 38: 1.0800000000000005, 60: 4.889999999999995, 77: 1.1900000000000004, 75: 1.5600000000000005, 44: 0.21000000000000002, 56: 1.6000000000000003, 63: 4.509999999999997, 32: 0.6800000000000002, 58: 3.549999999999999, 51: 0.6700000000000002, 55: 2.6399999999999992, 27: 0.48, 34: 1.3600000000000005, 36: 1.0500000000000003, 76: 1.4200000000000006, 74: 2.609999999999999, 78: 1.0900000000000003, 79: 0.9400000000000003, 81: 0.37, 82: 0.3800000000000001, 70: 3.3299999999999987, 45: 0.25, 80: 0.4700000000000001, 83: 0.19999999999999998, 85: 0.21000000000000002, 87: 0.07, 39: 0.480000000

In [9]:
# Horizontal bar chart for pitch counts using plotly express (with y is the pitch and x is the count)
fig = px.bar(
    x=list(average_results['pitch_counts'].values()), 
    y=list(average_results['pitch_counts'].keys()), 
    orientation='h'
)

fig.update_layout(
    title=f'Hooktheory (n = {n_file_to_analyze})', 
    xaxis_title='Số lần xuất hiện trung bình',
    yaxis_title='Cao độ'
)

# Show labels 0, 5, 10, 15, ... in y-axis
y_axis_tickvals = list(range(0, max(average_results['pitch_counts'].keys()) + 1, 5))
fig.update_yaxes(tickvals=y_axis_tickvals)

# Show labels 0, 1, 2, 3, ... in x-axis
x_axis_tickvals = list(range(
    0, 
    int(max(average_results['pitch_counts'].values())) + 1 + 1
))
fig.update_xaxes(tickvals=x_axis_tickvals)

# Change the color set to make it easier to see when printed in black and white
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5, opacity=0.6)

# Change the background color to make it easier to see when printed in black and white
fig.update_layout(plot_bgcolor='white')

# Show the vertical grid lines
fig.update_layout(xaxis=dict(showgrid=True, gridwidth=1, gridcolor='rgb(158,202,225)'))

# Make the colums overlay the grid lines
fig.update_layout(barmode='overlay')

# Make the font size bigger
fig.update_layout(font=dict(size=13))

# Make the plot ratio x:y
fig.update_layout(
    autosize=False,
    width=450,
    height=1200,
)

fig.show()


# Pitch class distribution
fig = px.bar(x=list(average_results['pitch_class_distribution'].keys()), y=list(average_results['pitch_class_distribution'].values()))
fig.update_layout(
    title=f'Hooktheory (n = {n_file_to_analyze})', 
    yaxis_title='Số lần xuất hiện trung bình',
    xaxis_title='Cao độ cơ bản'
)

# Show labels 0, 1, 2, 3, ... in x-axis
x_axis_tickvals = list(range(0, max(average_results['pitch_class_distribution'].keys()) + 1))
fig.update_xaxes(tickvals=x_axis_tickvals)

# Change the color set to make it easier to see when printed in black and white
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5, opacity=0.6)

# Change the background color to make it easier to see when printed in black and white
fig.update_layout(plot_bgcolor='white')

# Show the horizontal grid lines
fig.update_layout(yaxis=dict(showgrid=True, gridwidth=1, gridcolor='rgb(158,202,225)'))

# Make the font size bigger
fig.update_layout(font=dict(size=13))

# Make the plot ratio x:y
fig.update_layout(
    autosize=False,
    width=650,
    height=400,
)

fig.show()

# Note count
print(f"Số lượng nốt nhạc trung bình trong một đoạn nhạc từ Hooktheory (n = {n_file_to_analyze}): {average_results['note_count']}")

Số lượng nốt nhạc trung bình trong một đoạn nhạc từ Hooktheory (n = 100): 116.44000000000007


# Model Output

In [10]:
folder_path = model_output_dir

# Count midi files in the folder path
n_file_to_analyze = 0
for root, dirs, files in os.walk(folder_path):
    for file_name in files:
        if file_name.endswith('.mid') or file_name.endswith('.midi'):
            n_file_to_analyze += 1

analysis_results = analyze_midi_folder(folder_path, n_file_to_analyze)

Analyzing file 1 of 906 - 166_2024_07_28_11_pm_164.mid
Analyzing file 2 of 906 - 257_2024_07_29_10_am_82.mid
Analyzing file 3 of 906 - 71_2024_07_28_09_am_69.mid
Analyzing file 4 of 906 - 258_2024_07_29_10_am_83.mid
Analyzing file 5 of 906 - 230_2024_07_29_06_am_55.mid
Analyzing file 6 of 906 - 107_2024_07_28_03_pm_105.mid
Analyzing file 7 of 906 - 52_2024_07_28_06_am_50.mid
Analyzing file 8 of 906 - 118_2024_07_28_04_pm_116.mid
Analyzing file 9 of 906 - 210_2024_07_29_04_am_35.mid
Analyzing file 10 of 906 - 97_2024_07_28_01_pm_95.mid
Analyzing file 11 of 906 - 249_2024_07_29_09_am_74.mid
Analyzing file 12 of 906 - 0_2024_07_27_11_pm_2.mid
Analyzing file 13 of 906 - 55_2024_07_28_07_am_53.mid
Analyzing file 14 of 906 - 244_2024_07_29_09_am_69.mid
Analyzing file 15 of 906 - 263_2024_07_29_11_am_88.mid
Analyzing file 16 of 906 - 189_2024_07_29_01_am_14.mid
Analyzing file 17 of 906 - 161_2024_07_28_10_pm_159.mid
Analyzing file 18 of 906 - 235_2024_07_29_07_am_60.mid
Analyzing file 19 of 9

In [11]:
n_file_to_print = 1

current_file = 0

# Print results
for file_name, analysis in analysis_results.items():
    if current_file < n_file_to_print:
        current_file += 1
        print(f"{current_file} of {n_file_to_analyze} result - {file_name}:")

        for key, value in analysis.items():
            print(f"  {key}: {value}")
    else:
        break

1 of 906 result - 166_2024_07_28_11_pm_164.mid:
  pitch_counts: {57: 15, 50: 20, 69: 10, 53: 7, 62: 11, 60: 21, 67: 24, 72: 17, 74: 14, 59: 9, 65: 14, 52: 16, 48: 11, 45: 6, 55: 5, 36: 1, 43: 2, 77: 4, 64: 49, 70: 2, 76: 46, 38: 1, 58: 1, 44: 1, 40: 1, 56: 1, 41: 2, 93: 1, 79: 1}
  pitch_class_distribution: {9: 32, 2: 46, 5: 27, 0: 50, 7: 32, 11: 9, 4: 112, 10: 3, 8: 2}
  note_count: 313


In [12]:
average_results = average_analysis_results(analysis_results)

# Print average results
print("Average results:")
for key, value in average_results.items():
    print(f"  {key}: {value}")

# Save average results to file
# with open('average_results.json', 'w') as f:
#     json.dump(average_results, f, indent=4)

# Save analysis results to file
# with open('analysis_results.json', 'w') as f:
#     json.dump(analysis_results, f, indent=4)

Average results:
  pitch_counts: {57: 12.963696369636967, 50: 12.23102310231022, 69: 21.930693069306898, 53: 7.93729372937295, 62: 16.27722772277226, 60: 18.06270627062703, 67: 16.584158415841575, 72: 12.834983498349825, 74: 12.88778877887788, 59: 10.867986798679858, 65: 8.511551155115509, 52: 14.168316831683168, 48: 11.244224422442231, 45: 5.537953795379538, 55: 14.029702970296999, 36: 5.607260726072607, 43: 5.419141914191417, 77: 4.29042904290429, 64: 24.21452145214518, 70: 0.74917491749175, 76: 19.369636963696355, 38: 2.132013201320132, 58: 0.5379537953795379, 44: 0.7029702970297037, 40: 2.5940594059405964, 56: 1.0330033003300332, 41: 4.366336633663371, 93: 0.029702970297029698, 79: 3.9966996699669988, 75: 0.7194719471947201, 51: 1.772277227722772, 46: 1.478547854785479, 63: 0.5808580858085812, 87: 0.17821782178217815, 66: 1.4389438943894393, 47: 2.412541254125413, 71: 5.8580858085808645, 54: 1.0792079207920795, 61: 0.7590759075907597, 49: 0.7953795379537953, 81: 1.7029702970297018,

In [13]:
# Horizontal bar chart for pitch counts using plotly express (with y is the pitch and x is the count)
fig = px.bar(
    x=list(average_results['pitch_counts'].values()), 
    y=list(average_results['pitch_counts'].keys()), 
    orientation='h'
)

fig.update_layout(
    title=f'Model output (n = {n_file_to_analyze})', 
    xaxis_title='Số lần xuất hiện trung bình',
    yaxis_title='Cao độ'
)

# Show labels 0, 5, 10, 15, ... in y-axis
y_axis_tickvals = list(range(0, max(average_results['pitch_counts'].keys()) + 1, 5))
fig.update_yaxes(tickvals=y_axis_tickvals)

# Show labels 0, 1, 2, 3, ... in x-axis
x_axis_tickvals = list(range(
    0, 
    int(max(average_results['pitch_counts'].values())) + 1 + 1
))
fig.update_xaxes(tickvals=x_axis_tickvals)

# Change the color set to make it easier to see when printed in black and white
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5, opacity=0.6)

# Change the background color to make it easier to see when printed in black and white
fig.update_layout(plot_bgcolor='white')

# Show the vertical grid lines
fig.update_layout(xaxis=dict(showgrid=True, gridwidth=1, gridcolor='rgb(158,202,225)'))

# Make the colums overlay the grid lines
fig.update_layout(barmode='overlay')

# Make the font size bigger
fig.update_layout(font=dict(size=13))

# Make the plot ratio x:y
fig.update_layout(
    autosize=False,
    width=450,
    height=1200,
)

fig.show()


# Pitch class distribution
fig = px.bar(x=list(average_results['pitch_class_distribution'].keys()), y=list(average_results['pitch_class_distribution'].values()))
fig.update_layout(
    title=f'Model output (n = {n_file_to_analyze})', 
    yaxis_title='Số lần xuất hiện trung bình',
    xaxis_title='Cao độ cơ bản'
)

# Show labels 0, 1, 2, 3, ... in x-axis
x_axis_tickvals = list(range(0, max(average_results['pitch_class_distribution'].keys()) + 1))
fig.update_xaxes(tickvals=x_axis_tickvals)

# Change the color set to make it easier to see when printed in black and white
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5, opacity=0.6)

# Change the background color to make it easier to see when printed in black and white
fig.update_layout(plot_bgcolor='white')

# Show the horizontal grid lines
fig.update_layout(yaxis=dict(showgrid=True, gridwidth=1, gridcolor='rgb(158,202,225)'))

# Make the font size bigger
fig.update_layout(font=dict(size=13))

# Make the plot ratio x:y
fig.update_layout(
    autosize=False,
    width=650,
    height=400,
)

fig.show()

# Note count
print(f"Số lượng nốt nhạc trung bình trong một đoạn nhạc từ Model output (n = {n_file_to_analyze}): {average_results['note_count']}")

Số lượng nốt nhạc trung bình trong một đoạn nhạc từ Model output (n = 906): 303.1254125412542


# Model Output (C major & A minor key only)

In [14]:
# Select the pitches to analyze (C major scale and A minor scale)
selected_pitchs = [
    0, # C
    2, # D
    4, # E
    5, # F
    7, # G
    9, # A
    11 # B
]

# AVARAGE PITCH COUNTS
# Sort average_results['pitch_counts'] by pitch
average_results['pitch_counts'] = dict(sorted(average_results['pitch_counts'].items()))

print(average_results['pitch_counts'])

# Filter the pitch that pitch mod any of the selected pitches equals 0
filtered_pitch_counts = {
    str(pitch): count 
    for pitch, count in average_results['pitch_counts'].items() 
    if pitch % 12 in selected_pitchs
}

print(filtered_pitch_counts)

# AVARAGE PITCH CLASS DISTRIBUTION
# Sort average_results['pitch_class_distribution'] by pitch class
average_results['pitch_class_distribution'] = dict(sorted(average_results['pitch_class_distribution'].items()))

print(average_results['pitch_class_distribution'])

# Filter the pitch class that is in the selected pitch classes
filtered_pitch_class_distribution = {
    str(pitch_class): count 
    for pitch_class, count in average_results['pitch_class_distribution'].items() 
    if pitch_class in selected_pitchs
}

print(filtered_pitch_class_distribution)

{24: 0.056105610561056105, 25: 0.11551155115511547, 26: 0.07260726072607258, 28: 0.029702970297029695, 29: 0.04950495049504951, 31: 0.13531353135313534, 32: 0.1419141914191419, 33: 2.419141914191421, 34: 0.0429042904290429, 35: 1.402640264026402, 36: 5.607260726072607, 37: 0.34653465346534673, 38: 2.132013201320132, 39: 0.21122112211221125, 40: 2.5940594059405964, 41: 4.366336633663371, 42: 1.8481848184818481, 43: 5.419141914191417, 44: 0.7029702970297037, 45: 5.537953795379538, 46: 1.478547854785479, 47: 2.412541254125413, 48: 11.244224422442231, 49: 0.7953795379537953, 50: 12.23102310231022, 51: 1.772277227722772, 52: 14.168316831683168, 53: 7.93729372937295, 54: 1.0792079207920795, 55: 14.029702970296999, 56: 1.0330033003300332, 57: 12.963696369636967, 58: 0.5379537953795379, 59: 10.867986798679858, 60: 18.06270627062703, 61: 0.7590759075907597, 62: 16.27722772277226, 63: 0.5808580858085812, 64: 24.21452145214518, 65: 8.511551155115509, 66: 1.4389438943894393, 67: 16.584158415841575

In [15]:
# Horizontal bar chart for filtered pitch counts using plotly express (with y is the pitch and x is the count)
fig = px.bar(
    x=list(filtered_pitch_counts.values()), 
    y=list(filtered_pitch_counts.keys()), 
    orientation='h'
)

fig.update_layout(
    title=f'Model output (n = {n_file_to_analyze})', 
    xaxis_title='Average count',
    yaxis_title='Pitch'
)

# Show labels 0, 5, 10, 15, ... in y-axis in string format
# y_axis_tickvals = list(range(0, max(filtered_pitch_counts.keys()) + 1, 5))
y_axis_tickvals = list(map(str, [x for x in filtered_pitch_counts.keys()]))

fig.update_yaxes(tickvals=y_axis_tickvals)

# Show labels 0, 1, 2, 3, ... in x-axis
# x_axis_tickvals = list(range(
#     0, 
#     int(max(filtered_pitch_counts.values())) + 1 + 1
# ))
# Show 0, 5, 10, 15, ... in x-axis
# x_axis_tickvals = list(range(0, int(max(filtered_pitch_counts.values())) + 2, 5))
x_axis_tickvals = list(range(0, int(max(filtered_pitch_counts.values())) + 2, 2))

fig.update_xaxes(tickvals=x_axis_tickvals)

# Change the color set to make it easier to see when printed in black and white
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5, opacity=0.6)

# Change the background color to make it easier to see when printed in black and white
fig.update_layout(plot_bgcolor='white')

# Show the vertical grid lines
fig.update_layout(xaxis=dict(showgrid=True, gridwidth=1, gridcolor='rgb(158,202,225)'))

# Make the columns overlay the grid lines
fig.update_layout(barmode='overlay')

# Make the font size bigger
fig.update_layout(font=dict(size=13))

# Make the plot ratio x:y
fig.update_layout(
    autosize=False,
    width=450,
    height=850,
)

fig.show()

# Filtered pitch class distribution
fig = px.bar(
    x=list(filtered_pitch_class_distribution.keys()),
    y=list(filtered_pitch_class_distribution.values())
)

fig.update_layout(
    title=f'Model output (n = {n_file_to_analyze})', 
    # yaxis_title='Số lần xuất hiện trung bình',
    yaxis_title='Average count',
    xaxis_title='Pitch class'
)

# Show labels 0, 1, 2, 3, ... in x-axis
# x_axis_tickvals = list(range(0, max(filtered_pitch_class_distribution.keys()) + 1))

# fig.update_xaxes(tickvals=x_axis_tickvals)

# Change the color set to make it easier to see when printed in black and white
fig.update_traces(marker_color='rgb(158,202,225)', marker_line_color='rgb(8,48,107)', marker_line_width=1.5, opacity=0.6)

# Change the background color to make it easier to see when printed in black and white
fig.update_layout(plot_bgcolor='white')

# Show the horizontal grid lines
fig.update_layout(yaxis=dict(showgrid=True, gridwidth=1, gridcolor='rgb(158,202,225)'))

# Make the font size bigger
fig.update_layout(font=dict(size=13))

# Make the plot ratio x:y
fig.update_layout(
    autosize=False,
    width=550,
    height=400,
)

fig.show()