In [4]:
from model import ComposeNet, VOCABULARY
from dataset import MidiDataset
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import librosa 
import torch

In [11]:
df = {
    'notes': [torch.tensor([1 for i in range(0, 255)])],
    'velocities': [torch.tensor([1 for i in range(0, 255)])],
    'durations': [torch.tensor([1 for i in range(0, 255)])],
    'times': [torch.tensor([1 for i in range(0, 255)])],     
}

In [9]:
model = ComposeNet()

In [21]:
midi = MidiDataset(df, context_len = model._context_len)
train_loader = DataLoader(midi, batch_size=1)

data = next(iter(train_loader))
note_tensor = data['notes'].to("cuda")
velocity_tensor = data['velocities'].to("cuda")
duration_tensor = data['durations'].to("cuda")

input = {
    "notes": note_tensor,
    "velocities": velocity_tensor,
    "durations": duration_tensor 
}

print("Input Tensor = ", note_tensor.shape, velocity_tensor.shape, duration_tensor.shape)

output_notes, output_velocities, output_durations = model(input)

output_notes = output_notes.cpu().detach().numpy()
output_velocities = output_velocities.cpu().detach().numpy()
output_durations = output_durations.cpu().detach().numpy()

print("Output Tensor = ", output_notes.shape, output_velocities.shape, output_durations.shape)

print("Note = ", librosa.midi_to_note(output_notes[0].argmax()))
print("Velocity = ", output_velocities[0].argmax())
print("Duration = ", output_durations[0].argmax())


Input Tensor =  torch.Size([1, 255]) torch.Size([1, 255]) torch.Size([1, 255])
Output Tensor =  (1, 131) (1, 128) (1, 255, 1)
Note =  C-1
Velocity =  105
Duration =  0


In [13]:
# Model Specifications
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 0.119MB


In [None]:
plt.bar(np.linspace(0, len(VOCABULARY), len(VOCABULARY)), output_notes.reshape((-1,)))
plt.show()