In [1]:
from music21 import converter, instrument, note, chord, stream, midi, instrument
from scipy import sparse
import time
import tqdm.auto
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import math

In [2]:
def DataToStream(data_mat, time_signature = 0.25):
    melody_stream = stream.Stream()
    melody_stream.append(instrument.Piano())
    t,k = data_mat.shape
    old_element = None
    counter = 1
    for i in range(t):
        arr = data_mat[i,:]
        r = np.int(np.sum(arr))
        
        if r == 1:
            #print(np.argmax(arr))
            new_element = note.Note(np.argmax(arr))
        elif r > 1:
            #print(arr)
            arr = arr.todense()
            #print(np.where(arr == 1)[1])
            pitches = np.where(arr == 1)[1]
            all_notes = np.array([note.Note(i) for i in pitches])
            new_element = chord.Chord(all_notes)
        else:
            #print(r)
            new_element = note.Rest()
        
        if new_element == old_element and counter < 6:
            counter += 1
        else:
            if old_element:
                old_element.quarterLength = time_signature*counter
                melody_stream.append(old_element)
            counter = 1
            
        old_element = new_element
            
    return melody_stream

In [3]:
classes = ['bach','backstreetboys','beatles','beethoven','brahms','britneyspears',
             'chopin','coldplay','debussy','haydn','liszt','mendelssohn',
            'mozart','nirvana','paganini','queen','rachmaninow','schubert',
            'schumann','tchaikovsky']
datasets = [folder + '_dataset.npz' for folder in classes]
labels = np.arange(20)

In [4]:
song_length = 1*60*6
sample_size = 20
n_classes = len(classes)
device = "cuda:0"

In [5]:
data = []
for file in datasets:
    d = []
    subset = np.load(file,allow_pickle=True)
    for item in subset.files:
        d.append(subset[item])
    data.append(d[0])
    

In [6]:
len(data)

20

In [7]:
data[19][19]

<2388x128 sparse matrix of type '<class 'numpy.float64'>'
	with 6871 stored elements in Compressed Sparse Row format>

In [8]:
def random_snippet(c):
    n_songs = len(data[c])
    song_no = np.random.randint(n_songs)
    song = data[c][song_no]
    L,D = song.shape
    start_max = L-song_length
    if start_max > 0:
        song_start = np.random.randint(start_max)
    else:
        #print(c,song_no)
        return random_snippet(c)
    
    return song[song_start:song_start+song_length,:]

In [9]:
random_snippet(3)

<360x128 sparse matrix of type '<class 'numpy.float64'>'
	with 809 stored elements in Compressed Sparse Row format>

In [10]:
torch.randint(20,(1,10)).detach().cpu()[0]

tensor([ 8, 10,  3, 18, 15, 17, 19, 12,  2,  1])

In [11]:
def standard_batch(size = sample_size):
    batch = []
    #class_id = torch.randint(20,(1,sample_size)).detach().cpu()[0]
    class_id = [0,6]*int(sample_size/2)
    for i in range(size):
        snippet = random_snippet(class_id[i])
        batch.append(snippet)
    return batch

In [12]:
batch1 = standard_batch()

In [13]:
stream1 = DataToStream(batch1[sample_size-1], time_signature = 0.25)

In [14]:
stream1.show('midi')

In [15]:
batch2 = np.array([a.todense() for a in batch1])
batch2.shape

(20, 360, 128)

In [16]:
def sample():
    batch1 = standard_batch()
    batch2 = np.array([a.todense() for a in batch1])
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3
    

In [17]:
def torch_convert(song):
    batch2 = song.todense()
    device = "cuda:0"
    batch3 = torch.from_numpy(batch2).to(device).float()
    return batch3.view(1,song.shape[0],song.shape[1])

In [18]:
batch = sample()
batch.shape

torch.Size([20, 360, 128])

In [61]:
class LingLing_VAE(nn.Module):
    def __init__(self, dim = 128, feature_len = 256, kernel_num = 64, channel_num = 1):
        super(LingLing_VAE, self).__init__()
        
        device = "cuda:0"
        
        self.dim = dim
        self.latent_dim = feature_len//2
        self.feature_len = self.latent_dim*2
        self.kernel_num = kernel_num
        
        
        self.encoder = nn.Sequential(
            self._conv(channel_num, kernel_num // 4),
            self._conv(kernel_num // 4, kernel_num // 2),
            self._conv(kernel_num // 2, kernel_num),
        )
        
        self.decoder = nn.Sequential(
            self._deconv(kernel_num, kernel_num // 2),
            self._deconv(kernel_num // 2, kernel_num // 4),
            self._deconv(kernel_num // 4, channel_num),
            nn.Sigmoid()
        )
        
        self.n_features = (song_length//8)*(self.dim//8)*kernel_num
        
        self.compressor = nn.Linear(self.n_features,self.feature_len)
        self.decompressor = nn.Linear(self.latent_dim,self.n_features)
        
        self.compressor.to(device)
        self.decompressor.to(device)
        self.encoder.to(device)
        self.decoder.to(device)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std)
        return sample
        

    def forward(self, song):
        
        encoded = self.encoder(song.unsqueeze(1))
        features = self.compressor(encoded.view(-1,self.n_features))        
        features = features.view(-1, 2, self.latent_dim)
        
        mu = features[:, 0, :]
        log_var = features[:, 1, :]
        #print(mu.shape,log_var.shape)
        #z = self.reparameterize(mu, log_var)
        z = mu
        
        decoder_input = self.decompressor(z)
        decoder_input = decoder_input.view(-1, self.kernel_num,(song_length//8),(self.dim//8))
        
        reconstruction = self.decoder(decoder_input).view(-1,song_length,self.dim)
                
        return reconstruction, mu, log_var
    
    def _conv(self, channel_size, kernel_num):
        return nn.Sequential(
            nn.Conv2d(
                channel_size, kernel_num,
                kernel_size=4, stride=2, padding=1,
            ),
            nn.BatchNorm2d(kernel_num),
            nn.ReLU(),
        )

    def _deconv(self, channel_num, kernel_num):
        return nn.Sequential(
            nn.ConvTranspose2d(
                channel_num, kernel_num,
                kernel_size=4, stride=2, padding=1,
            ),
            nn.BatchNorm2d(kernel_num),
            nn.ReLU(),
        )


In [62]:
#sample_batch = sample()
#recon,mu,log_var = model(sample_batch)
#print(recon.shape,mu.shape,log_var.shape)

In [63]:
#torch.sum(recon)

In [77]:
model = LingLing_VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas =(0.9,0.999), weight_decay = 0.00001, eps=1e-08)
criterion = nn.MSELoss(reduction='sum')
#criterion = nn.BCELoss(reduction='sum')
optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

In [78]:
model.load_state_dict(torch.load('Ling_ling_VAE.pt'))

<All keys matched successfully>

In [80]:
optimizer.zero_grad()
iterations = 5000 + 1
total_loss = 0
re_loss = 0
new_sample_cycle = 1
KLD_weight = 100000

for i in range(iterations):
    
    if i%new_sample_cycle == 0:
        sample_batch = sample()

    optimizer.zero_grad()
    reconstruction,mu,log_var = model(sample_batch)
    
    #print(reconstruction,sample_batch)
    
    reconstruction_loss = criterion(reconstruction.view(-1,song_length*128), sample_batch.view(-1,song_length*128))
    
    #KLD_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
    
    loss = reconstruction_loss #+ KLD_weight*KLD_loss
    #print(loss.item(),bce_loss.item(),KLD_weight*KLD_loss.item())
    
    total_loss += loss.item()
    re_loss += reconstruction_loss.item()
    
    loss.backward()
    optimizer.step()
    
    if i%250 == 249:
        print("Iteration :",i, " Loss :", loss.item()/sample_size, "Running Average Loss :", total_loss/(250*sample_size))
        #print("Reconstruction Loss", re_loss/(250*sample_size)," KLD Loss", KLD_weight*KLD_loss.item()/sample_size)
        total_loss = 0
        re_loss = 0
        
    #if i%new_sample_cycle == 0: 
    #    print("Iteration :",i, " Starting Loss :", loss.item()/sample_size)

Iteration : 249  Loss : 11520.0 Running Average Loss : 11520.0
Iteration : 499  Loss : 11520.0 Running Average Loss : 11520.000015625
Iteration : 749  Loss : 11520.0 Running Average Loss : 11520.0
Iteration : 999  Loss : 11520.0 Running Average Loss : 11520.0
Iteration : 1249  Loss : 11520.0 Running Average Loss : 11520.0
Iteration : 1499  Loss : 11520.0 Running Average Loss : 11520.000003125


KeyboardInterrupt: 

In [67]:
torch.save(model.state_dict(), 'Ling_ling_VAE.pt')

In [40]:
sample_batch = sample()

In [41]:
song = sample_batch[1].detach().cpu().numpy()
song = sparse.csr_matrix(song)
stream1 = DataToStream(song, time_signature = 0.25)

In [42]:
stream1.show('midi')

In [43]:
def make_it_music(not_a_song, max_notes = 3, song_length = 360):
    threshold = -1
    avg_notes = 100
    while avg_notes > max_notes:
        threshold = threshold + 0.01
        s = (not_a_song>threshold).sum()
        avg_notes = s.item()/song_length

    print(avg_notes, threshold)
    return (not_a_song>threshold).int()

In [44]:
def make_it_music2(not_a_song, max_notes = 3, song_length = 360):
    threshold = 0.5

    s = (not_a_song>threshold).sum()
    avg_notes = s.item()/song_length

    print(avg_notes, threshold)
    return (not_a_song>threshold).int()

In [54]:
batch1 = sample()
song = batch1[0].detach().cpu().numpy()
song = sparse.csr_matrix(song)
stream1 = DataToStream(song, time_signature = 0.25)

In [55]:
stream1.show('midi')

In [56]:
reconstructed_songs,mu,log_var = model(batch1)
test = reconstructed_songs[0]

RuntimeError: CUDA out of memory. Tried to allocate 58.00 MiB (GPU 0; 6.00 GiB total capacity; 4.39 GiB already allocated; 48.91 MiB free; 4.55 GiB reserved in total by PyTorch)

In [57]:
test_song = make_it_music(test, max_notes=3)

2.911111111111111 0.09000000000000075


In [58]:
test_song = test_song.detach().cpu().numpy()
test_song = sparse.csr_matrix(test_song)
test_song.shape

(360, 128)

In [59]:
test_stream = DataToStream(test_song, time_signature = 0.25)

In [51]:
test_stream.show('midi')

In [52]:
test_stream.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord B-2 G5>
{0.25} <music21.chord.Chord E-2 C3 G#5>
{0.5} <music21.chord.Chord C3 G5>
{0.75} <music21.chord.Chord A2 D3 F5>
{1.0} <music21.chord.Chord D3 G#5>
{1.25} <music21.chord.Chord E-3 G5>
{1.5} <music21.chord.Chord E-3 B-5>
{1.75} <music21.chord.Chord C3 G#5>
{2.0} <music21.chord.Chord C3 E5 C6>
{2.25} <music21.chord.Chord G2 B-5>
{2.5} <music21.chord.Chord G2 E-5>
{2.75} <music21.chord.Chord G#2 D5>
{3.0} <music21.chord.Chord G#2 E-5>
{3.25} <music21.chord.Chord B-1 B-2 G#5>
{3.5} <music21.chord.Chord B-2 G5>
{3.75} <music21.chord.Chord G#1 B-1 B-2 F5 B-5 C6 C#6>
{4.0} <music21.chord.Chord G#1 B-1 C#2 E-5 C6>
{4.25} <music21.chord.Chord E-1 E-2 G4 B-4 E-5 C6>
{4.5} <music21.chord.Chord E-2 G4 B-4 E-5>
{4.75} <music21.chord.Chord E-3 G4 B-4 E-5>
{5.5} <music21.chord.Chord E-3 C#4 G4 B-4 E-5>
{5.75} <music21.chord.Chord D3 B-3 G4 B-4 E-5>
{6.0} <music21.chord.Chord D3 G4 B-4 E-5>
{6.75} <music21.chord.Chord C3 G4 B-4

In [53]:
stream1.show('text')

{0.0} <music21.instrument.Piano 'Piano'>
{0.0} <music21.chord.Chord B-2 G5>
{0.25} <music21.chord.Chord C3 G#5>
{0.5} <music21.chord.Chord C3 G5>
{0.75} <music21.chord.Chord D3 F5>
{1.0} <music21.chord.Chord D3 G#5>
{1.25} <music21.chord.Chord E-3 G5>
{1.5} <music21.chord.Chord E-3 B-5>
{1.75} <music21.chord.Chord C3 G#5>
{2.0} <music21.chord.Chord C3 C6>
{2.25} <music21.chord.Chord G2 B-5>
{2.5} <music21.chord.Chord G2 E-5>
{2.75} <music21.chord.Chord G#2 D5>
{3.0} <music21.chord.Chord G#2 E-5>
{3.25} <music21.chord.Chord B-2 G#5>
{3.5} <music21.chord.Chord B-2 G5>
{3.75} <music21.chord.Chord B-1 F5>
{4.0} <music21.chord.Chord B-1 E-5>
{4.25} <music21.chord.Chord E-2 G4 B-4 E-5>
{4.75} <music21.chord.Chord E-3 G4 B-4 E-5>
{5.75} <music21.chord.Chord D3 G4 B-4 E-5>
{6.75} <music21.chord.Chord C3 G4 B-4 E-5>
{7.25} <music21.chord.Chord C3 E-5 G5 C6>
{8.0} <music21.chord.Chord C3 E-5 G5>
{8.25} <music21.chord.Chord C2 F5 D6>
{8.75} <music21.chord.Chord C2 E-5 C6>
{9.25} <music21.chord.Ch