In [2]:
import os
import re
import sys
import torch
import hashlib
import itertools
import logging

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')


from progress.bar import Bar
from concurrent.futures import ProcessPoolExecutor

from Utils.NotesSeq import NoteSeq as ns
from Utils.EventSeq import EventSeq as es
from Utils.ControlSeq import ControlSeq as cs
from Utils import utils

import warnings
warnings.filterwarnings("ignore")

In [3]:
import numpy as np

In [4]:
def preprocess_midi(path):

    note_seq = ns.from_midi_file(path)
    iter = itertools.cycle(note_seq)

    es1 = list
    es2 = list

    cs1 = list
    cs2 = list
    ev = True
    #print("evaluating")
    for seq in iter:
        #print(seq)
        if ev:
            #print("1st iter")
            seq.adjust_time(-seq.notes[0].start)
            print(-seq.notes[0].start)
            event_seq = es.from_note_seq(seq)
            print(event_seq)
            control_seq = cs.from_event_seq(event_seq)
            es1 = event_seq.to_array()
            cs1 = control_seq.to_compressed_array()
            ev = False
        else:
            #print("2nd iter")
            seq.adjust_time(-seq.notes[0].start)
            event_seq = es.from_note_seq(seq)
            control_seq = cs.from_event_seq(event_seq)
            es2 = event_seq.to_array()
            cs2 = control_seq.to_compressed_array()
            break

    return es1, es2, cs1, cs2

In [5]:
midi_paths = list(utils.find_files_by_extensions('./Dataset-RAW', ['.mid', '.midi']))

In [6]:
es1, es2, cs1, cs2 = preprocess_midi(midi_paths[0])

-0.0
<Utils.EventSeq.EventSeq object at 0x132f71160>


In [7]:
from http.client import ImproperConnectionState
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable

In [8]:
device = torch.device('mps')

In [9]:
batch_size = 64
window_size = 200
stride_size = 10
control_ratio = 1.0
teacher_forcing_ratio = 1.0
init_dim = 32
event_dim = 240

In [10]:
from dataset import Dataset

def load_dataset():
    dataset = Dataset('./ProcessedControls-RAW', verbose=True)
    dataset_size = len(dataset.samples)
    assert dataset_size > 0
    return dataset

In [11]:
dataset = load_dataset()

In [12]:
dataset

Dataset(root="./ProcessedControls-RAW", samples=6, avglen=9079.166666666666)

In [13]:
from torch.utils.tensorboard import SummaryWriter
import time

writer = SummaryWriter()

last_saving_time = time.time()
loss_function = nn.CrossEntropyLoss()


2024-07-11 01:02:27,486 - DEBUG - Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
2024-07-11 01:02:27,802 - DEBUG - Creating converter from 7 to 5
2024-07-11 01:02:27,803 - DEBUG - Creating converter from 5 to 7
2024-07-11 01:02:27,803 - DEBUG - Creating converter from 7 to 5
2024-07-11 01:02:27,803 - DEBUG - Creating converter from 5 to 7


In [14]:
from model import PerformanceRNN

model_artist = PerformanceRNN(240,24,32,512)
model_artist = model_artist.to(device)
model_AI = PerformanceRNN(240,24,32,512)
model_AI = model_artist.to(device)

In [15]:
optimizer1 = optim.Adam(model_artist.parameters(), lr=0.001)
optimizer2 = optim.Adam(model_AI.parameters(), lr=0.001)

In [16]:
def save_model(model, optimizer, model_config, sess_path):
    print('Saving to', sess_path)
    torch.save({'model_config': model_config,
                'model_state': model.state_dict(),
                'model_optimizer_state': optimizer.state_dict()}, sess_path)
    print('Done saving')

In [62]:
try:
    batch_gen = dataset.batches(batch_size, window_size, stride_size)

    for iteration, (events, controls, events2, controls2) in enumerate(batch_gen):

        #1st generation

        events = torch.LongTensor(events).to(device)
        assert events.shape[0] == window_size

        if np.random.random() < control_ratio:
            controls = torch.FloatTensor(controls).to(device)
            assert controls.shape[0] == window_size
        else:
            controls = None

        init = torch.randn(batch_size, init_dim).to(device)
        outputs = model_artist.generate(init, window_size, events=events[:-1], controls=controls,
                                 teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit')
        assert outputs.shape[:2] == events.shape[:2]

        #2nd generation

        events2 = torch.LongTensor(events2).to(device)
        assert events2.shape[0] == window_size

        if np.random.random() < control_ratio:
            controls2 = torch.FloatTensor(controls2).to(device)
            assert controls2.shape[0] == window_size
        else:
            controls = None

        init2 = torch.randn(batch_size, init_dim).to(device)
        outputs2 = model_AI.generate(init, window_size, events=events[:-1], controls=controls,
                                 teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit')
        assert outputs.shape[:2] == events.shape[:2]

        loss = loss_function(outputs.view(-1, event_dim), events2.view(-1))
        model_artist.zero_grad()
        loss.backward()
        #print(f"loss 1 = {loss}")

        loss2 = loss_function(outputs2.view(-1, event_dim), events2.view(-1))
        model_AI.zero_grad()
        loss2.backward()
        #print(f"loss 2 = {loss2}")
        norm = utils.compute_gradient_norm(model_artist.parameters())
        nn.utils.clip_grad_norm_(model_artist.parameters(), 1.0)

        norm2 = utils.compute_gradient_norm(model_AI.parameters())
        nn.utils.clip_grad_norm_(model_AI.parameters(), 1.0)
        
        optimizer1.step()
        optimizer2.step()
        
        writer.add_scalar('model_artist/loss', loss.item(), iteration)
        writer.add_scalar('model_artist/norm', norm.item(), iteration)
        writer.add_scalar('model_AI/loss', loss2.item(), iteration)
        writer.add_scalar('model_AI/norm', norm2.item(), iteration)

        print(f'iter {iteration}, loss: {loss.item()}, loss2:{loss2.item()}')

except KeyboardInterrupt:
    save1 = './save/Artist.sess'
    save2 = './save/AI.sess'

    save_model(model_artist, optimizer1, model_artist.state_dict(),save1)
    save_model(model_AI, optimizer1, model_artist.state_dict(),save2)
    print("ciao")

iter 0, loss: 5.4751667976379395, loss2:5.475212574005127
iter 1, loss: 5.1125078201293945, loss2:5.112485408782959
iter 2, loss: 6.118977069854736, loss2:6.119515419006348
iter 3, loss: 5.416440963745117, loss2:5.416049957275391
iter 4, loss: 5.041950702667236, loss2:5.041746139526367
iter 5, loss: 4.767392158508301, loss2:4.767796516418457
iter 6, loss: 4.612310886383057, loss2:4.612529754638672
iter 7, loss: 4.665270805358887, loss2:4.665220260620117
iter 8, loss: 4.540961742401123, loss2:4.540817737579346
iter 9, loss: 4.59427547454834, loss2:4.594080924987793
iter 10, loss: 4.486323356628418, loss2:4.486688613891602
iter 11, loss: 4.523890495300293, loss2:4.523595809936523
iter 12, loss: 4.38703727722168, loss2:4.3872761726379395
iter 13, loss: 4.3582916259765625, loss2:4.357961177825928
iter 14, loss: 4.221174716949463, loss2:4.221160888671875
iter 15, loss: 4.191904067993164, loss2:4.191351890563965
iter 16, loss: 4.131591320037842, loss2:4.1320414543151855
iter 17, loss: 4.1042

In [17]:
batch_size_gen = 8
mAI_path = 'save/AI.sess'
mArtist_path = 'save/Artist.sess'
output = 'output/'
max_len = 2000
greedy_ratio = 1.0
controls = None
control = 'NONE'

In [1]:
test = 'save/low_data_train.sess'

In [19]:
mTest_model = torch.load(test, map_location=device)

In [20]:
mTest_model['model_config']

{'init_dim': 32,
 'event_dim': 240,
 'control_dim': 24,
 'hidden_dim': 512,
 'gru_layers': 3,
 'gru_dropout': 0.3}

In [21]:
mAI_state = torch.load(mAI_path, map_location=device)
mAI = PerformanceRNN(**mTest_model['model_config']).to(device)
mArtist_state = torch.load(mArtist_path, map_location=device)
mArtist = PerformanceRNN(**mTest_model['model_config']).to(device)

In [22]:
mAI.load_state_dict(mAI_state['model_state'])
mAI.eval()
print(mAI)
print('-' * 70)

PerformanceRNN(
  (inithid_fc): Linear(in_features=32, out_features=1536, bias=True)
  (inithid_fc_activation): Tanh()
  (event_embedding): Embedding(240, 240)
  (concat_input_fc): Linear(in_features=265, out_features=512, bias=True)
  (concat_input_fc_activation): LeakyReLU(negative_slope=0.1, inplace=True)
  (gru): GRU(512, 512, num_layers=3, dropout=0.3)
  (output_fc): Linear(in_features=1536, out_features=240, bias=True)
  (output_fc_activation): Softmax(dim=-1)
)
----------------------------------------------------------------------


In [23]:
mArtist.load_state_dict(mArtist_state['model_state'])
mArtist.eval()
print(mArtist)
print('-' * 70)

PerformanceRNN(
  (inithid_fc): Linear(in_features=32, out_features=1536, bias=True)
  (inithid_fc_activation): Tanh()
  (event_embedding): Embedding(240, 240)
  (concat_input_fc): Linear(in_features=265, out_features=512, bias=True)
  (concat_input_fc_activation): LeakyReLU(negative_slope=0.1, inplace=True)
  (gru): GRU(512, 512, num_layers=3, dropout=0.3)
  (output_fc): Linear(in_features=1536, out_features=240, bias=True)
  (output_fc_activation): Softmax(dim=-1)
)
----------------------------------------------------------------------


In [27]:
def load_dataset2():
    dataset = Dataset('./Line-RAW', verbose=True)
    dataset_size = len(dataset.samples)
    assert dataset_size > 0
    return dataset

In [29]:
dataset = load_dataset2()

In [30]:
dataset

Dataset(root="./Line-RAW", samples=1, avglen=14385.0)

In [53]:
def get_primary_event(self, batch_size):
        return torch.LongTensor([[self.primary_event] * batch_size]).to(device)

def init_to_hidden(self, init):
        # [batch_size, init_dim]
        batch_size = init.shape[0]
        out = self.inithid_fc(init)
        out = self.inithid_fc_activation(out)
        out = out.view(self.gru_layers, batch_size, self.hidden_dim)
        return out

In [54]:
def compose(init, steps, events=None, controls=None, greedy=1.0,
                 temperature=1.0, teacher_forcing_ratio=1.0, output_type='index', verbose=True):
    
    batch_size = init.shape[0]
    assert init.shape[1] == init_dim
    assert steps > 0 #max_len

    event = get_primary_event(batch_size)
    hidden = init_to_hidden(init)

    outputs = []
    step_iter = range(steps)
    if verbose:
        step_iter = Bar('Generating').iter(step_iter)

    for step in step_iter:
        control = controls[step].unsqueeze(0) if use_control else None
        output, hidden = self.forward(event, control, hidden)

        use_greedy = np.random.random() < greedy
        event = self._sample_event(output, greedy=use_greedy,
                                    temperature=temperature)

        if output_type == 'index':
            outputs.append(event)
        elif output_type == 'softmax':
            outputs.append(self.output_fc_activation(output))
        elif output_type == 'logit':
            outputs.append(output)
        else:
            assert False
    
    return torch.cat(outputs, 0)


In [55]:
batch_gen = dataset.batches(batch_size, window_size, stride_size)
a = 0
with torch.no_grad():
    for iteration, (events, controls, events2, controls2) in enumerate(batch_gen):
        print(events[-1])
        print(controls.shape)
        print('-'*70)
        a += 1
        if (a == 3):
            break

[147 142  48  41 225 143  59 197 197  51 138  55  41 197  59 197 143 142
 152 127  55 220 220  38 197 127 197  58 197  42 197 197  31 220 124 197
 215  29 197 197 197 220 110 219  61  60  50 197 230 197  50 220  10  50
 220 197 197 197  61 140  55  33 210 155]
(200, 64, 24)
----------------------------------------------------------------------
[ 15 220   9 220 197 197 220 197 197  59  52 197  40 230 197 210 220  35
 140  45 197 197 210 197 197 142 220 197 197  59 138 197 220  62 140 197
 137 197 123 197 197 146  57 197 220 147 142 138  49  47  39 140  59 220
 126 220 197 222 131 197  52  55  46  47]
(200, 64, 24)
----------------------------------------------------------------------
[210  47 150  57  40 135 197 197 197  59 152 220 197 126 197 143 122  49
 220 220 215 197 197 210 197  62 197  54 197  54 197 220 220 220 143  55
 197 197 220 220 197 128 115 197 197  60 130 197 140 197 220  62 217 137
 197  47 150 197 225 144 197 197  45 222]
(200, 64, 24)
---------------------------------

In [40]:
batch_gen = dataset.batches(batch_size, window_size, stride_size)
with torch.no_grad():
    for iteration, (events, controls, events2, controls2) in enumerate(batch_gen):

        #1st generation

        events = torch.LongTensor(events).to(device)
        assert events.shape[0] == window_size

        if np.random.random() < control_ratio:
            controls = torch.FloatTensor(controls).to(device)
            assert controls.shape[0] == window_size
        else:
            controls = None

        init = torch.randn(batch_size, init_dim).to(device)
        outputs = model_artist.generate(init, window_size, events=events[:-1], controls=controls,
                                    teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit')
        assert outputs.shape[:2] == events.shape[:2]

        #2nd generation

        events2 = torch.LongTensor(events2).to(device)
        assert events2.shape[0] == window_size

        if np.random.random() < control_ratio:
            controls2 = torch.FloatTensor(controls2).to(device)
            assert controls2.shape[0] == window_size
        else:
            controls = None

        init2 = torch.randn(batch_size, init_dim).to(device)
        outputs2 = model_AI.generate(init, window_size, events=events[:-1], controls=controls,
                                    teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit')
        assert outputs2.shape[:2] == events.shape[:2]

        print(outputs)
        print('-'*70)
        print(outputs.shape)
        print('-'*70)
        print(events)
        print('-'*70)
        print(events.shape)
        print('-'*70)
        outputs1 = outputs.cpu().numpy().T 
        print('-'*70)
        outputs2 = outputs2.cpu().numpy().T 

        break

tensor([[[ 0.0168,  0.1439,  0.1157,  ...,  0.0753, -0.0046, -0.1751],
         [-0.1868, -0.1477,  0.1475,  ...,  0.1131,  0.0596,  0.1878],
         [-0.0701,  0.1899,  0.0099,  ...,  0.0040, -0.0280, -0.3033],
         ...,
         [-0.1647, -0.1041,  0.2173,  ...,  0.1843, -0.1301,  0.0569],
         [ 0.0723,  0.0560,  0.1791,  ...,  0.0599, -0.2430,  0.1499],
         [-0.2988,  0.0029,  0.1291,  ..., -0.0918,  0.1187,  0.1205]],

        [[-0.0377,  0.0994,  0.0782,  ...,  0.0400, -0.0678, -0.0984],
         [-0.1539, -0.0200,  0.1461,  ...,  0.0798, -0.0147,  0.1074],
         [-0.0821,  0.0928,  0.1158,  ...,  0.0106, -0.0716, -0.1778],
         ...,
         [-0.1606, -0.0632,  0.2021,  ...,  0.1169, -0.0907, -0.0140],
         [ 0.0475,  0.0626,  0.1557,  ...,  0.0476, -0.2077,  0.1325],
         [-0.2209, -0.0143,  0.1155,  ..., -0.0210, -0.0186,  0.0668]],

        [[-0.0554,  0.1093,  0.0833,  ...,  0.0146, -0.0959, -0.0410],
         [-0.1186,  0.0182,  0.1063,  ...,  0