<h1>Make a classification prediction</h1>
Please ensure `topmagd_data_bin.zip` and `masd_data_bin.zip` are unzipped before trying to access them when loading the model

In [4]:
import preprocess # Pre-written functions for OctupleMIDI
import miditoolkit, io, torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
## Get file
# filename = "Blue room midi file.mid"
filename = "reggae.mid"
# filename = "2 of a kind.mid"

# Set params
task = "topmagd" # {"topmagd", "masd"}
fold = 0 # Only have checkpoints for 0

In [20]:
# Convert file to OctupleMIDI
midi = miditoolkit.midi.parser.MidiFile(f"midi/{filename}")
enc = preprocess.MIDI_to_encoding(midi)
print("MIDI file converted to OctupleMIDI: " + str(enc[0]) + " ...")
oct_midi_str = preprocess.encoding_to_str(enc)
print("OctupleMIDI as a string: " + str(oct_midi_str[100:120]) + " ...")

MIDI file converted to OctupleMIDI: (1, 0, 4, 67, 16, 20, 9, 39) ...
OctupleMIDI as a string: 3> <3-33> <4-49> <5- ...


In [21]:
# Load model
from fairseq.models.roberta import RobertaModel
roberta = RobertaModel.from_pretrained(
    '.',
    checkpoint_file=f"final_checkpoints/checkpoint_last_genre_{task}_0_checkpoint_last_musicbert_small.pt",
    data_name_or_path=f"{task}_data_bin/{fold}",
    user_dir="musicbert"
).cuda()

# Function for retrieving Name of category
label_fn = lambda label: roberta.task.label_dictionary.string(
    [label + label_dict.nspecial]
)

In [22]:
# Encode OctupleMIDI into tokens and predict
label_dict = roberta.task.label_dictionary
tokenized = label_dict.encode_line(oct_midi_str).long()
pred = torch.sigmoid(roberta.predict(f'{task}_head', tokenized, True)).tolist()[0]

In [23]:
## Print confidence levels
print(f"Genre confidence levels for file '{filename}':")

# Zip certainty values with labels
genres = zip([label_fn(j) for j in range(0, len(pred))], pred)

# Sort descending by certainty
genres = sorted(tuple(genres), key=lambda x: x[1], reverse=True) 

# Format and print 
for g in genres:
    p = 100*g[1]
    print("    {}: {:.2f}%".format(g[0], p))

Genre confidence levels for file 'reggae.mid':
    Pop_Rock: 70.10%
    Electronic: 57.40%
    Rap: 18.46%
    Latin: 13.09%
    International: 10.22%
    Jazz: 9.49%
    RnB: 8.63%
    New-Age: 4.25%
    Folk: 3.26%
    Reggae: 2.80%
    Country: 2.39%
    Vocal: 1.92%
    Blues: 1.16%


<h1>Evaluate classification predictions</h1>
Assesses the effect of the imbalance in data by iterating multiple files that are <b>not</b> labelled as 'Pop Rock' and counts incorrect predictions.
<ul><li>Please unzip `topmagd_data_raw.zip` and `lmd_matched.zip` before using this functionality</ul>
<hr>
<h5>Supporting functions:</h5>

In [24]:
## Function to find files that aren't pop rock and test predictions:
from random import randint
def get_non_pop_rock(fold, count, subset, start_rand=False):
    filepath = f"topmagd_data_raw/{fold}/{subset}"
    linestart = 1 if not start_rand else randint(1, get_line_count(f"{filepath}.label")) # Random start point
    
    # Get line nums in label file not containing 'Pop_rock'
    line_nums = set([])
    with open(f"{filepath}.label") as f:
        # length = sum(1 for l in f)
        print(f"Starting from line {linestart}.")
        
        n = 1
        for line in f:
            if n < linestart:
                n += 1
                continue # skip iteration if before starting point
                
            if "Pop_Rock" not in line.split():
                line_nums.add(n)
            n += 1
            
            # Finish if hit count
            if len(line_nums) >= count:
                break

    # Get id's from line numbers in id file
    ids = set([])
    with open(f"{filepath}.id") as f:
        n = 1
        for line in f:
            if n in line_nums:
                ids.add(line.strip() + ".mid") # Strip newline and add file extension
            n += 1
            
            # Finish if hit count
            if len(ids) >= count:
                break
    return ids

    
# Helper func for use in random starting points
def get_line_count(file_path):
    with open(file_path, 'r') as file:
        line_count = sum(1 for l in file)
    return line_count

In [25]:
# function to fetch paths of filenames (as we don't know parent folders)
import os
def get_paths(test_files, count):
    paths = set([])
    print("Searching lmd_matched")
    wlk = os.walk("lmd_matched")
    for root, _, files in wlk:
        for file in files:
            if file in test_files:
                paths.add(os.path.join(root, file))
                print(f"Found {file} - {len(paths)}/{count}")
                if len(paths) >= count:
                    return paths
    return paths

In [26]:
# Function to predict the genre/style of files in a filelist
def predict_files(files, detail=False):
    preds = []
    for file in files:
        # Get oct:
        midi = miditoolkit.midi.parser.MidiFile(file)
        enc = preprocess.MIDI_to_encoding(midi)
        oct_midi_str = preprocess.encoding_to_str(enc)
        # Encode:
        tokenized = label_dict.encode_line(oct_midi_str).long()
        # Predict:
        pred = torch.sigmoid(roberta.predict('topmagd_head', tokenized, True)).tolist()[0]
        # Format:
        genres = zip([label_fn(j) for j in range(0, len(pred))], pred)
        genres = sorted(tuple(genres), key=lambda x: x[1], reverse=True) 
        preds.append(genres[0])
        # Print details 
        if detail:
            print(f"Genre predictions for {file}:")
            for g in genres:
                p = 100*g[1]
                print("    {}: {:.2f}%".format(g[0], p)) # Format prediction
            print("--")
    return preds

<h5>Tweak parameters and evaluate:</h5>

In [27]:
fld = 0 # Fold number for data
cnt = 50 # Number of files to test
sub = "test" # Subset of fold to test. Possible values: {"train", "test"}
start_rand = True

# Get id's
files = get_non_pop_rock(fld, cnt, sub, start_rand=start_rand)
# Fetch paths for id's
paths = get_paths(files, cnt)
# Predict for each file
preds = predict_files(paths)

# Print stats
print(f"Classification predictions of non-'Pop Rock' samples")
print(f"{preds}\n")

incorr = 0
confs = [] # Confidence levels showing how certain the model was with its incorrect prediction
for p in preds:
    if p[0] == "Pop_Rock":
        incorr += 1
        confs.append(p[1]*100)
        
av_conf = sum(confs)/len(confs)
print("For files NOT labelled 'Pop_Rock', model incorrectly predicted 'Pop_Rock' {} out of {} times averaging {:.2f}% confidence".format(incorr, cnt, av_conf))


Starting from line 1529.
Searching lmd_matched
Found 7aacc203eed5109931465906ab66107f.mid - 1/50
Found c2dd5209834257c64acd6d0bcaed201f.mid - 2/50
Found 4d7d87c3aa300ef8bee345558426c34c.mid - 3/50
Found 01bc0adc044ca7061edfd808dcdcb170.mid - 4/50
Found 81059f7c4f020ab4a404706dd976c884.mid - 5/50
Found a1f4cc4b564b1bc12739563be42de256.mid - 6/50
Found 77f13e906ed5e77c0e3f32c823292469.mid - 7/50
Found 7e835b3e1d53b20aba77ff3899ff750c.mid - 8/50
Found c2dd5209834257c64acd6d0bcaed201f.mid - 9/50
Found 4e269ff0883fb58e84f95bdcdf4280d0.mid - 10/50
Found e936c5d842307753024dbc967c6fa339.mid - 11/50
Found 7aacc203eed5109931465906ab66107f.mid - 12/50
Found a8e070874d727af3f0c6a35d96b2e467.mid - 13/50
Found fb853997ca50718ce2c6d51d87e28712.mid - 14/50
Found 1c400411d020c6070b3081b1c75a8036.mid - 15/50
Found 0ecccc464ac97dce6f3239919fc50c89.mid - 16/50
Found bbd41c52600531e2ccd6e3590b1e942c.mid - 17/50
Found e1a44ea887e62185f8b0c61606ace589.mid - 18/50
Found 7e835b3e1d53b20aba77ff3899ff750c.mid -