In [None]:
import os
import numpy as np
from random import shuffle
from deep_pianist_identification import utils
from deep_pianist_identification.extractors import get_piano_roll
from pretty_midi import PrettyMIDI
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import groupby
from torch.utils.data import DataLoader, Dataset

SPLIT_PATH = os.path.join(utils.get_project_root(), "references/_unsupervised_resources/voicings/midi_split/")
class VoicingLoaderReal(Dataset):
    def __init__(self, cav_number, n_clips: int = None):
        super().__init__()
        self.cav_midis = self.get_midi_paths(cav_number, n_clips)
        midis = [list(self.create_midi(midi)) for midi in self.cav_midis]
        self.midis = [m for ms in midis for m in ms]

    @staticmethod
    def get_midi_paths(cav_number, n_clips):
        cav_path = os.path.join(SPLIT_PATH, f"{cav_number}_cav")
        assert os.path.isdir(cav_path)
        cavs = [os.path.join(cav_path, i) for i in os.listdir(cav_path) if i.endswith('mid')]
        if n_clips is not None:
            shuffle(cavs)
            cavs = cavs[:n_clips]
        return cavs

    def __len__(self):
        return len(self.midis)
    
    def __getitem__(self, idx: int) -> tuple[np.ndarray, int]:
        # Return the MIDI piano roll and the class index (i.e., True)
        return self.midis[idx], 1

    def create_midi(self, cav_midi_path):
        pm = PrettyMIDI(cav_midi_path)
        lhand = [(i.start, i.end, i.pitch, i.velocity) for i in pm.instruments[1].notes]
        rhand = [(i.start, i.end, i.pitch, i.velocity) for i in pm.instruments[0].notes]
        # Take top note from each left hand if we have multiple voices
        newlhand = []
        for key, group in groupby(lhand, lambda x: x[0]):
            group = list(group)
            if len(group) > 1:
                # Get the highest pitch
                sort = sorted(group, key=lambda x: x[2])
                rhand.extend(sort[1:])
                newlhand.append(sort[0])
        lhand = newlhand
        
        rhand = sorted(rhand, key=lambda x: (x[0], x[2]))
        rhand_grp = [list(i[1]) for i in groupby(rhand, lambda x: x[0])]
        
        
        inverted_rhands = {i: [] for i in range(len(rhand_grp[0]) - 1)}
        for inversion in inverted_rhands.keys():
            for grp in rhand_grp:
                to_invert = sorted(grp, key=lambda x: x[2])[:inversion + 1]
                inverted = [(i[0], i[1], i[2] + 12, i[3]) for i in to_invert]
                combined = inverted + grp[inversion + 1:]
                inverted_rhands[inversion].extend(sorted(combined, key=lambda x: x[2]))
                      
        to_compute = [rhand, * list(inverted_rhands.values())]
        # Iterating over all inversions
        for rhands in to_compute:
            for transp in range(-11, 12):
                for include_lh in [True, False]:
                    if include_lh:
                        comb = sorted(rhands + lhand, key=lambda x: x[0])
                    else:
                        comb = sorted(rhands, key=lambda x: x[0])
                    onset_grps = [list(i[1]) for i in groupby(comb, lambda x: x[0])]
                    spacer = np.linspace(0, utils.CLIP_LENGTH, len(onset_grps))
                    fmt = []
                    for start, end, grp in zip(spacer, spacer[1:], onset_grps):
                        fmt_grp = [(start, end, note[2] + transp, note[3]) for note in grp]
                        fmt.extend(fmt_grp)
                    yield get_piano_roll(fmt)


class VoicingLoaderFake(VoicingLoaderReal):
    def __init__(self, avoid_cav_number, n_clips):
        super().__init__(cav_number=avoid_cav_number, n_clips=n_clips)
        self.cav_midis = self.cav_midis[:n_clips]
    
    @staticmethod
    def get_midi_paths(cav_number, n_clips):
        fmt_cav_paths = [os.path.join(SPLIT_PATH, i) for i in os.listdir(SPLIT_PATH) if i != f"{cav_number}_cav"]
        all_cavs = []
        for path in fmt_cav_paths:
            all_cavs.extend([os.path.join(SPLIT_PATH, path, i) for i in os.listdir(path) if i.endswith('.mid')])
        if n_clips is not None:
            shuffle(all_cavs)
            all_cavs = all_cavs[:n_clips]
        return all_cavs


vl = DataLoader(VoicingLoaderFake(15, n_clips=10), shuffle=True)
len(vl.dataset)