In [25]:
import os
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import random
from pytorch_lightning.loggers import WandbLogger
import wandb
import auraloss
import collections
from tqdm import tqdm
import pretty_midi
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
print(torch.cuda.is_available())
import plotly.graph_objects as go
from torch.optim import lr_scheduler
from IPython.display import Audio
from torchaudio.transforms import Fade
import musdb
import museval
import gc
import pandas as pd
import sklearn
from sklearn.metrics import recall_score, precision_score, f1_score, jaccard_score, accuracy_score,zero_one_loss

True


In [2]:
class Track:
    def __init__(self, name, midi_path, drum_path, mix_path):
        self.name = name
        self.midi_path = midi_path
        self.drum_path = drum_path
        self.mix_path = mix_path
        self.targets = {'drums': '', 'bass': ''}
        self.rate = 44100
        self.subset = 'test'

In [3]:
class AudioData:
    def __init__(self, audio):
        self.audio = audio

In [4]:
seed_value = 3407
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

In [5]:
mix_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/audio/full_mix/'
mixes = os.listdir(mix_folder)
mixes = [mix_folder + m for m in mixes]

drum_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/audio/drum_only/'
drum = os.listdir(drum_folder)
drum = [drum_folder + d for d in drum]

beats_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/annotations/beats/'
beats = os.listdir(beats_folder)
beats = [beats_folder + b for b in beats]#

class_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/annotations/subclass/'
classes = os.listdir(class_folder)
classes = [class_folder + c for c in classes]

midi_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/midi/'
midis = os.listdir(midi_folder)
midis = [midi_folder + m for m in midis]

In [6]:
all_tracks = []
for idx, val in tqdm(enumerate(classes)):

    name = val.replace('D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/annotations/subclass/', '')
    name = name.replace('_subclass.txt', '')

    t = Track(name, midis[idx], drum[idx], mixes[idx])
    all_tracks.append(t)

23it [00:00, ?it/s]


In [7]:
def turn_transcription_into_roll(transcription, frames):
    # Determine your sampling frequency (frames per second)
    fs = 44100
    
    piano_roll_length = int(frames)
    
    # Initialize the piano roll array
    piano_roll = np.zeros((64, piano_roll_length))
    
    # Fill in the piano roll array
    for note in transcription.instruments[0].notes:
        # Convert start and end times to frame indices
        start_frame = int(np.floor(note.start * fs))
        end_frame = int(np.ceil(note.end * fs))
        
        # Set the corresponding frames to 1 (or note.velocity for a velocity-sensitive representation)
        piano_roll[note.pitch, start_frame:end_frame] = 1  # Or use note.velocity
        
    roll = np.vstack([piano_roll[35:36, :], piano_roll[38:39, :], piano_roll[42:43, :], piano_roll[47:48, :], piano_roll[49:50, :]])
    return roll

In [8]:
def expand(x, out_size=44100*4, step=4410):
    output_tensor = torch.zeros((5, out_size))
    for i in range(x.shape[1]):  # Iterate over the second dimension
        start_idx = i * step
        end_idx = start_idx + step
        output_tensor[:, start_idx:end_idx] = x[:, i].unsqueeze(1)
    return output_tensor

def compress(x, original_shape=(5, 40), step=4410):
    """
    Compresses a tensor from a larger size to its original smaller size by averaging blocks of values.
    
    Args:
    - x (Tensor): The input tensor to be compressed, expected to have the shape (5, 44100) or similar.
    - original_shape (tuple): The shape of the output tensor, default is (5, 40).
    - step (int): The size of the block to average over, default is 4410.
    
    Returns:
    - Tensor: The compressed tensor with shape specified by `original_shape`.
    """
    output_tensor = torch.zeros(original_shape)
    for i in range(original_shape[1]):  # Iterate over the second dimension of the target shape
        start_idx = i * step
        end_idx = start_idx + step
        # Take the mean of each block and assign it to the corresponding position in the output tensor
        output_tensor[:, i] = x[:, start_idx:end_idx].mean(dim=1)
    return output_tensor


In [9]:
def load_audio(path):
    audio_tensors = []
    waveform, _ = torchaudio.load(path)
    return waveform

def load_roll(path, frames):
    transcription = pretty_midi.PrettyMIDI(path)
    roll = turn_transcription_into_roll(transcription, frames)

    return torch.from_numpy(roll).float()


In [26]:
out_dir = f"D:/Github/phd-drum-sep/model-as-adt/results_ht_epoch_280/adt/"
rows = []
for track in tqdm(all_tracks):

    mixture_tensor = load_audio(track.mix_path)
    shape = mixture_tensor.shape[1]
    
    snippet_length = (mixture_tensor.shape[1] // (44100 * 4)) * (44100 * 4)
    mixture_tensor = mixture_tensor[:, :snippet_length]

    roll_tensor = load_roll(track.midi_path, shape)
    roll_tensor = roll_tensor[:, :snippet_length]

    proposed_answers = []

    track_dir = f'{out_dir}{track.name}'
    track_folder = os.listdir(track_dir)
    segments = len(track_folder) - 2

    chunk_len = int(44100 * 4)

    for i in range(segments):
        proposed_transcription = pd.read_csv(f'{track_dir}/{i}.csv')
        proposed_transcription = np.asarray(proposed_transcription)
        proposed_transcription = np.asarray(proposed_transcription.T)[1:].T
        proposed_transcription = expand(torch.from_numpy(proposed_transcription))

        start = i * chunk_len
        end = start + chunk_len

        drum_chunk_ = roll_tensor[:, start:end].numpy()
        proposed_transcription_ = proposed_transcription.numpy()

        for drum in range(5):
            drum_chunk = drum_chunk_[drum, :]
            proposed_transcription = proposed_transcription_[0, :]
            recall = recall_score(drum_chunk, proposed_transcription, average='macro', zero_division=0)
            precision = precision_score(drum_chunk, proposed_transcription, average='macro', zero_division=0)
            f1 = f1_score(drum_chunk, proposed_transcription, average='macro')
            jaccard = jaccard_score(drum_chunk, proposed_transcription, average='macro')
            acc = accuracy_score(drum_chunk, proposed_transcription)
            zo = zero_one_loss(drum_chunk, proposed_transcription)
            rows.append([track.name, i, drum, recall, precision, f1, jaccard, acc, zo])


    

    

  0%|                                                                                                                                                                                                                | 0/23 [00:07<?, ?it/s]


In [27]:
df_results = pd.DataFrame(rows, columns=['track_name', 'slice', 'drum', 'recall', 'precision', 'f1', 'jaccard', 'accuracy', 'zero_one'])

Unnamed: 0,track_name,slice,drum,recall,precision,f1,jaccard,accuracy,zero_one
0,MusicDelta_80sRock,0,0,0.5,0.399977,0.44443,0.399977,0.799955,0.200045
1,MusicDelta_80sRock,0,1,0.5,0.449989,0.473678,0.449989,0.899977,0.100023
2,MusicDelta_80sRock,0,2,1.0,1.0,1.0,1.0,1.0,0.0
3,MusicDelta_80sRock,0,3,1.0,1.0,1.0,1.0,1.0,0.0
4,MusicDelta_80sRock,0,4,1.0,1.0,1.0,1.0,1.0,0.0
5,MusicDelta_80sRock,1,0,0.457501,0.461646,0.459209,0.363648,0.670437,0.329563
6,MusicDelta_80sRock,1,1,0.39189,0.453114,0.420284,0.362491,0.724983,0.275017
7,MusicDelta_80sRock,1,2,0.4,0.5,0.444444,0.4,0.8,0.2
8,MusicDelta_80sRock,1,3,0.4,0.5,0.444444,0.4,0.8,0.2
9,MusicDelta_80sRock,1,4,0.397435,0.484371,0.436618,0.387497,0.774994,0.225006
