In [1]:
import numpy as np
import pandas as pd
import pretty_midi
import collections
from pathlib import Path
import fluidsynth

from IPython import display

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

# load data

In [4]:
def get_df_meta(datadir: str):
    files = collections.defaultdict(list)
    for filepath in datadir.glob('*.mid'):
        files['file'].append(str(filepath))
        composer = filepath.stem.split('_')[0]
        files['composer'].append(composer)

        pm = pretty_midi.PrettyMIDI(str(filepath))
        files['end_time'].append(pm.get_end_time())


        tempos, probabilities = pm.estimate_tempi()
        assert np.isclose(sum(probabilities), 1)
        tempo_bpm = np.dot(tempos, probabilities) # expected tempo in beats/min
        seconds_per_beat = (1/tempo_bpm)*60
        time_sig_denom = pm.time_signature_changes[0].denominator
        note_dist = seconds_per_beat / (16 / time_sig_denom)
        note_dist *= 1

        files['expected_tempo'].append(tempo_bpm)
        files['sampling_note_duration'].append(note_dist)
        roll = pm.get_piano_roll(fs=1/note_dist)
        files['roll_length'].append(roll.shape[1])

    df_meta = pd.DataFrame({ key: np.asarray(val) for key, val in files.items() })
    return df_meta

def get_simple_chords(processed_data_dir):

    preprocessed_data = collections.defaultdict(list)
    for midi_dir in processed_data_dir.glob('*/'):
        if not midi_dir.is_dir():
            continue
        piece_name = midi_dir.name
        chord_file_name = f'{piece_name}_full.npy'
        data_file = midi_dir / chord_file_name
        if not data_file.exists():
            print(data_file)
            raise Exception('shit is fucked!!')

        # record piece name and the path of the file containing the full chord data
        preprocessed_data['file'].append(str(data_file))
        preprocessed_data['piece_name'].append(piece_name)

        # compute the size of the file after 
        chord_roll = np.load(data_file)
        preprocessed_data['roll_length'].append(chord_roll.shape[0])

    df_preprocess = pd.DataFrame({ key: np.asarray(val) for key, val in preprocessed_data.items() })
    return df_preprocess

In [5]:
simple_chords = get_simple_chords(Path('./data/prog_benchmark_processed_handless'))
np.load(simple_chords.iloc[0]['file']).shape

(10, 69)

# dataset class

In [19]:
class PreprocessedChords(Dataset):
    
    def __init__(self, df_meta: pd.DataFrame, seq_length: int = 25, max_windows=None):
        self.df_meta = df_meta.copy()

        
        n_windows = self.df_meta['roll_length'].values - seq_length - 1
        if max_windows is not None:
            n_windows[n_windows > max_windows] = max_windows
            
        self.df_meta['n_windows'] = n_windows
        self.df_meta['n_windows'] = self.df_meta['n_windows'].astype(int)
        
        file_idx_ends = []
        n_windows = self.df_meta['n_windows'].values
        file_idx_ends = [n_windows[0] - 1]
        for windows_in_file in n_windows[1:]:
            file_idx_ends.append(windows_in_file + file_idx_ends[-1] )

        self.df_meta['file_idx_ends'] = file_idx_ends
        self.seq_length = seq_length
        
        self.roll_cache = {}
        
    def __len__(self):
        return self.df_meta['n_windows'].sum()
    
    def __getitem__(self, idx):
        file_idx = self.get_file_idx(idx)
        window_idx = self.get_window_idx(idx)
        
        seq, label = self.get_rolls(file_idx, window_idx, idx)
        
        seq = torch.from_numpy(seq).float()
        label = torch.from_numpy(label).float()
        return seq, label
    
    def get_file_idx(self, idx):
        file_idx = None
        file_idx_ends = self.df_meta['file_idx_ends'].values
        for i in range(len(file_idx_ends)):
            if idx <= file_idx_ends[i]:
                file_idx = i
                break
        if file_idx is None:
            raise ValueError(f'file_idx could not be found for {idx=}')
        return file_idx
    
    def get_window_idx(self, idx):
        file_idx = self.get_file_idx(idx)
        file_idx_ends = self.df_meta['file_idx_ends'].values
        if file_idx == 0:
            idx_start = 0
        else:
            idx_start = file_idx_ends[file_idx - 1]
            
        window_idx = int(idx - idx_start)
        return window_idx
    
    def midi_to_pianoroll(self, file, sample_dist=0.02):
        pm = pretty_midi.PrettyMIDI(file)
        
        sampling_rate = 1/sample_dist
        piano_roll = pm.get_piano_roll(fs=sampling_rate)
        return piano_roll
    
    def get_rolls(self, file_idx, window_idx, idx):
        file_path = self.df_meta.iloc[file_idx]['file']
        
        if file_idx in self.roll_cache:
            roll = self.roll_cache[file_idx]
        else:
            roll = np.load(file_path)
            # roll = roll[range(0, roll.shape[0], 7),:]
            self.roll_cache[file_idx] = roll
            
        roll_window = roll[window_idx:window_idx+self.seq_length+1, :]
        # if roll_window.shape[0] != self.seq_length + 1:
        #     print(f'{roll_window.shape[0]=}')
        #     raise Exception(f'fuck {idx=}')
        
        seq = roll_window[:-1]
        label = roll_window[-1]
        return seq, label

# model

In [7]:
class ChordLSTM(nn.Module):
    def __init__(self, vocab_size=1848, hidden_size=None):
        super(ChordLSTM, self).__init__()
            
        if hidden_size is None:
            hidden_size = vocab_size // 8
            
        self.lstm = nn.LSTM(input_size=vocab_size, batch_first=True, num_layers=1, hidden_size=hidden_size)
        
        self.norm = nn.BatchNorm1d(num_features=hidden_size)
        
        self.predict_layer = nn.Sequential(
            nn.Linear(hidden_size, vocab_size),
            nn.Softmax(dim=1)
        )
    def forward(self, x):
        output, (h_n, c_n) = self.lstm(x)
        
        last_hidden_vector = output[:, -1, :]
        linear_input = self.norm(last_hidden_vector)
        output = self.predict_layer(linear_input)
        
        return output

# train

In [39]:
len(dset_train)

384

In [40]:
torch.manual_seed(0)

# set filepaths
_reporoot = Path('/home/ian/projects/music-rnn')
processed_data_dir = _reporoot / 'data/chopin_processed_bin'
output_dir = Path('./')
if not output_dir.exists():
    output_dir.mkdir()
metrics_file = output_dir / 'metrics.csv'

# hyperparameters
seq_length = 6
learning_rate = 3e-4
batch_size = 8
num_workers = 0
n_iters = 384*30
output_interval = 384 // 2 
hidden_size = 50

# df_preprocess = get_preprocessed_files(processed_data_dir)
df_preprocess = get_simple_chords(Path('./data/prog_benchmark_processed_handless'))


g4_mask = df_preprocess['piece_name'].str.contains('G4')
df_train = df_preprocess[~g4_mask]
df_test = df_preprocess[g4_mask]

# rng = np.random.default_rng(12345)
# idx = np.arange(df_preprocess.shape[0])
# n_train = int(0.8*idx.shape[0])
# train_idx = rng.choice(idx, size=n_train, replace=False)
# test_idx = idx[~np.in1d(idx, train_idx)]
# df_train = df_preprocess.iloc[train_idx]
# df_test = df_preprocess.iloc[test_idx]

dset_train = PreprocessedChords(df_meta=df_train, 
                    seq_length=seq_length)


dset_test = PreprocessedChords(df_meta=df_test, 
                    seq_length=seq_length,
                    max_windows=60)


train_dataloader = DataLoader(dset_train, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
test_dataloader = DataLoader(dset_test, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)

model = ChordLSTM(hidden_size=hidden_size, vocab_size=69)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

metrics = collections.defaultdict(list)

iter_idx = -1
train_iterator = iter(train_dataloader)

train_losses = []

while iter_idx < n_iters:
    iter_idx += 1
    # print(f'{iter_idx=}', flush=True)

    try:
        features, labels = next(train_iterator)
    except StopIteration:
        train_iterator = iter(train_dataloader)
        features, labels = next(train_iterator)


    # compute prediction and loss
    pred = model(features)
    loss = loss_fn(pred, labels)

    # backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_losses.append(loss.item())


    # compute metrics every 10 iterations
    if iter_idx > 0 and iter_idx % output_interval == 0:
        print(f'{iter_idx=}', flush=True)

        metrics['iter'].append(iter_idx)

        # compute train loss
        train_loss = np.mean(np.asarray(train_losses))
        metrics['train_loss'].append(train_loss)
        train_losses = []


        # test loop
        test_loss_fn = nn.CrossEntropyLoss()
        test_loss = 0
        frames_correct = 0
        num_batches = len(test_dataloader)
        model.eval()
        with torch.no_grad():
            for features, labels in test_dataloader:
                pred = model(features)
                test_loss += test_loss_fn(pred, labels).item()

                pred_chords = pred.argmax(axis=1)
                label_chords = labels.argmax(axis=1) 
                equal = torch.eq(pred_chords, label_chords)
                frames_correct += torch.sum(equal)
        model.train()

        frac_frames_correct = frames_correct / (num_batches*batch_size)
        avg_test_loss = test_loss / num_batches

        metrics['test_loss'].append(avg_test_loss)
        metrics['frac_frames_correct'].append(frac_frames_correct)

        # save metrics
        df_metrics = pd.DataFrame({ key: np.asarray(val) for key, val in metrics.items() })
        # df_metrics.to_csv(metrics_file)

        # for key, val in metrics.items():
        #     print(f'iter_idx={iter_idx}, {key}={val[-1]}')

        # save model
        # state_file = output_dir / f'model_weights_iter{iter_idx}.pth'
        # torch.save(model.state_dict(), state_file)

iter_idx=192
iter_idx=384
iter_idx=576
iter_idx=768
iter_idx=960
iter_idx=1152
iter_idx=1344
iter_idx=1536
iter_idx=1728
iter_idx=1920
iter_idx=2112
iter_idx=2304
iter_idx=2496
iter_idx=2688
iter_idx=2880
iter_idx=3072
iter_idx=3264
iter_idx=3456
iter_idx=3648
iter_idx=3840
iter_idx=4032
iter_idx=4224
iter_idx=4416
iter_idx=4608
iter_idx=4800
iter_idx=4992
iter_idx=5184
iter_idx=5376
iter_idx=5568
iter_idx=5760
iter_idx=5952
iter_idx=6144
iter_idx=6336
iter_idx=6528
iter_idx=6720
iter_idx=6912
iter_idx=7104
iter_idx=7296
iter_idx=7488
iter_idx=7680
iter_idx=7872
iter_idx=8064
iter_idx=8256
iter_idx=8448
iter_idx=8640
iter_idx=8832
iter_idx=9024
iter_idx=9216
iter_idx=9408
iter_idx=9600
iter_idx=9792
iter_idx=9984
iter_idx=10176
iter_idx=10368
iter_idx=10560
iter_idx=10752
iter_idx=10944
iter_idx=11136
iter_idx=11328
iter_idx=11520


In [41]:
from preprocess.process_output import pianoRoll_to_midi, full_chord_to_pianoRoll
_SAMPLING_RATE = 16000

# header_file = './data/chp_headers/header_right_full.csv'
header_file = './data/prog_benchmark_processed_handless/header_full.csv'
CHORD_COLUMNS = pd.read_csv(header_file, index_col=0).values.flatten().tolist()

def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):
    waveform= pm.fluidsynth(fs=_SAMPLING_RATE)
    # Take a sample of the generated waveform to mitigate kernel resets
    # waveform_short = waveform[:seconds*_SAMPLING_RATE]
    return display.Audio(waveform, rate=_SAMPLING_RATE)

def sequence_to_midi(sequence, timestep):
    seq = np.asarray(sequence)
    df_seq = pd.DataFrame(data=seq, columns=CHORD_COLUMNS)
    df_roll = full_chord_to_pianoRoll(df_seq)
    midi = pianoRoll_to_midi(df_roll, timestep=timestep)
    return midi

In [45]:
dset_train.df_meta.head()

Unnamed: 0,file,piece_name,roll_length,n_windows,file_idx_ends
0,data/prog_benchmark_processed_handless/prog_1-...,prog_1-3-4-5-1_B4_sev,10,3,2
1,data/prog_benchmark_processed_handless/prog_4-...,prog_4-5-1-4-5-1_F4_sev,12,5,7
2,data/prog_benchmark_processed_handless/prog_1-...,prog_1-4-5-1_D4_triad,8,1,8
3,data/prog_benchmark_processed_handless/prog_1-...,prog_1-4-5-4-1_E4_sev,10,3,11
4,data/prog_benchmark_processed_handless/prog_1-...,prog_1-6-5-4-1_E4_sev,10,3,14


# training set example

In [50]:
start_idx = 12
seq_len = 6*3
dset = dset_train

sequence, label = dset[start_idx]
file_idx = dset.get_file_idx(start_idx)
print(dset.df_meta.iloc[file_idx]['file'])
piece_name = dset.df_meta.iloc[file_idx]['piece_name']
# timestep = name_time_map[piece_name]
timestep = 0.5

primer = np.asarray(sequence).astype(int)
primer_midi = sequence_to_midi(sequence=primer, timestep=timestep)
display.display(display_audio(primer_midi))

next_frames = []
model.eval()
with torch.no_grad():
    for idx in range(seq_len):
        # sequence, label = dset_test[start_idx + idx]
        pred = model(sequence.view(1, *sequence.shape))
        next_frame = (pred > 0.5).float()
        next_frames.append(next_frame)
        extended_sequence = torch.cat([sequence, next_frame], dim=0)
        sequence = extended_sequence[1:]
        # sequence = extended_sequence

out = torch.cat(next_frames,dim=0)
out = np.asarray(out).astype(int)
midi = sequence_to_midi(sequence=out, timestep=timestep)
display_audio(midi)

data/prog_benchmark_processed_handless/prog_1-6-5-4-1_E4_sev/prog_1-6-5-4-1_E4_sev_full.npy


In [54]:
primer_midi.write('simple_chord_results/training_set_primer.mid')
midi.write('simple_chord_results/training_set_output.mid')

# test set example

In [63]:
dset_test.df_meta

Unnamed: 0,file,piece_name,roll_length,n_windows,file_idx_ends
5,data/prog_benchmark_processed_handless/prog_1-...,prog_1-6-5-4-1_G4_triad,10,3,2
12,data/prog_benchmark_processed_handless/prog_1-...,prog_1-2-3-4-5-1_G4_sev,12,5,7
16,data/prog_benchmark_processed_handless/prog_1-...,prog_1-5-4-5-1_G4_triad,10,3,10
17,data/prog_benchmark_processed_handless/prog_1-...,prog_1-4-5-4-1_G4_sev,10,3,13
40,data/prog_benchmark_processed_handless/prog_1-...,prog_1-4-5-4-1_G4_triad,10,3,16
42,data/prog_benchmark_processed_handless/prog_1-...,prog_1-5-3-4-1_G4_triad,10,3,19
43,data/prog_benchmark_processed_handless/prog_1-...,prog_1-6-5-4-1_G4_sev,10,3,22
45,data/prog_benchmark_processed_handless/prog_1-...,prog_1-3-4-5-1_G4_sev,10,3,25
57,data/prog_benchmark_processed_handless/prog_1-...,prog_1-2-3-4-5-1_G4_triad,12,5,30
61,data/prog_benchmark_processed_handless/prog_1-...,prog_1-3-6-4-1_G4_triad,10,3,33


In [66]:
file_idx = 4
test_file = dset_test.df_meta.iloc[file_idx]['file']
roll = np.load(test_file)
roll.shape

(10, 69)

In [74]:
# start_idx = 11
seq_len = 6*3
dset = dset_test

file_idx = 4
test_file = dset_test.df_meta.iloc[file_idx]['file']
roll = np.load(test_file).astype(float)
sequence = torch.as_tensor(roll[:6, :]).float()
# timestep = name_time_map[piece_name]
timestep = 0.5

primer = np.asarray(sequence).astype(int)
primer_midi = sequence_to_midi(sequence=primer, timestep=timestep)
display.display(display_audio(primer_midi))

next_frames = []
model.eval()
with torch.no_grad():
    for idx in range(seq_len):
        # sequence, label = dset_test[start_idx + idx]
        pred = model(sequence.view(1, *sequence.shape))
        next_frame = (pred > 0.5).float()
        next_frames.append(next_frame)
        extended_sequence = torch.cat([sequence, next_frame], dim=0)
        sequence = extended_sequence[1:]
        # sequence = extended_sequence

out = torch.cat(next_frames,dim=0)
out = np.asarray(out).astype(int)
midi = sequence_to_midi(sequence=out, timestep=timestep)
display_audio(midi)

In [75]:
primer_midi.write('simple_chord_results/test_set_primer.mid')
midi.write('simple_chord_results/test_set_output.mid')

In [76]:
state_file = './simple_chord_results/model_weights.pth'
torch.save(model.state_dict(), state_file)