In [1]:
from music21 import converter, instrument, note, chord, stream, midi, instrument
from scipy import sparse
import time
import tqdm.auto
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import math

In [2]:
def DataToStream(data_mat, time_signature = 0.25):
    melody_stream = stream.Stream()
    melody_stream.append(instrument.Piano())
    t,k = data_mat.shape
    old_element = None
    counter = 1
    for i in range(t):
        arr = data_mat[i,:]
        r = np.int(np.sum(arr))
        
        if r == 1:
            #print(np.argmax(arr))
            new_element = note.Note(np.argmax(arr))
        elif r > 1:
            #print(arr)
            arr = arr.todense()
            #print(np.where(arr == 1)[1])
            pitches = np.where(arr == 1)[1]
            all_notes = np.array([note.Note(i) for i in pitches])
            new_element = chord.Chord(all_notes)
        else:
            #print(r)
            new_element = note.Rest()
        
        if new_element == old_element and counter < 6:
            counter += 1
        else:
            if old_element:
                old_element.quarterLength = time_signature*counter
                melody_stream.append(old_element)
            counter = 1
            
        old_element = new_element
            
    return melody_stream

In [3]:
classes = ['bach','backstreetboys','beatles','beethoven','brahms','britneyspears',
             'chopin','coldplay','debussy','haydn','liszt','mendelssohn',
            'mozart','nirvana','paganini','queen','rachmaninow','schubert',
            'schumann','tchaikovsky']
datasets = [folder + '_dataset.npz' for folder in classes]
labels = np.arange(20)

In [4]:
song_length = 320
sample_size = 20
n_classes = len(classes)
device = "cuda:0"

In [5]:
data = []
for file in datasets:
    d = []
    subset = np.load(file,allow_pickle=True)
    for item in subset.files:
        d.append(subset[item])
    data.append(d[0])
    

In [6]:
len(data)

20

In [7]:
data[19][19]

<2388x128 sparse matrix of type '<class 'numpy.float64'>'
	with 6871 stored elements in Compressed Sparse Row format>

In [8]:
def random_snippet(c):
    n_songs = len(data[c])
    song_no = np.random.randint(n_songs)
    song = data[c][song_no]
    L,D = song.shape
    start_max = L-song_length
    if start_max > 0:
        song_start = np.random.randint(start_max)
    else:
        #print(c,song_no)
        return random_snippet(c)
    
    return song[song_start:song_start+song_length,:]

In [9]:
random_snippet(3)

<320x128 sparse matrix of type '<class 'numpy.float64'>'
	with 1095 stored elements in Compressed Sparse Row format>

In [10]:
torch.randint(20,(1,10)).detach().cpu()[0]

tensor([ 8,  4, 13, 10,  6,  0,  7, 19,  3, 13])

In [11]:
def standard_batch(size = sample_size):
    batch = []
    #class_id = torch.randint(20,(1,sample_size)).detach().cpu()[0]
    class_id = [0,3,6,12,13]*int(sample_size/5)
    for i in range(size):
        snippet = random_snippet(class_id[i])
        batch.append(snippet)
    return batch

In [12]:
batch1 = standard_batch()

In [13]:
stream1 = DataToStream(batch1[sample_size-1], time_signature = 0.25)

In [14]:
stream1.show('midi')

In [15]:
batch2 = np.array([a.todense() for a in batch1])
batch2.shape

(20, 320, 128)

In [16]:
def sample():
    batch1 = standard_batch()
    batch2 = np.array([a.todense() for a in batch1])
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3
    

In [17]:
def torch_convert(song):
    batch2 = song.todense()
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3.view(1,song.shape[0],song.shape[1])

In [18]:
batch = sample()
batch.shape

torch.Size([20, 320, 128])

In [19]:
class LingLing_VAE(nn.Module):
    def __init__(self, dim = 128, nheads = 8, nlayers = 6, features = 1):
        super(LingLing_VAE, self).__init__()
        
        device = "cuda:0"
        self.dim = dim
        self.feature_len = features
        
        self.positional_encodings = nn.Embedding(song_length, dim)
        
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=nheads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=nlayers)

        self.decoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=nheads)
        self.transformer_decoder = nn.TransformerEncoder(self.decoder_layer, num_layers=nlayers)
        
        
        self.positional_encodings.to(device)
        self.transformer_encoder.to(device)
        self.transformer_decoder.to(device)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std)
        return sample
        

    def forward(self, song):
        
        pos = torch.arange(0, song_length).unsqueeze(0).repeat(sample_size, 1).to(device)
        pos_embeddings = self.positional_encodings(pos)
        pos_song = song + pos_embeddings

        latent_rep = self.transformer_encoder(pos_song)
        
        frac = song_length//self.feature_len
        mu = latent_rep[:,0, :]
        log_var = latent_rep[:,-1, :] 
        z = self.reparameterize(mu, log_var)
        #z = mu
        
        decoder_input_rep = z.unsqueeze(1).repeat(1,frac,1).to(device)
        decoder_input_rep = decoder_input_rep + pos_embeddings
        reconstruction = self.transformer_decoder(decoder_input_rep)
        reconstruction = torch.sigmoid(reconstruction)
                
        return reconstruction, mu, log_var

In [20]:
#sample_batch = sample()
#recon,mu,log_var = model(sample_batch)
#print(recon.shape,mu.shape,log_var.shape)

In [21]:
#torch.sum(recon)

In [22]:
model = LingLing_VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas =(0.9,0.999), weight_decay = 0.000001, eps=1e-08)
#criterion = nn.MSELoss(reduction='sum')
criterion = nn.BCELoss(reduction='sum')
#optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

In [23]:
model.load_state_dict(torch.load('Ling_ling_Transformer_VAE.pt'))

<All keys matched successfully>

In [78]:
optimizer.zero_grad()
iterations = 5000 + 1
total_loss = 0
re_loss = 0
kl_loss = 0
new_sample_cycle = 1
KLD_weight = 1000

for i in range(iterations):
    
    if i%new_sample_cycle == 0:
        sample_batch = sample()

    optimizer.zero_grad()
    reconstruction,mu,log_var = model(sample_batch)
        
    reconstruction_loss = criterion(reconstruction.view(-1,song_length*128), sample_batch.view(-1,song_length*128))
    
    KLD_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
    
    loss = reconstruction_loss + KLD_weight*KLD_loss
    
    total_loss += loss.item()
    re_loss += reconstruction_loss.item()
    kl_loss += KLD_weight*KLD_loss.item()
    
    loss.backward()
    optimizer.step()
    
    if i%250 == 249:
        print("Iteration :",i, " Loss :", loss.item()/sample_size, "Running Average Loss :", total_loss/(250*sample_size))
        print("Reconstruction Loss", re_loss/(250*sample_size)," KLD Loss", kl_loss/(250*sample_size))
        if (kl_loss/(250*sample_size)) < 10:
            KLD_weight = KLD_weight*10
        total_loss,re_loss,kl_loss = 0,0,0
        
    #if i%new_sample_cycle == 0: 
    #    print("Iteration :",i, " Starting Loss :", loss.item()/sample_size)

Iteration : 249  Loss : 3965.328125 Running Average Loss : 3875.99588125
Reconstruction Loss 3875.960603125  KLD Loss 0.03529240489006043
Iteration : 499  Loss : 3800.581640625 Running Average Loss : 3860.87831171875
Reconstruction Loss 3860.7133578125  KLD Loss 0.16495633125305176
Iteration : 749  Loss : 3473.268359375 Running Average Loss : 3832.774109375
Reconstruction Loss 3832.06891875  KLD Loss 0.705181360244751
Iteration : 999  Loss : 3812.240234375 Running Average Loss : 3853.047475
Reconstruction Loss 3849.8421703125  KLD Loss 3.2052993774414062
Iteration : 1249  Loss : 3952.744921875 Running Average Loss : 4164.86530703125
Reconstruction Loss 3820.4320078125  KLD Loss 344.4333076477051


KeyboardInterrupt: 

In [79]:
torch.save(model.state_dict(), 'Ling_ling_Transformer_VAE.pt')

In [80]:
batch1 = sample()
song = batch1[1].detach().cpu().numpy()
song = sparse.csr_matrix(song)
stream1 = DataToStream(song, time_signature = 0.25)

In [81]:
stream1.show('midi')

In [82]:
reconstructed_songs,mu,log_var = model(batch1)
test = reconstructed_songs[1]

In [83]:
test_song = (test>0.1).int()

In [84]:
test_song = test_song.detach().cpu().numpy()
test_song = sparse.csr_matrix(test_song)
test_song.shape

(320, 128)

In [85]:
test_stream = DataToStream(test_song, time_signature = 0.25)

In [86]:
test_stream.show('midi')

In [87]:
test_stream.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord G3 A3 C4 D4 G4 A4>
{0.25} <music21.chord.Chord G3 A3 C4 D4 G4>
{0.5} <music21.chord.Chord G3 A3 C4 D4 E4 G4 A4>
{0.75} <music21.chord.Chord G3 A3 C4 D4 E4 G4>
{1.0} <music21.chord.Chord G3 A3 C4 D4 E4 G4 A4>
{1.25} <music21.chord.Chord G3 A3 C4 D4 E4 G4>
{1.75} <music21.chord.Chord G3 A3 C4 D4 E4 G4 A4>
{2.0} <music21.chord.Chord G3 A3 B3 C4 D4 E4 G4>
{2.25} <music21.chord.Chord G3 A3 C4 D4 E4 G4 A4>
{2.75} <music21.chord.Chord G3 A3 C4 D4 G4>
{3.0} <music21.chord.Chord G3 A3 C4 D4 G4 A4>
{3.25} <music21.chord.Chord G3 A3 C4 D4 E4 G4 A4>
{3.5} <music21.chord.Chord G3 A3 C4 D4 G4 A4>
{3.75} <music21.chord.Chord G3 A3 C4 D4 E4 G4>
{4.0} <music21.chord.Chord G3 A3 C4 D4 G4 A4>
{4.5} <music21.chord.Chord G3 A3 C4 D4 E4 G4 A4>
{4.75} <music21.chord.Chord G3 A3 C4 D4 G4 A4>
{5.0} <music21.chord.Chord G3 A3 C4 D4 E4 G4>
{5.5} <music21.chord.Chord G3 A3 C4 D4 G4 A4>
{5.75} <music21.chord.Chord G3 A3 C4 D4 E4 G4 A4>
{6.0} <musi

In [88]:
stream1.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord E3 B-4 B-5 D6>
{0.25} <music21.chord.Chord C3 B-4 B-5 B5 C6>
{0.5} <music21.chord.Chord E3 B-4 B-5 C6>
{0.75} <music21.chord.Chord C3 B-4 B-5 C6>
{1.0} <music21.chord.Chord E3 G4 G5 C6>
{1.25} <music21.chord.Chord C3 G4 G5 C6>
{1.5} <music21.chord.Chord F3 A4 A5 C6>
{1.75} <music21.chord.Chord C3 A4 A5 B5 C6>
{2.0} <music21.chord.Chord F3 A4 A5 D6>
{2.25} <music21.chord.Chord C3 A4 A5 B5 C6>
{2.5} <music21.chord.Chord F3 A4 A5 C6>
{2.75} <music21.chord.Chord C3 A4 A5 C6>
{3.0} <music21.chord.Chord A2 F4 F5 C6>
{3.25} <music21.chord.Chord C3 F4 F5 C6>
{3.5} <music21.chord.Chord B-2 G4 G5 C6>
{3.75} <music21.chord.Chord C3 G4 G5 C6>
{4.0} <music21.chord.Chord G2 G4 G5 C6>
{4.25} <music21.chord.Chord C3 E4 E5>
{4.5} <music21.chord.Chord B-2 C6>
{4.75} <music21.note.Note C>
{5.0} <music21.chord.Chord G2 C6>
{5.25} <music21.note.Note C>
{5.5} <music21.chord.Chord F2 F3 F5>
{5.75} <music21.chord.Chord F2 F3 E5 F5>
{6.0} <mus