In [1]:
import torch
from tqdm import tqdm

from audiocraft.data.audio_dataset import AudioDataset, AudioMeta
import pretty_midi
import numpy as np

from midi_to_pianoroll import parse_midi, create_piano_roll, parse_midi_slice

    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cu118)
    Python  3.8.18 (you have 3.8.10)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = {
    # "batch_size": 64,
    # "num_workers": 10,
    "segment_duration": 30,
    "num_samples": 500000,    
    "sample_rate": 32000,
    "channels": 1,
}

In [3]:

class MIDIAudioDataset(AudioDataset):
    def __init__(self, meta, **kwargs):
        super().__init__(meta, **kwargs)
        self.return_info = True

    def load_midi(self, midi_file_path):
        """
        Load a MIDI file and return its piano roll representation.
        """
        midi_data = pretty_midi.PrettyMIDI(midi_file_path)
        # Extract piano roll, shape [88, time_steps], then expand to [2, 88, time_steps]
        piano_roll = midi_data.get_piano_roll(fs=50)
        onset_roll = midi_data.get_onset_strengths(fs=50)
        piano_roll = np.stack([onset_roll, piano_roll], axis=0)
        return piano_roll

    def __getitem__(self, index):
        # Fetch the audio segment using parent class method
        audio_segment, segment_info = super().__getitem__(index)
        
        # Calculate the corresponding MIDI segment time
        start_time = segment_info.seek_time
        end_time = start_time + 30  # 30 seconds segment

        # Construct MIDI file path (this may vary based on your file structure)
        midi_file_path = f"{segment_info.meta.path.rsplit('.',1)[0]}.midi"
        
        # Load the MIDI file and extract the segment
        print(start_time, end_time)
        sliced_midi = parse_midi_slice(midi_file_path, start_time, end_time)
        # Convert the MIDI segment to piano roll representation
        piano_roll_segment = create_piano_roll(sliced_midi, 50)
        # Convert numpy array to torch tensor
        piano_roll_segment = torch.from_numpy(piano_roll_segment).float()

        return audio_segment, piano_roll_segment, segment_info, sliced_midi

In [4]:
midiaudiodataset = MIDIAudioDataset.from_meta("/home/jongmin/userdata/SemCodec/egs/midiaudio_test", **args)

In [5]:
midiaudiodataset.max_audio_duration

In [6]:
midiaudiodataset.segment_duration

30

In [7]:
midiaudiodataset.return_info

True

In [8]:
data = midiaudiodataset[0]

743.8253091364437 773.8253091364437


In [9]:
audio = data[0]
pr = data[1]
info = data[2]

In [10]:
audio.shape, pr.shape

(torch.Size([1, 960000]), torch.Size([2, 88, 1500]))