In [279]:
import numpy as np
import pretty_midi
import torch
from collections import deque,defaultdict
from functools import reduce
import json

In [280]:
def melody_to_numpy(fpath="ashover12.mid", unit_time=0.125):
    # 首先取出midi中的音乐信息，将第一个音轨提出来，然后取所有音符元素。
    music = pretty_midi.PrettyMIDI(fpath)
    notes = music.instruments[0].notes

    # 纪录一个时间戳 t和保存向量的列表 roll。
    t = 0.
    roll = list()
    # print(notes[0], notes[-1])

    # 在 Notes里遍历，找出所有音符的音高和长度，同时不放过休止符。
    for note in notes:
        # print(t, note)

        # 两个相邻的音符不是无缝连接，说明休止符存在。计算其相对于最小分辨率 unit_time的相对时长 T，建立一个(T, 130)的矩阵，将第129维置1.
        elapsed_time = note.start - t
        if elapsed_time > 0.:
            steps = torch.zeros((int(round(elapsed_time / unit_time)), 130))
            steps[range(int(round(elapsed_time / unit_time))), 129] += 1.
            roll.append(steps)

        # 如果是无缝连接，那么检查当前音符：
        n_units = int(round((note.end - note.start) / unit_time))
        steps = torch.zeros((n_units, 130))
        steps[0, note.pitch] += 1
        steps[range(1, n_units), 128] += 1

        # 其中除第一列记录pitch外，其他列都记录sustain的128.最后合成为一个矩阵：
        roll.append(steps)
        t = note.end
    return torch.cat(roll, 0)

In [281]:
def numpy_to_midi(sample_roll, output='sample.mid'):
    music = pretty_midi.PrettyMIDI()
    piano_program = pretty_midi.instrument_name_to_program(
        'Acoustic Grand Piano')
    piano = pretty_midi.Instrument(program=piano_program)
    t = 0
    for i in sample_roll:
        if 'torch' in str(type(i)):
            pitch = int(i.max(0)[1])
        else:
            pitch = int(np.argmax(i))
        if pitch < 128:
            note = pretty_midi.Note(
                velocity=100, pitch=pitch, start=t, end=t + 1 / 8)
            t += 1 / 8
            piano.notes.append(note)
        elif pitch == 128:
            if len(piano.notes) > 0:
                note = piano.notes.pop()
            else:
                p = np.random.randint(60, 72)
                note = pretty_midi.Note(
                    velocity=100, pitch=int(p), start=0, end=t)
            note = pretty_midi.Note(
                velocity=100,
                pitch=note.pitch,
                start=note.start,
                end=note.end + 1 / 8)
            piano.notes.append(note)
            t += 1 / 8
        elif pitch == 129:
            t += 1 / 8
    music.instruments.append(piano)
    music.write(output)

In [282]:
def merge_sorted_arrays(a1,a2):
    result = deque()
    a1 = deque(a1)
    a2 = deque(a2)
    while a1 and a2:
        if tuple(a1[0][0]) < tuple(a2[0][0]):
            result.append(a1.popleft())
        elif tuple(a1[0][0]) == tuple(a2[0][0]) and tuple(a1[0][1]) < tuple(a2[0][1]):
            result.append(a1.popleft())
        else:
            result.append(a2.popleft())
    result += a1
    result += a2
    return np.array(result)

In [283]:
def group_by_first(input):
    result = defaultdict(list)
    for k,v in input:
        result[str(k)].append(v)
    return result

In [284]:
def intersect_sorted_arrays(a1,a2):
    result = deque()
    a1 = deque(a1)
    a2 = deque(a2)
    while a1 and a2:
        if tuple(a1[0]) == tuple(a2[0]):
            result.append(a1.popleft())
        elif tuple(a1[0]) < tuple(a2[0]):
            a1.popleft()
        else:
            a2.popleft()
    return np.array(result)

In [285]:
def json_pattern(points,vecs,occs):
    return{
        'points':[p.tolist() for p in points],
        'vectors':[v.tolist() for v in vecs],
        'occurrences':np.array(occs).tolist()
    }

In [286]:
def to_json(points,pats,vecs,occs):
    patterns = [json_pattern(p,vecs[i],occs[i]) for i,p in enumerate(pats)]
    return {'points':points.tolist,'patterns':patterns}

In [287]:
if __name__ == "__main__":
    print("table:")
    music_path = '../dataset/梁祝.mid'
    midiarray = melody_to_numpy(music_path)
    
    # midi_batch = []
    # for i in midiarray:
    #     num = 0
    #     for j in i:
    #         if j == 1:
    #             midi_batch.append(num)
    #             num = 0
    #             break
    #         num += 1
    
    # numpy_to_midi(midiarray)
    # print(midiarray)
    points = np.unique(midiarray,axis = 0)
    # print(points)
    vector_table = [[(q - p, p) for q in points] for p in points]
    # print(vector_table)
    half_table = [r[i+1:] for i,r in enumerate(vector_table) if i < len(r) - 1]
    # print(half_table)
    print("merge:")
    table_list = reduce(merge_sorted_arrays, half_table)
    # print(table_list)
    print("group:")
    patterns = group_by_first(table_list).values()
    # print(patterns)
    pdict = {str(p): i for i,p in enumerate(points)}
    # print(pdict)
    simple_table = [[r[0] for r in c] for c in vector_table]
    # print(simple_table)
    tsls = [[simple_table[pdict[str(o)]] for o in p] for p in patterns]
    # print(tsls)
    print("intersect:")
    vectors = [reduce(intersect_sorted_arrays, ts) for ts in tsls]
    # print(vectors)
    print("json:")
    # for i,v in enumerate(vectors):
    #     print(len(v))
    #     for p in patterns:
    #         print(len(p))
    #         occurrences = [p+v]

    # print(occurrences)
    occurrences = [[p+v for p in patterns[i]] for i,v in enumerate(vectors)]

    # result = to_json(points,patterns,vectors,occurrences)
    # print(result)    
    result = to_json(points,patterns,vectors,occurrences)
    print("save:")
    with open("test.json","w") as outfile:
        json.dump(result,outfile)

table:
merge:
group:
intersect:
json:


TypeError: 'dict_values' object does not support indexing