# Simple RNN Example

In [81]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [82]:
%reload_ext autoreload

In [83]:
import sys
sys.path.append('../')

Here we'll train the simple RNN model that can be found in the `src/models` directory and use a few of the functions provided by neural-seq to process the midi files. For this we'll train it on a single Phil Collins song. If the model has been implemented correctly then we should see the loss decrease and the model should essentially memorise the performance. This is a useful sanity check, but also it's interesting to play around with the sampling strategy; even though an RNN that has memorised a single track isn't all that academically interesting, by adding some stochastisity to the generated output it demonstrates some potential as a creative tool.

In [84]:
import torch
from torch import nn
from torch.functional import F

In [85]:
from src.utils.midi_encode import MIDIData
from src.encoders.simple_drum import DrumEncoder
from src.utils.midi_data import fetch_midi_data, play_midi

## Encoding MIDI data

Let's fetch the dataset first. This will download the [The Lakh MIDI](https://colinraffel.com/projects/lmd/) which includes the song we're after.

In [None]:
fetch_midi_data()

In [86]:
ls ../.data/midi_data/clean_midi/Phil\ Collins/

[31mA Groovy Kind of Love.mid[m[m*     [31mIn The Air Tonight.1.mid[m[m*
[31mAgainst All Odds.mid[m[m*          [31mIn The Air Tonight.mid[m[m*
[31mAnother Day in Paradise.1.mid[m[m* [31mNo Son of Mine.mid[m[m*
[31mAnother Day in Paradise.2.mid[m[m* [31mOne More Night.mid[m[m*
[31mAnother Day in Paradise.mid[m[m*   [31mSussudio.mid[m[m*
[31mDon't Lose My Number.mid[m[m*      [31mTrue Colors.mid[m[m*
[31mEasy Lover.mid[m[m*                [31mYou Can't Hurry Love.mid[m[m*
[31mI Wish It Would Rain Down.mid[m[m*


Next we'll specify an encoder. The `simple_drum` encoder will convert the MIDI file (supplied to it as a [PrettyMIDI](http://craffel.github.io/pretty-midi/) object) into a list of 'action' and 'duration' tokens that specify the pitch (the drum in this case) that is played and the duration until the next set of pitches is played. This is essentially a polyphonic encoding scheme but since the expected instrument is percussive and the sounds are short-lived, all the lengths are fixed to a 32nd-note duration.

In [87]:
encoder = DrumEncoder()

Specify the songs that we're going to encode:

In [88]:
songs = [
    'In The Air Tonight.mid',
]

In [89]:
def encode_songs(songs, encoder):
    encodings = []
    for s in songs:
        midi_data = MIDIData.from_file(f'../.data/midi_data/clean_midi/Phil Collins/{s}', encoder, instrument='drum kit')
        encodings += (midi_data.encode())
    return encodings

Notice that an instrument filter was applied when loading the MIDI data which filtered out the drum kit. Any of the names from the [General MIDI Instrument list](https://soundprogramming.net/file-formats/general-midi-instrument-list/) can be specified.

In [90]:
encoded = encode_songs(songs, encoder)

Found instrument matching filter: Drum Kit


By inspecting the tokens that are in the encoded performance we can extract the vocabulary (the unique tokens that make up the encoding).

In [91]:
vocab = list(set(encoded))
vocab_idx = {char: i for i, char in enumerate(vocab)}
vocab_sz = len(vocab)

The first ten tokens look like this:

In [92]:
','.join(encoded[:10])

'P-42,D-8,P-42,D-8,P-42,D-8,P-42,D-8,P-46,D-4'

## Constructing a model and dataset

Below we'll construct a simple RNN model (this is essentially the same as the one included in `src/models`)

In [118]:
class RNN(nn.Module):
    def __init__(self, vocab_sz, emb_sz=300, n_hid=600, n_layers=1):
        super().__init__()
        self.encoder = nn.Embedding(vocab_sz, emb_sz)
        self.lstm = nn.LSTM(input_size=emb_sz, hidden_size=n_hid, num_layers=n_layers)
        self.decoder = nn.Linear(n_hid, vocab_sz)
        
    def forward(self, mb, hidden=None):
        x = self.encoder(mb)
        x, hidd = self.lstm(x, hidden)
        return self.decoder(x), hidd

The following function will take the encoding of the Phil Collins drum track, chop it up into sequences of length 64 and create a target dataset (the following token for each position). The `get_batches` function will divide this up into minibatches that we can then pass into the model for training.

In [94]:
def create_dataset(encoding, bptt=64):
    encoding = [vocab_idx[enc] for enc in encoding]
    seqs = []
    targs = []
    for i in range(len(encoding)//(bptt + 1)):
        seqs.append(encoding[i*bptt:(i+1)*bptt])
        targs.append(encoding[(i*bptt)+1:((i+1)*bptt) + 1])
    return torch.LongTensor(seqs).permute((1, 0)), torch.LongTensor(targs).permute((1, 0))

In [95]:
def get_batches(inputs, targets, bs=64):
    batches = []
    for i in range(inputs.shape[1]//32):
        batches.append((inputs[:, i*bs:(i+1)*bs], targets[:, i*bs:(i+1)*bs]))
    return batches

## Training the model

The training procedure below will train the model for 25 epochs using an Adam optimiser. The loss is calculated using the cross-entropy.

In [96]:
rnn = RNN(vocab_sz)

In [97]:
x, y = create_dataset(encoded)
optim = torch.optim.Adam(rnn.parameters(), lr=3e-3)
for i in range(25):
    epoch_loss = 0
    num = 0
    for xb, yb in get_batches(x, y):        
        rnn.zero_grad()
        out, _, _ = rnn(x)
        loss = F.cross_entropy(out.reshape(-1, vocab_sz), y.reshape(-1))
        epoch_loss += loss.item()
        num += 1
        loss.backward()
        optim.step()
    print(f'epoch {i} loss {epoch_loss/num}')

epoch 0 loss 2.7856833934783936
epoch 1 loss 2.0310475826263428
epoch 2 loss 1.5986616611480713
epoch 3 loss 1.3271511793136597
epoch 4 loss 1.1650668382644653
epoch 5 loss 1.0346001386642456
epoch 6 loss 0.92838054895401
epoch 7 loss 0.8601027727127075
epoch 8 loss 0.775689959526062
epoch 9 loss 0.7118246555328369
epoch 10 loss 0.6587616205215454
epoch 11 loss 0.5966862440109253
epoch 12 loss 0.5482185482978821
epoch 13 loss 0.5103653073310852
epoch 14 loss 0.4751216769218445
epoch 15 loss 0.4349558353424072
epoch 16 loss 0.396340936422348
epoch 17 loss 0.3698026239871979
epoch 18 loss 0.33964991569519043
epoch 19 loss 0.3158837556838989
epoch 20 loss 0.29283303022384644
epoch 21 loss 0.26924362778663635
epoch 22 loss 0.25207337737083435
epoch 23 loss 0.2305152416229248
epoch 24 loss 0.21670326590538025


## Sampling from the model

We can now generate continuations for a given prompt! In order to do this we'll

1. Convert the prompt (which is at least semi human-readable) into a model input of suitable type and dimension, i.e. of size (seq, batch_size=1)

2. Feed in the input to the model and perform [nucleus sampling](https://arxiv.org/abs/1904.09751). This involves collection the most probable continuations that sum up to a given threshold (0.9 in the example below), renormalising the resulting selection and then sampling from it.

3. Feed in the resulting prompt again until we reach the specified duration (256 32nd notes, or 8 bars, in the example below). Notice that the `encoder.duration` method is used to do this.

The temperature (`temp`) for the initial softmax can be twiddled with to flatten out the distribution, making other tokens more likely or make it more confident in its prediction thus reducing stochasticity. A value of 1 is quite conservative whereas 1.5 results some... interesting beats.

In [120]:
def convert_to_input(prompt):
    seq = [vocab_idx[sym] for sym in prompt]
    return torch.LongTensor(seq).unsqueeze(-1)

Note 

In [148]:
temp = 1.3
with torch.no_grad():
    prompt = ['P-42', 'D-8']
    hidd = None
    while encoder.duration(prompt) < 256:
        seq = convert_to_input(prompt)
        out, hidd = rnn(seq[-1].unsqueeze(-1), hidd)
        # Get the last vector of logprobs
        logprobs = out.reshape(-1, vocab_sz)[-1]
                
        nucleus_probs = []
        nucleus_indices = []
        
        sorted_probs = F.softmax(logprobs/temp, -1).sort(descending=True)
        for p, idx in zip(sorted_probs[0], sorted_probs[1]):
            nucleus_probs.append(p)
            nucleus_indices.append(idx)
            if sum(nucleus_probs) > 0.9:
                break

        unnormalised = torch.Tensor(nucleus_probs)
        probs =  unnormalised * (1/sum(torch.Tensor(unnormalised)))
        # We need to refer back to the original indices to grab the correct vocab elements
        prediction = nucleus_indices[torch.distributions.Categorical(probs).sample().item()]
        prompt.append(vocab[prediction])

In [149]:
prompt[:10]

['P-42', 'D-8', 'P-64', 'P-50', 'P-42', 'D-8', 'P-42', 'P-41', 'P-35', 'D-4']

## Listen

Finally we can listen to the generated output!

In [150]:
play_midi(encoder.decode(prompt, tempo=120))

Playing... (Ctrl+C to stop)
