# Imports

In [190]:
import os
import re
import sys
import torch
import hashlib
import itertools
import logging
import numpy as np

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')


from progress.bar import Bar
from concurrent.futures import ProcessPoolExecutor

from Utils.NotesSeq import NoteSeq as ns
from Utils.EventSeq import EventSeq as es
from Utils.ControlSeq import ControlSeq as cs
from Utils import utils

import warnings
warnings.filterwarnings("ignore")

In [191]:
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical, Gumbel
from torch import optim

from config import device

In [192]:
from pretty_midi import PrettyMIDI, Note, Instrument

# options

In [193]:
sess_path = 'save/train.sess'
data_path = 'Processed-RAW'
saving_interval = 60.
reset_optimizer = False
enable_logging = False

# Dataset

In [194]:
class Dataset:
    def __init__(self, root, verbose=False):
        assert os.path.isdir(root), root
        paths = utils.find_files_by_extensions(root, ['.data'])

        self.root = root
        self.samples = []
        self.seqlens = []
        self.samples2 = []
        self.seqlens2 = []

        if verbose:
            paths = Bar(root).iter(list(paths))
        for path in paths:
            eventseq, eventseq2, controlseq, controlseq2 = torch.load(path)
            controlseq = cs.recover_compressed_array(controlseq)
            controlseq2 = cs.recover_compressed_array(controlseq2)
            assert len(eventseq) == len(controlseq)
            assert len(eventseq2) == len(controlseq2)
            self.samples.append((eventseq, controlseq))
            self.seqlens.append(len(eventseq))
            self.samples2.append((eventseq2,controlseq2))
            self.seqlens2.append(len(eventseq2))

        self.avglen = np.mean(self.seqlens)
        self.avglen2 = np.mean(self.seqlens2)
    
    def batches(self, batch_size, window_size, stride_size):
        indeces = [(i, range(j, j + window_size))
                   for i, seqlen in enumerate(self.seqlens)
                   for j in range(0, seqlen - window_size, stride_size)]
        while True:
            eventseq_batch = []
            controlseq_batch = []
            eventseq_batch2 = []
            controlseq_batch2 = []
            n = 0
            for ii in np.random.permutation(len(indeces)):
                i, r = indeces[ii]

                eventseq, controlseq = self.samples[i]
                eventseq2, controlseq2 = self.samples2[i]

                eventseq = eventseq[r.start:r.stop]
                eventseq2 = eventseq2[r.start:r.stop]

                controlseq = controlseq[r.start:r.stop]
                controlseq2 = controlseq2[r.start:r.stop]

                eventseq_batch.append(eventseq)
                controlseq_batch.append(controlseq)
                eventseq_batch2.append(eventseq2)
                controlseq_batch2.append(controlseq2)

                n += 1
                if n == batch_size:
                    yield (np.stack(eventseq_batch, axis=1),
                           np.stack(controlseq_batch, axis=1),
                           np.stack(eventseq_batch, axis=1),
                           np.stack(controlseq_batch, axis=1))
                    eventseq_batch.clear()
                    controlseq_batch.clear()
                    eventseq_batch2.clear()
                    controlseq_batch2.clear()
                    n = 0
    
    def __repr__(self):
        return (f'Dataset(root="{self.root}", '
                f'samples={len(self.samples)}, '
                f'avglen={self.avglen})')


In [195]:
dataset = Dataset(data_path, verbose=True)

In [196]:
dataset_size = len(dataset.samples)
assert dataset_size > 0

In [197]:
class P2X(nn.Module):

    def _initialize_weights(self):
        nn.init.xavier_normal_(self.event_embedding.weight)
        nn.init.xavier_normal_(self.inithid_fc.weight)
        self.inithid_fc.bias.data.fill_(0.)
        nn.init.xavier_normal_(self.concat_input_fc.weight)
        nn.init.xavier_normal_(self.output_fc.weight)
        self.output_fc.bias.data.fill_(0.)

    def __init__(self, event_dim, control_dim, init_dim, hidden_dim,
                 inithid_fc = None, gru_layers=3, gru_dropout=0.3):
        super().__init__()

        # Parameters initialization
        self.event_dim = event_dim
        self.control_dim = control_dim
        self.init_dim = init_dim
        self.hidden_dim = hidden_dim
        self.gru_layers = gru_layers
        self.concat_dim = event_dim + 1 + control_dim
        self.input_dim = hidden_dim
        self.output_dim = event_dim

        self.primary_event = self.event_dim - 1

        #Model Layers

        self.inithid_fc = nn.Linear(init_dim, gru_layers * hidden_dim)
        self.inithid_fc_activation = nn.Tanh()

        self.event_embedding = nn.Embedding(event_dim, event_dim)
        self.concat_input_fc = nn.Linear(self.concat_dim, self.input_dim)
        self.concat_input_fc_activation = nn.LeakyReLU(0.1, inplace=True)

        self.gru = nn.GRU(self.input_dim, self.hidden_dim,
                          num_layers=gru_layers, dropout=gru_dropout)
        self.output_fc = nn.Linear(hidden_dim * gru_layers, self.output_dim)
        self.output_fc_activation = nn.Softmax(dim=-1)

        self._initialize_weights()

    def forward(self, event, control=None, hidden=None):
        # One step forward

        assert len(event.shape) == 2
        assert event.shape[0] == 1
        batch_size = event.shape[1]
        event = self.event_embedding(event)

        if control is None:
            default = torch.ones(1, batch_size, 1).to(device)
            control = torch.zeros(1, batch_size, self.control_dim).to(device)
        else:
            default = torch.zeros(1, batch_size, 1).to(device)
            assert control.shape == (1, batch_size, self.control_dim)

        concat = torch.cat([event, default, control], -1)
        input = self.concat_input_fc(concat)  #nn.Linear(self.concat_dim, self.input_dim)
        input = self.concat_input_fc_activation(input)  #nn.LeakyReLU(0.1, inplace=True)

        _, hidden = self.gru(input, hidden)  #nn.GRU(self.input_dim, self.hidden_dim,num_layers=gru_layers, dropout=gru_dropout)

        output = hidden.permute(1, 0, 2).contiguous()
        output = output.view(batch_size, -1).unsqueeze(0)
        output = self.output_fc(output) #nn.Linear(hidden_dim * gru_layers, self.output_dim)
        return output, hidden # output is under the form of a logit
    
    def _sample_event(self, output, greedy=True, temperature=1.0):
        if greedy:
            return output.argmax(-1)
        else:
            output = output / temperature
            probs = self.output_fc_activation(output)
            return Categorical(probs).sample()
    
    def get_primary_event(self, batch_size):
        return torch.LongTensor([[self.primary_event] * batch_size]).to(device)

    def init_to_hidden(self, init):
        # [batch_size, init_dim]
        batch_size = init.shape[0]
        out = self.inithid_fc(init)
        out = self.inithid_fc_activation(out)
        out = out.view(self.gru_layers, batch_size, self.hidden_dim)
        return out
    
    def expand_controls(self, controls, steps):
        # [1 or steps, batch_size, control_dim]
        assert len(controls.shape) == 3
        assert controls.shape[2] == self.control_dim
        if controls.shape[0] > 1:
            assert controls.shape[0] >= steps
            return controls[:steps]
        return controls.repeat(steps, 1, 1)
    
    def generate(self, init, batch_size, init_dim, steps, events = None, controls = None,
                 verbose = True, greedy = 1.0, temperature = 1.0):
        batch_size = batch_size
        self.init_dim = init_dim

        assert init.shape[1] == self.init_dim
        assert steps > 0

        use_teacher_forcing = events is not None

        if use_teacher_forcing:
            assert len(events.shape) == 2
            assert events.shape[0] >= steps - 1
            events = events[:steps-1]

        event = self.get_primary_event(batch_size)

        use_control = controls is not None

        if use_control:
            controls = self.expand_controls(controls, steps)
        hidden = self.init_to_hidden(init)

        outputs = []
        step_iter = range(steps)

        if verbose:
            step_iter = Bar('Generating').iter(step_iter)

        for step in step_iter:
            #control = controls[step].unsqueeze(0) if use_control else None
            control = None
            output, hidden = self.forward(event, control, hidden)

            use_greedy = np.random.random() < greedy
            event = self._sample_event(output, greedy=use_greedy,
                                       temperature=temperature)
            
            #here outputs are served in the lo-git format
            outputs.append(output)
            #
            #if use_teacher_forcing and step < steps - 1: # avoid last one
            #    if np.random.random() <= teacher_forcing_ratio:
            #        event = events[step].unsqueeze(0)
        
        return torch.cat(outputs, 0)

In [198]:
class P2XSecondary(nn.Module):

    def _initialize_weights(self):
        nn.init.xavier_normal_(self.event_embedding.weight)
        nn.init.xavier_normal_(self.inithid_fc.weight)
        self.inithid_fc.bias.data.fill_(0.)
        nn.init.xavier_normal_(self.concat_input_fc.weight)
        nn.init.xavier_normal_(self.output_fc.weight)
        self.output_fc.bias.data.fill_(0.)

    def __init__(self, event_dim, control_dim, init_dim, hidden_dim,
                 inithid_fc = None, gru_layers=3, gru_dropout=0.3):
        super().__init__()

        # Parameters initialization
        self.event_dim = event_dim
        self.control_dim = control_dim
        self.init_dim = init_dim
        self.hidden_dim = hidden_dim
        self.gru_layers = gru_layers
        self.concat_dim = event_dim + event_dim + 1 + control_dim
        self.input_dim = hidden_dim
        self.output_dim = event_dim

        self.primary_event = self.event_dim - 1

        #Model Layers

        self.inithid_fc = nn.Linear(init_dim, gru_layers * hidden_dim)
        self.inithid_fc_activation = nn.Tanh()

        self.event_embedding = nn.Embedding(event_dim, event_dim)
        self.concat_input_fc = nn.Linear(self.concat_dim, self.input_dim)
        self.concat_input_fc_activation = nn.LeakyReLU(0.1, inplace=True)

        self.gru = nn.GRU(self.input_dim, self.hidden_dim,
                          num_layers=gru_layers, dropout=gru_dropout)
        self.output_fc = nn.Linear(hidden_dim * gru_layers, self.output_dim)
        self.output_fc_activation = nn.Softmax(dim=-1)

        self._initialize_weights()

    def forward(self, event, event2, control=None, hidden=None):
        # One step forward

        assert len(event.shape) == 2
        assert event.shape[0] == 1
        batch_size = event.shape[1]
        event = self.event_embedding(event)

        #print(event2.shape)
        if(len(event2.shape) <= 2):
            event2 = torch.tensor(event2).to(device).long()
            event2 = self.event_embedding(event2)

        """
        print(event2.shape)
        print(event.shape)
        print(len(event.shape))
        """

        if control is None:
            default = torch.ones(1, batch_size, 1).to(device)
            control = torch.zeros(1, batch_size, self.control_dim).to(device)
        else:
            default = torch.zeros(1, batch_size, 1).to(device)
            assert control.shape == (1, batch_size, self.control_dim)

        event = torch.cat([event,event2],-1)
        concat = torch.cat([event, default, control], -1)
        input = self.concat_input_fc(concat)  #nn.Linear(self.concat_dim, self.input_dim)
        input = self.concat_input_fc_activation(input)  #nn.LeakyReLU(0.1, inplace=True)

        _, hidden = self.gru(input, hidden)  #nn.GRU(self.input_dim, self.hidden_dim,num_layers=gru_layers, dropout=gru_dropout)

        output = hidden.permute(1, 0, 2).contiguous()
        output = output.view(batch_size, -1).unsqueeze(0)
        output = self.output_fc(output) #nn.Linear(hidden_dim * gru_layers, self.output_dim)

        #print("prima passata")

        return output, hidden # output is under the form of a logit
    
    def _sample_event(self, output, greedy=True, temperature=1.0):
        if greedy:
            return output.argmax(-1)
        else:
            output = output / temperature
            probs = self.output_fc_activation(output)
            return Categorical(probs).sample()
    
    def get_primary_event(self, batch_size):
        return torch.LongTensor([[self.primary_event] * batch_size]).to(device)

    def init_to_hidden(self, init):
        # [batch_size, init_dim]
        batch_size = init.shape[0]
        out = self.inithid_fc(init)
        out = self.inithid_fc_activation(out)
        out = out.view(self.gru_layers, batch_size, self.hidden_dim)
        return out
    
    def expand_controls(self, controls, steps):
        # [1 or steps, batch_size, control_dim]
        assert len(controls.shape) == 3
        assert controls.shape[2] == self.control_dim
        if controls.shape[0] > 1:
            assert controls.shape[0] >= steps
            return controls[:steps]
        return controls.repeat(steps, 1, 1)
    
    def generate(self, init, batch_size, init_dim, steps, events = None, controls = None, events2 = None,
                 verbose = True, greedy = 1.0, temperature = 1.0, teacher_forcing_ratio = 1.0):
        batch_size = batch_size
        self.init_dim = init_dim

        assert init.shape[1] == self.init_dim
        assert steps > 0

        use_teacher_forcing = events is not None

        if use_teacher_forcing:
            assert len(events.shape) == 2
            assert events.shape[0] >= steps - 1
            events = events[:steps-1]

        event = self.get_primary_event(batch_size)
        event2 = self.get_primary_event(batch_size)

        use_control = controls is not None

        if use_control:
            controls = self.expand_controls(controls, steps)
        hidden = self.init_to_hidden(init)

        outputs = []
        step_iter = range(steps)

        if verbose:
            step_iter = Bar('Generating').iter(step_iter)

        for step in step_iter:
            #control = controls[step].unsqueeze(0) if use_control else None
            control = None
            output, _ = self.forward(event, event2, control, hidden)

            use_greedy = np.random.random() < greedy
            event = self._sample_event(output, greedy=use_greedy,
                                       temperature=temperature)
            
            #here outputs are served in the logit format
            outputs.append(output)
            #
            if use_teacher_forcing and step < steps - 1: # avoid last one
                if np.random.random() <= teacher_forcing_ratio:
                    event = events[step].unsqueeze(0)
                    event2 = events2[step].unsqueeze(0)
        
        outputs = torch.cat(outputs, 0)

        return outputs
        


# Training

In [199]:
# Variables for training process

init_dim = 32
event_dim = es.dim()
control_dim = cs.dim()
hidden_dim = 512
gru_layers = 3
gru_droput = 0.3

In [200]:
learning_rate = 0.001
batch_size = 64
window_size = 200
stride_size = 10
use_transposition = False
control_ratio = 1.0
teacher_forcing_ratio = 1.0

In [201]:
model_config = {
    'init_dim': init_dim,
    'event_dim': event_dim,
    'control_dim': control_dim,
    'hidden_dim': hidden_dim,
    'gru_layers': gru_layers,
    'gru_dropout': gru_droput,
}

In [202]:
init = torch.randn(batch_size, init_dim).to(device)

In [203]:
model = P2X(**model_config).to(device)
model2 = P2XSecondary(**model_config).to(device)

In [204]:
params = list(model.parameters()) + list(model.parameters())
paramsB = list(model.parameters()) + list(model2.parameters())

optimizer = optim.Adam(paramsB, lr=learning_rate)

In [205]:
loss_function = nn.CrossEntropyLoss()

In [206]:

#try:
batch_gen = dataset.batches(batch_size, window_size, stride_size)
for iteration, data in enumerate(batch_gen):
    e1 = data[0]
    c1 = data[1]
    e2 = data[2]
    c2 = data[3]

    # First Model process

    events = torch.LongTensor(e1).to(device)
    assert events.shape[0] == window_size
    assert len(events.shape) == 2
    assert events.shape[0] >= window_size - 1

    if np.random.random() < control_ratio:
        controls = torch.FloatTensor(c1).to(device)
        assert controls.shape[0] == window_size
    else:
        controls = None

    init = torch.randn(batch_size, model.init_dim).to(device)
    outputs = model.generate(init,batch_size, model.init_dim, window_size, 
                            events[:-1], controls)
    
    assert outputs.shape[:2] == events.shape[:2]
    loss = loss_function(outputs.view(-1, event_dim), events.view(-1))
    #print(outputs)

    # Second Model process

    eventsB = torch.LongTensor(e2).to(device)
    assert eventsB.shape[0] == window_size
    assert len(eventsB.shape) == 2
    assert eventsB.shape[0] >= window_size - 1

    if np.random.random() < control_ratio:
        controls = torch.FloatTensor(c2).to(device)
        assert controls.shape[0] == window_size
    else:
        controls = None

    init = torch.randn(batch_size, model2.init_dim).to(device)
    outputsB = model2.generate(init,batch_size, model2.init_dim, window_size, 
                            events=eventsB[:-1], events2=outputs[:-1], controls=controls)
    
    assert outputsB.shape[:2] == eventsB.shape[:2]
    #eventsF = torch.cat((events,eventsB))
    #outputsF = torch.cat((outputs, outputsB),2)
    #print(outputsF)

    loss = loss + loss_function(outputsB.view(-1, event_dim), eventsB.view(-1))

    model.zero_grad()
    model2.zero_grad()
    loss.backward()

    #concatenated_vectors = torch.cat((outputs, outputsB),2)
    #print(outputsF.shape)
    #print(outputsB.shape)
    #print(events.shape)

    norm = utils.compute_gradient_norm(model.parameters())
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    optimizer.step()

    print(f'iter {iteration}, loss: {loss.item()}')

    

#except:
#    print("banane")


iter 0, loss: 10.96728515625
iter 1, loss: 10.79328727722168
iter 2, loss: 10.409486770629883
iter 3, loss: 11.087179183959961
iter 4, loss: 9.985830307006836
iter 5, loss: 9.799489974975586
