In [None]:
import os
import pickle
from collections import defaultdict
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import mido
from mido import MidiFile, MidiTrack, Message

import notochord

pd.set_option('display.float_format', '{:0.8f}'.format)

# load data

In [None]:
data_dir = "/path/to/mrp_data"

In [None]:
def ext_files_in_dir(directory, ext=".txt"):
    files = os.listdir(directory)
    files = [f for f in files if f.endswith(ext)]
    return files

def mrp_txt_to_df(file):
    return pd.read_csv(file,
        names=('time', 'osc', 'types', 'v0', 'v1', 'v2'), 
        converters={
            'time':float, 'osc':str, 'types':str, 
            'v0':float, 'v1':float, 'v2':float},
        sep='\s+')

def save_pkl(data, file):
    with open(file, 'wb') as f:
        pickle.dump(data, f)

def load_pkl(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

In [None]:
df = mrp_txt_to_df(data_dir+"/data.log")

In [None]:
osc_types = df.osc.value_counts().index
df.osc.value_counts()

# events over time by MIDI number

In [None]:
def scatter(osc,x,y,v0=None,v2=0,s=2,alpha=0.03, ax=plt, df=df, **kw):
    cond = df.osc==osc
    if v0 is not None:
        cond &= (df.v0==v0)
    df_ = df[cond]
    ax.scatter(df_[x],df_[y] + df_.v2*v2, alpha=alpha, marker='.', s=s, ls='', **kw)

In [None]:
fig,ax = plt.subplots(figsize=(80,44))
scatter('/mrp/midi','time','v1', v0=143, s=32, alpha=1, c='k', ax=ax) # note off
scatter('/mrp/midi','time','v1', v0=159, s=32, alpha=1, ax=ax) # note on
# scatter('/mrp/midi','time','v1', v0=176, s=512, alpha=1, c='y', ax=ax)
scatter('/mrp/quality/intensity','time','v1', v2=0.5, alpha=0.05, ax=ax, c='r')
scatter('/mrp/quality/pitch/vibrato','time','v1', v2=1, alpha=0.05, ax=ax, c='b')
scatter('/mrp/quality/brightness','time','v1', v2=1/10, alpha=0.05, ax=ax, c='g')
scatter('/mrp/quality/harmonic','time','v1', v2=1, alpha=0.05, ax=ax, c='m')
# plt.savefig(data_dir+'/name.png')

# extract notes into dictionary

In [None]:
# filter just note events
notes_onoff = df[(df.osc=='/mrp/midi') & (df.v0<=159)]
# initially store notes by MIDI pitch
notes_by_pitch = {int(p):[] for p in notes_onoff.v1.unique()}
# track held notes to validate matching noteon/noteoff
note_is_on = {k:False for k in notes_by_pitch}

In [None]:
(df.time.diff().fillna(0) >= 0).all()

In [None]:
# for all rows in MRP recording
for i,(idx,row) in enumerate(tqdm(df.iterrows())):
    pitch = int(row.v1)
    osc = row.osc.split('/')[-1] # short label for event type
    t = row.time
    if osc=='midi':
        if row.v0==159: # noteOn
            if note_is_on[pitch]:
                tqdm.write(f'bad noteOn {pitch} at {t}')
                continue
            note_is_on[pitch] = True
            # get the series of notes at this pitch
            note_seq = notes_by_pitch[pitch]
            # append a new note
            note = defaultdict(list)
            note_seq.append(note)
            note['start_time'] = t
        elif row.v0==143: # noteOff
            if not note_is_on[pitch]:
                tqdm.write(f'bad noteOff {pitch} at {t}')
                continue
            note_is_on[pitch] = False
            # end the current note at this pitch
            note = notes_by_pitch[pitch][-1]
            # convert expression curves to pd.Series
            for k in set(note) & {
                'intensity', 'vibrato', 'brightness', 'harmonic'}:
                # time is index
                note[k] = pd.Series(*zip(*note[k]))
            note['end_time'] = t
        else:
            # skip any non-note MIDI
            tqdm.write(f'skip {row.osc} {row.v0}')
    else: # OSC expression data
        # accumulate into current note at this pitch
        note = notes_by_pitch[pitch][-1]
        # use times relative to note start
        note[osc].append((row.v2, t - note['start_time']))

# validate that all notes ended
for k,v in note_is_on.items():
    if v:
        print(f'note {k} not closed')

In [None]:
durs = pd.Series([
    note['end_time'] - note['start_time']
    for p in notes_by_pitch
    for note in notes_by_pitch[p]
])
durs.describe()

In [None]:
note_counts = {k:len(v) for k,v in notes_by_pitch.items()}
len(note_counts), sum(note_counts.values())

In [None]:
def plot_all_curves(k, ylim=(-0.05,1.1), xlim=(1e-3,3e1), alpha=0.1):
    fig, ax = plt.subplots(figsize=(24,16))
    for p in notes_by_pitch:
        for note in notes_by_pitch[p]:
            curve = note[k]
            if isinstance(curve, pd.Series):
                curve.plot(logx=True, alpha=alpha, ax=ax, ylim=ylim, xlim=xlim)

In [None]:
plot_all_curves('intensity')

In [None]:
plot_all_curves('brightness', ylim=(0, 10))

In [None]:
plot_all_curves('vibrato', ylim=(-1.1, 1.1))

In [None]:
plot_all_curves('harmonic', ylim=(0, 2), alpha=0.5)

In [None]:
# convert to flat time-indexed notes dict
all_times = []
notes_by_time = {}
short_notes = []
for p in notes_by_pitch:
    for note in notes_by_pitch[p]:
        note['pitch'] = p
        dur = note['end_time'] - note['start_time']
        if dur < 5e-2:
            # print(f'skipping {dur=}')
            short_notes.append(note)
            continue
        notes_by_time[note['start_time']] = note
        # prevent any exactly simultaneous events
        shift = np.random.rand()*1e-5
        note['start_time'] += shift
        note['end_time'] += shift + np.random.rand()*1e-5
        all_times.append(note['start_time'])
        all_times.append(note['end_time'])
len(all_times), len(set(all_times))

In [None]:
len(short_notes)
# pd.Series([note['pitch'] for note in short_notes]).value_counts()

In [None]:
durs = pd.Series([
    note['end_time'] - note['start_time']
    for note in notes_by_time.values()
])
durs.describe()

In [None]:
durs[durs<0.1].plot(kind='hist', bins=100)

# expression -> velocity


In [None]:
# compute velocity scores and attach to notes
intens = []
harm = []
bright = []
vib = []
def reduce(s):
    return s.mean() if isinstance(s, pd.Series) else 0
for note in tqdm(notes_by_time.values()):
    intens.append(reduce(note['intensity']))
    harm.append(reduce(note['harmonic']))
    bright.append(reduce(note['brightness']))
    vib.append(reduce(note['vibrato']))

score = np.array(intens) + np.array(harm) + np.abs(np.array(vib)) + np.array(bright)**0.5/8

score_scale = score / max(score)
score_rank = np.argsort(np.argsort(score)) / (len(score)-1)
score_mix = ((score_scale + score_rank)/2)**0.5 * 126 + 1
for note, score in zip(notes_by_time.values(), score_mix):
    note['vel_score'] = score

In [None]:
save_pkl(notes_by_time, 'notes_by_time_gig.pkl')

# resume from here if preprocessing has been done already

In [None]:
# notes_by_time = load_pkl('notes_by_time.pkl')
notes_by_time = load_pkl('notes_by_time_gig.pkl')

# Notochord

In [None]:
model = notochord.Notochord.from_checkpoint('notochord-latest.ckpt')
model.reset()

In [None]:
# convert notes to notochord events
# velocity 0 = noteoff,
# noteon uses velocity score computed above
events_by_time = {}
for note in notes_by_time.values():
    inst = 1
    pitch = note['pitch']
    if note['start_time'] in events_by_time: raise ValueError
    if note['end_time'] in events_by_time: raise ValueError
    events_by_time[note['start_time']] = {
        'pitch':pitch,
        'inst':inst,
        'vel': max(1, note['vel_score'])
    }
    events_by_time[note['end_time']] = {
        'pitch':pitch,
        'inst':inst,
        'vel': 0
    }

# compute delta time
t = 0
for k in sorted(events_by_time):
    event = events_by_time[k]
    event['time'] = k - t
    t = k

# get set of performed pitches
all_pitch = set(note_counts.keys())

In [None]:
# validate no concurrent events
assert len(events_by_time)==len(set(events_by_time))

In [None]:
# check balanced number of noteon/off events
pd.Series([e['vel']>=0.5 for e in events_by_time.values()]).value_counts()

In [None]:
deltas = [e['time'] for e in events_by_time.values()]
pd.Series(deltas).describe()

### write MIDI

In [None]:
# function to dump notochord events to MIDI file
def to_mid(events, file='output.mid', pc=None):
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)
    if pc is not None:
        track.append(Message('program_change', program=pc))
    ticks_per_second = mid.ticks_per_beat / (500000 / 1000000)
    for k in tqdm(sorted(events)):
        event = events[k]
        delta_ticks = int(event['time'] * ticks_per_second)
        if event['vel'] < 0.5: track.append(Message(
                 'note_off', note=event['pitch'], velocity=100, time=delta_ticks))
        else: track.append(Message(
                 'note_on', note=event['pitch'], velocity=int(event['vel']+0.5), time=delta_ticks))
    mid.save(file)
    return mid

In [None]:
# # write data as a MIDI file for model training etc
# data_mid = to_mid(events_by_time, 'training/data.mid', pc=0)
data_mid = to_mid(events_by_time, 'test.mid', pc=None)

## continuation

In [None]:
def noto_continue(
        max_note_len = 10,
        total_notes = 1000
        ):
    # feed the data as a prompt
    model.reset()
    for k in tqdm(sorted(events_by_time), desc='prompt'):
        event = events_by_time[k]
        model.feed(**event)

    # free generation
    gen_events = {}
    t = 0
    held_pitch_starts = {}
    note_count = 0
    ### heat safety ##TODO
    last_stopped = {}
    cumulative_on = {p:0 for p in all_pitch}
    ###
    for _ in tqdm(range(total_notes*2), desc='generation'):
        held_pitch = set(held_pitch_starts)
        hot_pitches = {p for p,l in cumulative_on.items() if l > 45}

        on_map = {inst:all_pitch - held_pitch - hot_pitches}
        # end all notes after 
        if note_count >= total_notes:
            if not len(held_pitch):
                break
            on_map = {inst:set()}

        long_note_ps = {
            p for p,st in held_pitch_starts.items() 
            if (t - st > max_note_len) or p in hot_pitches}
                    
        if len(long_note_ps):
            off_map = {inst:long_note_ps}
        else:
            off_map = {inst:held_pitch}

        event = model.query_vipt(
            note_on_map=on_map,
            note_off_map=off_map,
            # min_time=0.001,
            max_time=16,
            truncate_quantile_time=(0.25, 0.9),
            steer_density=0.55,
            )
        model.feed(**event)
        # print(event)
        t = max(event['time'], 1e-5) + t
        gen_events[t] = event
        p = event['pitch']
        if event['vel'] >= 0.5:
            note_count += 1
            held_pitch_starts[p] = t
        else:
            held_pitch_starts.pop(p)

    return gen_events

In [None]:
# gen_events = noto_continue()
# mid = to_mid(gen_events)
# # duration in seconds and events
# max(gen_events), len(gen_events)

## variations

In [None]:
def noto_variation(
        max_note_len = 7,
        pitch_temp = 0.85,
        truncate_quantile_time=(0.1,None),
    ):
    # feed the data as a prompt
    model.reset()
    for k in tqdm(sorted(events_by_time), desc='prompt'):
        event = events_by_time[k]
        model.feed(**event)

    temp_kw = dict(
        pitch_temp = pitch_temp,
        truncate_quantile_time=truncate_quantile_time,
    )
    gen_events = {}
    t = 0

    held_pitch_starts = {}
    ### heat safety
    cumulative_on = {p:0 for p in all_pitch}
    ###
    for k in tqdm(sorted(events_by_time), desc='generation'):
        # actual event
        event = events_by_time[k]

        hot_pitches = {p for p,l in cumulative_on.items() if l > 45}
        held_pitch = set(held_pitch_starts)
      
        # get an event from notochord which roughly matches the performed event
        if event['vel'] >= 0.5:
            # noteon
            valid_pitch = (
                (set(range(event['pitch']-12, event['pitch']+13)) & all_pitch)
                # (all_pitch - {event['pitch']}) # inverse pitch
                # ({event['pitch']-17, event['pitch']-12, event['pitch']-7, event['pitch']-5, event['pitch'], event['pitch']+5, event['pitch']+5, event['pitch']+12, event['pitch']+17} & all_pitch)
                - held_pitch
                - hot_pitches
                )
            if not len(valid_pitch):
                # escape hatch if all pitches are excluded above
                valid_pitch = all_pitch - held_pitch - hot_pitches
            if not len(valid_pitch):
                raise ValueError(f"""
                    can't start any more notes: 
                    {held_pitch=} 
                    {hot_pitches=} 
                    {all_pitch=}""")
    
            # handle case where there is exactly one possible pitch:
            if len(valid_pitch)==1:
                kw = dict(next_pitch=next(iter(valid_pitch)))
            else:
                kw = dict(include_pitch=valid_pitch)

            # reduce long gaps:
            gap = min(event['time'], 2)

            event = model.query_feed(
                next_inst=inst,
                min_time=gap/2,
                max_time=gap*1.5,
                min_vel=max(1, event['vel']-16),
                max_vel=min(127, event['vel']+16),
                **kw,
                **temp_kw
            )
        else:
            # noteoff
            long_note_ps = {
                p for p,st in held_pitch_starts.items() 
                if (t - st > max_note_len) or p in hot_pitches}

            if len(long_note_ps):
                valid_pitch = long_note_ps
            else:
                valid_pitch = held_pitch
            if len(valid_pitch)==1:
                kw = dict(next_pitch=next(iter(valid_pitch)))
            else:
                kw = dict(include_pitch=valid_pitch)

            event = model.query_feed(
                next_inst=inst,
                next_vel=0,
                min_time=event['time']/2,
                max_time=event['time']*1.5,
                **kw,
                **temp_kw
            )

        t = max(event['time'], 1e-5) + t
        gen_events[t] = event
        p = event['pitch']
        # update heat 
        for p in cumulative_on:
            if p in held_pitch_starts:
                cumulative_on[p] = cumulative_on[p] + event['time']
            else:
                cumulative_on[p] = max(0, cumulative_on[p] - event['time'])
        # update held notes
        if event['vel'] >= 0.5:
            held_pitch_starts[p] = t
        else:
            held_pitch_starts.pop(p)

    return gen_events

In [None]:
gen_events = noto_variation()
mid = to_mid(gen_events)
# duration in seconds and events
max(gen_events), len(gen_events)

## events -> notes w durations

In [None]:
# add absolute time and duration to each noteon, and put in gen_notes
held_notes = {}
gen_notes = {}
for t,event in gen_events.items():
    p = event['pitch']
    if event['vel'] >= 0.5:
        event['abstime'] = t
        gen_notes[t] = held_notes[p] = event
    else:
        note = held_notes.pop(p)
        note['duration'] = t - note['abstime']


## notes back to curves

In [None]:
# resample an expression curve to desired length, 
# leaving the attack alone but stretching the rest
def resamp_hybrid(s, target_len, attack=0.2, expression_sr=100):
    if not isinstance(s, pd.Series):
        return s
    try:
        stretch = target_len / max(s.index)
    except:
        stretch = 1
    new_t = np.linspace(3e-3, target_len-3e-3, int(target_len*expression_sr))
    mod_t = pd.Index([
        *s.index[s.index < attack], 
        *((s.index[s.index >= attack]-attack)*stretch + attack)
        ])
    return pd.Series(np.interp(new_t, mod_t, s), index=new_t)

In [None]:
data_notes = list(notes_by_time.values())
vel_scores = np.array([
    data_note['vel_score']
    for data_note in data_notes])

In [None]:
df_out = []

def get_curve_df(curve, osc, t):
    df_ = pd.DataFrame(columns=df.columns)
    df_['time'] = intens.index + t
    df_['osc'] = osc
    df_['types'] = 'iif'
    df_['v0'] = 15
    df_['v1'] = note['pitch']
    df_['v2'] = curve.values
    return df_

for t,note in tqdm(gen_notes.items()):
    v = note['vel']
    dist = np.abs(v-vel_scores)
    idx = np.argmin(dist)
    data_note = data_notes[idx]

    intens = resamp_hybrid(data_note['intensity'], note['duration'])
    if isinstance(intens, pd.Series):
        df_out.append(get_curve_df(intens, '/mrp/quality/intensity', t))

    bright = resamp_hybrid(data_note['brightness'], note['duration'])
    if isinstance(bright, pd.Series):
        df_out.append(get_curve_df(bright, '/mrp/quality/brightness', t))

    vib = resamp_hybrid(data_note['vibrato'], note['duration'])
    if isinstance(vib, pd.Series):
        df_out.append(get_curve_df(vib, '/mrp/quality/pitch/vibrato', t))

    harm = resamp_hybrid(data_note['harmonic'], note['duration'])
    if isinstance(harm, pd.Series):
        df_out.append(get_curve_df(harm, '/mrp/quality/harmonic', t))

df_out = pd.concat(df_out)
df_out

## events to df

In [None]:
# convert notochord events to MIDI datadrame
def gen_events_to_df(gen_events, columns):
    note_on = lambda t, p, i: pd.Series([t, '/mrp/midi', 'iii', 159, p, 127], index=i)
    note_off = lambda t, p, i: pd.Series([t, '/mrp/midi', 'iii', 143, p, 0], index=i)
    rows = []
    for t, event in tqdm(gen_events.items(), total=len(gen_events)):
        if event['vel'] > 0.5:
            rows.append(note_on(t, event['pitch'], columns))
        else:
            rows.append(note_off(t, event['pitch'], columns))
    return pd.DataFrame(rows, columns=columns)

gen_df = gen_events_to_df(gen_events, df.columns)
gen_df

In [None]:
gen_df.v2 = gen_df.v2.astype(str)

In [None]:
all_df = pd.concat((df_out, gen_df))
all_df = all_df.sort_values('time')
all_df

In [None]:
# dataframe to MRP log file
def df_to_mrp(df: pd.DataFrame, file=None):
    if file is None: file = f'mrp_{datetime.now().strftime("%Y_%m_%d-%H%M%S")}.log'
    row_to_str = lambda row: ' '.join(row.astype(str))
    rows = []
    df.time = np.round(df.time, 5)
    df.v2 = np.round(df.v2, 6)
    with open(file, 'w') as f:
        for i, row in df.iterrows():
            row_str = row_to_str(row)
            f.write(row_str + '\n')
            rows.append(row_str)
    return rows

In [None]:
rows = df_to_mrp(all_df)

In [None]:
rows[0], rows[-1]