In [1]:
from music21 import converter, instrument, note, chord, stream, midi, instrument
from scipy import sparse
import time
import tqdm.auto
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import math

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
def DataToStream(data_mat, time_signature = 0.25):
    melody_stream = stream.Stream()
    melody_stream.append(instrument.Piano())
    t,k = data_mat.shape
    old_element = None
    counter = 1
    for i in range(t):
        arr = data_mat[i,:]
        r = np.int(np.sum(arr))
        
        if r == 1:
            #print(np.argmax(arr))
            new_element = note.Note(np.argmax(arr))
        elif r > 1:
            #print(arr)
            arr = arr.todense()
            #print(np.where(arr == 1)[1])
            pitches = np.where(arr == 1)[1]
            all_notes = np.array([note.Note(i) for i in pitches])
            new_element = chord.Chord(all_notes)
        else:
            #print(r)
            new_element = note.Rest()
        
        if new_element == old_element and counter < 6:
            counter += 1
        else:
            if old_element:
                old_element.quarterLength = time_signature*counter
                melody_stream.append(old_element)
            counter = 1
            
        old_element = new_element
            
    return melody_stream

In [4]:
classes = ['bach','backstreetboys','beatles','beethoven','brahms','britneyspears',
             'chopin','coldplay','debussy','haydn','liszt','mendelssohn',
            'mozart','nirvana','paganini','queen','rachmaninow','schubert',
            'schumann','tchaikovsky']
datasets = [folder + '_dataset.npz' for folder in classes]
labels = np.arange(20)

In [5]:
song_length = 320
sample_size = 20
n_classes = len(classes)
device = "cuda:0"

In [6]:
data = []
for file in datasets:
    d = []
    subset = np.load(file,allow_pickle=True)
    for item in subset.files:
        d.append(subset[item])
    data.append(d[0])
    

In [7]:
len(data)

20

In [8]:
data[19][19]

<2388x128 sparse matrix of type '<class 'numpy.float64'>'
	with 6871 stored elements in Compressed Sparse Row format>

In [9]:
def random_snippet(c):
    n_songs = len(data[c])
    song_no = np.random.randint(n_songs)
    song = data[c][song_no]
    L,D = song.shape
    start_max = L-song_length
    if start_max > 0:
        song_start = np.random.randint(start_max)
    else:
        #print(c,song_no)
        return random_snippet(c)
    
    return song[song_start:song_start+song_length,:]

In [10]:
random_snippet(3)

<320x128 sparse matrix of type '<class 'numpy.float64'>'
	with 578 stored elements in Compressed Sparse Row format>

In [11]:
torch.randint(20,(1,10)).detach().cpu()[0]

tensor([14, 18,  1, 19,  6,  3, 16,  7, 15, 18])

In [12]:
def standard_batch(size = sample_size):
    batch = []
    #class_id = torch.randint(20,(1,sample_size)).detach().cpu()[0]
    #class_id = [0,3,6,12,13]*int(sample_size/5)
    class_id = torch.arange(20).detach().cpu().tolist()
    for i in range(size):
        snippet = random_snippet(class_id[i])
        batch.append(snippet)
    return batch

In [13]:
batch1 = standard_batch()

In [14]:
stream1 = DataToStream(batch1[sample_size-1], time_signature = 0.25)

In [15]:
stream1.show('midi')

In [16]:
batch2 = np.array([a.todense() for a in batch1])
batch2.shape

(20, 320, 128)

In [17]:
def sample():
    batch1 = standard_batch()
    batch2 = np.array([a.todense() for a in batch1])
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3
    

In [18]:
def torch_convert(song):
    batch2 = song.todense()
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3.view(1,song.shape[0],song.shape[1])

In [19]:
batch = sample()
batch.shape

torch.Size([20, 320, 128])

In [20]:
class LingLing_VAE(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=512, latent_dim = 20, num_layers = 3):
        super(LingLing_VAE, self).__init__()
        
        device = "cuda:0"
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        
        self.encoder = nn.LSTM(self.input_dim,self.hidden_dim,num_layers = self.num_layers)
        self.compressor = nn.Linear(self.hidden_dim*self.num_layers,self.latent_dim*2)
        self.decompressor = nn.Linear(self.latent_dim,self.hidden_dim*self.num_layers)
        self.decoder = nn.LSTM(self.input_dim+self.latent_dim,self.hidden_dim,num_layers=self.num_layers)
        self.output = nn.Linear(self.hidden_dim,self.input_dim)
        
        self.encoder.to(device)
        self.compressor.to(device)
        self.decompressor.to(device)
        self.decoder.to(device)
        self.output.to(device)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std)
        return sample
        

    def forward(self, song):
        inputs = song[:, :-1, :]
        inputs = inputs.permute(1,0,2)
        
        _,(hidden_rep, _) = self.encoder(inputs)
        latent_rep = self.compressor(hidden_rep.view(-1,self.hidden_dim*self.num_layers))
        latent_rep = latent_rep.view(-1,2,self.latent_dim)
        
        mu = latent_rep[:,0, :]
        log_var = latent_rep[:,1, :] 
        z = self.reparameterize(mu, log_var)
        #z = mu
        
        latent_rep_context = z.repeat(inputs.shape[0],1,1)
        decoder_input_rep = torch.cat((inputs,latent_rep_context), 2)
        
        hidden_out = self.decompressor(z)
        h0 = hidden_out.view(self.num_layers,-1,self.hidden_dim).contiguous()
        c0 = torch.zeros(h0.shape).to(device)
        
        o1, (h, _) = self.decoder(decoder_input_rep, (h0,c0))
        o1 = o1.permute(1,0,2)
        reconstruction = self.output(o1)
        reconstruction = torch.sigmoid(reconstruction)
                
        return reconstruction, mu, log_var
    
    def generate_song(self, song_start, mu, log_var, max_length=600):
        device = "cuda:0"

        z = self.reparameterize(mu, log_var)

        hidden_out = self.decompressor(z)
        h0 = hidden_out.view(self.num_layers,-1,self.hidden_dim).contiguous()
        c0 = torch.zeros(h0.shape).to(device)

        latent_rep_context = z.repeat(song_start.shape[0],1,1)
        song_start = song_start.view(1,song_start.shape[0],song_start.shape[1])
        song_start = song_start.permute(1,0,2)

        song_start_input = torch.cat((song_start,latent_rep_context), 2)

        _, (prev_h, prev_c) = self.decoder(song_start_input, (h0,c0))

        song_out = torch.zeros((max_length, self.input_dim), dtype=torch.long).to(device)

        prev_chords = song_start[-1,:]

        t = 0
        while t<max_length:

            prev_chords = prev_chords.view(1,1,self.input_dim)
            latent_rep_context = z.repeat(prev_chords.shape[0],1,1)

            decoder_input_rep = torch.cat((prev_chords,latent_rep_context), 2)

            out, (next_h, next_c) = self.decoder(decoder_input_rep, (prev_h, prev_c))   

            prev_c,prev_h = next_c,next_h
            out = out.view(out.shape[1],out.shape[2])

            out = self.output(out)
            out = torch.sigmoid(out)

            formatted_out = torch.bernoulli(out.view(-1))
            #formatted_out = (out>0.7).float()
            
            song_out[t,:] = formatted_out
            prev_chords = formatted_out

            t+=1

        return song_out

In [21]:
#model = LingLing_VAE()
#sample_batch = sample()
#recon,mu,log_var = model(sample_batch)
#print(recon.shape,mu.shape,log_var.shape)

In [22]:
#song_start = sample_batch[0]
#mu = mu[0].view(1,mu.shape[1])
#log_var = log_var[0].view(1,log_var.shape[1])

#song_start.detach()
#mu.detach()
#log_var.detach()

#generated_sample = model.generate_song(song_start, mu, log_var, max_length=360)

In [23]:
model = LingLing_VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas =(0.9,0.999), weight_decay = 0.000001, eps=1e-08)
criterion = nn.BCELoss(reduction='sum')
#optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

In [24]:
model.load_state_dict(torch.load('Ling_ling_VAE2.pt'))

<All keys matched successfully>

In [25]:
optimizer.zero_grad()
iterations = 25000 + 1
total_loss = 0
re_loss = 0
kl_loss = 0
new_sample_cycle = 1
KLD_weight = 1e9

for i in range(iterations):
    
    if i%new_sample_cycle == 0:
        sample_batch = sample()

    optimizer.zero_grad()
    reconstruction,mu,log_var = model(sample_batch)
        
    reconstruction_loss = criterion(reconstruction.view(-1,(song_length-1)*128), sample_batch[:, 1:, :].view(-1,(song_length-1)*128))
    
    KLD_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
    
    loss = reconstruction_loss + KLD_weight*KLD_loss
    
    total_loss += loss.item()
    re_loss += reconstruction_loss.item()
    kl_loss += KLD_weight*KLD_loss.item()
    
    loss.backward()
    optimizer.step()
    
    if i%250 == 249:
        print("Iteration :",i, " Loss :", loss.item()/sample_size, "Running Average Loss :", total_loss/(250*sample_size))
        print("Reconstruction Loss", re_loss/(250*sample_size)," KLD Loss", kl_loss/(250*sample_size))
        if (kl_loss/(250*sample_size)) < 1000:
            KLD_weight = KLD_weight*10
        total_loss,re_loss,kl_loss = 0,0,0
        
    #if i%new_sample_cycle == 0: 
    #    print("Iteration :",i, " Starting Loss :", loss.item()/sample_size)

Iteration : 249  Loss : 6968.57421875 Running Average Loss : 8963.6670328125
Reconstruction Loss 4214.4225734375  KLD Loss 4749.244451522827
Iteration : 499  Loss : 7743.296875 Running Average Loss : 8206.2008734375
Reconstruction Loss 4021.03095  KLD Loss 4185.16993522644
Iteration : 749  Loss : 6770.7125 Running Average Loss : 7031.8821578125
Reconstruction Loss 3713.86446171875  KLD Loss 3318.0177211761475
Iteration : 999  Loss : 7104.03984375 Running Average Loss : 7345.9416234375
Reconstruction Loss 3522.0354640625  KLD Loss 3823.906183242798
Iteration : 1249  Loss : 7081.2109375 Running Average Loss : 6413.2742
Reconstruction Loss 3287.19535390625  KLD Loss 3126.0788440704346
Iteration : 1499  Loss : 7330.40703125 Running Average Loss : 6351.1268359375
Reconstruction Loss 3083.802865625  KLD Loss 3267.3239707946777
Iteration : 1749  Loss : 5828.5421875 Running Average Loss : 5825.2931109375
Reconstruction Loss 2940.60113125  KLD Loss 2884.6919536590576
Iteration : 1999  Loss : 55

Iteration : 14749  Loss : 6724.9578125 Running Average Loss : 6156.464515625
Reconstruction Loss 1641.889499609375  KLD Loss 4514.575004577637
Iteration : 14999  Loss : 8316.6203125 Running Average Loss : 7302.9844625
Reconstruction Loss 1659.02065390625  KLD Loss 5643.963813781738
Iteration : 15249  Loss : 6343.821484375 Running Average Loss : 6367.197425
Reconstruction Loss 1645.973111328125  KLD Loss 4721.224308013916
Iteration : 15499  Loss : 8296.8421875 Running Average Loss : 6461.177353125
Reconstruction Loss 1637.373437109375  KLD Loss 4823.803901672363
Iteration : 15749  Loss : 7206.3734375 Running Average Loss : 7259.509115625
Reconstruction Loss 1624.426391015625  KLD Loss 5635.082721710205
Iteration : 15999  Loss : 5400.818359375 Running Average Loss : 6959.215628125
Reconstruction Loss 1629.72591328125  KLD Loss 5329.489707946777
Iteration : 16249  Loss : 6007.8515625 Running Average Loss : 5484.847025
Reconstruction Loss 1612.69086953125  KLD Loss 3872.1561431884766
Itera

In [26]:
torch.save(model.state_dict(), 'Ling_ling_VAE2.pt')

In [31]:
model.eval()

LingLing_VAE(
  (encoder): LSTM(128, 512, num_layers=3)
  (compressor): Linear(in_features=1536, out_features=40, bias=True)
  (decompressor): Linear(in_features=20, out_features=1536, bias=True)
  (decoder): LSTM(148, 512, num_layers=3)
  (output): Linear(in_features=512, out_features=128, bias=True)
)

In [68]:
sample_batch = sample()
recon,mu,log_var = model(sample_batch)
print(recon.shape,mu.shape,log_var.shape)

torch.Size([20, 319, 128]) torch.Size([20, 20]) torch.Size([20, 20])


In [69]:
#0 - bach, 3-Beethoven, 6-Chopin, 12-Mozart, 13-Nirvana
song_start = sample_batch[19]
mu = mu[19].view(1,mu.shape[1])
log_var = log_var[19].view(1,log_var.shape[1])

In [70]:
song = song_start.detach().cpu().numpy()
song = sparse.csr_matrix(song)
start_stream = DataToStream(song, time_signature = 0.25)
start_stream.show('midi')

In [71]:
generated_sample = model.generate_song(song_start, mu, log_var, max_length=240)
song = generated_sample.detach().cpu().numpy()
song = sparse.csr_matrix(song)
generated_stream = DataToStream(song, time_signature = 0.25)
generated_stream.show('midi')

In [72]:
generated_stream.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord E3 F#3 E4 F#4 B4 G#5>
{0.25} <music21.chord.Chord C#2 F2 G2 G#2 A3 D5>
{0.5} <music21.note.Note G#>
{0.75} <music21.chord.Chord E2 C4 C#4 C#5>
{1.0} <music21.chord.Chord G3 C4 E4 F4>
{1.25} <music21.chord.Chord A2 G#3 B-6>
{1.5} <music21.chord.Chord B3 G#4>
{1.75} <music21.chord.Chord G2 G4 B4>
{2.0} <music21.chord.Chord E2 B2 C#4 D4 E-4>
{2.25} <music21.chord.Chord C2 C#3 C4 E4 F#4>
{2.5} <music21.chord.Chord A2 A5>
{2.75} <music21.chord.Chord B2 A3 D4>
{3.0} <music21.chord.Chord G2 G#2 F#3 A3 B3 C4 E4 F#4 B-4 E5>
{3.25} <music21.note.Note B>
{3.5} <music21.chord.Chord D2 D3>
{3.75} <music21.chord.Chord G#2 E3 F#3 G4 F#5 G#5>
{4.0} <music21.chord.Chord D2 F2 F#4>
{4.25} <music21.chord.Chord B1 G3 F4 G5 C6>
{4.5} <music21.chord.Chord E2 B-3 B3 C#4 C5 E-5>
{4.75} <music21.note.Note G>
{5.0} <music21.chord.Chord G#2 G3 D4 G5>
{5.25} <music21.chord.Chord F#2 C#3>
{5.5} <music21.chord.Chord F#1 D4 E5>
{5.75} <music21.note.

In [73]:
start_stream.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord D2 A2 F3 A3 C#4>
{1.0} <music21.chord.Chord D2 A2 F3 A3 D4>
{1.25} <music21.note.Note B->
{1.5} <music21.chord.Chord B-2 G4>
{1.75} <music21.note.Note B->
{2.0} <music21.chord.Chord B-2 A4>
{2.25} <music21.chord.Chord G3 C4 E4 B-4>
{2.5} <music21.chord.Chord G3 C4 E4 B4>
{2.75} <music21.chord.Chord G3 C4 E4>
{3.0} <music21.chord.Chord G3 C4 E4 C5>
{3.25} <music21.chord.Chord A2 F5>
{4.0} <music21.chord.Chord A2 E5>
{4.25} <music21.chord.Chord A3 C4 F4 E5>
{5.0} <music21.chord.Chord A3 C4 F4 D5>
{5.25} <music21.chord.Chord G2 D5>
{5.5} <music21.chord.Chord G2 E4 D5>
{5.75} <music21.chord.Chord G2 D5>
{6.0} <music21.chord.Chord G2 F4 D5>
{6.25} <music21.chord.Chord B-3 C4 E4 F#4 C5>
{6.5} <music21.chord.Chord B-3 C4 E4 G4 C5>
{6.75} <music21.chord.Chord B-3 C4 E4 C5>
{7.0} <music21.chord.Chord B-3 C4 E4 G#4 C5>
{7.25} <music21.chord.Chord F2 B-4 C5>
{7.75} <music21.chord.Chord F2 B-4>
{8.0} <music21.chord.Chord F2 B-4 B4

In [188]:
start_stream.write('midi', fp='output_samples/misc_in.mid')
generated_stream.write('midi', fp='output_samples/misc_out.mid')

'output_samples/chopin_out.mid'