In [1]:
from music21 import converter, instrument, note, chord, stream, midi, instrument
from scipy import sparse
import time
import tqdm.auto
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import math

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
def streamToData(stream_set):
    stream = stream_set.parts[0]
    total_length = np.int(np.round(stream.flat.highestTime / 0.25)) # in semiquavers
    stream_list = []
    output = np.zeros((total_length,128))
    
    for stream in stream_set:
        for element in stream.flat:
            if isinstance(element, note.Note):
                off = np.int(np.round(element.offset / 0.25))
                length = np.int(np.round(element.quarterLength / 0.25))
                pitch = element.pitch.midi
                output[off:off+length,pitch] = 1
            elif isinstance(element, chord.Chord):
                off = np.int(np.round(element.offset / 0.25))
                length = np.int(np.round(element.quarterLength / 0.25))
                pitches = np.array([np.int(np.round(i.midi)) for i in element.pitches])
                times = np.arange(off,off+length)
                output[off:off+length,pitches] = 1
            
    return sparse.csr_matrix(output)

def DataToStream(data_mat, time_signature = 0.25):
    melody_stream = stream.Stream()
    melody_stream.append(instrument.Piano())
    t,k = data_mat.shape
    old_element = None
    counter = 1
    for i in range(t):
        arr = data_mat[i,:]
        r = np.int(np.sum(arr))
        
        if r == 1:
            #print(np.argmax(arr))
            new_element = note.Note(np.argmax(arr))
        elif r > 1:
            #print(arr)
            arr = arr.todense()
            #print(np.where(arr == 1)[1])
            pitches = np.where(arr == 1)[1]
            all_notes = np.array([note.Note(i) for i in pitches])
            new_element = chord.Chord(all_notes)
        else:
            #print(r)
            new_element = note.Rest()
        
        if new_element == old_element and counter < 6:
            counter += 1
        else:
            if old_element:
                old_element.quarterLength = time_signature*counter
                melody_stream.append(old_element)
            counter = 1
            
        old_element = new_element
            
    return melody_stream

In [4]:
classes = ['bach','backstreetboys','beatles','beethoven','brahms','britneyspears',
             'chopin','coldplay','debussy','haydn','liszt','mendelssohn',
            'mozart','nirvana','paganini','queen','rachmaninow','schubert',
            'schumann','tchaikovsky']
datasets = [folder + '_dataset.npz' for folder in classes]
labels = np.arange(20)

In [5]:
song_length = 320
sample_size = 20
n_classes = len(classes)
device = "cuda:0"

In [6]:
data = []
for file in datasets:
    d = []
    subset = np.load(file,allow_pickle=True)
    for item in subset.files:
        d.append(subset[item])
    data.append(d[0])
    

In [7]:
len(data)

20

In [8]:
data[19][19]

<2388x128 sparse matrix of type '<class 'numpy.float64'>'
	with 6871 stored elements in Compressed Sparse Row format>

In [9]:
def random_snippet(c):
    n_songs = len(data[c])
    song_no = np.random.randint(n_songs)
    song = data[c][song_no]
    L,D = song.shape
    start_max = L-song_length
    if start_max > 0:
        song_start = np.random.randint(start_max)
    else:
        #print(c,song_no)
        return random_snippet(c)
    
    return song[song_start:song_start+song_length,:]

In [10]:
random_snippet(3)

<320x128 sparse matrix of type '<class 'numpy.float64'>'
	with 658 stored elements in Compressed Sparse Row format>

In [11]:
torch.randint(20,(1,10)).detach().cpu()[0]

tensor([ 5,  8,  5,  9,  4, 10, 15, 18,  1, 19])

In [12]:
def standard_batch(size = sample_size):
    batch = []
    #class_id = torch.randint(20,(1,sample_size)).detach().cpu()[0]
    #class_id = [0,3,6,12,13]*int(sample_size/5)
    class_id = torch.arange(20).detach().cpu().tolist()
    for i in range(size):
        snippet = random_snippet(class_id[i])
        batch.append(snippet)
    return batch

In [13]:
batch1 = standard_batch()

In [14]:
stream1 = DataToStream(batch1[sample_size-1], time_signature = 0.25)

In [15]:
stream1.show('midi')

In [16]:
batch2 = np.array([a.todense() for a in batch1])
batch2.shape

(20, 320, 128)

In [17]:
def sample():
    batch1 = standard_batch()
    batch2 = np.array([a.todense() for a in batch1])
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3
    

In [18]:
def torch_convert(song):
    batch2 = song.todense()
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3.view(1,song.shape[0],song.shape[1])

In [19]:
batch = sample()
batch.shape

torch.Size([20, 320, 128])

In [20]:
class LingLing_VAE(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=512, latent_dim = 20, num_layers = 3):
        super(LingLing_VAE, self).__init__()
        
        device = "cuda:0"
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        
        self.encoder = nn.LSTM(self.input_dim,self.hidden_dim,num_layers = self.num_layers)
        self.compressor = nn.Linear(self.hidden_dim*self.num_layers,self.latent_dim*2)
        self.decompressor = nn.Linear(self.latent_dim,self.hidden_dim*self.num_layers)
        self.decoder = nn.LSTM(self.input_dim+self.latent_dim,self.hidden_dim,num_layers=self.num_layers)
        self.output = nn.Linear(self.hidden_dim,self.input_dim)
        
        self.encoder.to(device)
        self.compressor.to(device)
        self.decompressor.to(device)
        self.decoder.to(device)
        self.output.to(device)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std)
        return sample
        

    def forward(self, song):
        inputs = song[:, :-1, :]
        inputs = inputs.permute(1,0,2)
        
        _,(hidden_rep, _) = self.encoder(inputs)
        latent_rep = self.compressor(hidden_rep.view(-1,self.hidden_dim*self.num_layers))
        latent_rep = latent_rep.view(-1,2,self.latent_dim)
        
        mu = latent_rep[:,0, :]
        log_var = latent_rep[:,1, :] 
        z = self.reparameterize(mu, log_var)
        #z = mu
        
        latent_rep_context = z.repeat(inputs.shape[0],1,1)
        decoder_input_rep = torch.cat((inputs,latent_rep_context), 2)
        
        hidden_out = self.decompressor(z)
        h0 = hidden_out.view(self.num_layers,-1,self.hidden_dim).contiguous()
        c0 = torch.zeros(h0.shape).to(device)
        
        o1, (h, _) = self.decoder(decoder_input_rep, (h0,c0))
        o1 = o1.permute(1,0,2)
        reconstruction = self.output(o1)
        reconstruction = torch.sigmoid(reconstruction)
                
        return reconstruction, mu, log_var
    
    def generate_song(self, song_start, mu, log_var, max_length=600):
        device = "cuda:0"

        z = self.reparameterize(mu, log_var)

        hidden_out = self.decompressor(z)
        h0 = hidden_out.view(self.num_layers,-1,self.hidden_dim).contiguous()
        c0 = torch.zeros(h0.shape).to(device)

        latent_rep_context = z.repeat(song_start.shape[0],1,1)
        song_start = song_start.view(1,song_start.shape[0],song_start.shape[1])
        song_start = song_start.permute(1,0,2)

        song_start_input = torch.cat((song_start,latent_rep_context), 2)

        _, (prev_h, prev_c) = self.decoder(song_start_input, (h0,c0))

        song_out = torch.zeros((max_length, self.input_dim), dtype=torch.long).to(device)

        prev_chords = song_start[-1,:]

        t = 0
        while t<max_length:

            prev_chords = prev_chords.view(1,1,self.input_dim)
            latent_rep_context = z.repeat(prev_chords.shape[0],1,1)

            decoder_input_rep = torch.cat((prev_chords,latent_rep_context), 2)

            out, (next_h, next_c) = self.decoder(decoder_input_rep, (prev_h, prev_c))   

            prev_c,prev_h = next_c,next_h
            out = out.view(out.shape[1],out.shape[2])

            out = self.output(out)
            out = torch.sigmoid(out)

            formatted_out = torch.bernoulli(out.view(-1))
            #formatted_out = (out>0.6).float()
            
            song_out[t,:] = formatted_out
            prev_chords = formatted_out

            t+=1

        return song_out

In [21]:
#model = LingLing_VAE()
#sample_batch = sample()
#recon,mu,log_var = model(sample_batch)
#print(recon.shape,mu.shape,log_var.shape)

In [22]:
#song_start = sample_batch[0]
#mu = mu[0].view(1,mu.shape[1])
#log_var = log_var[0].view(1,log_var.shape[1])

#song_start.detach()
#mu.detach()
#log_var.detach()

#generated_sample = model.generate_song(song_start, mu, log_var, max_length=360)

In [23]:
model = LingLing_VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-5, betas =(0.9,0.999), weight_decay = 0, eps=1e-08)
criterion = nn.BCELoss(reduction='sum')
#optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

In [24]:
model.load_state_dict(torch.load('Ling_ling_VAE2.pt'))

<All keys matched successfully>

In [25]:
optimizer.zero_grad()
iterations = 10000 + 1
total_loss = 0
re_loss = 0
kl_loss = 0
new_sample_cycle = 1
KLD_weight = 1e10

for i in range(iterations):
    
    if i%new_sample_cycle == 0:
        sample_batch = sample()

    optimizer.zero_grad()
    reconstruction,mu,log_var = model(sample_batch)
        
    reconstruction_loss = criterion(reconstruction.view(-1,(song_length-1)*128), sample_batch[:, 1:, :].view(-1,(song_length-1)*128))
    
    KLD_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
    
    loss = reconstruction_loss + KLD_weight*KLD_loss
    
    total_loss += loss.item()
    re_loss += reconstruction_loss.item()
    kl_loss += KLD_weight*KLD_loss.item()
    
    loss.backward()
    optimizer.step()
    
    if i%250 == 249:
        print("Iteration :",i, " Loss :", loss.item()/sample_size, "Running Average Loss :", total_loss/(250*sample_size))
        print("Reconstruction Loss", re_loss/(250*sample_size)," KLD Loss", kl_loss/(250*sample_size))
        if (kl_loss/(250*sample_size)) < 1000:
            KLD_weight = KLD_weight*10
        total_loss,re_loss,kl_loss = 0,0,0
        
    #if i%new_sample_cycle == 0: 
    #    print("Iteration :",i, " Starting Loss :", loss.item()/sample_size)

Iteration : 249  Loss : 2097.5390625 Running Average Loss : 1507.4351134765625
Reconstruction Loss 1280.818251171875  KLD Loss 226.61685943603516
Iteration : 499  Loss : 3730.2890625 Running Average Loss : 2873.927471484375
Reconstruction Loss 1273.5427984375  KLD Loss 1600.3847122192383
Iteration : 749  Loss : 1835.0994140625 Running Average Loss : 2491.5426123046873
Reconstruction Loss 1262.494853515625  KLD Loss 1229.0477752685547
Iteration : 999  Loss : 3317.97265625 Running Average Loss : 2661.801773828125
Reconstruction Loss 1282.55030703125  KLD Loss 1379.251480102539
Iteration : 1249  Loss : 272.7484375 Running Average Loss : 2519.5558962890623
Reconstruction Loss 1273.818846875  KLD Loss 1245.737075805664
Iteration : 1499  Loss : 7606.2203125 Running Average Loss : 2501.6695111328127
Reconstruction Loss 1264.277122265625  KLD Loss 1237.3924255371094


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'Ling_ling_VAE3.pt')

In [26]:
path = "./Data/ac_dc_thunderstruck.mid"
mf = midi.MidiFile()
mf.open(path)
mf.read()
mf.close()
loaded_song = midi.translate.midiFileToStream(mf)

In [47]:
song_data = streamToData(loaded_song)
batch1 = song_data[0:120,:]
batch2 = np.array([a.todense() for a in batch1])
device = "cuda:0"
batch3 = torch.from_numpy(batch2).to(device).float()
batch3 = batch3.permute(1,0,2).squeeze(0)
song_start = batch3

In [28]:
model.eval()

LingLing_VAE(
  (encoder): LSTM(128, 512, num_layers=3)
  (compressor): Linear(in_features=1536, out_features=40, bias=True)
  (decompressor): Linear(in_features=20, out_features=1536, bias=True)
  (decoder): LSTM(148, 512, num_layers=3)
  (output): Linear(in_features=512, out_features=128, bias=True)
)

In [101]:
sample_batch = sample()
recon,mu_vec,log_var_vec = model(sample_batch)
print(recon.shape,mu_vec.shape,log_var_vec.shape)

torch.Size([20, 319, 128]) torch.Size([20, 20]) torch.Size([20, 20])


In [102]:
#0-bach,1-backstreetboys,2-beatles,3-beethoven,4-brahms,5-britneyspears, 6-chopin,7-coldplay, 8-debussy, 9-haydn, 10-liszt,
#11-mendelssohn, 12-mozart, 13-nirvana, 14-paganini, 15-queen, 16-rachmaninow, 17-schubert, 18-schumann,19-tchaikovsky

song_start = sample_batch[6]
mu = mu_vec[6].view(1,mu_vec.shape[1])
log_var = log_var_vec[6].view(1,log_var_vec.shape[1])

In [103]:
song = song_start.detach().cpu().numpy()
song = sparse.csr_matrix(song)
start_stream = DataToStream(song, time_signature = 0.25)
start_stream.show('midi')

In [104]:
generated_sample = model.generate_song(song_start, mu, log_var, max_length=240)
song = generated_sample.detach().cpu().numpy()
song = sparse.csr_matrix(song)
generated_stream = DataToStream(song, time_signature = 0.25)
generated_stream.show('midi')

In [105]:
generated_stream.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord G#2 C4 G#4>
{0.25} <music21.chord.Chord G#2 E-3 C4 G#4>
{0.5} <music21.chord.Chord E-3 C4 G#4>
{0.75} <music21.chord.Chord G#2 E-3 C4 G#4>
{1.0} <music21.chord.Chord E-3 C4>
{1.25} <music21.chord.Chord E-3 C4 F4>
{1.5} <music21.chord.Chord C4 C#4>
{1.75} <music21.chord.Chord C4 C#4 F4>
{2.0} <music21.chord.Chord A3 C4 C#4 F4>
{2.25} <music21.chord.Chord C4 C#4 F4>
{2.5} <music21.note.Note C#>
{4.0} <music21.note.Note C#>
{4.5} <music21.note.Rest rest>
{5.0} <music21.note.Note C#>
{5.25} <music21.chord.Chord C#2 C3 E7>
{5.5} <music21.chord.Chord C#2 C3>
{6.0} <music21.note.Note C#>
{6.25} <music21.chord.Chord C#2 B-2 G#3>
{6.5} <music21.chord.Chord C#2 C#4>
{7.0} <music21.note.Note C#>
{7.5} <music21.chord.Chord C#2 C#4>
{7.75} <music21.note.Note C#>
{8.25} <music21.chord.Chord C#2 C#3>
{8.5} <music21.note.Note C#>
{9.0} <music21.chord.Chord E-3 C#4>
{9.25} <music21.note.Rest rest>
{10.0} <music21.note.Note G#>
{10.5} <

In [34]:
start_stream.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord F#3 A4 D5>
{0.25} <music21.chord.Chord G3 B-4 D5>
{1.25} <music21.chord.Chord E3 G4 C5>
{2.25} <music21.chord.Chord F3 A4 C5>
{3.25} <music21.chord.Chord D3 F4 B4>
{4.25} <music21.chord.Chord C2 C3 C5>
{4.5} <music21.chord.Chord C2 C3 E4 G4>
{4.75} <music21.chord.Chord C4 C5>
{5.0} <music21.chord.Chord E4 G4>
{5.25} <music21.chord.Chord C4 C5>
{5.5} <music21.chord.Chord E4 G4>
{5.75} <music21.chord.Chord C4 D5>
{6.0} <music21.chord.Chord E4 G4>
{6.25} <music21.chord.Chord C4 E5>
{6.5} <music21.chord.Chord G4 B-4>
{6.75} <music21.chord.Chord C4 C5>
{7.0} <music21.chord.Chord G4 B-4>
{7.25} <music21.chord.Chord C4 D5>
{7.5} <music21.chord.Chord G4 B-4>
{7.75} <music21.chord.Chord C4 E5>
{8.0} <music21.chord.Chord G4 B-4>
{8.25} <music21.chord.Chord C4 F5>
{8.5} <music21.chord.Chord F4 A4>
{8.75} <music21.chord.Chord C4 F5>
{9.0} <music21.chord.Chord F4 A4>
{9.25} <music21.chord.Chord C4 F5>
{9.5} <music21.chord.Chord F4 

In [80]:
start_stream.write('midi', fp='output_samples/mozart_in2.mid')

'output_samples/mozart_in2.mid'

In [81]:
generated_stream.write('midi', fp='output_samples/mozart_out2.mid')

'output_samples/mozart_out2.mid'