# Imports

In [1]:
import numpy as np
from matplotlib import pyplot as plt
import torch.optim as optim
import kagglehub  # for dataset
import os
import string
import torch
import random
from sklearn.utils import shuffle
from collections import defaultdict
from math import ceil

In [2]:
try:
    import mido  # for parsing midi files
    from mido import MidiFile, MidiTrack, Message, MetaMessage
except ModuleNotFoundError:
    !pip install mido
    import mido
    from mido import MidiFile, MidiTrack, Message, MetaMessage

Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mido
Successfully installed mido-1.3.3


In [3]:
try:
    from midi2audio import FluidSynth
    from IPython import display
except ModuleNotFoundError:
    !apt install fluidsynth
    !pip install --upgrade pyfluidsynth
    !pip install midi2audio
    from midi2audio import FluidSynth
    from IPython import display

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fluid-soundfont-gm libevdev2 libfluidsynth3 libgudev-1.0-0 libinput-bin libinput10
  libinstpatch-1.0-2 libmd4c0 libmtdev1 libqt5core5a libqt5dbus5 libqt5gui5 libqt5network5
  libqt5svg5 libqt5widgets5 libwacom-bin libwacom-common libwacom9 libxcb-icccm4 libxcb-image0
  libxcb-keysyms1 libxcb-render-util0 libxcb-util1 libxcb-xinerama0 libxcb-xinput0 libxcb-xkb1
  libxkbcommon-x11-0 qsynth qt5-gtk-platformtheme qttranslations5-l10n timgm6mb-soundfont
Suggested packages:
  fluid-soundfont-gs qt5-image-formats-plugins qtwayland5 jackd
The following NEW packages will be installed:
  fluid-soundfont-gm fluidsynth libevdev2 libfluidsynth3 libgudev-1.0-0 libinput-bin libinput10
  libinstpatch-1.0-2 libmd4c0 libmtdev1 libqt5core5a libqt5dbus5 libqt5gui5 libqt5network5
  libqt5svg5 libqt5widgets5 libwacom-bin libwacom-common libwacom9 libxcb-icc

# Midi Functions

code from: https://raw.githubusercontent.com/TianyangZhan/AutoMusicGeneration/refs/heads/master/midi_parser.py

In [4]:
# GLOBAL PARAMETERS
unit_time = 0.02  # unit: second 	# the time unit for each time slice (column in the piano roll)
highest_note = 127  # pitch value
lowest_note = 0  # pitch value
pitch_dimension = highest_note - lowest_note + 1

In [5]:
def parseMidi(midi_file):
	'''
		parse midi file into a piano roll and save temporal values

		params:		midi_file: a midi files for parsing

		output:	[pianoroll, tempo, resolution]
					pianoroll:	a matrix of size (timestep x pitch_dimension)
					tempo: the tempo value from midi file
					resolution: the resolution value from midi file

	'''

	midi_data = MidiFile(midi_file)

	# get music tempo info
	resolution = midi_data.ticks_per_beat

	track_tempos = [event.tempo for track in midi_data.tracks for event in track if str(event.type) == "set_tempo"]
	# track_tempos += [0]
	try:
		tempo = int(60000000/max(track_tempos)) # get the max track tempo
	except:
		# print(midi_file)
		tempo = 60000000

	ticks_per_time = resolution*tempo*unit_time/60.0

	#Get maximum ticks across all tracks
	total_ticks =0
	for track in midi_data.tracks:
		sum_ticks = sum([event.time for event in track if str(event.type) in ['note_on','note_off','end_of_track']])
		total_ticks = max(total_ticks,sum_ticks)

	time_slices = int(ceil(total_ticks/ticks_per_time))

	# slice file into piano roll matrix
	piano_roll = np.zeros((pitch_dimension, time_slices), dtype=int)
	note_states = defaultdict(lambda:-1)

	for track in midi_data.tracks:

		total_ticks = 0

		for event in track:

			if str(event.type) == 'note_on' and event.velocity > 0 and event.note in range(lowest_note,highest_note+1):
			# note is played

				total_ticks += event.time
				time_slice_idx = int(total_ticks/ticks_per_time)
				# count note as played
				note_idx = event.note - lowest_note
				piano_roll[note_idx][time_slice_idx] = 1
				note_states[note_idx] = time_slice_idx

			elif (str(event.type) == 'note_off' or str(event.type) == 'note_on') and event.note in range(lowest_note,highest_note+1):
			# note is not played

				total_ticks += event.time
				time_slice_idx = int(total_ticks/ticks_per_time)

				if note_states[note_idx] != -1:	 # note was played
					piano_roll[note_idx][note_states[note_idx] : time_slice_idx] = 1
					note_states[note_idx] = -1

	return piano_roll.T, tempo, resolution

In [6]:
#preprocess data directory
def getData(file_paths):
	'''
		parse midi files into a piano rolls

		params:		file_paths: paths to midi files (.midi or .mid)

		output:		pianoroll_lst:	a list of (N = #files) matrices of size (timestep x pitch_dimension)

	'''
	# print("Parsing MIDI files", file_paths)

	pianoroll_lst = []

	for path in file_paths:
		pr,_,_ = parseMidi(path) # don't need temporal values for training
		pianoroll_lst.append(pr)

	return pianoroll_lst

In [7]:
def createTrainData(pianoroll_lst, x_length, y_length, tight_window=False):
	'''
		create X and Y samples from piano roll matrix with a sliding window

		params:		pianoroll_lst: a list of piano roll matrices
					x_length: the length of input sequence. for best performance: x_length > y_length
					y_length: the length of output sequence. for best performance: y_length < x_length
					tight_window: default: False. the step size for shifting the sliding window.
								  tight_window=True: shift sliding window by y_length
								  tight_window=False: shift sliding window by x_length

		output:		[x,y]: shuffled data for training
	'''

	x = []
	y = []

	for piano_roll in pianoroll_lst:
		pos = 0
		while pos + x_length + y_length < piano_roll.shape[0]:
			x.append(piano_roll[pos:pos+x_length])
			y.append(piano_roll [pos+x_length: pos+x_length+y_length])
			if tight_window:
				pos += y_length
			else:
				pos += x_length

	return shuffle(np.array(x),np.array(y))

In [8]:
# NN output to pianoroll
def outputPianoRoll(output, note_threshold=0.1):
	'''
		convert a list of output to piano roll

		params:		output: a list of prediction result sequence
					note_threshold: default: 0.1. the threshold for a note to be played

		output:		pianoroll_lst:	a list of matrices of size (timestep x pitch_dimension)
	'''
	pianoroll_lst = []
	for sequence in output:

		for timeslice in sequence:
			result = np.zeros(timeslice.shape)
			note_on = [i for i in range(len(timeslice)) if timeslice[i] > note_threshold]
			result[note_on] = 1
			pianoroll_lst.append(result)

	return np.array(pianoroll_lst)

In [9]:
# pianoroll to MIDI
def outputMidi(output_dir, piano_roll, tempo=120, resolution=480, scale=1, velocity=65):
	'''
		convert the piano roll to midi file

		params:		output_dir: the directory to store output file
					piano_roll: a list of (N = #files) matrices of size (timestep x pitch_dimension)
					tempo: default: 120			the tempo value from midi file
					resolution: default: 480	the resolution value from midi file
					scale: default:1			the number of ticks per time slice.	for best performance: = length of sequence in one prediction
					velocity: default:65		the speed/strength to play a note
	'''


	ticks_per_time=(resolution*tempo*unit_time)/60.0

	mid = MidiFile(ticks_per_beat = int(resolution))

	track = MidiTrack()
	track.append(MetaMessage('set_tempo', tempo = int(60000000/tempo), time=0))

	note_events = ["note_off","note_on"]
	last_state = np.zeros(pitch_dimension)
	last_index = 0

	for current_index, current_state in enumerate(np.concatenate((piano_roll, last_state.reshape(1, -1)), axis=0)): # terminate note at the end

		delta = current_state - last_state
		last_state = current_state

		for i in range(len(delta)):
			if delta[i] == 1 or delta[i] == -1: # play/stop note
				event = Message(note_events[delta[i] > 0], time=int(scale*(current_index-last_index)*ticks_per_time), velocity=velocity, note=(lowest_note+i))
				track.append(event)
				last_index = current_index
			else:
				pass # don't change note state

	end = MetaMessage('end_of_track', time=1)
	track.append(end)

	mid.tracks.append(track)
	mid.save(output_dir)

In [10]:
def play_midi(path: str) -> display.Audio:
    """
    Function from CSC311 lab 6.
    Displays a playable wav of the given midi file.
    """
    FluidSynth("font.sf2").midi_to_audio(path, 'tmp.wav')
    return display.Audio("tmp.wav")

In [11]:
def display_piano_roll(piano_roll: np.ndarray, name: str) -> None:
    """
    https://medium.com/analytics-vidhya/convert-midi-file-to-numpy-array-in-python-7d00531890c
    :param piano_roll:
    """
    plt.plot(range(piano_roll.shape[0]), np.multiply(piano_roll, range(1, piano_roll.shape[1] + 1)), marker='.', markersize=1, linestyle='')
    plt.title(name)
    plt.xlabel('time step')
    plt.ylabel('note')
    plt.ylim(bottom=1)
    plt.show()

# Data

In [12]:
path = kagglehub.dataset_download("hansespinosa2/nin-video-game-midis", force_download=True)
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/hansespinosa2/nin-video-game-midis?dataset_version_number=2...


100%|██████████| 9.24M/9.24M [00:00<00:00, 71.1MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2


In [13]:
# I found some unreadable files / no name / no tempo
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/Wii/MarioKartWii/MushroomGorgeFourHands.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/Wii/SuperPaperMario/FrancisBattleTwoPianos.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/Wii/SuperMarioGalaxy/EndTitleTwoPianos.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/SNES/LufiaTheFortressofDoom/EndingFourHands.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/SNES/Terranigma/EvergreenForest.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/NDS/Pok�monDiamondVersionPok�monPearlVersion/Route205DayTwoPianos.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Other/MUL/TheLegendofZeldaBreathoftheWild/SpiritOrbObtainedFourHands.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Other/MUL/SonicUnleashed/WindmillIsleDayTwoPianos.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Sony/PS5/RatchetClankRiftApart/OdetoNefarious.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/GBA/MarioKartSuperCircuit/Credits.mid

# remove hidden files
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Nintendo/3DS/FireEmblemAwakening/.mid
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Other/PC/BugFablesTheEverlastingSapling/.mid

# large file
!rm /root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Other/PC/Lenen2EarthenMiraculousSword/MonoEyeIronicFATE.mid

In [14]:
def get_file_paths(path: str) -> list[str]:
    """
    Returns a list of all absolute file paths for each song.
    """
    file_paths = []
    for entry in os.listdir(path):
        full_path = os.path.join(path, entry)
        if os.path.isdir(full_path):
            file_paths.extend(get_file_paths(full_path))
        else:
            file_paths.append(full_path)
    return file_paths

In [15]:
def sort_files_by_name(file_paths: list[str]) -> list[str]:
    """
    assume song names are unique.
    for consistent train/val/test split after setting random seed.
    """
    names = dict()
    for full_path in file_paths:
        name = ''.join(full_path.split('/')[12:15][::-1])
        names[name] = full_path
    keys = list(names.keys())
    keys.sort()
    return [names[key] for key in keys]

In [16]:
# get a list of all absolute paths to the midi files in the dataset
file_paths = get_file_paths(path + '/nin_midi_files/nin_midi_files')

# Model

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Seq2Seq(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, batch_size, sequence_length):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        super(Seq2Seq, self).__init__()

        self.encoder = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout = 0.4)  # dropout should be validated
        # self.encoder = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)  # dropout should be validated
        self.dropout = nn.Dropout(p=0.1)
        # consider birdirectional= False, bias = True, proj_size, num_layer =1
        # or consider dropout(encoder), its only between different stacked layers
        self.decoder = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        self.fc = nn.Linear(hidden_size, input_size)
        self.sigmoid = nn.Sigmoid()


    def forward(self, X, Y):
        _, (h, c) = self.encoder(X)
        output, _ = self.decoder(Y, (h, c))
        output = self.fc(output)
        # output = self.softmax(output)
        # output = self.sigmoid(output)  # TODO

        #output = self.sigmoid(output)
        #output = torch.from_numpy(np.where(output > 0.5, 1, 0)).type(torch.float32)

        return output

In [18]:
random.seed(10)
sorted_file_paths = sort_files_by_name(file_paths)
random.shuffle(sorted_file_paths)

train_files = sorted_file_paths[:3500]
valid_files = sorted_file_paths[3500:]

In [21]:
# piano_roll, tempo, resolution = parseMidi(touhou_bad_apple)
#X, Y = createTrainData([piano_roll], 500, 50)
X, Y = createTrainData(getData(train_files[0:70]), 149, 1)  # slow
# train_data_set = createTrainData(getData(train_files[0:20]), 100, 10)
#valid_data = createTrainData(getData(valid_files), 100, 10)
# train_data =  torch.from_numpy(train_data_set[0]).type(torch.float32)
# train_labels = torch.from_numpy(train_data_set[1]).type(torch.float32)
# print(train_data.shape)
# print(X)
# print(Y)
#file_paths1 = get_file_paths(path + '/nin_midi_files/nin_midi_files')
#FluidSynth("font.sf2").midi_to_audio(file_paths[0], 'abc.wav')
#display.Audio("abc.wav")
x = torch.from_numpy(X).type(torch.float32)
y = torch.from_numpy(Y).type(torch.float32)
lst = []
for batch_size in range(x.shape[0]):
  temp = np.vstack((np.array(x[batch_size]), np.array(y[batch_size])))
  #print(temp.shape)
  lst.append(temp)

print(lst[0].shape)
print(x.shape)

(150, 128)
torch.Size([4450, 149, 128])


In [24]:
# rnn = nn.LSTM(10, 20, 2)
# input = torch.randn(5, 3, 10)
# h0 = torch.randn(2, 3, 20)
# c0 = torch.randn(2, 3, 20)
# output, (hn, cn) = rnn(input, (h0, c0))
# print(output.shape)

# test the actual class
# print("X,Y: ", x.shape,y.shape)
# x -> batch size , sequence length , note vocab
# y-> batch size , target sequence length , note vocab
#model = Seq2Seq()
# encoder = nn.LSTM(128, 128, 1, batch_first=True)
# context, (h ,c) = encoder(x)
# print("context:",context[0].shape)
# decoder = nn.LSTM(128, 128, 1, batch_first=True)
# context, (h, c) = decoder(y, (h,c))
# print(h.shape)
# print(output.shape)
# fc = nn.Linear(128,128)
# output = fc(output)
# print(output.shape)
# sig = nn.Sigmoid()
# output = sig(output)
# print(output.shape)
# print(output[0][0].shape)
# print(output[0][0])
# output = np.where(output > 0.5,1, 0)
# print(output[0][0])
# print(len(output[0][0]))
#print(output[0])
#output = model(input)
#print(output.shape)

In [25]:
from torch.utils.data import Dataset, DataLoader


train_data_loader = DataLoader(lst, batch_size=100, shuffle=True)
num = 0
for data in train_data_loader:
    print(data.shape)
    print(data[:,100:,:].shape)
    print(data[:,:100,:].shape)
    if num ==4:
      break;
    num+=1

torch.Size([100, 150, 128])
torch.Size([100, 50, 128])
torch.Size([100, 100, 128])
torch.Size([100, 150, 128])
torch.Size([100, 50, 128])
torch.Size([100, 100, 128])
torch.Size([100, 150, 128])
torch.Size([100, 50, 128])
torch.Size([100, 100, 128])
torch.Size([100, 150, 128])
torch.Size([100, 50, 128])
torch.Size([100, 100, 128])
torch.Size([100, 150, 128])
torch.Size([100, 50, 128])
torch.Size([100, 100, 128])


In [26]:
def collate_data(batch, input_seq_len):
  X, labels = [], []
  for data in batch:
    #print(data.shape)
    labels.append(data[input_seq_len:,:])
    X.append(data[:input_seq_len,:])

  X = torch.from_numpy(np.array(X)).type(torch.float32)
  labels = torch.from_numpy(np.array(labels)).type(torch.float32)
  #print(X.shape ,labels.shape, "COLLATE")

  return X, labels

In [27]:
#!pip install torchmetrics
import torch
try:
    import torchmetrics
except ModuleNotFoundError:
    !pip install torchmetrics
    import torchmetrics
from torchmetrics.classification import MultilabelF1Score

def accuracy(model, dataset, input_seq_len, max=1000):
    """
    Estimate the accuracy of `model` over the `dataset`.
    We will take the **most probable class**
    as the class predicted by the model.

    Parameters:
        `model`   - An object of class nn.Module
        `dataset` - A dataset of the same type as `train_data`.
        `max`     - The max number of samples to use to estimate
                    model accuracy

    Returns: a floating-point value between 0 and 1.
    """

    correct, total = 0, 0
    dataloader = DataLoader(dataset,
                            batch_size=1,  # use batch size 1 to prevent padding
                            collate_fn= lambda batch :collate_data(batch, input_seq_len=100) )
    pred = []
    targets = []
    for i, (x, t) in enumerate(dataloader):
        z = model(x,t)
        cond = z > 0.5
        z2 =torch.sigmoid(z)
        z2 = torch.where(cond, 1, 0)
        #print(z2.shape, t.shape, "shapes")
        pred.append(z2)
        targets.append(t)




    a = torch.cat(pred, dim =0)
    b = torch.cat(targets, dim =0)
    print(a.shape, b.shape)
    metric = torchmetrics.classification.Accuracy(task = "multiclass", num_classes=128)
    acc = metric(a, b)
    return acc.item()

def accuracy2(model, dataset, input_seq_len, max=1000):
    """
    Estimate the accuracy of `model` over the `dataset`.
    We will take the **most probable class**
    as the class predicted by the model.

    Parameters:
        `model`   - An object of class nn.Module
        `dataset` - A dataset of the same type as `train_data`.
        `max`     - The max number of samples to use to estimate
                    model accuracy

    Returns: a floating-point value between 0 and 1.
    """

    correct, total = 0, 0
    dataloader = DataLoader(dataset,
                            batch_size=1,  # use batch size 1 to prevent padding
                            collate_fn= lambda batch :collate_data(batch, input_seq_len=100) )
    pred = []
    targets = []
    count  =0
    for i, (x, t) in enumerate(dataloader):
        z = model(x,t)
        #z = torch.sigmoid(z)
        #n = torch.nn.Sigmoid()
        #z = n(z)
        z = (z > 0.1).int()
        pred.append(z.squeeze(0))
        targets.append(t.squeeze(0))
        count = i
        #print(z.shape, t.shape, x.shape, "SHAPES") #  both are 1, 10, 128


    print(z)
    a = torch.cat(pred, dim =0)
    b = torch.cat(targets, dim =0)
    print(a.shape, b.shape)
    metric = torchmetrics.classification.MultilabelAccuracy(num_labels=128)
    acc = metric(a, b)
    return acc.item()
    #return np.count(a != b)

In [28]:
# device = 'cuda' if torch.cuda.is_available()  else 'cpu'

In [29]:
!pip install --upgrade torch
minecraft_sweden = '/root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Other/PC/Minecraft/Sweden.mid'
touhou_bad_apple = '/root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Other/PC/Touhou4LotusLandStory/BadApple.mid'
ffxiv_uldah = '/root/.cache/kagglehub/datasets/hansespinosa2/nin-video-game-midis/versions/2/nin_midi_files/nin_midi_files/Other/PC/FinalFantasyXIV/TheTwinFacesofFateTheThemeofUldah.mid'



In [31]:
train_data_loader = DataLoader(lst, batch_size=100, shuffle=True)
for data in train_data_loader:
  X,label = collate_data(data, 500)
  print(X.shape, label.shape)
  break;

# might have to add in sequence ength for data,target as paramter
def train(train_data, val_data, learning_rate, batch_size, num_epochs, model, input_seq_len):

    dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True,
                            collate_fn= lambda batch :collate_data(batch, input_seq_len))
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    #loss_fn = nn.CrossEntropyLoss()  # change back to BCE later maybe

    #imbalance_ratio = 100
    positive_weight = torch.tensor([imbalance_ratio])
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=positive_weight)

    iters, train_loss, train_acc, val_acc = [], [], [], []
    iter_count = 0 # count the number of iterations that has passed
    plot_every = 2
    for epoch in range(num_epochs):
        print("Epoch: ", epoch)
        for ind, (data, labels) in enumerate(dataloader):
            z = model(data, labels)
            print(z.shape)
            print("model shape", z.shape)
            #z = np.where(z > 0.5,1, 0)
            #??z = torch.tensor(z, dtype=torch.float32,requires_grad=True)
            loss = loss_fn(z, labels)
            optimizer.zero_grad()
            loss.backward()
            #   torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            print(data.shape, "data")
            print(labels.shape)
            print(torch.max(z[0]))
            #break;

            iter_count += 1
            if iter_count % plot_every == 0:
                iters.append(iter_count)
                ta = accuracy2(model, train_data, input_seq_len)
                #va = accuracy(model, val_data)
                train_loss.append(float(loss))
                train_acc.append(ta)
                #val_acc.append(va)
                print(iter_count, "Loss:", float(loss), "Train Acc:", ta, "Val Acc:", None)
model = Seq2Seq(128, 128, 2, 10, 10)
# model.to(device)
train_loader = DataLoader(lst,batch_size = 1, collate_fn= lambda batch :collate_data(batch, 100))
total_notes = 0
num_notes_played = 0
for (x,t) in train_loader:
  total_notes += t.numel()
  num_notes_played += t.sum().item()
imbalance_ratio = (total_notes - num_notes_played) / num_notes_played
print(imbalance_ratio)

train(lst, None, 0.001, 100, num_epochs=100, model=model, input_seq_len=149)  # last param is input len for 150 len
#data



torch.Size([100, 150, 128]) torch.Size([100, 0, 128])
140.79100762226238
Epoch:  0
torch.Size([100, 1, 128])
model shape torch.Size([100, 1, 128])
torch.Size([100, 149, 128]) data
torch.Size([100, 1, 128])
tensor(0.1052, grad_fn=<MaxBackward1>)
torch.Size([100, 1, 128])
model shape torch.Size([100, 1, 128])
torch.Size([100, 149, 128]) data
torch.Size([100, 1, 128])
tensor(0.1108, grad_fn=<MaxBackward1>)


KeyboardInterrupt: 

In [None]:
# Example tensors
y_true = torch.tensor([0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]).type(torch.float32)
y_pred = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).type(torch.float32)

loss = nn.CrossEntropyLoss()
print(loss(y_pred, y_true))

y_true = torch.tensor([0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]).type(torch.float32)
y_pred = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).type(torch.float32)

loss_bce = nn.BCEWithLogitsLoss()
print(loss_bce(y_pred, y_true))

y1 = torch.tensor([0.1,0.2,0.4,0.6])
y2 = torch.tensor([1,0,1,0])
#!pip install torcheval
#import torch, torcheval, torchmetrics
#from torcheval.metrics.functional import multiclass_f1_score

# Initialize the metric
metric = torchmetrics.classification.Accuracy(task = "multiclass", num_classes=4)
acc = metric(y1, y2)
print(acc)


In [None]:
def evaluate(model, song, pred_len):
    # song should be formatted as tensor
    # seq = song.shape[0]-10
    # a = model(song[:seq], song[seq:])
    preds = 0
    while preds < 2:
        seq = song.shape[0] - pred_len
        preds += 1
        print(song[:seq].shape)
        z = model(song[(seq -2400):seq], song[seq:])
        # z = torch.sigmoid(z)
        # cond = z > 0.5
        z = torch.sigmoid(z)
        z = (z > 0.54).int()
        # z2 = torch.where(cond, 1, 0)
        # print(z.shape)
        # print(z)
        # print(z == 1)
        song = torch.cat([song, z])
    return song


# not 1
print(valid_files[4])
#song = getData([valid_files[1]])
song, tempo, resolution = parseMidi(valid_files[4])
print(song.shape)
song = torch.tensor(song).type(torch.float32)
#song = torch.from_numpy(song)
#song = torch.from_numpy(song).type(torch.float32)
print(song.shape)
ext = evaluate(model, song.squeeze(0), 2000)
print(ext.shape)
#print(song.shape)
new_midi = outputMidi("new_midi.midi", ext.detach().numpy(), tempo = tempo, resolution = resolution)
display_piano_roll(ext.detach().numpy(), "new_midi")
play_midi("new_midi.midi")

In [None]:
def evaluate2(model, song, pred_len):
    # song should be formatted as tensor
    # seq = song.shape[0]-10
    # a = model(song[:seq], song[seq:])
    preds = 0
    # src = song[-200:-10,:]
    src = song[-1400:,:]
    tgt = song[-10:,:]
    while preds < 2500:
        seq = song.shape[0] - pred_len
        # print(song[:seq].shape)
        z = model(src, tgt)
        z = torch.softmax(z, 1)
        # cond = z > 0.5
        z = (z > 0.0085).int()
        # z = (z > 0.5).int()
        # z2 = torch.where(cond, 1, 0)
        # print(z.shape)
        # print(z)
        # print(z == 1)
        # src = torch.cat((src[-199:,:], tgt)).type(torch.float32)
        tgt = z.type(torch.float32)
        song = torch.cat([song, z]).type(torch.float32)
        preds += z.shape[0]
    return song


# not 1
print(valid_files[0])
#song = getData([valid_files[1]])
song, tempo, resolution = parseMidi(valid_files[0])
print(song.shape)
song = torch.tensor(song).type(torch.float32)
#song = torch.from_numpy(song)
#song = torch.from_numpy(song).type(torch.float32)
print(song.shape)
ext = evaluate2(model, song.squeeze(0), 10)
print(ext.shape)
new_midi = outputMidi("new_midi.midi", ext.detach().numpy(), tempo = tempo, resolution = resolution)
display_piano_roll(ext.detach().numpy(), "new_midi")
play_midi("new_midi.midi")

In [None]:
#piano_roll, tempo, resolution = parseMidi(valid_files[0])
new_midi = outputMidi("new_midi.midi", ext.detach().numpy(), tempo = tempo, resolution = resolution)
print(ext.shape)
display_piano_roll(song.detach().numpy(), "new_midi")
# play_midi("new_midi.midi")

In [None]:
play_midi(valid_files[0])

In [None]:
display_piano_roll(ext.detach().numpy(), "original")