In [2]:
import io
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.utils import download_from_url, extract_archive
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab
from torchtext.data import Field
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from tensorflow.keras.preprocessing.text import Tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [4]:
filename = '../dataset_text/miditokens_waitFix.txt'
with open(filename) as f:
    miditokens = f.readlines()
    
miditokens_tempo_and_sig = [tokens.strip().split(' ') for tokens in miditokens]

miditokens = []
for song in miditokens_tempo_and_sig:
    sig = song[2]
    if sig in ['timesig:4/4', 'timesig:3/4', 'timesig:2/4', 'timesig:6/8']:
        miditokens.append(song)
print("Number of songs: {0}".format(len(miditokens)))

tokenizer = Tokenizer(oov_token='x') # token -> int
tokenizer.fit_on_texts(miditokens)
midiTokensAsInt = tokenizer.texts_to_sequences(miditokens)

Number of songs: 550


In [5]:
with open("../dataset_text/seq2note_int.txt", "r") as f:
    lines = f.readlines()

In [89]:
SEQ_LEN = 1
LINES = 0

for song in midiTokensAsInt:
    for idx in range(0, len(song) - SEQ_LEN, 1):
        LINES += 1
print(LINES)

BATCH_SIZE = 1
VOCAB_SIZE = len(tokenizer.word_index)
steps = LINES // BATCH_SIZE

def batchGenerator(trainData, lines, VOCAB_SIZE, LINES, BATCH_SIZE=32):
    lastLine = 0
    while True:
        
        # https://towardsdatascience.com/how-to-generate-music-using-a-lstm-neural-network-in-keras-68786834d4c5
        
        X_train = []
        y_train = []
        
        for idx in range(lastLine, min(lastLine + BATCH_SIZE, LINES), 1):
            sample = lines[idx].split(", ")
            X_train.append([int(i) for i in sample[0].split(" ")])
            y_train.append([int(sample[1])])
        
        yield torch.tensor(X_train).to(device), torch.tensor(y_train).to(device)
        
        lastLine += BATCH_SIZE
        if lastLine > LINES:
            lastLine = 0

batchGen = batchGenerator(midiTokensAsInt, lines, VOCAB_SIZE, LINES-1, BATCH_SIZE)

1384532


In [90]:
test = next(batchGen)
print(test[0], test[1])

tensor([[151, 207, 153,   4,  22,   2,  23,  28,  73,   3,  74,  57,   3,  58,
           4,  53,   3,  54,  98,   3,  99,  94,   2,  95,  42,   2,  43,  37,
           2,  38,  28,  88,   2,  89,   4,  42,   2,  43,  37,   2,  38,  28,
          88,   2,  89,   4,  15,  49,  73,   3]], device='cuda:0') tensor([[74]], device='cuda:0')


In [91]:
ntokens = VOCAB_SIZE # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

In [126]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

import time
def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(SEQ_LEN).to(device)
    for i in enumerate(range(steps)):
        data, targets = next(batchGen)
        optimizer.zero_grad()
        if data.size(0) != SEQ_LEN:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        print(output.size())
        _, top_ix = torch.topk(output, k=1, axis=2)
        print(top_ix)
        #print(targets)
        loss = criterion(top_ix, targets.squeeze())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // SEQ_LEN, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
    return total_loss

In [127]:
best_loss = float("inf")
epochs = 40 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    loss = train()
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s'.format(epoch, (time.time() - epoch_start_time),))
    print('-' * 89)

    if loss < best_loss:
        best_loss = loss
        best_model = model

    scheduler.step()

torch.Size([1, 50, 310])


RuntimeError: "host_softmax" not implemented for 'Long'

In [None]:
import numpy as np

intToNote = dict(map(reversed, tokenizer.word_index.items()))

songTokens = tokenizer.texts_to_sequences([midi2text(open_midi("../testmidis/sonic.mid"))])[0]
pattern = [songTokens[:50]] 

x = torch.tensor(pattern).to(device)

if x.size(0) != bptt:
    src_mask = model.generate_square_subsequent_mask(x.size(0)).to(device)

y_pred = best_model(x, src_mask)

_, top_ix = torch.topk(y_pred[0], k=1)
choices = top_ix.tolist()
print([intToNote[i] for [i] in choices])
#words.append(int_to_vocab[choice])

# Conversion functions

In [None]:
from music21 import *

def open_midi(midi_path):
    mf = converter.parse(midi_path)
    return mf

# Restricts possible velocities to 8 values, keeping the number of unique note events smaller
# Resembles ppp, pp, p, mp, mf, f, ff, fff dynamics 
def vModifier(velocity):
    if (velocity == 0):
        return 0
    
    velocity = min(127, ((velocity // 16) + 1) * 16)
    return velocity

def tModifier(tempo):
    if (tempo == 0):
        return 0
    
    tempo = ((tempo // 10) + 1) * 10
    return tempo

# Check if there are notes which should have ended before given offset
def checkForNoteOffEvent(currentOffset, noteOffEvents):
    notesToEnd = []
    
    for noteOffEvent in noteOffEvents: # for (notename, endingOffset)
        if noteOffEvent[1] <= currentOffset:
            notesToEnd.append(noteOffEvent)
            
    return notesToEnd

def midi2text(midifile):
    previousElementOffset = 0.0
    offsetChanged = False

    tempoRetrieved = False
    timeSigRetrieved = False
    
    currentVelocity = 0

    tokens = []
    noteOffEvents = []

    tokens.append("START")

    for element in midifile.flat.elements:
        #print(type(element))

        currentElementOffset = element.offset

        notesToEnd = checkForNoteOffEvent(currentElementOffset, noteOffEvents)

        if (len(notesToEnd) != 0):
            for noteToEnd in notesToEnd:
                difference = float(noteToEnd[1]) - float(previousElementOffset)
                if (difference > 0.01):
                    tokens.append("wait:" + str(round(difference, 5)))
                    previousElementOffset = noteToEnd[1]
                tokens.append("note:" + str(noteToEnd[0]) + ":OFF")
                noteOffEvents.remove(noteToEnd)

        # If offset has increased and we're looking at new notes, add a wait event before adding the new notes
        if (float(currentElementOffset) > float(previousElementOffset + 0.01) and (isinstance(element, note.Note) or isinstance(element, chord.Chord))):
            offsetChanged = True
            difference = float(currentElementOffset - previousElementOffset)
            tokens.append("wait:" + str(round(difference, 5)))

        if (isinstance(element, tempo.MetronomeMark) and not tempoRetrieved):
            tempoRetrieved = True
            tokens.append("tempo:" + str(tModifier(element.number)))

        if (isinstance(element, meter.TimeSignature) and not timeSigRetrieved):
            timeSigRetrieved = True
            tokens.append("timesig:" + str(element.ratioString))

        if (isinstance(element, note.Note)): # This is a note event, add a token for this note
            if (currentVelocity != vModifier(element.volume.velocity)):
                currentVelocity = vModifier(element.volume.velocity)
                tokens.append("velocity:" + str(currentVelocity))
            tokens.append("note:" + str(element.pitch))
            noteOffEvents.append((str(element.pitch), float(currentElementOffset + element.duration.quarterLength), 5))

        if (isinstance(element, chord.Chord)): # This is a chord event, add a token for each note in chord
            for chordnote in element:
                if (currentVelocity != vModifier(element.volume.velocity)):
                    currentVelocity = vModifier(element.volume.velocity)
                    tokens.append("velocity:" + str(currentVelocity))
                tokens.append("note:" + str(chordnote.pitch))
                noteOffEvents.append((str(chordnote.pitch), float(currentElementOffset + element.duration.quarterLength)))

        if (offsetChanged):
            previousElementOffset = currentElementOffset
            offsetChanged = False

    # Finally make sure that all notes that end after the offset of the last element of mf.flat.elements are given an off event.
    for noteToEnd in noteOffEvents.copy():
        difference = float(noteToEnd[1]) - float(previousElementOffset)
        if (difference > 0.01):
            tokens.append("wait:" + str(round(difference, 5)))
            previousElementOffset = noteToEnd[1]
        tokens.append("note:" + str(noteToEnd[0]) + ":OFF")
        noteOffEvents.remove(noteToEnd)
        
    if (len(noteOffEvents) != 0):
        print("Not all notes have note-off events")

    tokens.append("END")
    return tokens

def text2midi(tokens):
    s = stream.Stream()
    
    currentVelocity = 80

    currentOffset = 0
    currentToken = 0

    for token in tokens:

        splitToken = token.split(":")

        if token.startswith("tempo"):
            s.append(tempo.MetronomeMark(number=float(splitToken[1])))

        if token.startswith("timesig"):
            s.append(meter.TimeSignature(splitToken[1]))
            
        if token.startswith("velocity"):
            currentVelocity = int(splitToken[1])

        if token.startswith("note") and not token.lower().endswith("off"):
            noteDuration = 0
            noteName = splitToken[1]

            for element in tokens[currentToken+1:]:
                splitToken2 = element.split(":")
                if (element.startswith("wait")):
                    noteDuration += float(splitToken2[1])
                if (element.startswith("note") and element.lower().endswith("off")):
                    if (noteName == splitToken2[1]):
                        newNote = note.Note(nameWithOctave=splitToken[1],  
                               quarterLength=float(noteDuration))
                        newNote.volume.velocity = currentVelocity
                        s.insert(currentOffset, newNote)
                        break

        if token.startswith("wait"):
            currentOffset += float(splitToken[1]) 

        currentToken += 1

    return s

In [63]:
sequence_length   = 50
number_of_classes = 310
# Creates random tensor of your output shape
output = torch.rand(32, sequence_length, number_of_classes)
output = output.permute(0, 1, 2)
print(output.size())

# Creates tensor with random targets
target = torch.randint(310, (32,))
print(target.size())

print(output.size()[1:], target.size()[2:])

# Define loss function and calculate loss
criterion = nn.CrossEntropyLoss()

loss = criterion(output, target.squeeze())
print(loss)

torch.Size([32, 50, 310])
torch.Size([32])
torch.Size([50, 310]) torch.Size([])


ValueError: Expected target size (32, 310), got torch.Size([32])

In [65]:
# the output of one batch of 10 sentences, every sentence has 78 words, and every word has a score after logsoftmax
p = torch.rand(10,78,5)
#p = p.permute(0, 2, 1)
# the gold label of one batch of 10 sentences, every sentence has 78 words, and each word have one index of 0~4 which indicate its property, the detail meaning as above 
y = torch.ones(10,78).long()
# use the NLLLoss function
loss = nn.NLLLoss()
# get the loss value
r = loss(p,y)

ValueError: Expected target size (10, 5), got torch.Size([10, 78])