<a href="https://colab.research.google.com/github/EmanueleCosenza/Polyphemus/blob/main/midi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pwd

In [None]:
!git branch

Libraries installation

In [None]:
#!tar -C data -xvzf data/lmd_matched.tar.gz

In [None]:
# Install the required music libraries
#!pip3 install muspy
#!pip3 install pypianoroll

In [None]:
# Install torch_geometric
#!v=$(python3 -c "import torch; print(torch.__version__)"); \
#pip3 install torch-scatter -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-sparse -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-geometric

Libraries and reproducibility

In [1]:
import numpy as np
import torch
import random
import os
import muspy
from itertools import product
import pypianoroll as pproll
import time
from tqdm.auto import tqdm

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

seed = 42
set_seed(seed)

In [2]:
class MIDIPreprocessor():
    
    def __init__():
        pass

    def preprocess_dataset(self, dir, early_exit=None):
        pass
    
    def preprocess_file(self, f):
        pass


# Todo: to config file (or separate files)
MAX_SIMU_NOTES = 16 # 14 + SOS and EOS

PITCH_SOS = 128
PITCH_EOS = 129
PITCH_PAD = 130
DUR_SOS = 96
DUR_EOS = 97
DUR_PAD = 98

MAX_DUR = 96

# Number of time steps per quarter note
# To get bar resolution -> RESOLUTION*4
RESOLUTION = 8 
NUM_BARS = 2


def preprocess_file(filepath, dest_dir, num_samples):

    saved_samples = 0

    print("Preprocessing file " + filepath)

    # Load the file both as a pypianoroll song and a muspy song
    # (Need to load both since muspy.to_pypianoroll() is expensive)
    try:
        pproll_song = pproll.read(filepath, resolution=RESOLUTION)
        muspy_song = muspy.read(filepath)
    except Exception as e:
        print("Song skipped (Invalid song format)")
        return 0
    
    # Only accept songs that have a time signature of 4/4 and no time changes
    for t in muspy_song.time_signatures:
        if t.numerator != 4 or t.denominator != 4:
            print("Song skipped ({}/{} time signature)".
                            format(t.numerator, t.denominator))
            return 0

    # Gather tracks of pypianoroll song based on MIDI program number
    drum_tracks = []
    bass_tracks = []
    guitar_tracks = []
    strings_tracks = []

    for track in pproll_song.tracks:
        if track.is_drum:
            #continue
            track.name = 'Drums'
            drum_tracks.append(track)
        elif 0 <= track.program <= 31:
            track.name = 'Guitar'
            guitar_tracks.append(track)
        elif 32 <= track.program <= 39:
            track.name = 'Bass'
            bass_tracks.append(track)
        else:
            # Tracks with program > 39 are all considered as strings tracks
            # and will be merged into a single track later on
            strings_tracks.append(track)

    # Filter song if it does not contain drum, guitar, bass or strings tracks
    #if not guitar_tracks \
    if not drum_tracks or not guitar_tracks \
            or not bass_tracks or not strings_tracks:
        print("Song skipped (does not contain drum or "
                "guitar or bass or strings tracks)")
        return 0
    
    # Merge strings tracks into a single pypianoroll track
    strings = pproll.Multitrack(tracks=strings_tracks)
    strings_track = pproll.Track(pianoroll=strings.blend(mode='max'),
                                 program=48, name='Strings')

    combinations = list(product(drum_tracks, bass_tracks, guitar_tracks))
    #combinations = list(product(bass_tracks, guitar_tracks))

    # Single instruments can have multiple tracks.
    # Consider all possible combinations of drum, bass, and guitar tracks
    for i, combination in enumerate(combinations):

        print("Processing combination", i+1, "of", len(combinations))
        
        # Process combination (called 'subsong' from now on)
        drum_track, bass_track, guitar_track = combination
        tracks = [drum_track, bass_track, guitar_track, strings_track]
        #bass_track, guitar_track = combination
        #tracks = [bass_track, guitar_track, strings_track]
        
        pproll_subsong = pproll.Multitrack(
            tracks=tracks,
            tempo=pproll_song.tempo,
            resolution=RESOLUTION
        )
        muspy_subsong = muspy.from_pypianoroll(pproll_subsong)
        
        tracks_notes = [track.notes for track in muspy_subsong.tracks]
        
        # Obtain length of subsong (maximum of each track's length)
        length = 0
        for notes in tracks_notes:
            track_length = max(note.end for note in notes) if notes else 0
            length = max(length, track_length)
        length += 1

        # Add timesteps until length is a multiple of RESOLUTION
        length = length if length%(RESOLUTION*4) == 0 \
                            else length + (RESOLUTION*4-(length%(RESOLUTION*4)))


        tracks_tensors = []
        tracks_activations = []

        # Todo: adapt to velocity
        for notes in tracks_notes:

            # Initialize encoder-ready track tensor
            # track_tensor: (length x max_simu_notes x 2 (or 3 if velocity))
            # The last dimension contains pitches and durations (and velocities)
            # int16 is enough for small to medium duration values
            track_tensor = np.zeros((length, MAX_SIMU_NOTES, 2), np.int16)

            track_tensor[:, :, 0] = PITCH_PAD
            track_tensor[:, 0, 0] = PITCH_SOS
            track_tensor[:, :, 1] = DUR_PAD
            track_tensor[:, 0, 1] = DUR_SOS

            # Keeps track of how many notes have been stored in each timestep
            # (int8 imposes that MAX_SIMU_NOTES < 256)
            notes_counter = np.ones(length, dtype=np.int8)

            # Todo: np.put_along_axis?
            for note in notes:
                # Insert note in the lowest position available in the timestep
                
                t = note.time
                
                if notes_counter[t] >= MAX_SIMU_NOTES-1:
                    # Skip note if there is no more space
                    continue
                
                pitch = max(min(note.pitch, 127), 0)
                track_tensor[t, notes_counter[t], 0] = pitch
                dur = max(min(MAX_DUR, note.duration), 1)
                track_tensor[t, notes_counter[t], 1] = dur-1
                notes_counter[t] += 1
            
            # Add end of sequence token
            track_tensor[np.arange(0, length), notes_counter, 0] = PITCH_EOS
            track_tensor[np.arange(0, length), notes_counter, 1] = DUR_EOS
            
            # Get track activations, a boolean tensor indicating whether notes
            # are being played in a timestep (sustain does not count)
            # (needed for graph rep.)
            activations = np.array(notes_counter-1, dtype=bool)
            
            tracks_tensors.append(track_tensor)
            tracks_activations.append(activations)
        
        # (#tracks x length x max_simu_notes x 2 (or 3))
        subsong_tensor = np.stack(tracks_tensors, axis=0)

        # (#tracks x length)
        subsong_activations = np.stack(tracks_activations, axis=0)


        # Slide window over 'subsong_tensor' and 'subsong_activations' along the
        # time axis (2nd dimension) with the stride of a bar
        # Todo: np.lib.stride_tricks.as_strided(song_proll)
        for i in range(0, length-NUM_BARS*RESOLUTION*4+1, RESOLUTION*4):
            
            # Get the sequence and its activations
            seq_tensor = subsong_tensor[:, i:i+NUM_BARS*RESOLUTION*4, :]
            seq_acts = subsong_activations[:, i:i+NUM_BARS*RESOLUTION*4]
            seq_tensor = np.copy(seq_tensor)
            seq_acts = np.copy(seq_acts)

            if NUM_BARS > 1:
                # Skip sequence if it contains more than one bar of consecutive
                # silence in at least one track
                bars = seq_acts.reshape(seq_acts.shape[0], NUM_BARS, -1)
                bars_acts = np.any(bars, axis=2)

                if 1 in np.diff(np.where(bars_acts == 0)[1]):
                    continue
                    
                # Skip sequence if it contains one bar of complete silence
                # (in terms of note activations)
                silences = np.logical_not(np.any(bars_acts, axis=0))
                if np.any(silences):
                    continue
                
            else:
                # In the case of just 1 bar, skip it if all tracks are silenced
                bar_acts = np.any(seq_acts, axis=1)
                if not np.any(bar_acts):
                    continue
            
            # Randomly transpose the pitches of the sequence (-5 to 6 semitones)
            # Not considering pad, sos, eos tokens
            # Not transposing drums/percussions
            shift = np.random.choice(np.arange(-5, 7), 1)
            cond = (seq_tensor[1:, :, :, 0] != PITCH_PAD) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_SOS) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_EOS)
            #cond = (seq_tensor[:, :, :, 0] != PITCH_PAD) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_SOS) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_EOS)
            non_perc = seq_tensor[1:, ...]
            #non_perc = seq_tensor
            non_perc[cond, 0] += shift
            non_perc[cond, 0] = np.clip(non_perc[cond, 0], a_min=0, a_max=127)

            # Save sample (seq_tensor and seq_acts) to file
            curr_sample = str(num_samples + saved_samples)
            sample_filepath = os.path.join(dest_dir, curr_sample)
            np.savez(sample_filepath, seq_tensor=seq_tensor, seq_acts=seq_acts)

            saved_samples += 1


    print("File preprocessing finished. Saved samples:", saved_samples)
    print()

    return saved_samples



# Total number of files: 116189
# Number of unique files: 45129
def preprocess_dataset(dataset_dir, dest_dir, num_files=45129, early_exit=None):

    files_dict = {}
    seen = 0
    tot_samples = 0
    not_filtered = 0
    finished = False
    
    print("Starting preprocessing")
    
    progress_bar = tqdm(range(early_exit)) if early_exit is not None else tqdm(range(num_files))
    start = time.time()

    # Visit recursively the directories inside the dataset directory
    for dirpath, dirs, files in os.walk(dataset_dir):

        # Sort alphabetically the found directories
        # (to help guess the remaining time) 
        dirs.sort()
        
        print("Current path:", dirpath)
        print()

        for f in files:
            
            seen += 1

            if f in files_dict:
                # Skip already seen file
                files_dict[f] += 1
                continue

            # File never seen before, add to dictionary of files
            # (from filename to # of occurrences)
            files_dict[f] = 1

            # Preprocess file
            filepath = os.path.join(dirpath, f)
            n_saved = preprocess_file(filepath, dest_dir, tot_samples)

            tot_samples += n_saved
            if n_saved > 0:
                not_filtered += 1
            
            progress_bar.update(1)
            
            # Todo: also print # of processed (not filtered) files
            #       and # of produced sequences (samples)
            print("Total number of seen files:", seen)
            print("Number of unique files:", len(files_dict))
            print("Total number of non filtered songs:", not_filtered)
            print("Total number of saved samples:", tot_samples)
            print()

            # Exit when a maximum number of files has been processed (if set)
            if early_exit != None and len(files_dict) >= early_exit:
                finished = True
                break

        if finished:
            break
    
    end = time.time()
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Preprocessing completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
              .format(int(hours),int(minutes),seconds))


In [3]:
#!rm -rf data/preprocessed/
#!mkdir data/preprocessed

In [4]:
#dataset_dir = 'data/lmd_matched/Y/G/'
#dest_dir = 'data/preprocessed'

Check preprocessed data:

In [5]:
#preprocess_dataset(dataset_dir, dest_dir, early_exit=10)

In [6]:
#filepath = os.path.join(dest_dir, "5.npz")
#data = np.load(filepath)

In [7]:
#print(data["seq_tensor"].shape)
#print(data["seq_acts"].shape)

In [8]:
#data["seq_tensor"][0, 1]

# Model

In [9]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.data import Data
import itertools
from torch_geometric.data.collate import collate


class MIDIDataset(Dataset):

    def __init__(self, dir, n_bars=2):
        self.dir = dir
        self.files = list(os.scandir(self.dir))
        self.len = len(self.files)
        self.n_bars = n_bars

        
    def __len__(self):
        return self.len

    
    def _get_track_edges(self, acts, edge_type_ind=0):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        
        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        track_edges = []

        for track in range(a_t.shape[1]):
            tr_inds = list(inds[inds[:,1] == track])
            e_inds = [(tr_inds[i],
                    tr_inds[i+1]) for i in range(len(tr_inds)-1)]
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind+track, e[1][0]-e[0][0]) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            track_edges.extend(edges)
            track_edges.extend(inv_edges)

        return np.array(track_edges, dtype='long')

    
    def _get_onset_edges(self, acts, edge_type_ind=4):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        onset_edges = []

        for i in ts_inds:
            ts_acts_inds = list(inds[inds[:,0] == i])
            if len(ts_acts_inds) < 2:
                continue
            e_inds = list(itertools.combinations(ts_acts_inds, 2))
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, 0) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            onset_edges.extend(edges)
            onset_edges.extend(inv_edges)

        return np.array(onset_edges, dtype='long')


    def _get_next_edges(self, acts, edge_type_ind=5):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        next_edges = []

        for i in range(len(ts_inds)-1):

            ind_s = ts_inds[i]
            ind_e = ts_inds[i+1]
            s = inds[inds[:,0] == ind_s]
            e = inds[inds[:,0] == ind_e]

            e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, ind_e-ind_s) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            
            next_edges.extend(edges)
            next_edges.extend(inv_edges)
            
        return np.array(next_edges, dtype='long')
    
    def _get_super_edges(self, num_nodes, edge_type_ind=6):
    
        super_edges = [(num_nodes, i, edge_type_ind, 0) for i in range(num_nodes)]
        inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
        
        super_edges.extend(inv_edges)
        
        return np.array(super_edges, dtype='long')
        
    
    def _get_node_features(self, acts, num_nodes):
        
        num_tracks = acts.shape[0]
        features = torch.zeros((num_nodes, num_tracks), dtype=torch.float)
        features[np.arange(num_nodes), np.stack(np.where(acts))[0]] = 1.
        
        return features


    def __getitem__(self, idx):

        # Load tensors
        sample_path = os.path.join(self.dir, self.files[idx].name)
        data = np.load(sample_path)
        seq_tensor = data["seq_tensor"]
        seq_acts = data["seq_acts"]
        
        # From (#tracks x #timesteps x ...) to (#bars x #tracks x #timesteps x ...)
        seq_tensor = seq_tensor.reshape(seq_tensor.shape[0], self.n_bars, -1,
                                        seq_tensor.shape[2], seq_tensor.shape[3])
        seq_tensor = seq_tensor.transpose(1, 0, 2, 3, 4)
        seq_acts = seq_acts.reshape(seq_acts.shape[0], self.n_bars, -1)
        seq_acts = seq_acts.transpose(1, 0, 2)
        
        # Construct src_key_padding_mask (PAD = 130)
        src_mask = torch.from_numpy((seq_tensor[..., 0] == 130))

        # From decimals to one-hot (pitch)
        pitches = seq_tensor[..., 0]
        onehot_p = np.zeros(
            (pitches.shape[0]*pitches.shape[1]*pitches.shape[2]*pitches.shape[3],
             131), 
            dtype=float
        )
        onehot_p[np.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1.
        onehot_p = onehot_p.reshape(pitches.shape[0], pitches.shape[1], 
                                    pitches.shape[2], pitches.shape[3], 131)
        
        # From decimals to one-hot (dur)
        durs = seq_tensor[..., 1]
        onehot_d = np.zeros(
            (durs.shape[0]*durs.shape[1]*durs.shape[2]*durs.shape[3],
             99),
            dtype=float
        )
        onehot_d[np.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1.
        onehot_d = onehot_d.reshape(durs.shape[0], durs.shape[1], 
                                    durs.shape[2], durs.shape[3], 99)
        
        # Concatenate pitches and durations
        new_seq_tensor = np.concatenate((onehot_p, onehot_d),
                                        axis=-1)
        
        graphs = []
        
        # Iterate over bars and construct a graph for each bar
        for i in range(self.n_bars):
            
            # Number of nodes
            n = torch.sum(torch.Tensor(seq_acts[i]), dtype=torch.long)
            
            # Get edges from boolean activations
            # Todo: optimize and refactor
            track_edges = self._get_track_edges(seq_acts[i])
            onset_edges = self._get_onset_edges(seq_acts[i])
            next_edges = self._get_next_edges(seq_acts[i])
            #super_edges = self._get_super_edges(n)
            edges = [track_edges, onset_edges, next_edges]
            
            # Concatenate edge tensors (N x 4) (if any)
            # First two columns -> source and dest nodes
            # Third column -> edge_type, Fourth column -> timestep distance
            no_edges = (len(track_edges) == 0 and 
                        len(onset_edges) == 0 and len(next_edges) == 0)
            if not no_edges:
                edge_list = np.concatenate([x for x in edges
                                              if x.size > 0])
                edge_list = torch.from_numpy(edge_list)
                
            # Adapt tensor to torch_geometric's Data
            # No edges: add fictitious self-edge
            edge_index = (torch.LongTensor([[0], [0]]) if no_edges else
                                   edge_list[:, :2].t().contiguous())
            attrs = (torch.Tensor([[0, 0]]) if no_edges else
                                           edge_list[:, 2:])

            # One hot timestep distance concatenated to edge type
            edge_attrs = torch.zeros(attrs.size(0), 1+seq_acts.shape[-1])
            edge_attrs[:, 0] = attrs[:, 0]
            edge_attrs[np.arange(edge_attrs.size(0)), attrs.long()[:, 1]+1] = 1
            #edge_attrs = torch.Tensor(attrs.float())
            
            node_features = self._get_node_features(seq_acts[i], n)
            is_drum = node_features[:, 0].bool()
            
            graphs.append(Data(edge_index=edge_index, edge_attrs=edge_attrs,
                               num_nodes=n, node_features=node_features,
                               is_drum=is_drum))
            
            
        # Merge the graphs corresponding to different bars into a single big graph
        graphs, _, inc_dict = collate(
            Data,
            data_list=graphs,
            increment=True,
            add_batch=True
        )
        
        # Change bars assignment vector name (otherwise, Dataloader's collate
        # would overwrite graphs.batch)
        graphs.bars = graphs.batch
        
        # Filter silences in order to get a sparse representation
        new_seq_tensor = new_seq_tensor.reshape(-1, new_seq_tensor.shape[-2],
                                                new_seq_tensor.shape[-1])
        src_mask = src_mask.reshape(-1, src_mask.shape[-1])
        new_seq_tensor = new_seq_tensor[seq_acts.reshape(-1).astype(bool)]
        src_mask = src_mask[seq_acts.reshape(-1).astype(bool)]
        
        new_seq_tensor = torch.Tensor(new_seq_tensor)
        seq_acts = torch.Tensor(seq_acts)
        graphs.x_seq = new_seq_tensor
        graphs.x_acts = seq_acts
        graphs.src_mask = src_mask
        
        # Todo: start with torch at mount
        #return torch.Tensor(new_seq_tensor), torch.Tensor(seq_acts), graphs, src_mask
        return graphs


In [10]:
from typing import Optional, Union, Tuple
from torch_geometric.typing import OptTensor, Adj
from typing import Callable
from torch_geometric.nn.inits import reset

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter as Param
from torch.nn import Parameter
from torch_scatter import scatter
from torch_sparse import SparseTensor, matmul, masked_select_nnz
from torch_geometric.nn.conv import MessagePassing

from torch_geometric.nn.inits import glorot, zeros


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (Tensor, Tensor) -> Tensor
    pass


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (SparseTensor, Tensor) -> SparseTensor
    pass


def masked_edge_index(edge_index, edge_mask):
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]
    else:
        return masked_select_nnz(edge_index, edge_mask, layout='coo')

def masked_edge_attrs(edge_attrs, edge_mask):
    return edge_attrs[edge_mask, :]


class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
        \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
    stores a relation identifier
    :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.

    .. note::
        This implementation is as memory-efficient as possible by iterating
        over each individual relation type.
        Therefore, it may result in low GPU utilization in case the graph has a
        large number of relations.
        As an alternative approach, :class:`FastRGCNConv` does not iterate over
        each individual type, but may consume a large amount of memory to
        compensate.
        We advise to check out both implementations to see which one fits your
        needs.

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
            In case no input features are given, this argument should
            correspond to the number of nodes in your graph.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
            maps edge features :obj:`edge_attr` of shape :obj:`[-1,
            num_edge_features]` to shape
            :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`.
        num_bases (int, optional): If set to not :obj:`None`, this layer will
            use the basis-decomposition regularization scheme where
            :obj:`num_bases` denotes the number of bases to use.
            (default: :obj:`None`)
        num_blocks (int, optional): If set to not :obj:`None`, this layer will
            use the block-diagonal-decomposition regularization scheme where
            :obj:`num_blocks` denotes the number of blocks to use.
            (default: :obj:`None`)
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"mean"`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_relations: int,
        #num_dists: int,
        nn: Callable,
        num_bases: Optional[int] = None,
        num_blocks: Optional[int] = None,
        dropout: Optional[float] = 0.1,
        aggr: str = 'mean',
        root_weight: bool = True,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(aggr=aggr, node_dim=0, **kwargs)

        if num_bases is not None and num_blocks is not None:
            raise ValueError('Can not apply both basis-decomposition and '
                             'block-diagonal-decomposition at the same time.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.nn = nn
        #self.num_dists = num_dists
        self.dropout = dropout
        self.num_relations = num_relations
        self.num_bases = num_bases
        self.num_blocks = num_blocks

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)
        self.in_channels_l = in_channels[0]

        if num_bases is not None:
            self.weight = Parameter(
                torch.Tensor(num_bases, in_channels[0], out_channels))
            self.comp = Parameter(torch.Tensor(num_relations, num_bases))
        
        elif num_blocks is not None:
            assert (in_channels[0] % num_blocks == 0
                    and out_channels % num_blocks == 0)
            self.weight = Parameter(
                torch.Tensor(num_relations, num_blocks,
                             in_channels[0] // num_blocks,
                             out_channels // num_blocks))
            self.register_parameter('comp', None)

        else:
            self.weight = Parameter(
                torch.Tensor(num_relations, in_channels[0], out_channels))
            self.register_parameter('comp', None)

        if root_weight:
            self.root = Param(torch.Tensor(in_channels[1], out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Param(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        #self.dist_weights = Parameter(torch.Tensor(self.num_dists))

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        reset(self.nn)
        glorot(self.comp)
        glorot(self.root)
        #zeros(self.dist_weights)
        zeros(self.bias)


    def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
                edge_index: Adj, edge_type: OptTensor = None,
                edge_attr: OptTensor = None):
        r"""
        Args:
            x: The input node features. Can be either a :obj:`[num_nodes,
                in_channels]` node feature matrix, or an optional
                one-dimensional node index tensor (in which case input features
                are treated as trainable node embeddings).
                Furthermore, :obj:`x` can be of type :obj:`tuple` denoting
                source and destination node features.
            edge_type: The one-dimensional relation type/index for each edge in
                :obj:`edge_index`.
                Should be only :obj:`None` in case :obj:`edge_index` is of type
                :class:`torch_sparse.tensor.SparseTensor`.
                (default: :obj:`None`)
        """

        # Convert input features to a pair of node features or node indices.
        x_l: OptTensor = None
        if isinstance(x, tuple):
            x_l = x[0]
        else:
            x_l = x
        if x_l is None:
            x_l = torch.arange(self.in_channels_l, device=self.weight.device)

        x_r: Tensor = x_l
        if isinstance(x, tuple):
            x_r = x[1]

        size = (x_l.size(0), x_r.size(0))

        if isinstance(edge_index, SparseTensor):
            edge_type = edge_index.storage.value()
        assert edge_type is not None

        # propagate_type: (x: Tensor)
        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

        weight = self.weight
        if self.num_bases is not None:  # Basis-decomposition =================
            weight = (self.comp @ weight.view(self.num_bases, -1)).view(
                self.num_relations, self.in_channels_l, self.out_channels)

        if self.num_blocks is not None:  # Block-diagonal-decomposition =====

            if x_l.dtype == torch.long and self.num_blocks is not None:
                raise ValueError('Block-diagonal decomposition not supported '
                                 'for non-continuous input features.')

            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                h = self.propagate(tmp, x=x_l, size=size)
                h = h.view(-1, weight.size(1), weight.size(2))
                h = torch.einsum('abc,bcd->abd', h, weight[i])
                out += h.contiguous().view(-1, self.out_channels)

        else:  # No regularization/Basis-decomposition ========================
            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                attr = masked_edge_attrs(edge_attr, edge_type == i)

                if x_l.dtype == torch.long:
                    out += self.propagate(tmp, x=weight[i, x_l], size=size)
                else:
                    h = self.propagate(tmp, x=x_l, size=size,
                                       edge_attr=attr)
                    out = out + (h @ weight[i])

        root = self.root
        if root is not None:
            out += root[x_r] if x_r.dtype == torch.long else x_r @ root

        if self.bias is not None:
            out += self.bias

        return out


    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        #weights = self.dist_weights[edge_attr.view(-1).long()]
        #weights = torch.diag(weights)
        #return torch.matmul(weights, x_j)
        weights = self.nn(edge_attr)
        weights = weights[..., :self.in_channels_l]
        weights = weights.view(-1, self.in_channels_l)
        ret = x_j * weights
        ret = F.relu(ret)
        ret = F.dropout(ret, p=self.dropout, training=self.training)
        return ret
    
        
        
    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_relations={self.num_relations})')

In [11]:
import torch
from torch import nn, Tensor
from torch_geometric.nn.conv import GCNConv#, RGCNConv
from torch_geometric.nn.norm import BatchNorm
from torch_geometric.nn.glob import GlobalAttention
import torch.nn.functional as F
import math
import torch.optim as optim
from torch_scatter import scatter_mean
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence


# Todo: check and think about max_len
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 256):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *                     \
                             (-math.log(10000.0)/d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position*div_term)
        pe[:, 0, 1::2] = torch.cos(position*div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    

class MLP(nn.Module):
    
    def __init__(self, input_dim=256, hidden_dim=256, output_dim=256, num_layers=2,
                 act=True, dropout=0.1):
        super().__init__()
        
        assert num_layers >= 1
        
        self.layers = nn.ModuleList()
        
        if num_layers == 1:
            self.layers.append(nn.Linear(input_dim, output_dim))
        else:
            self.layers.append(nn.Linear(input_dim, hidden_dim))
        
            for i in range(num_layers-2):
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))

            self.layers.append(nn.Linear(hidden_dim, output_dim))
        
        self.act = act
        self.p = dropout
        

    def forward(self, x):
        
        for layer in self.layers:
            x = F.dropout(x, p=self.p, training=self.training)
            x = layer(x)
            if self.act:
                x = F.relu(x)
        
        return x
    
    
class EncoderRNN(nn.Module):
    
    def __init__(self, input_size=240, hidden_size=256, num_layers=2, 
                 dropout=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Linear(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True,
                            num_layers=num_layers, bidirectional=True)

    def forward(self, x, hidden):
        embedded = self.embedding(x).view(1, 1, -1)
        output = embedded
        output, hidden = self.lstm(output, hidden)
        return output, hidden
    

class DecoderRNN(nn.Module):
    
    def __init__(self, hidden_size=256, output_size=240, num_layers=2, 
                 dropout=0.1):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Linear(output_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True,
                            num_layers=num_layers, bidirectional=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        output = self.embedding(x).view(1, 1, -1)
        output, hidden = self.lstm(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)


class GraphEncoder(nn.Module):
    
    def __init__(self, input_dim=256, hidden_dim=256, n_layers=3, 
                 num_relations=3, num_dists=32, batch_norm=False, dropout=0.1):
        super().__init__()
        
        self.layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        edge_nn = nn.Linear(num_dists, input_dim)
        self.batch_norm = batch_norm
        
        self.layers.append(RGCNConv(input_dim, hidden_dim,
                                    num_relations, edge_nn))
        if self.batch_norm:
            self.norm_layers.append(BatchNorm(hidden_dim))
        
        for i in range(n_layers-1):
            #edge_nn = nn.Linear(num_dists, input_dim)
            self.layers.append(RGCNConv(hidden_dim, hidden_dim,
                                        num_relations, edge_nn))
            if self.batch_norm:
                self.norm_layers.append(BatchNorm(hidden_dim))
            
        self.p = dropout
        

    def forward(self, data):
        x, edge_index, edge_attrs = data.x, data.edge_index, data.edge_attrs
        #batch = data.distinct_bars
        edge_type = edge_attrs[:, 0]
        edge_attr = edge_attrs[:, 1:]
        
        for i in range(len(self.layers)):
            residual = x
            x = F.dropout(x, p=self.p, training=self.training)
            x = self.layers[i](x, edge_index, edge_type, edge_attr)
            if self.batch_norm:
                x = self.norm_layers[i](x)
            x = F.relu(x)
            x = residual + x

        return x
    
    
class CNNEncoder(nn.Module):
    
    def __init__(self, output_dim=256, dense_dim=256, batch_norm=False,
                 dropout=0.1):
        super().__init__()
        
        if batch_norm:
            self.conv = nn.Sequential(
                # 4*32 --> 8*4*32
                nn.Conv2d(1, 8, 3, padding=1),
                nn.BatchNorm2d(8),
                nn.ReLU(True),
                # 8*4*32 --> 8*4*8
                nn.MaxPool2d((1, 4), stride=(1, 4)),
                # 8*4*8 --> 16*4*8
                nn.Conv2d(8, 16, 3, padding=1),
                nn.BatchNorm2d(16),
                nn.ReLU(True)
            )
        else:
            self.conv = nn.Sequential(
                # 4*32 --> 8*4*32
                nn.Conv2d(1, 8, 3, padding=1),
                nn.ReLU(True),
                # 8*4*32 --> 8*4*8
                nn.MaxPool2d((1, 4), stride=(1, 4)),
                # 8*4*8 --> 16*4*8
                nn.Conv2d(8, 16, 3, padding=1),
                nn.ReLU(True)
            )
        
        self.flatten = nn.Flatten(start_dim=1)
        
        self.lin = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(16*4*8, dense_dim),
            nn.ReLU(True),
            nn.Dropout(dropout),
            nn.Linear(dense_dim, output_dim)
        )
        
        
    def forward(self, x):
        
        x = x.unsqueeze(1)
        x = self.conv(x)
        x = self.flatten(x)
        x = self.lin(x)
        
        return x
    

class CNNDecoder(nn.Module):
    
    def __init__(self, input_dim=256, dense_dim=256, batch_norm=False,
                 dropout=0.1):
        super().__init__()
        
        self.lin = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_dim, dense_dim),
            nn.ReLU(True),
            nn.Dropout(dropout),
            nn.Linear(dense_dim, 16*4*8),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(16, 4, 8))

        if batch_norm:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=(1, 4), mode='nearest'),
                nn.Conv2d(16, 8, 3, padding=1),
                nn.BatchNorm2d(8),
                nn.ReLU(True),
                nn.Conv2d(8, 1, 3, padding=1)
            )
        else:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=(1, 4), mode='nearest'),
                nn.Conv2d(16, 8, 3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(8, 1, 3, padding=1)
            )
        
        
    def forward(self, x):
        
        x = self.lin(x)
        x = self.unflatten(x)
        x = self.conv(x)
        
        return x.unsqueeze(1)
        

class Encoder(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        
        self.__dict__.update(kwargs)
        
        self.dropout_layer = nn.Dropout(p=self.dropout)
        
        
        self.notes_pitch_emb = nn.Linear(self.d_token_pitches, 
                                               self.d//2)
        
        self.bn_npe = nn.BatchNorm1d(num_features=self.d//2)
        
        self.drums_pitch_emb = nn.Linear(self.d_token_pitches, 
                                               self.d//2)
        
        self.bn_dpe = nn.BatchNorm1d(num_features=self.d//2)
        
        self.dur_emb = nn.Linear(self.d_token_dur, self.d//2)
        
        self.bn_de = nn.BatchNorm1d(num_features=self.d//2)
        
        self.chord_encoder = nn.Linear(self.d * (self.max_simu_notes-1),
                                       self.d)

        # Graph encoder
        self.graph_encoder = GraphEncoder(dropout=self.dropout, 
                                          input_dim=self.d,
                                          hidden_dim=self.d,
                                          n_layers=self.gnn_n_layers,
                                          num_relations=self.n_relations,
                                          batch_norm=self.batch_norm)
        
        gate_nn = nn.Sequential(
            MLP(input_dim=self.d, output_dim=1, num_layers=1, act=False,
                      dropout=self.dropout),
            nn.BatchNorm1d(1)
        )
        self.graph_attention = GlobalAttention(gate_nn)
        
        #self.context_bar_rnn = nn.GRU(input_size=self.d,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              bidirectional=True,
        #                              batch_first=True, 
        #                              dropout=self.dropout)
        
        self.bars_encoder_attr = nn.Linear(self.n_bars*self.d,
                                           self.d)
        
        
        self.cnn_encoder = CNNEncoder(dense_dim=self.d,
                                      output_dim=self.d,
                                      dropout=0,
                                      batch_norm=self.batch_norm)
        
        self.bars_encoder_struct = nn.Linear(self.n_bars*self.d,
                                             self.d)
        #self.struct_bar_rnn = nn.GRU(input_size=self.d,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              batch_first=True, 
        #                              dropout=self.dropout)
        
        self.linear_merge = nn.Linear(2*self.d, self.d)
        
        self.bn_lm = nn.BatchNorm1d(num_features=self.d)
        
        # Linear layers that compute the final mu and log_var
        # Todo: as parameters
        self.linear_mu = nn.Linear(self.d, self.d)
        self.linear_log_var = nn.Linear(self.d, self.d)

        
    def forward(self, x_seq, x_acts, x_graph, src_mask):
        
        # No start of seq token
        x_seq = x_seq[:, 1:, :]
        
        # Get drums and non drums tensors
        drums = x_seq[x_graph.is_drum]
        src_mask_drums = src_mask[x_graph.is_drum]
        non_drums = x_seq[torch.logical_not(x_graph.is_drum)]
        src_mask_non_drums = src_mask[torch.logical_not(x_graph.is_drum)]
        
        # Permute dimensions to batch_first = False
        #drums = drums.permute(1, 0, 2)
        #non_drums = non_drums.permute(1, 0, 2)
        
        # Compute note/drums embeddings
        s = drums.size()
        drums_pitch = self.drums_pitch_emb(drums[..., :self.d_token_pitches])
        drums_pitch = self.bn_dpe(drums_pitch.view(-1, self.d//2))
        drums_pitch = drums_pitch.view(s[0], s[1], self.d//2)
        drums_dur = self.dur_emb(drums[..., self.d_token_pitches:])
        drums_dur = self.bn_de(drums_dur.view(-1, self.d//2))
        drums_dur = drums_dur.view(s[0], s[1], self.d//2)
        drums = torch.cat((drums_pitch, drums_dur), dim=-1)
        #drums = self.dropout_layer(drums)
        # [n_nodes x max_simu_notes x d]
        
        s = non_drums.size()
        non_drums_pitch = self.notes_pitch_emb(non_drums[..., :self.d_token_pitches])
        non_drums_pitch = self.bn_npe(non_drums_pitch.view(-1, self.d//2))
        non_drums_pitch = non_drums_pitch.view(s[0], s[1], self.d//2)
        non_drums_dur = self.dur_emb(non_drums[..., self.d_token_pitches:])
        non_drums_dur = self.bn_de(non_drums_dur.view(-1, self.d//2))
        non_drums_dur = non_drums_dur.view(s[0], s[1], self.d//2)
        non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1)
        #non_drums = self.dropout_layer(non_drums)
        # [n_nodes x max_simu_notes x d]
        
        #len_drums = self.max_simu_notes - torch.sum(src_mask_drums, dim=1)
        #len_non_drums = self.max_simu_notes - torch.sum(src_mask_non_drums, dim=1)
        
        #drums = pack_padded_sequence(drums, len_drums.cpu().view(-1),
        #                             enforce_sorted=False)
        #non_drums = pack_padded_sequence(non_drums, len_non_drums.cpu().view(-1),
        #                                 enforce_sorted=False)

        # Compute chord embeddings both for drums and non drums
        drums = self.chord_encoder(drums.view(-1, self.d*(self.max_simu_notes-1)))
        non_drums = self.chord_encoder(non_drums.view(-1, self.d*(self.max_simu_notes-1)))
        drums = F.relu(drums)
        non_drums = F.relu(non_drums)
        drums = self.dropout_layer(drums)
        non_drums = self.dropout_layer(non_drums)
        # [n_nodes x d]
        
        #hidden = torch.zeros(drums.size(1), )
        #drums = self.chord_encoder_drums(drums)[-1]
        #non_drums = self.chord_encoder(non_drums)[-1]
        #drums = torch.mean(drums, dim=0)
        #non_drums = torch.mean(non_drums, dim=0)
        
        #drums = self.dropout_layer(drums)
        #non_drums = self.dropout_layer(non_drums)
        
        # Merge drums and non-drums
        #out = torch.zeros((x_seq.size(0), self.d_model), 
        #                  device=self.device)
        out = torch.zeros((x_seq.size(0), self.d), 
                          device=self.device, dtype=torch.half)
        out[x_graph.is_drum] = drums
        out[torch.logical_not(x_graph.is_drum)] = non_drums
        # [n_nodes x d]
        
        #x_graph.x = torch.cat((x_graph.node_features, out), 1)
        x_graph.x = out
        x_graph.distinct_bars = x_graph.bars + self.n_bars*x_graph.batch
        out = self.graph_encoder(x_graph)
        # [n_nodes x d]
        
        with torch.cuda.amp.autocast(enabled=False):
            out = self.graph_attention(out,
                                       batch=x_graph.distinct_bars)
            # [bs x n_bars x d]
            
        out = out.view(-1, self.n_bars * self.d)
        # [bs x n_bars * d]
        out_attr = self.bars_encoder_attr(out)
        # [bs x d]
        
        # Process structure
        out = self.cnn_encoder(x_acts.view(-1, self.n_tracks,
                                                self.resolution*4))
        # [bs * n_bars x d]
        out = out.view(-1, self.n_bars * self.d)
        # [bs x n_bars * d]
        out_struct = self.bars_encoder_struct(out)
        # [bs x d]
        
        # Merge attr state and struct state
        out = torch.cat((out_attr, out_struct), dim=1)
        out = self.dropout_layer(out)
        out = self.linear_merge(out)
        out = self.bn_lm(out)
        out = F.relu(out)

        # Compute mu and log(std^2)
        out = self.dropout_layer(out)
        mu = self.linear_mu(out)
        log_var = self.linear_log_var(out)
        
        return mu, log_var


class Decoder(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        
        self.__dict__.update(kwargs)
        
        self.dropout_layer = nn.Dropout(p=self.dropout)

        self.lin_divide = nn.Linear(self.d, 2 * self.d)
        
        self.bn_ld = nn.BatchNorm1d(num_features=2*self.d)
        
        self.bars_decoder_attr = nn.Linear(self.d, self.d * self.n_bars)
        #self.context_bar_rnn = nn.GRU(input_size=self.d,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              bidirectional=True,
        #                              batch_first=True,
        #                              dropout=self.dropout)
        
        self.bars_decoder_struct = nn.Linear(self.d, self.d * self.n_bars)
        #self.struct_bar_rnn = nn.GRU(input_size=self.d//2,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              batch_first=True,
        #                              dropout=self.dropout)
        
        self.cnn_decoder = CNNDecoder(input_dim=self.d,
                                      dense_dim=self.d,
                                      dropout=0,
                                      batch_norm=self.batch_norm)
        
        self.graph_decoder = GraphEncoder(dropout=self.dropout,
                                          input_dim=self.d,
                                          hidden_dim=self.d,
                                          n_layers=self.gnn_n_layers,
                                          num_relations=self.n_relations,
                                          batch_norm=self.batch_norm)
        
        #gate_nn = nn.Sequential(
        #    MLP(input_dim=self.d, output_dim=1, num_layers=1, act=False,
        #              dropout=self.dropout),
        #    nn.BatchNorm1d(1)
        #)
        #feat_nn = nn.Sequential(
        #    MLP(input_dim=self.d, output_dim=self.d//2, num_layers=1,
        #              dropout=self.dropout),
        #    nn.BatchNorm1d(self.d//2)
        #)
        #self.graph_attention = GlobalAttention(gate_nn, feat_nn)
        
        self.chord_decoder = nn.Linear(self.d, self.d*(self.max_simu_notes-1))
        #self.chord_decoder = nn.GRU(input_size=self.d,
        #                            hidden_size=self.d,
        #                            num_layers=1,
        #                            dropout=self.dropout)
        #self.chord_decoder_drums = nn.GRU(input_size=self.d,
        #                                  hidden_size=self.d,
        #                                  num_layers=1,
        #                                  dropout=self.dropout)
        
        # Pitch and dur linear layers
        self.drums_pitch_emb = nn.Linear(self.d//2, self.d_token_pitches)
        self.notes_pitch_emb = nn.Linear(self.d//2, self.d_token_pitches)
        self.dur_emb = nn.Linear(self.d//2, self.d_token_dur)


    def forward(self, z, x_seq, x_acts, x_graph, src_mask, tgt_mask,
                inference=False):
        # z: [bs x d]
        
        # Obtain z_structure and z_attributes from z
        #z = self.dropout_layer(z)
        z = self.lin_divide(z)
        z = self.bn_ld(z)
        z = F.relu(z)
        z = self.dropout_layer(z)
        # [bs x 2*d]
        
        out_struct = z[:, :self.d]
        # [bs x d] 
        out_struct = self.bars_decoder_struct(out_struct)
        # [bs x n_bars * d]
        
        out_struct = self.cnn_decoder(out_struct.reshape(-1, self.d))
        out_struct = out_struct.view(x_acts.size())
        
        # Decode attributes
        out = z[:, self.d:]
        # [bs x d]
        out = self.bars_decoder_attr(out)
        # [bs x n_bars * d]
        
        # Initialize node features with corresponding z_bar
        # and propagate with GNN
        _, counts = torch.unique(x_graph.distinct_bars, return_counts=True)
        out = out.view(-1, self.d)
        out = torch.repeat_interleave(out, counts, axis=0)
        # [n_nodes x d]
        
        # Add one-hot encoding of tracks
        # Todo: use also edge info
        #x_graph.x = torch.cat((x_graph.node_features, out), 1)
        x_graph.x = out
        out = self.graph_decoder(x_graph)
        # [n_nodes x d]
        #print("Node decodings:", node_decs.size())
        
        
        #out = torch.matmul(out, self.chord_decoder.weight)
        out = self.chord_decoder(out)
        # [n_nodes x max_simu_notes * d]
        out = out.view(-1, self.max_simu_notes-1, self.d)
        
        drums = out[x_graph.is_drum]
        non_drums = out[torch.logical_not(x_graph.is_drum)]
        # [n_nodes(dr/non_dr) x max_simu_notes x d]
        
        # Obtain final pitch and dur decodings
        # (softmax to be applied outside the model)
        non_drums = self.dropout_layer(non_drums)
        drums = self.dropout_layer(drums)
        
        drums_pitch = self.drums_pitch_emb(drums[..., :self.d//2])
        drums_dur = self.dur_emb(drums[..., self.d//2:])
        drums = torch.cat((drums_pitch, drums_dur), dim=-1)
        #drums_pitch = torch.matmul(drums[..., :self.d//2], self.drums_pitch_emb.weight)
        #drums_dur = torch.matmul(drums[..., self.d//2:], self.dur_emb.weight)
        #drums = torch.cat((drums_pitch, drums_dur), dim=-1)
        # [n_nodes(dr) x max_simu_notes x d_token]
        non_drums_pitch = self.notes_pitch_emb(non_drums[..., :self.d//2])
        non_drums_dur = self.dur_emb(non_drums[..., self.d//2:])
        non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1)
        #non_drums_pitch = torch.matmul(non_drums[..., :self.d//2], self.notes_pitch_emb.weight)
        #non_drums_dur = torch.matmul(non_drums[..., self.d//2:], self.dur_emb.weight)
        #non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1)
        # [n_nodes(non_dr) x max_simu_notes x d_token]
        
        # Merge drums and non-drums
        out = torch.zeros((x_seq.size(0), x_seq.size(1), x_seq.size(2)),
                          device=self.device, dtype=torch.half)
        out[x_graph.is_drum] = drums
        out[torch.logical_not(x_graph.is_drum)] = non_drums
        
        out = out.view(x_seq.size())

        return out, out_struct


class VAE(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        
        self.encoder = Encoder(**kwargs)
        self.decoder = Decoder(**kwargs)
    
    
    def forward(self, x_seq, x_acts, x_graph, src_mask, tgt_mask,
                inference=False):
        
        #src_mask = src_mask.view(-1, src_mask.size(-1))
        
        # Encoder pass
        mu, log_var = self.encoder(x_seq, x_acts, x_graph, src_mask)
        
        # Reparameterization trick
        z = torch.exp(0.5*log_var)
        z = z * torch.randn_like(z)
        #print("eps:", eps.size())
        z = z + mu
        
        # Shifting target sequence and mask for transformer decoder
        tgt = x_seq[..., :-1, :]
        src_mask = src_mask[:, :-1]
        
        # Decoder pass
        out = self.decoder(z, tgt, x_acts, x_graph, src_mask, tgt_mask,
                           inference=inference)
        
        return out, mu, log_var


Trainer

In [12]:
import torch.optim as optim
import matplotlib.pyplot as plt
import uuid
import copy
import time
from statistics import mean
from collections import defaultdict
import math


def generate_square_subsequent_mask(sz):
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


def append_dict(dest_d, source_d):
        
    for k, v in source_d.items():
        dest_d[k].append(v)


def sigmoid(x, m=1, a=1, z=0):
    return m / (1 + math.exp(-a*(x-z)))


class VAETrainer():
    
    def __init__(self, model_dir, checkpoint=False, model=None, optimizer=None,
                 init_lr=1e-4, lr_scheduler=None, device=torch.device("cuda"), 
                 print_every=1, save_every=1, eval_every=100, iters_to_accumulate=1,
                 **kwargs):
        
        self.__dict__.update(kwargs)
        
        self.model_dir = model_dir
        self.device = device
        self.print_every = print_every
        self.save_every = save_every
        self.eval_every = eval_every
        self.iters_to_accumulate = iters_to_accumulate
        
        # Criteria with ignored padding
        self.bce_unreduced = nn.BCEWithLogitsLoss(reduction='none')
        self.ce_p = nn.CrossEntropyLoss(ignore_index=130)
        self.ce_d = nn.CrossEntropyLoss(ignore_index=98)
        
        # Training stats
        self.tr_losses = defaultdict(list)
        self.tr_accuracies = defaultdict(list)
        self.val_losses = defaultdict(list)
        self.val_accuracies = defaultdict(list)
        self.lrs = []
        self.betas = []
        self.times = []
        
        self.model = model
        self.optimizer = optimizer
        self.init_lr = init_lr
        self.lr_scheduler = lr_scheduler
        
        self.tot_batches = 0
        self.beta = 0
        self.min_val_loss = np.inf
        
        if checkpoint:
            self.load_checkpoint()
        
    
    def train(self, trainloader, validloader=None, epochs=1,
              early_exit=None):
        
        self.model.train()
        
        print("Starting training.\n")
        
        if not self.times:
            start = time.time()
            self.times.append(start)
        
        progress_bar = tqdm(range(len(trainloader)))
        scaler = torch.cuda.amp.GradScaler()
                
        # Zero out the gradients
        self.optimizer.zero_grad()
        
        for epoch in range(epochs):
            
            self.cur_epoch = epoch
            
            for batch_idx, inputs in enumerate(trainloader):
                
                self.cur_batch_idx = batch_idx
                
                # Get the inputs
                x_graph = inputs.to(self.device)
                x_seq, x_acts, src_mask = x_graph.x_seq, x_graph.x_acts, x_graph.src_mask
                tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(self.device)
                
                inputs = (x_seq, x_acts, x_graph)

                with torch.cuda.amp.autocast():
                    # Forward pass, get the reconstructions
                    outputs, mu, log_var = self.model(x_seq, x_acts, x_graph, src_mask, tgt_mask)

                    # Compute the backprop loss and other required losses
                    tot_loss, losses = self._compute_losses(inputs, outputs, mu,
                                                             log_var)
                    tot_loss = tot_loss / self.iters_to_accumulate
                
                # Free GPU
                del x_seq
                del x_acts
                del src_mask
                del tgt_mask
                
                # Backprop
                scaler.scale(tot_loss).backward()
                #tot_loss.backward()
                #self.optimizer.step()
                    
                if (self.tot_batches + 1) % self.iters_to_accumulate == 0:
                    scaler.step(self.optimizer)
                    scaler.update()
                    self.optimizer.zero_grad()
                    if self.lr_scheduler is not None:
                        self.lr_scheduler.step()
                    if self.beta_update:    
                        self._update_beta()
                
                # Compute accuracies
                accs = self._compute_accuracies(inputs, outputs, x_graph.is_drum)
                
                # Update the stats
                append_dict(self.tr_losses, losses)
                last_lr = (self.lr_scheduler.lr 
                               if self.lr_scheduler is not None else self.init_lr)
                self.lrs.append(last_lr)
                self.betas.append(self.beta)
                append_dict(self.tr_accuracies, accs)
                now = time.time()
                self.times.append(now)
                
                # Print stats
                if (self.tot_batches + 1) % self.print_every == 0:
                    print("Training on batch {}/{} of epoch {}/{} complete."
                          .format(batch_idx+1, len(trainloader), epoch+1, epochs))
                    self._print_stats()
                    #print("Tot_loss: {:.4f} acts_loss: {:.4f} "
                          #.format(running_loss/self.print_every, acts_loss), end='')
                    #print("pitches_loss: {:.4f} dur_loss: {:.4f} kld_loss: {:.4f}"
                          #.format(pitches_loss, dur_loss, kld_loss))
                    print("\n----------------------------------------\n")
                    
                # ------------------------------------
                # EVAL ON VL SET EVERY N GRADIENT UPDATES
                # ------------------------------------
                
                if validloader is not None and (self.tot_batches + 1) % self.eval_every == 0:
                    
                    # Evaluate on val set
                    print("\nEvaluating on validation set...\n")
                    val_losses, val_accuracies = self.evaluate(validloader)
                    
                    # Update stats
                    append_dict(self.val_losses, val_losses)
                    append_dict(self.val_accuracies, val_accuracies)
                    
                    print("Val losses:")
                    print(val_losses)
                    print("Val accuracies:")
                    print(val_accuracies)
                    
                    # Save model if val loss (tot) reached a new minimum
                    tot_loss = val_losses['tot']
                    if tot_loss < self.min_val_loss:
                        print("\nValidation loss improved.")
                        print("Saving new best model to disk...\n")
                        self._save_model('best_model')
                        self.min_val_loss = tot_loss
                    
                    self.model.train()
                
                progress_bar.update(1)     
                    
                # When appropriate, save model and stats on disk
                if self.save_every > 0 and (self.tot_batches + 1) % self.save_every == 0:
                    print("\nSaving model to disk...\n")
                    self._save_model('checkpoint')
                
                # Stop prematurely if early_exit is set and reached
                if early_exit is not None and (self.tot_batches + 1) > early_exit:
                    break
                
                self.tot_batches += 1
            

        end = time.time()
        # Todo: self.__print_time()
        hours, rem = divmod(end-start, 3600)
        minutes, seconds = divmod(rem, 60)
        print("Training completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours),int(minutes),seconds))
        
        print("Saving model to disk...")
        self._save_model('checkpoint')
        
        print("Model saved.")
        
    
    def _update_beta(self):
        
        # Number of gradient updates
        i = self.tot_batches
        
        if i < self.anneal_start or i >= self.anneal_end:
            return
        
        n_steps = self.beta_max // self.step_size
        inc_every = (self.anneal_end - self.anneal_start) // n_steps
        
        curr_step = (i - self.anneal_start) // inc_every
        self.beta = self.step_size * (curr_step + 1)
    
        
    def _update_beta2(self):
        
        steps = self.tot_batches
        
        if steps < self.anneal_after:
            self.beta = 0
        else:
            
            # Compute steps in current cycle
            curr_cycle = (steps - self.anneal_after) // self.anneal_steps
            
            if curr_cycle >= self.n_cycles:
                self.beta = self.beta_max
                return
            
            steps_in_cycle = steps - self.anneal_after - curr_cycle * self.anneal_steps
            
            # Decide stage in cycle (sigmoidal increase or beta set to zero)
            if steps_in_cycle / self.anneal_steps < self.inc_to_zero_ratio:
                
                # Compute values to correctly shift and scale sigmoid
                sig_zero = self.anneal_steps * self.inc_to_zero_ratio / 2
                inc_steps = self.anneal_steps * self.inc_to_zero_ratio
                scale = 1 / ((inc_steps / 2) / self.sig_scaled_point)
                
                self.beta = sigmoid(steps_in_cycle, m=self.beta_max, a=scale, z=sig_zero)
                
            else:
                self.beta = 0
        
    
    def evaluate(self, loader):
        
        losses = defaultdict(list)
        accs = defaultdict(list)
        
        self.model.eval()
        progress_bar = tqdm(range(len(loader)))
        
        with torch.no_grad():
            for batch_idx, inputs in enumerate(loader):

                # Get the inputs and move them to device
                x_graph = inputs.to(self.device)
                x_seq, x_acts, src_mask = x_graph.x_seq, x_graph.x_acts, x_graph.src_mask
                #x_seq, x_acts, x_graph, src_mask = inputs
                #x_seq = x_seq.float().to(self.device)
                #x_acts = x_acts.to(self.device)
                #x_graph = x_graph.to(self.device)
                #src_mask = src_mask.to(self.device)
                tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(self.device)
                inputs = (x_seq, x_acts, x_graph)

                with torch.cuda.amp.autocast():
                    # Forward pass, get the reconstructions
                    outputs, mu, log_var = self.model(x_seq, x_acts, x_graph, src_mask, tgt_mask)

                    # Compute losses and accuracies wrt batch
                    _, losses_b = self._compute_losses(inputs, outputs, mu,
                                                         log_var)
                    
                accs_b = self._compute_accuracies(inputs, outputs, x_graph.is_drum)
                
                # Save losses and accuracies
                append_dict(losses, losses_b)
                append_dict(accs, accs_b)
                
                progress_bar.update(1)
        
        
        # Compute avg losses and accuracies
        avg_losses = {}
        for k, l in losses.items():
            avg_losses[k] = mean(l)
            
        avg_accs = {}
        for k, l in accs.items():
            avg_accs[k] = mean(l)
            
        return avg_losses, avg_accs
                
        
    
    def _compute_losses(self, inputs, outputs, mu, log_var):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs for transformer decoder loss and filter silences
        x_seq = x_seq[..., 1:, :]
        #x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
                
        # Compute the losses
        acts_loss = self.bce_unreduced(acts_rec.view(-1), x_acts.view(-1).float())
        #weights = torch.zeros(acts_loss.size()).to(device)
        #weights[x_acts.view(-1) == 1] = 0.9
        #weights[x_acts.view(-1) == 0] = 0.1
        #acts_loss = torch.mean(weights*acts_loss)
        acts_loss = torch.mean(acts_loss)
        
        pitches_loss = self.ce_p(seq_rec.reshape(-1, seq_rec.size(-1))[:, :131],
                          x_seq.reshape(-1, x_seq.size(-1))[:, :131].argmax(dim=1))
        dur_loss = self.ce_d(seq_rec.reshape(-1, seq_rec.size(-1))[:, 131:],
                          x_seq.reshape(-1, x_seq.size(-1))[:, 131:].argmax(dim=1))
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
        kld_loss = torch.mean(kld_loss)
        rec_loss = pitches_loss + dur_loss + acts_loss
        #rec_loss = acts_loss
        tot_loss = rec_loss + self.beta*kld_loss
        
        losses = {
            'tot': tot_loss.item(),
            'pitches': pitches_loss.item(),
            'dur': dur_loss.item(),
            'acts': acts_loss.item(),
            'rec': rec_loss.item(),
            'kld': kld_loss.item(),
            'beta*kld': self.beta*kld_loss.item()
        }
        
        return tot_loss, losses
            
            
    def _compute_accuracies(self, inputs, outputs, is_drum):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs and filter silences
        x_seq = x_seq[..., 1:, :]
        #x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
        
        notes_acc = self._note_accuracy(seq_rec, x_seq)
        pitches_acc = self._pitches_accuracy(seq_rec, x_seq)
        pitches_acc_drums = self._pitches_accuracy(seq_rec, x_seq, 
                                                   is_drum, drums=True)
        pitches_acc_non_drums = self._pitches_accuracy(seq_rec, x_seq,
                                                       is_drum, drums=False)
        dur_acc = self._dur_accuracy(seq_rec, x_seq)
        acts_acc = self._acts_accuracy(acts_rec, x_acts)
        acts_precision = self._acts_precision(acts_rec, x_acts)
        acts_recall = self._acts_recall(acts_rec, x_acts)
        acts_f1 = (2 * acts_recall * acts_precision / 
                       (acts_recall + acts_precision))
        
        accs = {
            'notes': notes_acc.item(),
            'pitches': pitches_acc.item(),
            'pitches_drums': pitches_acc_drums.item(),
            'pitches_non_drums': pitches_acc_non_drums.item(),
            'dur': dur_acc.item(),
            'acts_acc': acts_acc.item(),
            'acts_precision': acts_precision.item(),
            'acts_recall': acts_recall.item(),
            'acts_f1': acts_f1.item()
        }
        
        return accs
    
    
    def _note_accuracy(self, seq_rec, x_seq):
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        #print(torch.all(pitches_rec == 129))
        #print(pitches_rec)
        
        mask_p = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask_p)
        
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        #print(torch.all(dur_rec == 97))
        
        mask_d = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask_d)
        
        return torch.sum(torch.logical_and(preds_pitches, 
                                           preds_dur)) / torch.sum(mask_p)
    
    
    def _acts_precision(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        tp = torch.sum(x_acts[acts_rec == 1])
        
        return tp / torch.sum(acts_rec)
    
    
    def _acts_recall(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        tp = torch.sum(x_acts[acts_rec == 1])
        
        return tp / torch.sum(x_acts)
    
    
    def _acts_accuracy(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        #print("All zero acts?", torch.all(acts_rec == 0))
        #print("All one acts?", torch.all(acts_rec == 0))
        
        return torch.sum(acts_rec == x_acts) / x_acts.numel()
    
    
    def _pitches_accuracy(self, seq_rec, x_seq, is_drum=None, drums=None):
        
        if drums is not None:
            if drums:
                seq_rec = seq_rec[is_drum]
                x_seq = x_seq[is_drum]
            else:
                seq_rec = seq_rec[torch.logical_not(is_drum)]
                x_seq = x_seq[torch.logical_not(is_drum)]
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        #print("All EOS pitches?", torch.all(pitches_rec == 129))
        
        mask = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask)
        
        return torch.sum(preds_pitches) / torch.sum(mask)
    
    
    def _dur_accuracy(self, seq_rec, x_seq):
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        #print("All EOS durs?", torch.all(dur_rec == 97))
        
        mask = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask)
        
        return torch.sum(preds_dur) / torch.sum(mask)
    
    
    def _save_model(self, filename):
        path = os.path.join(self.model_dir, filename)
        torch.save({
            'epoch': self.cur_epoch,
            'batch': self.cur_batch_idx,
            'tot_batches': self.tot_batches,
            'betas': self.betas,
            'min_val_loss': self.min_val_loss,
            'print_every': self.print_every,
            'save_every': self.save_every,
            'eval_every': self.eval_every,
            'lrs': self.lrs,
            'tr_losses': self.tr_losses,
            'tr_accuracies': self.tr_accuracies,
            'val_losses': self.val_losses,
            'val_accuracies': self.val_accuracies,
            #'scheduler_state_dict': self.lr_scheduler.state_dict(), # Todo: fix
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }, path)
        
    
    def load(self):
        
        checkpoint = torch.load(os.path.join(self.model_dir, 'checkpoint'))
        
        self.cur_epoch = checkpoint['epoch']
        self.cur_batch_idx = checkpoint['batch']
        self.save_every = checkpoint['save_every']
        self.eval_every = checkpoint['eval_every']
        self.lrs = checkpoint['lrs']
        self.tr_losses = checkpoint['tr_losses']
        self.tr_accuracies = checkpoint['tr_accuracies']
        self.val_losses = checkpoint['val_losses']
        self.val_accuracies = checkpoint['val_accuracies']
        self.times = checkpoint['times']
        self.min_val_loss = checkpoint['min_val_loss']
        self.beta = checkpoint['beta']
        self.tot_batches = checkpoint['tot_batches']
        
        
    def _print_stats(self):
        
        hours, rem = divmod(self.times[-1]-self.times[0], 3600)
        minutes, seconds = divmod(rem, 60)
        print("Elapsed time from start (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours), int(minutes), seconds))
        
        avg_lr = mean(self.lrs[-self.print_every:])
        
        # Take mean of the last non-printed batches for each stat
        
        avg_losses = {}
        for k, l in self.tr_losses.items():
            avg_losses[k] = mean(l[-self.print_every:])
        
        avg_accs = {}
        for k, l in self.tr_accuracies.items():
            avg_accs[k] = mean(l[-self.print_every:])
        
        print("Losses:")
        print(avg_losses)
        print("Accuracies:")
        print(avg_accs)
        


Training

In [13]:
from torch.utils.data import Subset
from torch.utils.data import random_split

bs = 256
nw = 10
n_bars = 2

#bs = 32
#nw = 10
#n_bars = 16

ds_dir = "/data/cosenza/datasets/MMD/preprocessed_" + str(n_bars) + "bars_par/"

dataset = MIDIDataset(ds_dir, n_bars=n_bars)
ds_len = len(dataset)

print('Dataset len:', len(dataset))

train_len = int(0.7 * len(dataset)) 
valid_len = int(0.1 * len(dataset))
test_len = len(dataset) - train_len - valid_len
tr_set, vl_set, ts_set = random_split(dataset, (train_len, valid_len, test_len))

trainloader = DataLoader(tr_set, batch_size=bs, shuffle=True, num_workers=nw)
validloader = DataLoader(vl_set, batch_size=bs, shuffle=False, num_workers=nw)

tr_len = len(tr_set)
vl_len = len(vl_set)
ts_len = len(ts_set)

print('TR set len:', len(tr_set))
print('VL set len:', len(vl_set))
print('TS set len:', len(ts_set))

#n_samples = 128
#subset = Subset(dataset, np.arange(n_samples))
#loader = DataLoader(subset, batch_size=64, shuffle=True)

Dataset len: 27251322
TR set len: 19075925
VL set len: 2725132
TS set len: 5450265


In [14]:
tr_set[0]

Data(edge_index=[2, 142], edge_attrs=[142, 33], num_nodes=35, node_features=[35, 4], is_drum=[35], batch=[35], ptr=[3], bars=[35], x_seq=[35, 16, 230], x_acts=[2, 4, 32], src_mask=[35, 16])

In [15]:
import torch
torch.cuda.set_device(0)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")
print("Device:", device)
print("Current device idx:", torch.cuda.current_device())

Device: cuda
Current device idx: 0


In [16]:
#!rm models/vae

In [17]:
from prettytable import PrettyTable


def print_params(model):
    
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    
    for name, parameter in model.named_parameters():
        
        if not parameter.requires_grad:
            continue
            
        param = parameter.numel()
        table.add_row([name, param])
        total_params += param
        
    print(table)
    print(f"Total Trainable Params: {total_params}")
    
    return total_params

In [18]:
import math
import torch
from typing import Optional
from torch.optim import Optimizer

from lr_scheduler.lr_scheduler import LearningRateScheduler

class TransformerLRScheduler(LearningRateScheduler):
    r"""
    Transformer Learning Rate Scheduler proposed in "Attention Is All You Need"
    Args:
        optimizer (Optimizer): Optimizer.
        peak_lr (float): Maximum learning rate.
        final_lr_scale (float): Final learning rate scale
        warmup_steps (int): Warmup the learning rate linearly for the first N updates
        decay_steps (int): Steps in decay stage
    """
    def __init__(
            self,
            optimizer: Optimizer,
            peak_lr: float,
            final_lr_scale: float,
            warmup_steps: int,
            decay_steps: int,
            init_lr = 0
    ) -> None:
        assert isinstance(warmup_steps, int), "warmup_steps should be integer type"
        assert isinstance(decay_steps, int), "total_steps should be integer type"

        super(TransformerLRScheduler, self).__init__(optimizer, init_lr)
        self.peak_lr = peak_lr
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps

        self.warmup_rate = self.peak_lr / self.warmup_steps
        self.decay_factor = -math.log(final_lr_scale) / self.decay_steps

        self.update_steps = 0

    def _decide_stage(self):
        if self.update_steps < self.warmup_steps:
            return 0, self.update_steps

        if self.warmup_steps <= self.update_steps:
            return 1, self.update_steps - self.warmup_steps

        return 2, None

    def step(self, val_loss: Optional[torch.FloatTensor] = None):
        self.update_steps += 1
        stage, steps_in_stage = self._decide_stage()

        if stage == 0:
            #self.lr = self.update_steps * self.warmup_rate
            self.lr = self.peak_lr
        elif stage == 1:
            # gamma = final_lr_scale^(1/decay_steps)
            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
        else:
            raise ValueError("Undefined stage")

        self.set_lr(self.optimizer, self.lr)

        return self.lr

In [19]:
parameters = {
    
    'training': {
        'batch_size': bs,
        'num_workers': nw,
        'ds_len': ds_len,
        'tr_len': tr_len,
        'vl_len': vl_len,
        'ts_len': ts_len,
    },
    
    'model': {
        
        'dropout': 0, # was 0.1
        'batch_norm': True,
        
         
        'gnn_n_layers': 8, # was 3
        'actsnn_n_layers': 2,
        'd': 512,
        
        'rnn_n_layers': 1,
        'k_isgn': 3,
        
        'd_token': 230,
        'd_token_pitches': 131,
        'd_token_dur': 99,
        'n_bars': n_bars,
        'n_relations': 6,
        'n_tracks': 4,
        'resolution': 8,
        'max_simu_notes': 16
    },
    
    'scheduler': {
        'peak_lr': 1e-4, # 1e-4 2bars, 5e-5 16bars
        'final_lr_scale': 0.01,
        'warmup_steps': 8000,
        'decay_steps': 800000 # 500000 16bars
    },
    
    'optimizer': {
        'betas': (0.9, 0.98),
        'eps': 1e-09,
        'lr': 5e-6
    },
    
    'beta_annealing': {
        'beta_update': True,
        
        #'anneal_after': 4000,
        #'beta_max': 0.01,
        #'anneal_steps': 500000,
        #'inc_to_zero_ratio': 1,
        #'sig_scaled_point': 7,
        #'n_cycles': 4
        
        'anneal_start': 40000,
        'beta_max': 0.01,
        'step_size': 0.001,
        'anneal_end': 500000,
        #'inc_to_zero_ratio': 1,
        #'sig_scaled_point': 7
    }
}

Training from scratch:

In [None]:
import torch_geometric

print("Creating the model and moving it to the specified device...")

# Create model dir
models_dir = 'models/'
model_name = 'MMD16'
model_dir = os.path.join(models_dir, model_name)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=False)

# Creating the model
vae = VAE(**parameters['model'], device=device).to(device)
#vae = torch_geometric.nn.DataParallel(vae, device_ids=[0, 1, 2])
print_params(vae)
print()

# Creating optimizer and scheduler
optimizer = optim.Adam(vae.parameters(), **parameters['optimizer'])
scheduler = TransformerLRScheduler(
    optimizer=optimizer,
    **parameters['scheduler']
)

# Save parameters
params_path = os.path.join(model_dir, 'params')
torch.save(parameters, params_path)

print('--------------------------------------------------\n')

trainer = VAETrainer(
    model_dir,
    model=vae,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    save_every=100,
    print_every=1,
    eval_every=18631,
    iters_to_accumulate=1,
    device=device,
    **parameters['beta_annealing']
)
trainer.train(trainloader, validloader=None, epochs=100)

Creating the model and moving it to the specified device...
+---------------------------------------------------+------------+
|                      Modules                      | Parameters |
+---------------------------------------------------+------------+
|           encoder.notes_pitch_emb.weight          |   33536    |
|            encoder.notes_pitch_emb.bias           |    256     |
|               encoder.bn_npe.weight               |    256     |
|                encoder.bn_npe.bias                |    256     |
|           encoder.drums_pitch_emb.weight          |   33536    |
|            encoder.drums_pitch_emb.bias           |    256     |
|               encoder.bn_dpe.weight               |    256     |
|                encoder.bn_dpe.bias                |    256     |
|               encoder.dur_emb.weight              |   25344    |
|                encoder.dur_emb.bias               |    256     |
|                encoder.bn_de.weight               |    256     |
| 

  0%|          | 0/242302 [00:00<?, ?it/s]

Training on batch 1/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:09.06
Losses:
{'tot': 11.794047355651855, 'pitches': 5.367997646331787, 'dur': 5.738558769226074, 'acts': 0.6874915361404419, 'rec': 11.794047355651855, 'kld': 66.8394775390625, 'beta*kld': 0.0}
Accuracies:
{'notes': 2.9663027817150578e-05, 'pitches': 0.003945182543247938, 'pitches_drums': 0.001259544980712235, 'pitches_non_drums': 0.005569041706621647, 'dur': 0.0076530613005161285, 'acts_acc': 0.573455810546875, 'acts_precision': 0.21410371363162994, 'acts_recall': 0.45041778683662415, 'acts_f1': 0.2902422845363617}

----------------------------------------

Training on batch 2/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:10.16
Losses:
{'tot': 11.790508270263672, 'pitches': 5.3876214027404785, 'dur': 5.713562965393066, 'acts': 0.6893234252929688, 'rec': 11.790508270263672, 'kld': 67.15559387207031, 'beta*kld': 0.0}
Accuracies:
{'notes': 3.186540197930299e-05, 'pitches':

Training on batch 14/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:20.43
Losses:
{'tot': 6.187596321105957, 'pitches': 3.328681707382202, 'dur': 2.2161800861358643, 'acts': 0.6427344679832458, 'rec': 6.187596321105957, 'kld': 67.92916870117188, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43362948298454285, 'pitches': 0.4365137219429016, 'pitches_drums': 0.4990752935409546, 'pitches_non_drums': 0.39874526858329773, 'dur': 0.6125845313072205, 'acts_acc': 0.672882080078125, 'acts_precision': 0.2384955734014511, 'acts_recall': 0.3599933087825775, 'acts_f1': 0.2869119644165039}

----------------------------------------

Training on batch 15/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:21.16
Losses:
{'tot': 6.724354267120361, 'pitches': 3.600785732269287, 'dur': 2.4881935119628906, 'acts': 0.6353751420974731, 'rec': 6.724354267120361, 'kld': 66.74177551269531, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37824317812919617, 'pitches': 0.382044941186

Training on batch 27/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:30.19
Losses:
{'tot': 5.63378381729126, 'pitches': 3.0901129245758057, 'dur': 1.9402254819869995, 'acts': 0.6034456491470337, 'rec': 5.63378381729126, 'kld': 77.23260498046875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.32133838534355164, 'pitches': 0.39968767762184143, 'pitches_drums': 0.4314066767692566, 'pitches_non_drums': 0.38010531663894653, 'dur': 0.5471159219741821, 'acts_acc': 0.7718353271484375, 'acts_precision': 0.31190571188926697, 'acts_recall': 0.28239619731903076, 'acts_f1': 0.29641830921173096}

----------------------------------------

Training on batch 28/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:31.00
Losses:
{'tot': 5.380305767059326, 'pitches': 3.101240634918213, 'dur': 1.6711406707763672, 'acts': 0.6079244613647461, 'rec': 5.380305767059326, 'kld': 74.31295013427734, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3437096178531647, 'pitches': 0.406282484

Training on batch 40/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:40.25
Losses:
{'tot': 5.210914611816406, 'pitches': 3.022393226623535, 'dur': 1.6054575443267822, 'acts': 0.583064079284668, 'rec': 5.210914611816406, 'kld': 75.76699829101562, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.35792839527130127, 'pitches': 0.39786747097969055, 'pitches_drums': 0.39920949935913086, 'pitches_non_drums': 0.3967825472354889, 'dur': 0.5833054184913635, 'acts_acc': 0.808868408203125, 'acts_precision': 0.4232954680919647, 'acts_recall': 0.20442461967468262, 'acts_f1': 0.2757025361061096}

----------------------------------------

Training on batch 41/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:40.94
Losses:
{'tot': 4.956630229949951, 'pitches': 2.8287503719329834, 'dur': 1.5522159337997437, 'acts': 0.5756641626358032, 'rec': 4.956630229949951, 'kld': 74.87813568115234, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3670280873775482, 'pitches': 0.42069143056

Training on batch 53/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:49.83
Losses:
{'tot': 4.848611354827881, 'pitches': 2.7593846321105957, 'dur': 1.524255633354187, 'acts': 0.5649709701538086, 'rec': 4.848611354827881, 'kld': 75.85971069335938, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38456350564956665, 'pitches': 0.42216068506240845, 'pitches_drums': 0.4537188708782196, 'pitches_non_drums': 0.4019453525543213, 'dur': 0.602336049079895, 'acts_acc': 0.835174560546875, 'acts_precision': 0.48144328594207764, 'acts_recall': 0.17525096237659454, 'acts_f1': 0.2569640278816223}

----------------------------------------

Training on batch 54/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:50.62
Losses:
{'tot': 4.87716007232666, 'pitches': 2.745396375656128, 'dur': 1.5661741495132446, 'acts': 0.5655896663665771, 'rec': 4.87716007232666, 'kld': 72.08973693847656, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3741186857223511, 'pitches': 0.41066741943359

Training on batch 66/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:59.84
Losses:
{'tot': 4.44960880279541, 'pitches': 2.533803939819336, 'dur': 1.3580811023712158, 'acts': 0.5577236413955688, 'rec': 4.44960880279541, 'kld': 76.64447021484375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38884925842285156, 'pitches': 0.4410989582538605, 'pitches_drums': 0.479467511177063, 'pitches_non_drums': 0.4141824543476105, 'dur': 0.6264380216598511, 'acts_acc': 0.8286590576171875, 'acts_precision': 0.6864931583404541, 'acts_recall': 0.17822100222110748, 'acts_f1': 0.2829779088497162}

----------------------------------------

Training on batch 67/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:00.55
Losses:
{'tot': 4.590671539306641, 'pitches': 2.64908504486084, 'dur': 1.3951010704040527, 'acts': 0.5464855432510376, 'rec': 4.590671539306641, 'kld': 78.12358093261719, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37791159749031067, 'pitches': 0.424965739250183

Training on batch 79/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:09.65
Losses:
{'tot': 4.731412410736084, 'pitches': 2.654513120651245, 'dur': 1.5311245918273926, 'acts': 0.5457749962806702, 'rec': 4.731412410736084, 'kld': 80.70964050292969, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.35188913345336914, 'pitches': 0.4127041697502136, 'pitches_drums': 0.45563796162605286, 'pitches_non_drums': 0.3894372582435608, 'dur': 0.5816531777381897, 'acts_acc': 0.84393310546875, 'acts_precision': 0.7057845592498779, 'acts_recall': 0.18515610694885254, 'acts_f1': 0.2933535873889923}

----------------------------------------

Training on batch 80/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:10.50
Losses:
{'tot': 4.46766471862793, 'pitches': 2.574687957763672, 'dur': 1.341643214225769, 'acts': 0.5513336658477783, 'rec': 4.46766471862793, 'kld': 79.855224609375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3723582923412323, 'pitches': 0.4240924119949341, 

Training on batch 92/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:19.82
Losses:
{'tot': 4.314304828643799, 'pitches': 2.52597713470459, 'dur': 1.2499347925186157, 'acts': 0.5383930206298828, 'rec': 4.314304828643799, 'kld': 84.46043395996094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3894313871860504, 'pitches': 0.4362627863883972, 'pitches_drums': 0.45673882961273193, 'pitches_non_drums': 0.4216929078102112, 'dur': 0.6466720104217529, 'acts_acc': 0.8392486572265625, 'acts_precision': 0.6624921560287476, 'acts_recall': 0.18283936381340027, 'acts_f1': 0.2865849435329437}

----------------------------------------

Training on batch 93/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:20.57
Losses:
{'tot': 4.555361270904541, 'pitches': 2.7074015140533447, 'dur': 1.3110780715942383, 'acts': 0.5368820428848267, 'rec': 4.555361270904541, 'kld': 84.8551254272461, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37055885791778564, 'pitches': 0.399758577346

Training on batch 105/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:31.95
Losses:
{'tot': 4.603834629058838, 'pitches': 2.6855297088623047, 'dur': 1.3764151334762573, 'acts': 0.5418896675109863, 'rec': 4.603834629058838, 'kld': 87.55838012695312, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.36901649832725525, 'pitches': 0.4042353630065918, 'pitches_drums': 0.43932226300239563, 'pitches_non_drums': 0.38058528304100037, 'dur': 0.60957932472229, 'acts_acc': 0.824188232421875, 'acts_precision': 0.6453018188476562, 'acts_recall': 0.1665060967206955, 'acts_f1': 0.26470962166786194}

----------------------------------------

Training on batch 106/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:32.74
Losses:
{'tot': 4.415637493133545, 'pitches': 2.5784032344818115, 'dur': 1.2995980978012085, 'acts': 0.5376364588737488, 'rec': 4.415637493133545, 'kld': 86.80528259277344, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3751640319824219, 'pitches': 0.412823885

Training on batch 118/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:41.96
Losses:
{'tot': 4.3781418800354, 'pitches': 2.5159754753112793, 'dur': 1.3237794637680054, 'acts': 0.5383867025375366, 'rec': 4.3781418800354, 'kld': 90.45896911621094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37557917833328247, 'pitches': 0.42179518938064575, 'pitches_drums': 0.4641196131706238, 'pitches_non_drums': 0.3982337713241577, 'dur': 0.6223713755607605, 'acts_acc': 0.8221435546875, 'acts_precision': 0.6631812453269958, 'acts_recall': 0.16497202217578888, 'acts_f1': 0.26421764492988586}

----------------------------------------

Training on batch 119/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:42.77
Losses:
{'tot': 4.3431572914123535, 'pitches': 2.5472359657287598, 'dur': 1.2632368803024292, 'acts': 0.5326844453811646, 'rec': 4.3431572914123535, 'kld': 90.69886779785156, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37326672673225403, 'pitches': 0.4106163680

Training on batch 131/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:52.24
Losses:
{'tot': 4.249125003814697, 'pitches': 2.4804680347442627, 'dur': 1.2419025897979736, 'acts': 0.5267542600631714, 'rec': 4.249125003814697, 'kld': 91.9327163696289, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3896908462047577, 'pitches': 0.4217754602432251, 'pitches_drums': 0.4608994126319885, 'pitches_non_drums': 0.39278629422187805, 'dur': 0.6273742318153381, 'acts_acc': 0.8334503173828125, 'acts_precision': 0.7192118167877197, 'acts_recall': 0.18931841850280762, 'acts_f1': 0.29973694682121277}

----------------------------------------

Training on batch 132/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:53.04
Losses:
{'tot': 4.248368263244629, 'pitches': 2.4740004539489746, 'dur': 1.2512954473495483, 'acts': 0.5230720043182373, 'rec': 4.248368263244629, 'kld': 95.60496520996094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.373824805021286, 'pitches': 0.417232930

Training on batch 144/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:02.19
Losses:
{'tot': 4.243363857269287, 'pitches': 2.4247536659240723, 'dur': 1.3118252754211426, 'acts': 0.5067849159240723, 'rec': 4.243363857269287, 'kld': 97.51380920410156, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.39807236194610596, 'pitches': 0.44137343764305115, 'pitches_drums': 0.5069432258605957, 'pitches_non_drums': 0.39707890152931213, 'dur': 0.6377520561218262, 'acts_acc': 0.85870361328125, 'acts_precision': 0.7392328381538391, 'acts_recall': 0.2056538462638855, 'acts_f1': 0.32178688049316406}

----------------------------------------

Training on batch 145/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:02.96
Losses:
{'tot': 4.453500270843506, 'pitches': 2.561290740966797, 'dur': 1.3752692937850952, 'acts': 0.516940176486969, 'rec': 4.453500270843506, 'kld': 95.82859802246094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38380900025367737, 'pitches': 0.410095125

Training on batch 157/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:12.26
Losses:
{'tot': 4.167882919311523, 'pitches': 2.394155502319336, 'dur': 1.2623701095581055, 'acts': 0.5113575458526611, 'rec': 4.167882919311523, 'kld': 98.11409759521484, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38439664244651794, 'pitches': 0.4386520981788635, 'pitches_drums': 0.48647505044937134, 'pitches_non_drums': 0.40850475430488586, 'dur': 0.6485815048217773, 'acts_acc': 0.8468780517578125, 'acts_precision': 0.780868411064148, 'acts_recall': 0.19688676297664642, 'acts_f1': 0.31448087096214294}

----------------------------------------

Training on batch 158/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:13.03
Losses:
{'tot': 4.222934722900391, 'pitches': 2.4447827339172363, 'dur': 1.2678546905517578, 'acts': 0.5102970600128174, 'rec': 4.222934722900391, 'kld': 98.77474975585938, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38327455520629883, 'pitches': 0.436473

Training on batch 170/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:22.33
Losses:
{'tot': 4.22914457321167, 'pitches': 2.5147950649261475, 'dur': 1.205750584602356, 'acts': 0.508598804473877, 'rec': 4.22914457321167, 'kld': 102.0362319946289, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3682056963443756, 'pitches': 0.4151398539543152, 'pitches_drums': 0.46023011207580566, 'pitches_non_drums': 0.3890816271305084, 'dur': 0.6576279401779175, 'acts_acc': 0.845123291015625, 'acts_precision': 0.8029172420501709, 'acts_recall': 0.1983242630958557, 'acts_f1': 0.3180810511112213}

----------------------------------------

Training on batch 171/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:23.03
Losses:
{'tot': 4.317383766174316, 'pitches': 2.4652628898620605, 'dur': 1.3555631637573242, 'acts': 0.4965576231479645, 'rec': 4.317383766174316, 'kld': 102.10542297363281, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3745667040348053, 'pitches': 0.4211921393871

Training on batch 183/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:32.28
Losses:
{'tot': 4.089475631713867, 'pitches': 2.424278497695923, 'dur': 1.1617461442947388, 'acts': 0.5034510493278503, 'rec': 4.089475631713867, 'kld': 107.94913482666016, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37917080521583557, 'pitches': 0.44058141112327576, 'pitches_drums': 0.4709521532058716, 'pitches_non_drums': 0.416938841342926, 'dur': 0.6519858241081238, 'acts_acc': 0.8423614501953125, 'acts_precision': 0.7887700796127319, 'acts_recall': 0.19568823277950287, 'acts_f1': 0.31357958912849426}

----------------------------------------

Training on batch 184/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:33.18
Losses:
{'tot': 4.240344047546387, 'pitches': 2.519108533859253, 'dur': 1.2087699174880981, 'acts': 0.5124655961990356, 'rec': 4.240344047546387, 'kld': 108.68437194824219, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3636079430580139, 'pitches': 0.4228878

Training on batch 196/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:42.59
Losses:
{'tot': 4.071811199188232, 'pitches': 2.339733362197876, 'dur': 1.2357778549194336, 'acts': 0.4962998032569885, 'rec': 4.071811199188232, 'kld': 109.28887176513672, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37720292806625366, 'pitches': 0.42544686794281006, 'pitches_drums': 0.46128538250923157, 'pitches_non_drums': 0.4004063010215759, 'dur': 0.6343781352043152, 'acts_acc': 0.85150146484375, 'acts_precision': 0.7732188105583191, 'acts_recall': 0.21233294904232025, 'acts_f1': 0.3331734240055084}

----------------------------------------

Training on batch 197/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:43.34
Losses:
{'tot': 3.9209794998168945, 'pitches': 2.2161991596221924, 'dur': 1.2131420373916626, 'acts': 0.4916382431983948, 'rec': 3.9209794998168945, 'kld': 109.23309326171875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.426115483045578, 'pitches': 0.470449

Training on batch 209/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:56.55
Losses:
{'tot': 4.0305023193359375, 'pitches': 2.390475034713745, 'dur': 1.1433145999908447, 'acts': 0.49671244621276855, 'rec': 4.0305023193359375, 'kld': 110.73977661132812, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3883543908596039, 'pitches': 0.43259134888648987, 'pitches_drums': 0.4648107588291168, 'pitches_non_drums': 0.41092649102211, 'dur': 0.6561340689659119, 'acts_acc': 0.8471832275390625, 'acts_precision': 0.7528562545776367, 'acts_recall': 0.21407198905944824, 'acts_f1': 0.33335551619529724}

----------------------------------------

Training on batch 210/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:57.28
Losses:
{'tot': 4.017096042633057, 'pitches': 2.320265054702759, 'dur': 1.209691047668457, 'acts': 0.48714008927345276, 'rec': 4.017096042633057, 'kld': 111.80357360839844, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41015681624412537, 'pitches': 0.45546

Training on batch 222/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:06.56
Losses:
{'tot': 4.122218132019043, 'pitches': 2.4913392066955566, 'dur': 1.1310157775878906, 'acts': 0.4998631179332733, 'rec': 4.122218132019043, 'kld': 117.43737030029297, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3752436637878418, 'pitches': 0.4169407784938812, 'pitches_drums': 0.43119195103645325, 'pitches_non_drums': 0.40782177448272705, 'dur': 0.6705348491668701, 'acts_acc': 0.840362548828125, 'acts_precision': 0.7228506803512573, 'acts_recall': 0.21234527230262756, 'acts_f1': 0.32826048135757446}

----------------------------------------

Training on batch 223/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:07.25
Losses:
{'tot': 4.219362735748291, 'pitches': 2.5025596618652344, 'dur': 1.2379165887832642, 'acts': 0.4788867235183716, 'rec': 4.219362735748291, 'kld': 115.29954528808594, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3760586380958557, 'pitches': 0.41840

Training on batch 235/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:16.48
Losses:
{'tot': 3.9602012634277344, 'pitches': 2.3314285278320312, 'dur': 1.1393814086914062, 'acts': 0.48939135670661926, 'rec': 3.9602012634277344, 'kld': 121.59980010986328, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4228873550891876, 'pitches': 0.45529913902282715, 'pitches_drums': 0.5072494745254517, 'pitches_non_drums': 0.4115773141384125, 'dur': 0.6707901954650879, 'acts_acc': 0.846466064453125, 'acts_precision': 0.7186411023139954, 'acts_recall': 0.21395228803157806, 'acts_f1': 0.32973620295524597}

----------------------------------------

Training on batch 236/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:17.26
Losses:
{'tot': 3.975447177886963, 'pitches': 2.351484775543213, 'dur': 1.1423887014389038, 'acts': 0.48157355189323425, 'rec': 3.975447177886963, 'kld': 120.26751708984375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3947693705558777, 'pitches': 0.431

Training on batch 248/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:26.69
Losses:
{'tot': 3.9601008892059326, 'pitches': 2.399667501449585, 'dur': 1.0783205032348633, 'acts': 0.4821127951145172, 'rec': 3.9601008892059326, 'kld': 125.782958984375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3925654888153076, 'pitches': 0.43587902188301086, 'pitches_drums': 0.4889172911643982, 'pitches_non_drums': 0.40293368697166443, 'dur': 0.6726007461547852, 'acts_acc': 0.8555755615234375, 'acts_precision': 0.7925228476524353, 'acts_recall': 0.2411637008190155, 'acts_f1': 0.3697982728481293}

----------------------------------------

Training on batch 249/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:27.52
Losses:
{'tot': 4.002651691436768, 'pitches': 2.3537709712982178, 'dur': 1.1540967226028442, 'acts': 0.4947836697101593, 'rec': 4.002651691436768, 'kld': 125.6275863647461, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3747212588787079, 'pitches': 0.42050960

Training on batch 261/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:36.78
Losses:
{'tot': 4.049405097961426, 'pitches': 2.3421740531921387, 'dur': 1.236528754234314, 'acts': 0.47070208191871643, 'rec': 4.049405097961426, 'kld': 126.07728576660156, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3912124037742615, 'pitches': 0.42338132858276367, 'pitches_drums': 0.46018022298812866, 'pitches_non_drums': 0.4013204574584961, 'dur': 0.6311318874359131, 'acts_acc': 0.86700439453125, 'acts_precision': 0.8295924663543701, 'acts_recall': 0.2565765380859375, 'acts_f1': 0.391935259103775}

----------------------------------------

Training on batch 262/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:37.58
Losses:
{'tot': 4.215815544128418, 'pitches': 2.509296178817749, 'dur': 1.2227437496185303, 'acts': 0.48377537727355957, 'rec': 4.215815544128418, 'kld': 128.83047485351562, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37686216831207275, 'pitches': 0.41457334

Training on batch 274/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:46.90
Losses:
{'tot': 3.916898250579834, 'pitches': 2.342641830444336, 'dur': 1.0926934480667114, 'acts': 0.48156318068504333, 'rec': 3.916898250579834, 'kld': 129.9416046142578, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3951004147529602, 'pitches': 0.4431367516517639, 'pitches_drums': 0.4981667101383209, 'pitches_non_drums': 0.4075809419155121, 'dur': 0.6731802225112915, 'acts_acc': 0.84686279296875, 'acts_precision': 0.7808098793029785, 'acts_recall': 0.2226591855287552, 'acts_f1': 0.34650689363479614}

----------------------------------------

Training on batch 275/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:47.80
Losses:
{'tot': 3.978350877761841, 'pitches': 2.310793876647949, 'dur': 1.2003308534622192, 'acts': 0.46722620725631714, 'rec': 3.978350877761841, 'kld': 131.35279846191406, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3944099247455597, 'pitches': 0.4370113313

Training on batch 287/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:56.83
Losses:
{'tot': 3.7538304328918457, 'pitches': 2.2072696685791016, 'dur': 1.0685548782348633, 'acts': 0.4780058264732361, 'rec': 3.7538304328918457, 'kld': 134.17079162597656, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4422599673271179, 'pitches': 0.47980791330337524, 'pitches_drums': 0.556786060333252, 'pitches_non_drums': 0.42360877990722656, 'dur': 0.697070300579071, 'acts_acc': 0.84600830078125, 'acts_precision': 0.7903422713279724, 'acts_recall': 0.21562577784061432, 'acts_f1': 0.33881425857543945}

----------------------------------------

Training on batch 288/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:57.58
Losses:
{'tot': 3.837916374206543, 'pitches': 2.2714147567749023, 'dur': 1.096052885055542, 'acts': 0.47044867277145386, 'rec': 3.837916374206543, 'kld': 139.96005249023438, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4050202965736389, 'pitches': 0.448964

Training on batch 300/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:06.80
Losses:
{'tot': 3.980302572250366, 'pitches': 2.2542965412139893, 'dur': 1.2605749368667603, 'acts': 0.4654310345649719, 'rec': 3.980302572250366, 'kld': 138.36770629882812, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42741429805755615, 'pitches': 0.46502941846847534, 'pitches_drums': 0.5684706568717957, 'pitches_non_drums': 0.40663155913352966, 'dur': 0.6429337859153748, 'acts_acc': 0.8594818115234375, 'acts_precision': 0.7108287811279297, 'acts_recall': 0.2216353863477707, 'acts_f1': 0.3379106819629669}

----------------------------------------


Saving model to disk...

Training on batch 301/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:11.13
Losses:
{'tot': 4.097280979156494, 'pitches': 2.387305974960327, 'dur': 1.2358455657958984, 'acts': 0.4741295576095581, 'rec': 4.097280979156494, 'kld': 141.07862854003906, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.39950355887

Training on batch 313/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:20.24
Losses:
{'tot': 4.04277229309082, 'pitches': 2.450697183609009, 'dur': 1.1155294179916382, 'acts': 0.47654592990875244, 'rec': 4.04277229309082, 'kld': 147.1757049560547, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.386711984872818, 'pitches': 0.4192187488079071, 'pitches_drums': 0.45292550325393677, 'pitches_non_drums': 0.39351630210876465, 'dur': 0.6640361547470093, 'acts_acc': 0.844451904296875, 'acts_precision': 0.7541487216949463, 'acts_recall': 0.2071235626935959, 'acts_f1': 0.32499006390571594}

----------------------------------------

Training on batch 314/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:21.02
Losses:
{'tot': 4.065858364105225, 'pitches': 2.4277877807617188, 'dur': 1.165560007095337, 'acts': 0.472510427236557, 'rec': 4.065858364105225, 'kld': 147.13218688964844, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.386434942483902, 'pitches': 0.4180243909358

Training on batch 326/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:30.24
Losses:
{'tot': 3.9504170417785645, 'pitches': 2.3822004795074463, 'dur': 1.0960769653320312, 'acts': 0.4721395671367645, 'rec': 3.9504170417785645, 'kld': 147.9156951904297, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.39218655228614807, 'pitches': 0.43917447328567505, 'pitches_drums': 0.48675256967544556, 'pitches_non_drums': 0.40821003913879395, 'dur': 0.6762914657592773, 'acts_acc': 0.8455810546875, 'acts_precision': 0.7734184861183167, 'acts_recall': 0.21339263021945953, 'acts_f1': 0.3344952464103699}

----------------------------------------

Training on batch 327/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:30.98
Losses:
{'tot': 3.892024278640747, 'pitches': 2.2670414447784424, 'dur': 1.1588149070739746, 'acts': 0.4661678671836853, 'rec': 3.892024278640747, 'kld': 150.1416015625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4045565128326416, 'pitches': 0.454784661

Training on batch 339/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:40.26
Losses:
{'tot': 4.008444786071777, 'pitches': 2.3835642337799072, 'dur': 1.1431407928466797, 'acts': 0.48173975944519043, 'rec': 4.008444786071777, 'kld': 148.4400177001953, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.39912906289100647, 'pitches': 0.4400295913219452, 'pitches_drums': 0.4947222173213959, 'pitches_non_drums': 0.40204495191574097, 'dur': 0.6609267592430115, 'acts_acc': 0.8299560546875, 'acts_precision': 0.7939682602882385, 'acts_recall': 0.1924438327550888, 'acts_f1': 0.30979809165000916}

----------------------------------------

Training on batch 340/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:40.94
Losses:
{'tot': 3.939293384552002, 'pitches': 2.3289711475372314, 'dur': 1.1469292640686035, 'acts': 0.463392972946167, 'rec': 3.939293384552002, 'kld': 150.2874298095703, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41226378083229065, 'pitches': 0.445668071

Training on batch 352/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:50.15
Losses:
{'tot': 4.071332931518555, 'pitches': 2.38334059715271, 'dur': 1.223412036895752, 'acts': 0.46458038687705994, 'rec': 4.071332931518555, 'kld': 153.62136840820312, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4086180329322815, 'pitches': 0.44133639335632324, 'pitches_drums': 0.5010564923286438, 'pitches_non_drums': 0.3960314095020294, 'dur': 0.6516327857971191, 'acts_acc': 0.847900390625, 'acts_precision': 0.7533053159713745, 'acts_recall': 0.21513773500919342, 'acts_f1': 0.3346906304359436}

----------------------------------------

Training on batch 353/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:50.88
Losses:
{'tot': 4.0890703201293945, 'pitches': 2.3583571910858154, 'dur': 1.2764501571655273, 'acts': 0.45426276326179504, 'rec': 4.0890703201293945, 'kld': 155.17425537109375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.39510124921798706, 'pitches': 0.43786641

Training on batch 365/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:59.99
Losses:
{'tot': 3.7773923873901367, 'pitches': 2.257617712020874, 'dur': 1.0644617080688477, 'acts': 0.4553130269050598, 'rec': 3.7773923873901367, 'kld': 156.68814086914062, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4255639612674713, 'pitches': 0.4657607674598694, 'pitches_drums': 0.5490601658821106, 'pitches_non_drums': 0.41638681292533875, 'dur': 0.6962983012199402, 'acts_acc': 0.854034423828125, 'acts_precision': 0.837193489074707, 'acts_recall': 0.21286913752555847, 'acts_f1': 0.33943241834640503}

----------------------------------------

Training on batch 366/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:00.71
Losses:
{'tot': 3.869351387023926, 'pitches': 2.2682504653930664, 'dur': 1.141262173652649, 'acts': 0.45983898639678955, 'rec': 3.869351387023926, 'kld': 159.92330932617188, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40480029582977295, 'pitches': 0.44243

Training on batch 378/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:10.22
Losses:
{'tot': 3.823883295059204, 'pitches': 2.2108705043792725, 'dur': 1.1545393466949463, 'acts': 0.4584733843803406, 'rec': 3.823883295059204, 'kld': 167.79611206054688, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.411323606967926, 'pitches': 0.4481867849826813, 'pitches_drums': 0.49712297320365906, 'pitches_non_drums': 0.4091835618019104, 'dur': 0.665584921836853, 'acts_acc': 0.84588623046875, 'acts_precision': 0.7503276467323303, 'acts_recall': 0.1969214826822281, 'acts_f1': 0.3119678497314453}

----------------------------------------

Training on batch 379/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:10.94
Losses:
{'tot': 3.82848858833313, 'pitches': 2.2321720123291016, 'dur': 1.1365231275558472, 'acts': 0.4597933888435364, 'rec': 3.82848858833313, 'kld': 168.44833374023438, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4224129021167755, 'pitches': 0.4607212245464

Training on batch 391/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:20.01
Losses:
{'tot': 4.061030387878418, 'pitches': 2.396545886993408, 'dur': 1.2062233686447144, 'acts': 0.4582611322402954, 'rec': 4.061030387878418, 'kld': 170.50765991210938, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3964586555957794, 'pitches': 0.42480796575546265, 'pitches_drums': 0.4951235353946686, 'pitches_non_drums': 0.3817928731441498, 'dur': 0.6437363028526306, 'acts_acc': 0.845916748046875, 'acts_precision': 0.7663487792015076, 'acts_recall': 0.19291777908802032, 'acts_f1': 0.30824029445648193}

----------------------------------------

Training on batch 392/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:20.81
Losses:
{'tot': 3.9113306999206543, 'pitches': 2.3792614936828613, 'dur': 1.0739141702651978, 'acts': 0.4581551253795624, 'rec': 3.9113306999206543, 'kld': 160.79345703125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4002583920955658, 'pitches': 0.44324934

Training on batch 404/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:34.43
Losses:
{'tot': 3.9774837493896484, 'pitches': 2.4195542335510254, 'dur': 1.0994341373443604, 'acts': 0.4584953188896179, 'rec': 3.9774837493896484, 'kld': 172.54971313476562, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3889915347099304, 'pitches': 0.42855411767959595, 'pitches_drums': 0.45312824845314026, 'pitches_non_drums': 0.4096127152442932, 'dur': 0.6754415035247803, 'acts_acc': 0.84088134765625, 'acts_precision': 0.8020378351211548, 'acts_recall': 0.18232958018779755, 'acts_f1': 0.29711511731147766}

----------------------------------------

Training on batch 405/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:35.22
Losses:
{'tot': 3.7878499031066895, 'pitches': 2.2461910247802734, 'dur': 1.0992916822433472, 'acts': 0.4423670768737793, 'rec': 3.7878499031066895, 'kld': 174.30728149414062, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41100090742111206, 'pitches': 0.4

Training on batch 417/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:44.76
Losses:
{'tot': 4.011955738067627, 'pitches': 2.364499807357788, 'dur': 1.2049444913864136, 'acts': 0.44251152873039246, 'rec': 4.011955738067627, 'kld': 177.2836151123047, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.406100332736969, 'pitches': 0.44088175892829895, 'pitches_drums': 0.5378972887992859, 'pitches_non_drums': 0.3892589509487152, 'dur': 0.6414523720741272, 'acts_acc': 0.856689453125, 'acts_precision': 0.7446730136871338, 'acts_recall': 0.18901529908180237, 'acts_f1': 0.3015023171901703}

----------------------------------------

Training on batch 418/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:45.56
Losses:
{'tot': 3.711568832397461, 'pitches': 2.1659083366394043, 'dur': 1.0958828926086426, 'acts': 0.4497776925563812, 'rec': 3.711568832397461, 'kld': 181.31350708007812, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4349132180213928, 'pitches': 0.484230369329

Training on batch 430/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:54.82
Losses:
{'tot': 3.821821928024292, 'pitches': 2.2258248329162598, 'dur': 1.1637189388275146, 'acts': 0.43227821588516235, 'rec': 3.821821928024292, 'kld': 187.01904296875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.411068856716156, 'pitches': 0.457083523273468, 'pitches_drums': 0.5521238446235657, 'pitches_non_drums': 0.38678261637687683, 'dur': 0.6580058932304382, 'acts_acc': 0.86590576171875, 'acts_precision': 0.8098859190940857, 'acts_recall': 0.20445382595062256, 'acts_f1': 0.32648685574531555}

----------------------------------------

Training on batch 431/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:55.65
Losses:
{'tot': 3.8123903274536133, 'pitches': 2.275324583053589, 'dur': 1.080902099609375, 'acts': 0.4561637341976166, 'rec': 3.8123903274536133, 'kld': 184.50888061523438, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.39846351742744446, 'pitches': 0.4378008246

Training on batch 443/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:05.13
Losses:
{'tot': 3.8169002532958984, 'pitches': 2.2617337703704834, 'dur': 1.1047717332839966, 'acts': 0.4503946304321289, 'rec': 3.8169002532958984, 'kld': 186.13345336914062, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41308653354644775, 'pitches': 0.4586881995201111, 'pitches_drums': 0.5209165811538696, 'pitches_non_drums': 0.41418716311454773, 'dur': 0.668311357498169, 'acts_acc': 0.843048095703125, 'acts_precision': 0.8024193644523621, 'acts_recall': 0.1833794116973877, 'acts_f1': 0.29853394627571106}

----------------------------------------

Training on batch 444/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:05.94
Losses:
{'tot': 3.6928043365478516, 'pitches': 2.1469948291778564, 'dur': 1.1022652387619019, 'acts': 0.4435443580150604, 'rec': 3.6928043365478516, 'kld': 191.62306213378906, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40996241569519043, 'pitches': 0.45

Training on batch 456/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:15.42
Losses:
{'tot': 3.659226417541504, 'pitches': 2.2302167415618896, 'dur': 0.9798403978347778, 'acts': 0.4491690397262573, 'rec': 3.659226417541504, 'kld': 193.8253173828125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43640828132629395, 'pitches': 0.46649467945098877, 'pitches_drums': 0.5090743899345398, 'pitches_non_drums': 0.4323173463344574, 'dur': 0.7129359245300293, 'acts_acc': 0.8412322998046875, 'acts_precision': 0.8157311677932739, 'acts_recall': 0.18232089281082153, 'acts_f1': 0.29803022742271423}

----------------------------------------

Training on batch 457/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:16.21
Losses:
{'tot': 3.863088846206665, 'pitches': 2.334545135498047, 'dur': 1.077065110206604, 'acts': 0.4514787495136261, 'rec': 3.863088846206665, 'kld': 189.446533203125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41781753301620483, 'pitches': 0.45394933

Training on batch 469/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:25.38
Losses:
{'tot': 3.7755372524261475, 'pitches': 2.2401936054229736, 'dur': 1.0928871631622314, 'acts': 0.4424564838409424, 'rec': 3.7755372524261475, 'kld': 192.5618896484375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4131942093372345, 'pitches': 0.4530647099018097, 'pitches_drums': 0.5214827060699463, 'pitches_non_drums': 0.4096449911594391, 'dur': 0.689801812171936, 'acts_acc': 0.84857177734375, 'acts_precision': 0.8100833892822266, 'acts_recall': 0.18486158549785614, 'acts_f1': 0.301028311252594}

----------------------------------------

Training on batch 470/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:26.17
Losses:
{'tot': 3.6897780895233154, 'pitches': 2.2493317127227783, 'dur': 0.9904413819313049, 'acts': 0.45000505447387695, 'rec': 3.6897780895233154, 'kld': 197.312744140625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40133607387542725, 'pitches': 0.44732785

Training on batch 482/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:35.31
Losses:
{'tot': 3.93754243850708, 'pitches': 2.3157219886779785, 'dur': 1.1915465593338013, 'acts': 0.4302741289138794, 'rec': 3.93754243850708, 'kld': 202.25880432128906, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.396347314119339, 'pitches': 0.44411322474479675, 'pitches_drums': 0.5460848808288574, 'pitches_non_drums': 0.3840583562850952, 'dur': 0.647889256477356, 'acts_acc': 0.8584136962890625, 'acts_precision': 0.7951388955116272, 'acts_recall': 0.19067443907260895, 'acts_f1': 0.30758899450302124}

----------------------------------------

Training on batch 483/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:36.06
Losses:
{'tot': 3.917802572250366, 'pitches': 2.3177976608276367, 'dur': 1.1590182781219482, 'acts': 0.440986692905426, 'rec': 3.917802572250366, 'kld': 199.71241760253906, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3885962963104248, 'pitches': 0.4411004483

Training on batch 495/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:45.38
Losses:
{'tot': 3.8107171058654785, 'pitches': 2.2568039894104004, 'dur': 1.123022437095642, 'acts': 0.43089067935943604, 'rec': 3.8107171058654785, 'kld': 206.47314453125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41902902722358704, 'pitches': 0.4533701539039612, 'pitches_drums': 0.5228166580200195, 'pitches_non_drums': 0.41290563344955444, 'dur': 0.6682378053665161, 'acts_acc': 0.8528900146484375, 'acts_precision': 0.8252730369567871, 'acts_recall': 0.1871076077222824, 'acts_f1': 0.3050529956817627}

----------------------------------------

Training on batch 496/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:46.17
Losses:
{'tot': 3.865614652633667, 'pitches': 2.328871726989746, 'dur': 1.0857617855072021, 'acts': 0.450981080532074, 'rec': 3.865614652633667, 'kld': 210.47579956054688, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40080708265304565, 'pitches': 0.43525606

Training on batch 508/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:58.76
Losses:
{'tot': 3.814249038696289, 'pitches': 2.2793757915496826, 'dur': 1.103191614151001, 'acts': 0.43168169260025024, 'rec': 3.814249038696289, 'kld': 211.28329467773438, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40233123302459717, 'pitches': 0.44027844071388245, 'pitches_drums': 0.48796331882476807, 'pitches_non_drums': 0.4052247107028961, 'dur': 0.6689331531524658, 'acts_acc': 0.8500518798828125, 'acts_precision': 0.8132884502410889, 'acts_recall': 0.1708933413028717, 'acts_f1': 0.28243884444236755}

----------------------------------------

Training on batch 509/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:59.49
Losses:
{'tot': 3.7426602840423584, 'pitches': 2.1796364784240723, 'dur': 1.144828200340271, 'acts': 0.4181954264640808, 'rec': 3.7426602840423584, 'kld': 211.05267333984375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41447803378105164, 'pitches': 0.46

Training on batch 521/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:08.76
Losses:
{'tot': 3.555528402328491, 'pitches': 2.1006603240966797, 'dur': 1.0293306112289429, 'acts': 0.42553725838661194, 'rec': 3.555528402328491, 'kld': 215.2797393798828, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4525563418865204, 'pitches': 0.48865675926208496, 'pitches_drums': 0.5946537256240845, 'pitches_non_drums': 0.4222261905670166, 'dur': 0.7074949741363525, 'acts_acc': 0.855804443359375, 'acts_precision': 0.7241241931915283, 'acts_recall': 0.17197692394256592, 'acts_f1': 0.27794331312179565}

----------------------------------------

Training on batch 522/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:09.50
Losses:
{'tot': 3.8082327842712402, 'pitches': 2.216203451156616, 'dur': 1.1702765226364136, 'acts': 0.4217529296875, 'rec': 3.8082327842712402, 'kld': 217.2522430419922, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41503414511680603, 'pitches': 0.45320898

Training on batch 534/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:19.25
Losses:
{'tot': 3.5276856422424316, 'pitches': 2.1120593547821045, 'dur': 0.9755257368087769, 'acts': 0.4401007890701294, 'rec': 3.5276856422424316, 'kld': 220.14199829101562, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4476073682308197, 'pitches': 0.48055213689804077, 'pitches_drums': 0.5769807696342468, 'pitches_non_drums': 0.4138646721839905, 'dur': 0.7139877080917358, 'acts_acc': 0.835540771484375, 'acts_precision': 0.7863418459892273, 'acts_recall': 0.16123485565185547, 'acts_f1': 0.267599880695343}

----------------------------------------

Training on batch 535/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:20.07
Losses:
{'tot': 3.836670160293579, 'pitches': 2.281416654586792, 'dur': 1.1075263023376465, 'acts': 0.4477272629737854, 'rec': 3.836670160293579, 'kld': 216.7996826171875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4031989574432373, 'pitches': 0.44235342

Training on batch 547/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:29.00
Losses:
{'tot': 3.99935245513916, 'pitches': 2.342879295349121, 'dur': 1.238713026046753, 'acts': 0.4177601933479309, 'rec': 3.99935245513916, 'kld': 225.24285888671875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3924674689769745, 'pitches': 0.4367044270038605, 'pitches_drums': 0.5194206237792969, 'pitches_non_drums': 0.38621389865875244, 'dur': 0.6358581781387329, 'acts_acc': 0.8599395751953125, 'acts_precision': 0.7445194125175476, 'acts_recall': 0.17079304158687592, 'acts_f1': 0.2778477072715759}

----------------------------------------

Training on batch 548/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:29.82
Losses:
{'tot': 3.7624173164367676, 'pitches': 2.311396837234497, 'dur': 1.016271948814392, 'acts': 0.43474864959716797, 'rec': 3.7624173164367676, 'kld': 224.87725830078125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3827817738056183, 'pitches': 0.431260257

Training on batch 560/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:39.05
Losses:
{'tot': 3.6865365505218506, 'pitches': 2.163313865661621, 'dur': 1.1013842821121216, 'acts': 0.42183852195739746, 'rec': 3.6865365505218506, 'kld': 229.422607421875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41626957058906555, 'pitches': 0.4585772752761841, 'pitches_drums': 0.5251230597496033, 'pitches_non_drums': 0.41402432322502136, 'dur': 0.6771953701972961, 'acts_acc': 0.8528900146484375, 'acts_precision': 0.7834395170211792, 'acts_recall': 0.1778581142425537, 'acts_f1': 0.2899020314216614}

----------------------------------------

Training on batch 561/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:39.88
Losses:
{'tot': 3.8080894947052, 'pitches': 2.2291998863220215, 'dur': 1.1345921754837036, 'acts': 0.44429755210876465, 'rec': 3.8080894947052, 'kld': 228.68350219726562, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41681215167045593, 'pitches': 0.46154090

Training on batch 573/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:49.12
Losses:
{'tot': 3.887516975402832, 'pitches': 2.322384834289551, 'dur': 1.1364591121673584, 'acts': 0.4286731481552124, 'rec': 3.887516975402832, 'kld': 235.849853515625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3932614028453827, 'pitches': 0.4332282245159149, 'pitches_drums': 0.4522646963596344, 'pitches_non_drums': 0.42083629965782166, 'dur': 0.6636680364608765, 'acts_acc': 0.845306396484375, 'acts_precision': 0.7270183563232422, 'acts_recall': 0.16134469211101532, 'acts_f1': 0.2640824615955353}

----------------------------------------

Training on batch 574/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:49.76
Losses:
{'tot': 3.724151134490967, 'pitches': 2.1883747577667236, 'dur': 1.1252813339233398, 'acts': 0.4104950428009033, 'rec': 3.724151134490967, 'kld': 238.34451293945312, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4159397482872009, 'pitches': 0.4492205381

Training on batch 586/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:58.91
Losses:
{'tot': 3.9727888107299805, 'pitches': 2.3113954067230225, 'dur': 1.2490036487579346, 'acts': 0.4123898148536682, 'rec': 3.9727888107299805, 'kld': 240.98175048828125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4102937579154968, 'pitches': 0.44897523522377014, 'pitches_drums': 0.5330343246459961, 'pitches_non_drums': 0.40855294466018677, 'dur': 0.6472134590148926, 'acts_acc': 0.8576812744140625, 'acts_precision': 0.7991647720336914, 'acts_recall': 0.19306613504886627, 'acts_f1': 0.3109995126724243}

----------------------------------------

Training on batch 587/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:59.66
Losses:
{'tot': 3.7377119064331055, 'pitches': 2.3332247734069824, 'dur': 0.9855398535728455, 'acts': 0.41894716024398804, 'rec': 3.7377119064331055, 'kld': 243.00491333007812, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4066888391971588, 'pitches': 0.

Training on batch 599/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:09.26
Losses:
{'tot': 3.689293384552002, 'pitches': 2.19480037689209, 'dur': 1.0876612663269043, 'acts': 0.40683186054229736, 'rec': 3.689293384552002, 'kld': 247.647705078125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4232042133808136, 'pitches': 0.4597887396812439, 'pitches_drums': 0.5340009927749634, 'pitches_non_drums': 0.40640512108802795, 'dur': 0.6812323927879333, 'acts_acc': 0.8641357421875, 'acts_precision': 0.808867335319519, 'acts_recall': 0.20554819703102112, 'acts_f1': 0.3277970850467682}

----------------------------------------

Training on batch 600/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:10.06
Losses:
{'tot': 3.709538221359253, 'pitches': 2.236954927444458, 'dur': 1.059822916984558, 'acts': 0.4127604365348816, 'rec': 3.709538221359253, 'kld': 247.48434448242188, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41734352707862854, 'pitches': 0.45609188079833

Training on batch 612/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:23.01
Losses:
{'tot': 3.6438663005828857, 'pitches': 2.1267805099487305, 'dur': 1.0827234983444214, 'acts': 0.4343622326850891, 'rec': 3.6438663005828857, 'kld': 250.0550537109375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43785303831100464, 'pitches': 0.4697619378566742, 'pitches_drums': 0.5695902109146118, 'pitches_non_drums': 0.4000000059604645, 'dur': 0.6686086058616638, 'acts_acc': 0.829498291015625, 'acts_precision': 0.8066188097000122, 'acts_recall': 0.15914097428321838, 'acts_f1': 0.26583442091941833}

----------------------------------------

Training on batch 613/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:23.72
Losses:
{'tot': 3.6555511951446533, 'pitches': 2.158262014389038, 'dur': 1.087565541267395, 'acts': 0.4097236096858978, 'rec': 3.6555511951446533, 'kld': 253.77670288085938, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4344416558742523, 'pitches': 0.46498

Training on batch 625/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:32.62
Losses:
{'tot': 3.6809914112091064, 'pitches': 2.159064531326294, 'dur': 1.1050382852554321, 'acts': 0.4168885350227356, 'rec': 3.6809914112091064, 'kld': 261.83343505859375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42897799611091614, 'pitches': 0.45803821086883545, 'pitches_drums': 0.5220860242843628, 'pitches_non_drums': 0.4085328280925751, 'dur': 0.6683531403541565, 'acts_acc': 0.8488311767578125, 'acts_precision': 0.7785934805870056, 'acts_recall': 0.17754867672920227, 'acts_f1': 0.28915831446647644}

----------------------------------------

Training on batch 626/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:33.42
Losses:
{'tot': 3.7088265419006348, 'pitches': 2.1959171295166016, 'dur': 1.0889068841934204, 'acts': 0.4240027368068695, 'rec': 3.7088265419006348, 'kld': 255.19017028808594, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4162498116493225, 'pitches': 0.4

Training on batch 638/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:42.65
Losses:
{'tot': 3.6376967430114746, 'pitches': 2.168583631515503, 'dur': 1.0489394664764404, 'acts': 0.4201735854148865, 'rec': 3.6376967430114746, 'kld': 262.6335754394531, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4241778254508972, 'pitches': 0.4612032473087311, 'pitches_drums': 0.5313194394111633, 'pitches_non_drums': 0.4097112715244293, 'dur': 0.6711602807044983, 'acts_acc': 0.844512939453125, 'acts_precision': 0.8539857864379883, 'acts_recall': 0.18057410418987274, 'acts_f1': 0.2981127202510834}

----------------------------------------

Training on batch 639/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:43.42
Losses:
{'tot': 3.534097671508789, 'pitches': 2.084735631942749, 'dur': 1.045083999633789, 'acts': 0.40427812933921814, 'rec': 3.534097671508789, 'kld': 263.49749755859375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4596846103668213, 'pitches': 0.491292417

Training on batch 651/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:52.86
Losses:
{'tot': 3.8461921215057373, 'pitches': 2.2698962688446045, 'dur': 1.1535181999206543, 'acts': 0.42277756333351135, 'rec': 3.8461921215057373, 'kld': 266.40435791015625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41757676005363464, 'pitches': 0.4553452134132385, 'pitches_drums': 0.5735470652580261, 'pitches_non_drums': 0.3906000554561615, 'dur': 0.6609634757041931, 'acts_acc': 0.838348388671875, 'acts_precision': 0.813725471496582, 'acts_recall': 0.18186981976032257, 'acts_f1': 0.29729369282722473}

----------------------------------------

Training on batch 652/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:53.63
Losses:
{'tot': 3.6694815158843994, 'pitches': 2.2801101207733154, 'dur': 0.9736343026161194, 'acts': 0.415737122297287, 'rec': 3.6694815158843994, 'kld': 270.78155517578125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40003451704978943, 'pitches': 0.43

Training on batch 664/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:02.75
Losses:
{'tot': 3.5394113063812256, 'pitches': 2.058382272720337, 'dur': 1.0608762502670288, 'acts': 0.4201529026031494, 'rec': 3.5394113063812256, 'kld': 272.6866149902344, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4440610110759735, 'pitches': 0.4824218153953552, 'pitches_drums': 0.6019719243049622, 'pitches_non_drums': 0.40374988317489624, 'dur': 0.6745031476020813, 'acts_acc': 0.8406829833984375, 'acts_precision': 0.825947105884552, 'acts_recall': 0.1884223371744156, 'acts_f1': 0.306844562292099}

----------------------------------------

Training on batch 665/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:03.54
Losses:
{'tot': 3.903212547302246, 'pitches': 2.4663002490997314, 'dur': 1.0142923593521118, 'acts': 0.4226197898387909, 'rec': 3.903212547302246, 'kld': 277.5787048339844, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3638755679130554, 'pitches': 0.4114019572

Training on batch 677/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:12.74
Losses:
{'tot': 3.6140222549438477, 'pitches': 2.140336036682129, 'dur': 1.0652650594711304, 'acts': 0.4084210991859436, 'rec': 3.6140222549438477, 'kld': 275.66839599609375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4396588206291199, 'pitches': 0.4655076861381531, 'pitches_drums': 0.542735755443573, 'pitches_non_drums': 0.4127158224582672, 'dur': 0.6877624988555908, 'acts_acc': 0.84930419921875, 'acts_precision': 0.7358356714248657, 'acts_recall': 0.18540328741073608, 'acts_f1': 0.29618018865585327}

----------------------------------------

Training on batch 678/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:13.49
Losses:
{'tot': 3.7019338607788086, 'pitches': 2.1770055294036865, 'dur': 1.1244394779205322, 'acts': 0.4004889130592346, 'rec': 3.7019338607788086, 'kld': 280.57244873046875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4240538775920868, 'pitches': 0.463845

Training on batch 690/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:22.62
Losses:
{'tot': 3.658499002456665, 'pitches': 2.1293959617614746, 'dur': 1.1269605159759521, 'acts': 0.40214258432388306, 'rec': 3.658499002456665, 'kld': 283.33587646484375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42897728085517883, 'pitches': 0.4738157093524933, 'pitches_drums': 0.5731503367424011, 'pitches_non_drums': 0.4050845503807068, 'dur': 0.6681612730026245, 'acts_acc': 0.8569488525390625, 'acts_precision': 0.7823104858398438, 'acts_recall': 0.1980985403060913, 'acts_f1': 0.3161426782608032}

----------------------------------------

Training on batch 691/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:23.39
Losses:
{'tot': 3.6511714458465576, 'pitches': 2.1113057136535645, 'dur': 1.1263221502304077, 'acts': 0.4135434627532959, 'rec': 3.6511714458465576, 'kld': 286.9012756347656, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4242307245731354, 'pitches': 0.47220

Training on batch 703/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:35.92
Losses:
{'tot': 3.4765572547912598, 'pitches': 2.013414144515991, 'dur': 1.0670021772384644, 'acts': 0.3961409330368042, 'rec': 3.4765572547912598, 'kld': 287.8099670410156, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.433165580034256, 'pitches': 0.4886319041252136, 'pitches_drums': 0.5793003439903259, 'pitches_non_drums': 0.4297238290309906, 'dur': 0.684370219707489, 'acts_acc': 0.8579559326171875, 'acts_precision': 0.7767361402511597, 'acts_recall': 0.20515407621860504, 'acts_f1': 0.32457923889160156}

----------------------------------------

Training on batch 704/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:36.70
Losses:
{'tot': 3.5561087131500244, 'pitches': 2.1006999015808105, 'dur': 1.0534237623214722, 'acts': 0.4019848704338074, 'rec': 3.5561087131500244, 'kld': 293.7475891113281, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41845664381980896, 'pitches': 0.464565

Training on batch 716/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:46.01
Losses:
{'tot': 3.7001054286956787, 'pitches': 2.182466506958008, 'dur': 1.1027737855911255, 'acts': 0.4148651659488678, 'rec': 3.7001054286956787, 'kld': 296.4206237792969, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4054015874862671, 'pitches': 0.4432692229747772, 'pitches_drums': 0.5326825976371765, 'pitches_non_drums': 0.3858337104320526, 'dur': 0.6721153855323792, 'acts_acc': 0.8406524658203125, 'acts_precision': 0.8548168540000916, 'acts_recall': 0.20102106034755707, 'acts_f1': 0.3254973292350769}

----------------------------------------

Training on batch 717/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:46.80
Losses:
{'tot': 3.57381534576416, 'pitches': 2.155726909637451, 'dur': 1.0126190185546875, 'acts': 0.40546929836273193, 'rec': 3.57381534576416, 'kld': 295.1914367675781, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41924893856048584, 'pitches': 0.457906365

Training on batch 729/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:55.80
Losses:
{'tot': 3.525526523590088, 'pitches': 2.0923449993133545, 'dur': 1.048410177230835, 'acts': 0.3847713768482208, 'rec': 3.525526523590088, 'kld': 309.32452392578125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4280492067337036, 'pitches': 0.4696628451347351, 'pitches_drums': 0.5501183271408081, 'pitches_non_drums': 0.4095781147480011, 'dur': 0.6937563419342041, 'acts_acc': 0.865692138671875, 'acts_precision': 0.8894366025924683, 'acts_recall': 0.2293444722890854, 'acts_f1': 0.36466002464294434}

----------------------------------------

Training on batch 730/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:56.50
Losses:
{'tot': 3.628673791885376, 'pitches': 2.1686596870422363, 'dur': 1.0676910877227783, 'acts': 0.3923230469226837, 'rec': 3.628673791885376, 'kld': 300.4228515625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42123112082481384, 'pitches': 0.470143765211

Training on batch 742/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:05.63
Losses:
{'tot': 3.6948256492614746, 'pitches': 2.179412364959717, 'dur': 1.1152654886245728, 'acts': 0.4001479148864746, 'rec': 3.6948256492614746, 'kld': 313.20355224609375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4100671708583832, 'pitches': 0.4597722887992859, 'pitches_drums': 0.5714512467384338, 'pitches_non_drums': 0.3861650228500366, 'dur': 0.6638281941413879, 'acts_acc': 0.8455352783203125, 'acts_precision': 0.7858136296272278, 'acts_recall': 0.19204622507095337, 'acts_f1': 0.3086588382720947}

----------------------------------------

Training on batch 743/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:06.43
Losses:
{'tot': 3.601970672607422, 'pitches': 2.1635749340057373, 'dur': 1.0359227657318115, 'acts': 0.4024730920791626, 'rec': 3.601970672607422, 'kld': 320.0757751464844, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41708990931510925, 'pitches': 0.457906

Training on batch 755/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:15.94
Losses:
{'tot': 3.8132200241088867, 'pitches': 2.1831910610198975, 'dur': 1.2303929328918457, 'acts': 0.39963600039482117, 'rec': 3.8132200241088867, 'kld': 315.187744140625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4073842465877533, 'pitches': 0.45246633887290955, 'pitches_drums': 0.5780743956565857, 'pitches_non_drums': 0.38910889625549316, 'dur': 0.6373029947280884, 'acts_acc': 0.8503875732421875, 'acts_precision': 0.7297857403755188, 'acts_recall': 0.18967220187187195, 'acts_f1': 0.30109062790870667}

----------------------------------------

Training on batch 756/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:16.58
Losses:
{'tot': 3.634645938873291, 'pitches': 2.101804256439209, 'dur': 1.1546846628189087, 'acts': 0.37815725803375244, 'rec': 3.634645938873291, 'kld': 313.4503173828125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4250401258468628, 'pitches': 0.4693

Training on batch 768/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:25.98
Losses:
{'tot': 3.5801949501037598, 'pitches': 2.1806185245513916, 'dur': 0.9980804324150085, 'acts': 0.4014960527420044, 'rec': 3.5801949501037598, 'kld': 321.88720703125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4182819724082947, 'pitches': 0.4562273621559143, 'pitches_drums': 0.5406989455223083, 'pitches_non_drums': 0.39491650462150574, 'dur': 0.6881903409957886, 'acts_acc': 0.8442840576171875, 'acts_precision': 0.7862491607666016, 'acts_recall': 0.19432994723320007, 'acts_f1': 0.31163573265075684}

----------------------------------------

Training on batch 769/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:26.79
Losses:
{'tot': 3.507418155670166, 'pitches': 2.1184325218200684, 'dur': 0.9855987429618835, 'acts': 0.40338677167892456, 'rec': 3.507418155670166, 'kld': 317.169921875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42826956510543823, 'pitches': 0.465073019

Training on batch 781/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:35.93
Losses:
{'tot': 3.40164852142334, 'pitches': 2.0223138332366943, 'dur': 1.0017526149749756, 'acts': 0.37758201360702515, 'rec': 3.40164852142334, 'kld': 327.814453125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43735548853874207, 'pitches': 0.47498422861099243, 'pitches_drums': 0.5668364763259888, 'pitches_non_drums': 0.4140242338180542, 'dur': 0.6998808979988098, 'acts_acc': 0.86346435546875, 'acts_precision': 0.7909277081489563, 'acts_recall': 0.21768516302108765, 'acts_f1': 0.341405987739563}

----------------------------------------

Training on batch 782/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:36.74
Losses:
{'tot': 3.7039718627929688, 'pitches': 2.2164793014526367, 'dur': 1.0855305194854736, 'acts': 0.40196195244789124, 'rec': 3.7039718627929688, 'kld': 330.98236083984375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4330504238605499, 'pitches': 0.46192649006

Training on batch 794/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:45.78
Losses:
{'tot': 3.500709056854248, 'pitches': 2.055762529373169, 'dur': 1.054249882698059, 'acts': 0.3906964659690857, 'rec': 3.500709056854248, 'kld': 332.97369384765625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4541079103946686, 'pitches': 0.4894251525402069, 'pitches_drums': 0.5742241740226746, 'pitches_non_drums': 0.4295298159122467, 'dur': 0.6857036352157593, 'acts_acc': 0.854095458984375, 'acts_precision': 0.8165760636329651, 'acts_recall': 0.21039733290672302, 'acts_f1': 0.33458593487739563}

----------------------------------------

Training on batch 795/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:46.50
Losses:
{'tot': 3.4885241985321045, 'pitches': 2.0738027095794678, 'dur': 1.0345656871795654, 'acts': 0.38015586137771606, 'rec': 3.4885241985321045, 'kld': 334.3622741699219, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4447229206562042, 'pitches': 0.4830735

Training on batch 807/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:58.57
Losses:
{'tot': 3.5269582271575928, 'pitches': 2.091813087463379, 'dur': 1.0333170890808105, 'acts': 0.40182799100875854, 'rec': 3.5269582271575928, 'kld': 340.9834289550781, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4396120607852936, 'pitches': 0.47623637318611145, 'pitches_drums': 0.5481758117675781, 'pitches_non_drums': 0.4149986207485199, 'dur': 0.6798686981201172, 'acts_acc': 0.8425750732421875, 'acts_precision': 0.8246445655822754, 'acts_recall': 0.19910094141960144, 'acts_f1': 0.3207584321498871}

----------------------------------------

Training on batch 808/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:10:59.27
Losses:
{'tot': 3.6840832233428955, 'pitches': 2.2062630653381348, 'dur': 1.095583438873291, 'acts': 0.3822367191314697, 'rec': 3.6840832233428955, 'kld': 341.1230773925781, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42739933729171753, 'pitches': 0.4592

Training on batch 820/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:08.42
Losses:
{'tot': 3.4859917163848877, 'pitches': 2.065661668777466, 'dur': 1.0280117988586426, 'acts': 0.3923181891441345, 'rec': 3.4859917163848877, 'kld': 342.55059814453125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4464150369167328, 'pitches': 0.4891757071018219, 'pitches_drums': 0.5807941555976868, 'pitches_non_drums': 0.42184504866600037, 'dur': 0.6882731914520264, 'acts_acc': 0.846710205078125, 'acts_precision': 0.8414464592933655, 'acts_recall': 0.20148196816444397, 'acts_f1': 0.3251158595085144}

----------------------------------------

Training on batch 821/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:09.20
Losses:
{'tot': 3.576345920562744, 'pitches': 2.1090755462646484, 'dur': 1.0752137899398804, 'acts': 0.39205652475357056, 'rec': 3.576345920562744, 'kld': 340.01336669921875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43092671036720276, 'pitches': 0.4695

Training on batch 833/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:18.26
Losses:
{'tot': 3.476583957672119, 'pitches': 2.1354451179504395, 'dur': 0.9351127743721008, 'acts': 0.40602612495422363, 'rec': 3.476583957672119, 'kld': 351.28070068359375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41787856817245483, 'pitches': 0.45723482966423035, 'pitches_drums': 0.526361882686615, 'pitches_non_drums': 0.41708680987358093, 'dur': 0.6994001269340515, 'acts_acc': 0.8381805419921875, 'acts_precision': 0.859375, 'acts_recall': 0.20247513055801392, 'acts_f1': 0.3277337849140167}

----------------------------------------

Training on batch 834/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:19.03
Losses:
{'tot': 3.3513691425323486, 'pitches': 2.0307047367095947, 'dur': 0.9373620748519897, 'acts': 0.3833024501800537, 'rec': 3.3513691425323486, 'kld': 346.45159912109375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43919050693511963, 'pitches': 0.47570863366

Training on batch 846/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:28.23
Losses:
{'tot': 3.5537705421447754, 'pitches': 2.0766804218292236, 'dur': 1.08866286277771, 'acts': 0.3884272575378418, 'rec': 3.5537705421447754, 'kld': 353.04754638671875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42150604724884033, 'pitches': 0.4701147675514221, 'pitches_drums': 0.5451778769493103, 'pitches_non_drums': 0.41945338249206543, 'dur': 0.663134753704071, 'acts_acc': 0.847747802734375, 'acts_precision': 0.8398385047912598, 'acts_recall': 0.20805200934410095, 'acts_f1': 0.33348920941352844}

----------------------------------------

Training on batch 847/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:29.06
Losses:
{'tot': 3.5912046432495117, 'pitches': 2.1451292037963867, 'dur': 1.0486809015274048, 'acts': 0.39739465713500977, 'rec': 3.5912046432495117, 'kld': 353.76214599609375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42907196283340454, 'pitches': 0.45

Training on batch 859/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:38.20
Losses:
{'tot': 3.7391958236694336, 'pitches': 2.221522569656372, 'dur': 1.1302402019500732, 'acts': 0.38743311166763306, 'rec': 3.7391958236694336, 'kld': 363.8403015136719, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4275151193141937, 'pitches': 0.46329542994499207, 'pitches_drums': 0.5752446055412292, 'pitches_non_drums': 0.3980259895324707, 'dur': 0.6552307605743408, 'acts_acc': 0.8487548828125, 'acts_precision': 0.8128876686096191, 'acts_recall': 0.2011425644159317, 'acts_f1': 0.32248803973197937}

----------------------------------------

Training on batch 860/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:38.86
Losses:
{'tot': 3.5535521507263184, 'pitches': 2.056736946105957, 'dur': 1.1401387453079224, 'acts': 0.3566766381263733, 'rec': 3.5535521507263184, 'kld': 360.17431640625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.45977506041526794, 'pitches': 0.491712063

Training on batch 872/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:48.12
Losses:
{'tot': 3.559091091156006, 'pitches': 2.1812198162078857, 'dur': 0.9772993326187134, 'acts': 0.4005718231201172, 'rec': 3.559091091156006, 'kld': 364.99993896484375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40929242968559265, 'pitches': 0.4551169276237488, 'pitches_drums': 0.50208580493927, 'pitches_non_drums': 0.4227930009365082, 'dur': 0.7044640183448792, 'acts_acc': 0.839874267578125, 'acts_precision': 0.7801465392112732, 'acts_recall': 0.19234560430049896, 'acts_f1': 0.30860456824302673}

----------------------------------------

Training on batch 873/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:48.85
Losses:
{'tot': 3.6101691722869873, 'pitches': 2.1135828495025635, 'dur': 1.1163409948349, 'acts': 0.3802453577518463, 'rec': 3.6101691722869873, 'kld': 367.5814208984375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.44125595688819885, 'pitches': 0.479128420

Training on batch 885/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:58.05
Losses:
{'tot': 3.5380492210388184, 'pitches': 2.0931143760681152, 'dur': 1.076387882232666, 'acts': 0.36854684352874756, 'rec': 3.5380492210388184, 'kld': 374.6556701660156, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4303669035434723, 'pitches': 0.4724184274673462, 'pitches_drums': 0.5853211283683777, 'pitches_non_drums': 0.40465977787971497, 'dur': 0.6659060716629028, 'acts_acc': 0.867767333984375, 'acts_precision': 0.7767552733421326, 'acts_recall': 0.22147716581821442, 'acts_f1': 0.3446763753890991}

----------------------------------------

Training on batch 886/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:11:58.75
Losses:
{'tot': 3.6136062145233154, 'pitches': 2.1658742427825928, 'dur': 1.0921114683151245, 'acts': 0.3556206226348877, 'rec': 3.6136062145233154, 'kld': 373.2100830078125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4197503328323364, 'pitches': 0.46267

Training on batch 898/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:08.28
Losses:
{'tot': 3.651345729827881, 'pitches': 2.1985089778900146, 'dur': 1.0686359405517578, 'acts': 0.38420093059539795, 'rec': 3.651345729827881, 'kld': 376.5452880859375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40757033228874207, 'pitches': 0.4574373662471771, 'pitches_drums': 0.49148327112197876, 'pitches_non_drums': 0.43116384744644165, 'dur': 0.6775220632553101, 'acts_acc': 0.851654052734375, 'acts_precision': 0.7476697564125061, 'acts_recall': 0.20035682618618011, 'acts_f1': 0.3160264492034912}

----------------------------------------

Training on batch 899/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:09.05
Losses:
{'tot': 3.4030377864837646, 'pitches': 1.9778773784637451, 'dur': 1.0512809753417969, 'acts': 0.37387943267822266, 'rec': 3.4030377864837646, 'kld': 381.10186767578125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4256545603275299, 'pitches': 0.46

Training on batch 911/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:21.67
Losses:
{'tot': 3.5001916885375977, 'pitches': 2.021904706954956, 'dur': 1.1095739603042603, 'acts': 0.368712842464447, 'rec': 3.5001916885375977, 'kld': 381.4197692871094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43925297260284424, 'pitches': 0.48096397519111633, 'pitches_drums': 0.6256383657455444, 'pitches_non_drums': 0.4099006950855255, 'dur': 0.6721480488777161, 'acts_acc': 0.856414794921875, 'acts_precision': 0.8255578279495239, 'acts_recall': 0.21541990339756012, 'acts_f1': 0.3416818380355835}

----------------------------------------

Training on batch 912/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:22.39
Losses:
{'tot': 3.482538938522339, 'pitches': 2.0456619262695312, 'dur': 1.0703473091125488, 'acts': 0.3665296137332916, 'rec': 3.482538938522339, 'kld': 389.48797607421875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4402821958065033, 'pitches': 0.4869841

Training on batch 924/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:31.49
Losses:
{'tot': 3.451585054397583, 'pitches': 2.0890166759490967, 'dur': 0.9944161772727966, 'acts': 0.3681522011756897, 'rec': 3.451585054397583, 'kld': 391.5955810546875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43086713552474976, 'pitches': 0.4668576121330261, 'pitches_drums': 0.5451115965843201, 'pitches_non_drums': 0.4136142432689667, 'dur': 0.705234706401825, 'acts_acc': 0.8539276123046875, 'acts_precision': 0.8483033776283264, 'acts_recall': 0.2185651808977127, 'acts_f1': 0.34757718443870544}

----------------------------------------

Training on batch 925/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:32.22
Losses:
{'tot': 3.6649763584136963, 'pitches': 2.112424373626709, 'dur': 1.1882246732711792, 'acts': 0.36432725191116333, 'rec': 3.6649763584136963, 'kld': 390.3612060546875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4377208650112152, 'pitches': 0.4725945

Training on batch 937/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:41.44
Losses:
{'tot': 3.4309701919555664, 'pitches': 2.077216625213623, 'dur': 0.9585173726081848, 'acts': 0.39523637294769287, 'rec': 3.4309701919555664, 'kld': 394.3873291015625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4398055076599121, 'pitches': 0.47817859053611755, 'pitches_drums': 0.5700782537460327, 'pitches_non_drums': 0.422854483127594, 'dur': 0.6996244192123413, 'acts_acc': 0.832977294921875, 'acts_precision': 0.8166779279708862, 'acts_recall': 0.18866011500358582, 'acts_f1': 0.3065129220485687}

----------------------------------------

Training on batch 938/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:42.20
Losses:
{'tot': 3.397416830062866, 'pitches': 2.0169458389282227, 'dur': 1.0109909772872925, 'acts': 0.36947983503341675, 'rec': 3.397416830062866, 'kld': 395.14111328125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4436318278312683, 'pitches': 0.486266046

Training on batch 950/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:51.40
Losses:
{'tot': 3.6306567192077637, 'pitches': 2.1086552143096924, 'dur': 1.1549079418182373, 'acts': 0.36709368228912354, 'rec': 3.6306567192077637, 'kld': 408.76678466796875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42565152049064636, 'pitches': 0.4689245820045471, 'pitches_drums': 0.5693320035934448, 'pitches_non_drums': 0.4046887457370758, 'dur': 0.6718603372573853, 'acts_acc': 0.8607177734375, 'acts_precision': 0.7685788869857788, 'acts_recall': 0.21881960332393646, 'acts_f1': 0.3406529724597931}

----------------------------------------

Training on batch 951/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:12:52.19
Losses:
{'tot': 3.4996490478515625, 'pitches': 2.106335163116455, 'dur': 0.9901852011680603, 'acts': 0.40312862396240234, 'rec': 3.4996490478515625, 'kld': 398.6356201171875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42980849742889404, 'pitches': 0.4558

Training on batch 963/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:01.34
Losses:
{'tot': 3.729671001434326, 'pitches': 2.2530715465545654, 'dur': 1.0811741352081299, 'acts': 0.39542537927627563, 'rec': 3.729671001434326, 'kld': 407.7514343261719, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3948511481285095, 'pitches': 0.43042072653770447, 'pitches_drums': 0.5172069668769836, 'pitches_non_drums': 0.3835376501083374, 'dur': 0.6549170613288879, 'acts_acc': 0.83416748046875, 'acts_precision': 0.745244562625885, 'acts_recall': 0.1782001256942749, 'acts_f1': 0.2876245379447937}

----------------------------------------

Training on batch 964/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:02.14
Losses:
{'tot': 3.3851969242095947, 'pitches': 2.03981876373291, 'dur': 0.9741348624229431, 'acts': 0.3712431788444519, 'rec': 3.3851969242095947, 'kld': 402.84393310546875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4440148174762726, 'pitches': 0.4771575331

Training on batch 976/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:10.95
Losses:
{'tot': 3.448009967803955, 'pitches': 2.0211732387542725, 'dur': 1.051364779472351, 'acts': 0.37547194957733154, 'rec': 3.448009967803955, 'kld': 418.1736755371094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4333062171936035, 'pitches': 0.4773784279823303, 'pitches_drums': 0.5710234642028809, 'pitches_non_drums': 0.4175197184085846, 'dur': 0.6856074333190918, 'acts_acc': 0.84765625, 'acts_precision': 0.7935290336608887, 'acts_recall': 0.20257152616977692, 'acts_f1': 0.32275131344795227}

----------------------------------------

Training on batch 977/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:11.70
Losses:
{'tot': 3.38635516166687, 'pitches': 1.9921250343322754, 'dur': 1.0371769666671753, 'acts': 0.3570529818534851, 'rec': 3.38635516166687, 'kld': 409.90386962890625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43926548957824707, 'pitches': 0.4733396470546722

Training on batch 989/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:20.75
Losses:
{'tot': 3.473740577697754, 'pitches': 2.034090757369995, 'dur': 1.0708523988723755, 'acts': 0.3687974214553833, 'rec': 3.473740577697754, 'kld': 417.6308288574219, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.44584211707115173, 'pitches': 0.4798395335674286, 'pitches_drums': 0.5535422563552856, 'pitches_non_drums': 0.42722293734550476, 'dur': 0.6665193438529968, 'acts_acc': 0.854766845703125, 'acts_precision': 0.7784526944160461, 'acts_recall': 0.21625222265720367, 'acts_f1': 0.3384765088558197}

----------------------------------------

Training on batch 990/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:21.42
Losses:
{'tot': 3.446553945541382, 'pitches': 2.03117036819458, 'dur': 1.0532265901565552, 'acts': 0.36215704679489136, 'rec': 3.446553945541382, 'kld': 412.7839660644531, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4399546980857849, 'pitches': 0.4697473645

Training on batch 1002/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:33.77
Losses:
{'tot': 3.6713924407958984, 'pitches': 2.2293453216552734, 'dur': 1.043880581855774, 'acts': 0.39816659688949585, 'rec': 3.6713924407958984, 'kld': 415.09210205078125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4255765676498413, 'pitches': 0.44838443398475647, 'pitches_drums': 0.5096489191055298, 'pitches_non_drums': 0.4078178107738495, 'dur': 0.6943065524101257, 'acts_acc': 0.83203125, 'acts_precision': 0.7980263233184814, 'acts_recall': 0.18925033509731293, 'acts_f1': 0.3059461712837219}

----------------------------------------

Training on batch 1003/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:34.54
Losses:
{'tot': 3.371103048324585, 'pitches': 2.0154292583465576, 'dur': 0.9880444407463074, 'acts': 0.36762934923171997, 'rec': 3.371103048324585, 'kld': 419.2145080566406, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4506446421146393, 'pitches': 0.4852916598

Training on batch 1015/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:43.90
Losses:
{'tot': 3.6373825073242188, 'pitches': 2.164496898651123, 'dur': 1.1126824617385864, 'acts': 0.3602031469345093, 'rec': 3.6373825073242188, 'kld': 431.2904052734375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41714078187942505, 'pitches': 0.4437793791294098, 'pitches_drums': 0.500388503074646, 'pitches_non_drums': 0.40894603729248047, 'dur': 0.6670503616333008, 'acts_acc': 0.8572540283203125, 'acts_precision': 0.7937704920768738, 'acts_recall': 0.21718847751617432, 'acts_f1': 0.34105798602104187}

----------------------------------------

Training on batch 1016/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:44.70
Losses:
{'tot': 3.4422245025634766, 'pitches': 2.076392650604248, 'dur': 0.9995998740196228, 'acts': 0.3662317991256714, 'rec': 3.4422245025634766, 'kld': 432.529541015625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4267929196357727, 'pitches': 0.4650

Training on batch 1028/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:54.01
Losses:
{'tot': 3.359752893447876, 'pitches': 2.035055160522461, 'dur': 0.9635481834411621, 'acts': 0.3611496090888977, 'rec': 3.359752893447876, 'kld': 436.91961669921875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.44741812348365784, 'pitches': 0.4846227765083313, 'pitches_drums': 0.5547837615013123, 'pitches_non_drums': 0.44074374437332153, 'dur': 0.7109155058860779, 'acts_acc': 0.8592071533203125, 'acts_precision': 0.772870659828186, 'acts_recall': 0.2236013561487198, 'acts_f1': 0.3468535542488098}

----------------------------------------

Training on batch 1029/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:13:54.86
Losses:
{'tot': 3.630855083465576, 'pitches': 2.136660575866699, 'dur': 1.1241039037704468, 'acts': 0.3700904846191406, 'rec': 3.630855083465576, 'kld': 432.6048583984375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4271310567855835, 'pitches': 0.45276811

Training on batch 1041/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:04.31
Losses:
{'tot': 3.5410256385803223, 'pitches': 2.1111137866973877, 'dur': 1.072933316230774, 'acts': 0.35697847604751587, 'rec': 3.5410256385803223, 'kld': 432.1111755371094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4320010542869568, 'pitches': 0.46203458309173584, 'pitches_drums': 0.5522972345352173, 'pitches_non_drums': 0.40802833437919617, 'dur': 0.6723997592926025, 'acts_acc': 0.8562469482421875, 'acts_precision': 0.7975594401359558, 'acts_recall': 0.2251369059085846, 'acts_f1': 0.3511503040790558}

----------------------------------------

Training on batch 1042/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:05.03
Losses:
{'tot': 3.5578038692474365, 'pitches': 2.1294918060302734, 'dur': 1.0732451677322388, 'acts': 0.3550669550895691, 'rec': 3.5578038692474365, 'kld': 432.01458740234375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.43553486466407776, 'pitches': 0.

Training on batch 1054/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:14.14
Losses:
{'tot': 3.4109864234924316, 'pitches': 2.0262115001678467, 'dur': 1.0285041332244873, 'acts': 0.3562707304954529, 'rec': 3.4109864234924316, 'kld': 448.9580078125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4312594532966614, 'pitches': 0.46799716353416443, 'pitches_drums': 0.5430273413658142, 'pitches_non_drums': 0.41380226612091064, 'dur': 0.6882691979408264, 'acts_acc': 0.8558502197265625, 'acts_precision': 0.869332492351532, 'acts_recall': 0.23414179682731628, 'acts_f1': 0.36892038583755493}

----------------------------------------

Training on batch 1055/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:14.88
Losses:
{'tot': 3.3420746326446533, 'pitches': 1.9613090753555298, 'dur': 1.0084956884384155, 'acts': 0.3722699284553528, 'rec': 3.3420746326446533, 'kld': 446.0151672363281, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.47214260697364807, 'pitches': 0.496

Training on batch 1067/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:23.83
Losses:
{'tot': 3.395444869995117, 'pitches': 2.0041942596435547, 'dur': 1.014426827430725, 'acts': 0.37682393193244934, 'rec': 3.395444869995117, 'kld': 456.8260498046875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4147968888282776, 'pitches': 0.4718646705150604, 'pitches_drums': 0.5303711295127869, 'pitches_non_drums': 0.4316596984863281, 'dur': 0.6825991868972778, 'acts_acc': 0.8408203125, 'acts_precision': 0.8715277910232544, 'acts_recall': 0.21592241525650024, 'acts_f1': 0.3460983633995056}

----------------------------------------

Training on batch 1068/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:24.63
Losses:
{'tot': 3.3909008502960205, 'pitches': 2.0094523429870605, 'dur': 1.0122959613800049, 'acts': 0.3691524565219879, 'rec': 3.3909008502960205, 'kld': 457.461669921875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4420175552368164, 'pitches': 0.489015847444

Training on batch 1080/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:34.01
Losses:
{'tot': 3.4384849071502686, 'pitches': 2.031425952911377, 'dur': 1.059930682182312, 'acts': 0.34712809324264526, 'rec': 3.4384849071502686, 'kld': 461.3748474121094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4268066883087158, 'pitches': 0.48342251777648926, 'pitches_drums': 0.6045026779174805, 'pitches_non_drums': 0.39887961745262146, 'dur': 0.6809384226799011, 'acts_acc': 0.86602783203125, 'acts_precision': 0.8178725838661194, 'acts_recall': 0.24409376084804535, 'acts_f1': 0.3759772777557373}

----------------------------------------

Training on batch 1081/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:34.76
Losses:
{'tot': 3.3939526081085205, 'pitches': 1.9689536094665527, 'dur': 1.069448709487915, 'acts': 0.35555028915405273, 'rec': 3.3939526081085205, 'kld': 455.1085510253906, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.446276992559433, 'pitches': 0.49817

Training on batch 1093/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:44.33
Losses:
{'tot': 3.4521431922912598, 'pitches': 2.1005496978759766, 'dur': 0.9786285161972046, 'acts': 0.3729647397994995, 'rec': 3.4521431922912598, 'kld': 461.6707763671875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42989885807037354, 'pitches': 0.4697172939777374, 'pitches_drums': 0.5254618525505066, 'pitches_non_drums': 0.4270188808441162, 'dur': 0.6941871643066406, 'acts_acc': 0.83966064453125, 'acts_precision': 0.7827818393707275, 'acts_recall': 0.20681560039520264, 'acts_f1': 0.32718655467033386}

----------------------------------------

Training on batch 1094/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:45.13
Losses:
{'tot': 3.6017932891845703, 'pitches': 2.2009494304656982, 'dur': 1.0292145013809204, 'acts': 0.37162917852401733, 'rec': 3.6017932891845703, 'kld': 463.3453369140625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.41755715012550354, 'pitches': 0.4

Training on batch 1106/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:57.59
Losses:
{'tot': 3.368277072906494, 'pitches': 1.9345264434814453, 'dur': 1.0845826864242554, 'acts': 0.34916773438453674, 'rec': 3.368277072906494, 'kld': 462.45611572265625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4669084846973419, 'pitches': 0.49539458751678467, 'pitches_drums': 0.6298868060112, 'pitches_non_drums': 0.4150857627391815, 'dur': 0.6829190254211426, 'acts_acc': 0.864898681640625, 'acts_precision': 0.8181003332138062, 'acts_recall': 0.24936270713806152, 'acts_f1': 0.3822215795516968}

----------------------------------------

Training on batch 1107/242302 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:14:58.36
Losses:
{'tot': 3.473883628845215, 'pitches': 2.060175657272339, 'dur': 1.0594549179077148, 'acts': 0.3542531132698059, 'rec': 3.473883628845215, 'kld': 464.9404296875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4376811683177948, 'pitches': 0.471723020076

Training from checkpoint:

In [None]:
# Where is the model? in models/model_name/checkpoint (cause we don't need the best yet)
# What were its parameters? model_state_dict
# What were the scheduler's parameters? optimizer_state_dict
# What were the optimizer's parameters (Adam)? scheduler_state_dict
# --> Ok! Now we have a model, a scheduler and an optimizer.

# What else do we need? A training checkpoint! This should be gathered inside Trainer!
# A training checkpoint is composed of many things such as:
#    last epoch, last batch, ...
# However, we also need a way to store the dataloaders. Why?
# At the beginning we have a full dataset in a directory. Then, we create two (three) subsets,
# corresponding to TR, VL (TS). 

## Reconstructions

In [None]:
checkpoint = torch.load('models/just_pitches_warmup')

In [None]:
state_dict = checkpoint['model_state_dict']
vae = VAE().to(device)

In [None]:
vae.load_state_dict(state_dict)

In [None]:
loader = DataLoader(dataset, batch_size=32, shuffle=False)
len(dataset)

In [None]:
for idx, inputs in enumerate(loader):
    
    x_seq, x_acts, x_graph, src_mask = inputs
    x_seq = x_seq.float().to(device)
    x_acts = x_acts.to(device)
    x_graph = x_graph.to(device)
    src_mask = src_mask.to(device)
    tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(device)

    # Forward pass, get the reconstructions
    outputs, mu, log_var = vae(x_seq, x_acts, x_graph, src_mask, tgt_mask)
    
    break

seq_rec, _  = outputs

In [None]:
x_seq.size()

In [None]:
seq_rec.size()

In [None]:
x_acts.size()

Create dense reconstruction from sparse reconstruction:

In [None]:
seq_rec_dense = torch.zeros(x_seq.size(), dtype=torch.float).to(device)
seq_rec_dense = seq_rec_dense[..., 1:, :]
size = seq_rec_dense.size()

seq_rec_dense = seq_rec_dense.view(-1, seq_rec_dense.size(-2), seq_rec_dense.size(-1))

silence = torch.zeros(seq_rec_dense.size(-2), seq_rec_dense.size(-1)).to(device)
silence[:, 129] = 1. # eos token

seq_rec_dense[x_acts.bool().view(-1)] = seq_rec
seq_rec_dense[torch.logical_not(x_acts.bool().view(-1))] = silence

seq_rec_dense = seq_rec_dense.view(size)

In [None]:
print(seq_rec_dense.size())
print(x_seq.size())

In [None]:
music_real = x_seq[0]
music_rec = seq_rec_dense[0]

In [None]:
music_real.size()

In [None]:
prefix = "data/music/"

real = from_tensor_to_muspy(music_real, track_data)
muspy.show_pianoroll(real, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "real" + ".png")
muspy.write_midi(prefix + "real" + ".mid", real)

In [None]:
rec = from_tensor_to_muspy(music_rec, track_data)
muspy.show_pianoroll(rec, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "rec" + ".png")
muspy.write_midi(prefix + "rec" + ".mid", rec)

Plot music and save it to disk

In [None]:
#tracks = [drum_track, bass_track, guitar_track, strings_track]
import copy

def from_tensor_to_muspy(music_tensor, track_data):
    
    powers = torch.tensor([2**n for n in reversed(range(9))], dtype=torch.float)
    tracks = []
    
    for tr in range(music_tensor.size(0)):
        
        notes = []
        
        for ts in range(music_tensor.size(1)):
            for note in range(music_tensor.size(2)):
                
                pitch = music_tensor[tr, ts, note, :131]
                pitch = torch.argmax(pitch)

                if pitch == 129:
                    break
                
                if pitch != 128:
                    #dur = music_tensor[tr, ts, note, 131:]
                    #dur = torch.dot(dur, powers).long()
                    
                    dur = 4
                    
                    #notes.append(muspy.Note(ts, pitch.item(), dur.item(), 64))
                    notes.append(muspy.Note(ts, pitch.item(), dur, 64))
        
        if track_data[tr][0] == 'Drums':
            track = muspy.Track(name='Drums', is_drum=True, notes=copy.deepcopy(notes))
        else:
            track = muspy.Track(name=track_data[tr][0], 
                                program=track_data[tr][1],
                                notes=copy.deepcopy(notes))
        tracks.append(track)
    
    meta = muspy.Metadata(title='prova')
    music = muspy.Music(tracks=tracks, metadata=meta, resolution=RESOLUTION)
    
    return music


track_data = [('Drums', -1), ('Bass', 34), ('Guitar', 1), ('Strings', 41)]

In [None]:
prefix = "data/music/file"

for i in range(10):
    music_tensor = dataset[20+i][0]
    music = from_tensor_to_muspy(music_tensor, track_data)
    muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
    plt.savefig(prefix + str(i) + ".png")
    muspy.write_midi(prefix + str(i) + ".mid", music)

In [None]:
music

In [None]:
music_path = "data/music/file2.mid"
muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
plt.savefig('file2.png')
muspy.write_midi(music_path, music)

In [None]:
print(dataset[0][0].size())
notes = []
notes.append(muspy.Note(1, 48, 20, 64))
drums = muspy.Track(is_drum=True)
bass = muspy.Track(program=34, notes=notes)
guitar = muspy.Track(program=27, notes=[])
strings = muspy.Track(program=42, notes=[muspy.Note(0, 100, 4, 64), muspy.Note(4, 91, 20, 64)])

tracks = [drums, bass, guitar, strings]

meta = muspy.Metadata(title='prova')
music = muspy.Music(tracks=tracks, metadata=meta, resolution=32)

In [None]:
!ls data/lmd_matched/M/T/O/TRMTOBP128E07822EF/63edabc86c087f07eca448b0edad53c3.mid

# Stuff

next edges

In [None]:
import itertools

a = np.random.randint(2, size=(4,8))
a_t = a.transpose()
print(a_t)
inds = np.stack(np.where(a_t == 1)).transpose()
ts_acts = np.any(a_t, axis=1)
ts_inds = np.where(ts_acts)[0]

labels = np.arange(32).reshape(4, 8).transpose()
print(labels)

next_edges = []
for i in range(len(ts_inds)-1):
    ind_s = ts_inds[i]
    ind_e = ts_inds[i+1]
    s = inds[inds[:,0] == ind_s]
    e = inds[inds[:,0] == ind_e]
    e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
    edges = [(labels[tuple(e[0])],labels[tuple(e[1])], ind_e-ind_s) for e in e_inds]
    next_edges.extend(edges)

print(next_edges)
    

onset edges

In [None]:
onset_edges = []
print(a_t)
print(labels)

for i in ts_inds:
    ts_acts_inds = list(inds[inds[:,0] == i])
    if len(ts_acts_inds) < 2:
        continue
    e_inds = list(itertools.combinations(ts_acts_inds, 2))
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], 0) for e in e_inds]
    inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
    onset_edges.extend(edges)
    onset_edges.extend(inv_edges)

print(onset_edges)


track edges

In [None]:
print(a_t)
print(labels)
track_edges = []

for track in range(a_t.shape[1]):
    tr_inds = list(inds[inds[:,1] == track])
    e_inds = [(tr_inds[i],
               tr_inds[i+1]) for i in range(len(tr_inds)-1)]
    print(e_inds)
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], e[1][0]-e[0][0]) for e in e_inds]
    track_edges.extend(edges)

print(track_edges)

In [None]:
track_edges = np.array(track_edges)
onset_edges = np.array(onset_edges)
np.concatenate((track_edges, onset_edges)).shape

In [None]:
pip install pypianoroll

In [None]:
import pypianoroll

In [None]:
multitrack = pypianoroll.read("tests_fur-elise.mid")
print(multitrack)

In [None]:
multitrack.tracks[0].pianoroll

In [None]:
multitrack.plot()

In [None]:
multitrack.trim(0, 12 * multitrack.resolution)
multitrack.binarize()

In [None]:
multitrack.plot()

In [None]:
multitrack.tracks[0].pianoroll.shape