<a href="https://colab.research.google.com/github/EmanueleCosenza/Polyphemus/blob/main/midi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pwd

/home/cosenza/thesis/Polyphemus


In [2]:
!git branch

  main
* sparse


Libraries installation

In [3]:
#!tar -C data -xvzf data/lmd_matched.tar.gz

In [4]:
# Install the required music libraries
#!pip3 install muspy
#!pip3 install pypianoroll

In [5]:
# Install torch_geometric
#!v=$(python3 -c "import torch; print(torch.__version__)"); \
#pip3 install torch-scatter -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-sparse -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-geometric

Libraries and reproducibility

In [6]:
import numpy as np
import torch
import random
import os
import muspy
from itertools import product
import pypianoroll as pproll
import time
from tqdm.auto import tqdm

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

seed = 42
set_seed(seed)

In [7]:
class MIDIPreprocessor():
    
    def __init__():
        pass

    def preprocess_dataset(self, dir, early_exit=None):
        pass
    
    def preprocess_file(self, f):
        pass


# Todo: to config file (or separate files)
MAX_SIMU_NOTES = 16 # 14 + SOS and EOS

PITCH_SOS = 128
PITCH_EOS = 129
PITCH_PAD = 130
DUR_SOS = 96
DUR_EOS = 97
DUR_PAD = 98

MAX_DUR = 96

# Number of time steps per quarter note
# To get bar resolution -> RESOLUTION*4
RESOLUTION = 8 
NUM_BARS = 2


def preprocess_file(filepath, dest_dir, num_samples):

    saved_samples = 0

    print("Preprocessing file " + filepath)

    # Load the file both as a pypianoroll song and a muspy song
    # (Need to load both since muspy.to_pypianoroll() is expensive)
    try:
        pproll_song = pproll.read(filepath, resolution=RESOLUTION)
        muspy_song = muspy.read(filepath)
    except Exception as e:
        print("Song skipped (Invalid song format)")
        return 0
    
    # Only accept songs that have a time signature of 4/4 and no time changes
    for t in muspy_song.time_signatures:
        if t.numerator != 4 or t.denominator != 4:
            print("Song skipped ({}/{} time signature)".
                            format(t.numerator, t.denominator))
            return 0

    # Gather tracks of pypianoroll song based on MIDI program number
    drum_tracks = []
    bass_tracks = []
    guitar_tracks = []
    strings_tracks = []

    for track in pproll_song.tracks:
        if track.is_drum:
            #continue
            track.name = 'Drums'
            drum_tracks.append(track)
        elif 0 <= track.program <= 31:
            track.name = 'Guitar'
            guitar_tracks.append(track)
        elif 32 <= track.program <= 39:
            track.name = 'Bass'
            bass_tracks.append(track)
        else:
            # Tracks with program > 39 are all considered as strings tracks
            # and will be merged into a single track later on
            strings_tracks.append(track)

    # Filter song if it does not contain drum, guitar, bass or strings tracks
    #if not guitar_tracks \
    if not drum_tracks or not guitar_tracks \
            or not bass_tracks or not strings_tracks:
        print("Song skipped (does not contain drum or "
                "guitar or bass or strings tracks)")
        return 0
    
    # Merge strings tracks into a single pypianoroll track
    strings = pproll.Multitrack(tracks=strings_tracks)
    strings_track = pproll.Track(pianoroll=strings.blend(mode='max'),
                                 program=48, name='Strings')

    combinations = list(product(drum_tracks, bass_tracks, guitar_tracks))
    #combinations = list(product(bass_tracks, guitar_tracks))

    # Single instruments can have multiple tracks.
    # Consider all possible combinations of drum, bass, and guitar tracks
    for i, combination in enumerate(combinations):

        print("Processing combination", i+1, "of", len(combinations))
        
        # Process combination (called 'subsong' from now on)
        drum_track, bass_track, guitar_track = combination
        tracks = [drum_track, bass_track, guitar_track, strings_track]
        #bass_track, guitar_track = combination
        #tracks = [bass_track, guitar_track, strings_track]
        
        pproll_subsong = pproll.Multitrack(
            tracks=tracks,
            tempo=pproll_song.tempo,
            resolution=RESOLUTION
        )
        muspy_subsong = muspy.from_pypianoroll(pproll_subsong)
        
        tracks_notes = [track.notes for track in muspy_subsong.tracks]
        
        # Obtain length of subsong (maximum of each track's length)
        length = 0
        for notes in tracks_notes:
            track_length = max(note.end for note in notes) if notes else 0
            length = max(length, track_length)
        length += 1

        # Add timesteps until length is a multiple of RESOLUTION
        length = length if length%(RESOLUTION*4) == 0 \
                            else length + (RESOLUTION*4-(length%(RESOLUTION*4)))


        tracks_tensors = []
        tracks_activations = []

        # Todo: adapt to velocity
        for notes in tracks_notes:

            # Initialize encoder-ready track tensor
            # track_tensor: (length x max_simu_notes x 2 (or 3 if velocity))
            # The last dimension contains pitches and durations (and velocities)
            # int16 is enough for small to medium duration values
            track_tensor = np.zeros((length, MAX_SIMU_NOTES, 2), np.int16)

            track_tensor[:, :, 0] = PITCH_PAD
            track_tensor[:, 0, 0] = PITCH_SOS
            track_tensor[:, :, 1] = DUR_PAD
            track_tensor[:, 0, 1] = DUR_SOS

            # Keeps track of how many notes have been stored in each timestep
            # (int8 imposes that MAX_SIMU_NOTES < 256)
            notes_counter = np.ones(length, dtype=np.int8)

            # Todo: np.put_along_axis?
            for note in notes:
                # Insert note in the lowest position available in the timestep
                
                t = note.time
                
                if notes_counter[t] >= MAX_SIMU_NOTES-1:
                    # Skip note if there is no more space
                    continue
                
                pitch = max(min(note.pitch, 127), 0)
                track_tensor[t, notes_counter[t], 0] = pitch
                dur = max(min(MAX_DUR, note.duration), 1)
                track_tensor[t, notes_counter[t], 1] = dur-1
                notes_counter[t] += 1
            
            # Add end of sequence token
            track_tensor[np.arange(0, length), notes_counter, 0] = PITCH_EOS
            track_tensor[np.arange(0, length), notes_counter, 1] = DUR_EOS
            
            # Get track activations, a boolean tensor indicating whether notes
            # are being played in a timestep (sustain does not count)
            # (needed for graph rep.)
            activations = np.array(notes_counter-1, dtype=bool)
            
            tracks_tensors.append(track_tensor)
            tracks_activations.append(activations)
        
        # (#tracks x length x max_simu_notes x 2 (or 3))
        subsong_tensor = np.stack(tracks_tensors, axis=0)

        # (#tracks x length)
        subsong_activations = np.stack(tracks_activations, axis=0)


        # Slide window over 'subsong_tensor' and 'subsong_activations' along the
        # time axis (2nd dimension) with the stride of a bar
        # Todo: np.lib.stride_tricks.as_strided(song_proll)
        for i in range(0, length-NUM_BARS*RESOLUTION*4+1, RESOLUTION*4):
            
            # Get the sequence and its activations
            seq_tensor = subsong_tensor[:, i:i+NUM_BARS*RESOLUTION*4, :]
            seq_acts = subsong_activations[:, i:i+NUM_BARS*RESOLUTION*4]
            seq_tensor = np.copy(seq_tensor)
            seq_acts = np.copy(seq_acts)

            if NUM_BARS > 1:
                # Skip sequence if it contains more than one bar of consecutive
                # silence in at least one track
                bars = seq_acts.reshape(seq_acts.shape[0], NUM_BARS, -1)
                bars_acts = np.any(bars, axis=2)

                if 1 in np.diff(np.where(bars_acts == 0)[1]):
                    continue
                    
                # Skip sequence if it contains one bar of complete silence
                # (in terms of note activations)
                silences = np.logical_not(np.any(bars_acts, axis=0))
                if np.any(silences):
                    continue
                
            else:
                # In the case of just 1 bar, skip it if all tracks are silenced
                bar_acts = np.any(seq_acts, axis=1)
                if not np.any(bar_acts):
                    continue
            
            # Randomly transpose the pitches of the sequence (-5 to 6 semitones)
            # Not considering pad, sos, eos tokens
            # Not transposing drums/percussions
            shift = np.random.choice(np.arange(-5, 7), 1)
            cond = (seq_tensor[1:, :, :, 0] != PITCH_PAD) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_SOS) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_EOS)
            #cond = (seq_tensor[:, :, :, 0] != PITCH_PAD) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_SOS) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_EOS)
            non_perc = seq_tensor[1:, ...]
            #non_perc = seq_tensor
            non_perc[cond, 0] += shift
            non_perc[cond, 0] = np.clip(non_perc[cond, 0], a_min=0, a_max=127)

            # Save sample (seq_tensor and seq_acts) to file
            curr_sample = str(num_samples + saved_samples)
            sample_filepath = os.path.join(dest_dir, curr_sample)
            np.savez(sample_filepath, seq_tensor=seq_tensor, seq_acts=seq_acts)

            saved_samples += 1


    print("File preprocessing finished. Saved samples:", saved_samples)
    print()

    return saved_samples



# Total number of files: 116189
# Number of unique files: 45129
def preprocess_dataset(dataset_dir, dest_dir, num_files=45129, early_exit=None):

    files_dict = {}
    seen = 0
    tot_samples = 0
    not_filtered = 0
    finished = False
    
    print("Starting preprocessing")
    
    progress_bar = tqdm(range(early_exit)) if early_exit is not None else tqdm(range(num_files))
    start = time.time()

    # Visit recursively the directories inside the dataset directory
    for dirpath, dirs, files in os.walk(dataset_dir):

        # Sort alphabetically the found directories
        # (to help guess the remaining time) 
        dirs.sort()
        
        print("Current path:", dirpath)
        print()

        for f in files:
            
            seen += 1

            if f in files_dict:
                # Skip already seen file
                files_dict[f] += 1
                continue

            # File never seen before, add to dictionary of files
            # (from filename to # of occurrences)
            files_dict[f] = 1

            # Preprocess file
            filepath = os.path.join(dirpath, f)
            n_saved = preprocess_file(filepath, dest_dir, tot_samples)

            tot_samples += n_saved
            if n_saved > 0:
                not_filtered += 1
            
            progress_bar.update(1)
            
            # Todo: also print # of processed (not filtered) files
            #       and # of produced sequences (samples)
            print("Total number of seen files:", seen)
            print("Number of unique files:", len(files_dict))
            print("Total number of non filtered songs:", not_filtered)
            print("Total number of saved samples:", tot_samples)
            print()

            # Exit when a maximum number of files has been processed (if set)
            if early_exit != None and len(files_dict) >= early_exit:
                finished = True
                break

        if finished:
            break
    
    end = time.time()
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Preprocessing completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
              .format(int(hours),int(minutes),seconds))


In [8]:
#!rm -rf data/preprocessed/
#!mkdir data/preprocessed

In [9]:
#dataset_dir = 'data/lmd_matched/Y/G/'
#dest_dir = 'data/preprocessed'

Check preprocessed data:

In [10]:
#preprocess_dataset(dataset_dir, dest_dir, early_exit=10)

In [11]:
#filepath = os.path.join(dest_dir, "5.npz")
#data = np.load(filepath)

In [12]:
#print(data["seq_tensor"].shape)
#print(data["seq_acts"].shape)

In [13]:
#data["seq_tensor"][0, 1]

# Model

In [14]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.data import Data
import itertools
from torch_geometric.data.collate import collate


class MIDIDataset(Dataset):

    def __init__(self, dir, n_bars=2):
        self.dir = dir
        _, _, files = next(os.walk(self.dir))
        self.len = len(files)
        self.n_bars = n_bars

        
    def __len__(self):
        return self.len

    
    def _get_track_edges(self, acts, edge_type_ind=0):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        
        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        track_edges = []

        for track in range(a_t.shape[1]):
            tr_inds = list(inds[inds[:,1] == track])
            e_inds = [(tr_inds[i],
                    tr_inds[i+1]) for i in range(len(tr_inds)-1)]
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind+track, e[1][0]-e[0][0]) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            track_edges.extend(edges)
            track_edges.extend(inv_edges)

        return np.array(track_edges, dtype='long')

    
    def _get_onset_edges(self, acts, edge_type_ind=4):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        onset_edges = []

        for i in ts_inds:
            ts_acts_inds = list(inds[inds[:,0] == i])
            if len(ts_acts_inds) < 2:
                continue
            e_inds = list(itertools.combinations(ts_acts_inds, 2))
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, 0) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            onset_edges.extend(edges)
            onset_edges.extend(inv_edges)

        return np.array(onset_edges, dtype='long')


    def _get_next_edges(self, acts, edge_type_ind=5):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        next_edges = []

        for i in range(len(ts_inds)-1):

            ind_s = ts_inds[i]
            ind_e = ts_inds[i+1]
            s = inds[inds[:,0] == ind_s]
            e = inds[inds[:,0] == ind_e]

            e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, ind_e-ind_s) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            
            next_edges.extend(edges)
            next_edges.extend(inv_edges)
            
        return np.array(next_edges, dtype='long')
    
    def _get_super_edges(self, num_nodes, edge_type_ind=6):
    
        super_edges = [(num_nodes, i, edge_type_ind, 0) for i in range(num_nodes)]
        inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
        
        super_edges.extend(inv_edges)
        
        return np.array(super_edges, dtype='long')
        
    
    def _get_node_features(self, acts, num_nodes):
        
        num_tracks = acts.shape[0]
        features = torch.zeros((num_nodes, num_tracks), dtype=torch.float)
        features[np.arange(num_nodes), np.stack(np.where(acts))[0]] = 1.
        
        return features


    def __getitem__(self, idx):

        # Load tensors
        sample_path = os.path.join(self.dir, str(idx) + ".npz")
        data = np.load(sample_path)
        seq_tensor = data["seq_tensor"]
        seq_acts = data["seq_acts"]
        
        # From (#tracks x #timesteps x ...) to (#bars x #tracks x #timesteps x ...)
        seq_tensor = seq_tensor.reshape(seq_tensor.shape[0], self.n_bars, -1,
                                        seq_tensor.shape[2], seq_tensor.shape[3])
        seq_tensor = seq_tensor.transpose(1, 0, 2, 3, 4)
        seq_acts = seq_acts.reshape(seq_acts.shape[0], self.n_bars, -1)
        seq_acts = seq_acts.transpose(1, 0, 2)
        
        # Construct src_key_padding_mask (PAD = 130)
        src_mask = torch.from_numpy((seq_tensor[..., 0] == 130))

        # From decimals to one-hot (pitch)
        pitches = seq_tensor[..., 0]
        onehot_p = np.zeros(
            (pitches.shape[0]*pitches.shape[1]*pitches.shape[2]*pitches.shape[3],
             131), 
            dtype=float
        )
        onehot_p[np.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1.
        onehot_p = onehot_p.reshape(pitches.shape[0], pitches.shape[1], 
                                    pitches.shape[2], pitches.shape[3], 131)
        
        # From decimals to one-hot (dur)
        durs = seq_tensor[..., 1]
        onehot_d = np.zeros(
            (durs.shape[0]*durs.shape[1]*durs.shape[2]*durs.shape[3],
             99),
            dtype=float
        )
        onehot_d[np.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1.
        onehot_d = onehot_d.reshape(durs.shape[0], durs.shape[1], 
                                    durs.shape[2], durs.shape[3], 99)
        
        # Concatenate pitches and durations
        new_seq_tensor = np.concatenate((onehot_p, onehot_d),
                                        axis=-1)
        
        graphs = []
        
        # Iterate over bars and construct a graph for each bar
        for i in range(self.n_bars):
            
            # Number of nodes
            n = torch.sum(torch.Tensor(seq_acts[i]), dtype=torch.long)
            
            # Get edges from boolean activations
            # Todo: optimize and refactor
            track_edges = self._get_track_edges(seq_acts[i])
            onset_edges = self._get_onset_edges(seq_acts[i])
            next_edges = self._get_next_edges(seq_acts[i])
            #super_edges = self._get_super_edges(n)
            edges = [track_edges, onset_edges, next_edges]
            
            # Concatenate edge tensors (N x 4) (if any)
            # First two columns -> source and dest nodes
            # Third column -> edge_type, Fourth column -> timestep distance
            no_edges = (len(track_edges) == 0 and 
                        len(onset_edges) == 0 and len(next_edges) == 0)
            if not no_edges:
                edge_list = np.concatenate([x for x in edges
                                              if x.size > 0])
                edge_list = torch.from_numpy(edge_list)
                
            # Adapt tensor to torch_geometric's Data
            # No edges: add fictitious self-edge
            edge_index = (torch.LongTensor([[0], [0]]) if no_edges else
                                   edge_list[:, :2].t().contiguous())
            attrs = (torch.Tensor([[0, 0]]) if no_edges else
                                           edge_list[:, 2:])

            # One hot timestep distance concatenated to edge type
            edge_attrs = torch.zeros(attrs.size(0), 1+seq_acts.shape[-1])
            edge_attrs[:, 0] = attrs[:, 0]
            edge_attrs[np.arange(edge_attrs.size(0)), attrs.long()[:, 1]+1] = 1
            #edge_attrs = torch.Tensor(attrs.float())
            
            node_features = self._get_node_features(seq_acts[i], n)
            is_drum = node_features[:, 0].bool()
            
            graphs.append(Data(edge_index=edge_index, edge_attrs=edge_attrs,
                               num_nodes=n, node_features=node_features,
                               is_drum=is_drum))
            
            
        # Merge the graphs corresponding to different bars into a single big graph
        graphs, _, inc_dict = collate(
            Data,
            data_list=graphs,
            increment=True,
            add_batch=True
        )
        
        # Change bars assignment vector name (otherwise, Dataloader's collate
        # would overwrite graphs.batch)
        graphs.bars = graphs.batch
        
        # Filter silences in order to get a sparse representation
        new_seq_tensor = new_seq_tensor.reshape(-1, new_seq_tensor.shape[-2],
                                                new_seq_tensor.shape[-1])
        src_mask = src_mask.reshape(-1, src_mask.shape[-1])
        new_seq_tensor = new_seq_tensor[seq_acts.reshape(-1).astype(bool)]
        src_mask = src_mask[seq_acts.reshape(-1).astype(bool)]
        
        new_seq_tensor = torch.Tensor(new_seq_tensor)
        seq_acts = torch.Tensor(seq_acts)
        graphs.x_seq = new_seq_tensor
        graphs.x_acts = seq_acts
        graphs.src_mask = src_mask
        
        # Todo: start with torch at mount
        #return torch.Tensor(new_seq_tensor), torch.Tensor(seq_acts), graphs, src_mask
        return graphs


In [15]:
from typing import Optional, Union, Tuple
from torch_geometric.typing import OptTensor, Adj
from typing import Callable
from torch_geometric.nn.inits import reset

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter as Param
from torch.nn import Parameter
from torch_scatter import scatter
from torch_sparse import SparseTensor, matmul, masked_select_nnz
from torch_geometric.nn.conv import MessagePassing

from torch_geometric.nn.inits import glorot, zeros


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (Tensor, Tensor) -> Tensor
    pass


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (SparseTensor, Tensor) -> SparseTensor
    pass


def masked_edge_index(edge_index, edge_mask):
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]
    else:
        return masked_select_nnz(edge_index, edge_mask, layout='coo')

def masked_edge_attrs(edge_attrs, edge_mask):
    return edge_attrs[edge_mask, :]


class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
        \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
    stores a relation identifier
    :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.

    .. note::
        This implementation is as memory-efficient as possible by iterating
        over each individual relation type.
        Therefore, it may result in low GPU utilization in case the graph has a
        large number of relations.
        As an alternative approach, :class:`FastRGCNConv` does not iterate over
        each individual type, but may consume a large amount of memory to
        compensate.
        We advise to check out both implementations to see which one fits your
        needs.

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
            In case no input features are given, this argument should
            correspond to the number of nodes in your graph.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
            maps edge features :obj:`edge_attr` of shape :obj:`[-1,
            num_edge_features]` to shape
            :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`.
        num_bases (int, optional): If set to not :obj:`None`, this layer will
            use the basis-decomposition regularization scheme where
            :obj:`num_bases` denotes the number of bases to use.
            (default: :obj:`None`)
        num_blocks (int, optional): If set to not :obj:`None`, this layer will
            use the block-diagonal-decomposition regularization scheme where
            :obj:`num_blocks` denotes the number of blocks to use.
            (default: :obj:`None`)
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"mean"`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_relations: int,
        #num_dists: int,
        nn: Callable,
        num_bases: Optional[int] = None,
        num_blocks: Optional[int] = None,
        dropout: Optional[float] = 0.1,
        aggr: str = 'mean',
        root_weight: bool = True,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(aggr=aggr, node_dim=0, **kwargs)

        if num_bases is not None and num_blocks is not None:
            raise ValueError('Can not apply both basis-decomposition and '
                             'block-diagonal-decomposition at the same time.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.nn = nn
        #self.num_dists = num_dists
        self.dropout = dropout
        self.num_relations = num_relations
        self.num_bases = num_bases
        self.num_blocks = num_blocks

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)
        self.in_channels_l = in_channels[0]

        if num_bases is not None:
            self.weight = Parameter(
                torch.Tensor(num_bases, in_channels[0], out_channels))
            self.comp = Parameter(torch.Tensor(num_relations, num_bases))
        
        elif num_blocks is not None:
            assert (in_channels[0] % num_blocks == 0
                    and out_channels % num_blocks == 0)
            self.weight = Parameter(
                torch.Tensor(num_relations, num_blocks,
                             in_channels[0] // num_blocks,
                             out_channels // num_blocks))
            self.register_parameter('comp', None)

        else:
            self.weight = Parameter(
                torch.Tensor(num_relations, in_channels[0], out_channels))
            self.register_parameter('comp', None)

        if root_weight:
            self.root = Param(torch.Tensor(in_channels[1], out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Param(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        #self.dist_weights = Parameter(torch.Tensor(self.num_dists))

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        reset(self.nn)
        glorot(self.comp)
        glorot(self.root)
        #zeros(self.dist_weights)
        zeros(self.bias)


    def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
                edge_index: Adj, edge_type: OptTensor = None,
                edge_attr: OptTensor = None):
        r"""
        Args:
            x: The input node features. Can be either a :obj:`[num_nodes,
                in_channels]` node feature matrix, or an optional
                one-dimensional node index tensor (in which case input features
                are treated as trainable node embeddings).
                Furthermore, :obj:`x` can be of type :obj:`tuple` denoting
                source and destination node features.
            edge_type: The one-dimensional relation type/index for each edge in
                :obj:`edge_index`.
                Should be only :obj:`None` in case :obj:`edge_index` is of type
                :class:`torch_sparse.tensor.SparseTensor`.
                (default: :obj:`None`)
        """

        # Convert input features to a pair of node features or node indices.
        x_l: OptTensor = None
        if isinstance(x, tuple):
            x_l = x[0]
        else:
            x_l = x
        if x_l is None:
            x_l = torch.arange(self.in_channels_l, device=self.weight.device)

        x_r: Tensor = x_l
        if isinstance(x, tuple):
            x_r = x[1]

        size = (x_l.size(0), x_r.size(0))

        if isinstance(edge_index, SparseTensor):
            edge_type = edge_index.storage.value()
        assert edge_type is not None

        # propagate_type: (x: Tensor)
        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

        weight = self.weight
        if self.num_bases is not None:  # Basis-decomposition =================
            weight = (self.comp @ weight.view(self.num_bases, -1)).view(
                self.num_relations, self.in_channels_l, self.out_channels)

        if self.num_blocks is not None:  # Block-diagonal-decomposition =====

            if x_l.dtype == torch.long and self.num_blocks is not None:
                raise ValueError('Block-diagonal decomposition not supported '
                                 'for non-continuous input features.')

            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                h = self.propagate(tmp, x=x_l, size=size)
                h = h.view(-1, weight.size(1), weight.size(2))
                h = torch.einsum('abc,bcd->abd', h, weight[i])
                out += h.contiguous().view(-1, self.out_channels)

        else:  # No regularization/Basis-decomposition ========================
            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                attr = masked_edge_attrs(edge_attr, edge_type == i)

                if x_l.dtype == torch.long:
                    out += self.propagate(tmp, x=weight[i, x_l], size=size)
                else:
                    h = self.propagate(tmp, x=x_l, size=size,
                                       edge_attr=attr)
                    out = out + (h @ weight[i])

        root = self.root
        if root is not None:
            out += root[x_r] if x_r.dtype == torch.long else x_r @ root

        if self.bias is not None:
            out += self.bias

        return out


    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        #weights = self.dist_weights[edge_attr.view(-1).long()]
        #weights = torch.diag(weights)
        #return torch.matmul(weights, x_j)
        weights = self.nn(edge_attr)
        weights = weights[..., :self.in_channels_l]
        weights = weights.view(-1, self.in_channels_l)
        ret = x_j * weights
        ret = F.relu(ret)
        ret = F.dropout(ret, p=self.dropout, training=self.training)
        return ret
    
        
        
    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_relations={self.num_relations})')

In [16]:
import torch
from torch import nn, Tensor
from torch_geometric.nn.conv import GCNConv#, RGCNConv
from torch_geometric.nn.norm import BatchNorm
from torch_geometric.nn.glob import GlobalAttention
import torch.nn.functional as F
import math
import torch.optim as optim
from torch_scatter import scatter_mean
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence


# Todo: check and think about max_len
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 256):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *                     \
                             (-math.log(10000.0)/d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position*div_term)
        pe[:, 0, 1::2] = torch.cos(position*div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    

class MLP(nn.Module):
    
    def __init__(self, input_dim=256, hidden_dim=256, output_dim=256, num_layers=2,
                 act=True, dropout=0.1):
        super().__init__()
        
        assert num_layers >= 1
        
        self.layers = nn.ModuleList()
        
        if num_layers == 1:
            self.layers.append(nn.Linear(input_dim, output_dim))
        else:
            self.layers.append(nn.Linear(input_dim, hidden_dim))
        
            for i in range(num_layers-2):
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))

            self.layers.append(nn.Linear(hidden_dim, output_dim))
        
        self.act = act
        self.p = dropout
        

    def forward(self, x):
        
        for layer in self.layers:
            x = F.dropout(x, p=self.p, training=self.training)
            x = layer(x)
            if self.act:
                x = F.relu(x)
        
        return x
    
    
class EncoderRNN(nn.Module):
    
    def __init__(self, input_size=240, hidden_size=256, num_layers=2, 
                 dropout=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Linear(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True,
                            num_layers=num_layers, bidirectional=True)

    def forward(self, x, hidden):
        embedded = self.embedding(x).view(1, 1, -1)
        output = embedded
        output, hidden = self.lstm(output, hidden)
        return output, hidden
    

class DecoderRNN(nn.Module):
    
    def __init__(self, hidden_size=256, output_size=240, num_layers=2, 
                 dropout=0.1):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Linear(output_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True,
                            num_layers=num_layers, bidirectional=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        output = self.embedding(x).view(1, 1, -1)
        output, hidden = self.lstm(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)


class GraphEncoder(nn.Module):
    
    def __init__(self, input_dim=256, hidden_dim=256, n_layers=3, 
                 num_relations=3, num_dists=32, batch_norm=False, dropout=0.1):
        super().__init__()
        
        self.layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        edge_nn = nn.Linear(num_dists, input_dim)
        self.batch_norm = batch_norm
        
        self.layers.append(RGCNConv(input_dim, hidden_dim,
                                    num_relations, edge_nn))
        if self.batch_norm:
            self.norm_layers.append(BatchNorm(hidden_dim))
        
        for i in range(n_layers-1):
            #edge_nn = nn.Linear(num_dists, input_dim)
            self.layers.append(RGCNConv(hidden_dim, hidden_dim,
                                        num_relations, edge_nn))
            if self.batch_norm:
                self.norm_layers.append(BatchNorm(hidden_dim))
            
        self.p = dropout
        

    def forward(self, data):
        x, edge_index, edge_attrs = data.x, data.edge_index, data.edge_attrs
        #batch = data.distinct_bars
        edge_type = edge_attrs[:, 0]
        edge_attr = edge_attrs[:, 1:]
        
        for i in range(len(self.layers)):
            residual = x
            x = F.dropout(x, p=self.p, training=self.training)
            x = self.layers[i](x, edge_index, edge_type, edge_attr)
            if self.batch_norm:
                x = self.norm_layers[i](x)
            x = F.relu(x)
            x = residual + x

        return x
    
    
class CNNEncoder(nn.Module):
    
    def __init__(self, output_dim=256, dense_dim=256, batch_norm=False,
                 dropout=0.1):
        super().__init__()
        
        if batch_norm:
            self.conv = nn.Sequential(
                # 4*32 --> 8*4*32
                nn.Conv2d(1, 8, 3, padding=1),
                nn.BatchNorm2d(8),
                nn.ReLU(True),
                # 8*4*32 --> 8*4*8
                nn.MaxPool2d((1, 4), stride=(1, 4)),
                # 8*4*8 --> 16*4*8
                nn.Conv2d(8, 16, 3, padding=1),
                nn.BatchNorm2d(16),
                nn.ReLU(True)
            )
        else:
            self.conv = nn.Sequential(
                # 4*32 --> 8*4*32
                nn.Conv2d(1, 8, 3, padding=1),
                nn.ReLU(True),
                # 8*4*32 --> 8*4*8
                nn.MaxPool2d((1, 4), stride=(1, 4)),
                # 8*4*8 --> 16*4*8
                nn.Conv2d(8, 16, 3, padding=1),
                nn.ReLU(True)
            )
        
        self.flatten = nn.Flatten(start_dim=1)
        
        self.lin = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(16*4*8, dense_dim),
            nn.ReLU(True),
            nn.Dropout(dropout),
            nn.Linear(dense_dim, output_dim)
        )
        
        
    def forward(self, x):
        
        x = x.unsqueeze(1)
        x = self.conv(x)
        x = self.flatten(x)
        x = self.lin(x)
        
        return x
    

class CNNDecoder(nn.Module):
    
    def __init__(self, input_dim=256, dense_dim=256, batch_norm=False,
                 dropout=0.1):
        super().__init__()
        
        self.lin = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_dim, dense_dim),
            nn.ReLU(True),
            nn.Dropout(dropout),
            nn.Linear(dense_dim, 16*4*8),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(16, 4, 8))

        if batch_norm:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=(1, 4), mode='nearest'),
                nn.Conv2d(16, 8, 3, padding=1),
                nn.BatchNorm2d(8),
                nn.ReLU(True),
                nn.Conv2d(8, 1, 3, padding=1)
            )
        else:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=(1, 4), mode='nearest'),
                nn.Conv2d(16, 8, 3, padding=1),
                nn.ReLU(True),
                nn.Conv2d(8, 1, 3, padding=1)
            )
        
        
    def forward(self, x):
        
        x = self.lin(x)
        x = self.unflatten(x)
        x = self.conv(x)
        
        return x.unsqueeze(1)
        

class Encoder(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        
        self.__dict__.update(kwargs)
        
        self.dropout_layer = nn.Dropout(p=self.dropout)
        
        
        self.notes_pitch_emb = nn.Linear(self.d_token_pitches, 
                                               self.d//2)
        
        self.bn_npe = nn.BatchNorm1d(num_features=self.d//2)
        
        self.drums_pitch_emb = nn.Linear(self.d_token_pitches, 
                                               self.d//2)
        
        self.bn_dpe = nn.BatchNorm1d(num_features=self.d//2)
        
        self.dur_emb = nn.Linear(self.d_token_dur, self.d//2)
        
        self.bn_de = nn.BatchNorm1d(num_features=self.d//2)
        
        self.chord_encoder = nn.Linear(self.d * (self.max_simu_notes-1),
                                       self.d)

        # Graph encoder
        self.graph_encoder = GraphEncoder(dropout=self.dropout, 
                                          input_dim=self.d,
                                          hidden_dim=self.d,
                                          n_layers=self.gnn_n_layers,
                                          num_relations=self.n_relations,
                                          batch_norm=self.batch_norm)
        
        gate_nn = nn.Sequential(
            MLP(input_dim=self.d, output_dim=1, num_layers=1, act=False,
                      dropout=self.dropout),
            nn.BatchNorm1d(1)
        )
        self.graph_attention = GlobalAttention(gate_nn)
        
        #self.context_bar_rnn = nn.GRU(input_size=self.d,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              bidirectional=True,
        #                              batch_first=True, 
        #                              dropout=self.dropout)
        
        self.bars_encoder_attr = nn.Linear(self.n_bars*self.d,
                                           self.d)
        
        
        self.cnn_encoder = CNNEncoder(dense_dim=self.d,
                                      output_dim=self.d,
                                      dropout=0,
                                      batch_norm=self.batch_norm)
        
        self.bars_encoder_struct = nn.Linear(self.n_bars*self.d,
                                             self.d)
        #self.struct_bar_rnn = nn.GRU(input_size=self.d,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              batch_first=True, 
        #                              dropout=self.dropout)
        
        self.linear_merge = nn.Linear(2*self.d, self.d)
        
        self.bn_lm = nn.BatchNorm1d(num_features=self.d)
        
        # Linear layers that compute the final mu and log_var
        # Todo: as parameters
        self.linear_mu = nn.Linear(self.d, self.d)
        self.linear_log_var = nn.Linear(self.d, self.d)

        
    def forward(self, x_seq, x_acts, x_graph, src_mask):
        
        # No start of seq token
        x_seq = x_seq[:, 1:, :]
        
        # Get drums and non drums tensors
        drums = x_seq[x_graph.is_drum]
        src_mask_drums = src_mask[x_graph.is_drum]
        non_drums = x_seq[torch.logical_not(x_graph.is_drum)]
        src_mask_non_drums = src_mask[torch.logical_not(x_graph.is_drum)]
        
        # Permute dimensions to batch_first = False
        #drums = drums.permute(1, 0, 2)
        #non_drums = non_drums.permute(1, 0, 2)
        
        # Compute note/drums embeddings
        s = drums.size()
        drums_pitch = self.drums_pitch_emb(drums[..., :self.d_token_pitches])
        drums_pitch = self.bn_dpe(drums_pitch.view(-1, self.d//2))
        drums_pitch = drums_pitch.view(s[0], s[1], self.d//2)
        drums_dur = self.dur_emb(drums[..., self.d_token_pitches:])
        drums_dur = self.bn_de(drums_dur.view(-1, self.d//2))
        drums_dur = drums_dur.view(s[0], s[1], self.d//2)
        drums = torch.cat((drums_pitch, drums_dur), dim=-1)
        #drums = self.dropout_layer(drums)
        # [n_nodes x max_simu_notes x d]
        
        s = non_drums.size()
        non_drums_pitch = self.notes_pitch_emb(non_drums[..., :self.d_token_pitches])
        non_drums_pitch = self.bn_npe(non_drums_pitch.view(-1, self.d//2))
        non_drums_pitch = non_drums_pitch.view(s[0], s[1], self.d//2)
        non_drums_dur = self.dur_emb(non_drums[..., self.d_token_pitches:])
        non_drums_dur = self.bn_de(non_drums_dur.view(-1, self.d//2))
        non_drums_dur = non_drums_dur.view(s[0], s[1], self.d//2)
        non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1)
        #non_drums = self.dropout_layer(non_drums)
        # [n_nodes x max_simu_notes x d]
        
        #len_drums = self.max_simu_notes - torch.sum(src_mask_drums, dim=1)
        #len_non_drums = self.max_simu_notes - torch.sum(src_mask_non_drums, dim=1)
        
        #drums = pack_padded_sequence(drums, len_drums.cpu().view(-1),
        #                             enforce_sorted=False)
        #non_drums = pack_padded_sequence(non_drums, len_non_drums.cpu().view(-1),
        #                                 enforce_sorted=False)

        # Compute chord embeddings both for drums and non drums
        drums = self.chord_encoder(drums.view(-1, self.d*(self.max_simu_notes-1)))
        non_drums = self.chord_encoder(non_drums.view(-1, self.d*(self.max_simu_notes-1)))
        drums = F.relu(drums)
        non_drums = F.relu(non_drums)
        drums = self.dropout_layer(drums)
        non_drums = self.dropout_layer(non_drums)
        # [n_nodes x d]
        
        #hidden = torch.zeros(drums.size(1), )
        #drums = self.chord_encoder_drums(drums)[-1]
        #non_drums = self.chord_encoder(non_drums)[-1]
        #drums = torch.mean(drums, dim=0)
        #non_drums = torch.mean(non_drums, dim=0)
        
        #drums = self.dropout_layer(drums)
        #non_drums = self.dropout_layer(non_drums)
        
        # Merge drums and non-drums
        #out = torch.zeros((x_seq.size(0), self.d_model), 
        #                  device=self.device)
        out = torch.zeros((x_seq.size(0), self.d), 
                          device=self.device, dtype=torch.half)
        out[x_graph.is_drum] = drums
        out[torch.logical_not(x_graph.is_drum)] = non_drums
        # [n_nodes x d]
        
        #x_graph.x = torch.cat((x_graph.node_features, out), 1)
        x_graph.x = out
        x_graph.distinct_bars = x_graph.bars + self.n_bars*x_graph.batch
        out = self.graph_encoder(x_graph)
        # [n_nodes x d]
        
        with torch.cuda.amp.autocast(enabled=False):
            out = self.graph_attention(out,
                                       batch=x_graph.distinct_bars)
            # [bs x n_bars x d]
            
        out = out.view(-1, self.n_bars * self.d)
        # [bs x n_bars * d]
        out_attr = self.bars_encoder_attr(out)
        # [bs x d]
        
        # Process structure
        out = self.cnn_encoder(x_acts.view(-1, self.n_tracks,
                                                self.resolution*4))
        # [bs * n_bars x d]
        out = out.view(-1, self.n_bars * self.d)
        # [bs x n_bars * d]
        out_struct = self.bars_encoder_struct(out)
        # [bs x d]
        
        # Merge attr state and struct state
        out = torch.cat((out_attr, out_struct), dim=1)
        out = self.dropout_layer(out)
        out = self.linear_merge(out)
        out = self.bn_lm(out)
        out = F.relu(out)

        # Compute mu and log(std^2)
        out = self.dropout_layer(out)
        mu = self.linear_mu(out)
        log_var = self.linear_log_var(out)
        
        return mu, log_var


class Decoder(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        
        self.__dict__.update(kwargs)
        
        self.dropout_layer = nn.Dropout(p=self.dropout)

        self.lin_divide = nn.Linear(self.d, 2 * self.d)
        
        self.bn_ld = nn.BatchNorm1d(num_features=2*self.d)
        
        self.bars_decoder_attr = nn.Linear(self.d, self.d * self.n_bars)
        #self.context_bar_rnn = nn.GRU(input_size=self.d,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              bidirectional=True,
        #                              batch_first=True,
        #                              dropout=self.dropout)
        
        self.bars_decoder_struct = nn.Linear(self.d, self.d * self.n_bars)
        #self.struct_bar_rnn = nn.GRU(input_size=self.d//2,
        #                              hidden_size=self.d//2,
        #                              num_layers=1,
        #                              batch_first=True,
        #                              dropout=self.dropout)
        
        self.cnn_decoder = CNNDecoder(input_dim=self.d,
                                      dense_dim=self.d,
                                      dropout=0,
                                      batch_norm=self.batch_norm)
        
        self.graph_decoder = GraphEncoder(dropout=self.dropout,
                                          input_dim=self.d,
                                          hidden_dim=self.d,
                                          n_layers=self.gnn_n_layers,
                                          num_relations=self.n_relations,
                                          batch_norm=self.batch_norm)
        
        #gate_nn = nn.Sequential(
        #    MLP(input_dim=self.d, output_dim=1, num_layers=1, act=False,
        #              dropout=self.dropout),
        #    nn.BatchNorm1d(1)
        #)
        #feat_nn = nn.Sequential(
        #    MLP(input_dim=self.d, output_dim=self.d//2, num_layers=1,
        #              dropout=self.dropout),
        #    nn.BatchNorm1d(self.d//2)
        #)
        #self.graph_attention = GlobalAttention(gate_nn, feat_nn)
        
        self.chord_decoder = nn.Linear(self.d, self.d*(self.max_simu_notes-1))
        #self.chord_decoder = nn.GRU(input_size=self.d,
        #                            hidden_size=self.d,
        #                            num_layers=1,
        #                            dropout=self.dropout)
        #self.chord_decoder_drums = nn.GRU(input_size=self.d,
        #                                  hidden_size=self.d,
        #                                  num_layers=1,
        #                                  dropout=self.dropout)
        
        # Pitch and dur linear layers
        self.drums_pitch_emb = nn.Linear(self.d//2, self.d_token_pitches)
        self.notes_pitch_emb = nn.Linear(self.d//2, self.d_token_pitches)
        self.dur_emb = nn.Linear(self.d//2, self.d_token_dur)


    def forward(self, z, x_seq, x_acts, x_graph, src_mask, tgt_mask,
                inference=False):
        # z: [bs x d]
        
        # Obtain z_structure and z_attributes from z
        #z = self.dropout_layer(z)
        z = self.lin_divide(z)
        z = self.bn_ld(z)
        z = F.relu(z)
        z = self.dropout_layer(z)
        # [bs x 2*d]
        
        out_struct = z[:, :self.d]
        # [bs x d] 
        out_struct = self.bars_decoder_struct(out_struct)
        # [bs x n_bars * d]
        
        out_struct = self.cnn_decoder(out_struct.reshape(-1, self.d))
        out_struct = out_struct.view(x_acts.size())
        
        # Decode attributes
        out = z[:, self.d:]
        # [bs x d]
        out = self.bars_decoder_attr(out)
        # [bs x n_bars * d]
        
        # Initialize node features with corresponding z_bar
        # and propagate with GNN
        _, counts = torch.unique(x_graph.distinct_bars, return_counts=True)
        out = out.view(-1, self.d)
        out = torch.repeat_interleave(out, counts, axis=0)
        # [n_nodes x d]
        
        # Add one-hot encoding of tracks
        # Todo: use also edge info
        #x_graph.x = torch.cat((x_graph.node_features, out), 1)
        x_graph.x = out
        out = self.graph_decoder(x_graph)
        # [n_nodes x d]
        #print("Node decodings:", node_decs.size())
        
        
        #out = torch.matmul(out, self.chord_decoder.weight)
        out = self.chord_decoder(out)
        # [n_nodes x max_simu_notes * d]
        out = out.view(-1, self.max_simu_notes-1, self.d)
        
        drums = out[x_graph.is_drum]
        non_drums = out[torch.logical_not(x_graph.is_drum)]
        # [n_nodes(dr/non_dr) x max_simu_notes x d]
        
        # Obtain final pitch and dur decodings
        # (softmax to be applied outside the model)
        non_drums = self.dropout_layer(non_drums)
        drums = self.dropout_layer(drums)
        
        drums_pitch = self.drums_pitch_emb(drums[..., :self.d//2])
        drums_dur = self.dur_emb(drums[..., self.d//2:])
        drums = torch.cat((drums_pitch, drums_dur), dim=-1)
        #drums_pitch = torch.matmul(drums[..., :self.d//2], self.drums_pitch_emb.weight)
        #drums_dur = torch.matmul(drums[..., self.d//2:], self.dur_emb.weight)
        #drums = torch.cat((drums_pitch, drums_dur), dim=-1)
        # [n_nodes(dr) x max_simu_notes x d_token]
        non_drums_pitch = self.notes_pitch_emb(non_drums[..., :self.d//2])
        non_drums_dur = self.dur_emb(non_drums[..., self.d//2:])
        non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1)
        #non_drums_pitch = torch.matmul(non_drums[..., :self.d//2], self.notes_pitch_emb.weight)
        #non_drums_dur = torch.matmul(non_drums[..., self.d//2:], self.dur_emb.weight)
        #non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1)
        # [n_nodes(non_dr) x max_simu_notes x d_token]
        
        # Merge drums and non-drums
        out = torch.zeros((x_seq.size(0), x_seq.size(1), x_seq.size(2)),
                          device=self.device, dtype=torch.half)
        out[x_graph.is_drum] = drums
        out[torch.logical_not(x_graph.is_drum)] = non_drums
        
        out = out.view(x_seq.size())

        return out, out_struct


class VAE(nn.Module):

    def __init__(self, **kwargs):
        super().__init__()
        
        self.encoder = Encoder(**kwargs)
        self.decoder = Decoder(
            **kwargs
        )
    
    
    def forward(self, x_seq, x_acts, x_graph, src_mask, tgt_mask,
                inference=False):
        
        #src_mask = src_mask.view(-1, src_mask.size(-1))
        
        # Encoder pass
        mu, log_var = self.encoder(x_seq, x_acts, x_graph, src_mask)
        
        # Reparameterization trick
        z = torch.exp(0.5*log_var)
        z = z * torch.randn_like(z)
        #print("eps:", eps.size())
        z = z + mu
        
        # Shifting target sequence and mask for transformer decoder
        tgt = x_seq[..., :-1, :]
        src_mask = src_mask[:, :-1]
        
        # Decoder pass
        out = self.decoder(z, tgt, x_acts, x_graph, src_mask, tgt_mask,
                           inference=inference)
        
        return out, mu, log_var


Trainer

In [17]:
import torch.optim as optim
import matplotlib.pyplot as plt
import uuid
import copy
import time
from statistics import mean
from collections import defaultdict
import math


def generate_square_subsequent_mask(sz):
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


def append_dict(dest_d, source_d):
        
    for k, v in source_d.items():
        dest_d[k].append(v)


def sigmoid(x, m=1, a=1, z=0):
    return m / (1 + math.exp(-a*(x-z)))


class VAETrainer():
    
    def __init__(self, model_dir, checkpoint=False, model=None, optimizer=None,
                 init_lr=1e-4, lr_scheduler=None, device=torch.device("cuda"), 
                 print_every=1, save_every=1, eval_every=100, iters_to_accumulate=1,
                 **kwargs):
        
        self.__dict__.update(kwargs)
        
        self.model_dir = model_dir
        self.device = device
        self.print_every = print_every
        self.save_every = save_every
        self.eval_every = eval_every
        self.iters_to_accumulate = iters_to_accumulate
        
        # Criteria with ignored padding
        self.bce_unreduced = nn.BCEWithLogitsLoss(reduction='none')
        self.ce_p = nn.CrossEntropyLoss(ignore_index=130)
        self.ce_d = nn.CrossEntropyLoss(ignore_index=98)
        
        # Training stats
        self.tr_losses = defaultdict(list)
        self.tr_accuracies = defaultdict(list)
        self.val_losses = defaultdict(list)
        self.val_accuracies = defaultdict(list)
        self.lrs = []
        self.betas = []
        self.times = []
        
        self.model = model
        self.optimizer = optimizer
        self.init_lr = init_lr
        self.lr_scheduler = lr_scheduler
        
        self.tot_batches = 0
        self.beta = 0
        self.min_val_loss = np.inf
        
        if checkpoint:
            self.load_checkpoint()
        
    
    def train(self, trainloader, validloader=None, epochs=1,
              early_exit=None):
        
        self.model.train()
        
        print("Starting training.\n")
        
        if not self.times:
            start = time.time()
            self.times.append(start)
        
        progress_bar = tqdm(range(len(trainloader)))
        scaler = torch.cuda.amp.GradScaler()
                
        # Zero out the gradients
        self.optimizer.zero_grad()
        
        for epoch in range(epochs):
            
            self.cur_epoch = epoch
            
            for batch_idx, inputs in enumerate(trainloader):
                
                self.cur_batch_idx = batch_idx
                
                # Get the inputs
                x_graph = inputs.to(self.device)
                x_seq, x_acts, src_mask = x_graph.x_seq, x_graph.x_acts, x_graph.src_mask
                tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(self.device)
                
                inputs = (x_seq, x_acts, x_graph)

                with torch.cuda.amp.autocast():
                    # Forward pass, get the reconstructions
                    outputs, mu, log_var = self.model(x_seq, x_acts, x_graph, src_mask, tgt_mask)

                    # Compute the backprop loss and other required losses
                    tot_loss, losses = self._compute_losses(inputs, outputs, mu,
                                                             log_var)
                    tot_loss = tot_loss / self.iters_to_accumulate
                
                # Free GPU
                del x_seq
                del x_acts
                del src_mask
                del tgt_mask
                
                # Backprop
                scaler.scale(tot_loss).backward()
                #tot_loss.backward()
                #self.optimizer.step()
                    
                if (self.tot_batches + 1) % self.iters_to_accumulate == 0:
                    scaler.step(self.optimizer)
                    scaler.update()
                    self.optimizer.zero_grad()
                    if self.lr_scheduler is not None:
                        self.lr_scheduler.step()
                    if self.beta_update:    
                        self._update_beta()
                
                # Compute accuracies
                accs = self._compute_accuracies(inputs, outputs, x_graph.is_drum)
                
                # Update the stats
                append_dict(self.tr_losses, losses)
                last_lr = (self.lr_scheduler.lr 
                               if self.lr_scheduler is not None else self.init_lr)
                self.lrs.append(last_lr)
                self.betas.append(self.beta)
                append_dict(self.tr_accuracies, accs)
                now = time.time()
                self.times.append(now)
                
                # Print stats
                if (self.tot_batches + 1) % self.print_every == 0:
                    print("Training on batch {}/{} of epoch {}/{} complete."
                          .format(batch_idx+1, len(trainloader), epoch+1, epochs))
                    self._print_stats()
                    #print("Tot_loss: {:.4f} acts_loss: {:.4f} "
                          #.format(running_loss/self.print_every, acts_loss), end='')
                    #print("pitches_loss: {:.4f} dur_loss: {:.4f} kld_loss: {:.4f}"
                          #.format(pitches_loss, dur_loss, kld_loss))
                    print("\n----------------------------------------\n")
                    
                # ------------------------------------
                # EVAL ON VL SET EVERY N GRADIENT UPDATES
                # ------------------------------------
                
                if validloader is not None and (self.tot_batches + 1) % self.eval_every == 0:
                    
                    # Evaluate on val set
                    print("\nEvaluating on validation set...\n")
                    val_losses, val_accuracies = self.evaluate(validloader)
                    
                    # Update stats
                    append_dict(self.val_losses, val_losses)
                    append_dict(self.val_accuracies, val_accuracies)
                    
                    print("Val losses:")
                    print(val_losses)
                    print("Val accuracies:")
                    print(val_accuracies)
                    
                    # Save model if val loss (tot) reached a new minimum
                    tot_loss = val_losses['tot']
                    if tot_loss < self.min_val_loss:
                        print("\nValidation loss improved.")
                        print("Saving new best model to disk...\n")
                        self._save_model('best_model')
                        self.min_val_loss = tot_loss
                    
                    self.model.train()
                
                progress_bar.update(1)     
                    
                # When appropriate, save model and stats on disk
                if self.save_every > 0 and (self.tot_batches + 1) % self.save_every == 0:
                    print("\nSaving model to disk...\n")
                    self._save_model('checkpoint')
                
                # Stop prematurely if early_exit is set and reached
                if early_exit is not None and (self.tot_batches + 1) > early_exit:
                    break
                
                self.tot_batches += 1
            

        end = time.time()
        # Todo: self.__print_time()
        hours, rem = divmod(end-start, 3600)
        minutes, seconds = divmod(rem, 60)
        print("Training completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours),int(minutes),seconds))
        
        print("Saving model to disk...")
        self._save_model('checkpoint')
        
        print("Model saved.")
        
    
    def _update_beta(self):
        
        # Number of gradient updates
        i = self.tot_batches
        
        if i < self.anneal_start or i >= self.anneal_end:
            return
        
        n_steps = self.beta_max // self.step_size
        inc_every = (self.anneal_end - self.anneal_start) // n_steps
        
        curr_step = (i - self.anneal_start) // inc_every
        self.beta = self.step_size * (curr_step + 1)
    
        
    def _update_beta2(self):
        
        steps = self.tot_batches
        
        if steps < self.anneal_after:
            self.beta = 0
        else:
            
            # Compute steps in current cycle
            curr_cycle = (steps - self.anneal_after) // self.anneal_steps
            
            if curr_cycle >= self.n_cycles:
                self.beta = self.beta_max
                return
            
            steps_in_cycle = steps - self.anneal_after - curr_cycle * self.anneal_steps
            
            # Decide stage in cycle (sigmoidal increase or beta set to zero)
            if steps_in_cycle / self.anneal_steps < self.inc_to_zero_ratio:
                
                # Compute values to correctly shift and scale sigmoid
                sig_zero = self.anneal_steps * self.inc_to_zero_ratio / 2
                inc_steps = self.anneal_steps * self.inc_to_zero_ratio
                scale = 1 / ((inc_steps / 2) / self.sig_scaled_point)
                
                self.beta = sigmoid(steps_in_cycle, m=self.beta_max, a=scale, z=sig_zero)
                
            else:
                self.beta = 0
        
    
    def evaluate(self, loader):
        
        losses = defaultdict(list)
        accs = defaultdict(list)
        
        self.model.eval()
        progress_bar = tqdm(range(len(loader)))
        
        with torch.no_grad():
            for batch_idx, inputs in enumerate(loader):

                # Get the inputs and move them to device
                x_graph = inputs.to(self.device)
                x_seq, x_acts, src_mask = x_graph.x_seq, x_graph.x_acts, x_graph.src_mask
                #x_seq, x_acts, x_graph, src_mask = inputs
                #x_seq = x_seq.float().to(self.device)
                #x_acts = x_acts.to(self.device)
                #x_graph = x_graph.to(self.device)
                #src_mask = src_mask.to(self.device)
                tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(self.device)
                inputs = (x_seq, x_acts, x_graph)

                with torch.cuda.amp.autocast():
                    # Forward pass, get the reconstructions
                    outputs, mu, log_var = self.model(x_seq, x_acts, x_graph, src_mask, tgt_mask)

                    # Compute losses and accuracies wrt batch
                    _, losses_b = self._compute_losses(inputs, outputs, mu,
                                                         log_var)
                    
                accs_b = self._compute_accuracies(inputs, outputs, x_graph.is_drum)
                
                # Save losses and accuracies
                append_dict(losses, losses_b)
                append_dict(accs, accs_b)
                
                progress_bar.update(1)
        
        
        # Compute avg losses and accuracies
        avg_losses = {}
        for k, l in losses.items():
            avg_losses[k] = mean(l)
            
        avg_accs = {}
        for k, l in accs.items():
            avg_accs[k] = mean(l)
            
        return avg_losses, avg_accs
                
        
    
    def _compute_losses(self, inputs, outputs, mu, log_var):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs for transformer decoder loss and filter silences
        x_seq = x_seq[..., 1:, :]
        #x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
                
        # Compute the losses
        acts_loss = self.bce_unreduced(acts_rec.view(-1), x_acts.view(-1).float())
        #weights = torch.zeros(acts_loss.size()).to(device)
        #weights[x_acts.view(-1) == 1] = 0.9
        #weights[x_acts.view(-1) == 0] = 0.1
        #acts_loss = torch.mean(weights*acts_loss)
        acts_loss = torch.mean(acts_loss)
        
        pitches_loss = self.ce_p(seq_rec.reshape(-1, seq_rec.size(-1))[:, :131],
                          x_seq.reshape(-1, x_seq.size(-1))[:, :131].argmax(dim=1))
        dur_loss = self.ce_d(seq_rec.reshape(-1, seq_rec.size(-1))[:, 131:],
                          x_seq.reshape(-1, x_seq.size(-1))[:, 131:].argmax(dim=1))
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
        kld_loss = torch.mean(kld_loss)
        rec_loss = pitches_loss + dur_loss + acts_loss
        #rec_loss = acts_loss
        tot_loss = rec_loss + self.beta*kld_loss
        
        losses = {
            'tot': tot_loss.item(),
            'pitches': pitches_loss.item(),
            'dur': dur_loss.item(),
            'acts': acts_loss.item(),
            'rec': rec_loss.item(),
            'kld': kld_loss.item(),
            'beta*kld': self.beta*kld_loss.item()
        }
        
        return tot_loss, losses
            
            
    def _compute_accuracies(self, inputs, outputs, is_drum):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs and filter silences
        x_seq = x_seq[..., 1:, :]
        #x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
        
        notes_acc = self._note_accuracy(seq_rec, x_seq)
        pitches_acc = self._pitches_accuracy(seq_rec, x_seq)
        pitches_acc_drums = self._pitches_accuracy(seq_rec, x_seq, 
                                                   is_drum, drums=True)
        pitches_acc_non_drums = self._pitches_accuracy(seq_rec, x_seq,
                                                       is_drum, drums=False)
        dur_acc = self._dur_accuracy(seq_rec, x_seq)
        acts_acc = self._acts_accuracy(acts_rec, x_acts)
        acts_precision = self._acts_precision(acts_rec, x_acts)
        acts_recall = self._acts_recall(acts_rec, x_acts)
        acts_f1 = (2 * acts_recall * acts_precision / 
                       (acts_recall + acts_precision))
        
        accs = {
            'notes': notes_acc.item(),
            'pitches': pitches_acc.item(),
            'pitches_drums': pitches_acc_drums.item(),
            'pitches_non_drums': pitches_acc_non_drums.item(),
            'dur': dur_acc.item(),
            'acts_acc': acts_acc.item(),
            'acts_precision': acts_precision.item(),
            'acts_recall': acts_recall.item(),
            'acts_f1': acts_f1.item()
        }
        
        return accs
    
    
    def _note_accuracy(self, seq_rec, x_seq):
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        #print(torch.all(pitches_rec == 129))
        #print(pitches_rec)
        
        mask_p = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask_p)
        
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        #print(torch.all(dur_rec == 97))
        
        mask_d = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask_d)
        
        return torch.sum(torch.logical_and(preds_pitches, 
                                           preds_dur)) / torch.sum(mask_p)
    
    
    def _acts_precision(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        tp = torch.sum(x_acts[acts_rec == 1])
        
        return tp / torch.sum(acts_rec)
    
    
    def _acts_recall(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        tp = torch.sum(x_acts[acts_rec == 1])
        
        return tp / torch.sum(x_acts)
    
    
    def _acts_accuracy(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        #print("All zero acts?", torch.all(acts_rec == 0))
        #print("All one acts?", torch.all(acts_rec == 0))
        
        return torch.sum(acts_rec == x_acts) / x_acts.numel()
    
    
    def _pitches_accuracy(self, seq_rec, x_seq, is_drum=None, drums=None):
        
        if drums is not None:
            if drums:
                seq_rec = seq_rec[is_drum]
                x_seq = x_seq[is_drum]
            else:
                seq_rec = seq_rec[torch.logical_not(is_drum)]
                x_seq = x_seq[torch.logical_not(is_drum)]
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        #print("All EOS pitches?", torch.all(pitches_rec == 129))
        
        mask = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask)
        
        return torch.sum(preds_pitches) / torch.sum(mask)
    
    
    def _dur_accuracy(self, seq_rec, x_seq):
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        #print("All EOS durs?", torch.all(dur_rec == 97))
        
        mask = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask)
        
        return torch.sum(preds_dur) / torch.sum(mask)
    
    
    def _save_model(self, filename):
        path = os.path.join(self.model_dir, filename)
        torch.save({
            'epoch': self.cur_epoch,
            'batch': self.cur_batch_idx,
            'tot_batches': self.tot_batches,
            'betas': self.betas,
            'min_val_loss': self.min_val_loss,
            'print_every': self.print_every,
            'save_every': self.save_every,
            'eval_every': self.eval_every,
            'lrs': self.lrs,
            'tr_losses': self.tr_losses,
            'tr_accuracies': self.tr_accuracies,
            'val_losses': self.val_losses,
            'val_accuracies': self.val_accuracies,
            #'scheduler_state_dict': self.lr_scheduler.state_dict(), # Todo: fix
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }, path)
        
    
    def load(self):
        
        checkpoint = torch.load(os.path.join(self.model_dir, 'checkpoint'))
        
        self.cur_epoch = checkpoint['epoch']
        self.cur_batch_idx = checkpoint['batch']
        self.save_every = checkpoint['save_every']
        self.eval_every = checkpoint['eval_every']
        self.lrs = checkpoint['lrs']
        self.tr_losses = checkpoint['tr_losses']
        self.tr_accuracies = checkpoint['tr_accuracies']
        self.val_losses = checkpoint['val_losses']
        self.val_accuracies = checkpoint['val_accuracies']
        self.times = checkpoint['times']
        self.min_val_loss = checkpoint['min_val_loss']
        self.beta = checkpoint['beta']
        self.tot_batches = checkpoint['tot_batches']
        
        
    def _print_stats(self):
        
        hours, rem = divmod(self.times[-1]-self.times[0], 3600)
        minutes, seconds = divmod(rem, 60)
        print("Elapsed time from start (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours), int(minutes), seconds))
        
        avg_lr = mean(self.lrs[-self.print_every:])
        
        # Take mean of the last non-printed batches for each stat
        
        avg_losses = {}
        for k, l in self.tr_losses.items():
            avg_losses[k] = mean(l[-self.print_every:])
        
        avg_accs = {}
        for k, l in self.tr_accuracies.items():
            avg_accs[k] = mean(l[-self.print_every:])
        
        print("Losses:")
        print(avg_losses)
        print("Accuracies:")
        print(avg_accs)
        


Training

In [18]:
from torch.utils.data import Subset
from torch.utils.data import random_split

#bs = 8
#nw = 4
#n_bars = 16

bs = 32
nw = 4
n_bars = 16

ds_dir = "/data/cosenza/datasets/preprocessed_16bars/"
#ds_dir = 'data/preprocessed'

dataset = MIDIDataset(ds_dir, n_bars=n_bars)
ds_len = len(dataset)

print('Dataset len:', len(dataset))

train_len = int(0.7 * len(dataset)) 
valid_len = int(0.1 * len(dataset))
test_len = len(dataset) - train_len - valid_len
#train_dataset, valid_dataset, test_dataset = random_split(model_dataset, (train_count, valid_count, test_count))
tr_set, vl_set, ts_set = random_split(dataset, (train_len, valid_len, test_len))

trainloader = DataLoader(tr_set, batch_size=bs, shuffle=True, num_workers=nw)
validloader = DataLoader(vl_set, batch_size=bs, shuffle=False, num_workers=nw)
#testloader = DataLoader(ts_set, batch_size=64, shuffle=False, num_workers=4)

tr_len = len(tr_set)
vl_len = len(vl_set)
ts_len = len(ts_set)

print('TR set len:', len(tr_set))
print('VL set len:', len(vl_set))
print('TS set len:', len(ts_set))

#n_samples = 128
#subset = Subset(dataset, np.arange(n_samples))
#loader = DataLoader(subset, batch_size=64, shuffle=True)

Dataset len: 2842739
TR set len: 1989917
VL set len: 284273
TS set len: 568549


In [19]:
tr_set[0]

Data(edge_index=[2, 1432], edge_attrs=[1432, 33], num_nodes=296, node_features=[296, 4], is_drum=[296], batch=[296], ptr=[17], bars=[296], x_seq=[296, 16, 230], x_acts=[16, 4, 32], src_mask=[296, 16])

In [20]:
import torch
torch.cuda.set_device(1)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")
print("Device:", device)
print("Current device idx:", torch.cuda.current_device())

Device: cuda
Current device idx: 1


In [21]:
#!rm models/vae

In [22]:
from prettytable import PrettyTable


def print_params(model):
    
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    
    for name, parameter in model.named_parameters():
        
        if not parameter.requires_grad:
            continue
            
        param = parameter.numel()
        table.add_row([name, param])
        total_params += param
        
    print(table)
    print(f"Total Trainable Params: {total_params}")
    
    return total_params

In [23]:
import math
import torch
from typing import Optional
from torch.optim import Optimizer

from lr_scheduler.lr_scheduler import LearningRateScheduler

class TransformerLRScheduler(LearningRateScheduler):
    r"""
    Transformer Learning Rate Scheduler proposed in "Attention Is All You Need"
    Args:
        optimizer (Optimizer): Optimizer.
        peak_lr (float): Maximum learning rate.
        final_lr_scale (float): Final learning rate scale
        warmup_steps (int): Warmup the learning rate linearly for the first N updates
        decay_steps (int): Steps in decay stage
    """
    def __init__(
            self,
            optimizer: Optimizer,
            peak_lr: float,
            final_lr_scale: float,
            warmup_steps: int,
            decay_steps: int,
            init_lr = 0
    ) -> None:
        assert isinstance(warmup_steps, int), "warmup_steps should be integer type"
        assert isinstance(decay_steps, int), "total_steps should be integer type"

        super(TransformerLRScheduler, self).__init__(optimizer, init_lr)
        self.peak_lr = peak_lr
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps

        self.warmup_rate = self.peak_lr / self.warmup_steps
        self.decay_factor = -math.log(final_lr_scale) / self.decay_steps

        self.update_steps = 0

    def _decide_stage(self):
        if self.update_steps < self.warmup_steps:
            return 0, self.update_steps

        if self.warmup_steps <= self.update_steps:
            return 1, self.update_steps - self.warmup_steps

        return 2, None

    def step(self, val_loss: Optional[torch.FloatTensor] = None):
        self.update_steps += 1
        stage, steps_in_stage = self._decide_stage()

        if stage == 0:
            #self.lr = self.update_steps * self.warmup_rate
            self.lr = self.peak_lr
        elif stage == 1:
            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
        else:
            raise ValueError("Undefined stage")

        self.set_lr(self.optimizer, self.lr)

        return self.lr

In [24]:
parameters = {
    
    'training': {
        'batch_size': bs,
        'num_workers': nw,
        'ds_len': ds_len,
        'tr_len': tr_len,
        'vl_len': vl_len,
        'ts_len': ts_len,
    },
    
    'model': {
        
        'dropout': 0, # was 0.1
        'batch_norm': True,
        
         
        'gnn_n_layers': 8, # was 3
        'actsnn_n_layers': 2,
        'd': 512,
        
        'rnn_n_layers': 1,
        'k_isgn': 3,
        
        'd_token': 230,
        'd_token_pitches': 131,
        'd_token_dur': 99,
        'n_bars': n_bars,
        'n_relations': 6,
        'n_tracks': 4,
        'resolution': 8,
        'max_simu_notes': 16
    },
    
    'scheduler': {
        'peak_lr': 5e-5, # 1e-5 2bars, 5e-6 16bars
        'final_lr_scale': 0.01,
        'warmup_steps': 8000,
        'decay_steps': 800000 # 500000 16bars
    },
    
    'optimizer': {
        'betas': (0.9, 0.98),
        'eps': 1e-09,
        'lr': 5e-6
    },
    
    'beta_annealing': {
        'beta_update': True,
        
        #'anneal_after': 4000,
        #'beta_max': 0.01,
        #'anneal_steps': 500000,
        #'inc_to_zero_ratio': 1,
        #'sig_scaled_point': 7,
        #'n_cycles': 4
        
        'anneal_start': 40000,
        'beta_max': 0.01,
        'step_size': 0.001,
        'anneal_end': 500000,
        #'inc_to_zero_ratio': 1,
        #'sig_scaled_point': 7
    }
}

Training from scratch:

In [None]:
import torch_geometric

print("Creating the model and moving it to the specified device...")

# Create model dir
models_dir = 'models/'
model_name = '16barsGNNDEF'
model_dir = os.path.join(models_dir, model_name)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=False)

# Creating the model
vae = VAE(**parameters['model'], device=device).to(device)
#vae = torch_geometric.nn.DataParallel(vae, device_ids=[0, 1, 2])
print_params(vae)
print()

# Creating optimizer and scheduler
optimizer = optim.Adam(vae.parameters(), **parameters['optimizer'])
scheduler = TransformerLRScheduler(
    optimizer=optimizer,
    **parameters['scheduler']
)

# Save parameters
params_path = os.path.join(model_dir, 'params')
torch.save(parameters, params_path)

print('--------------------------------------------------\n')

trainer = VAETrainer(
    model_dir,
    model=vae,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    save_every=100,
    print_every=1,
    eval_every=62000,
    iters_to_accumulate=1,
    device=device,
    **parameters['beta_annealing']
)
trainer.train(trainloader, validloader=validloader, epochs=100)

Creating the model and moving it to the specified device...
+---------------------------------------------------+------------+
|                      Modules                      | Parameters |
+---------------------------------------------------+------------+
|           encoder.notes_pitch_emb.weight          |   33536    |
|            encoder.notes_pitch_emb.bias           |    256     |
|               encoder.bn_npe.weight               |    256     |
|                encoder.bn_npe.bias                |    256     |
|           encoder.drums_pitch_emb.weight          |   33536    |
|            encoder.drums_pitch_emb.bias           |    256     |
|               encoder.bn_dpe.weight               |    256     |
|                encoder.bn_dpe.bias                |    256     |
|               encoder.dur_emb.weight              |   25344    |
|                encoder.dur_emb.bias               |    256     |
|                encoder.bn_de.weight               |    256     |
| 

  0%|          | 0/62185 [00:00<?, ?it/s]

Training on batch 1/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:06.12
Losses:
{'tot': 11.290946006774902, 'pitches': 5.531617164611816, 'dur': 5.171469211578369, 'acts': 0.587859034538269, 'rec': 11.290946006774902, 'kld': 64.67884826660156, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.0, 'pitches': 0.05178241804242134, 'pitches_drums': 0.10790582001209259, 'pitches_non_drums': 0.006426816340535879, 'dur': 0.021535351872444153, 'acts_acc': 0.7539520263671875, 'acts_precision': 0.16551458835601807, 'acts_recall': 0.11863473802804947, 'acts_f1': 0.13820748031139374}

----------------------------------------

Training on batch 2/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:07.27
Losses:
{'tot': 11.303545951843262, 'pitches': 5.5521464347839355, 'dur': 5.156982898712158, 'acts': 0.594416618347168, 'rec': 11.303545951843262, 'kld': 66.38317108154297, 'beta*kld': 0.0}
Accuracies:
{'notes': 9.053597023012117e-05, 'pitches': 0.03838725388050079, '

Training on batch 14/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:17.96
Losses:
{'tot': 6.263038635253906, 'pitches': 3.4850194454193115, 'dur': 2.2246220111846924, 'acts': 0.5533969402313232, 'rec': 6.263038635253906, 'kld': 67.29110717773438, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3855352997779846, 'pitches': 0.39540621638298035, 'pitches_drums': 0.39824533462524414, 'pitches_non_drums': 0.39348265528678894, 'dur': 0.5556188225746155, 'acts_acc': 0.79766845703125, 'acts_precision': 0.2836134433746338, 'acts_recall': 0.09299922734498978, 'acts_f1': 0.14006873965263367}

----------------------------------------

Training on batch 15/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:18.80
Losses:
{'tot': 5.712218761444092, 'pitches': 3.2338900566101074, 'dur': 1.9235286712646484, 'acts': 0.5548001527786255, 'rec': 5.712218761444092, 'kld': 66.64513397216797, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40136995911598206, 'pitches': 0.4230189025

Training on batch 27/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:30.25
Losses:
{'tot': 5.51398229598999, 'pitches': 3.1725573539733887, 'dur': 1.8105905055999756, 'acts': 0.5308347344398499, 'rec': 5.51398229598999, 'kld': 81.86231994628906, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3400026261806488, 'pitches': 0.38483089208602905, 'pitches_drums': 0.38791584968566895, 'pitches_non_drums': 0.38283073902130127, 'dur': 0.5569912195205688, 'acts_acc': 0.8190155029296875, 'acts_precision': 0.35944700241088867, 'acts_recall': 0.11466372758150101, 'acts_f1': 0.17386458814144135}

----------------------------------------

Training on batch 28/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:30.97
Losses:
{'tot': 5.185186862945557, 'pitches': 2.9866199493408203, 'dur': 1.668620228767395, 'acts': 0.5299468636512756, 'rec': 5.185186862945557, 'kld': 79.27806091308594, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3571232855319977, 'pitches': 0.41075342893

Training on batch 40/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:42.82
Losses:
{'tot': 5.122925758361816, 'pitches': 2.8437442779541016, 'dur': 1.764727234840393, 'acts': 0.5144545435905457, 'rec': 5.122925758361816, 'kld': 71.25537109375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3629480302333832, 'pitches': 0.4190284013748169, 'pitches_drums': 0.4694799482822418, 'pitches_non_drums': 0.3896416425704956, 'dur': 0.5630329251289368, 'acts_acc': 0.8369903564453125, 'acts_precision': 0.507798969745636, 'acts_recall': 0.10934875905513763, 'acts_f1': 0.17994779348373413}

----------------------------------------

Training on batch 41/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:43.62
Losses:
{'tot': 5.0426025390625, 'pitches': 2.9647462368011475, 'dur': 1.5622719526290894, 'acts': 0.5155845880508423, 'rec': 5.0426025390625, 'kld': 72.04402160644531, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.35556650161743164, 'pitches': 0.39434102177619934, '

Training on batch 53/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:55.14
Losses:
{'tot': 4.958611488342285, 'pitches': 2.8062124252319336, 'dur': 1.6366472244262695, 'acts': 0.5157519578933716, 'rec': 4.958611488342285, 'kld': 70.94557189941406, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3587322235107422, 'pitches': 0.39962348341941833, 'pitches_drums': 0.47086790204048157, 'pitches_non_drums': 0.3643661141395569, 'dur': 0.563219428062439, 'acts_acc': 0.83050537109375, 'acts_precision': 0.6808916926383972, 'acts_recall': 0.09155532717704773, 'acts_f1': 0.1614072173833847}

----------------------------------------

Training on batch 54/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:00:55.92
Losses:
{'tot': 4.785670280456543, 'pitches': 2.7546849250793457, 'dur': 1.5168112516403198, 'acts': 0.5141738057136536, 'rec': 4.785670280456543, 'kld': 72.68707275390625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37315765023231506, 'pitches': 0.4096215069293

Training on batch 66/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:07.01
Losses:
{'tot': 4.711211204528809, 'pitches': 2.768904447555542, 'dur': 1.43604576587677, 'acts': 0.5062607526779175, 'rec': 4.711211204528809, 'kld': 74.2129898071289, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3644680082798004, 'pitches': 0.40647760033607483, 'pitches_drums': 0.4250490367412567, 'pitches_non_drums': 0.3928827941417694, 'dur': 0.5900586247444153, 'acts_acc': 0.8248291015625, 'acts_precision': 0.7749820351600647, 'acts_recall': 0.0880359336733818, 'acts_f1': 0.15811088681221008}

----------------------------------------

Training on batch 67/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:07.90
Losses:
{'tot': 4.763686180114746, 'pitches': 2.7526638507843018, 'dur': 1.5193358659744263, 'acts': 0.49168646335601807, 'rec': 4.763686180114746, 'kld': 70.82512664794922, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.356740266084671, 'pitches': 0.40825703740119934, 

Training on batch 79/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:18.43
Losses:
{'tot': 4.490115642547607, 'pitches': 2.6187539100646973, 'dur': 1.3797460794448853, 'acts': 0.4916159212589264, 'rec': 4.490115642547607, 'kld': 76.72782897949219, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38554108142852783, 'pitches': 0.40888527035713196, 'pitches_drums': 0.41513791680336, 'pitches_non_drums': 0.4038941264152527, 'dur': 0.6037519574165344, 'acts_acc': 0.83416748046875, 'acts_precision': 0.8902891278266907, 'acts_recall': 0.10074657946825027, 'acts_f1': 0.18100979924201965}

----------------------------------------

Training on batch 80/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:19.19
Losses:
{'tot': 4.653721332550049, 'pitches': 2.696575403213501, 'dur': 1.463333249092102, 'acts': 0.4938124418258667, 'rec': 4.653721332550049, 'kld': 76.05809020996094, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.36784160137176514, 'pitches': 0.403667777776718

Training on batch 92/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:29.94
Losses:
{'tot': 4.428382873535156, 'pitches': 2.4771080017089844, 'dur': 1.4702644348144531, 'acts': 0.48101019859313965, 'rec': 4.428382873535156, 'kld': 77.88085174560547, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40425384044647217, 'pitches': 0.43763765692710876, 'pitches_drums': 0.4809112250804901, 'pitches_non_drums': 0.4108635187149048, 'dur': 0.6083080768585205, 'acts_acc': 0.8464508056640625, 'acts_precision': 0.8633986711502075, 'acts_recall': 0.11821029335260391, 'acts_f1': 0.20794962346553802}

----------------------------------------

Training on batch 93/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:30.75
Losses:
{'tot': 4.406085014343262, 'pitches': 2.5984318256378174, 'dur': 1.3179295063018799, 'acts': 0.4897238612174988, 'rec': 4.406085014343262, 'kld': 77.97232055664062, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3862477242946625, 'pitches': 0.413934439

Training on batch 105/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:43.55
Losses:
{'tot': 4.355696201324463, 'pitches': 2.5260233879089355, 'dur': 1.3504998683929443, 'acts': 0.47917309403419495, 'rec': 4.355696201324463, 'kld': 82.68358612060547, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3759024143218994, 'pitches': 0.41479960083961487, 'pitches_drums': 0.43765753507614136, 'pitches_non_drums': 0.3982623517513275, 'dur': 0.6181852221488953, 'acts_acc': 0.8407440185546875, 'acts_precision': 0.8743202686309814, 'acts_recall': 0.12392942607402802, 'acts_f1': 0.217087984085083}

----------------------------------------

Training on batch 106/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:44.33
Losses:
{'tot': 4.401873588562012, 'pitches': 2.574451208114624, 'dur': 1.3467100858688354, 'acts': 0.4807121753692627, 'rec': 4.401873588562012, 'kld': 83.00200653076172, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3781414330005646, 'pitches': 0.4162815809

Training on batch 118/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:55.06
Losses:
{'tot': 4.247522354125977, 'pitches': 2.498579502105713, 'dur': 1.257826566696167, 'acts': 0.49111634492874146, 'rec': 4.247522354125977, 'kld': 85.47300720214844, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4003651440143585, 'pitches': 0.4302346408367157, 'pitches_drums': 0.43813493847846985, 'pitches_non_drums': 0.4250815808773041, 'dur': 0.6421345472335815, 'acts_acc': 0.8254547119140625, 'acts_precision': 0.8510758876800537, 'acts_recall': 0.11854247003793716, 'acts_f1': 0.20809967815876007}

----------------------------------------

Training on batch 119/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:01:55.76
Losses:
{'tot': 4.295773506164551, 'pitches': 2.452768087387085, 'dur': 1.3806132078170776, 'acts': 0.4623924195766449, 'rec': 4.295773506164551, 'kld': 85.13887786865234, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3828653395175934, 'pitches': 0.42030471563

Training on batch 131/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:07.31
Losses:
{'tot': 4.467079162597656, 'pitches': 2.5582973957061768, 'dur': 1.4499300718307495, 'acts': 0.45885181427001953, 'rec': 4.467079162597656, 'kld': 90.69322204589844, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37909597158432007, 'pitches': 0.40642011165618896, 'pitches_drums': 0.443280965089798, 'pitches_non_drums': 0.3854837715625763, 'dur': 0.582165539264679, 'acts_acc': 0.860504150390625, 'acts_precision': 0.8099891543388367, 'acts_recall': 0.14507973194122314, 'acts_f1': 0.24608279764652252}

----------------------------------------

Training on batch 132/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:08.08
Losses:
{'tot': 4.375945091247559, 'pitches': 2.5413424968719482, 'dur': 1.3653662204742432, 'acts': 0.46923643350601196, 'rec': 4.375945091247559, 'kld': 90.951904296875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37525027990341187, 'pitches': 0.4076815843

Training on batch 144/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:18.57
Losses:
{'tot': 4.237133979797363, 'pitches': 2.4642934799194336, 'dur': 1.315338134765625, 'acts': 0.4575025141239166, 'rec': 4.237133979797363, 'kld': 94.89729309082031, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3897474408149719, 'pitches': 0.4143422544002533, 'pitches_drums': 0.43623432517051697, 'pitches_non_drums': 0.39540740847587585, 'dur': 0.6187649369239807, 'acts_acc': 0.847381591796875, 'acts_precision': 0.840873658657074, 'acts_recall': 0.14293290674686432, 'acts_f1': 0.2443336397409439}

----------------------------------------

Training on batch 145/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:19.39
Losses:
{'tot': 4.204680442810059, 'pitches': 2.42799973487854, 'dur': 1.3088374137878418, 'acts': 0.46784311532974243, 'rec': 4.204680442810059, 'kld': 93.1676025390625, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38286274671554565, 'pitches': 0.4274779558181

Training on batch 157/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:31.23
Losses:
{'tot': 4.110701084136963, 'pitches': 2.3260273933410645, 'dur': 1.3262609243392944, 'acts': 0.45841285586357117, 'rec': 4.110701084136963, 'kld': 93.12812042236328, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.411007821559906, 'pitches': 0.43143436312675476, 'pitches_drums': 0.47153541445732117, 'pitches_non_drums': 0.4022132158279419, 'dur': 0.627271831035614, 'acts_acc': 0.8429718017578125, 'acts_precision': 0.836006224155426, 'acts_recall': 0.13943053781986237, 'acts_f1': 0.23900021612644196}

----------------------------------------

Training on batch 158/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:32.59
Losses:
{'tot': 4.137833118438721, 'pitches': 2.429332733154297, 'dur': 1.2617344856262207, 'acts': 0.4467660188674927, 'rec': 4.137833118438721, 'kld': 93.57264709472656, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4044097363948822, 'pitches': 0.42590278387

Training on batch 170/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:43.48
Losses:
{'tot': 4.228262901306152, 'pitches': 2.4261457920074463, 'dur': 1.3409947156906128, 'acts': 0.4611223340034485, 'rec': 4.228262901306152, 'kld': 97.99620056152344, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.3944041132926941, 'pitches': 0.4225347936153412, 'pitches_drums': 0.4723186194896698, 'pitches_non_drums': 0.3915603458881378, 'dur': 0.6070175170898438, 'acts_acc': 0.8373565673828125, 'acts_precision': 0.8837087750434875, 'acts_recall': 0.13914549350738525, 'acts_f1': 0.2404332458972931}

----------------------------------------

Training on batch 171/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:44.24
Losses:
{'tot': 4.202975273132324, 'pitches': 2.4541592597961426, 'dur': 1.2812376022338867, 'acts': 0.4675785303115845, 'rec': 4.202975273132324, 'kld': 98.72468566894531, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4049144983291626, 'pitches': 0.43161216378

Training on batch 183/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:55.12
Losses:
{'tot': 3.975043296813965, 'pitches': 2.232896566390991, 'dur': 1.2949838638305664, 'acts': 0.44716280698776245, 'rec': 3.975043296813965, 'kld': 101.15029907226562, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42762699723243713, 'pitches': 0.46085596084594727, 'pitches_drums': 0.5206653475761414, 'pitches_non_drums': 0.41929587721824646, 'dur': 0.6432150602340698, 'acts_acc': 0.8508758544921875, 'acts_precision': 0.8735070824623108, 'acts_recall': 0.1443178802728653, 'acts_f1': 0.2477099448442459}

----------------------------------------

Training on batch 184/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:02:55.88
Losses:
{'tot': 4.153001308441162, 'pitches': 2.4360971450805664, 'dur': 1.2632381916046143, 'acts': 0.4536658525466919, 'rec': 4.153001308441162, 'kld': 103.30278015136719, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40665724873542786, 'pitches': 0.432972

Training on batch 196/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:06.67
Losses:
{'tot': 4.145449161529541, 'pitches': 2.454493522644043, 'dur': 1.224639892578125, 'acts': 0.4663156270980835, 'rec': 4.145449161529541, 'kld': 105.91291809082031, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38964179158210754, 'pitches': 0.424456924200058, 'pitches_drums': 0.45310407876968384, 'pitches_non_drums': 0.4046877920627594, 'dur': 0.6320529580116272, 'acts_acc': 0.8251800537109375, 'acts_precision': 0.8470588326454163, 'acts_recall': 0.12418659031391144, 'acts_f1': 0.21661537885665894}

----------------------------------------

Training on batch 197/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:07.46
Losses:
{'tot': 3.9942286014556885, 'pitches': 2.306154727935791, 'dur': 1.234184980392456, 'acts': 0.45388883352279663, 'rec': 3.9942286014556885, 'kld': 106.42057800292969, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.4161105155944824, 'pitches': 0.44881206

Training on batch 209/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:19.77
Losses:
{'tot': 4.156214714050293, 'pitches': 2.4773168563842773, 'dur': 1.2292664051055908, 'acts': 0.4496312439441681, 'rec': 4.156214714050293, 'kld': 107.06781005859375, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.37765926122665405, 'pitches': 0.4124767780303955, 'pitches_drums': 0.41775187849998474, 'pitches_non_drums': 0.408379465341568, 'dur': 0.6236291527748108, 'acts_acc': 0.8364715576171875, 'acts_precision': 0.8588362336158752, 'acts_recall': 0.13229313492774963, 'acts_f1': 0.22927004098892212}

----------------------------------------

Training on batch 210/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:20.51
Losses:
{'tot': 3.919241428375244, 'pitches': 2.2220458984375, 'dur': 1.2546570301055908, 'acts': 0.44253844022750854, 'rec': 3.919241428375244, 'kld': 108.21607971191406, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.42043906450271606, 'pitches': 0.45244750

Training on batch 222/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:30.63
Losses:
{'tot': 4.068234443664551, 'pitches': 2.3786420822143555, 'dur': 1.2371468544006348, 'acts': 0.45244544744491577, 'rec': 4.068234443664551, 'kld': 112.42438507080078, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.40887153148651123, 'pitches': 0.440479040145874, 'pitches_drums': 0.5185185074806213, 'pitches_non_drums': 0.3893709182739258, 'dur': 0.6480576992034912, 'acts_acc': 0.8328094482421875, 'acts_precision': 0.8208802342414856, 'acts_recall': 0.13135696947574615, 'acts_f1': 0.22647370398044586}

----------------------------------------

Training on batch 223/62185 of epoch 1/100 complete.
Elapsed time from start (h:m:s): 00:03:31.42
Losses:
{'tot': 4.058856964111328, 'pitches': 2.382995843887329, 'dur': 1.2382436990737915, 'acts': 0.43761706352233887, 'rec': 4.058856964111328, 'kld': 110.774169921875, 'beta*kld': 0.0}
Accuracies:
{'notes': 0.38504624366760254, 'pitches': 0.41525584

Training from checkpoint:

In [None]:
# Where is the model? in models/model_name/checkpoint (cause we don't need the best yet)
# What were its parameters? model_state_dict
# What were the scheduler's parameters? optimizer_state_dict
# What were the optimizer's parameters (Adam)? scheduler_state_dict
# --> Ok! Now we have a model, a scheduler and an optimizer.

# What else do we need? A training checkpoint! This should be gathered inside Trainer!
# A training checkpoint is composed of many things such as:
#    last epoch, last batch, ...
# However, we also need a way to store the dataloaders. Why?
# At the beginning we have a full dataset in a directory. Then, we create two (three) subsets,
# corresponding to TR, VL (TS). 

## Reconstructions

In [None]:
checkpoint = torch.load('models/just_pitches_warmup')

In [None]:
state_dict = checkpoint['model_state_dict']
vae = VAE().to(device)

In [None]:
vae.load_state_dict(state_dict)

In [None]:
loader = DataLoader(dataset, batch_size=32, shuffle=False)
len(dataset)

In [None]:
for idx, inputs in enumerate(loader):
    
    x_seq, x_acts, x_graph, src_mask = inputs
    x_seq = x_seq.float().to(device)
    x_acts = x_acts.to(device)
    x_graph = x_graph.to(device)
    src_mask = src_mask.to(device)
    tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(device)

    # Forward pass, get the reconstructions
    outputs, mu, log_var = vae(x_seq, x_acts, x_graph, src_mask, tgt_mask)
    
    break

seq_rec, _  = outputs

In [None]:
x_seq.size()

In [None]:
seq_rec.size()

In [None]:
x_acts.size()

Create dense reconstruction from sparse reconstruction:

In [None]:
seq_rec_dense = torch.zeros(x_seq.size(), dtype=torch.float).to(device)
seq_rec_dense = seq_rec_dense[..., 1:, :]
size = seq_rec_dense.size()

seq_rec_dense = seq_rec_dense.view(-1, seq_rec_dense.size(-2), seq_rec_dense.size(-1))

silence = torch.zeros(seq_rec_dense.size(-2), seq_rec_dense.size(-1)).to(device)
silence[:, 129] = 1. # eos token

seq_rec_dense[x_acts.bool().view(-1)] = seq_rec
seq_rec_dense[torch.logical_not(x_acts.bool().view(-1))] = silence

seq_rec_dense = seq_rec_dense.view(size)

In [None]:
print(seq_rec_dense.size())
print(x_seq.size())

In [None]:
music_real = x_seq[0]
music_rec = seq_rec_dense[0]

In [None]:
music_real.size()

In [None]:
prefix = "data/music/"

real = from_tensor_to_muspy(music_real, track_data)
muspy.show_pianoroll(real, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "real" + ".png")
muspy.write_midi(prefix + "real" + ".mid", real)

In [None]:
rec = from_tensor_to_muspy(music_rec, track_data)
muspy.show_pianoroll(rec, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "rec" + ".png")
muspy.write_midi(prefix + "rec" + ".mid", rec)

Plot music and save it to disk

In [None]:
#tracks = [drum_track, bass_track, guitar_track, strings_track]
import copy

def from_tensor_to_muspy(music_tensor, track_data):
    
    powers = torch.tensor([2**n for n in reversed(range(9))], dtype=torch.float)
    tracks = []
    
    for tr in range(music_tensor.size(0)):
        
        notes = []
        
        for ts in range(music_tensor.size(1)):
            for note in range(music_tensor.size(2)):
                
                pitch = music_tensor[tr, ts, note, :131]
                pitch = torch.argmax(pitch)

                if pitch == 129:
                    break
                
                if pitch != 128:
                    #dur = music_tensor[tr, ts, note, 131:]
                    #dur = torch.dot(dur, powers).long()
                    
                    dur = 4
                    
                    #notes.append(muspy.Note(ts, pitch.item(), dur.item(), 64))
                    notes.append(muspy.Note(ts, pitch.item(), dur, 64))
        
        if track_data[tr][0] == 'Drums':
            track = muspy.Track(name='Drums', is_drum=True, notes=copy.deepcopy(notes))
        else:
            track = muspy.Track(name=track_data[tr][0], 
                                program=track_data[tr][1],
                                notes=copy.deepcopy(notes))
        tracks.append(track)
    
    meta = muspy.Metadata(title='prova')
    music = muspy.Music(tracks=tracks, metadata=meta, resolution=RESOLUTION)
    
    return music


track_data = [('Drums', -1), ('Bass', 34), ('Guitar', 1), ('Strings', 41)]

In [None]:
prefix = "data/music/file"

for i in range(10):
    music_tensor = dataset[20+i][0]
    music = from_tensor_to_muspy(music_tensor, track_data)
    muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
    plt.savefig(prefix + str(i) + ".png")
    muspy.write_midi(prefix + str(i) + ".mid", music)

In [None]:
music

In [None]:
music_path = "data/music/file2.mid"
muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
plt.savefig('file2.png')
muspy.write_midi(music_path, music)

In [None]:
print(dataset[0][0].size())
notes = []
notes.append(muspy.Note(1, 48, 20, 64))
drums = muspy.Track(is_drum=True)
bass = muspy.Track(program=34, notes=notes)
guitar = muspy.Track(program=27, notes=[])
strings = muspy.Track(program=42, notes=[muspy.Note(0, 100, 4, 64), muspy.Note(4, 91, 20, 64)])

tracks = [drums, bass, guitar, strings]

meta = muspy.Metadata(title='prova')
music = muspy.Music(tracks=tracks, metadata=meta, resolution=32)

In [None]:
!ls data/lmd_matched/M/T/O/TRMTOBP128E07822EF/63edabc86c087f07eca448b0edad53c3.mid

# Stuff

next edges

In [None]:
import itertools

a = np.random.randint(2, size=(4,8))
a_t = a.transpose()
print(a_t)
inds = np.stack(np.where(a_t == 1)).transpose()
ts_acts = np.any(a_t, axis=1)
ts_inds = np.where(ts_acts)[0]

labels = np.arange(32).reshape(4, 8).transpose()
print(labels)

next_edges = []
for i in range(len(ts_inds)-1):
    ind_s = ts_inds[i]
    ind_e = ts_inds[i+1]
    s = inds[inds[:,0] == ind_s]
    e = inds[inds[:,0] == ind_e]
    e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
    edges = [(labels[tuple(e[0])],labels[tuple(e[1])], ind_e-ind_s) for e in e_inds]
    next_edges.extend(edges)

print(next_edges)
    

onset edges

In [None]:
onset_edges = []
print(a_t)
print(labels)

for i in ts_inds:
    ts_acts_inds = list(inds[inds[:,0] == i])
    if len(ts_acts_inds) < 2:
        continue
    e_inds = list(itertools.combinations(ts_acts_inds, 2))
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], 0) for e in e_inds]
    inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
    onset_edges.extend(edges)
    onset_edges.extend(inv_edges)

print(onset_edges)


track edges

In [None]:
print(a_t)
print(labels)
track_edges = []

for track in range(a_t.shape[1]):
    tr_inds = list(inds[inds[:,1] == track])
    e_inds = [(tr_inds[i],
               tr_inds[i+1]) for i in range(len(tr_inds)-1)]
    print(e_inds)
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], e[1][0]-e[0][0]) for e in e_inds]
    track_edges.extend(edges)

print(track_edges)

In [None]:
track_edges = np.array(track_edges)
onset_edges = np.array(onset_edges)
np.concatenate((track_edges, onset_edges)).shape

In [None]:
pip install pypianoroll

In [None]:
import pypianoroll

In [None]:
multitrack = pypianoroll.read("tests_fur-elise.mid")
print(multitrack)

In [None]:
multitrack.tracks[0].pianoroll

In [None]:
multitrack.plot()

In [None]:
multitrack.trim(0, 12 * multitrack.resolution)
multitrack.binarize()

In [None]:
multitrack.plot()

In [None]:
multitrack.tracks[0].pianoroll.shape