<a href="https://colab.research.google.com/github/EmanueleCosenza/Polyphemus/blob/main/midi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pwd

/home/cosenza/thesis/Polyphemus


In [2]:
!git branch

  main
* sparse


Libraries installation

In [3]:
#!tar -C data -xvzf data/lmd_matched.tar.gz

In [4]:
# Install the required music libraries
#!pip3 install muspy
#!pip3 install pypianoroll

In [5]:
# Install torch_geometric
#!v=$(python3 -c "import torch; print(torch.__version__)"); \
#pip3 install torch-scatter -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-sparse -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-geometric

Reproducibility

In [6]:
import numpy as np
import torch
import random
import os

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

seed = 42
set_seed(seed)

In [7]:
import os
import muspy
from itertools import product
import pypianoroll as pproll
import time
from tqdm.auto import tqdm


class MIDIPreprocessor():
    
    def __init__():
        pass

    def preprocess_dataset(self, dir, early_exit=None):
        pass
    
    def preprocess_file(self, f):
        pass


# Todo: to config file (or separate files)
MAX_SIMU_NOTES = 16 # 14 + SOS and EOS

PITCH_SOS = 128
PITCH_EOS = 129
PITCH_PAD = 130
DUR_SOS = 96
DUR_EOS = 97
DUR_PAD = 98

MAX_DUR = 96

# Number of time steps per quarter note
# To get bar resolution -> RESOLUTION*4
RESOLUTION = 8 
NUM_BARS = 1


def preprocess_file(filepath, dest_dir, num_samples):

    saved_samples = 0

    print("Preprocessing file " + filepath)

    # Load the file both as a pypianoroll song and a muspy song
    # (Need to load both since muspy.to_pypianoroll() is expensive)
    try:
        pproll_song = pproll.read(filepath, resolution=RESOLUTION)
        muspy_song = muspy.read(filepath)
    except Exception as e:
        print("Song skipped (Invalid song format)")
        return 0
    
    # Only accept songs that have a time signature of 4/4 and no time changes
    for t in muspy_song.time_signatures:
        if t.numerator != 4 or t.denominator != 4:
            print("Song skipped ({}/{} time signature)".
                            format(t.numerator, t.denominator))
            return 0

    # Gather tracks of pypianoroll song based on MIDI program number
    drum_tracks = []
    bass_tracks = []
    guitar_tracks = []
    strings_tracks = []

    for track in pproll_song.tracks:
        if track.is_drum:
            #continue
            track.name = 'Drums'
            drum_tracks.append(track)
        elif 0 <= track.program <= 31:
            track.name = 'Guitar'
            guitar_tracks.append(track)
        elif 32 <= track.program <= 39:
            track.name = 'Bass'
            bass_tracks.append(track)
        else:
            # Tracks with program > 39 are all considered as strings tracks
            # and will be merged into a single track later on
            strings_tracks.append(track)

    # Filter song if it does not contain drum, guitar, bass or strings tracks
    #if not guitar_tracks \
    if not drum_tracks or not guitar_tracks \
            or not bass_tracks or not strings_tracks:
        print("Song skipped (does not contain drum or "
                "guitar or bass or strings tracks)")
        return 0
    
    # Merge strings tracks into a single pypianoroll track
    strings = pproll.Multitrack(tracks=strings_tracks)
    strings_track = pproll.Track(pianoroll=strings.blend(mode='max'),
                                 program=48, name='Strings')

    combinations = list(product(drum_tracks, bass_tracks, guitar_tracks))
    #combinations = list(product(bass_tracks, guitar_tracks))

    # Single instruments can have multiple tracks.
    # Consider all possible combinations of drum, bass, and guitar tracks
    for i, combination in enumerate(combinations):

        print("Processing combination", i+1, "of", len(combinations))
        
        # Process combination (called 'subsong' from now on)
        drum_track, bass_track, guitar_track = combination
        tracks = [drum_track, bass_track, guitar_track, strings_track]
        #bass_track, guitar_track = combination
        #tracks = [bass_track, guitar_track, strings_track]
        
        pproll_subsong = pproll.Multitrack(
            tracks=tracks,
            tempo=pproll_song.tempo,
            resolution=RESOLUTION
        )
        muspy_subsong = muspy.from_pypianoroll(pproll_subsong)
        
        tracks_notes = [track.notes for track in muspy_subsong.tracks]
        
        # Obtain length of subsong (maximum of each track's length)
        length = 0
        for notes in tracks_notes:
            track_length = max(note.end for note in notes) if notes else 0
            length = max(length, track_length)
        length += 1

        # Add timesteps until length is a multiple of RESOLUTION
        length = length if length%(RESOLUTION*4) == 0 \
                            else length + (RESOLUTION*4-(length%(RESOLUTION*4)))


        tracks_tensors = []
        tracks_activations = []

        # Todo: adapt to velocity
        for notes in tracks_notes:

            # Initialize encoder-ready track tensor
            # track_tensor: (length x max_simu_notes x 2 (or 3 if velocity))
            # The last dimension contains pitches and durations (and velocities)
            # int16 is enough for small to medium duration values
            track_tensor = np.zeros((length, MAX_SIMU_NOTES, 2), np.int16)

            track_tensor[:, :, 0] = PITCH_PAD
            track_tensor[:, 0, 0] = PITCH_SOS
            track_tensor[:, :, 1] = DUR_PAD
            track_tensor[:, 0, 1] = DUR_SOS

            # Keeps track of how many notes have been stored in each timestep
            # (int8 imposes that MAX_SIMU_NOTES < 256)
            notes_counter = np.ones(length, dtype=np.int8)

            # Todo: np.put_along_axis?
            for note in notes:
                # Insert note in the lowest position available in the timestep
                
                t = note.time
                
                if notes_counter[t] >= MAX_SIMU_NOTES-1:
                    # Skip note if there is no more space
                    continue
                
                track_tensor[t, notes_counter[t], 0] = note.pitch
                dur = min(MAX_DUR, note.duration)
                track_tensor[t, notes_counter[t], 1] = dur-1
                notes_counter[t] += 1
            
            # Add end of sequence token
            track_tensor[np.arange(0, length), notes_counter, 0] = PITCH_EOS
            track_tensor[np.arange(0, length), notes_counter, 1] = DUR_EOS
            
            # Get track activations, a boolean tensor indicating whether notes
            # are being played in a timestep (sustain does not count)
            # (needed for graph rep.)
            activations = np.array(notes_counter-1, dtype=bool)
            
            tracks_tensors.append(track_tensor)
            tracks_activations.append(activations)
        
        # (#tracks x length x max_simu_notes x 2 (or 3))
        subsong_tensor = np.stack(tracks_tensors, axis=0)

        # (#tracks x length)
        subsong_activations = np.stack(tracks_activations, axis=0)


        # Slide window over 'subsong_tensor' and 'subsong_activations' along the
        # time axis (2nd dimension) with the stride of a bar
        # Todo: np.lib.stride_tricks.as_strided(song_proll)
        for i in range(0, length-NUM_BARS*RESOLUTION*4+1, RESOLUTION*4):
            
            # Get the sequence and its activations
            seq_tensor = subsong_tensor[:, i:i+NUM_BARS*RESOLUTION*4, :]
            seq_acts = subsong_activations[:, i:i+NUM_BARS*RESOLUTION*4]
            seq_tensor = np.copy(seq_tensor)
            seq_acts = np.copy(seq_acts)

            if NUM_BARS > 1:
                # Skip sequence if it contains more than one bar of consecutive
                # silence in at least one track
                bars = seq_acts.reshape(seq_acts.shape[0], NUM_BARS, -1)
                bars_acts = np.any(bars, axis=2)

                if 1 in np.diff(np.where(bars_acts == 0)[1]):
                    continue
            else:
                # In the case of just 1 bar, skip it if all tracks are silenced
                bar_acts = np.any(seq_acts, axis=1)
                if not np.any(bar_acts):
                    continue
            
            # Randomly transpose the pitches of the sequence (-5 to 6 semitones)
            # Not considering pad, sos, eos tokens
            # Not transposing drums/percussions
            shift = np.random.choice(np.arange(-5, 7), 1)
            cond = (seq_tensor[1:, :, :, 0] != PITCH_PAD) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_SOS) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_EOS)
            #cond = (seq_tensor[:, :, :, 0] != PITCH_PAD) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_SOS) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_EOS)
            non_perc = seq_tensor[1:, ...]
            #non_perc = seq_tensor
            non_perc[cond, 0] += shift

            # Save sample (seq_tensor and seq_acts) to file
            curr_sample = str(num_samples + saved_samples)
            sample_filepath = os.path.join(dest_dir, curr_sample)
            np.savez(sample_filepath, seq_tensor=seq_tensor, seq_acts=seq_acts)

            saved_samples += 1


    print("File preprocessing finished. Saved samples:", saved_samples)
    print()

    return saved_samples



# Total number of files: 116189
# Number of unique files: 45129
def preprocess_dataset(dataset_dir, dest_dir, num_files=45129, early_exit=None):

    files_dict = {}
    seen = 0
    tot_samples = 0
    not_filtered = 0
    finished = False
    
    print("Starting preprocessing")
    
    progress_bar = tqdm(range(early_exit)) if early_exit is not None else tqdm(range(num_files))
    start = time.time()

    # Visit recursively the directories inside the dataset directory
    for dirpath, dirs, files in os.walk(dataset_dir):

        # Sort alphabetically the found directories
        # (to help guess the remaining time) 
        dirs.sort()
        
        print("Current path:", dirpath)
        print()

        for f in files:
            
            seen += 1

            if f in files_dict:
                # Skip already seen file
                files_dict[f] += 1
                continue

            # File never seen before, add to dictionary of files
            # (from filename to # of occurrences)
            files_dict[f] = 1

            # Preprocess file
            filepath = os.path.join(dirpath, f)
            n_saved = preprocess_file(filepath, dest_dir, tot_samples)

            tot_samples += n_saved
            if n_saved > 0:
                not_filtered += 1
            
            progress_bar.update(1)
            
            # Todo: also print # of processed (not filtered) files
            #       and # of produced sequences (samples)
            print("Total number of seen files:", seen)
            print("Number of unique files:", len(files_dict))
            print("Total number of non filtered songs:", not_filtered)
            print("Total number of saved samples:", tot_samples)
            print()

            # Exit when a maximum number of files has been processed (if set)
            if early_exit != None and len(files_dict) >= early_exit:
                finished = True
                break

        if finished:
            break
    
    end = time.time()
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Preprocessing completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
              .format(int(hours),int(minutes),seconds))


In [None]:
!rm -rf data/preprocessed/
!mkdir data/preprocessed

In [8]:
dataset_dir = 'data/lmd_matched/Y/G/'
dest_dir = 'data/preprocessed'

Check preprocessed data:

In [None]:
preprocess_dataset(dataset_dir, dest_dir, early_exit=50)

In [9]:
filepath = os.path.join(dest_dir, "5.npz")
data = np.load(filepath)

In [10]:
print(data["seq_tensor"].shape)
print(data["seq_acts"].shape)

(4, 32, 16, 2)
(4, 32)


In [11]:
data["seq_tensor"][0, 1]

array([[128,  96],
       [129,  97],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98],
       [130,  98]], dtype=int16)

# Model

In [12]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.data import Data
import itertools


def unpackbits(x, num_bits):

    if np.issubdtype(x.dtype, np.floating):
        raise ValueError("numpy data type needs to be int-like")

    xshape = list(x.shape)
    x = x.reshape([-1, 1])
    mask = 2**np.arange(num_bits, dtype=x.dtype).reshape([1, num_bits])

    return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits])


class MIDIDataset(Dataset):

    def __init__(self, dir):
        self.dir = dir

    def __len__(self):
        _, _, files = next(os.walk(self.dir))
        return len(files)

    
    def __get_track_edges(self, acts, edge_type_ind=0):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        
        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        track_edges = []

        for track in range(a_t.shape[1]):
            tr_inds = list(inds[inds[:,1] == track])
            e_inds = [(tr_inds[i],
                    tr_inds[i+1]) for i in range(len(tr_inds)-1)]
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, e[1][0]-e[0][0]) for e in e_inds]
            track_edges.extend(edges)

        return np.array(track_edges, dtype='long')

    
    def __get_onset_edges(self, acts, edge_type_ind=1):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        onset_edges = []

        for i in ts_inds:
            ts_acts_inds = list(inds[inds[:,0] == i])
            if len(ts_acts_inds) < 2:
                continue
            e_inds = list(itertools.combinations(ts_acts_inds, 2))
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, 0) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            onset_edges.extend(edges)
            onset_edges.extend(inv_edges)

        return np.array(onset_edges, dtype='long')


    def __get_next_edges(self, acts, edge_type_ind=2):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        next_edges = []

        for i in range(len(ts_inds)-1):

            ind_s = ts_inds[i]
            ind_e = ts_inds[i+1]
            s = inds[inds[:,0] == ind_s]
            e = inds[inds[:,0] == ind_e]

            e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
            edges = [(labels[tuple(e[0])],labels[tuple(e[1])], edge_type_ind, ind_e-ind_s) for e in e_inds]

            next_edges.extend(edges)

        return np.array(next_edges, dtype='long')
    
    
    def _get_node_features(self, acts, num_nodes):
        
        num_tracks = acts.shape[0]
        features = torch.zeros((num_nodes, num_tracks), dtype=torch.float)
        features[np.arange(num_nodes), np.stack(np.where(acts))[0]] = 1.
        
        return features


    def __getitem__(self, idx):

        # Load tensors
        sample_path = os.path.join(self.dir, str(idx) + ".npz")
        data = np.load(sample_path)

        seq_tensor = data["seq_tensor"]
        seq_acts = data["seq_acts"]
        
        # Construct src_key_padding_mask (PAD = 130)
        src_mask = torch.from_numpy((seq_tensor[..., 0] == 130))

        # From decimals to one-hot (pitch)
        pitches = seq_tensor[:, :, :, 0]
        onehot_p = np.zeros((pitches.shape[0]*pitches.shape[1]*pitches.shape[2],
                            131), dtype=float)
        onehot_p[np.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1.
        onehot_p = onehot_p.reshape(-1, pitches.shape[1], seq_tensor.shape[2], 131)

        # From decimals to one-hot (dur)
        durs = seq_tensor[:, :, :, 1]
        onehot_d = np.zeros((durs.shape[0]*durs.shape[1]*durs.shape[2],
                            99), dtype=float)
        onehot_d[np.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1.
        onehot_d = onehot_d.reshape(-1, durs.shape[1], seq_tensor.shape[2], 99)
        #bin_durs = unpackbits(durs, 9)[:, :, :, ::-1]
        
        
        # Concatenate pitches and durations
        new_seq_tensor = np.concatenate((onehot_p, onehot_d),
                                        axis=-1)
        #new_seq_tensor = onehot_p
        
        # Construct graph from boolean activations
        # Todo: optimize and refactor
        track_edges = self.__get_track_edges(seq_acts)
        onset_edges = self.__get_onset_edges(seq_acts)
        next_edges = self.__get_next_edges(seq_acts)
        edges = [track_edges, onset_edges, next_edges]

        # Concatenate edge tensors (N x 4) (if any)
        # Third column -> edge_type, Fourth column -> timestep distance
        no_edges = (len(track_edges) == 0 and 
                    len(onset_edges) == 0 and len(next_edges) == 0)
        if not no_edges:
            edge_list = np.concatenate([x for x in edges
                                          if x.size > 0])
            edge_list = torch.from_numpy(edge_list)
        
        # Adapt tensor to torch_geometric's Data
        # No edges: add fictitious self-edge
        edge_index = (torch.LongTensor([[0], [0]]) if no_edges else
                               edge_list[:, :2].t().contiguous())
        attrs = (torch.Tensor([[0, 0]]) if no_edges else
                                       edge_list[:, 2:])
        
        edge_attrs = torch.zeros(attrs.size(0), 1+seq_acts.shape[-1])
        edge_attrs[:, 0] = attrs[:, 0]
        edge_attrs[np.arange(edge_attrs.size(0)), attrs.long()[:, 1]+1] = 1
        
        
        #n = seq_acts.shape[0]*seq_acts.shape[1]
        n = torch.sum(torch.Tensor(seq_acts), dtype=torch.long) # sparse
        node_features = self._get_node_features(seq_acts, n)
        graph = Data(edge_index=edge_index, edge_attrs=edge_attrs,
                     num_nodes=n, node_features=node_features)
        
        # Todo: start with torch at mount
        return torch.Tensor(new_seq_tensor), torch.Tensor(seq_acts), graph, src_mask


In [13]:
from typing import Optional, Union, Tuple
from torch_geometric.typing import OptTensor, Adj
from typing import Callable
from torch_geometric.nn.inits import reset

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter as Param
from torch.nn import Parameter
from torch_scatter import scatter
from torch_sparse import SparseTensor, matmul, masked_select_nnz
from torch_geometric.nn.conv import MessagePassing

from torch_geometric.nn.inits import glorot, zeros


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (Tensor, Tensor) -> Tensor
    pass


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (SparseTensor, Tensor) -> SparseTensor
    pass


def masked_edge_index(edge_index, edge_mask):
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]
    else:
        return masked_select_nnz(edge_index, edge_mask, layout='coo')

def masked_edge_attrs(edge_attrs, edge_mask):
    return edge_attrs[edge_mask, :]


class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
        \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
    stores a relation identifier
    :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.

    .. note::
        This implementation is as memory-efficient as possible by iterating
        over each individual relation type.
        Therefore, it may result in low GPU utilization in case the graph has a
        large number of relations.
        As an alternative approach, :class:`FastRGCNConv` does not iterate over
        each individual type, but may consume a large amount of memory to
        compensate.
        We advise to check out both implementations to see which one fits your
        needs.

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
            In case no input features are given, this argument should
            correspond to the number of nodes in your graph.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
            maps edge features :obj:`edge_attr` of shape :obj:`[-1,
            num_edge_features]` to shape
            :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`.
        num_bases (int, optional): If set to not :obj:`None`, this layer will
            use the basis-decomposition regularization scheme where
            :obj:`num_bases` denotes the number of bases to use.
            (default: :obj:`None`)
        num_blocks (int, optional): If set to not :obj:`None`, this layer will
            use the block-diagonal-decomposition regularization scheme where
            :obj:`num_blocks` denotes the number of blocks to use.
            (default: :obj:`None`)
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"mean"`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_relations: int,
        nn: Callable,
        num_bases: Optional[int] = None,
        num_blocks: Optional[int] = None,
        aggr: str = 'mean',
        root_weight: bool = True,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(aggr=aggr, node_dim=0, **kwargs)

        if num_bases is not None and num_blocks is not None:
            raise ValueError('Can not apply both basis-decomposition and '
                             'block-diagonal-decomposition at the same time.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.nn = nn
        self.num_relations = num_relations
        self.num_bases = num_bases
        self.num_blocks = num_blocks

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)
        self.in_channels_l = in_channels[0]

        if num_bases is not None:
            self.weight = Parameter(
                torch.Tensor(num_bases, in_channels[0], out_channels))
            self.comp = Parameter(torch.Tensor(num_relations, num_bases))

        elif num_blocks is not None:
            assert (in_channels[0] % num_blocks == 0
                    and out_channels % num_blocks == 0)
            self.weight = Parameter(
                torch.Tensor(num_relations, num_blocks,
                             in_channels[0] // num_blocks,
                             out_channels // num_blocks))
            self.register_parameter('comp', None)

        else:
            self.weight = Parameter(
                torch.Tensor(num_relations, in_channels[0], out_channels))
            self.register_parameter('comp', None)

        if root_weight:
            self.root = Param(torch.Tensor(in_channels[1], out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Param(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        reset(self.nn)
        glorot(self.comp)
        glorot(self.root)
        zeros(self.bias)


    def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
                edge_index: Adj, edge_type: OptTensor = None,
                edge_attr: OptTensor = None):
        r"""
        Args:
            x: The input node features. Can be either a :obj:`[num_nodes,
                in_channels]` node feature matrix, or an optional
                one-dimensional node index tensor (in which case input features
                are treated as trainable node embeddings).
                Furthermore, :obj:`x` can be of type :obj:`tuple` denoting
                source and destination node features.
            edge_type: The one-dimensional relation type/index for each edge in
                :obj:`edge_index`.
                Should be only :obj:`None` in case :obj:`edge_index` is of type
                :class:`torch_sparse.tensor.SparseTensor`.
                (default: :obj:`None`)
        """

        # Convert input features to a pair of node features or node indices.
        x_l: OptTensor = None
        if isinstance(x, tuple):
            x_l = x[0]
        else:
            x_l = x
        if x_l is None:
            x_l = torch.arange(self.in_channels_l, device=self.weight.device)

        x_r: Tensor = x_l
        if isinstance(x, tuple):
            x_r = x[1]

        size = (x_l.size(0), x_r.size(0))

        if isinstance(edge_index, SparseTensor):
            edge_type = edge_index.storage.value()
        assert edge_type is not None

        # propagate_type: (x: Tensor)
        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

        weight = self.weight
        if self.num_bases is not None:  # Basis-decomposition =================
            weight = (self.comp @ weight.view(self.num_bases, -1)).view(
                self.num_relations, self.in_channels_l, self.out_channels)

        if self.num_blocks is not None:  # Block-diagonal-decomposition =====

            if x_l.dtype == torch.long and self.num_blocks is not None:
                raise ValueError('Block-diagonal decomposition not supported '
                                 'for non-continuous input features.')

            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                h = self.propagate(tmp, x=x_l, size=size)
                h = h.view(-1, weight.size(1), weight.size(2))
                h = torch.einsum('abc,bcd->abd', h, weight[i])
                out += h.contiguous().view(-1, self.out_channels)

        else:  # No regularization/Basis-decomposition ========================
            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                attr = masked_edge_attrs(edge_attr, edge_type == i)

                if x_l.dtype == torch.long:
                    out += self.propagate(tmp, x=weight[i, x_l], size=size)
                else:
                    h = self.propagate(tmp, x=x_l, size=size,
                                       edge_attr=attr)
                    out = out + (h @ weight[i])

        root = self.root
        if root is not None:
            out += root[x_r] if x_r.dtype == torch.long else x_r @ root

        if self.bias is not None:
            out += self.bias

        return out


    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        weight = self.nn(edge_attr)
        weight = weight.view(-1, self.in_channels_l, self.in_channels_l)
        return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_relations={self.num_relations})')

In [14]:
import torch
from torch import nn, Tensor
from torch_geometric.nn.conv import GCNConv#, RGCNConv
import torch.nn.functional as F
import math
import torch.optim as optim
from torch_scatter import scatter_mean


# Todo: check and think about max_len
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 256):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *                     \
                             (-math.log(10000.0)/d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position*div_term)
        pe[:, 0, 1::2] = torch.cos(position*div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    

class MLP(nn.Module):
    
    def __init__(self, input_dim=256, hidden_dim=256, output_dim=256):
        super().__init__()
        self.layers = nn.Sequential(
          nn.Flatten(),
          nn.Linear(input_dim, hidden_dim),
          nn.ReLU(),
          nn.Linear(hidden_dim, output_dim)
        )
        

    def forward(self, x):
        return self.layers(x)


class GraphEncoder(nn.Module):
    
    def __init__(self, input_dim=256, dim_hidden=256, num_layers=3, num_relations=3,
                 edge_features_dim=32, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList()
        edge_nn = nn.Linear(edge_features_dim, input_dim*input_dim)
        self.layers.append(RGCNConv(input_dim, dim_hidden, num_relations, edge_nn))
        for i in range(num_layers-1):
            edge_nn = nn.Linear(edge_features_dim, dim_hidden*dim_hidden)
            self.layers.append(RGCNConv(dim_hidden, dim_hidden, num_relations, edge_nn))
        self.p = dropout

    def forward(self, data):
        x, edge_index, edge_attrs = data.x, data.edge_index, data.edge_attrs
        edge_type = edge_attrs[:, 0]
        edge_attr = edge_attrs[:, 1:]
        
        for layer in self.layers:
            x = F.dropout(x, p=self.p, training=self.training)
            x = layer(x, edge_index, edge_type, edge_attr)
            x = F.relu(x)

        return x


class Encoder(nn.Module):

    def __init__(self, d_token=230, d_transf=256, nhead_transf=2, 
                 num_layers_transf=2, n_tracks=4, dropout=0.1):
        super().__init__()

        # Todo: one separate encoder for drums
        # Transformer Encoder
        self.embedding = nn.Linear(d_token, d_transf)
        self.pos_encoder = PositionalEncoding(d_transf, dropout=dropout)
        transf_layer = nn.TransformerEncoderLayer(
            d_model=d_transf,
            nhead=nhead_transf,
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(
            transf_layer,
            num_layers=num_layers_transf
        )

        # Graph encoder
        self.graph_encoder = GraphEncoder(dropout=dropout, input_dim=n_tracks+256)

        # (LSTM)
        
        # Linear layers that compute the final mu and log_var
        # Todo: as parameters
        self.linear_mu = nn.Linear(256, 256)
        self.linear_log_var = nn.Linear(256, 256)

        
    def forward(self, x_seq, x_acts, x_graph, src_mask):

        # Collapse track (and optionally batch) dimension
        #print("Init input:", x_seq.size())
        x_seq = x_seq.view(-1, x_seq.size(-2), x_seq.size(-1))
        #print("Reshaped input:", x_seq.size())
        
        # Filter silences
        x_acts = x_acts.view(-1)
        x_seq = x_seq[x_acts.bool()]
        src_mask = src_mask[x_acts.bool()]

        # Compute embeddings
        embs = self.embedding(x_seq)
        #print("Embs:", embs.size())

        # batch_first = False
        embs = embs.permute(1, 0, 2)
        #print("Seq len first input:", embs.size())

        pos_encs = self.pos_encoder(embs)
        #print("Pos encodings:", pos_encs.size())

        transformer_encs = self.transformer_encoder(pos_encs, 
                                                    src_key_padding_mask=src_mask)
        #print("Transf encodings:", transformer_encs.size())

        pooled_encs = torch.mean(transformer_encs, 0)
        #print("Pooled encodings:", pooled_encs.size())

        # Concatenate track one hot features with chord encodings
        # and compute node encodings
        x_graph.x = torch.cat((x_graph.node_features, pooled_encs), 1)
        node_encs = self.graph_encoder(x_graph)
        #print("Node encodings:", node_encs.size())
        
        # Compute final graph latent vector(s)
        # (taking into account the batch size)
        encoding = scatter_mean(node_encs, x_graph.batch, dim=0)
        #num_nodes = x_graph[0].num_nodes
        #batch_sz = node_encs.size(0) // num_nodes
        #node_encs = node_encs.view(batch_sz, num_nodes, -1)
        #encoding = torch.mean(node_encs, 1)

        # Compute mu and log(std^2)
        mu = self.linear_mu(encoding)
        log_var = self.linear_log_var(encoding)
        
        return mu, log_var


class Decoder(nn.Module):

    def __init__(self, d_z=256, n_tracks=4, resolution=32, d_token=230, d_model=256,
                 d_transf=256, nhead_transf=2, num_layers_transf=2, dropout=0.1):
        super().__init__()

        # (LSTM)

        # Boolean activations decoder (CNN/MLP)
        self.acts_decoder = MLP(d_z, d_model, n_tracks*resolution)

        # GNN
        self.graph_decoder = GraphEncoder(dropout=dropout, input_dim=n_tracks+256)
        
        # Transformer Decoder
        self.embedding = nn.Linear(d_token, d_transf)
        self.pos_encoder = PositionalEncoding(d_transf, dropout=dropout)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead_transf,
            dropout=dropout
        )
        self.transf_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers_transf
        )
        
        # Last linear layer
        self.lin = nn.Linear(d_model, d_token)


    def forward(self, z, x_seq, x_acts, x_graph, src_mask, tgt_mask):

        # Compute activations from z
        acts_out = self.acts_decoder(z)
        acts_out = acts_out.view(x_acts.size())
        #print("Acts out:", acts_out.size())

        # Initialize node features with z and propagate with GNN
        _, counts = torch.unique(x_graph.batch, return_counts=True)
        z_node_features = torch.repeat_interleave(z, counts, axis=0)
        #print("Node features:", node_features.size())
        
        # Add one-hot encoding of tracks
        # Todo: use also edge info
        x_graph.x = torch.cat((x_graph.node_features, z_node_features), 1)
        node_decs = self.graph_decoder(x_graph)
        #print("Node decodings:", node_decs.size())
        
        # Prepare transformer memory
        node_decs = node_decs.repeat(16, 1, 1)
        #print("Tiled node decodings:", node_decs.size())
        
        # Filter silences
        x_seq = x_seq.view(-1, x_seq.size(-2), x_seq.size(-1))
        x_acts = x_acts.view(-1)
        x_seq = x_seq[x_acts.bool()]
        src_mask = src_mask[x_acts.bool()]
        #print(src_mask.size())
        #print(x_seq.size())
        
        # Todo: same embeddings as encoder?
        embs = self.embedding(x_seq)
        embs = embs.permute(1, 0, 2)
        pos_encs = self.pos_encoder(embs)

        seq_out = self.transf_decoder(pos_encs, node_decs,
                                      tgt_key_padding_mask=src_mask,
                                      tgt_mask=tgt_mask)
        #print("Seq out:", seq_out.size())
        
        seq_out = self.lin(seq_out)
        #print("Seq out after lin:", seq_out.size())
        
        # Softmax on first 131 values (pitch), sigmoid on last 9 (dur)
        #seq_out[:, :, :131] = F.log_softmax(seq_out[:, :, :131], dim=-1)
        #seq_out[:, :, 131:] = torch.sigmoid(seq_out[:, :, 131:])
        seq_out = seq_out.permute(1, 0, 2)
        seq_out = seq_out.view(x_seq.size())
        #print("Seq out after reshape", seq_out.size())
        

        return seq_out, acts_out


class VAE(nn.Module):

    def __init__(self, dropout=0.1, **kwargs):
        super().__init__()

        self.encoder = Encoder(dropout=dropout)
        self.decoder = Decoder(dropout=dropout)
    
    
    def forward(self, x_seq, x_acts, x_graph, src_mask, tgt_mask):
        
        src_mask = src_mask.view(-1, src_mask.size(-1))
        
        mu, log_var = self.encoder(x_seq, x_acts, x_graph, src_mask)
        #print("Mu:", mu.size())
        #print("log_var:", log_var.size())
        
        # Reparameterization trick
        sigma = torch.exp(0.5*log_var)
        eps = torch.randn_like(sigma)
        #print("eps:", eps.size())
        z = mu + eps*sigma
        
        tgt = x_seq[..., :-1, :]
        src_mask = src_mask[:, :-1]
        
        out = self.decoder(z, tgt, x_acts, x_graph, src_mask, tgt_mask)
        
        return out, mu, log_var


Trainer

In [24]:
import torch.optim as optim
import matplotlib.pyplot as plt
import uuid
import copy
import time
from statistics import mean
from collections import defaultdict


def generate_square_subsequent_mask(sz):
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


def append_dict(dest_d, source_d):
        
    for k, v in source_d.items():
        dest_d[k].append(v)


class VAETrainer():
    
    def __init__(self, model, models_dir, optimizer, init_lr,
                 name=None, lr_scheduler=None, device=torch.device("cuda"), 
                 print_every=1, save_every=1, eval_every=100):
        
        self.model = model
        self.models_dir = models_dir
        self.optimizer = optimizer
        self.init_lr = init_lr
        self.name = name if name is not None else str(uuid.uuid4())
        self.lr_scheduler = lr_scheduler
        self.device = device
        self.print_every = print_every
        self.save_every = save_every
        self.eval_every = eval_every
        
        # Create model dir, raise error if it already exists
        self.model_dir = os.path.join(self.models_dir, self.name)
        os.makedirs(self.model_dir, exist_ok=False)
        
        # Criteria with ignored padding
        self.bce_unreduced = nn.BCEWithLogitsLoss(reduction='none')
        self.ce_p = nn.CrossEntropyLoss(ignore_index=130)
        self.ce_d = nn.CrossEntropyLoss(ignore_index=98)
        
        # Training stats
        self.tr_losses = defaultdict(list)
        self.tr_accuracies = defaultdict(list)
        self.val_losses = defaultdict(list)
        self.val_accuracies = defaultdict(list)
        self.lrs = []
        self.times = []
        
    
    def train(self, trainloader, validloader=None, epochs=1,
              early_exit=None):

        self.beta = 0 # Todo: _update_params()
        min_val_loss = np.inf
        
        self.model.train()
        
        print("Starting training.\n")
        
        start = time.time()
        self.times.append(start)
        
        tot_batches = 0
        
        progress_bar = tqdm(range(len(trainloader)))
        
        for epoch in range(epochs):
            
            self.cur_epoch = epoch
            
            for batch_idx, inputs in enumerate(trainloader):
                
                self.cur_batch_idx = batch_idx
                
                # Zero out the gradients
                self.optimizer.zero_grad()
                
                # Get the inputs
                x_seq, x_acts, x_graph, src_mask = inputs
                x_seq = x_seq.float().to(self.device)
                x_acts = x_acts.to(self.device)
                x_graph = x_graph.to(self.device)
                src_mask = src_mask.to(self.device)
                tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(self.device)
                inputs = (x_seq, x_acts, x_graph)

                # Forward pass, get the reconstructions
                outputs, mu, log_var = self.model(x_seq, x_acts, x_graph, src_mask, tgt_mask)
                
                # Compute the backprop loss and other required losses
                tot_loss, losses = self._compute_losses(inputs, outputs, mu,
                                                         log_var)
                
                # Backprop and update lr
                tot_loss.backward()
                self.optimizer.step()
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                
                # Update the stats
                append_dict(self.tr_losses, losses)
                last_lr = (self.lr_scheduler.lr 
                               if self.lr_scheduler is not None else self.init_lr)
                self.lrs.append(last_lr)
                accs = self._compute_accuracies(inputs, outputs)
                append_dict(self.tr_accuracies, accs)
                now = time.time()
                self.times.append(now)
                
                # Print stats
                if (tot_batches + 1) % self.print_every == 0:
                    print("Training on batch {}/{} of epoch {}/{} complete."
                          .format(batch_idx+1, len(trainloader), epoch+1, epochs))
                    self._print_stats()
                    #print("Tot_loss: {:.4f} acts_loss: {:.4f} "
                          #.format(running_loss/self.print_every, acts_loss), end='')
                    #print("pitches_loss: {:.4f} dur_loss: {:.4f} kld_loss: {:.4f}"
                          #.format(pitches_loss, dur_loss, kld_loss))
                    print("\n----------------------------------------\n")
                    
                # ------------------------------------
                # EVAL ON VL SET EVERY N GRADIENT UPDATES
                # ------------------------------------
                
                if validloader is not None and (tot_batches + 1) % self.eval_every == 0:
                    
                    # Evaluate on val set and update stats
                    print("\nEvaluating on validation set...\n")
                    val_losses, val_accuracies = self.evaluate(validloader)
                    append_dict(self.val_losses, val_losses)
                    append_dict(self.val_accuracies, val_accuracies)
                    
                    print("Val losses:")
                    print(val_losses)
                    print("Val accuracies:")
                    print(val_accuracies)
                    
                    # Save model if val loss (tot) reached a new minimum
                    tot_loss = val_losses['tot']
                    if tot_loss < min_val_loss:
                        print("\nValidation loss improved.")
                        print("Saving new best model to disk...\n")
                        self._save_model('best_model')
                        min_val_loss = tot_loss
                    
                    self.model.train()
                
                progress_bar.update(1)     
                    
                # When appropriate, save model and stats on disk
                if self.save_every > 0 and (tot_batches + 1) % self.save_every == 0:
                    print("\nSaving model to disk...\n")
                    self._save_model('checkpoint')
                
                # Stop prematurely if early_exit is set and reached
                if early_exit is not None and (tot_batches + 1) > early_exit:
                    break
                
                tot_batches += 1
            

        end = time.time()
        # Todo: self.__print_time()
        hours, rem = divmod(end-start, 3600)
        minutes, seconds = divmod(rem, 60)
        print("Training completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours),int(minutes),seconds))
        
        print("Saving model to disk...")
        self._save_model('checkpoint')
        
        print("Model saved.")
        
    
    def evaluate(self, loader):
        
        losses = defaultdict(list)
        accs = defaultdict(list)
        
        self.model.eval()
        progress_bar = tqdm(range(len(loader)))
        
        with torch.no_grad():
            for batch_idx, inputs in enumerate(loader):

                # Get the inputs and move them to device
                x_seq, x_acts, x_graph, src_mask = inputs
                x_seq = x_seq.float().to(self.device)
                x_acts = x_acts.to(self.device)
                x_graph = x_graph.to(self.device)
                src_mask = src_mask.to(self.device)
                tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(self.device)
                inputs = (x_seq, x_acts, x_graph)

                # Forward pass, get the reconstructions
                outputs, mu, log_var = self.model(x_seq, x_acts, x_graph, src_mask, tgt_mask)

                # Compute losses and accuracies wrt batch
                _, losses_b = self._compute_losses(inputs, outputs, mu,
                                                         log_var)
                accs_b = self._compute_accuracies(inputs, outputs)
                
                # Save losses and accuracies
                append_dict(losses, losses_b)
                append_dict(accs, accs_b)
                
                progress_bar.update(1)
        
        
        # Compute avg losses and accuracies
        avg_losses = {}
        for k, l in losses.items():
            avg_losses[k] = mean(l)
            
        avg_accs = {}
        for k, l in accs.items():
            avg_accs[k] = mean(l)
            
        return avg_losses, avg_accs
                
        
    
    def _compute_losses(self, inputs, outputs, mu, log_var):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs for transformer decoder loss and filter silences
        x_seq = x_seq[..., 1:, :]
        x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
                
        # Compute the losses
        
        acts_loss = self.bce_unreduced(acts_rec.view(-1), x_acts.view(-1).float())
        weights = torch.zeros(acts_loss.size()).to(device)
        weights[x_acts.view(-1) == 1] = 0.9
        weights[x_acts.view(-1) == 0] = 0.1
        acts_loss = torch.mean(weights*acts_loss)
        #acts_loss = 50 * torch.mean(weights*acts_loss)
        
        pitches_loss = self.ce_p(seq_rec.reshape(-1, seq_rec.size(-1))[:, :131],
                          x_seq.reshape(-1, x_seq.size(-1))[:, :131].argmax(dim=1))
        dur_loss = self.ce_d(seq_rec.reshape(-1, seq_rec.size(-1))[:, 131:],
                          x_seq.reshape(-1, x_seq.size(-1))[:, 131:].argmax(dim=1))
        #dur_loss = mask * dur_loss
        #dur_loss = torch.sum(dur_loss) / torch.sum(mask)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        rec_loss = pitches_loss + dur_loss + acts_loss
        tot_loss = rec_loss + self.beta*kld_loss
        
        losses = {
            'tot': tot_loss.item(),
            'pitches': pitches_loss.item(),
            'dur': dur_loss.item(),
            'acts': acts_loss.item(),
            'rec': rec_loss.item(),
            'kld': kld_loss.item(),
            'beta*kld': self.beta*kld_loss.item()
        }
        
        return tot_loss, losses
            
            
    def _compute_accuracies(self, inputs, outputs):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs for transformer decoder loss
        x_seq = x_seq[..., 1:, :]
        x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
        
        notes_acc = self._note_accuracy(seq_rec, x_seq)
        pitches_acc = self._pitches_accuracy(seq_rec, x_seq)
        dur_acc = self._dur_accuracy(seq_rec, x_seq)
        acts_acc = self._acts_accuracy(acts_rec, x_acts)
        
        accs = {
            'notes': notes_acc.item(),
            'pitches': pitches_acc.item(),
            'dur': dur_acc.item(),
            'acts': acts_acc.item()
        }
        
        return accs
    
    
    def _note_accuracy(self, seq_rec, x_seq):
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        #print(torch.all(pitches_rec == 129))
        #print(pitches_rec)
        
        mask_p = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask_p)
        
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        #print(torch.all(dur_rec == 97))
        
        mask_d = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask_d)
        
        return torch.sum(torch.logical_and(preds_pitches, 
                                           preds_dur)) / torch.sum(mask_p)
    
    
    def _acts_accuracy(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        #print("All zero acts?", torch.all(acts_rec == 0))
        #print("All one acts?", torch.all(acts_rec == 0))
        
        return torch.sum(acts_rec == x_acts) / x_acts.numel()
    
    
    def _pitches_accuracy(self, seq_rec, x_seq):
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        #print("All EOS pitches?", torch.all(pitches_rec == 129))
        
        mask = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask)
        
        return torch.sum(preds_pitches) / torch.sum(mask)
    
    
    def _dur_accuracy(self, seq_rec, x_seq):
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        #print("All EOS durs?", torch.all(dur_rec == 97))
        
        mask = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask)
        
        return torch.sum(preds_dur) / torch.sum(mask)
    
    
    def _save_model(self, filename):
        path = os.path.join(self.model_dir, filename)
        torch.save({
            'epoch': self.cur_epoch,
            'batch': self.cur_batch_idx,
            'save_every': self.save_every,
            'eval_every': self.eval_every,
            'lrs': self.lrs,
            'tr_losses': self.tr_losses,
            'tr_accuracies': self.tr_accuracies,
            'val_losses': self.val_losses,
            'val_accuracies': self.val_accuracies,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }, path)
        
        
    def _print_stats(self):
        
        hours, rem = divmod(self.times[-1]-self.times[0], 3600)
        minutes, seconds = divmod(rem, 60)
        print("Elapsed time from start (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours), int(minutes), seconds))
        
        avg_lr = mean(self.lrs[-self.print_every:])
        
        # Take mean of the last non-printed batches for each stat
        
        avg_losses = {}
        for k, l in self.tr_losses.items():
            avg_losses[k] = mean(l[-self.print_every:])
        
        avg_accs = {}
        for k, l in self.tr_accuracies.items():
            avg_accs[k] = mean(l[-self.print_every:])
        
        print("Losses:")
        print(avg_losses)
        print("Accuracies:")
        print(avg_accs)
        


Training

In [25]:
models_dir = "models/"
os.makedirs(models_dir, exist_ok=True)

In [26]:
from torch.utils.data import Subset
from torch.utils.data import random_split

ds_dir = "data/preprocessed"

dataset = MIDIDataset(ds_dir)

print('Dataset len:', len(dataset))

train_len = int(0.7 * len(dataset)) 
valid_len = len(dataset) - train_len
#test_count = total_count - train_count - valid_count
#train_dataset, valid_dataset, test_dataset = random_split(model_dataset, (train_count, valid_count, test_count))
tr_set, vl_set = random_split(dataset, (train_len, valid_len))

trainloader = DataLoader(tr_set, batch_size=64, shuffle=True, num_workers=8)
validloader = DataLoader(vl_set, batch_size=64, shuffle=False, num_workers=8)

print('TR set len:', len(tr_set))
print('VL set len:', len(vl_set))

#n_samples = 128
#subset = Subset(dataset, np.arange(n_samples))
#loader = DataLoader(subset, batch_size=64, shuffle=True)

Dataset len: 30091
TR set len: 21063
VL set len: 9028


In [27]:
import torch
torch.cuda.set_device(0)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#decive = torch.device("cpu")
print("Device:", device)
print("Current device idx:", torch.cuda.current_device())

Device: cuda
Current device idx: 0


In [28]:
#!rm models/vae

In [29]:
from prettytable import PrettyTable


def print_params(model):
    
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    
    for name, parameter in model.named_parameters():
        
        if not parameter.requires_grad:
            continue
            
        param = parameter.numel()
        table.add_row([name, param])
        total_params += param
        
    print(table)
    print(f"Total Trainable Params: {total_params}")
    
    return total_params

In [30]:
import math
import torch
from typing import Optional
from torch.optim import Optimizer

from lr_scheduler.lr_scheduler import LearningRateScheduler

class TransformerLRScheduler(LearningRateScheduler):
    r"""
    Transformer Learning Rate Scheduler proposed in "Attention Is All You Need"
    Args:
        optimizer (Optimizer): Optimizer.
        init_lr (float): Initial learning rate.
        peak_lr (float): Maximum learning rate.
        final_lr (float): Final learning rate.
        final_lr_scale (float): Final learning rate scale
        warmup_steps (int): Warmup the learning rate linearly for the first N updates
        decay_steps (int): Steps in decay stages
    """
    def __init__(
            self,
            optimizer: Optimizer,
            init_lr: float,
            peak_lr: float,
            final_lr: float,
            final_lr_scale: float,
            warmup_steps: int,
            decay_steps: int,
    ) -> None:
        assert isinstance(warmup_steps, int), "warmup_steps should be inteager type"
        assert isinstance(decay_steps, int), "total_steps should be inteager type"

        super(TransformerLRScheduler, self).__init__(optimizer, init_lr)
        self.final_lr = final_lr
        self.peak_lr = peak_lr
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps

        self.warmup_rate = self.peak_lr / self.warmup_steps
        self.decay_factor = -math.log(final_lr_scale) / self.decay_steps

        self.init_lr = init_lr
        self.update_steps = 0

    def _decide_stage(self):
        if self.update_steps < self.warmup_steps:
            return 0, self.update_steps

        if self.warmup_steps <= self.update_steps:
            return 1, self.update_steps - self.warmup_steps

        return 2, None

    def step(self, val_loss: Optional[torch.FloatTensor] = None):
        self.update_steps += 1
        stage, steps_in_stage = self._decide_stage()

        if stage == 0:
            self.lr = self.update_steps * self.warmup_rate
        elif stage == 1:
            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
        else:
            raise ValueError("Undefined stage")

        self.set_lr(self.optimizer, self.lr)

        return self.lr

In [23]:
#from lr_scheduler.transformer_lr_scheduler import TransformerLRScheduler

print("Creating the model and moving it to the specified device...")

vae = VAE(dropout=0).to(device)
print_params(vae)
print()

init_lr = 5e-6
gamma = 0.999
optimizer = optim.Adam(vae.parameters(), lr=init_lr, betas=(0.9, 0.98), eps=1e-09)
#scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma)
scheduler = TransformerLRScheduler(
        optimizer=optimizer, 
        init_lr=1e-10, 
        peak_lr=5e-4,
        final_lr=1e-7, 
        final_lr_scale=0.1,
        warmup_steps=4000, 
        decay_steps=80000,
)

print('--------------------------------------------------\n')

trainer = VAETrainer(
    vae,
    models_dir,
    optimizer,
    init_lr,
    name='NULL',
    lr_scheduler=scheduler,
    save_every=100,
    eval_every=1000,
    device=device
)
trainer.train(trainloader, validloader=validloader, epochs=100)

Creating the model and moving it to the specified device...
+----------------------------------------------------------------+------------+
|                            Modules                             | Parameters |
+----------------------------------------------------------------+------------+
|                    encoder.embedding.weight                    |   58880    |
|                     encoder.embedding.bias                     |    256     |
| encoder.transformer_encoder.layers.0.self_attn.in_proj_weight  |   196608   |
|  encoder.transformer_encoder.layers.0.self_attn.in_proj_bias   |    768     |
| encoder.transformer_encoder.layers.0.self_attn.out_proj.weight |   65536    |
|  encoder.transformer_encoder.layers.0.self_attn.out_proj.bias  |    256     |
|      encoder.transformer_encoder.layers.0.linear1.weight       |   524288   |
|       encoder.transformer_encoder.layers.0.linear1.bias        |    2048    |
|      encoder.transformer_encoder.layers.0.linear2.weight  

  0%|          | 0/330 [00:00<?, ?it/s]

Training on batch 1/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:01.67
Losses:
{'tot': 53.208763122558594, 'pitches': 4.865426063537598, 'dur': 4.830996513366699, 'acts': 43.5123405456543, 'rec': 53.208763122558594, 'kld': 173284278272.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.0, 'pitches': 0.0033809165470302105, 'dur': 0.0003756574005819857, 'acts': 0.531005859375}

----------------------------------------

Training on batch 2/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:01.96
Losses:
{'tot': 28.05336570739746, 'pitches': 4.899317741394043, 'dur': 4.787333011627197, 'acts': 18.366714477539062, 'rec': 28.05336570739746, 'kld': 30109499392.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.00039354583714157343, 'pitches': 0.005903187673538923, 'dur': 0.00511609623208642, 'acts': 0.5301513671875}

----------------------------------------

Training on batch 3/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:02.23
Losses:
{'tot': 

Training on batch 20/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:07.07
Losses:
{'tot': 15.675088882446289, 'pitches': 4.797269821166992, 'dur': 4.668727397918701, 'acts': 6.209091663360596, 'rec': 15.675088882446289, 'kld': 6942013952.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.00038955980562604964, 'pitches': 0.009738994762301445, 'dur': 0.008570315316319466, 'acts': 0.5411376953125}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 18.40020935971972, 'pitches': 4.788230368788813, 'dur': 4.667317910933159, 'acts': 8.944661058170695, 'rec': 18.40020935971972, 'kld': 12940105203.380281, 'beta*kld': 0.0}
Val accuracies:
{'notes': 5.435056613043318e-05, 'pitches': 0.009691607174266812, 'dur': 0.0076331764330077444, 'acts': 0.5340163539832746}

Saving model to disk...

Training on batch 21/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:38.93
Losses:
{'tot': 15.36347484588623, 'pitches': 4.800722122192383, 'dur': 4.649045467376709, 'acts': 5.9137067794799805, 'rec': 15.36347484588623, 'kld': 3136240640.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.0, 'pitches': 0.008760402910411358, 'dur': 0.005694261752068996, 'acts': 0.5364990234375}

----------------------------------------

Training on batch 22/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:39.21
Losses:
{'tot': 14.551982879638672, 'pitches': 4.805830955505371, 'dur': 4.678613185882568, 'acts': 5.0675392150

Training on batch 39/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:44.30
Losses:
{'tot': 12.842111587524414, 'pitches': 4.58919095993042, 'dur': 4.373100757598877, 'acts': 3.879819393157959, 'rec': 12.842111587524414, 'kld': 2855031552.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.0, 'pitches': 0.053779616951942444, 'dur': 0.04024069383740425, 'acts': 0.534912109375}

----------------------------------------

Training on batch 40/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:44.60
Losses:
{'tot': 11.182708740234375, 'pitches': 4.608540058135986, 'dur': 4.310759544372559, 'acts': 2.263408660888672, 'rec': 11.182708740234375, 'kld': 535425088.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.00041169204632751644, 'pitches': 0.06093042343854904, 'dur': 0.05846027284860611, 'acts': 0.5399169921875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 11.437341522163068, 'pitches': 4.571016184041198, 'dur': 4.324074453031513, 'acts': 2.54225093965799, 'rec': 11.437341522163068, 'kld': 797985121.1830986, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.000379427764515563, 'pitches': 0.0727007135381581, 'dur': 0.05689459695467647, 'acts': 0.5353763204225352}

Saving model to disk...

Training on batch 41/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:17.19
Losses:
{'tot': 10.193792343139648, 'pitches': 4.60163688659668, 'dur': 4.292516708374023, 'acts': 1.2996385097503662, 'rec': 10.193792343139648, 'kld': 133260688.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.0, 'pitches': 0.0646701380610466, 'dur': 0.05859375, 'acts': 0.5384521484375}

----------------------------------------

Training on batch 42/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:17.47
Losses:
{'tot': 10.516519546508789, 'pitches': 4.542551040649414, 'dur': 4.281126976013184, 'acts': 1.692841649055481, 'rec': 10.

Training on batch 59/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:22.48
Losses:
{'tot': 9.002607345581055, 'pitches': 4.253903865814209, 'dur': 3.8347644805908203, 'acts': 0.9139388799667358, 'rec': 9.002607345581055, 'kld': 39434680.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.0050293379463255405, 'pitches': 0.3076278269290924, 'dur': 0.24015088379383087, 'acts': 0.52783203125}

----------------------------------------

Training on batch 60/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:22.77
Losses:
{'tot': 8.802912712097168, 'pitches': 4.274204730987549, 'dur': 3.777921199798584, 'acts': 0.7507871389389038, 'rec': 8.802912712097168, 'kld': 20555460.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.002862985711544752, 'pitches': 0.3226993978023529, 'dur': 0.2523517310619354, 'acts': 0.5311279296875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 9.228265157887634, 'pitches': 4.232507302727498, 'dur': 3.7679056214614652, 'acts': 1.2278521955013275, 'rec': 9.228265157887634, 'kld': 144290823.80985916, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.005844503547668352, 'pitches': 0.32559566535580325, 'dur': 0.26372714791919144, 'acts': 0.5369142344300176}

Saving model to disk...

Training on batch 61/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:55.06
Losses:
{'tot': 9.972620964050293, 'pitches': 4.224850177764893, 'dur': 3.8033814430236816, 'acts': 1.9443891048431396, 'rec': 9.972620964050293, 'kld': 272492064.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.004004368558526039, 'pitches': 0.3265380561351776, 'dur': 0.24899891018867493, 'acts': 0.5284423828125}

----------------------------------------

Training on batch 62/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:55.35
Losses:
{'tot': 9.317845344543457, 'pitches': 4.241974830627441, 'dur': 3.6893458366394043, 'acts': 

Training on batch 79/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:00.29
Losses:
{'tot': 7.907594680786133, 'pitches': 3.8888633251190186, 'dur': 3.2176520824432373, 'acts': 0.8010791540145874, 'rec': 7.907594680786133, 'kld': 29816506.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.019682059064507484, 'pitches': 0.40196821093559265, 'dur': 0.3406510353088379, 'acts': 0.533203125}

----------------------------------------

Training on batch 80/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:00.58
Losses:
{'tot': 7.470861434936523, 'pitches': 3.868804931640625, 'dur': 3.1199610233306885, 'acts': 0.48209547996520996, 'rec': 7.470861434936523, 'kld': 8994133.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.02574750781059265, 'pitches': 0.40074750781059265, 'dur': 0.3658638000488281, 'acts': 0.532470703125}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 7.68850084089897, 'pitches': 3.848155767145291, 'dur': 3.1069588812304216, 'acts': 0.7333861845479884, 'rec': 7.68850084089897, 'kld': 44890133.02112676, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.024695273882276575, 'pitches': 0.4002225917409843, 'dur': 0.3643051249460435, 'acts': 0.5397295884683099}

Saving model to disk...

Training on batch 81/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:33.61
Losses:
{'tot': 7.768494606018066, 'pitches': 3.869361639022827, 'dur': 3.0988965034484863, 'acts': 0.8002368211746216, 'rec': 7.768494606018066, 'kld': 51372012.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.02697262540459633, 'pitches': 0.3977455794811249, 'dur': 0.37600645422935486, 'acts': 0.5452880859375}

----------------------------------------

Training on batch 82/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:33.91
Losses:
{'tot': 7.397106170654297, 'pitches': 3.779360771179199, 'dur': 3.0832347869873047, 'acts': 0.53451

Training on batch 99/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:38.92
Losses:
{'tot': 6.632215976715088, 'pitches': 3.595416784286499, 'dur': 2.5763301849365234, 'acts': 0.46046897768974304, 'rec': 6.632215976715088, 'kld': 13206340.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.06840750575065613, 'pitches': 0.3927914798259735, 'dur': 0.3747701346874237, 'acts': 0.5416259765625}

----------------------------------------

Training on batch 100/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:39.23
Losses:
{'tot': 6.669652462005615, 'pitches': 3.556648015975952, 'dur': 2.612837314605713, 'acts': 0.5001674890518188, 'rec': 6.669652462005615, 'kld': 8893510.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.06344984471797943, 'pitches': 0.4012157917022705, 'dur': 0.332446813583374, 'acts': 0.5419921875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 6.567966118664809, 'pitches': 3.553146896227984, 'dur': 2.5016246547161693, 'acts': 0.513194558696008, 'rec': 6.567966118664809, 'kld': 19148421.205985915, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.08228301088994658, 'pitches': 0.4005805042008279, 'dur': 0.37466468168816097, 'acts': 0.5434088908450704}

Saving model to disk...

Training on batch 101/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:11.91
Losses:
{'tot': 7.251069068908691, 'pitches': 3.4864633083343506, 'dur': 2.505316972732544, 'acts': 1.2592886686325073, 'rec': 7.251069068908691, 'kld': 174298912.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.0827505812048912, 'pitches': 0.4052059054374695, 'dur': 0.38267287611961365, 'acts': 0.5433349609375}

----------------------------------------

Training on batch 102/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:12.23
Losses:
{'tot': 6.605900287628174, 'pitches': 3.551401138305664, 'dur': 2.525984048843384, 'acts': 0.52

Training on batch 119/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:17.10
Losses:
{'tot': 6.00332498550415, 'pitches': 3.422694683074951, 'dur': 2.1467270851135254, 'acts': 0.43390315771102905, 'rec': 6.00332498550415, 'kld': 8190747.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.18255189061164856, 'pitches': 0.3989239037036896, 'dur': 0.38893160223960876, 'acts': 0.55078125}

----------------------------------------

Training on batch 120/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:17.39
Losses:
{'tot': 5.80135440826416, 'pitches': 3.402132272720337, 'dur': 2.0872998237609863, 'acts': 0.3119221329689026, 'rec': 5.80135440826416, 'kld': 2943133.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.23179849982261658, 'pitches': 0.3974812924861908, 'dur': 0.4144037663936615, 'acts': 0.552734375}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 5.878528675562899, 'pitches': 3.3857958686183878, 'dur': 2.118374158798809, 'acts': 0.3743586459420097, 'rec': 5.878528675562899, 'kld': 8045932.941901408, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.20824999236304995, 'pitches': 0.4005805042008279, 'dur': 0.39850102155141426, 'acts': 0.5498743191571303}

Saving model to disk...

Training on batch 121/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:50.54
Losses:
{'tot': 6.047859191894531, 'pitches': 3.384915590286255, 'dur': 2.087686777114868, 'acts': 0.5752567648887634, 'rec': 6.047859191894531, 'kld': 26818248.0, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.20172414183616638, 'pitches': 0.40344828367233276, 'dur': 0.4073275923728943, 'acts': 0.54248046875}

----------------------------------------

Training on batch 122/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:50.83
Losses:
{'tot': 5.789438247680664, 'pitches': 3.369262456893921, 'dur': 2.067885160446167, 'acts': 0.35229

Training on batch 139/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:55.88
Losses:
{'tot': 5.476077079772949, 'pitches': 3.231300115585327, 'dur': 1.9804061651229858, 'acts': 0.26437070965766907, 'rec': 5.476077079772949, 'kld': 1909819.25, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.21323251724243164, 'pitches': 0.40718337893486023, 'dur': 0.42495274543762207, 'acts': 0.5611572265625}

----------------------------------------

Training on batch 140/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:56.17
Losses:
{'tot': 5.296639919281006, 'pitches': 3.2104239463806152, 'dur': 1.8862732648849487, 'acts': 0.19994260370731354, 'rec': 5.296639919281006, 'kld': 1100555.25, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.21797075867652893, 'pitches': 0.4062768816947937, 'dur': 0.42906275391578674, 'acts': 0.5584716796875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 5.467488198213174, 'pitches': 3.250897162397143, 'dur': 1.946654932599672, 'acts': 0.26993611801258277, 'rec': 5.467488198213174, 'kld': 3879837.5277288733, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.22015769464868895, 'pitches': 0.4005805042008279, 'dur': 0.4121519169757064, 'acts': 0.5591611190580986}

Saving model to disk...

Training on batch 141/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:28.71
Losses:
{'tot': 5.240694999694824, 'pitches': 3.1864278316497803, 'dur': 1.8508408069610596, 'acts': 0.2034262716770172, 'rec': 5.240694999694824, 'kld': 1322080.125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.19413593411445618, 'pitches': 0.4082629978656769, 'dur': 0.428254097700119, 'acts': 0.56298828125}

----------------------------------------

Training on batch 142/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:28.99
Losses:
{'tot': 5.471745014190674, 'pitches': 3.2781291007995605, 'dur': 1.9666919708251953, 'acts': 0.22

Training on batch 159/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:33.98
Losses:
{'tot': 5.246070384979248, 'pitches': 3.150061845779419, 'dur': 1.8718777894973755, 'acts': 0.22413063049316406, 'rec': 5.246070384979248, 'kld': 1850397.125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2561489939689636, 'pitches': 0.4072382152080536, 'dur': 0.4279690682888031, 'acts': 0.5743408203125}

----------------------------------------

Training on batch 160/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:04:34.27
Losses:
{'tot': 5.218161106109619, 'pitches': 3.177996873855591, 'dur': 1.8475292921066284, 'acts': 0.19263505935668945, 'rec': 5.218161106109619, 'kld': 929822.3125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.21677471697330475, 'pitches': 0.40113452076911926, 'dur': 0.42666125297546387, 'acts': 0.5626220703125}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 5.215033541262989, 'pitches': 3.144829254754832, 'dur': 1.8598079395965792, 'acts': 0.21039634449800976, 'rec': 5.215033541262989, 'kld': 2196881.5411531692, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.24031295858218638, 'pitches': 0.4005805042008279, 'dur': 0.42461275844506813, 'acts': 0.5701018857284331}

Saving model to disk...

Training on batch 161/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:07.64
Losses:
{'tot': 5.264692306518555, 'pitches': 3.1878561973571777, 'dur': 1.8654061555862427, 'acts': 0.211430162191391, 'rec': 5.264692306518555, 'kld': 1394447.625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.22653362154960632, 'pitches': 0.40243902802467346, 'dur': 0.42461198568344116, 'acts': 0.576904296875}

----------------------------------------

Training on batch 162/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:07.94
Losses:
{'tot': 5.132849216461182, 'pitches': 3.1102280616760254, 'dur': 1.826188087463379, 'acts': 

Training on batch 179/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:12.93
Losses:
{'tot': 5.021176338195801, 'pitches': 3.109670639038086, 'dur': 1.748045563697815, 'acts': 0.16346028447151184, 'rec': 5.021176338195801, 'kld': 944134.4375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.22419007122516632, 'pitches': 0.40043196082115173, 'dur': 0.4414686858654022, 'acts': 0.579833984375}

----------------------------------------

Training on batch 180/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:13.25
Losses:
{'tot': 5.054254055023193, 'pitches': 3.0999484062194824, 'dur': 1.7764590978622437, 'acts': 0.17784655094146729, 'rec': 5.054254055023193, 'kld': 768753.0625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.25206923484802246, 'pitches': 0.40744921565055847, 'dur': 0.43566590547561646, 'acts': 0.5858154296875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 5.039306959635775, 'pitches': 3.0764472098417683, 'dur': 1.7916946234837385, 'acts': 0.1711651123535465, 'rec': 5.039306959635775, 'kld': 1219074.0184859154, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.241723121891559, 'pitches': 0.4005805042008279, 'dur': 0.43557779679835684, 'acts': 0.5833379181338029}

Saving model to disk...

Training on batch 181/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:46.46
Losses:
{'tot': 5.04530143737793, 'pitches': 3.0906379222869873, 'dur': 1.8159615993499756, 'acts': 0.13870203495025635, 'rec': 5.04530143737793, 'kld': 551899.125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2531055808067322, 'pitches': 0.4013975262641907, 'dur': 0.4413819909095764, 'acts': 0.59130859375}

----------------------------------------

Training on batch 182/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:46.77
Losses:
{'tot': 4.947247505187988, 'pitches': 3.0408637523651123, 'dur': 1.7468092441558838, 'acts': 0.1595

Training on batch 199/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:52.01
Losses:
{'tot': 4.97079610824585, 'pitches': 3.0434155464172363, 'dur': 1.7851202487945557, 'acts': 0.1422601342201233, 'rec': 4.97079610824585, 'kld': 603320.5, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2300955355167389, 'pitches': 0.39331209659576416, 'dur': 0.45063695311546326, 'acts': 0.5946044921875}

----------------------------------------

Training on batch 200/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:05:52.35
Losses:
{'tot': 4.9555463790893555, 'pitches': 3.013582229614258, 'dur': 1.7618556022644043, 'acts': 0.18010878562927246, 'rec': 4.9555463790893555, 'kld': 1175121.25, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.21355685591697693, 'pitches': 0.40706998109817505, 'dur': 0.44788628816604614, 'acts': 0.60107421875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 4.904134310467143, 'pitches': 3.0185586465916163, 'dur': 1.7437336780655552, 'acts': 0.14184198266183826, 'rec': 4.904134310467143, 'kld': 743530.5911641725, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.22422338759815189, 'pitches': 0.4005805042008279, 'dur': 0.44453366730414645, 'acts': 0.599824287522007}

Saving model to disk...

Training on batch 201/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:24.48
Losses:
{'tot': 4.922852039337158, 'pitches': 3.002279758453369, 'dur': 1.7750083208084106, 'acts': 0.14556406438350677, 'rec': 4.922852039337158, 'kld': 659311.6875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.22464342415332794, 'pitches': 0.3914421498775482, 'dur': 0.4401743412017822, 'acts': 0.5948486328125}

----------------------------------------

Training on batch 202/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:24.79
Losses:
{'tot': 5.0061774253845215, 'pitches': 3.031975269317627, 'dur': 1.8183673620224, 'acts': 0.1

Training on batch 219/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:29.92
Losses:
{'tot': 4.783545017242432, 'pitches': 2.982226610183716, 'dur': 1.6988468170166016, 'acts': 0.10247164964675903, 'rec': 4.783545017242432, 'kld': 297601.28125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2410256415605545, 'pitches': 0.39393940567970276, 'dur': 0.45268064737319946, 'acts': 0.6107177734375}

----------------------------------------

Training on batch 220/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:06:30.23
Losses:
{'tot': 4.768223285675049, 'pitches': 2.9725918769836426, 'dur': 1.669410228729248, 'acts': 0.12622101604938507, 'rec': 4.768223285675049, 'kld': 518772.4375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.27963176369667053, 'pitches': 0.40736478567123413, 'dur': 0.46298426389694214, 'acts': 0.6209716796875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 4.794309327300166, 'pitches': 2.9695225702205175, 'dur': 1.7032920407577299, 'acts': 0.1214947210441173, 'rec': 4.794309327300166, 'kld': 495064.95461047534, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.2671543561236959, 'pitches': 0.4005805042008279, 'dur': 0.452019118717019, 'acts': 0.6196254676496479}

Saving model to disk...

Training on batch 221/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:03.77
Losses:
{'tot': 4.799081802368164, 'pitches': 2.987393856048584, 'dur': 1.6918443441390991, 'acts': 0.11984371393918991, 'rec': 4.799081802368164, 'kld': 392535.125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.25030574202537537, 'pitches': 0.4007337987422943, 'dur': 0.4684060215950012, 'acts': 0.60986328125}

----------------------------------------

Training on batch 222/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:04.09
Losses:
{'tot': 4.8922271728515625, 'pitches': 3.025879144668579, 'dur': 1.7543227672576904, 'acts': 0.112

Training on batch 239/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:09.24
Losses:
{'tot': 4.740001678466797, 'pitches': 2.967430353164673, 'dur': 1.6491986513137817, 'acts': 0.1233723983168602, 'rec': 4.740001678466797, 'kld': 501580.125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2414940893650055, 'pitches': 0.39829882979393005, 'dur': 0.459319531917572, 'acts': 0.63720703125}

----------------------------------------

Training on batch 240/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:09.55
Losses:
{'tot': 4.760018825531006, 'pitches': 2.9572982788085938, 'dur': 1.6798819303512573, 'acts': 0.12283830344676971, 'rec': 4.760018825531006, 'kld': 486833.53125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.24098536372184753, 'pitches': 0.4059264659881592, 'dur': 0.45305249094963074, 'acts': 0.6461181640625}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 4.698114358203512, 'pitches': 2.926281657017453, 'dur': 1.6636051376100998, 'acts': 0.10822758619004572, 'rec': 4.698114358203512, 'kld': 374625.2694487236, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.24275091269486387, 'pitches': 0.40061178656531055, 'dur': 0.45815291706944855, 'acts': 0.6407367545114436}

Saving model to disk...

Training on batch 241/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:42.77
Losses:
{'tot': 4.723637104034424, 'pitches': 2.9137022495269775, 'dur': 1.709928274154663, 'acts': 0.10000652074813843, 'rec': 4.723637104034424, 'kld': 246525.25, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2377237230539322, 'pitches': 0.40385496616363525, 'dur': 0.4580082595348358, 'acts': 0.6307373046875}

----------------------------------------

Training on batch 242/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:43.07
Losses:
{'tot': 4.789113998413086, 'pitches': 2.924435615539551, 'dur': 1.7410192489624023, 'acts': 0.

Training on batch 259/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:48.05
Losses:
{'tot': 4.642887115478516, 'pitches': 2.991431474685669, 'dur': 1.556716799736023, 'acts': 0.09473889321088791, 'rec': 4.642887115478516, 'kld': 257838.59375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.22512038052082062, 'pitches': 0.3848314583301544, 'dur': 0.46308186650276184, 'acts': 0.6611328125}

----------------------------------------

Training on batch 260/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:07:48.35
Losses:
{'tot': 4.603305816650391, 'pitches': 2.9211127758026123, 'dur': 1.5877248048782349, 'acts': 0.09446794539690018, 'rec': 4.603305816650391, 'kld': 255245.609375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.24727272987365723, 'pitches': 0.40242424607276917, 'dur': 0.4723232388496399, 'acts': 0.671142578125}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 4.608230765436737, 'pitches': 2.885288100847056, 'dur': 1.624832054259072, 'acts': 0.09811059868251773, 'rec': 4.608230765436737, 'kld': 296135.0331481074, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.25129545291124933, 'pitches': 0.4019340751036792, 'dur': 0.4626078208987142, 'acts': 0.6645370268485915}

Saving model to disk...

Training on batch 261/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:21.61
Losses:
{'tot': 4.631411552429199, 'pitches': 2.935112476348877, 'dur': 1.5982296466827393, 'acts': 0.09806913137435913, 'rec': 4.631411552429199, 'kld': 364577.65625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.24813896417617798, 'pitches': 0.394540935754776, 'dur': 0.46224743127822876, 'acts': 0.671875}

----------------------------------------

Training on batch 262/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:21.90
Losses:
{'tot': 4.580989360809326, 'pitches': 2.857893466949463, 'dur': 1.6238945722579956, 'acts': 0.0992011

Training on batch 279/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:27.12
Losses:
{'tot': 4.469095230102539, 'pitches': 2.815141439437866, 'dur': 1.5664324760437012, 'acts': 0.08752162754535675, 'rec': 4.469095230102539, 'kld': 260739.5, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.26966291666030884, 'pitches': 0.41733548045158386, 'dur': 0.47150883078575134, 'acts': 0.692138671875}

----------------------------------------

Training on batch 280/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:08:27.42
Losses:
{'tot': 4.44773530960083, 'pitches': 2.840688467025757, 'dur': 1.5146576166152954, 'acts': 0.09238891303539276, 'rec': 4.44773530960083, 'kld': 272283.28125, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2555099129676819, 'pitches': 0.3993276059627533, 'dur': 0.48262980580329895, 'acts': 0.6912841796875}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Val losses:
{'tot': 4.518468850095507, 'pitches': 2.836031391587056, 'dur': 1.591845541772708, 'acts': 0.09059191967400027, 'rec': 4.518468850095507, 'kld': 255071.27946192783, 'beta*kld': 0.0}
Val accuracies:
{'notes': 0.2552443559111004, 'pitches': 0.4045502306290076, 'dur': 0.47078694864897663, 'acts': 0.6900488625110035}

Saving model to disk...

Training on batch 281/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:00.98
Losses:
{'tot': 4.250892639160156, 'pitches': 2.714843273162842, 'dur': 1.4514883756637573, 'acts': 0.0845608189702034, 'rec': 4.250892639160156, 'kld': 255412.546875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.23461538553237915, 'pitches': 0.4025641083717346, 'dur': 0.4991452991962433, 'acts': 0.6910400390625}

----------------------------------------

Training on batch 282/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:01.28
Losses:
{'tot': 4.433719158172607, 'pitches': 2.7752506732940674, 'dur': 1.568697452545166, 'acts': 0

Training on batch 299/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:06.30
Losses:
{'tot': 4.490170478820801, 'pitches': 2.868933916091919, 'dur': 1.530999779701233, 'acts': 0.09023642539978027, 'rec': 4.490170478820801, 'kld': 276449.15625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2311355322599411, 'pitches': 0.4054945111274719, 'dur': 0.4673992693424225, 'acts': 0.7200927734375}

----------------------------------------

Training on batch 300/330 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:09:06.59
Losses:
{'tot': 4.329893112182617, 'pitches': 2.7252652645111084, 'dur': 1.5167964696884155, 'acts': 0.08783131092786789, 'rec': 4.329893112182617, 'kld': 236200.9375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.2552318572998047, 'pitches': 0.4226508140563965, 'dur': 0.4718916714191437, 'acts': 0.7073974609375}

----------------------------------------


Evaluating on validation set...



  0%|          | 0/142 [00:00<?, ?it/s]

Process Process-176:
Process Process-56:
Process Process-53:
Process Process-49:
Process Process-50:
Process Process-55:
Process Process-54:
Process Process-52:
Process Process-51:
Process Process-173:
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-172:
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/process.py", line 300, in _bootstrap
    util._exit_function()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/mu

  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/queues.py", line 192, in _finalize_join
    thread.join()
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/threading.py", line 1060, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/multiprocessing/queues.py", line 192, in _finalize_join
    thread.join()
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/threading.py", line 1060, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
  File "/home/cosenza/anaconda3/envs/thesis/lib/python3.7/threading.py", line 1044, in join
    self._wait_for_tstate_lo

KeyboardInterrupt: 

## Reconstructions

In [None]:
checkpoint = torch.load('models/just_pitches_warmup')

In [None]:
state_dict = checkpoint['model_state_dict']
vae = VAE().to(device)

In [None]:
vae.load_state_dict(state_dict)

In [None]:
loader = DataLoader(dataset, batch_size=32, shuffle=False)
len(dataset)

In [None]:
for idx, inputs in enumerate(loader):
    
    x_seq, x_acts, x_graph, src_mask = inputs
    x_seq = x_seq.float().to(device)
    x_acts = x_acts.to(device)
    x_graph = x_graph.to(device)
    src_mask = src_mask.to(device)
    tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(device)

    # Forward pass, get the reconstructions
    outputs, mu, log_var = vae(x_seq, x_acts, x_graph, src_mask, tgt_mask)
    
    break

seq_rec, _  = outputs

In [None]:
x_seq.size()

In [None]:
seq_rec.size()

In [None]:
x_acts.size()

Create dense reconstruction from sparse reconstruction:

In [None]:
seq_rec_dense = torch.zeros(x_seq.size(), dtype=torch.float).to(device)
seq_rec_dense = seq_rec_dense[..., 1:, :]
size = seq_rec_dense.size()

seq_rec_dense = seq_rec_dense.view(-1, seq_rec_dense.size(-2), seq_rec_dense.size(-1))

silence = torch.zeros(seq_rec_dense.size(-2), seq_rec_dense.size(-1)).to(device)
silence[:, 129] = 1. # eos token

seq_rec_dense[x_acts.bool().view(-1)] = seq_rec
seq_rec_dense[torch.logical_not(x_acts.bool().view(-1))] = silence

seq_rec_dense = seq_rec_dense.view(size)

In [None]:
print(seq_rec_dense.size())
print(x_seq.size())

In [None]:
music_real = x_seq[0]
music_rec = seq_rec_dense[0]

In [None]:
music_real.size()

In [None]:
prefix = "data/music/"

real = from_tensor_to_muspy(music_real, track_data)
muspy.show_pianoroll(real, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "real" + ".png")
muspy.write_midi(prefix + "real" + ".mid", real)

In [None]:
rec = from_tensor_to_muspy(music_rec, track_data)
muspy.show_pianoroll(rec, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "rec" + ".png")
muspy.write_midi(prefix + "rec" + ".mid", rec)

Plot music and save it to disk

In [None]:
#tracks = [drum_track, bass_track, guitar_track, strings_track]
import copy

def from_tensor_to_muspy(music_tensor, track_data):
    
    powers = torch.tensor([2**n for n in reversed(range(9))], dtype=torch.float)
    tracks = []
    
    for tr in range(music_tensor.size(0)):
        
        notes = []
        
        for ts in range(music_tensor.size(1)):
            for note in range(music_tensor.size(2)):
                
                pitch = music_tensor[tr, ts, note, :131]
                pitch = torch.argmax(pitch)

                if pitch == 129:
                    break
                
                if pitch != 128:
                    #dur = music_tensor[tr, ts, note, 131:]
                    #dur = torch.dot(dur, powers).long()
                    
                    dur = 4
                    
                    #notes.append(muspy.Note(ts, pitch.item(), dur.item(), 64))
                    notes.append(muspy.Note(ts, pitch.item(), dur, 64))
        
        if track_data[tr][0] == 'Drums':
            track = muspy.Track(name='Drums', is_drum=True, notes=copy.deepcopy(notes))
        else:
            track = muspy.Track(name=track_data[tr][0], 
                                program=track_data[tr][1],
                                notes=copy.deepcopy(notes))
        tracks.append(track)
    
    meta = muspy.Metadata(title='prova')
    music = muspy.Music(tracks=tracks, metadata=meta, resolution=RESOLUTION)
    
    return music


track_data = [('Drums', -1), ('Bass', 34), ('Guitar', 1), ('Strings', 41)]

In [None]:
prefix = "data/music/file"

for i in range(10):
    music_tensor = dataset[20+i][0]
    music = from_tensor_to_muspy(music_tensor, track_data)
    muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
    plt.savefig(prefix + str(i) + ".png")
    muspy.write_midi(prefix + str(i) + ".mid", music)

In [None]:
music

In [None]:
music_path = "data/music/file2.mid"
muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
plt.savefig('file2.png')
muspy.write_midi(music_path, music)

In [None]:
print(dataset[0][0].size())
notes = []
notes.append(muspy.Note(1, 48, 20, 64))
drums = muspy.Track(is_drum=True)
bass = muspy.Track(program=34, notes=notes)
guitar = muspy.Track(program=27, notes=[])
strings = muspy.Track(program=42, notes=[muspy.Note(0, 100, 4, 64), muspy.Note(4, 91, 20, 64)])

tracks = [drums, bass, guitar, strings]

meta = muspy.Metadata(title='prova')
music = muspy.Music(tracks=tracks, metadata=meta, resolution=32)

In [None]:
!ls data/lmd_matched/M/T/O/TRMTOBP128E07822EF/63edabc86c087f07eca448b0edad53c3.mid

# Stuff

next edges

In [None]:
import itertools

a = np.random.randint(2, size=(4,8))
a_t = a.transpose()
print(a_t)
inds = np.stack(np.where(a_t == 1)).transpose()
ts_acts = np.any(a_t, axis=1)
ts_inds = np.where(ts_acts)[0]

labels = np.arange(32).reshape(4, 8).transpose()
print(labels)

next_edges = []
for i in range(len(ts_inds)-1):
    ind_s = ts_inds[i]
    ind_e = ts_inds[i+1]
    s = inds[inds[:,0] == ind_s]
    e = inds[inds[:,0] == ind_e]
    e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
    edges = [(labels[tuple(e[0])],labels[tuple(e[1])], ind_e-ind_s) for e in e_inds]
    next_edges.extend(edges)

print(next_edges)
    

onset edges

In [None]:
onset_edges = []
print(a_t)
print(labels)

for i in ts_inds:
    ts_acts_inds = list(inds[inds[:,0] == i])
    if len(ts_acts_inds) < 2:
        continue
    e_inds = list(itertools.combinations(ts_acts_inds, 2))
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], 0) for e in e_inds]
    inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
    onset_edges.extend(edges)
    onset_edges.extend(inv_edges)

print(onset_edges)


track edges

In [None]:
print(a_t)
print(labels)
track_edges = []

for track in range(a_t.shape[1]):
    tr_inds = list(inds[inds[:,1] == track])
    e_inds = [(tr_inds[i],
               tr_inds[i+1]) for i in range(len(tr_inds)-1)]
    print(e_inds)
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], e[1][0]-e[0][0]) for e in e_inds]
    track_edges.extend(edges)

print(track_edges)

In [None]:
track_edges = np.array(track_edges)
onset_edges = np.array(onset_edges)
np.concatenate((track_edges, onset_edges)).shape

In [None]:
pip install pypianoroll

In [None]:
import pypianoroll

In [None]:
multitrack = pypianoroll.read("tests_fur-elise.mid")
print(multitrack)

In [None]:
multitrack.tracks[0].pianoroll

In [None]:
multitrack.plot()

In [None]:
multitrack.trim(0, 12 * multitrack.resolution)
multitrack.binarize()

In [None]:
multitrack.plot()

In [None]:
multitrack.tracks[0].pianoroll.shape