<a href="https://colab.research.google.com/github/Dudestin/SuperPiano/blob/master/Super_Chamber_Piano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super Chamber Piano (Ver 2.0): Notewise Music NN

***

## A Mini Musical Neural Net

***

All thanks and credit for this beautiful colab go out to edtky of GitHub on whose repo and code it is based: https://github.com/edtky/mini-musical-neural-net

***



# Model Specs and Default Parameters

## 2 Layers LSTM

### Hyperparameters

1. sequence_len / n_word = 512 (previous: 128)
2. batch_size / n_seq = 128
3. hidden_dim = 512
4. top_k words = 3 (previous: 5)
5. predict seq_len = 512 (previous: 1024)
6. epoch = 300

***

# System setup

In [None]:
#@title Install dependencies
!pip install git+https://github.com/kroger/pyknon.git
!pip install pretty_midi
!pip install pypianoroll
!pip install mir_eval
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio
!git clone https://github.com/asigalov61/arc-diagrams
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 /content/font.sf2

# Setup modules, functions, variables, and GPU check

In [None]:
#@title Load all modules, check the available devices (GPU/CPU), and setup MIDI parameters
import numpy as np

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

import keras
from keras.utils import to_categorical

import time

import pretty_midi
from midi2audio import FluidSynth
from google.colab import output
from IPython.display import display, Javascript, HTML, Audio

import librosa
import numpy as np
import pretty_midi
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('SVG')
# For plotting
import mir_eval.display
import librosa.display
%matplotlib inline


from mido import MidiFile
%cd /content/arc-diagrams/
from arc_diagram import plot_arc_diagram


dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print('Available Device:', device)

!mkdir /content/midis

sample_freq_variable = 12 #@param {type:"number"}
note_range_variable = 62 #@param {type:"number"}
note_offset_variable = 33 #@param {type:"number"}
number_of_instruments = 2 #@param {type:"number"}
chamber_option = True #@param {type:"boolean"}

%cd /content/

# (OPTIONS)

In [None]:
#@title (OPTION 1) Convert your own MIDIs to Notewise TXT DataSet (before running this cell, upload your MIDI DataSet to /content/midis folder)
import tqdm.auto
import argparse
import random
import os
import numpy as np
from math import floor
from pyknon.genmidi import Midi
from pyknon.music import NoteSeq, Note
import music21
from music21 import instrument, volume
from music21 import midi as midiModule
from pathlib import Path
import glob, sys
from music21 import converter, instrument
%cd /content
notes=[]
InstrumentID=0
folder = '/content/midis/*mid'
for file in tqdm.auto.tqdm(glob.glob(folder)):
    filename = file[-53:]
    print(filename)

    # fname = "../midi-files/mozart/sonat-3.mid"
    fname = filename

    mf=music21.midi.MidiFile()
    mf.open(fname)
    mf.read()
    mf.close()
    midi_stream=music21.midi.translate.midiFileToStream(mf)
    midi_stream



    sample_freq=sample_freq_variable
    note_range=note_range_variable
    note_offset=note_offset_variable
    chamber=chamber_option
    numInstruments=number_of_instruments

    s = midi_stream
    #print(s.duration.quarterLength)

    s[0].elements


    maxTimeStep = floor(s.duration.quarterLength * sample_freq)+1
    score_arr = np.zeros((maxTimeStep, numInstruments, note_range))

    #print(maxTimeStep, "\n", score_arr.shape)

    # define two types of filters because notes and chords have different structures for storing their data
    # chord have an extra layer because it consist of multiple notes

    noteFilter=music21.stream.filters.ClassFilter('Note')
    chordFilter=music21.stream.filters.ClassFilter('Chord')


    # pitch.midi-note_offset: pitch is the numerical representation of a note.
    #                         note_offset is the the pitch relative to a zero mark. eg. B-=25, C=27, A=24

    # n.offset: the timestamps of each note, relative to the start of the score
    #           by multiplying with the sample_freq, you make all the timestamps integers

    # n.duration.quarterLength: the duration of that note as a float eg. quarter note = 0.25, half note = 0.5
    #                           multiply by sample_freq to represent duration in terms of timesteps

    notes = []
    instrumentID = 0
    parts = instrument.partitionByInstrument(s)
    for i in range(len(parts.parts)):
      instru = parts.parts[i].getInstrument()


    for n in s.recurse().addFilter(noteFilter):
        if chamber:
          # assign_instrument where 0 means piano-like and 1 means violin-like, and -1 means neither
          if instru.instrumentName == 'Violin':
            notes.append((n.pitch.midi-note_offset, floor(n.offset*sample_freq),
              floor(n.duration.quarterLength*sample_freq), 1))

        notes.append((n.pitch.midi-note_offset, floor(n.offset*sample_freq),
              floor(n.duration.quarterLength*sample_freq), 0))

    #print(len(notes))
    notes[-5:]

    # do the same using a chord filter

    for c in s.recurse().addFilter(chordFilter):
        # unlike the noteFilter, this line of code is necessary as there are multiple notes in each chord
        # pitchesInChord is a list of notes at each chord eg. (<music21.pitch.Pitch D5>, <music21.pitch.Pitch F5>)
        pitchesInChord=c.pitches

        # do same as noteFilter and append all notes to the notes list
        for p in pitchesInChord:
            notes.append((p.midi-note_offset, floor(c.offset*sample_freq),
                          floor(c.duration.quarterLength*sample_freq), 1))

        # do same as noteFilter and append all notes to the notes list
        for p in pitchesInChord:
            notes.append((p.midi-note_offset, floor(c.offset*sample_freq),
                          floor(c.duration.quarterLength*sample_freq), 0))
    #print(len(notes))
    notes[-5:]

    # the variable/list "notes" is a collection of all the notes in the song, not ordered in any significant way

    for n in notes:

        # pitch is the first variable in n, previously obtained by n.midi-note_offset
        pitch=n[0]

        # do some calibration for notes that fall our of note range
        # i.e. less than 0 or more than note_range
        while pitch<0:
            pitch+=12
        while pitch>=note_range:
            pitch-=12

        # 3rd element refers to instrument type => if instrument is violin, use different pitch calibration
        if n[3]==1:      #Violin lowest note is v22
            while pitch<22:
                pitch+=12

        # start building the 3D-tensor of shape: (796, 1, 38)
        # score_arr[0] = timestep
        # score_arr[1] = type of instrument
        # score_arr[2] = pitch/note out of the range of note eg. 38

        # n[0] = pitch
        # n[1] = timestep
        # n[2] = duration
        # n[3] = instrument
        #print(n[3])
        try:
          score_arr[n[1], n[3], pitch]=1                  # Strike note
          score_arr[n[1]+1:n[1]+n[2], n[3], pitch]=2      # Continue holding note
        except:
          continue

    #print(score_arr.shape)
    # print first 5 timesteps
    score_arr[:5,0,]


    for timestep in score_arr:
        #print(list(reversed(range(len(timestep)))))
        break

    instr={}
    instr[0]="p"
    instr[1]="v"

    score_string_arr=[]

    # loop through all timesteps
    for timestep in score_arr:

        # selecting the instruments: i=0 means piano and i=1 means violin
        for i in list(reversed(range(len(timestep)))):   # List violin note first, then piano note

            #
            score_string_arr.append(instr[i]+''.join([str(int(note)) for note in timestep[i]]))

    #print(type(score_string_arr), len(score_string_arr))
    score_string_arr[:5]

    modulated=[]
    # get the note range from the array
    note_range=len(score_string_arr[0])-1

    for i in range(0,12):
        for chord in score_string_arr:

            # minus the instrument letter eg. 'p'
            # add 6 zeros on each side of the string
            padded='000000'+chord[1:]+'000000'

            # add back the instrument letter eg. 'p'
            # append window of len=note_range back into
            # eg. if we have "00012345000"
            # iteratively, we want to get "p00012", "p00123", "p01234", "p12345", "p23450", "p34500", "p45000",
            modulated.append(chord[0]+padded[i:i+note_range])

    # 796 * 12
    #print(len(modulated))
    modulated[:5]

    # input of this function is a modulated string
    long_string = modulated

    translated_list=[]

    # for every timestep of the string
    for j in range(len(long_string)):

        # chord at timestep j eg. 'p00000000000000000000000000000000000100'
        chord=long_string[j]
        next_chord=""

        # range is from next_timestep to max_timestep
        for k in range(j+1, len(long_string)):

            # checking if instrument of next chord is same as current chord
            if long_string[k][0]==chord[0]:

                # if same, set next chord as next element in modulation
                # otherwise, keep going until you find a chord with the same instrument
                # when you do, set it as the next chord
                next_chord=long_string[k]
                break

        # set prefix as the instrument
        # set chord and next_chord to be without the instrument prefix
        # next_chord is necessary to check when notes end
        prefix=chord[0]
        chord=chord[1:]
        next_chord=next_chord[1:]

        # checking for non-zero notes at one particular timestep
        # i is an integer indicating the index of each note the chord
        for i in range(len(chord)):

            if chord[i]=="0":
                continue

            # set note as 2 elements: instrument and index of note
            # examples: p22, p16, p4
            #p = music21.pitch.Pitch()
            #nt = music21.note.Note(p)
            #n.volume.velocity = 20
            #nt.volume.client == nt
            #V = nt.volume.velocity
            #print(V)
            #note=prefix+str(i)+' V' + str(V)
            note=prefix+str(i)

            # if note in chord is 1, then append the note eg. p22 to the list
            if chord[i]=="1":
                translated_list.append(note)

            # If chord[i]=="2" do nothing - we're continuing to hold the note

            # unless next_chord[i] is back to "0" and it's time to end the note.
            if next_chord=="" or next_chord[i]=="0":
                translated_list.append("end"+note)

        # wait indicates end of every timestep
        if prefix=="p":
            translated_list.append("wait")

    #print(len(translated_list))
    translated_list[:10]

    # this section transforms the list of notes into a string of notes

    # initialize i as zero and empty string
    i=0
    translated_string=""


    while i<len(translated_list):

        # stack all the repeated waits together using an integer to indicate the no. of waits
        # eg. "wait wait" => "wait2"
        wait_count=1
        if translated_list[i]=='wait':
            while wait_count<=sample_freq*2 and i+wait_count<len(translated_list) and translated_list[i+wait_count]=='wait':
                wait_count+=1
            translated_list[i]='wait'+str(wait_count)

        # add next note
        translated_string+=translated_list[i]+" "
        i+=wait_count

    translated_string[:100]
    len(translated_string)

    #print("chordwise encoding type and length:", type(modulated), len(modulated))
    #print("notewise encoding type and length:", type(translated_string), len(translated_string))

    # default settings: sample_freq=12, note_range=62

    chordwise_folder = "../"
    notewise_folder = "../"

    # export chordwise encoding
#    f=open(chordwise_folder+fname+"_chordwise"+".txt","w+")
#    f.write(" ".join(modulated))
#    f.close()

    # export notewise encoding
    f=open(notewise_folder+fname+"_notewise"+".txt","w+")
    f.write(translated_string)
    f.close()

folder = '/content/midis/*notewise.txt'


filenames = glob.glob('/content')
with open('notewise_custom_dataset.txt', 'w') as outfile:
    for fname in glob.glob(folder)[-53:]:
        with open(fname) as infile:
            for line in infile:
                outfile.write(line)

#folder = '/content/midis/*chordwise.txt'

#filenames = glob.glob('/content')
#with open('chordwise_custom_dataset.txt', 'w') as outfile:
#    for fname in glob.glob(folder)[-53:]:
#        with open(fname) as infile:
#            for line in infile:
#                outfile.write(line)

In [None]:
#@title (OPTION 2) Download ready-to-use Piano and Chamber Notewise DataSets
%cd /content/
!wget -nc 'https://github.com/asigalov61/SuperPiano/raw/master/Super%20Chamber%20Piano%20Violin%20Notewise%20DataSet.zip'
!unzip -o '/content/Super Chamber Piano Violin Notewise DataSet.zip'
!rm '/content/Super Chamber Piano Violin Notewise DataSet.zip'

!wget -nc 'https://github.com/asigalov61/SuperPiano/raw/master/Super%20Chamber%20Piano%20Only%20Notewise%20DataSet.zip'
!unzip -o '/content/Super Chamber Piano Only Notewise DataSet.zip'
!rm '/content/Super Chamber Piano Only Notewise DataSet.zip'

In [None]:
#@title Load and Encode TXT Notes DataSet
select_training_dataset_file = "/content/notewise_chamber.txt" #@param {type:"string"}

# replace with any text file containing full set of data
MIDI_data = select_training_dataset_file

with open(MIDI_data, 'r') as file:
    text = file.read()

# get vocabulary set
words = sorted(tuple(set(text.split())))
n = len(words)

# create word-integer encoder/decoder
word2int = dict(zip(words, list(range(n))))
int2word = dict(zip(list(range(n)), words))

# encode all words in dataset into integers
encoded = np.array([word2int[word] for word in text.split()])

# Main Model Setup

In [None]:
#@title Define all functions
# define model using the pytorch nn module
class WordLSTM(nn.ModuleList):

    def __init__(self, sequence_len, vocab_size, hidden_dim, batch_size):
        super(WordLSTM, self).__init__()

        # init the hyperparameters
        self.vocab_size = vocab_size
        self.sequence_len = sequence_len
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim

        # first layer lstm cell
        self.lstm_1 = nn.LSTMCell(input_size=vocab_size, hidden_size=hidden_dim)

        # second layer lstm cell
        self.lstm_2 = nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim)

        # third layer lstm cell
        #self.lstm_3 = nn.LSTMCell(input_size=hidden_dim, hidden_size=hidden_dim)

        # dropout layer
        self.dropout = nn.Dropout(p=0.35)

        # fully connected layer
        self.fc = nn.Linear(in_features=hidden_dim, out_features=vocab_size)

    # forward pass in training
    def forward(self, x, hc):
        """
            accepts 2 arguments:
            1. x: input of each batch
                - shape 128*149 (batch_size*vocab_size)
            2. hc: tuple of init hidden, cell states
                - each of shape 128*512 (batch_size*hidden_dim)
        """

        # create empty output seq
        output_seq = torch.empty((self.sequence_len,
                                  self.batch_size,
                                  self.vocab_size))
        # if using gpu
        output_seq = output_seq.to(device)

        # init hidden, cell states for lstm layers
        hc_1, hc_2, hc_3 = hc, hc, hc

        # for t-th word in every sequence
        for t in range(self.sequence_len):

            # layer 1 lstm
            hc_1 = self.lstm_1(x[t], hc_1)
            h_1, c_1 = hc_1

            # layer 2 lstm
            hc_2 = self.lstm_2(h_1, hc_2)
            h_2, c_2 = hc_2

            # layer 3 lstm
            #hc_3 = self.lstm_3(h_2, hc_3)
            #h_3, c_3 = hc_3

            # dropout and fully connected layer
            output_seq[t] = self.fc(self.dropout(h_2))

        return output_seq.view((self.sequence_len * self.batch_size, -1))

    def init_hidden(self):

        # initialize hidden, cell states for training
        # if using gpu
        return (torch.zeros(self.batch_size, self.hidden_dim).to(device),
                torch.zeros(self.batch_size, self.hidden_dim).to(device))

    def init_hidden_generator(self):

        # initialize hidden, cell states for prediction of 1 sequence
        # if using gpu
        return (torch.zeros(1, self.hidden_dim).to('cpu'),
                torch.zeros(1, self.hidden_dim).to('cpu'))

    def predict(self, seed_seq, top_k=5, pred_len=128):
        """
            accepts 3 arguments:
            1. seed_seq: seed string sequence for prediction (prompt)
            2. top_k: top k words to sample prediction from
            3. pred_len: number of words to generate after the seed seq
        """

        # set evaluation mode
        self.eval()

        # split string into list of words
        seed_seq = seed_seq.split()

        # get seed sequence length
        seed_len = len(seed_seq)

        # create output sequence
        out_seq = np.empty(seed_len+pred_len)

        # append input seq to output seq
        out_seq[:seed_len] = np.array([word2int[word] for word in seed_seq])

        # init hidden, cell states for generation
        hc = self.init_hidden_generator()
        hc_1, hc_2, hc_3 = hc, hc, hc

        # feed seed string into lstm
        # get the hidden state set up
        for word in seed_seq[:-1]:

            # encode starting word to one-hot encoding
            word = to_categorical(word2int[word], num_classes=self.vocab_size).float()

            # add batch dimension
            word = torch.from_numpy(word).unsqueeze(0)
            # if using gpu
            word = word.to('cpu')

            # layer 1 lstm
            hc_1 = self.lstm_1(word, hc_1)
            h_1, c_1 = hc_1

            # layer 2 lstm
            hc_2 = self.lstm_2(h_1, hc_2)
            h_2, c_2 = hc_2

            # layer 3 lstm
            #hc_3 = self.lstm_3(h_2, hc_3)
            #h_3, c_3 = hc_3

        word = seed_seq[-1]

        # encode starting word to one-hot encoding
        word = to_categorical(word2int[word], num_classes=self.vocab_size)

        # add batch dimension
        word = torch.from_numpy(word).unsqueeze(0).float()
        # if using gpu
        word = word.to('cpu')

        # forward pass
        for t in range(pred_len):

            # layer 1 lstm
            hc_1 = self.lstm_1(word, hc_1)
            h_1, c_1 = hc_1

            # layer 2 lstm
            hc_2 = self.lstm_2(h_1, hc_2)
            h_2, c_2 = hc_2

            # layer 3 lstm
            #hc_3 = self.lstm_3(h_2, hc_3)
            #h_3, c_3 = hc_3

            # fully connected layer without dropout (no need)
            output = self.fc(h_2)

            # software to get probabilities of output options
            output = F.softmax(output, dim=1)

            # get top k words and corresponding probabilities
            p, top_word = output.topk(top_k)
            # if using gpu
            p = p.cpu()

            # sample from top k words to get next word
            p = p.detach().squeeze().numpy()
            top_word = torch.squeeze(top_word)

            word = np.random.choice(top_word, p = p/p.sum())

            # add word to sequence
            out_seq[seed_len+t] = word

            # encode predicted word to one-hot encoding for next step
            word = to_categorical(word, num_classes=self.vocab_size)
            word = torch.from_numpy(word).unsqueeze(0).float()
            # word = torch.from_numpy(word).unsqueeze(0)
            # if using gpu
            word = word.to('cpu')

        return out_seq


def get_batches(arr, n_seqs, n_words):
    """
        create generator object that returns batches of input (x) and target (y).
        x of each batch has shape 128*128*149 (batch_size*seq_len*vocab_size).

        accepts 3 arguments:
        1. arr: array of words from text data
        2. n_seq: number of sequence in each batch (aka batch_size)
        3. n_word: number of words in each sequence
    """

    # compute total elements / dimension of each batch
    batch_total = n_seqs * n_words

    # compute total number of complete batches
    n_batches = arr.size//batch_total

    # chop array at the last full batch
    arr = arr[: n_batches* batch_total]

    # reshape array to matrix with rows = no. of seq in one batch
    arr = arr.reshape((n_seqs, -1))

    # for each n_words in every row of the dataset
    for n in range(0, arr.shape[1], n_words):

        # chop it vertically, to get the input sequences
        x = arr[:, n:n+n_words]

        # init y - target with shape same as x
        y = np.zeros_like(x)

        # targets obtained by shifting by one
        try:
            y[:, :-1], y[:, -1] = x[:, 1:], x[:, n+n_words]
        except IndexError:
            y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]

        # yield function is like return, but creates a generator object
        yield x, y

In [None]:
#@title Compile the Model
training_batch_size = 1024 #@param {type:"slider", min:0, max:1024, step:16}
attention_span_in_tokens = 256 #@param {type:"slider", min:0, max:512, step:64}
hidden_dimension_size = 256 #@param {type:"slider", min:0, max:512, step:64}
test_validation_ratio = 0.1 #@param {type:"slider", min:0, max:1, step:0.1}
learning_rate = 0.001 #@param {type:"number"}


# compile the network - sequence_len, vocab_size, hidden_dim, batch_size
net = WordLSTM(sequence_len=attention_span_in_tokens, vocab_size=len(word2int), hidden_dim=hidden_dimension_size, batch_size=training_batch_size)
# if using gpu
net.to(device)

# define the loss and the optimizer
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# split dataset into 90% train and 10% using index
val_idx = int(len(encoded) * (1 - test_validation_ratio))
train_data, val_data = encoded[:val_idx], encoded[val_idx:]

# empty list for the validation losses
val_losses = list()

# empty list for the samples
samples = list()

# (OPTIONS)

In [None]:
#@title (OPTION 1) Train the Model
number_of_training_epochs = 300 #@param {type:"slider", min:1, max:300, step:1}

import tqdm

# track time
start_time = time.time()

# declare seed sequence
#seed_string = "p47 p50 wait8 endp47 endp50 wait4 p47 p50 wait8 endp47 endp50"

# finally train the model
for epoch in tqdm.tqdm(range(number_of_training_epochs)):

    # init the hidden and cell states to zero
    hc = net.init_hidden()

    # (x, y) refers to one batch with index i, where x is input, y is target
    for i, (x, y) in enumerate(get_batches(train_data, training_batch_size, hidden_dimension_size)):

        # get the torch tensors from the one-hot of training data
        # also transpose the axis for the training set and the targets
        x_train = torch.from_numpy(to_categorical(x, num_classes=net.vocab_size).transpose([1, 0, 2])).to(dtype=torch.float)
        targets = torch.from_numpy(y.T).type(torch.LongTensor)  # tensor of the target

        # if using gpu
        x_train = x_train.to(device)
        targets = targets.to(device)

        # zero out the gradients
        optimizer.zero_grad()

        # get the output sequence from the input and the initial hidden and cell states
        # calls forward function
        output = net(x_train, hc)

        # calculate the loss
        # we need to calculate the loss across all batches, so we have to flat the targets tensor
        loss = criterion(output, targets.contiguous().view(training_batch_size*hidden_dimension_size))

        # calculate the gradients
        loss.backward()

        # update the parameters of the model
        optimizer.step()

        # track time

        # feedback every 100 batches
        if i % 100 == 0:

            # initialize the validation hidden state and cell state
            val_h, val_c = net.init_hidden()

            for val_x, val_y in get_batches(val_data, training_batch_size, hidden_dimension_size):

                # prepare the validation inputs and targets
                val_x = torch.from_numpy(to_categorical(val_x).transpose([1, 0, 2]))
                val_y = torch.from_numpy(val_y.T).type(torch.LongTensor).contiguous().view(training_batch_size*hidden_dimension_size)

                # if using gpu
                val_x = val_x.to(device)
                val_y = val_y.to(device)

                # get the validation output
                #val_output = net(val_x, (val_h, val_c))

                # get the validation loss
                #val_loss = criterion(val_output, val_y)

                # append the validation loss
                #val_losses.append(val_loss.item())

                # samples.append(''.join([int2char[int_] for int_ in net.predict("p33", seq_len=1024)]))

#            with open("../content" + str(epoch) + "_batch" + str(i) + ".txt", "w") as loss_file:
#                loss_file.write("Epoch: {}, Batch: {}, Train Loss: {:.6f}, Validation Loss: {:.6f}".format(epoch, i, loss.item(), val_loss.item()))

#            with open("../content" + str(epoch) + "_batch" + str(i) + ".txt", "w") as outfile:
#                outfile.write(' '.join([int2word[int_] for int_ in net.predict(seed_seq=seed_string, pred_len=512)]))

            # track time
            duration = round(time.time() - start_time, 1)
            start_time = time.time()

            print("Epoch: {}, Batch: {}, Duration: {} sec, Test Loss: {}".format(epoch, i, duration, loss.item()))

# Rythm-awared Score Prediction Model

In [None]:
#@title My Original Time Decayed Cross-Entropy

import matplotlib.pyplot as plt
import numpy as np
import warnings

# ===============================================================
# Global Parameters & Device Setup
# ===============================================================
# Sequence lengths (from t=0) to calculate accuracy for.
ACCURACY_SEQUENCE_LENGTHS = [1, 4, 16, 64, 128, 256, 512]
# Frequency to save checkpoints
save_checkpoint_every = 5

# Setup device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ===============================================================
# Custom Loss Function (No changes from original)
# ===============================================================
class TemporalWeightedCrossEntropyLoss(nn.Module):
    def __init__(self,
                 time_weight_type: str | None = 'linear',
                 time_weight_gamma: float = 0.9,
                 weight: torch.Tensor | None = None,
                 ignore_index: int = -100,
                 reduction: str = 'mean',
                 label_smoothing: float = 0.0):
        super().__init__()
        if time_weight_type not in ['linear', 'exponential', None]:
            raise ValueError("time_weight_type must be 'linear', 'exponential', or None")
        if reduction not in ['mean', 'sum']:
            raise ValueError("reduction must be 'mean' or 'sum'")
        self.time_weight_type = time_weight_type
        self.time_weight_gamma = time_weight_gamma
        self.reduction = reduction
        self.criterion = nn.CrossEntropyLoss(
            weight=weight,
            ignore_index=ignore_index,
            reduction='none',
            label_smoothing=label_smoothing
        )

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if input.dim() == 3:
            # input shape: [batch, seq_len, classes]
            # target shape: [batch, seq_len]
            # nn.CrossEntropyLoss expects input as [batch, classes, seq_len]
            input_permuted = input.transpose(1, 2)
            loss_per_step = self.criterion(input_permuted, target) # shape: [batch, seq_len]

            if self.time_weight_type is None:
                if self.reduction == 'mean': return loss_per_step.mean()
                else: return loss_per_step.sum()

            weights = self._get_weights(loss_per_step.shape[1], loss_per_step.device)
            weighted_loss = loss_per_step * weights

            if self.reduction == 'mean': return weighted_loss.mean()
            else: return weighted_loss.sum()

        elif input.dim() == 2:
            warnings.warn(
                "TemporalWeightedCrossEntropyLoss received a 2D input. "
                "Temporal weighting is being skipped. "
                "Standard CrossEntropyLoss is applied instead."
            )
            loss_per_element = self.criterion(input, target)
            if self.reduction == 'mean': return loss_per_element.mean()
            else: return loss_per_element.sum()
        else:
            raise ValueError(f"Unsupported input dimension: {input.dim()}. Expected 2 or 3.")

    def _get_weights(self, seq_len: int, device: torch.device) -> torch.Tensor:
        if self.time_weight_type == 'linear':
            return torch.linspace(1.0, 0.1, seq_len, device=device)
        elif self.time_weight_type == 'exponential':
            return torch.tensor(
                [self.time_weight_gamma**t for t in range(seq_len)],
                device=device, dtype=torch.float32
            )
        else:
            return torch.ones(seq_len, device=device)

    def plot_weights(self, seq_len: int = 50, save_path: str | None = None):
        weights = self._get_weights(seq_len, device='cpu').numpy()
        timesteps = np.arange(seq_len)
        plt.figure(figsize=(10, 6))
        plt.plot(timesteps, weights, marker='o', linestyle='--')
        title = f"Weight Decay Function (type: {self.time_weight_type or 'none'}"
        if self.time_weight_type == 'exponential': title += f", gamma: {self.time_weight_gamma}"
        title += ")"
        plt.title(title)
        plt.xlabel("Time Step")
        plt.ylabel("Weight")
        plt.grid(True, linestyle=':', alpha=0.6)
        plt.ylim(0, 1.1)
        if save_path:
            plt.savefig(save_path)
            print(f"Weight plot saved to {save_path}")
        else:
            plt.show()
        plt.close()



# --- Google Colabでの使用例 ---
%matplotlib inline

print("--- 1. 線形減衰のグラフ ---")
loss_fn_linear = TemporalWeightedCrossEntropyLoss(time_weight_type='linear')
# save_pathを指定しないことでインライン表示される
loss_fn_linear.plot_weights(seq_len=64)


print("\n--- 2. 指数減衰のグラフ ---")
loss_fn_exp = TemporalWeightedCrossEntropyLoss(
    time_weight_type='exponential',
    time_weight_gamma=0.9
)
loss_fn_exp.plot_weights(seq_len=64)


In [None]:
# Modified rhythm-to-score prediction model with dual mode support

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import re
import copy
import random
import os
import json
from torch.utils.tensorboard import SummaryWriter
import datetime

# Global constants
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
save_checkpoint_every = 5
ACCURACY_SEQUENCE_LENGTHS = [1, 4, 8, 16, 32, 64, 128, 256]

# ===============================================================
# Data Processing Functions
# ===============================================================
def extract_rhythm_from_score(score_string):
    """Extract rhythm information from score string by masking pitch information."""
    tokens = score_string.split()
    rhythm_tokens = []
    for token in tokens:
        if token.startswith('wait'):
            rhythm_tokens.append(token)
        elif token.startswith('end'):
            if token[3] in ['p', 'v']:
                rhythm_tokens.append(token[:4] + '<mask>')
            else:
                rhythm_tokens.append(token)
        elif token[0] in ['p', 'v'] and len(token) > 1 and token[1:].isdigit():
            rhythm_tokens.append(token[0] + '<mask>')
        else:
            rhythm_tokens.append(token)
    return ' '.join(rhythm_tokens)

def prepare_rhythm_to_score_data(text_data, sequence_length=512):
    """Prepare training data with overlapping sequences."""
    all_tokens = text_data.strip().split()
    rhythm_sequences, score_sequences = [], []

    # Ensure we have at least sequence_length + 1 tokens for proper training
    for i in range(0, len(all_tokens) - sequence_length, sequence_length // 2):
        # Extract sequence_length + 1 tokens to allow for (r(t+1), x(t)) -> x(t+1) mapping
        score_chunk = all_tokens[i:i + sequence_length + 1]
        score_text = ' '.join(score_chunk)
        rhythm_text = extract_rhythm_from_score(score_text)
        rhythm_sequences.append(rhythm_text)
        score_sequences.append(score_text)

    print(f"Created {len(rhythm_sequences)} training sequences")
    return rhythm_sequences, score_sequences

def prepare_vocabularies(rhythm_sequences, score_sequences, use_rhythm=True):
    """Create vocabularies with special tokens."""
    # Add special tokens - padding must be at index 0 for CrossEntropyLoss ignore_index
    special_tokens = ['<pad>', '<unk>']

    score_tokens = set(token for seq in score_sequences for token in seq.split())
    score_tokens = special_tokens + sorted(list(score_tokens))

    score2int = {token: i for i, token in enumerate(score_tokens)}
    int2score = {i: token for i, token in enumerate(score_tokens)}

    if use_rhythm:
        rhythm_tokens = set(token for seq in rhythm_sequences for token in seq.split())
        rhythm_tokens = special_tokens + sorted(list(rhythm_tokens))
        rhythm2int = {token: i for i, token in enumerate(rhythm_tokens)}
        int2rhythm = {i: token for i, token in enumerate(rhythm_tokens)}
    else:
        rhythm2int = None
        int2rhythm = None

    return rhythm2int, int2rhythm, score2int, int2score

# ===============================================================
# Modified Model Class with Dual Mode Support
# ===============================================================
class RhythmToScoreLSTM(nn.Module):
    def __init__(self, rhythm_vocab_size, score_vocab_size, hidden_dim, num_layers=2, use_rhythm=True):
        super(RhythmToScoreLSTM, self).__init__()
        self.rhythm_vocab_size = rhythm_vocab_size
        self.score_vocab_size = score_vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.use_rhythm = use_rhythm

        # Input dimension depends on mode
        if use_rhythm:
            self.input_dim = rhythm_vocab_size + score_vocab_size
        else:
            self.input_dim = score_vocab_size

        # LSTM takes one-hot vectors directly
        self.lstm = nn.LSTM(self.input_dim, hidden_dim, num_layers,
                           batch_first=True, dropout=0.35 if num_layers > 1 else 0)

        # Output layer
        self.fc = nn.Linear(hidden_dim, score_vocab_size)
        self.dropout = nn.Dropout(0.35)

    def forward(self, rhythm_input, score_input, hidden=None):
        """
        Forward pass with optional rhythm input.

        Args:
            rhythm_input: [batch_size, seq_len] - indices for r(t+1) (ignored if use_rhythm=False)
            score_input: [batch_size, seq_len] - indices for x(t)
            hidden: LSTM hidden state

        Returns:
            output: [batch_size, seq_len, score_vocab_size] - predictions for x(t+1)
            hidden: Updated LSTM hidden state
        """
        batch_size, seq_len = score_input.shape

        # Convert to one-hot encodings
        score_onehot = F.one_hot(score_input, num_classes=self.score_vocab_size).float()

        if self.use_rhythm:
            rhythm_onehot = F.one_hot(rhythm_input, num_classes=self.rhythm_vocab_size).float()
            # Concatenate rhythm and score one-hot vectors
            combined_input = torch.cat([rhythm_onehot, score_onehot], dim=-1)
        else:
            # Use only score one-hot vectors
            combined_input = score_onehot

        # Pass through LSTM
        if hidden is not None:
            lstm_out, hidden = self.lstm(combined_input, hidden)
        else:
            lstm_out, hidden = self.lstm(combined_input)

        # Apply dropout and output layer
        lstm_out = self.dropout(lstm_out)
        output = self.fc(lstm_out)

        return output, hidden

    def init_hidden(self, batch_size):
        """Initialize LSTM hidden state."""
        h = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        c = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        return (h, c)

    def generate(self, initial_score_token, score2int, int2score, rhythm_sequence=None,
                 rhythm2int=None, max_length=512, temperature=1.0):
        """
        Generate a score sequence.

        Args:
            initial_score_token: Initial score token to start generation
            score2int, int2score: Score vocabulary mappings
            rhythm_sequence: String of rhythm tokens (required if use_rhythm=True)
            rhythm2int: Rhythm vocabulary mapping (required if use_rhythm=True)
            max_length: Maximum generation length
            temperature: Sampling temperature

        Returns:
            Generated score sequence as string
        """
        self.eval()

        if self.use_rhythm:
            if rhythm_sequence is None or rhythm2int is None:
                raise ValueError("rhythm_sequence and rhythm2int are required when use_rhythm=True")

            rhythm_tokens = rhythm_sequence.split()
            if len(rhythm_tokens) <= 1:
                return ""

            # Determine generation length based on rhythm sequence
            max_length = min(max_length, len(rhythm_tokens) - 1)

        # Initialize with the initial score token
        current_score_idx = score2int.get(initial_score_token, score2int['<unk>'])
        generated_tokens = []

        hidden = self.init_hidden(1)

        with torch.no_grad():
            for i in range(max_length):
                # Prepare score input
                score_input = torch.tensor([[current_score_idx]], dtype=torch.long).to(device)

                if self.use_rhythm:
                    # Get r(t+1)
                    rhythm_idx = rhythm2int.get(rhythm_tokens[i+1], rhythm2int['<unk>'])
                    rhythm_input = torch.tensor([[rhythm_idx]], dtype=torch.long).to(device)
                else:
                    # Dummy rhythm input (not used in forward pass)
                    rhythm_input = torch.zeros((1, 1), dtype=torch.long).to(device)

                # Forward pass
                output, hidden = self.forward(rhythm_input, score_input, hidden)

                # Apply temperature and sample
                output = output.squeeze(0).squeeze(0) / temperature
                probs = F.softmax(output, dim=-1)
                current_score_idx = torch.multinomial(probs, 1).item()

                # Add to generated sequence
                generated_tokens.append(int2score.get(current_score_idx, '<unk>'))

        return ' '.join(generated_tokens)

# ===============================================================
# Training Function with Dual Mode Support
# ===============================================================
def train_rhythm_to_score_model(model, rhythm_sequences, score_sequences,
                                rhythm2int, score2int, int2rhythm, int2score, epochs,
                                learning_rate=0.001, batch_size=32, max_seq_length=512,
                                start_epoch=0, optimizer_state=None, save_checkpoints=True,
                                use_rhythm=True):
    """Train the model in either rhythm+score or score-only mode."""

    mode_str = "rhythm_score" if use_rhythm else "score_only"
    log_dir = f"runs/{mode_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = SummaryWriter(log_dir)
    print(f"TensorBoard logs will be saved to: {log_dir}")

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)
        print("Optimizer state loaded from checkpoint")

    # Use standard CrossEntropyLoss with ignore_index for padding
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is padding index

    # Prepare dataset with proper alignment
    dataset = []
    pad_idx_score = score2int['<pad>']
    pad_idx_rhythm = rhythm2int['<pad>'] if use_rhythm else 0

    for rhythm_seq, score_seq in zip(rhythm_sequences, score_sequences):
        score_tokens = score_seq.split()

        if use_rhythm:
            rhythm_tokens = rhythm_seq.split()
            if len(rhythm_tokens) != len(score_tokens):
                continue

        # We need at least 2 tokens to create one training example
        if len(score_tokens) < 2:
            continue

        if len(score_tokens) > max_seq_length + 1:
            score_tokens = score_tokens[:max_seq_length + 1]
            if use_rhythm:
                rhythm_tokens = rhythm_tokens[:max_seq_length + 1]

        # Convert to indices
        score_indices = [score2int.get(token, score2int['<unk>']) for token in score_tokens]

        if use_rhythm:
            rhythm_indices = [rhythm2int.get(token, rhythm2int['<unk>']) for token in rhythm_tokens]
            dataset.append((rhythm_indices, score_indices))
        else:
            dataset.append((None, score_indices))

    print(f"Total training samples after filtering: {len(dataset)}")
    if len(dataset) == 0:
        print("No valid training samples found!")
        return

    model.train()
    pbar = tqdm.tqdm(range(start_epoch, epochs), desc="Training Progress", unit="epoch")
    global_step = start_epoch * (len(dataset) // batch_size)

    for epoch in pbar:
        total_loss = 0
        num_batches = 0
        random.shuffle(dataset)

        for i in range(0, len(dataset), batch_size):
            batch_data = dataset[i:i+batch_size]

            if use_rhythm:
                # Find max length in batch (excluding last token for alignment)
                max_len = max(len(rhythm) - 1 for rhythm, _ in batch_data)
            else:
                max_len = max(len(score) - 1 for _, score in batch_data)

            if max_len == 0:
                continue

            batch_rhythm_input = []
            batch_score_input = []
            batch_score_target = []

            for rhythm_indices, score_indices in batch_data:
                seq_len = len(score_indices) - 1  # Number of training steps

                if use_rhythm:
                    # Create input sequences: r(t+1) and x(t)
                    rhythm_input = rhythm_indices[1:seq_len+1]  # r(1) to r(seq_len)
                    score_input = score_indices[:seq_len]        # x(0) to x(seq_len-1)
                    score_target = score_indices[1:seq_len+1]    # x(1) to x(seq_len)

                    # Pad sequences to max_len
                    pad_len = max_len - len(rhythm_input)
                    rhythm_input = rhythm_input + [pad_idx_rhythm] * pad_len
                else:
                    # Score-only mode: x(t) -> x(t+1)
                    score_input = score_indices[:seq_len]        # x(0) to x(seq_len-1)
                    score_target = score_indices[1:seq_len+1]    # x(1) to x(seq_len)
                    rhythm_input = [0] * max_len  # Dummy rhythm input

                # Pad score sequences
                pad_len = max_len - len(score_input)
                score_input = score_input + [pad_idx_score] * pad_len
                score_target = score_target + [pad_idx_score] * pad_len

                batch_rhythm_input.append(rhythm_input)
                batch_score_input.append(score_input)
                batch_score_target.append(score_target)

            # Convert to tensors
            rhythm_input_tensor = torch.tensor(batch_rhythm_input, dtype=torch.long).to(device)
            score_input_tensor = torch.tensor(batch_score_input, dtype=torch.long).to(device)
            score_target_tensor = torch.tensor(batch_score_target, dtype=torch.long).to(device)

            optimizer.zero_grad()

            try:
                # Forward pass
                output, _ = model(rhythm_input_tensor, score_input_tensor)

                # Reshape for loss calculation
                output = output.reshape(-1, model.score_vocab_size)
                target = score_target_tensor.reshape(-1)

                loss = criterion(output, target)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1
                global_step += 1
                writer.add_scalar('Loss/batch', loss.item(), global_step)

            except RuntimeError as e:
                print(f"Error in batch {i//batch_size}: {e}")
                continue

        if num_batches > 0:
            avg_loss = total_loss / num_batches
            writer.add_scalar('Loss/epoch', avg_loss, epoch)

            # Calculate and log accuracy
            accuracies = calculate_accuracy(
                model, dataset, rhythm2int, score2int, device,
                max_samples=500,
                sequence_lengths=ACCURACY_SEQUENCE_LENGTHS,
                use_rhythm=use_rhythm
            )

            # Log accuracies to TensorBoard
            for seq_len, acc in accuracies.items():
                writer.add_scalar(f'Accuracy/val_seq_{seq_len}', acc, epoch)

            # Update progress bar
            postfix_dict = {'Loss': f'{avg_loss:.4f}'}
            display_accs = {f'Acc@{L}': f'{accuracies.get(L, 0):.3f}'
                           for L in [1, 16, 64, 256] if L in accuracies}
            postfix_dict.update(display_accs)
            pbar.set_postfix(postfix_dict)

        if save_checkpoints and (epoch + 1) % save_checkpoint_every == 0:
            save_checkpoint(model, rhythm2int, int2rhythm, score2int, int2score,
                          optimizer, epoch, avg_loss, use_rhythm=use_rhythm)

    pbar.close()
    writer.close()

    if save_checkpoints:
        final_loss = avg_loss if 'avg_loss' in locals() else 0
        save_checkpoint(model, rhythm2int, int2rhythm, score2int, int2score,
                       optimizer, epochs - 1, final_loss, use_rhythm=use_rhythm)

# ===============================================================
# Accuracy Calculation Function with Dual Mode Support
# ===============================================================
def calculate_accuracy(model, dataset, rhythm2int, score2int, device,
                      max_samples=500, sequence_lengths=None, use_rhythm=True):
    """Calculate accuracy for different sequence lengths."""
    if sequence_lengths is None:
        sequence_lengths = ACCURACY_SEQUENCE_LENGTHS

    model.eval()

    pad_idx_score = score2int['<pad>']

    correct_counts = {L: 0 for L in sequence_lengths}
    total_counts = {L: 0 for L in sequence_lengths}

    sample_size = min(len(dataset), max_samples)
    if sample_size == 0:
        return {L: 0 for L in sequence_lengths}

    sampled_data = random.sample(dataset, sample_size)

    with torch.no_grad():
        for rhythm_indices, score_indices in sampled_data:
            if len(score_indices) < 2:
                continue

            seq_len = len(score_indices) - 1

            # Prepare inputs
            if use_rhythm:
                rhythm_input = rhythm_indices[1:seq_len+1]
            else:
                rhythm_input = [0] * seq_len  # Dummy rhythm input

            score_input = score_indices[:seq_len]
            score_target = score_indices[1:seq_len+1]

            # Convert to tensors (batch size = 1)
            rhythm_tensor = torch.tensor([rhythm_input], dtype=torch.long).to(device)
            score_tensor = torch.tensor([score_input], dtype=torch.long).to(device)
            target_tensor = torch.tensor([score_target], dtype=torch.long).to(device)

            # Forward pass
            output, _ = model(rhythm_tensor, score_tensor)
            _, predicted = torch.max(output, dim=2)

            # Remove batch dimension
            predicted = predicted.squeeze(0)
            target = target_tensor.squeeze(0)

            # Calculate accuracy for each sequence length
            for L in sequence_lengths:
                eval_len = min(L, seq_len)
                if eval_len == 0:
                    continue

                # Get predictions and targets for first L tokens
                pred_slice = predicted[:eval_len]
                target_slice = target[:eval_len]

                # Count correct predictions (excluding padding)
                mask = target_slice != pad_idx_score
                if mask.sum() > 0:  # Only count if there are non-padding tokens
                    correct_counts[L] += (pred_slice[mask] == target_slice[mask]).sum().item()
                    total_counts[L] += mask.sum().item()

    model.train()

    # Calculate final accuracies
    accuracies = {L: correct_counts[L] / total_counts[L] if total_counts[L] > 0 else 0
                  for L in sequence_lengths}

    return accuracies

# ===============================================================
# Checkpoint Functions with Mode Support
# ===============================================================
def save_checkpoint(model, rhythm2int, int2rhythm, score2int, int2score,
                    optimizer, epoch, loss, checkpoint_dir="checkpoints", use_rhythm=True):
    """Save model checkpoint."""
    mode_str = "rhythm_score" if use_rhythm else "score_only"
    checkpoint_dir = os.path.join(checkpoint_dir, mode_str)
    os.makedirs(checkpoint_dir, exist_ok=True)

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'rhythm2int': rhythm2int,
        'int2rhythm': int2rhythm,
        'score2int': score2int,
        'int2score': int2score,
        'use_rhythm': use_rhythm,
        'model_config': {
            'rhythm_vocab_size': model.rhythm_vocab_size,
            'score_vocab_size': model.score_vocab_size,
            'hidden_dim': model.hidden_dim,
            'num_layers': model.num_layers,
            'use_rhythm': use_rhythm
        }
    }

    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
    torch.save(checkpoint, checkpoint_path)

    latest_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pt')
    torch.save(checkpoint, latest_path)

    print(f"\nCheckpoint saved: {checkpoint_path}")
    return checkpoint_path

def load_checkpoint(checkpoint_path, device):
    """Load model checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = checkpoint['model_config']

    use_rhythm = config.get('use_rhythm', True)  # Default to True for backward compatibility

    model = RhythmToScoreLSTM(
        rhythm_vocab_size=config['rhythm_vocab_size'],
        score_vocab_size=config['score_vocab_size'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        use_rhythm=use_rhythm
    ).to(device)

    model.load_state_dict(checkpoint['model_state_dict'])

    rhythm2int = checkpoint['rhythm2int']
    int2rhythm = checkpoint['int2rhythm']
    score2int = checkpoint['score2int']
    int2score = checkpoint['int2score']
    start_epoch = checkpoint['epoch'] + 1

    print(f"Checkpoint loaded from epoch {checkpoint['epoch']}. Last loss: {checkpoint['loss']:.4f}")
    print(f"Model mode: {'rhythm+score' if use_rhythm else 'score-only'}")

    return model, rhythm2int, int2rhythm, score2int, int2score, checkpoint, start_epoch

# ===============================================================
# Main Training Pipeline with Mode Selection
# ===============================================================
def rhythm_to_score_training_pipeline(
    use_rhythm=True,  # NEW: Toggle between rhythm+score and score-only modes
    resume_from_checkpoint=False,
    checkpoint_path=None,
    select_training_dataset_file="path/to/your/data.txt",
    number_of_training_epochs=50,
    attention_span_in_tokens=128,
    training_batch_size=128,
    hidden_dimension_size=256,
    learning_rate=0.001
):
    """
    Main pipeline to orchestrate the training process.

    Args:
        use_rhythm: If True, use rhythm+score mode. If False, use score-only mode.
        Other args remain the same...
    """

    print(f"Training mode: {'rhythm+score' if use_rhythm else 'score-only'}")

    if resume_from_checkpoint and checkpoint_path and os.path.exists(checkpoint_path):
        try:
            model, rhythm2int, int2rhythm, score2int, int2score, checkpoint, start_epoch = \
                load_checkpoint(checkpoint_path, device)

            # Verify mode consistency
            loaded_use_rhythm = checkpoint.get('use_rhythm', True)
            if loaded_use_rhythm != use_rhythm:
                print(f"Warning: Loaded model was trained in {'rhythm+score' if loaded_use_rhythm else 'score-only'} mode, "
                      f"but requested mode is {'rhythm+score' if use_rhythm else 'score-only'}. Using loaded mode.")
                use_rhythm = loaded_use_rhythm

            with open(select_training_dataset_file, 'r', encoding='utf-8') as file:
                text_data = file.read()

            sequence_length = min(attention_span_in_tokens, 1024)
            rhythm_sequences, score_sequences = prepare_rhythm_to_score_data(text_data, sequence_length)

            if len(rhythm_sequences) == 0:
                print("Error: No training sequences created.")
                return

            print(f"Resuming training from epoch {start_epoch}")
            actual_batch_size = min(training_batch_size, len(rhythm_sequences))

            train_rhythm_to_score_model(
                model, rhythm_sequences, score_sequences,
                rhythm2int, score2int, int2rhythm, int2score,
                epochs=number_of_training_epochs,
                learning_rate=learning_rate, batch_size=actual_batch_size,
                max_seq_length=sequence_length, start_epoch=start_epoch,
                optimizer_state=checkpoint['optimizer_state_dict'],
                use_rhythm=use_rhythm
            )

        except Exception as e:
            print(f"Error loading checkpoint: {e}\nStarting fresh training instead...")
            resume_from_checkpoint = False

    if not resume_from_checkpoint:
        with open(select_training_dataset_file, 'r', encoding='utf-8') as file:
            text_data = file.read()

        print("Preparing rhythm-score pairs...")
        sequence_length = min(attention_span_in_tokens, 1024)
        rhythm_sequences, score_sequences = prepare_rhythm_to_score_data(text_data, sequence_length)

        if len(rhythm_sequences) == 0:
            print("Error: No training sequences created.")
            return

        rhythm2int, int2rhythm, score2int, int2score = prepare_vocabularies(
            rhythm_sequences, score_sequences, use_rhythm=use_rhythm
        )

        if use_rhythm:
            print(f"Rhythm vocabulary size: {len(rhythm2int)}")
        print(f"Score vocabulary size: {len(score2int)}")

        actual_batch_size = min(training_batch_size, len(rhythm_sequences))

        # Set rhythm_vocab_size to 1 if not using rhythm (dummy value)
        rhythm_vocab_size = len(rhythm2int) if use_rhythm else 1

        model = RhythmToScoreLSTM(
            rhythm_vocab_size=rhythm_vocab_size,
            score_vocab_size=len(score2int),
            hidden_dim=hidden_dimension_size,
            num_layers=2,
            use_rhythm=use_rhythm
        ).to(device)

        print(f"Model initialized on {device}")
        print(f"Starting training for {number_of_training_epochs} epochs...")
        print(f"Batch size: {actual_batch_size}")

        train_rhythm_to_score_model(
            model, rhythm_sequences, score_sequences,
            rhythm2int, score2int, int2rhythm, int2score,
            epochs=number_of_training_epochs,
            learning_rate=learning_rate, batch_size=actual_batch_size,
            max_seq_length=sequence_length, start_epoch=0,
            use_rhythm=use_rhythm
        )

    print("Training completed!")

    # Return the model and vocabularies
    return model, rhythm2int, int2rhythm, score2int, int2score


## Train

In [None]:
#@title Train

%load_ext tensorboard
%tensorboard --logdir runs/

In [None]:
#@title Hyperparameter and Training (non-Rhythm)

number_of_training_epochs = 500 #@param {type:"slider", min:1, max:3000, step:1}
resume_from_checkpoint = False #@param {type:"boolean"}
checkpoint_path        = "" #@param {type:"string"}
save_checkpoint_every  = 10 #@param {type:"slider", min:1, max:100, step:10}
attention_span_in_tokens = 512 #@param {type:"slider", min:1, max:512, step:1}
training_batch_size    = 1024 #@param {type:"slider", min:1, max:1024, step:1}
hidden_dimension_size  = 256 #@param {type:"slider", min:1, max:512, step:1}
learning_rate          = 0.005 #@param {type:"number"}
use_rhythm             = False #@param {type:"boolean"}


model, rhythm2int, int2rhythm, score2int, int2score = rhythm_to_score_training_pipeline(
    use_rhythm=use_rhythm,  # Use only previous scores
    resume_from_checkpoint=resume_from_checkpoint,
    checkpoint_path=checkpoint_path,
    select_training_dataset_file=select_training_dataset_file,
    number_of_training_epochs=number_of_training_epochs,
    attention_span_in_tokens=attention_span_in_tokens,
    training_batch_size=training_batch_size,
    hidden_dimension_size=hidden_dimension_size,
    learning_rate=learning_rate
)

In [None]:
#@title Hyperparameter and Training (non-Rhythm)

number_of_training_epochs = 500 #@param {type:"slider", min:1, max:3000, step:1}
resume_from_checkpoint = False #@param {type:"boolean"}
checkpoint_path        = "" #@param {type:"string"}
save_checkpoint_every  = 100 #@param {type:"slider", min:1, max:100, step:10}
attention_span_in_tokens = 512 #@param {type:"slider", min:1, max:512, step:1}
training_batch_size    = 1024 #@param {type:"slider", min:1, max:1024, step:1}
hidden_dimension_size  = 256 #@param {type:"slider", min:1, max:512, step:1}
learning_rate          = 0.005 #@param {type:"number"}
use_rhythm             = True #@param {type:"boolean"}


model, rhythm2int, int2rhythm, score2int, int2score = rhythm_to_score_training_pipeline(
    use_rhythm=use_rhythm,  # Use only previous scores
    resume_from_checkpoint=resume_from_checkpoint,
    checkpoint_path=checkpoint_path,
    select_training_dataset_file=select_training_dataset_file,
    number_of_training_epochs=number_of_training_epochs,
    attention_span_in_tokens=attention_span_in_tokens,
    training_batch_size=training_batch_size,
    hidden_dimension_size=hidden_dimension_size,
    learning_rate=learning_rate
)

In [None]:
#@title Upload Tensorboard logs
!tensorboard dev upload --logdir logs \
  --name "Rhythm-awared Following Score Prediction Model" \
  --description "リズム入力を考慮した後続楽譜推定モデル"

In [None]:
#@title Upload checkpoint & runs
!zip -r /content/runs.zip /content/runs
!zip -r /content/checkpoints.zip /content/checkpoints
from google.colab import drive
drive.mount('/content/drive')
!cp /content/runs.zip /content/checkpoints.zip /content/drive/MyDrive/
!rm /content/runs.zip /content/checkpoints.zip

## Evaluate Rythm2Score model

In [None]:
#@title Generate TXT and MIDI file
prompt = "wait9 p<mask> p<mask> p<mask> p<mask> wait4 endp<mask>" #@param {type:"string"}

tokens_to_generate = 8192 #@param {type:"slider", min:0, max:8192, step:16}
time_coefficient   = 4 #@param {type:"slider", min:1, max:16, step:1}
top_k_coefficient  = 5 #@param {type:"slider", min:2, max:50, step:1}
%cd /content/

if model is not None:
    generated = generate_score_from_rhythm(model, prompt, rhythm2int, int2score)
with open("../content/output.txt", "w") as outfile:
    outfile.write(generated)
import tqdm
import os
import dill as pickle
from pathlib import Path
import random
import numpy as np
import pandas as pd
from math import floor
from pyknon.genmidi import Midi
from pyknon.music import NoteSeq, Note
import music21
import random
import os, argparse

# default settings: sample_freq=12, note_range=62

def decoder(filename):

    filedir = '/content/'

    notetxt = filedir + filename

    with open(notetxt, 'r') as file:
        notestring=file.read()

    score_note = notestring.split(" ")

    # define some parameters (from encoding script)
    sample_freq=sample_freq_variable
    note_range=note_range_variable
    note_offset=note_offset_variable
    chamber=chamber_option
    numInstruments=number_of_instruments

    # define variables and lists needed for chord decoding
    speed=time_coefficient/sample_freq
    piano_notes=[]
    violin_notes=[]
    time_offset=0

    # start decoding here
    score = score_note

    i=0

    # for outlier cases, not seen in sonat-1.txt
    # not exactly sure what scores would have "p_octave_" or "eoc" (end of chord?)
    # it seems to insert new notes to the score whenever these conditions are met
    while i<len(score):
        if score[i][:9]=="p_octave_":
            add_wait=""
            if score[i][-3:]=="eoc":
                add_wait="eoc"
                score[i]=score[i][:-3]
            this_note=score[i][9:]
            score[i]="p"+this_note
            score.insert(i+1, "p"+str(int(this_note)+12)+add_wait)
            i+=1
        i+=1


    # loop through every event in the score
    for i in tqdm.tqdm(range(len(score))):

        # if the event is a blank, space, "eos" or unknown, skip and go to next event
        if score[i] in ["", " ", "<eos>", "<unk>"]:
            continue

        # if the event starts with 'end' indicating an end of note
        elif score[i][:3]=="end":

            # if the event additionally ends with eoc, increare the time offset by 1
            if score[i][-3:]=="eoc":
                time_offset+=1
            continue

        # if the event is wait, increase the timestamp by the number after the "wait"
        elif score[i][:4]=="wait":
            time_offset+=int(score[i][4:])
            continue

        # in this block, we are looking for notes
        else:
            # Look ahead to see if an end<noteid> was generated
            # soon after.
            duration=1
            has_end=False
            note_string_len = len(score[i])
            for j in range(1,200):
                if i+j==len(score):
                    break
                if score[i+j][:4]=="wait":
                    duration+=int(score[i+j][4:])
                if score[i+j][:3+note_string_len]=="end"+score[i] or score[i+j][:note_string_len]==score[i]:
                    has_end=True
                    break
                if score[i+j][-3:]=="eoc":
                    duration+=1

            if not has_end:
                duration=12

            add_wait = 0
            if score[i][-3:]=="eoc":
                score[i]=score[i][:-3]
                add_wait = 1

            try:
                new_note=music21.note.Note(int(score[i][1:])+note_offset)
                new_note.duration = music21.duration.Duration(duration*speed)
                new_note.offset=time_offset*speed
                if score[i][0]=="v":
                    violin_notes.append(new_note)
                else:
                    piano_notes.append(new_note)
            except:
                print("Unknown note: " + score[i])




            time_offset+=add_wait

    # list of all notes for each instrument should be ready at this stage

    # creating music21 instrument objects

    piano=music21.instrument.fromString("Accordion")
    violin=music21.instrument.fromString("Violin")

    # insert instrument object to start (0 index) of notes list

    piano_notes.insert(0, piano)
    violin_notes.insert(0, violin)
    # create music21 stream object for individual instruments

    piano_stream=music21.stream.Stream(piano_notes)
    violin_stream=music21.stream.Stream(violin_notes)
    # merge both stream objects into a single stream of 2 instruments
    note_stream = music21.stream.Stream([piano_stream, violin_stream])


    note_stream.write('midi', fp="/content/"+filename[:-4]+".mid")
    print("Done! Decoded midi file saved to 'content/'")


decoder('output.txt')
from google.colab import files
files.download('/content/output.mid')

In [None]:
#@title (OPTION 1) Save the trained Model from memory
torch.save(net, '/content/trained_model.h5')

In [None]:
#@title (OPTION 2) Download Super Chamber Piano Pre-Trained Chamber Model
%cd /content/
!wget 'https://github.com/asigalov61/SuperPiano/raw/master/trained_model.h5'

In [None]:
#@title (OPTION 2) Load existing/pre-trained Model checkpoint
model = torch.load('/content/trained_model.h5', map_location='cpu', weights_only=False)
model.eval()

# Generate, plot, and listen to the output

In [None]:
#@title Generate TXT and MIDI file
seed_prompt = "p24" #@param {type:"string"}
tokens_to_generate = 8192 #@param {type:"slider", min:0, max:8192, step:16}
time_coefficient = 4 #@param {type:"slider", min:1, max:16, step:1}
top_k_coefficient = 5 #@param {type:"slider", min:2, max:50, step:1}
%cd /content/
with open("../content/output.txt", "w") as outfile:
    outfile.write(' '.join([int2word[int(int_)] for int_ in model.predict(seed_seq=seed_prompt, pred_len=tokens_to_generate, top_k=top_k_coefficient)]))
import tqdm
import os
import dill as pickle
from pathlib import Path
import random
import numpy as np
import pandas as pd
from math import floor
from pyknon.genmidi import Midi
from pyknon.music import NoteSeq, Note
import music21
import random
import os, argparse

# default settings: sample_freq=12, note_range=62

def decoder(filename):

    filedir = '/content/'

    notetxt = filedir + filename

    with open(notetxt, 'r') as file:
        notestring=file.read()

    score_note = notestring.split(" ")

    # define some parameters (from encoding script)
    sample_freq=sample_freq_variable
    note_range=note_range_variable
    note_offset=note_offset_variable
    chamber=chamber_option
    numInstruments=number_of_instruments

    # define variables and lists needed for chord decoding
    speed=time_coefficient/sample_freq
    piano_notes=[]
    violin_notes=[]
    time_offset=0

    # start decoding here
    score = score_note

    i=0

    # for outlier cases, not seen in sonat-1.txt
    # not exactly sure what scores would have "p_octave_" or "eoc" (end of chord?)
    # it seems to insert new notes to the score whenever these conditions are met
    while i<len(score):
        if score[i][:9]=="p_octave_":
            add_wait=""
            if score[i][-3:]=="eoc":
                add_wait="eoc"
                score[i]=score[i][:-3]
            this_note=score[i][9:]
            score[i]="p"+this_note
            score.insert(i+1, "p"+str(int(this_note)+12)+add_wait)
            i+=1
        i+=1


    # loop through every event in the score
    for i in tqdm.tqdm(range(len(score))):

        # if the event is a blank, space, "eos" or unknown, skip and go to next event
        if score[i] in ["", " ", "<eos>", "<unk>"]:
            continue

        # if the event starts with 'end' indicating an end of note
        elif score[i][:3]=="end":

            # if the event additionally ends with eoc, increare the time offset by 1
            if score[i][-3:]=="eoc":
                time_offset+=1
            continue

        # if the event is wait, increase the timestamp by the number after the "wait"
        elif score[i][:4]=="wait":
            time_offset+=int(score[i][4:])
            continue

        # in this block, we are looking for notes
        else:
            # Look ahead to see if an end<noteid> was generated
            # soon after.
            duration=1
            has_end=False
            note_string_len = len(score[i])
            for j in range(1,200):
                if i+j==len(score):
                    break
                if score[i+j][:4]=="wait":
                    duration+=int(score[i+j][4:])
                if score[i+j][:3+note_string_len]=="end"+score[i] or score[i+j][:note_string_len]==score[i]:
                    has_end=True
                    break
                if score[i+j][-3:]=="eoc":
                    duration+=1

            if not has_end:
                duration=12

            add_wait = 0
            if score[i][-3:]=="eoc":
                score[i]=score[i][:-3]
                add_wait = 1

            try:
                new_note=music21.note.Note(int(score[i][1:])+note_offset)
                new_note.duration = music21.duration.Duration(duration*speed)
                new_note.offset=time_offset*speed
                if score[i][0]=="v":
                    violin_notes.append(new_note)
                else:
                    piano_notes.append(new_note)
            except:
                print("Unknown note: " + score[i])




            time_offset+=add_wait

    # list of all notes for each instrument should be ready at this stage

    # creating music21 instrument objects

    piano=music21.instrument.fromString("Accordion")
    violin=music21.instrument.fromString("Violin")

    # insert instrument object to start (0 index) of notes list

    piano_notes.insert(0, piano)
    violin_notes.insert(0, violin)
    # create music21 stream object for individual instruments

    piano_stream=music21.stream.Stream(piano_notes)
    violin_stream=music21.stream.Stream(violin_notes)
    # merge both stream objects into a single stream of 2 instruments
    note_stream = music21.stream.Stream([piano_stream, violin_stream])


    note_stream.write('midi', fp="/content/"+filename[:-4]+".mid")
    print("Done! Decoded midi file saved to 'content/'")


decoder('output.txt')
from google.colab import files
files.download('/content/output.mid')

In [None]:
#@title Plot, Graph, and Listen to the Output :)
graphs_length_inches = 18 #@param {type:"slider", min:0, max:20, step:1}
notes_graph_height = 6 #@param {type:"slider", min:0, max:20, step:1}
highest_displayed_pitch = 92 #@param {type:"slider", min:1, max:128, step:1}
lowest_displayed_pitch = 24 #@param {type:"slider", min:1, max:128, step:1}

%cd /content/

midi_data = pretty_midi.PrettyMIDI('/content/output.mid')

def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))



roll = np.zeros([int(graphs_length_inches), 128])
# Plot the output

track = Multitrack('/content/output.mid')
print(track)
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
fig, ax = track.plot()
fig.set_size_inches(graphs_length_inches, notes_graph_height)
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
ax2 = plot_piano_roll(midi_data, int(lowest_displayed_pitch), int(highest_displayed_pitch))
plt.show(block=False)


FluidSynth("/content/font.sf2", 16000).midi_to_audio('/content/output.mid', '/content/output.wav')
Audio('/content/output.wav', rate=16000)

In [None]:
#@title Reward yourself by making a nice Arc diagram from the generated output/MIDI file
%cd '/content/arc-diagrams'

midi_file = '/content/output.mid'
plot_title = "Super Chamber Piano Output Arc Diagram"

# midi_file = 'midis/fuer_elise.mid'
# plot_title = "Für Elise (Beethoven)"


def stringify_notes(midi_file, track_number ):

    mid = MidiFile(midi_file)
    track_notes = {}
    for i, track in enumerate(mid.tracks):
        track_notes[i] = ''
        for msg in track:
            if( msg.type == 'note_on'):
                track_notes[i] += str(msg.note) +'n'
            if( msg.type == 'note_off'):
                track_notes[i] += str(msg.note) +'f'
    return track_notes[track_number]
try:
  plot_arc_diagram(stringify_notes(midi_file, 0), plot_title)
  from google.colab import files
  files.download('/content/arc-diagrams/output.png')
except:
  print('Could not plot the diagram. Try again/another composition.')



# Save what you want to Google Drive (standard GD connect code)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!python3 --version

In [None]:
print (int2word)