In [None]:
!git clone https://github.com/chathasphere/pno-ai.git 

In [None]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip

In [None]:
!unzip -q maestro-v2.0.0-midi.zip -d data/

In [None]:
!pip install pretty_midi 

In [None]:
!mv data/maestro-v2.0.0/2004 pno-ai/data/maestro-v2.0.0 

In [None]:
cd pno-ai/

In [None]:
import os, time, datetime
import torch 
import torch.nn as nn
from random import shuffle
from preprocess import PreprocessingPipeline
from train import train
import argparse
import numpy as np
import torch.nn.functional as F
from preprocess import SequenceEncoder 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from helpers import clones, d
import math 
from tqdm import tqdm

In [None]:
class AttentionError(Exception):
    pass

class MultiheadedAttention(nn.Module):
    """
    Narrow multiheaded attention. Each attention head inspects a 
    fraction of the embedding space and expresses attention vectors for each sequence position as a weighted average of all (earlier) positions.
    """

    def __init__(self, d_model, heads=8, dropout=0.1, relative_pos=True):

        super().__init__()
        if d_model % heads != 0:
            raise AttentionError("Number of heads does not divide model dimension")
        self.d_model = d_model
        self.heads = heads
        s = d_model // heads
        self.linears = torch.nn.ModuleList([nn.Linear(s, s, bias=False) for i in range(3)])
        self.recombine_heads = nn.Linear(heads * s, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.max_length = 256
        #relative positional embeddings
        self.relative_pos = relative_pos
        if relative_pos:
            self.Er = torch.randn([heads, self.max_length, s],
                    device=d())
        else:
            self.Er = None

    def forward(self, x, mask):
        #batch size, sequence length, embedding dimension
        b, t, e = x.size()
        h = self.heads
        #each head inspects a fraction of the embedded space
        #head dimension
        s = e // h
        #start index of position embedding
        embedding_start = self.max_length - t
        x = x.view(b,t,h,s)
        queries, keys, values = [w(x).transpose(1,2)
                for w, x in zip(self.linears, (x,x,x))]
        if self.relative_pos:
            #apply same position embeddings across the batch
            #Is it possible to apply positional self-attention over
            #only half of all relative distances?
            Er  = self.Er[:, embedding_start:, :].unsqueeze(0)
            QEr = torch.matmul(queries, Er.transpose(-1,-2))
            QEr = self._mask_positions(QEr)
            #Get relative position attention scores
            #combine batch with head dimension
            SRel = self._skew(QEr).contiguous().view(b*h, t, t)
        else:
            SRel = torch.zeros([b*h, t, t], device=d())
        queries, keys, values = map(lambda x: x.contiguous()\
                .view(b*h, t, s), (queries, keys, values))
        #Compute scaled dot-product self-attention
        #scale pre-matrix multiplication   
        queries = queries / (e ** (1/4))
        keys    = keys / (e ** (1/4))

        scores = torch.bmm(queries, keys.transpose(1, 2))
        scores = scores + SRel
        #(b*h, t, t)

        subsequent_mask = torch.triu(torch.ones(1, t, t, device=d()),
                1)
        scores = scores.masked_fill(subsequent_mask == 1, -1e9)
        if mask is not None:
            mask = mask.repeat_interleave(h, 0)
            wtf = (mask == 0).nonzero().transpose(0,1)
            scores[wtf[0], wtf[1], :] = -1e9

        
        #Convert scores to probabilities
        attn_probs = F.softmax(scores, dim=2)
        attn_probs = self.dropout(attn_probs)
        #use attention to get a weighted average of values
        out = torch.bmm(attn_probs, values).view(b, h, t, s)
        #transpose and recombine attention heads
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)
        #last linear layer of weights
        return self.recombine_heads(out)


    def _mask_positions(self, qe):
        #QEr is a matrix of queries (absolute position) dot distance embeddings (relative pos).
        #Mask out invalid relative positions: e.g. if sequence length is L, the query at
        #L-1 can only attend to distance r = 0 (no looking backward).
        L = qe.shape[-1]
        mask = torch.triu(torch.ones(L, L, device=d()), 1).flip(1)
        return qe.masked_fill((mask == 1), 0)

    def _skew(self, qe):
        #pad a column of zeros on the left
        padded_qe = F.pad(qe, [1,0])
        s = padded_qe.shape
        padded_qe = padded_qe.view(s[0], s[1], s[3], s[2])
        #take out first (padded) row
        return padded_qe[:,:,1:,:]

In [None]:
class MusicTransformerError(Exception):
    pass

class MusicTransformer(nn.Module):
    """Generative, autoregressive transformer model. Train on a 
    dataset of encoded musical sequences."""

    def __init__(self, n_tokens, seq_length=None, d_model=64,
            n_heads=4, depth=2, d_feedforward=512, dropout=0.1,
            positional_encoding=False, relative_pos=True, gen=False):
        """
        Args:
            n_tokens: number of commands/states in encoded musical sequence
            seq_length: length of (padded) input/target sequences
            d_model: dimensionality of embedded sequences 
            n_heads: number of attention heads
            depth: number of stacked transformer layers
            d_feedforward: dimensionality of dense sublayer 
            dropout: probability of dropout in dropout sublayer
            relative_pos: (bool) if True, use relative positional embeddings
        """
        super().__init__()
        self.gen = gen
        #number of commands in an encoded musical sequence
        self.n_tokens = n_tokens
        #embedding layer
        self.d_model = d_model
        self.embed = SequenceEmbedding(n_tokens, d_model)
        #positional encoding layer
        self.positional_encoding = positional_encoding
        if self.positional_encoding:
            pos = torch.zeros(5000, d_model)
            position = torch.arange(5000).unsqueeze(1)
            #geometric progression of wave lengths
            div_term = torch.exp(torch.arange(0.0, d_model, 2) * \
                            - (math.log(10000.0) / d_model))
	    #even positions
            pos[0:, 0::2] = torch.sin(position * div_term)
            #odd positions
            pos[0:, 1::2] = torch.cos(position * div_term)
            #batch dimension
            pos = pos.unsqueeze(0)
            #move to GPU if needed
            pos = pos.to(d())
            self.register_buffer('pos', pos)
        else:
            if seq_length == None:
                raise MusicTransformerError("seq_length not provided for positional embeddings")
            self.pos = nn.Embedding(seq_length, d_model)
        #last layer, outputs logits of next token in sequence
        self.to_scores = nn.Linear(d_model, n_tokens)
        self.layers = clones(DecoderLayer(d_model, n_heads,
            d_feedforward, dropout, relative_pos), depth)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        b,t,e = x.size()
        if self.positional_encoding:
            positions = self.pos[:, :t, :]
        else:
            positions = self.pos(torch.arange(t, 
                device=d()))[None, :, :].expand(b, t, e)
        x = x + positions
        #another dropout layer here?
        #pass input batch and mask through layers
        for layer in self.layers:
            x  = layer(x, mask)
        #one last normalization for good measure
        z = self.norm(x)
        return self.to_scores(z)

class DecoderLayer(nn.Module):

    def __init__(self, size, n_heads, d_feedforward, dropout,
            relative_pos):

        super().__init__()
        self.self_attn = MultiheadedAttention(size, n_heads,
                dropout, relative_pos)
        self.feed_forward = PositionwiseFeedForward(size, d_feedforward, dropout)
        self.size = size
        #normalize over mean/std of embedding dimension
        self.norm1 = nn.LayerNorm(size)
        self.norm2 = nn.LayerNorm(size)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)


    def forward(self, x, mask):
        #perform masked attention on input
        #masked so queries cannot attend to subsequent keys
        #Pass through sublayers of attention and feedforward.
        #Apply dropout to sublayer output, add it to input, and norm.
        attn = self.self_attn(x, mask)
        x = x + self.dropout1(attn)
        x = self.norm1(x)

        ff = self.feed_forward(x)
        x = x + self.dropout2(ff)
        x = self.norm2(x)

        return x

class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

class SequenceEmbedding(nn.Module):
    """
    Standard embedding, scaled by the sqrt of model's hidden state size
    """
    def __init__(self, vocab_size, model_size):
        super().__init__()
        self.d_model = model_size
        self.emb = nn.Embedding(vocab_size, model_size)

    def forward(self, x):
        return self.emb(x) * math.sqrt(self.d_model)

In [None]:
class TransGenerator(nn.Module):
    def __init__(self, z_dim, n_tokens, seq_length=None, d_model=64,
            n_heads=4, depth=2, d_feedforward=512, dropout=0.1,
            positional_encoding=False, relative_pos=True):
        """
        Args:
            n_tokens: number of commands/states in encoded musical sequence
            seq_length: length of (padded) input/target sequences
            d_model: dimensionality of embedded sequences
            n_heads: number of attention heads
            depth: number of stacked transformer layers
            d_feedforward: dimensionality of dense sublayer 
            dropout: probability of dropout in dropout sublayer
            relative_pos: (bool) if True, use relative positional embeddings
        """
        super().__init__()
        self.z_dim = z_dim
        self.init_mlp = nn.Linear(z_dim, seq_length)
        self.norm = nn.LayerNorm(seq_length)
        self.relu = nn.ReLU()
        self.trans_block = MusicTransformer(n_tokens, seq_length, d_model, n_heads, depth, d_feedforward, dropout, positional_encoding, relative_pos)

    def forward(self, noise):
        x = self.init_mlp(noise)
        x = self.norm(x)
        x = self.relu(x).transpose(1, 2)
        x = self.trans_block(x)
        return x

In [None]:
class TransDiscriminator(nn.Module):
    def __init__(self, n_tokens, seq_length=None, d_model=64,
            n_heads=4, depth=2, d_feedforward=512, dropout=0.1,
            positional_encoding=False, relative_pos=True):
        """
        Args:
            n_tokens: number of commands/states in encoded musical sequence
            seq_length: length of (padded) input/target sequences
            d_model: dimensionality of embedded sequences
            n_heads: number of attention heads
            depth: number of stacked transformer layers
            d_feedforward: dimensionality of dense sublayer 
            dropout: probability of dropout in dropout sublayer
            relative_pos: (bool) if True, use relative positional embeddings
        """
        super().__init__()
        self.trans_block = MusicTransformer(n_tokens, seq_length, d_model, n_heads, depth, d_feedforward, dropout, positional_encoding, relative_pos, True)
        self.clf = nn.Linear(n_tokens, 1)

    def forward(self, x):
        x = self.trans_block(x)
        x = torch.mean(x, dim=1)
        x = self.clf(x)

        return x

# input shape = (N, Seq_len)
# label shape = (N, Seq_len)
# output shape = (N, Seq_len, N_tokens)
# embedded output = (N, Seq_len, d_model)

In [None]:
sampling_rate = 125
n_velocity_bins = 32
seq_length = 256

pipeline = PreprocessingPipeline(input_dir="data", stretch_factors=[0.975, 1, 1.025],
        split_size=30, sampling_rate=sampling_rate, n_velocity_bins=n_velocity_bins,
        transpositions=range(-2,3), training_val_split=0.9, max_encoded_length=seq_length,
                                min_encoded_length=257)
pipeline_start = time.time()
pipeline.run()
runtime = time.time() - pipeline_start
print(f"MIDI pipeline runtime: {runtime / 60 : .1f}m")

today = datetime.date.today().strftime('%m%d%Y')
checkpoint = f"saved_models/tf_{today}"

training_sequences = pipeline.encoded_sequences['training']
validation_sequences = pipeline.encoded_sequences['validation'] 

In [None]:
def prepare_batches(sequences, batch_size):
    """
    Splits a list of sequences into batches of a fixed size. Each sequence yields an input sequence
    and a target sequence, with the latter one time step ahead. For example, the sequence "to be or not
    to be" gives an input sequence of "to be or not to b" and a target sequence of "o be or not to be."
    """
    n_sequences = len(sequences)
    for i in range(0, n_sequences, batch_size):
        batch = sequences[i:i+batch_size]
	#needs to be in sorted order for packing batches to work
        batch = sorted(batch, key = len, reverse=True)
        input_sequences, target_sequences = [], []

        for sequence in batch:
            input_sequences.append(sequence)
            target_sequences.append(sequence)

        yield input_sequences, target_sequences

In [None]:
def batch_to_tensors(batch, n_tokens, max_length):
    """
    Make input, input mask, and target tensors for a batch of seqa batch of sequences.
    """
    input_sequences, target_sequences = batch
    sequence_lengths = [len(s) for s in input_sequences]
    batch_size = len(input_sequences)

    x = torch.zeros(batch_size, max_length, dtype=torch.long)
    #padding element
    y = torch.zeros(batch_size, max_length, dtype=torch.long)


    for i, sequence in enumerate(input_sequences):
        seq_length = sequence_lengths[i]
        #copy over input sequence data with zero-padding
        #cast to long to be embedded into model's hidden dimension
        x[i, :seq_length] = torch.Tensor(sequence).unsqueeze(0)
    
    x_mask = (x != 0)
    x_mask = x_mask.type(torch.uint8)

    for i, sequence in enumerate(target_sequences):
        seq_length = sequence_lengths[i]
        y[i, :seq_length] = torch.Tensor(sequence).unsqueeze(0)

    if torch.cuda.is_available():
        return x.cuda(), y.cuda(), x_mask.cuda()
    else:
        return x, y, x_mask 

In [None]:
max_length = max((len(L) 
        for L in (training_sequences + validation_sequences))) 

In [None]:
sampling_rate = 125
n_velocity_bins = 32
seq_length = 256 
n_tokens = 256 + sampling_rate + n_velocity_bins
generator = TransGenerator(50, n_tokens, seq_length, 
            d_model = 64, n_heads = 8, d_feedforward=256, 
            depth = 4, dropout = 0.0, positional_encoding=True, relative_pos=True)
discriminator = TransDiscriminator(n_tokens, seq_length, 
            d_model = 64, n_heads = 8, d_feedforward=256, 
            depth = 4, dropout = 0.2, positional_encoding=True, relative_pos=True)
generator.to(d()) 
discriminator.to(d()) 

In [None]:
epoch = 20
batch_size = 64
criterion = nn.BCEWithLogitsLoss()
d_lr = 0.0002
g_lr = 0.0002
gen_opt = torch.optim.Adam(generator.parameters(), lr=g_lr) 
disc_opt = torch.optim.Adam(discriminator.parameters(), lr=d_lr) 
PATH = '/content/drive/MyDrive/musictransgan.pt'

In [None]:
G_losses = []
D_losses = []
disc_real = []
disc_fake = []
disc_fool = []

for e in range(epoch):
    training_batches = prepare_batches(training_sequences, batch_size) 
    for k, batch in enumerate(training_batches):
        generator.train()
        discriminator.train()

        x, y, _ = batch_to_tensors(batch, n_tokens, 
                        max_length)
        
        mean_iteration_critic_loss = 0
        # Discriminator true outputs
        disc_opt.zero_grad()
        print(x.shape)
        y_disc_real = discriminator(x).view(-1)
        disc_loss_real = criterion(y_disc_real, torch.ones_like(y_disc_real))

        # Discriminator outputs via Generotor (False)
        noise = torch.randn(batch_size, 64, 50) # N, d_model, z_dim
        noise = noise.to(d())
        x_gen_fake = generator(noise)
        y_disc_fake = discriminator(x_gen_fake.detach()).view(-1)
        
        disc_loss_fake = criterion(y_disc_fake, torch.zeros_like(y_disc_fake))

        # Update gradients
        disc_loss.backward(retain_graph=True)
        # Optim
        disc_opt.step()

        # Generator
        gen_opt.zero_grad()

        noise_gen = torch.randn(batch_size, 64, 50) # N, d_model, z_dim
        noise_gen = noise_gen.to(d())
        x_gen_fool = generator(noise_gen)
        y_disc_fool = discriminator(x_gen_fool).view(-1)
        loss_gen = criterion(y_disc_fool, torch.ones_like(y_disc_fool))

        loss_gen.backward()
        # Optim
        gen_opt.step()

        # Statistics
        D_losses.append(disc_loss.item())
        G_losses.append(loss_gen.item())
        
        if k % iter == 0:
            print('Epoch: {}, Iter: {}, G_loss: {}, D_loss: {}'.format(e+1, k, loss_gen, loss_disc))

    torch.save({
        'epoch': e,
        'batch_size': batch_size,
        'lr': [d_lr, g_lr],
        'generator_model_state_dict': generator.state_dict(),
        'generator_optimizer_state_dict': gen_opt.state_dict(),
        'discriminator_model_state_dict': discriminator.state_dict(),
        'discriminator_optimizer_state_dict': disc_opt.state_dict(),
        'G_losses': G_losses,
        'D_losses': D_losses,
    }, PATH)

In [None]:
"""
Generate a MIDI event sequence of a fixed length by randomly sampling from a model's distribution of sequences. Optionally, "seed" the sequence with a prime. A well-trained model will create music that responds to the prime and develops upon it.
"""
prime_sequence = []
temperature = 1
#deactivate training mode
model.eval()
if len(prime_sequence) == 0:
    #if no prime is provided, randomly select a starting event
    input_sequence = [np.random.randint(model.n_tokens)]
else:
    input_sequence = prime_sequence.copy()

#add singleton dimension for the batch
input_tensor = torch.LongTensor(input_sequence).unsqueeze(0).cuda()
for i in range(1024):
    #select probabilities of *next* token
    out = model(input_tensor)[0, -1, :]
    #out is a 1d tensor of shape (n_tokens)
    probs = F.softmax(out / temperature, dim=0)
    #sample prob distribution for next character
    c = torch.multinomial(probs,1)
    input_tensor = torch.cat([input_tensor[:,1:], c[None]], dim=1)
    input_sequence.append(c.item()) 