In [8]:
import numpy as np
import torch
import pydub
import pypianoroll 

In [6]:
local_datapath_midi = 'roja-ar-rahman-melody.mid'

In [33]:
class MIDIClip:
    """MIDI clip object which reads from a MIDI file andi """
    def __init__(self, filepath:str) -> None: 
        """
        filepath (str): Path to the MIDI file (.mid).
        """
        self.filepath = filepath
        # self.midi = pydub.MIDI(self.filepath)
        # self.duration = self.midi.duration_seconds
        # self.mono = True if self.tensor.shape[0] == 1 else False

    def pianoroll(self, resolution: int = 24, type="binary", tensor_type=torch.float32) -> torch.Tensor:
        """Converts a MIDI file to a piano roll representation. 

        The piano roll representation is a common way to represent music data.
        It consists of a two-dimensional matrix where the rows represent time steps 
        and the columns represent pitches (and special symbols in certain variations). The time steps also depend on the model and level of granularity desired. For coherence, many ML models filter inputs to 4/4 time signatures and use sixteenth notes as a single time step (16 steps / bar).  
        The entries of the matrix are considered as non-negative. 

        1) Binary: One example is when the entries are binary, indicating whether a note is played at a given time step and pitch. 

        X ∈ {0, 1}^(h×w) where `h` is the number of pitches and `w` is the number of time steps.

        *NOTE: velocity and complex time signature differences are IGNORED in this version!*

        2) Velocity: Another variation on this form also considers the velocity of the note. In this case, the entries are integers in the range of 0-127 indicating the velocity of the note played at a given time step and pitch. In this first section, we will talk about the models that look at only the binary piano roll representation, but will revisit the velocity in the next section.

        X ∈ [0,127]^(h×w)

        *NOTE: The number of pitches in a MIDI file is 128, however the dimension of `h` may be reduced to narrow the range of notes to a specific instrument (e.g. 88 keys in a piano) or can be expanded to include silence or rests (e.g. 129/130 or 89/90).  

        Args:
        resolution (int, optional): Number of ticks per beat (default: 24).
        type (str, optional): type of pianoroll: binary - 0 or 1 , velocity - 0-127 (default: "binary")
        tensor_type (torch.Type, optional): Tensor type can be updated for different hardware configs (default: torch.float32)

        Returns:
        torch.Tensor: Piano roll representation of the MIDI file in tensor fp32 format (128 x T) 

        Usage: 
        > midi_obj = MIDIClip(midi_path)
        > midi_obj.piano_roll
        > [[1,..,0,...,0],
           ...,
           [0,...,1,...,0],
           ...,
           [0,...,0,...,1]]
        """
        assert(type in ["binary", "velocity"]), "type must be either binary or velocity"
        
        ppr_midi = pypianoroll.read(self.filepath, resolution=resolution)
        if type == "binary":
            ppr_midi.binarize()
        piano_roll = ppr_midi.tracks[0].pianoroll.astype(int)
        piano_roll = np.swapaxes(piano_roll,0,1)
        piano_roll = torch.from_numpy(piano_roll)
        piano_roll = piano_roll.to(dtype=tensor_type)
        return piano_roll

    def plot(self, type: str = "static") -> None: 
        pass 
        # dynamic plot via bokeh using https://github.com/dubreuia/visual_midi
        # statis using pypianoroll: https://salu133445.github.io/pypianoroll/visualization.html


In [34]:
midi_obj = MIDIClip(local_datapath_midi)
pr_obj = midi_obj.pianoroll(type="velocity")
print(pr_obj[:,0])
print(pr_obj.dtype)


tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 80.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.])
torch.float32
