In [9]:
import random
from operator import attrgetter
from tabulate import tabulate
from music21 import corpus
from scipy.optimize import minimize

from firms.graders import count_grader, log_count_grader, weighted_sum_grader_factory, log_weighted_sum_grader_factory, stem_counter_by_piece, log_weighted_sum_grader_weightless_factory
from firms.stemmers import index_key_by_pitch, index_key_by_simple_pitch, index_key_by_interval, index_key_by_contour, index_key_by_rythm, index_key_by_normalized_rythm
from firms.models import MemoryIRSystem, print_timing, flatten
from firms.sql_irsystems import SqlIRSystem

In [2]:
print_timing("Loading pieces")
piece_paths = flatten([
    corpus.getComposer('bach', 'xml'),
    corpus.getComposer('mozart', 'xml'),
    corpus.getComposer('beethoven', 'xml'),
    corpus.getComposer('schumann', 'xml')
])

2017-11-06 23:22:00.486892 Loading pieces


In [3]:
index_methods = {
    'By Pitch': index_key_by_pitch,
    'By Simple Pitch': index_key_by_simple_pitch,
    'By Contour': index_key_by_contour,
    'By Interval': index_key_by_interval,
    'By Rythm': index_key_by_rythm,
    'By Normal Rythm': index_key_by_normalized_rythm
}

weights = {'By Pitch': 2, 'By Simple Pitch': 1, 'By Interval': .2, 'By Contour': .1, 'By Rythm': .1, 'By Normal Rythm': .1}
weights2 = {'By Pitch': 4.3, 'By Simple Pitch': 2.5, 'By Interval': 3.0, 'By Contour': -1.94, 'By Rythm': 1.36, 'By Normal Rythm': -2.85}
scorer_methods = {
    # 'Count': count_grader,
    # 'Log Count': log_count_grader,
    # 'Linear': weighted_sum_grader_factory(weights),
    'LogLinar': log_weighted_sum_grader_factory(weights),
    'LogLinear2': log_weighted_sum_grader_factory(weights2)
}

In [4]:
print_timing("Building IR system")
# irsystem = MemoryIRSystem(index_methods, scorer_methods, piece_paths)

sqlsystem = SqlIRSystem('example.db.sqlite', index_methods, scorer_methods, piece_paths, False)

print_timing("Sampling ranges for demonstration")
sample_paths = random.sample(piece_paths, min(50, len(piece_paths)))
sample_pieces = (corpus.parse(piece) for piece in sample_paths)
sample_streams = []
sample_details = []
for piece in sample_pieces:
    parts = list(piece.recurse().parts)
    part = random.sample(parts, 1)[0]
    num_of_measures = len(part.measures(0, None))
    idx = random.randint(0, num_of_measures-5)
    sample_streams.append(part.measures(idx, idx+4).recurse().notesAndRests)
    sample_details.append((piece.metadata.title, part, idx))
print_timing("Done")

2017-11-06 23:22:39.892950 Building IR system
2017-11-06 23:22:39.895949 Sampling ranges for demonstration
2017-11-06 23:23:12.261235 Done


In [5]:
print("==========================================================================")
print("Optimization")
print("==========================================================================")

print_timing("Building linear model")
tp_names = list(map(lambda x: x[0], sample_details))
matches = []
for detail,query in zip(sample_details, sample_streams):
    print_timing("Querying %s (%s)" % (detail[0], detail[1].partName), 1)
    raw_results = list(sqlsystem.raw_query(query))
    matches.append(stem_counter_by_piece(raw_results))
print_timing("Done")

Optimization
2017-11-06 23:23:21.062047 Building linear model
	2017-11-06 23:23:21.062047 Querying bwv248.33-3.mxl (Tenor)
	2017-11-06 23:23:23.338577 Querying bwv421.mxl (Bass)
	2017-11-06 23:23:41.816420 Querying bwv177.5.mxl (Bass)
	2017-11-06 23:23:49.732160 Querying bwv384.mxl (Bass)
	2017-11-06 23:23:52.028255 Querying bwv52.6.mxl (Soprano)
	2017-11-06 23:23:57.104997 Querying bwv103.6.mxl (Alto)
	2017-11-06 23:24:11.711622 Querying bwv245.22.mxl (Alto)
	2017-11-06 23:24:22.339754 Querying bwv153.5.mxl (Tenor)
	2017-11-06 23:24:28.727255 Querying bwv156.6.mxl (Tenor)
	2017-11-06 23:24:35.801834 Querying bwv846.mxl (Piano)
	2017-11-06 23:25:01.840280 Querying movement1.mxl (Viola)
	2017-11-06 23:25:19.486589 Querying bwv287.mxl (Soprano)
	2017-11-06 23:25:21.829644 Querying bwv387.mxl (Tenor)
	2017-11-06 23:25:24.001366 Querying II. Aus meinen Tränen sprießen (MusicXML Part)
	2017-11-06 23:25:24.096866 Querying bwv341.mxl (Alto)
	2017-11-06 23:25:30.262867 Querying bwv270.mxl (Sop

In [46]:
from importlib import reload
import firms
from firms import graders
reload(graders)
to_optimize = graders.log_weighted_sum_grader_weightless_factory(matches, tp_names, weights.keys())
optimized = minimize(
    to_optimize,
    list(1 for i in weights.values()),
    method='L-BFGS-B',
    options={'eps': 1e-02},
    bounds=tuple((0.01, None) for i in weights.items())
)
print(optimized)

			Ignoring II. Aus meinen Tränen sprießen
			Ignoring String Quartet
      fun: 4.3125
 hess_inv: <6x6 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 0.,  0.,  0.,  0.,  0.,  0.])
  message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 21
      nit: 1
   status: 0
  success: True
        x: array([ 4.38148063,  4.38148063,  2.12716021,  0.46437347,  0.46437347,
        0.46437347])
