In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from helper.label import classes
from helper.audio_extraction.get_file_list import get_file_list
from helper.audio_extraction.padded_and_windowed import extract_windowed_features
from FeedForward import ChordAI

In [3]:
config = torch.load('models/config.pth')
config

{'INPUT_SIZE': 528, 'OUTPUT_SIZE': 24}

In [4]:
model = ChordAI(config['INPUT_SIZE'], config['OUTPUT_SIZE']).to(device)

In [5]:
state_dict = torch.load("models/chord_model.pth")

In [6]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
test_files = get_file_list('./audio')

In [8]:
_, test_data = extract_windowed_features(test_files, classes, test_rate=0.2)



windowed_features:  936
test_windowed_features:  240


In [9]:
def predict(model, inputs, target, class_mapping):
    model.eval()
    with torch.no_grad():
        
        inputs = torch.tensor(inputs, dtype=torch.float32)
        inputs = inputs.to(device)
        inputs = inputs.unsqueeze(0)
        predictions = model(inputs)
        predicted_index = predictions[0].argmax(0)
        predicted = list(class_mapping)[predicted_index]
        expected = list(class_mapping)[target]
    return predicted, expected

In [10]:
predicted_list = []
data_right = 0
data_len = 0
for i in range(0, len(test_data), 25):
    inputs, target = test_data[i][0], test_data[i][1]
    predicted, expected = predict(model, inputs, target, classes)
    predicted_list.append(predicted)
    
    predicted_list.append({
        "expected": expected,
        "predicted": predicted
    })
    data_len += 1
    if predicted == expected:
        data_right += 1

print("Accuracy: ", data_right/data_len)

Accuracy:  1.0


In [11]:
for i in range(0, len(predicted_list)):
    print(predicted_list[i])

F#
{'expected': 'F#', 'predicted': 'F#'}
D#m
{'expected': 'D#m', 'predicted': 'D#m'}
Em
{'expected': 'Em', 'predicted': 'Em'}
Bm
{'expected': 'Bm', 'predicted': 'Bm'}
Am
{'expected': 'Am', 'predicted': 'Am'}
Gm
{'expected': 'Gm', 'predicted': 'Gm'}
G#m
{'expected': 'G#m', 'predicted': 'G#m'}
G
{'expected': 'G', 'predicted': 'G'}
Dm
{'expected': 'Dm', 'predicted': 'Dm'}
A
{'expected': 'A', 'predicted': 'A'}
