<a href="https://colab.research.google.com/github/EmanueleCosenza/Polyphemus/blob/main/midi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pwd

In [None]:
!git branch

Libraries installation

In [None]:
#!tar -C data -xvzf data/lmd_matched.tar.gz

In [None]:
# Install the required music libraries
#!pip3 install muspy
#!pip3 install pypianoroll

In [None]:
# Install torch_geometric
#!v=$(python3 -c "import torch; print(torch.__version__)"); \
#pip3 install torch-scatter -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-sparse -f https://data.pyg.org/whl/torch-${v}.html; \
#pip3 install torch-geometric

Reproducibility

In [None]:
import numpy as np
import torch
import random
import os

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

seed = 42
set_seed(seed)

In [None]:
import os
import muspy
from itertools import product
import pypianoroll as pproll
import time
from tqdm.auto import tqdm


class MIDIPreprocessor():
    
    def __init__():
        pass

    def preprocess_dataset(self, dir, early_exit=None):
        pass
    
    def preprocess_file(self, f):
        pass


# Todo: to config file (or separate files)
MAX_SIMU_NOTES = 16 # 14 + SOS and EOS

PITCH_SOS = 128
PITCH_EOS = 129
PITCH_PAD = 130
DUR_SOS = 96
DUR_EOS = 97
DUR_PAD = 98

MAX_DUR = 96

# Number of time steps per quarter note
# To get bar resolution -> RESOLUTION*4
RESOLUTION = 8 
NUM_BARS = 1


def preprocess_file(filepath, dest_dir, num_samples):

    saved_samples = 0

    print("Preprocessing file " + filepath)

    # Load the file both as a pypianoroll song and a muspy song
    # (Need to load both since muspy.to_pypianoroll() is expensive)
    try:
        pproll_song = pproll.read(filepath, resolution=RESOLUTION)
        muspy_song = muspy.read(filepath)
    except Exception as e:
        print("Song skipped (Invalid song format)")
        return 0
    
    # Only accept songs that have a time signature of 4/4 and no time changes
    for t in muspy_song.time_signatures:
        if t.numerator != 4 or t.denominator != 4:
            print("Song skipped ({}/{} time signature)".
                            format(t.numerator, t.denominator))
            return 0

    # Gather tracks of pypianoroll song based on MIDI program number
    drum_tracks = []
    bass_tracks = []
    guitar_tracks = []
    strings_tracks = []

    for track in pproll_song.tracks:
        if track.is_drum:
            #continue
            track.name = 'Drums'
            drum_tracks.append(track)
        elif 0 <= track.program <= 31:
            track.name = 'Guitar'
            guitar_tracks.append(track)
        elif 32 <= track.program <= 39:
            track.name = 'Bass'
            bass_tracks.append(track)
        else:
            # Tracks with program > 39 are all considered as strings tracks
            # and will be merged into a single track later on
            strings_tracks.append(track)

    # Filter song if it does not contain drum, guitar, bass or strings tracks
    #if not guitar_tracks \
    if not drum_tracks or not guitar_tracks \
            or not bass_tracks or not strings_tracks:
        print("Song skipped (does not contain drum or "
                "guitar or bass or strings tracks)")
        return 0
    
    # Merge strings tracks into a single pypianoroll track
    strings = pproll.Multitrack(tracks=strings_tracks)
    strings_track = pproll.Track(pianoroll=strings.blend(mode='max'),
                                 program=48, name='Strings')

    combinations = list(product(drum_tracks, bass_tracks, guitar_tracks))
    #combinations = list(product(bass_tracks, guitar_tracks))

    # Single instruments can have multiple tracks.
    # Consider all possible combinations of drum, bass, and guitar tracks
    for i, combination in enumerate(combinations):

        print("Processing combination", i+1, "of", len(combinations))
        
        # Process combination (called 'subsong' from now on)
        drum_track, bass_track, guitar_track = combination
        tracks = [drum_track, bass_track, guitar_track, strings_track]
        #bass_track, guitar_track = combination
        #tracks = [bass_track, guitar_track, strings_track]
        
        pproll_subsong = pproll.Multitrack(
            tracks=tracks,
            tempo=pproll_song.tempo,
            resolution=RESOLUTION
        )
        muspy_subsong = muspy.from_pypianoroll(pproll_subsong)
        
        tracks_notes = [track.notes for track in muspy_subsong.tracks]
        
        # Obtain length of subsong (maximum of each track's length)
        length = 0
        for notes in tracks_notes:
            track_length = max(note.end for note in notes) if notes else 0
            length = max(length, track_length)
        length += 1

        # Add timesteps until length is a multiple of RESOLUTION
        length = length if length%(RESOLUTION*4) == 0 \
                            else length + (RESOLUTION*4-(length%(RESOLUTION*4)))


        tracks_tensors = []
        tracks_activations = []

        # Todo: adapt to velocity
        for notes in tracks_notes:

            # Initialize encoder-ready track tensor
            # track_tensor: (length x max_simu_notes x 2 (or 3 if velocity))
            # The last dimension contains pitches and durations (and velocities)
            # int16 is enough for small to medium duration values
            track_tensor = np.zeros((length, MAX_SIMU_NOTES, 2), np.int16)

            track_tensor[:, :, 0] = PITCH_PAD
            track_tensor[:, 0, 0] = PITCH_SOS
            track_tensor[:, :, 1] = DUR_PAD
            track_tensor[:, 0, 1] = DUR_SOS

            # Keeps track of how many notes have been stored in each timestep
            # (int8 imposes that MAX_SIMU_NOTES < 256)
            notes_counter = np.ones(length, dtype=np.int8)

            # Todo: np.put_along_axis?
            for note in notes:
                # Insert note in the lowest position available in the timestep
                
                t = note.time
                
                if notes_counter[t] >= MAX_SIMU_NOTES-1:
                    # Skip note if there is no more space
                    continue
                
                track_tensor[t, notes_counter[t], 0] = note.pitch
                dur = min(MAX_DUR, note.duration)
                track_tensor[t, notes_counter[t], 1] = dur-1
                notes_counter[t] += 1
            
            # Add end of sequence token
            track_tensor[np.arange(0, length), notes_counter, 0] = PITCH_EOS
            track_tensor[np.arange(0, length), notes_counter, 1] = DUR_EOS
            
            # Get track activations, a boolean tensor indicating whether notes
            # are being played in a timestep (sustain does not count)
            # (needed for graph rep.)
            activations = np.array(notes_counter-1, dtype=bool)
            
            tracks_tensors.append(track_tensor)
            tracks_activations.append(activations)
        
        # (#tracks x length x max_simu_notes x 2 (or 3))
        subsong_tensor = np.stack(tracks_tensors, axis=0)

        # (#tracks x length)
        subsong_activations = np.stack(tracks_activations, axis=0)


        # Slide window over 'subsong_tensor' and 'subsong_activations' along the
        # time axis (2nd dimension) with the stride of a bar
        # Todo: np.lib.stride_tricks.as_strided(song_proll)
        for i in range(0, length-NUM_BARS*RESOLUTION*4+1, RESOLUTION*4):
            
            # Get the sequence and its activations
            seq_tensor = subsong_tensor[:, i:i+NUM_BARS*RESOLUTION*4, :]
            seq_acts = subsong_activations[:, i:i+NUM_BARS*RESOLUTION*4]
            seq_tensor = np.copy(seq_tensor)
            seq_acts = np.copy(seq_acts)

            if NUM_BARS > 1:
                # Skip sequence if it contains more than one bar of consecutive
                # silence in at least one track
                bars = seq_acts.reshape(seq_acts.shape[0], NUM_BARS, -1)
                bars_acts = np.any(bars, axis=2)

                if 1 in np.diff(np.where(bars_acts == 0)[1]):
                    continue
            else:
                # In the case of just 1 bar, skip it if all tracks are silenced
                bar_acts = np.any(seq_acts, axis=1)
                if not np.any(bar_acts):
                    continue
            
            # Randomly transpose the pitches of the sequence (-5 to 6 semitones)
            # Not considering pad, sos, eos tokens
            # Not transposing drums/percussions
            shift = np.random.choice(np.arange(-5, 7), 1)
            cond = (seq_tensor[1:, :, :, 0] != PITCH_PAD) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_SOS) &                     \
                   (seq_tensor[1:, :, :, 0] != PITCH_EOS)
            #cond = (seq_tensor[:, :, :, 0] != PITCH_PAD) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_SOS) &                     \
            #       (seq_tensor[:, :, :, 0] != PITCH_EOS)
            non_perc = seq_tensor[1:, ...]
            #non_perc = seq_tensor
            non_perc[cond, 0] += shift

            # Save sample (seq_tensor and seq_acts) to file
            curr_sample = str(num_samples + saved_samples)
            sample_filepath = os.path.join(dest_dir, curr_sample)
            np.savez(sample_filepath, seq_tensor=seq_tensor, seq_acts=seq_acts)

            saved_samples += 1


    print("File preprocessing finished. Saved samples:", saved_samples)
    print()

    return saved_samples



# Total number of files: 116189
# Number of unique files: 45129
def preprocess_dataset(dataset_dir, dest_dir, num_files=45129, early_exit=None):

    files_dict = {}
    seen = 0
    tot_samples = 0
    not_filtered = 0
    finished = False
    
    print("Starting preprocessing")
    
    progress_bar = tqdm(range(early_exit)) if early_exit is not None else tqdm(range(num_files))
    start = time.time()

    # Visit recursively the directories inside the dataset directory
    for dirpath, dirs, files in os.walk(dataset_dir):

        # Sort alphabetically the found directories
        # (to help guess the remaining time) 
        dirs.sort()
        
        print("Current path:", dirpath)
        print()

        for f in files:
            
            seen += 1

            if f in files_dict:
                # Skip already seen file
                files_dict[f] += 1
                continue

            # File never seen before, add to dictionary of files
            # (from filename to # of occurrences)
            files_dict[f] = 1

            # Preprocess file
            filepath = os.path.join(dirpath, f)
            n_saved = preprocess_file(filepath, dest_dir, tot_samples)

            tot_samples += n_saved
            if n_saved > 0:
                not_filtered += 1
            
            progress_bar.update(1)
            
            # Todo: also print # of processed (not filtered) files
            #       and # of produced sequences (samples)
            print("Total number of seen files:", seen)
            print("Number of unique files:", len(files_dict))
            print("Total number of non filtered songs:", not_filtered)
            print("Total number of saved samples:", tot_samples)
            print()

            # Exit when a maximum number of files has been processed (if set)
            if early_exit != None and len(files_dict) >= early_exit:
                finished = True
                break

        if finished:
            break
    
    end = time.time()
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Preprocessing completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
              .format(int(hours),int(minutes),seconds))


In [None]:
!rm -rf data/preprocessed/
!mkdir data/preprocessed

In [None]:
dataset_dir = 'data/lmd_matched/Y/G/'
dest_dir = 'data/preprocessed'

Check preprocessed data:

In [None]:
preprocess_dataset(dataset_dir, dest_dir, early_exit=50)

In [None]:
filepath = os.path.join(dest_dir, "5.npz")
data = np.load(filepath)

In [None]:
print(data["seq_tensor"].shape)
print(data["seq_acts"].shape)

In [None]:
data["seq_tensor"][0, 1]

# Model

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.data import Data
import itertools


def unpackbits(x, num_bits):

    if np.issubdtype(x.dtype, np.floating):
        raise ValueError("numpy data type needs to be int-like")

    xshape = list(x.shape)
    x = x.reshape([-1, 1])
    mask = 2**np.arange(num_bits, dtype=x.dtype).reshape([1, num_bits])

    return (x & mask).astype(bool).astype(int).reshape(xshape + [num_bits])


class MIDIDataset(Dataset):

    def __init__(self, dir):
        self.dir = dir

    def __len__(self):
        _, _, files = next(os.walk(self.dir))
        return len(files)

    
    def __get_track_edges(self, acts, edge_type_ind=0):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        
        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        track_edges = []

        for track in range(a_t.shape[1]):
            tr_inds = list(inds[inds[:,1] == track])
            e_inds = [(tr_inds[i],
                    tr_inds[i+1]) for i in range(len(tr_inds)-1)]
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, e[1][0]-e[0][0]) for e in e_inds]
            track_edges.extend(edges)

        return np.array(track_edges, dtype='long')

    
    def __get_onset_edges(self, acts, edge_type_ind=1):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        onset_edges = []

        for i in ts_inds:
            ts_acts_inds = list(inds[inds[:,0] == i])
            if len(ts_acts_inds) < 2:
                continue
            e_inds = list(itertools.combinations(ts_acts_inds, 2))
            edges = [(labels[tuple(e[0])], labels[tuple(e[1])], edge_type_ind, 0) for e in e_inds]
            inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
            onset_edges.extend(edges)
            onset_edges.extend(inv_edges)

        return np.array(onset_edges, dtype='long')


    def __get_next_edges(self, acts, edge_type_ind=2):

        a_t = acts.transpose()
        inds = np.stack(np.where(a_t == 1)).transpose()
        ts_acts = np.any(a_t, axis=1)
        ts_inds = np.where(ts_acts)[0]

        # Create node labels
        labels = np.zeros(acts.shape)
        acts_inds = np.where(acts == 1)
        num_nodes = len(acts_inds[0])
        labels[acts_inds] = np.arange(num_nodes)
        labels = labels.transpose()

        next_edges = []

        for i in range(len(ts_inds)-1):

            ind_s = ts_inds[i]
            ind_e = ts_inds[i+1]
            s = inds[inds[:,0] == ind_s]
            e = inds[inds[:,0] == ind_e]

            e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
            edges = [(labels[tuple(e[0])],labels[tuple(e[1])], edge_type_ind, ind_e-ind_s) for e in e_inds]

            next_edges.extend(edges)

        return np.array(next_edges, dtype='long')
    
    
    def _get_node_features(self, acts, num_nodes):
        
        num_tracks = acts.shape[0]
        features = torch.zeros((num_nodes, num_tracks), dtype=torch.float)
        features[np.arange(num_nodes), np.stack(np.where(acts))[0]] = 1.
        
        return features


    def __getitem__(self, idx):

        # Load tensors
        sample_path = os.path.join(self.dir, str(idx) + ".npz")
        data = np.load(sample_path)

        seq_tensor = data["seq_tensor"]
        seq_acts = data["seq_acts"]
        
        # Construct src_key_padding_mask (PAD = 130)
        src_mask = torch.from_numpy((seq_tensor[..., 0] == 130))

        # From decimals to one-hot (pitch)
        pitches = seq_tensor[:, :, :, 0]
        onehot_p = np.zeros((pitches.shape[0]*pitches.shape[1]*pitches.shape[2],
                            131), dtype=float)
        onehot_p[np.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1.
        onehot_p = onehot_p.reshape(-1, pitches.shape[1], seq_tensor.shape[2], 131)

        # From decimals to one-hot (dur)
        #durs = seq_tensor[:, :, :, 1]
        #onehot_d = np.zeros((durs.shape[0]*durs.shape[1]*durs.shape[2],
        #                    99), dtype=float)
        #onehot_d[np.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1.
        #onehot_d = onehot_d.reshape(-1, durs.shape[1], seq_tensor.shape[2], 99)
        #bin_durs = unpackbits(durs, 9)[:, :, :, ::-1]
        
        
        # Concatenate pitches and durations
        #new_seq_tensor = np.concatenate((onehot_p, onehot_d),
        #                     axis=-1)
        new_seq_tensor = onehot_p
        
        # Construct graph from boolean activations
        # Todo: optimize and refactor
        track_edges = self.__get_track_edges(seq_acts)
        onset_edges = self.__get_onset_edges(seq_acts)
        next_edges = self.__get_next_edges(seq_acts)
        edges = [track_edges, onset_edges, next_edges]

        # Concatenate edge tensors (N x 4) (if any)
        # Third column -> edge_type, Fourth column -> timestep distance
        no_edges = (len(track_edges) == 0 and 
                    len(onset_edges) == 0 and len(next_edges) == 0)
        if not no_edges:
            edge_list = np.concatenate([x for x in edges
                                          if x.size > 0])
            edge_list = torch.from_numpy(edge_list)
        
        # Adapt tensor to torch_geometric's Data
        # No edges: add fictitious self-edge
        edge_index = (torch.LongTensor([[0], [0]]) if no_edges else
                               edge_list[:, :2].t().contiguous())
        attrs = (torch.Tensor([[0, 0]]) if no_edges else
                                       edge_list[:, 2:])
        
        edge_attrs = torch.zeros(attrs.size(0), 1+seq_acts.shape[-1])
        edge_attrs[:, 0] = attrs[:, 0]
        edge_attrs[np.arange(edge_attrs.size(0)), attrs.long()[:, 1]+1] = 1
        
        
        #n = seq_acts.shape[0]*seq_acts.shape[1]
        n = torch.sum(torch.Tensor(seq_acts), dtype=torch.long) # sparse
        node_features = self._get_node_features(seq_acts, n)
        graph = Data(edge_index=edge_index, edge_attrs=edge_attrs,
                     num_nodes=n, node_features=node_features)
        
        # Todo: start with torch at mount
        return torch.Tensor(new_seq_tensor), torch.Tensor(seq_acts), graph, src_mask


In [None]:
import torch
from torch import nn, Tensor
from torch_geometric.nn.conv import GCNConv#, RGCNConv
import torch.nn.functional as F
import math
import torch.optim as optim
from torch_scatter import scatter_mean


# Todo: check and think about max_len
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 256):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *                     \
                             (-math.log(10000.0)/d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position*div_term)
        pe[:, 0, 1::2] = torch.cos(position*div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class GraphEncoder(nn.Module):
    
    def __init__(self, input_dim=256, dim_hidden=256, num_layers=3, num_relations=3,
                 edge_features_dim=32, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList()
        edge_nn = nn.Linear(edge_features_dim, input_dim*input_dim)
        self.layers.append(RGCNConv(input_dim, dim_hidden, num_relations, edge_nn))
        for i in range(num_layers-1):
            edge_nn = nn.Linear(edge_features_dim, dim_hidden*dim_hidden)
            self.layers.append(RGCNConv(dim_hidden, dim_hidden, num_relations, edge_nn))
        self.p = dropout

    def forward(self, data):
        x, edge_index, edge_attrs = data.x, data.edge_index, data.edge_attrs
        edge_type = edge_attrs[:, 0]
        edge_attr = edge_attrs[:, 1:]
        
        for layer in self.layers:
            x = F.dropout(x, p=self.p, training=self.training)
            x = layer(x, edge_index, edge_type, edge_attr)
            x = F.relu(x)

        return x


class Encoder(nn.Module):

    # 140 = 128+3+9
    def __init__(self, d_token=131, d_transf=256, nhead_transf=2, 
                 num_layers_transf=2, n_tracks=4, dropout=0.1):
        super().__init__()

        # Todo: one separate encoder for drums
        # Transformer Encoder
        self.embedding = nn.Linear(d_token, d_transf)
        self.pos_encoder = PositionalEncoding(d_transf, dropout=dropout)
        transf_layer = nn.TransformerEncoderLayer(
            d_model=d_transf,
            nhead=nhead_transf,
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(
            transf_layer,
            num_layers=num_layers_transf
        )

        # Graph encoder
        self.graph_encoder = GraphEncoder(dropout=dropout, input_dim=n_tracks+256)

        # (LSTM)
        
        # Linear layers that compute the final mu and log_var
        # Todo: as parameters
        self.linear_mu = nn.Linear(256, 256)
        self.linear_log_var = nn.Linear(256, 256)

        
    def forward(self, x_seq, x_acts, x_graph, src_mask):

        # Collapse track (and optionally batch) dimension
        #print("Init input:", x_seq.size())
        x_seq = x_seq.view(-1, x_seq.size(-2), x_seq.size(-1))
        #print("Reshaped input:", x_seq.size())
        
        # Filter silences
        x_acts = x_acts.view(-1)
        x_seq = x_seq[x_acts.bool()]
        src_mask = src_mask[x_acts.bool()]

        # Compute embeddings
        embs = self.embedding(x_seq)
        #print("Embs:", embs.size())

        # batch_first = False
        embs = embs.permute(1, 0, 2)
        #print("Seq len first input:", embs.size())

        pos_encs = self.pos_encoder(embs)
        #print("Pos encodings:", pos_encs.size())

        transformer_encs = self.transformer_encoder(pos_encs, 
                                                    src_key_padding_mask=src_mask)
        #print("Transf encodings:", transformer_encs.size())

        pooled_encs = torch.mean(transformer_encs, 0)
        #print("Pooled encodings:", pooled_encs.size())

        # Concatenate track one hot features with chord encodings
        # and compute node encodings
        x_graph.x = torch.cat((x_graph.node_features, pooled_encs), 1)
        node_encs = self.graph_encoder(x_graph)
        #print("Node encodings:", node_encs.size())
        
        # Compute final graph latent vector(s)
        # (taking into account the batch size)
        encoding = scatter_mean(node_encs, x_graph.batch, dim=0)
        #num_nodes = x_graph[0].num_nodes
        #batch_sz = node_encs.size(0) // num_nodes
        #node_encs = node_encs.view(batch_sz, num_nodes, -1)
        #encoding = torch.mean(node_encs, 1)

        # Compute mu and log(std^2)
        mu = self.linear_mu(encoding)
        log_var = self.linear_log_var(encoding)
        
        return mu, log_var


class Decoder(nn.Module):

    def __init__(self, d_z=256, n_tracks=4, resolution=32, d_token=131, d_model=256,
                 d_transf=256, nhead_transf=2, num_layers_transf=2, dropout=0.1):
        super().__init__()

        # (LSTM)

        # Boolean activations decoder (CNN/MLP)
        self.acts_decoder = nn.Linear(d_z, n_tracks*resolution)

        # GNN
        self.graph_decoder = GraphEncoder(dropout=dropout, input_dim=n_tracks+256)
        
        # Transformer Decoder
        self.embedding = nn.Linear(d_token, d_transf)
        self.pos_encoder = PositionalEncoding(d_transf, dropout=dropout)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead_transf,
            dropout=dropout
        )
        self.transf_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers_transf
        )
        
        # Last linear layer
        self.lin = nn.Linear(d_model, d_token)


    def forward(self, z, x_seq, x_acts, x_graph, src_mask, tgt_mask):

        # Compute activations from z
        acts_out = self.acts_decoder(z)
        acts_out = acts_out.view(x_acts.size())
        #print("Acts out:", acts_out.size())

        # Initialize node features with z and propagate with GNN
        _, counts = torch.unique(x_graph.batch, return_counts=True)
        z_node_features = torch.repeat_interleave(z, counts, axis=0)
        #print("Node features:", node_features.size())
        
        # Add one-hot encoding of tracks
        # Todo: use also edge info
        x_graph.x = torch.cat((x_graph.node_features, z_node_features), 1)
        node_decs = self.graph_decoder(x_graph)
        #print("Node decodings:", node_decs.size())
        
        # Prepare transformer memory
        node_decs = node_decs.repeat(16, 1, 1)
        #print("Tiled node decodings:", node_decs.size())
        
        # Filter silences
        x_seq = x_seq.view(-1, x_seq.size(-2), x_seq.size(-1))
        x_acts = x_acts.view(-1)
        x_seq = x_seq[x_acts.bool()]
        src_mask = src_mask[x_acts.bool()]
        #print(src_mask.size())
        #print(x_seq.size())
        
        # Todo: same embeddings as encoder?
        embs = self.embedding(x_seq)
        embs = embs.permute(1, 0, 2)
        pos_encs = self.pos_encoder(embs)

        seq_out = self.transf_decoder(pos_encs, node_decs,
                                      tgt_key_padding_mask=src_mask,
                                      tgt_mask=tgt_mask)
        #print("Seq out:", seq_out.size())
        
        seq_out = self.lin(seq_out)
        #print("Seq out after lin:", seq_out.size())
        
        # Softmax on first 131 values (pitch), sigmoid on last 9 (dur)
        #seq_out[:, :, :131] = F.log_softmax(seq_out[:, :, :131], dim=-1)
        #seq_out[:, :, 131:] = torch.sigmoid(seq_out[:, :, 131:])
        seq_out = seq_out.permute(1, 0, 2)
        seq_out = seq_out.view(x_seq.size())
        #print("Seq out after reshape", seq_out.size())
        

        return seq_out, acts_out


class VAE(nn.Module):

    def __init__(self, dropout=0.1, **kwargs):
        super().__init__()

        self.encoder = Encoder(dropout=dropout)
        self.decoder = Decoder(dropout=dropout)
    
    
    def forward(self, x_seq, x_acts, x_graph, src_mask, tgt_mask):
        
        src_mask = src_mask.view(-1, src_mask.size(-1))
        
        mu, log_var = self.encoder(x_seq, x_acts, x_graph, src_mask)
        #print("Mu:", mu.size())
        #print("log_var:", log_var.size())
        
        # Reparameterization trick
        sigma = torch.exp(0.5*log_var)
        eps = torch.randn_like(sigma)
        #print("eps:", eps.size())
        z = mu + eps*sigma
        
        tgt = x_seq[..., :-1, :]
        src_mask = src_mask[:, :-1]
        
        out = self.decoder(z, tgt, x_acts, x_graph, src_mask, tgt_mask)
        
        return out, mu, log_var


Trainer

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt
import uuid
import copy
import time
from statistics import mean
from collections import defaultdict


def generate_square_subsequent_mask(sz):
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


class VAETrainer():
    
    def __init__(self, model, models_path, optimizer, init_lr,
                 name=None, lr_scheduler=None, device=torch.device("cuda"), 
                 print_every=1, save_every=1):
        
        self.model = model
        self.models_path = models_path
        self.optimizer = optimizer
        self.init_lr = init_lr
        self.name = name if name is not None else str(uuid.uuid4())
        self.lr_scheduler = lr_scheduler
        self.device = device
        self.print_every = print_every
        self.save_every = save_every
        
        self.model_path = os.path.join(self.models_path, self.name)
        
        # Criteria with ignored padding
        self.bce_unreduced = nn.BCEWithLogitsLoss(reduction='none')
        self.ce_p = nn.CrossEntropyLoss(ignore_index=130)
        self.ce_d = nn.CrossEntropyLoss(ignore_index=98)
        
        # Training stats
        self.losses = defaultdict(list)
        self.accuracies = defaultdict(list)
        self.lrs = []
        self.times = []
        
    
    def train(self, trainloader, validloader=None, epochs=1,
              early_exit=None):
        
        n_batches = len(trainloader)

        beta = 0 # Todo: _update_params()
        
        self.model.train()
        
        print("Starting training.\n")
        
        start = time.time()
        self.times.append(start)
        
        tot_batches = 0
        
        for epoch in range(epochs):
            
            self.cur_epoch = epoch
            progress_bar = tqdm(range(n_batches))
            
            for batch_idx, inputs in enumerate(trainloader):
                
                self.cur_batch_idx = batch_idx
                
                # Zero out the gradients
                self.optimizer.zero_grad()
                
                # Get the inputs
                x_seq, x_acts, x_graph, src_mask = inputs
                x_seq = x_seq.float().to(self.device)
                x_acts = x_acts.to(self.device)
                x_graph = x_graph.to(self.device)
                src_mask = src_mask.to(self.device)
                tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(self.device)
                inputs = (x_seq, x_acts, x_graph)

                # Forward pass, get the reconstructions
                outputs, mu, log_var = self.model(x_seq, x_acts, x_graph, src_mask, tgt_mask)
                
                # Compute the backprop loss and other required losses
                tot_loss, losses = self._compute_losses(inputs, outputs, mu,
                                                         log_var, beta)
                
                # Backprop and update lr
                tot_loss.backward()
                self.optimizer.step()
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                    
                # Update the stats
                self._append_losses(losses)
                
                last_lr = (self.lr_scheduler.lr 
                               if self.lr_scheduler is not None else self.init_lr)
                self.lrs.append(last_lr)
                
                accs = self._compute_accuracies(inputs, outputs)
                self._append_accuracies(accs)
                
                now = time.time()
                self.times.append(now)
                
                # Print stats
                if (tot_batches + 1) % self.print_every == 0:
                    print("Training on batch {}/{} of epoch {}/{} complete."
                          .format(batch_idx+1, n_batches, epoch+1, epochs))
                    self._print_stats()
                    #print("Tot_loss: {:.4f} acts_loss: {:.4f} "
                          #.format(running_loss/self.print_every, acts_loss), end='')
                    #print("pitches_loss: {:.4f} dur_loss: {:.4f} kld_loss: {:.4f}"
                          #.format(pitches_loss, dur_loss, kld_loss))
                    print("\n----------------------------------------\n")
                    
                # When appropriate, save model and stats on disk
                if self.save_every > 0 and (tot_batches + 1) % self.save_every == 0:
                    print("\nSaving model to disk...\n")
                    self._save_model()
                
                progress_bar.update(1)
                
                # Stop prematurely if early_exit is set and reached
                if early_exit is not None and (tot_batches + 1) > early_exit:
                    break
                
                tot_batches += 1
            

        end = time.time()
        # Todo: self.__print_time()
        hours, rem = divmod(end-start, 3600)
        minutes, seconds = divmod(rem, 60)
        print("Training completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours),int(minutes),seconds))
        
        print("Saving model to disk...")
        self._save_model()
        
        print("Model saved.")
        
    
    def _compute_losses(self, inputs, outputs, mu, log_var, beta):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs for transformer decoder loss and filter silences
        x_seq = x_seq[..., 1:, :]
        x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
                
        # Compute the losses
        
        #acts_loss = self.bce_unreduced(acts_rec.view(-1), x_acts.view(-1).float())
        #weights = torch.zeros(acts_loss.size()).to(device)
        #weights[x_acts.view(-1) == 1] = 0.9
        #weights[x_acts.view(-1) == 0] = 0.1
        #acts_loss = 50 * torch.mean(weights*acts_loss)
        
        pitches_loss = self.ce_p(seq_rec.reshape(-1, seq_rec.size(-1))[:, :131],
                          x_seq.reshape(-1, x_seq.size(-1))[:, :131].argmax(dim=1))
        #dur_loss = self.ce_d(seq_rec.reshape(-1, seq_rec.size(-1))[:, 131:],
        #                  x_seq.reshape(-1, x_seq.size(-1))[:, 131:].argmax(dim=1))
        #dur_loss = mask * dur_loss
        #dur_loss = torch.sum(dur_loss) / torch.sum(mask)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        rec_loss = pitches_loss #+ dur_loss# + acts_loss
        tot_loss = rec_loss + beta*kld_loss
        
        losses = {
            'tot': tot_loss.item(),
            'pitches': pitches_loss.item(),
            #'dur': dur_loss.item(),
            #'acts': acts_loss.item(),
            'rec': rec_loss.item(),
            'kld': kld_loss.item(),
            'beta*kld': beta*kld_loss.item()
        }
        
        return tot_loss, losses

    
    def _append_losses(self, losses):
        
        for k, loss in losses.items():
            self.losses[k].append(loss)
            
            
    def _compute_accuracies(self, inputs, outputs):
        
        x_seq, x_acts, _ = inputs
        seq_rec, acts_rec = outputs
        
        # Shift outputs for transformer decoder loss
        x_seq = x_seq[..., 1:, :]
        x_seq = x_seq[x_acts.bool()]
        #print(x_seq.size())
        #print(seq_rec.size())
        
        #notes_acc = self._note_accuracy(seq_rec, x_seq)
        pitches_acc = self._pitches_accuracy(seq_rec, x_seq)
        #dur_acc = self._dur_accuracy(seq_rec, x_seq)
        #acts_acc = self._acts_accuracy(acts_rec, x_acts)
        
        accs = {
            #'notes': notes_acc.item(),
            'pitches': pitches_acc.item(),
            #'dur': dur_acc.item(),
            #'acts': acts_acc.item()
        }
        
        return accs
        
        
    def _append_accuracies(self, accs):
        
        for k, acc in accs.items():
            self.accuracies[k].append(acc)
    
    
    def _note_accuracy(self, seq_rec, x_seq):
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        print(torch.all(pitches_rec == 129))
        #print(pitches_rec)
        
        mask_p = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask_p)
        
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        print(torch.all(dur_rec == 97))
        
        mask_d = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask_d)
        
        return torch.sum(torch.logical_and(preds_pitches, 
                                           preds_dur)) / torch.sum(mask_p)
    
    
    def _acts_accuracy(self, acts_rec, x_acts):
        
        acts_rec = torch.sigmoid(acts_rec)
        acts_rec[acts_rec < 0.5] = 0
        acts_rec[acts_rec >= 0.5] = 1
        
        print(torch.all(acts_rec == 0))
        print(acts_rec)
        
        return torch.sum(acts_rec == x_acts) / x_acts.numel()
    
    
    def _pitches_accuracy(self, seq_rec, x_seq):
        
        pitches_rec = F.softmax(seq_rec[..., :131], dim=-1)
        pitches_rec = torch.argmax(pitches_rec, dim=-1)
        pitches_true = torch.argmax(x_seq[..., :131], dim=-1)
        
        print(torch.all(pitches_rec == 129))
        
        mask = (pitches_true != 130)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_pitches = (pitches_rec == pitches_true)
        preds_pitches = torch.logical_and(preds_pitches, mask)
        
        return torch.sum(preds_pitches) / torch.sum(mask)
    
    
    def _dur_accuracy(self, seq_rec, x_seq):
        
        dur_rec = F.softmax(seq_rec[..., 131:], dim=-1)
        dur_rec = torch.argmax(dur_rec, dim=-1)
        dur_true = torch.argmax(x_seq[..., 131:], dim=-1)
        
        mask = (dur_true != 98)
        #mask = torch.logical_and(pitches_true != 128,
         #                        pitches_true != 129)
        #mask = torch.logical_and(mask,
         #                        pitches_true != 130)
        
        preds_dur = (dur_rec == dur_true)
        preds_dur = torch.logical_and(preds_dur, mask)
        
        return torch.sum(preds_dur) / torch.sum(mask)
    
    
    def _save_model(self):
        torch.save({
            'epoch': self.cur_epoch,
            'batch': self.cur_batch_idx,
            'save_every': self.save_every,
            'lrs': self.lrs,
            'losses': self.losses,
            'accuracies': self.accuracies,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }, self.model_path)
        
        
    def _print_stats(self):
        
        hours, rem = divmod(self.times[-1]-self.times[0], 3600)
        minutes, seconds = divmod(rem, 60)
        print("Elapsed time from start (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
                  .format(int(hours), int(minutes), seconds))
        
        avg_lr = mean(self.lrs[-self.print_every:])
        
        # Take mean of the last non-printed batches for each stat
        
        avg_losses = {}
        for k, l in self.losses.items():
            avg_losses[k] = mean(l[-self.print_every:])
        
        avg_accs = {}
        for k, l in self.accuracies.items():
            avg_accs[k] = mean(l[-self.print_every:])
        
        print("Losses:")
        print(avg_losses)
        print("Accuracies:")
        print(avg_accs)
        


Training

In [None]:
models_path = "models/"
os.makedirs(models_path, exist_ok=True)

In [None]:
from torch.utils.data import Subset

ds_dir = "data/preprocessed"
dataset = MIDIDataset(ds_dir)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
print(len(dataset))

#n_samples = 128
#subset = Subset(dataset, np.arange(n_samples))
#loader = DataLoader(subset, batch_size=64, shuffle=True)

In [None]:
dataset[1][0].size()

In [None]:
import torch
torch.cuda.set_device(0)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#decive = torch.device("cpu")
print("Device:", device)
print("Current device idx:", torch.cuda.current_device())

In [None]:
#!rm models/vae

In [None]:
from prettytable import PrettyTable


def print_params(model):
    
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    
    for name, parameter in model.named_parameters():
        
        if not parameter.requires_grad:
            continue
            
        param = parameter.numel()
        table.add_row([name, param])
        total_params += param
        
    print(table)
    print(f"Total Trainable Params: {total_params}")
    
    return total_params

In [None]:
import math
import torch
from typing import Optional
from torch.optim import Optimizer

from lr_scheduler.lr_scheduler import LearningRateScheduler

class TransformerLRScheduler(LearningRateScheduler):
    r"""
    Transformer Learning Rate Scheduler proposed in "Attention Is All You Need"
    Args:
        optimizer (Optimizer): Optimizer.
        init_lr (float): Initial learning rate.
        peak_lr (float): Maximum learning rate.
        final_lr (float): Final learning rate.
        final_lr_scale (float): Final learning rate scale
        warmup_steps (int): Warmup the learning rate linearly for the first N updates
        decay_steps (int): Steps in decay stages
    """
    def __init__(
            self,
            optimizer: Optimizer,
            init_lr: float,
            peak_lr: float,
            final_lr: float,
            final_lr_scale: float,
            warmup_steps: int,
            decay_steps: int,
    ) -> None:
        assert isinstance(warmup_steps, int), "warmup_steps should be inteager type"
        assert isinstance(decay_steps, int), "total_steps should be inteager type"

        super(TransformerLRScheduler, self).__init__(optimizer, init_lr)
        self.final_lr = final_lr
        self.peak_lr = peak_lr
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps

        self.warmup_rate = self.peak_lr / self.warmup_steps
        self.decay_factor = -math.log(final_lr_scale) / self.decay_steps

        self.init_lr = init_lr
        self.update_steps = 0

    def _decide_stage(self):
        if self.update_steps < self.warmup_steps:
            return 0, self.update_steps

        if self.warmup_steps <= self.update_steps:
            return 1, self.update_steps - self.warmup_steps

        return 2, None

    def step(self, val_loss: Optional[torch.FloatTensor] = None):
        self.update_steps += 1
        stage, steps_in_stage = self._decide_stage()

        if stage == 0:
            self.lr = self.update_steps * self.warmup_rate
        elif stage == 1:
            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
        else:
            raise ValueError("Undefined stage")

        self.set_lr(self.optimizer, self.lr)

        return self.lr

In [None]:
#from lr_scheduler.transformer_lr_scheduler import TransformerLRScheduler

print("Creating the model and moving it to the specified device...")

vae = VAE(dropout=0).to(device)
print_params(vae)
print()

init_lr = 5e-6
gamma = 0.999
optimizer = optim.Adam(vae.parameters(), lr=init_lr, betas=(0.9, 0.98), eps=1e-09)
#scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma)
scheduler = TransformerLRScheduler(
        optimizer=optimizer, 
        init_lr=1e-10, 
        peak_lr=5e-4,
        final_lr=1e-7, 
        final_lr_scale=0.1,
        warmup_steps=4000, 
        decay_steps=80000,
)

print('--------------------------------------------------\n')

trainer = VAETrainer(
    vae,
    models_path,
    optimizer,
    init_lr,
    name='just_pitches_warmup_rgcn_ecc',
    lr_scheduler=scheduler,
    save_every=100, 
    device=device
)
trainer.train(loader, epochs=100)

tensor(False, device='cuda:0')
Training on batch 378/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:10.98
Losses:
{'tot': 0.1445312798023224, 'pitches': 0.1445312798023224, 'rec': 0.1445312798023224, 'kld': 102150.8671875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9464519023895264}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 379/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:11.34
Losses:
{'tot': 0.2019306868314743, 'pitches': 0.2019306868314743, 'rec': 0.2019306868314743, 'kld': 104155.53125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9313154816627502}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 380/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:11.71
Losses:
{'tot': 0.15902890264987946, 'pitches': 0.15902890264987946, 'rec': 0.15902890264987946, 'kld': 108894.796875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.94344264268875

tensor(False, device='cuda:0')
Training on batch 402/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:20.01
Losses:
{'tot': 0.12579357624053955, 'pitches': 0.12579357624053955, 'rec': 0.12579357624053955, 'kld': 104288.125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9540650248527527}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 403/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:20.36
Losses:
{'tot': 0.20271269977092743, 'pitches': 0.20271269977092743, 'rec': 0.20271269977092743, 'kld': 99960.734375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9278752207756042}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 404/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:20.74
Losses:
{'tot': 0.13546450436115265, 'pitches': 0.13546450436115265, 'rec': 0.13546450436115265, 'kld': 89545.84375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95021647214889

tensor(False, device='cuda:0')
Training on batch 426/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:29.53
Losses:
{'tot': 0.11418938636779785, 'pitches': 0.11418938636779785, 'rec': 0.11418938636779785, 'kld': 106081.09375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9560787081718445}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 427/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:29.89
Losses:
{'tot': 0.13859528303146362, 'pitches': 0.13859528303146362, 'rec': 0.13859528303146362, 'kld': 102708.421875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.951545238494873}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 428/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:30.25
Losses:
{'tot': 0.18185444176197052, 'pitches': 0.18185444176197052, 'rec': 0.18185444176197052, 'kld': 108305.125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9361614584922

tensor(False, device='cuda:0')
Training on batch 450/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:38.26
Losses:
{'tot': 0.12838472425937653, 'pitches': 0.12838472425937653, 'rec': 0.12838472425937653, 'kld': 101144.5546875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9513123035430908}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 451/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:38.63
Losses:
{'tot': 0.1183507889509201, 'pitches': 0.1183507889509201, 'rec': 0.1183507889509201, 'kld': 104403.3125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9521738886833191}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 452/471 of epoch 26/100 complete.
Elapsed time from start (h:m:s): 01:15:38.98
Losses:
{'tot': 0.15213043987751007, 'pitches': 0.15213043987751007, 'rec': 0.15213043987751007, 'kld': 107069.046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.945525288581

  0%|          | 0/471 [00:00<?, ?it/s]

tensor(False, device='cuda:0')
Training on batch 1/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:15:46.28
Losses:
{'tot': 0.10309943556785583, 'pitches': 0.10309943556785583, 'rec': 0.10309943556785583, 'kld': 117974.78125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9659834504127502}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 2/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:15:46.67
Losses:
{'tot': 0.120806023478508, 'pitches': 0.120806023478508, 'rec': 0.120806023478508, 'kld': 110413.390625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.955244779586792}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 3/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:15:47.05
Losses:
{'tot': 0.12879088521003723, 'pitches': 0.12879088521003723, 'rec': 0.12879088521003723, 'kld': 113681.5390625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9591218829154968}

--

tensor(False, device='cuda:0')
Training on batch 25/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:15:55.29
Losses:
{'tot': 0.15163028240203857, 'pitches': 0.15163028240203857, 'rec': 0.15163028240203857, 'kld': 105819.375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9399664402008057}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 26/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:15:55.63
Losses:
{'tot': 0.13529907166957855, 'pitches': 0.13529907166957855, 'rec': 0.13529907166957855, 'kld': 97166.9453125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9509127736091614}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 27/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:15:55.96
Losses:
{'tot': 0.11465170979499817, 'pitches': 0.11465170979499817, 'rec': 0.11465170979499817, 'kld': 101106.3046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9561631679534

tensor(False, device='cuda:0')
Training on batch 49/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:03.77
Losses:
{'tot': 0.10427696257829666, 'pitches': 0.10427696257829666, 'rec': 0.10427696257829666, 'kld': 102387.484375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9655172228813171}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 50/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:04.10
Losses:
{'tot': 0.12100806832313538, 'pitches': 0.12100806832313538, 'rec': 0.12100806832313538, 'kld': 104080.453125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9582801461219788}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 51/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:04.44
Losses:
{'tot': 0.1065930426120758, 'pitches': 0.1065930426120758, 'rec': 0.1065930426120758, 'kld': 111461.296875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.96223384141922

tensor(False, device='cuda:0')
Training on batch 73/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:13.06
Losses:
{'tot': 0.08845777809619904, 'pitches': 0.08845777809619904, 'rec': 0.08845777809619904, 'kld': 106354.6953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9645732641220093}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 74/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:13.42
Losses:
{'tot': 0.1645064502954483, 'pitches': 0.1645064502954483, 'rec': 0.1645064502954483, 'kld': 101741.9375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9421245455741882}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 75/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:13.77
Losses:
{'tot': 0.08008336275815964, 'pitches': 0.08008336275815964, 'rec': 0.08008336275815964, 'kld': 100115.78125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9681836366653442

tensor(False, device='cuda:0')
Training on batch 97/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:21.55
Losses:
{'tot': 0.12510956823825836, 'pitches': 0.12510956823825836, 'rec': 0.12510956823825836, 'kld': 104099.9375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.953987717628479}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 98/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:21.92
Losses:
{'tot': 0.11925598233938217, 'pitches': 0.11925598233938217, 'rec': 0.11925598233938217, 'kld': 98617.703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9558441638946533}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 99/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:22.29
Losses:
{'tot': 0.11277778446674347, 'pitches': 0.11277778446674347, 'rec': 0.11277778446674347, 'kld': 102007.953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.958662152290344

tensor(False, device='cuda:0')
Training on batch 121/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:30.21
Losses:
{'tot': 0.13031332194805145, 'pitches': 0.13031332194805145, 'rec': 0.13031332194805145, 'kld': 96890.59375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9535603523254395}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 122/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:30.58
Losses:
{'tot': 0.14550481736660004, 'pitches': 0.14550481736660004, 'rec': 0.14550481736660004, 'kld': 104871.125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9451494216918945}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 123/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:30.94
Losses:
{'tot': 0.17454469203948975, 'pitches': 0.17454469203948975, 'rec': 0.17454469203948975, 'kld': 109304.4375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.934662222862243

tensor(False, device='cuda:0')
Training on batch 145/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:38.81
Losses:
{'tot': 0.14310875535011292, 'pitches': 0.14310875535011292, 'rec': 0.14310875535011292, 'kld': 101194.015625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9494423866271973}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 146/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:39.15
Losses:
{'tot': 0.08942829817533493, 'pitches': 0.08942829817533493, 'rec': 0.08942829817533493, 'kld': 105066.359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9653105139732361}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 147/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:39.53
Losses:
{'tot': 0.14641401171684265, 'pitches': 0.14641401171684265, 'rec': 0.14641401171684265, 'kld': 107145.5859375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9456349

tensor(False, device='cuda:0')
Training on batch 169/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:48.66
Losses:
{'tot': 0.1258522868156433, 'pitches': 0.1258522868156433, 'rec': 0.1258522868156433, 'kld': 106852.921875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9515151381492615}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 170/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:49.03
Losses:
{'tot': 0.08639182150363922, 'pitches': 0.08639182150363922, 'rec': 0.08639182150363922, 'kld': 106647.671875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9694201946258545}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 171/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:49.40
Losses:
{'tot': 0.13615119457244873, 'pitches': 0.13615119457244873, 'rec': 0.13615119457244873, 'kld': 102176.1171875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9509500861

tensor(False, device='cuda:0')
Training on batch 193/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:57.81
Losses:
{'tot': 0.13051234185695648, 'pitches': 0.13051234185695648, 'rec': 0.13051234185695648, 'kld': 103574.828125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9503129720687866}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 194/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:58.16
Losses:
{'tot': 0.08932150900363922, 'pitches': 0.08932150900363922, 'rec': 0.08932150900363922, 'kld': 113816.640625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9672915935516357}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 195/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:16:58.52
Losses:
{'tot': 0.10101178288459778, 'pitches': 0.10101178288459778, 'rec': 0.10101178288459778, 'kld': 111631.28125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.962818861

tensor(False, device='cuda:0')
Training on batch 217/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:06.73
Losses:
{'tot': 0.13928857445716858, 'pitches': 0.13928857445716858, 'rec': 0.13928857445716858, 'kld': 108228.765625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9531975984573364}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 218/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:07.09
Losses:
{'tot': 0.12565521895885468, 'pitches': 0.12565521895885468, 'rec': 0.12565521895885468, 'kld': 104592.9453125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9533103704452515}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 219/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:07.47
Losses:
{'tot': 0.13053636252880096, 'pitches': 0.13053636252880096, 'rec': 0.13053636252880096, 'kld': 102945.8359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.950819

tensor(False, device='cuda:0')
Training on batch 241/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:15.55
Losses:
{'tot': 0.12426484376192093, 'pitches': 0.12426484376192093, 'rec': 0.12426484376192093, 'kld': 98338.0078125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9566679000854492}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 242/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:15.92
Losses:
{'tot': 0.10799355059862137, 'pitches': 0.10799355059862137, 'rec': 0.10799355059862137, 'kld': 103880.96875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9626726508140564}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 243/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:16.25
Losses:
{'tot': 0.14432144165039062, 'pitches': 0.14432144165039062, 'rec': 0.14432144165039062, 'kld': 111573.5703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95114940

tensor(False, device='cuda:0')
Training on batch 265/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:25.27
Losses:
{'tot': 0.14258049428462982, 'pitches': 0.14258049428462982, 'rec': 0.14258049428462982, 'kld': 99783.46875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9481511116027832}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 266/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:25.61
Losses:
{'tot': 0.12499143183231354, 'pitches': 0.12499143183231354, 'rec': 0.12499143183231354, 'kld': 99594.78125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9547051787376404}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 267/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:25.94
Losses:
{'tot': 0.14629626274108887, 'pitches': 0.14629626274108887, 'rec': 0.14629626274108887, 'kld': 112265.359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.948101282119

tensor(False, device='cuda:0')
Training on batch 289/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:33.75
Losses:
{'tot': 0.15275971591472626, 'pitches': 0.15275971591472626, 'rec': 0.15275971591472626, 'kld': 102749.921875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9455645084381104}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 290/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:34.13
Losses:
{'tot': 0.19224728643894196, 'pitches': 0.19224728643894196, 'rec': 0.19224728643894196, 'kld': 99567.265625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9309453368186951}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 291/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:34.50
Losses:
{'tot': 0.14660227298736572, 'pitches': 0.14660227298736572, 'rec': 0.14660227298736572, 'kld': 107625.8125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95023196935

tensor(False, device='cuda:0')
Training on batch 313/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:42.27
Losses:
{'tot': 0.13698235154151917, 'pitches': 0.13698235154151917, 'rec': 0.13698235154151917, 'kld': 103043.578125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.948728621006012}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 314/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:42.66
Losses:
{'tot': 0.1289287656545639, 'pitches': 0.1289287656545639, 'rec': 0.1289287656545639, 'kld': 99082.109375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9535603523254395}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 315/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:43.06
Losses:
{'tot': 0.16020125150680542, 'pitches': 0.16020125150680542, 'rec': 0.16020125150680542, 'kld': 102309.359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9405940771102

tensor(False, device='cuda:0')
Training on batch 337/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:51.21
Losses:
{'tot': 0.14746391773223877, 'pitches': 0.14746391773223877, 'rec': 0.14746391773223877, 'kld': 102749.0390625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9516128897666931}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 338/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:51.59
Losses:
{'tot': 0.15935617685317993, 'pitches': 0.15935617685317993, 'rec': 0.15935617685317993, 'kld': 102184.5546875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9416356682777405}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 339/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:17:51.95
Losses:
{'tot': 0.1378423124551773, 'pitches': 0.1378423124551773, 'rec': 0.1378423124551773, 'kld': 102032.3203125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95032840

tensor(False, device='cuda:0')
Training on batch 361/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:00.95
Losses:
{'tot': 0.15565823018550873, 'pitches': 0.15565823018550873, 'rec': 0.15565823018550873, 'kld': 93381.015625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9438652992248535}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 362/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:01.31
Losses:
{'tot': 0.15546800196170807, 'pitches': 0.15546800196170807, 'rec': 0.15546800196170807, 'kld': 99489.109375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9420458674430847}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 363/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:01.68
Losses:
{'tot': 0.15759840607643127, 'pitches': 0.15759840607643127, 'rec': 0.15759840607643127, 'kld': 105473.0703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.945852518

tensor(False, device='cuda:0')
Training on batch 385/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:09.49
Losses:
{'tot': 0.152801975607872, 'pitches': 0.152801975607872, 'rec': 0.152801975607872, 'kld': 103111.6875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9465478658676147}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 386/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:09.83
Losses:
{'tot': 0.09925774484872818, 'pitches': 0.09925774484872818, 'rec': 0.09925774484872818, 'kld': 107687.3046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9609586596488953}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 387/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:10.21
Losses:
{'tot': 0.12139809876680374, 'pitches': 0.12139809876680374, 'rec': 0.12139809876680374, 'kld': 111411.984375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.955172419548034

tensor(False, device='cuda:0')
Training on batch 409/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:18.57
Losses:
{'tot': 0.1019904837012291, 'pitches': 0.1019904837012291, 'rec': 0.1019904837012291, 'kld': 107104.9765625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9630076289176941}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 410/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:18.92
Losses:
{'tot': 0.09092268347740173, 'pitches': 0.09092268347740173, 'rec': 0.09092268347740173, 'kld': 100741.921875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9683042764663696}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 411/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:19.29
Losses:
{'tot': 0.12141311168670654, 'pitches': 0.12141311168670654, 'rec': 0.12141311168670654, 'kld': 98014.0, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9568291902542114

tensor(False, device='cuda:0')
Training on batch 433/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:27.63
Losses:
{'tot': 0.16624921560287476, 'pitches': 0.16624921560287476, 'rec': 0.16624921560287476, 'kld': 105129.046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9395084977149963}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 434/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:27.97
Losses:
{'tot': 0.12809021770954132, 'pitches': 0.12809021770954132, 'rec': 0.12809021770954132, 'kld': 104181.9921875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.951869547367096}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 435/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:28.33
Losses:
{'tot': 0.14309245347976685, 'pitches': 0.14309245347976685, 'rec': 0.14309245347976685, 'kld': 105575.890625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.94749033

tensor(False, device='cuda:0')
Training on batch 457/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:37.27
Losses:
{'tot': 0.13813574612140656, 'pitches': 0.13813574612140656, 'rec': 0.13813574612140656, 'kld': 111740.015625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9500599503517151}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 458/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:37.65
Losses:
{'tot': 0.09317752718925476, 'pitches': 0.09317752718925476, 'rec': 0.09317752718925476, 'kld': 106120.046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9632701277732849}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 459/471 of epoch 27/100 complete.
Elapsed time from start (h:m:s): 01:18:38.03
Losses:
{'tot': 0.13287478685379028, 'pitches': 0.13287478685379028, 'rec': 0.13287478685379028, 'kld': 106973.96875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.946413516

  0%|          | 0/471 [00:00<?, ?it/s]

tensor(False, device='cuda:0')
Training on batch 1/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:18:42.99
Losses:
{'tot': 0.10745228081941605, 'pitches': 0.10745228081941605, 'rec': 0.10745228081941605, 'kld': 97738.484375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9588091373443604}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 2/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:18:43.35
Losses:
{'tot': 0.13909801840782166, 'pitches': 0.13909801840782166, 'rec': 0.13909801840782166, 'kld': 102973.1015625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9490616917610168}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 3/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:18:43.70
Losses:
{'tot': 0.1285875141620636, 'pitches': 0.1285875141620636, 'rec': 0.1285875141620636, 'kld': 104933.078125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9557126760482788}

tensor(False, device='cuda:0')
Training on batch 25/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:18:51.63
Losses:
{'tot': 0.12157328426837921, 'pitches': 0.12157328426837921, 'rec': 0.12157328426837921, 'kld': 105983.875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9544715285301208}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 26/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:18:51.97
Losses:
{'tot': 0.12628699839115143, 'pitches': 0.12628699839115143, 'rec': 0.12628699839115143, 'kld': 106171.59375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9526530504226685}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 27/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:18:52.34
Losses:
{'tot': 0.11587315052747726, 'pitches': 0.11587315052747726, 'rec': 0.11587315052747726, 'kld': 103905.0625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9580885171890259}

tensor(False, device='cuda:0')
Training on batch 49/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:00.80
Losses:
{'tot': 0.1522531807422638, 'pitches': 0.1522531807422638, 'rec': 0.1522531807422638, 'kld': 104050.390625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9485062956809998}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 50/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:01.17
Losses:
{'tot': 0.1315256953239441, 'pitches': 0.1315256953239441, 'rec': 0.1315256953239441, 'kld': 102611.6015625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9531126618385315}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 51/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:01.55
Losses:
{'tot': 0.14371822774410248, 'pitches': 0.14371822774410248, 'rec': 0.14371822774410248, 'kld': 105140.890625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9457530379295349

tensor(False, device='cuda:0')
Training on batch 73/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:09.28
Losses:
{'tot': 0.12720967829227448, 'pitches': 0.12720967829227448, 'rec': 0.12720967829227448, 'kld': 108952.3359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9506781697273254}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 74/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:09.64
Losses:
{'tot': 0.12443269789218903, 'pitches': 0.12443269789218903, 'rec': 0.12443269789218903, 'kld': 103820.90625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9522878527641296}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 75/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:10.03
Losses:
{'tot': 0.18088845908641815, 'pitches': 0.18088845908641815, 'rec': 0.18088845908641815, 'kld': 95631.5703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.93497061729

tensor(False, device='cuda:0')
Training on batch 97/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:19.41
Losses:
{'tot': 0.11171197891235352, 'pitches': 0.11171197891235352, 'rec': 0.11171197891235352, 'kld': 104845.359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9595257639884949}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 98/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:19.80
Losses:
{'tot': 0.11239931732416153, 'pitches': 0.11239931732416153, 'rec': 0.11239931732416153, 'kld': 107268.0, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.959454357624054}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 99/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:20.18
Losses:
{'tot': 0.12411954253911972, 'pitches': 0.12411954253911972, 'rec': 0.12411954253911972, 'kld': 108295.4375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9517804384231567}



tensor(False, device='cuda:0')
Training on batch 121/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:27.97
Losses:
{'tot': 0.13457924127578735, 'pitches': 0.13457924127578735, 'rec': 0.13457924127578735, 'kld': 104086.1953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9499121308326721}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 122/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:28.42
Losses:
{'tot': 0.1429058313369751, 'pitches': 0.1429058313369751, 'rec': 0.1429058313369751, 'kld': 97789.1484375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9505717158317566}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 123/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:28.81
Losses:
{'tot': 0.12078162282705307, 'pitches': 0.12078162282705307, 'rec': 0.12078162282705307, 'kld': 100809.21875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95816409587

tensor(False, device='cuda:0')
Training on batch 145/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:36.83
Losses:
{'tot': 0.1016961932182312, 'pitches': 0.1016961932182312, 'rec': 0.1016961932182312, 'kld': 103474.015625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9635480046272278}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 146/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:37.17
Losses:
{'tot': 0.1253073811531067, 'pitches': 0.1253073811531067, 'rec': 0.1253073811531067, 'kld': 108857.109375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9547511339187622}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 147/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:37.52
Losses:
{'tot': 0.11726132035255432, 'pitches': 0.11726132035255432, 'rec': 0.11726132035255432, 'kld': 106995.03125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.956173598766326

tensor(False, device='cuda:0')
Training on batch 169/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:45.64
Losses:
{'tot': 0.09814304858446121, 'pitches': 0.09814304858446121, 'rec': 0.09814304858446121, 'kld': 102536.6875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9624733328819275}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 170/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:46.01
Losses:
{'tot': 0.10447771847248077, 'pitches': 0.10447771847248077, 'rec': 0.10447771847248077, 'kld': 103036.4296875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9603841304779053}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 171/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:46.37
Losses:
{'tot': 0.13528816401958466, 'pitches': 0.13528816401958466, 'rec': 0.13528816401958466, 'kld': 95375.53125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.94974279403

tensor(False, device='cuda:0')
Training on batch 193/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:55.23
Losses:
{'tot': 0.15120331943035126, 'pitches': 0.15120331943035126, 'rec': 0.15120331943035126, 'kld': 105439.8203125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.948457658290863}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 194/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:55.59
Losses:
{'tot': 0.12314796447753906, 'pitches': 0.12314796447753906, 'rec': 0.12314796447753906, 'kld': 112644.640625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9594945907592773}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 195/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:19:55.94
Losses:
{'tot': 0.11740802973508835, 'pitches': 0.11740802973508835, 'rec': 0.11740802973508835, 'kld': 101164.640625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95702004

tensor(False, device='cuda:0')
Training on batch 217/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:03.73
Losses:
{'tot': 0.08828025311231613, 'pitches': 0.08828025311231613, 'rec': 0.08828025311231613, 'kld': 114150.515625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9654855132102966}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 218/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:04.11
Losses:
{'tot': 0.09274878352880478, 'pitches': 0.09274878352880478, 'rec': 0.09274878352880478, 'kld': 113665.65625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9668228030204773}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 219/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:04.48
Losses:
{'tot': 0.15964211523532867, 'pitches': 0.15964211523532867, 'rec': 0.15964211523532867, 'kld': 103473.46875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9411976933

tensor(False, device='cuda:0')
Training on batch 241/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:12.39
Losses:
{'tot': 0.11170990765094757, 'pitches': 0.11170990765094757, 'rec': 0.11170990765094757, 'kld': 108270.015625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9630467295646667}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 242/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:12.76
Losses:
{'tot': 0.10121878236532211, 'pitches': 0.10121878236532211, 'rec': 0.10121878236532211, 'kld': 106294.9140625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.960770308971405}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 243/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:13.13
Losses:
{'tot': 0.10627414286136627, 'pitches': 0.10627414286136627, 'rec': 0.10627414286136627, 'kld': 110602.796875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.96037662

tensor(False, device='cuda:0')
Training on batch 265/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:20.81
Losses:
{'tot': 0.12332470715045929, 'pitches': 0.12332470715045929, 'rec': 0.12332470715045929, 'kld': 103214.96875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9538003206253052}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 266/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:21.14
Losses:
{'tot': 0.11438950896263123, 'pitches': 0.11438950896263123, 'rec': 0.11438950896263123, 'kld': 107654.953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9577092528343201}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 267/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:21.47
Losses:
{'tot': 0.11513593792915344, 'pitches': 0.11513593792915344, 'rec': 0.11513593792915344, 'kld': 114160.7109375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.96011006

tensor(False, device='cuda:0')
Training on batch 289/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:30.23
Losses:
{'tot': 0.12905670702457428, 'pitches': 0.12905670702457428, 'rec': 0.12905670702457428, 'kld': 110705.421875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9539228081703186}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 290/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:30.62
Losses:
{'tot': 0.08458080887794495, 'pitches': 0.08458080887794495, 'rec': 0.08458080887794495, 'kld': 108460.421875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9692832827568054}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 291/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:30.99
Losses:
{'tot': 0.17405886948108673, 'pitches': 0.17405886948108673, 'rec': 0.17405886948108673, 'kld': 107618.078125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.94043248

tensor(False, device='cuda:0')
Training on batch 313/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:39.04
Losses:
{'tot': 0.07207910716533661, 'pitches': 0.07207910716533661, 'rec': 0.07207910716533661, 'kld': 104839.046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9749494791030884}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 314/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:39.42
Losses:
{'tot': 0.1530456840991974, 'pitches': 0.1530456840991974, 'rec': 0.1530456840991974, 'kld': 104341.6875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9444659948348999}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 315/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:39.78
Losses:
{'tot': 0.1136527881026268, 'pitches': 0.1136527881026268, 'rec': 0.1136527881026268, 'kld': 107165.9375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9564348459243774}


tensor(False, device='cuda:0')
Training on batch 337/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:47.77
Losses:
{'tot': 0.096275195479393, 'pitches': 0.096275195479393, 'rec': 0.096275195479393, 'kld': 103848.859375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9649820923805237}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 338/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:48.18
Losses:
{'tot': 0.10956251621246338, 'pitches': 0.10956251621246338, 'rec': 0.10956251621246338, 'kld': 104288.28125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9607762098312378}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 339/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:48.54
Losses:
{'tot': 0.14608293771743774, 'pitches': 0.14608293771743774, 'rec': 0.14608293771743774, 'kld': 101868.953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.947471559047699

tensor(False, device='cuda:0')
Training on batch 361/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:56.51
Losses:
{'tot': 0.13802918791770935, 'pitches': 0.13802918791770935, 'rec': 0.13802918791770935, 'kld': 108223.828125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9507094621658325}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 362/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:56.86
Losses:
{'tot': 0.10387273132801056, 'pitches': 0.10387273132801056, 'rec': 0.10387273132801056, 'kld': 108147.9375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9605367183685303}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 363/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:20:57.22
Losses:
{'tot': 0.11719230562448502, 'pitches': 0.11719230562448502, 'rec': 0.11719230562448502, 'kld': 105618.0078125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.958134412

tensor(False, device='cuda:0')
Training on batch 385/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:06.00
Losses:
{'tot': 0.12151878327131271, 'pitches': 0.12151878327131271, 'rec': 0.12151878327131271, 'kld': 106869.171875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9551842212677002}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 386/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:06.36
Losses:
{'tot': 0.09951169788837433, 'pitches': 0.09951169788837433, 'rec': 0.09951169788837433, 'kld': 96978.71875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9654510617256165}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 387/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:06.77
Losses:
{'tot': 0.11074624955654144, 'pitches': 0.11074624955654144, 'rec': 0.11074624955654144, 'kld': 103508.3828125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.958750486

tensor(False, device='cuda:0')
Training on batch 409/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:14.94
Losses:
{'tot': 0.12133733928203583, 'pitches': 0.12133733928203583, 'rec': 0.12133733928203583, 'kld': 102802.734375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9561623930931091}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 410/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:15.27
Losses:
{'tot': 0.14898623526096344, 'pitches': 0.14898623526096344, 'rec': 0.14898623526096344, 'kld': 108218.328125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9468302726745605}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 411/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:15.63
Losses:
{'tot': 0.1259814351797104, 'pitches': 0.1259814351797104, 'rec': 0.1259814351797104, 'kld': 106790.953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95604854822

tensor(False, device='cuda:0')
Training on batch 433/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:23.55
Losses:
{'tot': 0.13111238181591034, 'pitches': 0.13111238181591034, 'rec': 0.13111238181591034, 'kld': 103336.484375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.952531635761261}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 434/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:24.04
Losses:
{'tot': 0.15941889584064484, 'pitches': 0.15941889584064484, 'rec': 0.15941889584064484, 'kld': 96715.6640625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9391649961471558}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 435/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:24.43
Losses:
{'tot': 0.12415948510169983, 'pitches': 0.12415948510169983, 'rec': 0.12415948510169983, 'kld': 102622.3984375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95830142

tensor(False, device='cuda:0')
Training on batch 457/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:32.46
Losses:
{'tot': 0.1397886872291565, 'pitches': 0.1397886872291565, 'rec': 0.1397886872291565, 'kld': 96171.984375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9480569362640381}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 458/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:32.82
Losses:
{'tot': 0.11706309020519257, 'pitches': 0.11706309020519257, 'rec': 0.11706309020519257, 'kld': 105641.71875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9587137699127197}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 459/471 of epoch 28/100 complete.
Elapsed time from start (h:m:s): 01:21:33.18
Losses:
{'tot': 0.16156402230262756, 'pitches': 0.16156402230262756, 'rec': 0.16156402230262756, 'kld': 97679.5625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9432653188705444

  0%|          | 0/471 [00:00<?, ?it/s]

tensor(False, device='cuda:0')
Training on batch 1/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:37.45
Losses:
{'tot': 0.12337972223758698, 'pitches': 0.12337972223758698, 'rec': 0.12337972223758698, 'kld': 102107.3515625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9591836929321289}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 2/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:37.80
Losses:
{'tot': 0.10111132264137268, 'pitches': 0.10111132264137268, 'rec': 0.10111132264137268, 'kld': 103607.65625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9654757380485535}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 3/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:38.15
Losses:
{'tot': 0.11402030289173126, 'pitches': 0.11402030289173126, 'rec': 0.11402030289173126, 'kld': 104809.96875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.956701040267944

tensor(False, device='cuda:0')
Training on batch 25/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:47.24
Losses:
{'tot': 0.1345367431640625, 'pitches': 0.1345367431640625, 'rec': 0.1345367431640625, 'kld': 109731.390625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9461439251899719}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 26/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:47.62
Losses:
{'tot': 0.12485505640506744, 'pitches': 0.12485505640506744, 'rec': 0.12485505640506744, 'kld': 100076.609375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9582718014717102}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 27/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:47.97
Losses:
{'tot': 0.12228457629680634, 'pitches': 0.12228457629680634, 'rec': 0.12228457629680634, 'kld': 100113.75, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9569256901741028}


tensor(False, device='cuda:0')
Training on batch 49/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:56.07
Losses:
{'tot': 0.07968085259199142, 'pitches': 0.07968085259199142, 'rec': 0.07968085259199142, 'kld': 110997.4140625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9743177890777588}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 50/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:56.45
Losses:
{'tot': 0.0734378919005394, 'pitches': 0.0734378919005394, 'rec': 0.0734378919005394, 'kld': 102180.8984375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9723352193832397}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 51/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:21:56.83
Losses:
{'tot': 0.0877549797296524, 'pitches': 0.0877549797296524, 'rec': 0.0877549797296524, 'kld': 102307.6953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.97083979845047

tensor(False, device='cuda:0')
Training on batch 73/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:04.75
Losses:
{'tot': 0.12312790006399155, 'pitches': 0.12312790006399155, 'rec': 0.12312790006399155, 'kld': 100221.40625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9569892287254333}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 74/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:05.09
Losses:
{'tot': 0.11736756563186646, 'pitches': 0.11736756563186646, 'rec': 0.11736756563186646, 'kld': 101788.234375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9549393653869629}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 75/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:05.48
Losses:
{'tot': 0.11246772855520248, 'pitches': 0.11246772855520248, 'rec': 0.11246772855520248, 'kld': 93982.828125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9581497907638

tensor(False, device='cuda:0')
Training on batch 97/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:13.36
Losses:
{'tot': 0.09544522315263748, 'pitches': 0.09544522315263748, 'rec': 0.09544522315263748, 'kld': 101908.21875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9644343256950378}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 98/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:13.71
Losses:
{'tot': 0.10954586416482925, 'pitches': 0.10954586416482925, 'rec': 0.10954586416482925, 'kld': 92547.3203125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9616407155990601}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 99/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:14.03
Losses:
{'tot': 0.14050187170505524, 'pitches': 0.14050187170505524, 'rec': 0.14050187170505524, 'kld': 102683.7421875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.94341945648

tensor(False, device='cuda:0')
Training on batch 121/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:22.94
Losses:
{'tot': 0.11428695917129517, 'pitches': 0.11428695917129517, 'rec': 0.11428695917129517, 'kld': 105202.9296875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9558248519897461}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 122/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:23.28
Losses:
{'tot': 0.11747909337282181, 'pitches': 0.11747909337282181, 'rec': 0.11747909337282181, 'kld': 109172.4375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9583333134651184}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 123/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:23.65
Losses:
{'tot': 0.10788872092962265, 'pitches': 0.10788872092962265, 'rec': 0.10788872092962265, 'kld': 100558.375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.962423324584

tensor(False, device='cuda:0')
Training on batch 145/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:31.48
Losses:
{'tot': 0.1354193538427353, 'pitches': 0.1354193538427353, 'rec': 0.1354193538427353, 'kld': 107726.609375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9490106701850891}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 146/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:31.84
Losses:
{'tot': 0.0844171866774559, 'pitches': 0.0844171866774559, 'rec': 0.0844171866774559, 'kld': 109066.96875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9709510803222656}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 147/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:32.22
Losses:
{'tot': 0.10042843222618103, 'pitches': 0.10042843222618103, 'rec': 0.10042843222618103, 'kld': 109307.734375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.958904087543487

tensor(False, device='cuda:0')
Training on batch 169/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:39.99
Losses:
{'tot': 0.10417384654283524, 'pitches': 0.10417384654283524, 'rec': 0.10417384654283524, 'kld': 112082.546875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9617612957954407}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 170/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:40.34
Losses:
{'tot': 0.11289668083190918, 'pitches': 0.11289668083190918, 'rec': 0.11289668083190918, 'kld': 113281.125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9595789313316345}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 171/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:40.70
Losses:
{'tot': 0.10208510607481003, 'pitches': 0.10208510607481003, 'rec': 0.10208510607481003, 'kld': 112765.5234375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9622783660

tensor(False, device='cuda:0')
Training on batch 193/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:48.41
Losses:
{'tot': 0.10989845544099808, 'pitches': 0.10989845544099808, 'rec': 0.10989845544099808, 'kld': 120673.78125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9610331058502197}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 194/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:48.78
Losses:
{'tot': 0.1002361923456192, 'pitches': 0.1002361923456192, 'rec': 0.1002361923456192, 'kld': 104677.3125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9652103781700134}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 195/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:49.12
Losses:
{'tot': 0.1069435104727745, 'pitches': 0.1069435104727745, 'rec': 0.1069435104727745, 'kld': 108757.2578125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9633304476737976

tensor(False, device='cuda:0')
Training on batch 217/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:57.76
Losses:
{'tot': 0.10625937581062317, 'pitches': 0.10625937581062317, 'rec': 0.10625937581062317, 'kld': 100543.5234375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9639856219291687}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 218/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:58.13
Losses:
{'tot': 0.12674982845783234, 'pitches': 0.12674982845783234, 'rec': 0.12674982845783234, 'kld': 107472.7421875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9519084095954895}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 219/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:22:58.49
Losses:
{'tot': 0.11938486993312836, 'pitches': 0.11938486993312836, 'rec': 0.11938486993312836, 'kld': 108443.3125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95640635

tensor(False, device='cuda:0')
Training on batch 241/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:06.59
Losses:
{'tot': 0.11831460148096085, 'pitches': 0.11831460148096085, 'rec': 0.11831460148096085, 'kld': 105732.4375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9532710313796997}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 242/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:06.96
Losses:
{'tot': 0.12119381129741669, 'pitches': 0.12119381129741669, 'rec': 0.12119381129741669, 'kld': 104199.0, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9586840271949768}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 243/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:07.31
Losses:
{'tot': 0.11729804426431656, 'pitches': 0.11729804426431656, 'rec': 0.11729804426431656, 'kld': 112998.671875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.953680753707885

tensor(False, device='cuda:0')
Training on batch 265/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:14.94
Losses:
{'tot': 0.08809537440538406, 'pitches': 0.08809537440538406, 'rec': 0.08809537440538406, 'kld': 100387.921875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.967007040977478}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 266/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:15.29
Losses:
{'tot': 0.11977213621139526, 'pitches': 0.11977213621139526, 'rec': 0.11977213621139526, 'kld': 105843.359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9574980139732361}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 267/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:15.66
Losses:
{'tot': 0.13807936012744904, 'pitches': 0.13807936012744904, 'rec': 0.13807936012744904, 'kld': 104710.78125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9527272582

tensor(False, device='cuda:0')
Training on batch 289/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:23.48
Losses:
{'tot': 0.12158874422311783, 'pitches': 0.12158874422311783, 'rec': 0.12158874422311783, 'kld': 100199.0078125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9565567374229431}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 290/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:23.86
Losses:
{'tot': 0.09704498946666718, 'pitches': 0.09704498946666718, 'rec': 0.09704498946666718, 'kld': 109799.765625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9659348726272583}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 291/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:24.23
Losses:
{'tot': 0.12615735828876495, 'pitches': 0.12615735828876495, 'rec': 0.12615735828876495, 'kld': 102345.4765625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.951978

tensor(False, device='cuda:0')
Training on batch 313/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:32.87
Losses:
{'tot': 0.0931788980960846, 'pitches': 0.0931788980960846, 'rec': 0.0931788980960846, 'kld': 113806.328125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9642390012741089}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 314/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:33.21
Losses:
{'tot': 0.13847476243972778, 'pitches': 0.13847476243972778, 'rec': 0.13847476243972778, 'kld': 113524.625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9483587145805359}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 315/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:33.56
Losses:
{'tot': 0.11405501514673233, 'pitches': 0.11405501514673233, 'rec': 0.11405501514673233, 'kld': 111505.671875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.96048992872238

tensor(False, device='cuda:0')
Training on batch 337/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:41.28
Losses:
{'tot': 0.10036159306764603, 'pitches': 0.10036159306764603, 'rec': 0.10036159306764603, 'kld': 109080.203125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9626205563545227}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 338/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:41.61
Losses:
{'tot': 0.14597661793231964, 'pitches': 0.14597661793231964, 'rec': 0.14597661793231964, 'kld': 101448.625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9484493136405945}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 339/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:41.97
Losses:
{'tot': 0.11793441325426102, 'pitches': 0.11793441325426102, 'rec': 0.11793441325426102, 'kld': 105879.8046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9587935209

tensor(False, device='cuda:0')
Training on batch 361/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:49.80
Losses:
{'tot': 0.0877404734492302, 'pitches': 0.0877404734492302, 'rec': 0.0877404734492302, 'kld': 117352.1875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9608637094497681}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 362/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:50.19
Losses:
{'tot': 0.11791620403528214, 'pitches': 0.11791620403528214, 'rec': 0.11791620403528214, 'kld': 99870.9453125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9592822790145874}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 363/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:50.56
Losses:
{'tot': 0.11917903274297714, 'pitches': 0.11917903274297714, 'rec': 0.11917903274297714, 'kld': 107418.125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9540570378303528

tensor(False, device='cuda:0')
Training on batch 385/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:58.31
Losses:
{'tot': 0.1253775656223297, 'pitches': 0.1253775656223297, 'rec': 0.1253775656223297, 'kld': 110421.96875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9573770761489868}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 386/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:58.67
Losses:
{'tot': 0.12932689487934113, 'pitches': 0.12932689487934113, 'rec': 0.12932689487934113, 'kld': 107747.328125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9544781446456909}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 387/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:23:59.15
Losses:
{'tot': 0.12417256087064743, 'pitches': 0.12417256087064743, 'rec': 0.12417256087064743, 'kld': 111034.1953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95608109235

tensor(False, device='cuda:0')
Training on batch 409/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:06.92
Losses:
{'tot': 0.11757104098796844, 'pitches': 0.11757104098796844, 'rec': 0.11757104098796844, 'kld': 101356.5234375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9558823704719543}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 410/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:07.29
Losses:
{'tot': 0.09355002641677856, 'pitches': 0.09355002641677856, 'rec': 0.09355002641677856, 'kld': 103768.9375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9633727669715881}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 411/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:07.68
Losses:
{'tot': 0.13267643749713898, 'pitches': 0.13267643749713898, 'rec': 0.13267643749713898, 'kld': 103423.296875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.956351995

tensor(False, device='cuda:0')
Training on batch 433/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:16.29
Losses:
{'tot': 0.133271262049675, 'pitches': 0.133271262049675, 'rec': 0.133271262049675, 'kld': 103686.6328125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9511567950248718}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 434/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:16.68
Losses:
{'tot': 0.10814874619245529, 'pitches': 0.10814874619245529, 'rec': 0.10814874619245529, 'kld': 103492.703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9581412076950073}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 435/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:17.07
Losses:
{'tot': 0.1511930525302887, 'pitches': 0.1511930525302887, 'rec': 0.1511930525302887, 'kld': 106081.9140625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.942822873592376

tensor(False, device='cuda:0')
Training on batch 457/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:25.03
Losses:
{'tot': 0.15816596150398254, 'pitches': 0.15816596150398254, 'rec': 0.15816596150398254, 'kld': 102601.703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.944803774356842}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 458/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:25.41
Losses:
{'tot': 0.16837777197360992, 'pitches': 0.16837777197360992, 'rec': 0.16837777197360992, 'kld': 100194.734375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.943828284740448}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 459/471 of epoch 29/100 complete.
Elapsed time from start (h:m:s): 01:24:25.76
Losses:
{'tot': 0.09122777730226517, 'pitches': 0.09122777730226517, 'rec': 0.09122777730226517, 'kld': 108950.921875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9723926186

  0%|          | 0/471 [00:00<?, ?it/s]

tensor(False, device='cuda:0')
Training on batch 1/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:30.25
Losses:
{'tot': 0.07696886360645294, 'pitches': 0.07696886360645294, 'rec': 0.07696886360645294, 'kld': 98793.703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9721577763557434}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 2/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:30.64
Losses:
{'tot': 0.10214164108037949, 'pitches': 0.10214164108037949, 'rec': 0.10214164108037949, 'kld': 97919.84375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9586437940597534}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 3/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:31.02
Losses:
{'tot': 0.10653982311487198, 'pitches': 0.10653982311487198, 'rec': 0.10653982311487198, 'kld': 97900.09375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.957932710647583}

-

tensor(False, device='cuda:0')
Training on batch 25/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:39.30
Losses:
{'tot': 0.1024145781993866, 'pitches': 0.1024145781993866, 'rec': 0.1024145781993866, 'kld': 103247.734375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9621380567550659}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 26/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:39.64
Losses:
{'tot': 0.13397862017154694, 'pitches': 0.13397862017154694, 'rec': 0.13397862017154694, 'kld': 104575.921875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9557809233665466}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 27/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:39.99
Losses:
{'tot': 0.10690240561962128, 'pitches': 0.10690240561962128, 'rec': 0.10690240561962128, 'kld': 108567.78125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.961856067180633

tensor(False, device='cuda:0')
Training on batch 49/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:48.87
Losses:
{'tot': 0.10979843884706497, 'pitches': 0.10979843884706497, 'rec': 0.10979843884706497, 'kld': 100440.890625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.962445080280304}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 50/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:49.22
Losses:
{'tot': 0.11073977500200272, 'pitches': 0.11073977500200272, 'rec': 0.11073977500200272, 'kld': 103526.6875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9609914422035217}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 51/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:49.55
Losses:
{'tot': 0.08003200590610504, 'pitches': 0.08003200590610504, 'rec': 0.08003200590610504, 'kld': 105334.453125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.97017467021942

tensor(False, device='cuda:0')
Training on batch 73/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:57.32
Losses:
{'tot': 0.09755296260118484, 'pitches': 0.09755296260118484, 'rec': 0.09755296260118484, 'kld': 101855.125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9647803902626038}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 74/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:57.66
Losses:
{'tot': 0.09282185137271881, 'pitches': 0.09282185137271881, 'rec': 0.09282185137271881, 'kld': 106025.1875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.966486930847168}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 75/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:24:58.02
Losses:
{'tot': 0.08131232857704163, 'pitches': 0.08131232857704163, 'rec': 0.08131232857704163, 'kld': 103919.5390625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9691558480262756

tensor(False, device='cuda:0')
Training on batch 97/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:05.71
Losses:
{'tot': 0.10000070929527283, 'pitches': 0.10000070929527283, 'rec': 0.10000070929527283, 'kld': 88655.84375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9655172228813171}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 98/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:06.05
Losses:
{'tot': 0.0815393328666687, 'pitches': 0.0815393328666687, 'rec': 0.0815393328666687, 'kld': 104753.359375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9697946310043335}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 99/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:06.39
Losses:
{'tot': 0.08341504633426666, 'pitches': 0.08341504633426666, 'rec': 0.08341504633426666, 'kld': 104048.9140625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.964752495288848

tensor(False, device='cuda:0')
Training on batch 121/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:13.92
Losses:
{'tot': 0.0846593901515007, 'pitches': 0.0846593901515007, 'rec': 0.0846593901515007, 'kld': 119049.640625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9715784788131714}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 122/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:14.25
Losses:
{'tot': 0.10862723737955093, 'pitches': 0.10862723737955093, 'rec': 0.10862723737955093, 'kld': 103502.8828125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9610339403152466}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 123/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:14.59
Losses:
{'tot': 0.08537057787179947, 'pitches': 0.08537057787179947, 'rec': 0.08537057787179947, 'kld': 104182.046875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9688870906

tensor(False, device='cuda:0')
Training on batch 145/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:23.42
Losses:
{'tot': 0.09727994352579117, 'pitches': 0.09727994352579117, 'rec': 0.09727994352579117, 'kld': 111418.625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9656533598899841}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 146/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:23.78
Losses:
{'tot': 0.08904919773340225, 'pitches': 0.08904919773340225, 'rec': 0.08904919773340225, 'kld': 108489.890625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9657307863235474}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 147/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:24.13
Losses:
{'tot': 0.08422790467739105, 'pitches': 0.08422790467739105, 'rec': 0.08422790467739105, 'kld': 103589.8671875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9667673707

tensor(False, device='cuda:0')
Training on batch 169/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:31.90
Losses:
{'tot': 0.1229545846581459, 'pitches': 0.1229545846581459, 'rec': 0.1229545846581459, 'kld': 108690.984375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9553030133247375}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 170/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:32.24
Losses:
{'tot': 0.11872880160808563, 'pitches': 0.11872880160808563, 'rec': 0.11872880160808563, 'kld': 96438.0, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9547491073608398}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 171/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:32.58
Losses:
{'tot': 0.12209818512201309, 'pitches': 0.12209818512201309, 'rec': 0.12209818512201309, 'kld': 108347.703125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9569932818412781}

tensor(False, device='cuda:0')
Training on batch 193/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:40.15
Losses:
{'tot': 0.118560791015625, 'pitches': 0.118560791015625, 'rec': 0.118560791015625, 'kld': 116006.34375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95564204454422}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 194/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:40.48
Losses:
{'tot': 0.09049858897924423, 'pitches': 0.09049858897924423, 'rec': 0.09049858897924423, 'kld': 106240.9375, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9682797789573669}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 195/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:40.82
Losses:
{'tot': 0.12423056364059448, 'pitches': 0.12423056364059448, 'rec': 0.12423056364059448, 'kld': 111729.40625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9554166793823242}

-

tensor(False, device='cuda:0')
Training on batch 217/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:48.64
Losses:
{'tot': 0.12395550310611725, 'pitches': 0.12395550310611725, 'rec': 0.12395550310611725, 'kld': 109357.15625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9533699154853821}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 218/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:49.00
Losses:
{'tot': 0.09874580055475235, 'pitches': 0.09874580055475235, 'rec': 0.09874580055475235, 'kld': 98189.390625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9672130942344666}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 219/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:49.35
Losses:
{'tot': 0.12244235724210739, 'pitches': 0.12244235724210739, 'rec': 0.12244235724210739, 'kld': 103768.546875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9544538259

tensor(False, device='cuda:0')
Training on batch 241/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:57.29
Losses:
{'tot': 0.1185510903596878, 'pitches': 0.1185510903596878, 'rec': 0.1185510903596878, 'kld': 107262.265625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9545810222625732}

----------------------------------------


Saving model to disk...

tensor(False, device='cuda:0')
Training on batch 242/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:58.67
Losses:
{'tot': 0.11972299963235855, 'pitches': 0.11972299963235855, 'rec': 0.11972299963235855, 'kld': 105428.21875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.951953113079071}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 243/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:25:59.03
Losses:
{'tot': 0.10758668184280396, 'pitches': 0.10758668184280396, 'rec': 0.10758668184280396, 'kld': 106773.359375, 'beta*kld': 0.0}
Accuracies:
{

tensor(False, device='cuda:0')
Training on batch 265/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:07.18
Losses:
{'tot': 0.0947435200214386, 'pitches': 0.0947435200214386, 'rec': 0.0947435200214386, 'kld': 104720.1640625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9659134149551392}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 266/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:07.57
Losses:
{'tot': 0.11306414008140564, 'pitches': 0.11306414008140564, 'rec': 0.11306414008140564, 'kld': 98646.890625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9576923251152039}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 267/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:07.91
Losses:
{'tot': 0.09467194974422455, 'pitches': 0.09467194974422455, 'rec': 0.09467194974422455, 'kld': 105898.625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.96628350019454

tensor(False, device='cuda:0')
Training on batch 289/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:15.51
Losses:
{'tot': 0.17818738520145416, 'pitches': 0.17818738520145416, 'rec': 0.17818738520145416, 'kld': 102292.1953125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9351009130477905}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 290/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:15.84
Losses:
{'tot': 0.11381904035806656, 'pitches': 0.11381904035806656, 'rec': 0.11381904035806656, 'kld': 107118.21875, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9606791734695435}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 291/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:16.22
Losses:
{'tot': 0.11736701428890228, 'pitches': 0.11736701428890228, 'rec': 0.11736701428890228, 'kld': 105815.890625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.95948368

tensor(False, device='cuda:0')
Training on batch 313/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:24.10
Losses:
{'tot': 0.14021004736423492, 'pitches': 0.14021004736423492, 'rec': 0.14021004736423492, 'kld': 106464.0078125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9550561904907227}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 314/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:24.46
Losses:
{'tot': 0.10315755009651184, 'pitches': 0.10315755009651184, 'rec': 0.10315755009651184, 'kld': 104242.3203125, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.9637883305549622}

----------------------------------------

tensor(False, device='cuda:0')
Training on batch 315/471 of epoch 30/100 complete.
Elapsed time from start (h:m:s): 01:26:24.80
Losses:
{'tot': 0.12052535265684128, 'pitches': 0.12052535265684128, 'rec': 0.12052535265684128, 'kld': 108161.890625, 'beta*kld': 0.0}
Accuracies:
{'pitches': 0.952964

## Reconstructions

In [None]:
checkpoint = torch.load('models/just_pitches_warmup')

In [None]:
state_dict = checkpoint['model_state_dict']
vae = VAE().to(device)

In [None]:
vae.load_state_dict(state_dict)

In [None]:
loader = DataLoader(dataset, batch_size=32, shuffle=False)
len(dataset)

In [None]:
for idx, inputs in enumerate(loader):
    
    x_seq, x_acts, x_graph, src_mask = inputs
    x_seq = x_seq.float().to(device)
    x_acts = x_acts.to(device)
    x_graph = x_graph.to(device)
    src_mask = src_mask.to(device)
    tgt_mask = generate_square_subsequent_mask(x_seq.size(-2)-1).to(device)

    # Forward pass, get the reconstructions
    outputs, mu, log_var = vae(x_seq, x_acts, x_graph, src_mask, tgt_mask)
    
    break

seq_rec, _  = outputs

In [None]:
x_seq.size()

In [None]:
seq_rec.size()

In [None]:
x_acts.size()

Create dense reconstruction from sparse reconstruction:

In [None]:
seq_rec_dense = torch.zeros(x_seq.size(), dtype=torch.float).to(device)
seq_rec_dense = seq_rec_dense[..., 1:, :]
size = seq_rec_dense.size()

seq_rec_dense = seq_rec_dense.view(-1, seq_rec_dense.size(-2), seq_rec_dense.size(-1))

silence = torch.zeros(seq_rec_dense.size(-2), seq_rec_dense.size(-1)).to(device)
silence[:, 129] = 1. # eos token

seq_rec_dense[x_acts.bool().view(-1)] = seq_rec
seq_rec_dense[torch.logical_not(x_acts.bool().view(-1))] = silence

seq_rec_dense = seq_rec_dense.view(size)

In [None]:
print(seq_rec_dense.size())
print(x_seq.size())

In [None]:
music_real = x_seq[0]
music_rec = seq_rec_dense[0]

In [None]:
music_real.size()

In [None]:
prefix = "data/music/"

real = from_tensor_to_muspy(music_real, track_data)
muspy.show_pianoroll(real, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "real" + ".png")
muspy.write_midi(prefix + "real" + ".mid", real)

In [None]:
rec = from_tensor_to_muspy(music_rec, track_data)
muspy.show_pianoroll(rec, yticklabel='off', grid_axis='off')
plt.savefig(prefix + "rec" + ".png")
muspy.write_midi(prefix + "rec" + ".mid", rec)

Plot music and save it to disk

In [None]:
#tracks = [drum_track, bass_track, guitar_track, strings_track]
import copy

def from_tensor_to_muspy(music_tensor, track_data):
    
    powers = torch.tensor([2**n for n in reversed(range(9))], dtype=torch.float)
    tracks = []
    
    for tr in range(music_tensor.size(0)):
        
        notes = []
        
        for ts in range(music_tensor.size(1)):
            for note in range(music_tensor.size(2)):
                
                pitch = music_tensor[tr, ts, note, :131]
                pitch = torch.argmax(pitch)

                if pitch == 129:
                    break
                
                if pitch != 128:
                    #dur = music_tensor[tr, ts, note, 131:]
                    #dur = torch.dot(dur, powers).long()
                    
                    dur = 4
                    
                    #notes.append(muspy.Note(ts, pitch.item(), dur.item(), 64))
                    notes.append(muspy.Note(ts, pitch.item(), dur, 64))
        
        if track_data[tr][0] == 'Drums':
            track = muspy.Track(name='Drums', is_drum=True, notes=copy.deepcopy(notes))
        else:
            track = muspy.Track(name=track_data[tr][0], 
                                program=track_data[tr][1],
                                notes=copy.deepcopy(notes))
        tracks.append(track)
    
    meta = muspy.Metadata(title='prova')
    music = muspy.Music(tracks=tracks, metadata=meta, resolution=RESOLUTION)
    
    return music


track_data = [('Drums', -1), ('Bass', 34), ('Guitar', 1), ('Strings', 41)]

In [None]:
prefix = "data/music/file"

for i in range(10):
    music_tensor = dataset[20+i][0]
    music = from_tensor_to_muspy(music_tensor, track_data)
    muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
    plt.savefig(prefix + str(i) + ".png")
    muspy.write_midi(prefix + str(i) + ".mid", music)

In [None]:
music

In [None]:
music_path = "data/music/file2.mid"
muspy.show_pianoroll(music, yticklabel='off', grid_axis='off')
plt.savefig('file2.png')
muspy.write_midi(music_path, music)

In [None]:
print(dataset[0][0].size())
notes = []
notes.append(muspy.Note(1, 48, 20, 64))
drums = muspy.Track(is_drum=True)
bass = muspy.Track(program=34, notes=notes)
guitar = muspy.Track(program=27, notes=[])
strings = muspy.Track(program=42, notes=[muspy.Note(0, 100, 4, 64), muspy.Note(4, 91, 20, 64)])

tracks = [drums, bass, guitar, strings]

meta = muspy.Metadata(title='prova')
music = muspy.Music(tracks=tracks, metadata=meta, resolution=32)

In [None]:
!ls data/lmd_matched/M/T/O/TRMTOBP128E07822EF/63edabc86c087f07eca448b0edad53c3.mid

# Stuff

In [None]:
from typing import Optional, Union, Tuple
from torch_geometric.typing import OptTensor, Adj
from typing import Callable
from torch_geometric.nn.inits import reset

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter as Param
from torch.nn import Parameter
from torch_scatter import scatter
from torch_sparse import SparseTensor, matmul, masked_select_nnz
from torch_geometric.nn.conv import MessagePassing

from torch_geometric.nn.inits import glorot, zeros


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (Tensor, Tensor) -> Tensor
    pass


@torch.jit._overload
def masked_edge_index(edge_index, edge_mask):
    # type: (SparseTensor, Tensor) -> SparseTensor
    pass


def masked_edge_index(edge_index, edge_mask):
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]
    else:
        return masked_select_nnz(edge_index, edge_mask, layout='coo')

def masked_edge_attrs(edge_attrs, edge_mask):
    return edge_attrs[edge_mask, :]


class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
        \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
    stores a relation identifier
    :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.

    .. note::
        This implementation is as memory-efficient as possible by iterating
        over each individual relation type.
        Therefore, it may result in low GPU utilization in case the graph has a
        large number of relations.
        As an alternative approach, :class:`FastRGCNConv` does not iterate over
        each individual type, but may consume a large amount of memory to
        compensate.
        We advise to check out both implementations to see which one fits your
        needs.

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
            In case no input features are given, this argument should
            correspond to the number of nodes in your graph.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
            maps edge features :obj:`edge_attr` of shape :obj:`[-1,
            num_edge_features]` to shape
            :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
            :class:`torch.nn.Sequential`.
        num_bases (int, optional): If set to not :obj:`None`, this layer will
            use the basis-decomposition regularization scheme where
            :obj:`num_bases` denotes the number of bases to use.
            (default: :obj:`None`)
        num_blocks (int, optional): If set to not :obj:`None`, this layer will
            use the block-diagonal-decomposition regularization scheme where
            :obj:`num_blocks` denotes the number of blocks to use.
            (default: :obj:`None`)
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"mean"`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_relations: int,
        nn: Callable,
        num_bases: Optional[int] = None,
        num_blocks: Optional[int] = None,
        aggr: str = 'mean',
        root_weight: bool = True,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(aggr=aggr, node_dim=0, **kwargs)

        if num_bases is not None and num_blocks is not None:
            raise ValueError('Can not apply both basis-decomposition and '
                             'block-diagonal-decomposition at the same time.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.nn = nn
        self.num_relations = num_relations
        self.num_bases = num_bases
        self.num_blocks = num_blocks

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)
        self.in_channels_l = in_channels[0]

        if num_bases is not None:
            self.weight = Parameter(
                torch.Tensor(num_bases, in_channels[0], out_channels))
            self.comp = Parameter(torch.Tensor(num_relations, num_bases))

        elif num_blocks is not None:
            assert (in_channels[0] % num_blocks == 0
                    and out_channels % num_blocks == 0)
            self.weight = Parameter(
                torch.Tensor(num_relations, num_blocks,
                             in_channels[0] // num_blocks,
                             out_channels // num_blocks))
            self.register_parameter('comp', None)

        else:
            self.weight = Parameter(
                torch.Tensor(num_relations, in_channels[0], out_channels))
            self.register_parameter('comp', None)

        if root_weight:
            self.root = Param(torch.Tensor(in_channels[1], out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Param(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        reset(self.nn)
        glorot(self.comp)
        glorot(self.root)
        zeros(self.bias)


    def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
                edge_index: Adj, edge_type: OptTensor = None,
                edge_attr: OptTensor = None):
        r"""
        Args:
            x: The input node features. Can be either a :obj:`[num_nodes,
                in_channels]` node feature matrix, or an optional
                one-dimensional node index tensor (in which case input features
                are treated as trainable node embeddings).
                Furthermore, :obj:`x` can be of type :obj:`tuple` denoting
                source and destination node features.
            edge_type: The one-dimensional relation type/index for each edge in
                :obj:`edge_index`.
                Should be only :obj:`None` in case :obj:`edge_index` is of type
                :class:`torch_sparse.tensor.SparseTensor`.
                (default: :obj:`None`)
        """

        # Convert input features to a pair of node features or node indices.
        x_l: OptTensor = None
        if isinstance(x, tuple):
            x_l = x[0]
        else:
            x_l = x
        if x_l is None:
            x_l = torch.arange(self.in_channels_l, device=self.weight.device)

        x_r: Tensor = x_l
        if isinstance(x, tuple):
            x_r = x[1]

        size = (x_l.size(0), x_r.size(0))

        if isinstance(edge_index, SparseTensor):
            edge_type = edge_index.storage.value()
        assert edge_type is not None

        # propagate_type: (x: Tensor)
        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

        weight = self.weight
        if self.num_bases is not None:  # Basis-decomposition =================
            weight = (self.comp @ weight.view(self.num_bases, -1)).view(
                self.num_relations, self.in_channels_l, self.out_channels)

        if self.num_blocks is not None:  # Block-diagonal-decomposition =====

            if x_l.dtype == torch.long and self.num_blocks is not None:
                raise ValueError('Block-diagonal decomposition not supported '
                                 'for non-continuous input features.')

            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                h = self.propagate(tmp, x=x_l, size=size)
                h = h.view(-1, weight.size(1), weight.size(2))
                h = torch.einsum('abc,bcd->abd', h, weight[i])
                out += h.contiguous().view(-1, self.out_channels)

        else:  # No regularization/Basis-decomposition ========================
            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                attr = masked_edge_attrs(edge_attr, edge_type == i)

                if x_l.dtype == torch.long:
                    out += self.propagate(tmp, x=weight[i, x_l], size=size)
                else:
                    h = self.propagate(tmp, x=x_l, size=size,
                                       edge_attr=attr)
                    out = out + (h @ weight[i])

        root = self.root
        if root is not None:
            out += root[x_r] if x_r.dtype == torch.long else x_r @ root

        if self.bias is not None:
            out += self.bias

        return out


    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        weight = self.nn(edge_attr)
        weight = weight.view(-1, self.in_channels_l, self.in_channels_l)
        return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_relations={self.num_relations})')

next edges

In [None]:
import itertools

a = np.random.randint(2, size=(4,8))
a_t = a.transpose()
print(a_t)
inds = np.stack(np.where(a_t == 1)).transpose()
ts_acts = np.any(a_t, axis=1)
ts_inds = np.where(ts_acts)[0]

labels = np.arange(32).reshape(4, 8).transpose()
print(labels)

next_edges = []
for i in range(len(ts_inds)-1):
    ind_s = ts_inds[i]
    ind_e = ts_inds[i+1]
    s = inds[inds[:,0] == ind_s]
    e = inds[inds[:,0] == ind_e]
    e_inds = [t for t in list(itertools.product(s, e)) if t[0][1] != t[1][1]]
    edges = [(labels[tuple(e[0])],labels[tuple(e[1])], ind_e-ind_s) for e in e_inds]
    next_edges.extend(edges)

print(next_edges)
    

onset edges

In [None]:
onset_edges = []
print(a_t)
print(labels)

for i in ts_inds:
    ts_acts_inds = list(inds[inds[:,0] == i])
    if len(ts_acts_inds) < 2:
        continue
    e_inds = list(itertools.combinations(ts_acts_inds, 2))
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], 0) for e in e_inds]
    inv_edges = [(e[1], e[0], *e[2:]) for e in edges]
    onset_edges.extend(edges)
    onset_edges.extend(inv_edges)

print(onset_edges)


track edges

In [None]:
print(a_t)
print(labels)
track_edges = []

for track in range(a_t.shape[1]):
    tr_inds = list(inds[inds[:,1] == track])
    e_inds = [(tr_inds[i],
               tr_inds[i+1]) for i in range(len(tr_inds)-1)]
    print(e_inds)
    edges = [(labels[tuple(e[0])], labels[tuple(e[1])], e[1][0]-e[0][0]) for e in e_inds]
    track_edges.extend(edges)

print(track_edges)

In [None]:
track_edges = np.array(track_edges)
onset_edges = np.array(onset_edges)
np.concatenate((track_edges, onset_edges)).shape

In [None]:
pip install pypianoroll

In [None]:
import pypianoroll

In [None]:
multitrack = pypianoroll.read("tests_fur-elise.mid")
print(multitrack)

In [None]:
multitrack.tracks[0].pianoroll

In [None]:
multitrack.plot()

In [None]:
multitrack.trim(0, 12 * multitrack.resolution)
multitrack.binarize()

In [None]:
multitrack.plot()

In [None]:
multitrack.tracks[0].pianoroll.shape