In [52]:
import numpy as np
import pandas as pd
import pretty_midi
import collections
from pathlib import Path
import fluidsynth

from IPython import display

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

# preprocessed data

In [268]:
def get_preprocessed_files(processed_data_dir):

    preprocessed_data = collections.defaultdict(list)
    for midi_dir in processed_data_dir.glob('*/'):
        if not midi_dir.is_dir():
            continue
        piece_name = midi_dir.name
        chord_file_name = f'{piece_name}_right_C_full.npy'
        data_file = midi_dir / chord_file_name
        if not data_file.exists():
            print(data_file)
            raise Exception('shit is fucked!!')

        # record piece name and the path of the file containing the full chord data
        preprocessed_data['file'].append(str(data_file))
        preprocessed_data['piece_name'].append(piece_name)

        # compute the size of the file after 
        chord_roll = np.load(data_file)
        chord_roll = chord_roll[range(0, chord_roll.shape[0], 7),:]
        preprocessed_data['roll_length'].append(chord_roll.shape[0])

    df_preprocess = pd.DataFrame({ key: np.asarray(val) for key, val in preprocessed_data.items() })
    return df_preprocess

def get_df_meta(datadir: str):
    files = collections.defaultdict(list)
    for filepath in datadir.glob('*.mid'):
        files['file'].append(str(filepath))
        composer = filepath.stem.split('_')[0]
        files['composer'].append(composer)

        pm = pretty_midi.PrettyMIDI(str(filepath))
        files['end_time'].append(pm.get_end_time())


        tempos, probabilities = pm.estimate_tempi()
        assert np.isclose(sum(probabilities), 1)
        tempo_bpm = np.dot(tempos, probabilities) # expected tempo in beats/min
        seconds_per_beat = (1/tempo_bpm)*60
        time_sig_denom = pm.time_signature_changes[0].denominator
        note_dist = seconds_per_beat / (16 / time_sig_denom)
        note_dist *= 1

        files['expected_tempo'].append(tempo_bpm)
        files['sampling_note_duration'].append(note_dist)
        roll = pm.get_piano_roll(fs=1/note_dist)
        files['roll_length'].append(roll.shape[1])

    df_meta = pd.DataFrame({ key: np.asarray(val) for key, val in files.items() })
    return df_meta

def get_simple_chords(processed_data_dir):

    preprocessed_data = collections.defaultdict(list)
    for midi_dir in processed_data_dir.glob('*/'):
        if not midi_dir.is_dir():
            continue
        piece_name = midi_dir.name
        chord_file_name = f'{piece_name}_full.npy'
        data_file = midi_dir / chord_file_name
        if not data_file.exists():
            print(data_file)
            raise Exception('shit is fucked!!')

        # record piece name and the path of the file containing the full chord data
        preprocessed_data['file'].append(str(data_file))
        preprocessed_data['piece_name'].append(piece_name)

        # compute the size of the file after 
        chord_roll = np.load(data_file)
        preprocessed_data['roll_length'].append(chord_roll.shape[0])

    df_preprocess = pd.DataFrame({ key: np.asarray(val) for key, val in preprocessed_data.items() })
    return df_preprocess

In [216]:
simple_chords = get_simple_chords(Path('./data/prog_benchmark_processed_handless'))
np.load(simple_chords.iloc[0]['file']).shape

(10, 69)

# dataset

In [269]:
class PreprocessedChords(Dataset):
    
    def __init__(self, df_meta: pd.DataFrame, seq_length: int = 25, max_windows=None):
        self.df_meta = df_meta.copy()

        
        n_windows = self.df_meta['roll_length'].values - seq_length - 1
        if max_windows is not None:
            n_windows[n_windows > max_windows] = max_windows
            
        self.df_meta['n_windows'] = n_windows
        self.df_meta['n_windows'] = self.df_meta['n_windows'].astype(int)
        
        file_idx_ends = []
        n_windows = self.df_meta['n_windows'].values
        file_idx_ends = [n_windows[0] - 1]
        for windows_in_file in n_windows[1:]:
            file_idx_ends.append(windows_in_file + file_idx_ends[-1] )

        self.df_meta['file_idx_ends'] = file_idx_ends
        self.seq_length = seq_length
        
        self.roll_cache = {}
        
    def __len__(self):
        return self.df_meta['n_windows'].sum()
    
    def __getitem__(self, idx):
        file_idx = self.get_file_idx(idx)
        window_idx = self.get_window_idx(idx)
        
        seq, label = self.get_rolls(file_idx, window_idx, idx)
        
        seq = torch.from_numpy(seq).float()
        label = torch.from_numpy(label).float()
        return seq, label
    
    def get_file_idx(self, idx):
        file_idx = None
        file_idx_ends = self.df_meta['file_idx_ends'].values
        for i in range(len(file_idx_ends)):
            if idx <= file_idx_ends[i]:
                file_idx = i
                break
        if file_idx is None:
            raise ValueError(f'file_idx could not be found for {idx=}')
        return file_idx
    
    def get_window_idx(self, idx):
        file_idx = self.get_file_idx(idx)
        file_idx_ends = self.df_meta['file_idx_ends'].values
        if file_idx == 0:
            idx_start = 0
        else:
            idx_start = file_idx_ends[file_idx - 1]
            
        window_idx = int(idx - idx_start)
        return window_idx
    
    def midi_to_pianoroll(self, file, sample_dist=0.02):
        pm = pretty_midi.PrettyMIDI(file)
        
        sampling_rate = 1/sample_dist
        piano_roll = pm.get_piano_roll(fs=sampling_rate)
        return piano_roll
    
    def get_rolls(self, file_idx, window_idx, idx):
        file_path = self.df_meta.iloc[file_idx]['file']
        
        if file_idx in self.roll_cache:
            roll = self.roll_cache[file_idx]
        else:
            roll = np.load(file_path)
            roll = roll[range(0, roll.shape[0], 7),:]
            self.roll_cache[file_idx] = roll
            
        roll_window = roll[window_idx:window_idx+self.seq_length+1, :]
        # if roll_window.shape[0] != self.seq_length + 1:
        #     print(f'{roll_window.shape[0]=}')
        #     raise Exception(f'fuck {idx=}')
        
        seq = roll_window[:-1]
        label = roll_window[-1]
        return seq, label

# model

In [230]:
class ChordLSTM(nn.Module):
    def __init__(self, vocab_size=1848, hidden_size=None):
        super(ChordLSTM, self).__init__()
            
        if hidden_size is None:
            hidden_size = vocab_size // 8
            
        self.lstm = nn.LSTM(input_size=vocab_size, batch_first=True, num_layers=1, hidden_size=hidden_size)
        
        self.norm = nn.BatchNorm1d(num_features=hidden_size)
        
        self.predict_layer = nn.Sequential(
            nn.Linear(hidden_size, vocab_size),
            nn.Softmax(dim=1)
        )
    def forward(self, x):
        output, (h_n, c_n) = self.lstm(x)
        
        last_hidden_vector = output[:, -1, :]
        linear_input = self.norm(last_hidden_vector)
        output = self.predict_layer(linear_input)
        
        return output

# train

In [271]:
len(dset_train)

19348

In [272]:
torch.manual_seed(0)

# set filepaths
_reporoot = Path('/home/ian/projects/music-rnn')
processed_data_dir = _reporoot / 'data/chopin_processed_bin'
output_dir = Path('./')
if not output_dir.exists():
    output_dir.mkdir()
metrics_file = output_dir / 'metrics.csv'

# hyperparameters
seq_length = 30
learning_rate = 3e-4
batch_size = 8
num_workers = 0
n_iters = 19348*5
output_interval = 19348 // 2
hidden_size = 80

df_preprocess = get_preprocessed_files(processed_data_dir)
# df_preprocess = get_simple_chords(Path('./data/prog_benchmark_processed_handless'))

rng = np.random.default_rng(12345)
idx = np.arange(df_preprocess.shape[0])
n_train = int(0.8*idx.shape[0])
train_idx = rng.choice(idx, size=n_train, replace=False)
test_idx = idx[~np.in1d(idx, train_idx)]
df_train = df_preprocess.iloc[train_idx]
df_test = df_preprocess.iloc[test_idx]

dset_train = PreprocessedChords(df_meta=df_train, 
                    seq_length=seq_length)


dset_test = PreprocessedChords(df_meta=df_test, 
                    seq_length=seq_length,
                    max_windows=60)


train_dataloader = DataLoader(dset_train, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
test_dataloader = DataLoader(dset_test, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)

model = ChordLSTM(hidden_size=hidden_size)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

metrics = collections.defaultdict(list)

iter_idx = -1
train_iterator = iter(train_dataloader)

train_losses = []

while iter_idx < n_iters:
    iter_idx += 1
    # print(f'{iter_idx=}', flush=True)

    try:
        features, labels = next(train_iterator)
    except StopIteration:
        train_iterator = iter(train_dataloader)
        features, labels = next(train_iterator)


    # compute prediction and loss
    pred = model(features)
    loss = loss_fn(pred, labels)

    # backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_losses.append(loss.item())


    # compute metrics every 10 iterations
    if iter_idx > 0 and iter_idx % output_interval == 0:
        print(f'{iter_idx=}', flush=True)

        metrics['iter'].append(iter_idx)

        # compute train loss
        train_loss = np.mean(np.asarray(train_losses))
        metrics['train_loss'].append(train_loss)
        train_losses = []


        # test loop
        test_loss_fn = nn.CrossEntropyLoss()
        test_loss = 0
        frames_correct = 0
        num_batches = len(test_dataloader)
        model.eval()
        with torch.no_grad():
            for features, labels in test_dataloader:
                pred = model(features)
                test_loss += test_loss_fn(pred, labels).item()

                pred_chords = pred.argmax(axis=1)
                label_chords = labels.argmax(axis=1) 
                equal = torch.eq(pred_chords, label_chords)
                frames_correct += torch.sum(equal)
        model.train()

        frac_frames_correct = frames_correct / (num_batches*batch_size)
        avg_test_loss = test_loss / num_batches

        metrics['test_loss'].append(avg_test_loss)
        metrics['frac_frames_correct'].append(frac_frames_correct)

        # save metrics
        df_metrics = pd.DataFrame({ key: np.asarray(val) for key, val in metrics.items() })
        # df_metrics.to_csv(metrics_file)

        # for key, val in metrics.items():
        #     print(f'iter_idx={iter_idx}, {key}={val[-1]}')

        # save model
        # state_file = output_dir / f'model_weights_iter{iter_idx}.pth'
        # torch.save(model.state_dict(), state_file)

iter_idx=9674
iter_idx=19348
iter_idx=29022
iter_idx=38696
iter_idx=48370
iter_idx=58044
iter_idx=67718
iter_idx=77392
iter_idx=87066
iter_idx=96740


In [273]:
df_metrics

Unnamed: 0,iter,train_loss,test_loss,frac_frames_correct
0,9674,7.166122,7.355737,0.169643
1,19348,7.017081,7.35115,0.176786
2,29022,6.989837,7.350761,0.173214
3,38696,6.975683,7.360286,0.166071
4,48370,6.96465,7.36436,0.164286
5,58044,6.953945,7.365532,0.160714
6,67718,6.945099,7.359767,0.167857
7,77392,6.939502,7.355138,0.167857
8,87066,6.933152,7.355483,0.169643
9,96740,6.926656,7.350853,0.173214


In [274]:
from preprocess.process_output import pianoRoll_to_midi, full_chord_to_pianoRoll
_SAMPLING_RATE = 16000

header_file = './data/chp_headers/header_right_full.csv'
# header_file = './data/prog_benchmark_processed_handless/header_full.csv'
CHORD_COLUMNS = pd.read_csv(header_file, index_col=0).values.flatten().tolist()

def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):
    waveform= pm.fluidsynth(fs=_SAMPLING_RATE)
    # Take a sample of the generated waveform to mitigate kernel resets
    # waveform_short = waveform[:seconds*_SAMPLING_RATE]
    return display.Audio(waveform, rate=_SAMPLING_RATE)

def sequence_to_midi(sequence, timestep):
    seq = np.asarray(sequence)
    df_seq = pd.DataFrame(data=seq, columns=CHORD_COLUMNS)
    df_roll = full_chord_to_pianoRoll(df_seq)
    midi = pianoRoll_to_midi(df_roll, timestep=timestep)
    return midi

In [275]:
# get df_meta
df_meta = get_df_meta(Path('./data/classical'))
df_meta = df_meta[df_meta['composer'] == 'chpn']

In [276]:
# get name_time_map
df_time = df_meta.copy()

def get_name(row):
    file = row['file']
    name = Path(file).stem
    return name

df_time['piece_name'] = df_time.apply(get_name, axis=1)
df_time = df_time[['piece_name', 'sampling_note_duration']]
name_time_map = { row['piece_name']: row['sampling_note_duration'] for _, row in df_time.iterrows() }

In [250]:
state_file = './right_chords/run0/model_weights.pth'
model = ChordLSTM(hidden_size=200)
model.load_state_dict(torch.load(state_file))
model.eval()

ChordLSTM(
  (lstm): LSTM(1848, 200, batch_first=True)
  (norm): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predict_layer): Sequential(
    (0): Linear(in_features=200, out_features=1848, bias=True)
    (1): Softmax(dim=1)
  )
)

In [252]:
dset_test.df_meta

Unnamed: 0,file,piece_name,roll_length,n_windows,file_idx_ends
0,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op66,5857,5826,5825
1,/home/ian/projects/music-rnn/data/chopin_proce...,chpn-p15,7630,7599,13424
8,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op53,8279,8248,21672
11,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op25_e3,2089,2058,23730
22,/home/ian/projects/music-rnn/data/chopin_proce...,chpn-p10,388,357,24087
23,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op25_e12,1432,1401,25488
26,/home/ian/projects/music-rnn/data/chopin_proce...,chpn-p11,688,657,26145
44,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op25_e4,1125,1094,27239
45,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op33_4,5427,5396,32635
47,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op35_2,10527,10496,43131


In [278]:
dset_train.df_meta.head()

Unnamed: 0,file,piece_name,roll_length,n_windows,file_idx_ends
41,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op25_e11,377,346,345
31,/home/ian/projects/music-rnn/data/chopin_proce...,chpn-p9,57,26,371
28,/home/ian/projects/music-rnn/data/chopin_proce...,chpn-p23,55,24,395
27,/home/ian/projects/music-rnn/data/chopin_proce...,chpn-p6,183,152,547
15,/home/ian/projects/music-rnn/data/chopin_proce...,chpn_op25_e1,49,18,565


In [288]:
start_idx = 0
seq_len = 12*3
dset = dset_test

sequence, label = dset[start_idx]
file_idx = dset.get_file_idx(start_idx)
piece_name = dset.df_meta.iloc[file_idx]['piece_name']
timestep = name_time_map[piece_name]*2
# timestep = 0.5

primer = np.asarray(sequence).astype(int)
primer_midi = sequence_to_midi(sequence=primer, timestep=timestep)
display.display(display_audio(primer_midi))

next_frames = []
model.eval()
with torch.no_grad():
    for idx in range(seq_len):
        # sequence, label = dset_test[start_idx + idx]
        pred = model(sequence.view(1, *sequence.shape))
        next_frame = (pred > 0.5).float()
        next_frames.append(next_frame)
        extended_sequence = torch.cat([sequence, next_frame], dim=0)
        sequence = extended_sequence[1:]
        # sequence = extended_sequence

out = torch.cat(next_frames,dim=0)
out = np.asarray(out).astype(int)
midi = sequence_to_midi(sequence=out, timestep=timestep)
display_audio(midi)

In [285]:
primer_midi.write('./full_chord_chopin_results/train_example_primer.mid')
midi.write('./full_chord_chopin_results/train_example_output.mid')

In [90]:
dset.df_meta.iloc[file_idx]['file']

'/home/ian/projects/music-rnn/data/chopin_processed_bin/chpn_op25_e11/chpn_op25_e11_right_C_full.npy'

In [287]:
state_file = './full_chord_chopin_results/model_weights.pth'
torch.save(model.state_dict(), state_file)