# Imports

In [1]:
import os
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import random
from pytorch_lightning.loggers import WandbLogger
import wandb
import auraloss
import collections
from tqdm import tqdm
import pretty_midi
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
print(torch.cuda.is_available())
import plotly.graph_objects as go
from torch.optim import lr_scheduler
from IPython.display import Audio
from torchaudio.transforms import Fade
import musdb
import museval
import gc
import pandas as pd

True


In [2]:
class Track:
    def __init__(self, name, midi_path, drum_path, mix_path):
        self.name = name
        self.midi_path = midi_path
        self.drum_path = drum_path
        self.mix_path = mix_path
        self.targets = {'drums': '', 'bass': ''}
        self.rate = 44100
        self.subset = 'test'

# Set Seeds

In [3]:
seed_value = 3407
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

In [4]:
mix_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/audio/full_mix/'
mixes = os.listdir(mix_folder)
mixes = [mix_folder + m for m in mixes]

drum_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/audio/drum_only/'
drum = os.listdir(drum_folder)
drum = [drum_folder + d for d in drum]

beats_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/annotations/beats/'
beats = os.listdir(beats_folder)
beats = [beats_folder + b for b in beats]#

class_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/annotations/subclass/'
classes = os.listdir(class_folder)
classes = [class_folder + c for c in classes]

midi_folder = 'D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/midi/'
midis = os.listdir(midi_folder)
midis = [midi_folder + m for m in midis]

In [5]:
all_tracks = []
for idx, val in tqdm(enumerate(classes)):

    name = val.replace('D:/Github/phd-drum-sep/data/MDBDrums-master/MDB Drums/annotations/subclass/', '')
    name = name.replace('_subclass.txt', '')

    t = Track(name, midis[idx], drum[idx], mixes[idx])
    all_tracks.append(t)

23it [00:00, ?it/s]


# Construct Teh Datas

In [6]:
def turn_transcription_into_roll(transcription, frames):
    # Determine your sampling frequency (frames per second)
    fs = 44100
    
    piano_roll_length = int(frames)
    
    # Initialize the piano roll array
    piano_roll = np.zeros((64, piano_roll_length))
    
    # Fill in the piano roll array
    for note in transcription.instruments[0].notes:
        # Convert start and end times to frame indices
        start_frame = int(np.floor(note.start * fs))
        end_frame = int(np.ceil(note.end * fs))
        
        # Set the corresponding frames to 1 (or note.velocity for a velocity-sensitive representation)
        piano_roll[note.pitch, start_frame:end_frame] = 1  # Or use note.velocity
        
    roll = np.vstack([piano_roll[35:36, :], piano_roll[38:39, :], piano_roll[42:43, :], piano_roll[47:48, :], piano_roll[49:50, :]])
    return roll

# model

In [7]:
class DrumDemucs(pl.LightningModule):
    def __init__(self):
        super(DrumDemucs, self).__init__()

        self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
                    fft_sizes=[1024, 2048, 4096],
                    hop_sizes=[256, 512, 1024],
                    win_lengths=[1024, 2048, 4096],
                    scale="mel", 
                    n_bins=150,
                    sample_rate=44100,
                    device="cuda"
                )

        self.loss_fn_2 = auraloss.time.SISDRLoss()

        self.loss_fn_3 = torch.nn.L1Loss()

        self.loss_used = 0

        sources = ['drum',
                   'noise',
                   ]
        
        self.demucs_mixer =  torchaudio.models.HDemucs(
            sources=sources,
            audio_channels=7,
            depth=6,
        )

        self.out_conv = nn.Conv1d(in_channels=7, out_channels=2, kernel_size=1)
        self.out = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1)      


    def compute_loss(self, outputs, ref_signals):
        loss = self.loss_fn(outputs, ref_signals) + self.loss_fn_2(outputs, ref_signals) +  self.loss_fn_3(outputs, ref_signals)
        return loss

    def forward(self, audio, drumroll):
        to_mix = torch.cat([audio, drumroll], axis=1)
        out = self.demucs_mixer(to_mix)
        out_2 = self.out_conv(out[:, 0, :, :])
        out_2 = self.out(out_2)
        # out_2 = torch.tanh(out_2)

        return out_2
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        audio, drum, drumroll = batch
        
        outputs = self.forward(audio, drumroll)
        # print(outputs.size())

        if batch_idx % 64 == 0:
            input_signal = audio[0].cpu().detach().numpy().T
            generated_signal = outputs[0].cpu().detach().numpy().T
            drum_signal = drum[0].cpu().detach().numpy().T 
            wandb.log({'audio_input': [wandb.Audio(input_signal, caption="Input", sample_rate=44100)]})
            wandb.log({'audio_reference': [wandb.Audio(drum_signal, caption="Reference", sample_rate=44100)]})
            wandb.log({'audio_output': [wandb.Audio(generated_signal, caption="Output", sample_rate=44100)]})
             
            for i in range(5):
                wandb.log({f'drum_{i + 1}': [wandb.Audio(drumroll[0].cpu().detach().numpy()[i, :], caption="Output", sample_rate=44100)]})


        loss = self.compute_loss(outputs, drum)         

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    

    def configure_optimizers(self):
        # Define your optimizer and optionally learning rate scheduler here
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        return [optimizer], [scheduler]
        

In [8]:
def load_audio(path):
    audio_tensors = []
    waveform, _ = torchaudio.load(path)
    return waveform

def load_roll(path, frames):
    transcription = pretty_midi.PrettyMIDI(path)
    roll = turn_transcription_into_roll(transcription, frames)

    return torch.from_numpy(roll).float()


# SISNR

In [9]:
def expand(x, out_size=44100*4, step=4410):
    output_tensor = torch.zeros((5, out_size))
    for i in range(x.shape[1]):  # Iterate over the second dimension
        start_idx = i * step
        end_idx = start_idx + step
        output_tensor[:, start_idx:end_idx] = x[:, i].unsqueeze(1)
    return output_tensor

def compress(x, original_shape=(5, 40), step=4410):
    """
    Compresses a tensor from a larger size to its original smaller size by averaging blocks of values.
    
    Args:
    - x (Tensor): The input tensor to be compressed, expected to have the shape (5, 44100) or similar.
    - original_shape (tuple): The shape of the output tensor, default is (5, 40).
    - step (int): The size of the block to average over, default is 4410.
    
    Returns:
    - Tensor: The compressed tensor with shape specified by `original_shape`.
    """
    output_tensor = torch.zeros(original_shape)
    for i in range(original_shape[1]):  # Iterate over the second dimension of the target shape
        start_idx = i * step
        end_idx = start_idx + step
        # Take the mean of each block and assign it to the corresponding position in the output tensor
        output_tensor[:, i] = x[:, start_idx:end_idx].mean(dim=1)
    return output_tensor

def tournament_selection(population, losses, tournament_size=3):
    """
    Selects two parents using tournament selection.

    Args:
    - population (list of Tensors): The population from which to select parents.
    - losses (list of floats): The loss associated with each individual in the population, serving as a measure of fitness.
    - tournament_size (int): The number of individuals to sample for each tournament.

    Returns:
    - parent1, parent2 (tuple of Tensors): Two selected parents from the population.
    """
    population_size = len(population)

    # Tournament 1
    indices = np.random.choice(range(population_size), size=tournament_size, replace=False)
    tournament_losses = [losses[i] for i in indices]
    winner_index = indices[np.argmin(tournament_losses)]
    parent1 = population[winner_index]

    # Tournament 2
    indices = np.random.choice(range(population_size), size=tournament_size, replace=False)
    tournament_losses = [losses[i] for i in indices]
    winner_index = indices[np.argmin(tournament_losses)]
    parent2 = population[winner_index]

    return parent1, parent2

def crossover(parent1, parent2):
    mask = torch.randint(0, 2, size=parent1.shape, dtype=torch.bool)
    offspring = torch.where(mask, parent1, parent2)
    return offspring

def adaptive_mutation_rate(current_iteration, max_iterations, start_rate=0.75, end_rate=0.25):
    """
    Calculates an adaptive mutation rate that decreases from start_rate to end_rate over time.

    Args:
    - current_iteration (int): The current iteration number (should start from 0).
    - max_iterations (int): The total number of iterations the algorithm will run.
    - start_rate (float): The initial mutation rate at the start of the algorithm.
    - end_rate (float): The final mutation rate at the end of the algorithm.

    Returns:
    - float: The calculated mutation rate for the current iteration.
    """
    # Linear decay
    rate = start_rate - ((start_rate - end_rate) * (current_iteration / max_iterations))
    
    # Ensure the rate never falls below the end_rate
    return max(rate, end_rate)


In [10]:
class AudioData:
    def __init__(self, audio):
        self.audio = audio

In [11]:
def find_best(mixture_tensor_, drum_tensor_):
    with torch.no_grad():
        n_iters = 100
        population_size = 32
        batch_size = population_size
        elite_size = 2  # Number of elites to carry over to the next generation
        shape = (5,40)
    
        solution = torch.randn(shape).clamp(0, 1)
        solution = torch.where(solution < 0.5, torch.tensor(1), torch.tensor(0))
    
        population = []
    
        for i in range(population_size - len(population)):
            candidates = torch.randint_like(solution, low=0, high=1)
            population.append(candidates)
            
        best_loss = 10000000000
        best_solution = []
    
        for iteration in range(n_iters):
            
            losses = []
            batch = []
            mix = []
            drums = []
    
            for j in range(batch_size):
                try:
                    proposed = torch.where(population[j] < 0.5, torch.tensor(1), torch.tensor(0))
                    proposed = expand(proposed).unsqueeze(0)
                    batch.append(proposed)
                    mix.append(mixture_tensor_)
                    drums.append(drum_tensor_)
                except Exception as e:
                    print('error', e)
            
            batch_candidates = torch.cat(batch, axis=0).to(model.device)
            mix = torch.cat(mix, axis=0)
            drums = torch.cat(drums, axis=0)
            sep = model(mix, batch_candidates)
    
            for j in range(batch_size):
                # try:
                    # print(loss)
                loss_item = nn.L1Loss()(sep[j, : ,:].unsqueeze(0), drum_tensor_).item()
                losses.append(loss_item)
                # except Exception as e:
                #     print('error', e) 
    
            sorted_indices = np.argsort(losses)
            sorted_population = [population[i] for i in sorted_indices]
            sorted_losses = [losses[i] for i in sorted_indices]
    
            # Update best solution if found
            if sorted_losses[0] < best_loss:
                best_loss = sorted_losses[0]
                best_solution = sorted_population[0]
                # print(f"Iteration {iteration}, Loss: {best_loss}")
    
    
            # Elitism: Carry over the best solutions unchanged
            new_population = sorted_population[:elite_size]
            
            # Fill the rest of the new population
            while len(new_population) < population_size:
                # Tournament selection for parent selection
                parent1, parent2 = tournament_selection(sorted_population, sorted_losses)
    
                # Crossover to produce offspring
                offspring1 = crossover(parent1, parent2)
                offspring2 = crossover(parent2, parent1)
    
                # Adaptive mutation rate
                mutation_rate = adaptive_mutation_rate(iteration, n_iters)
    
                # Mutation for offspring
                for offspring in [offspring1, offspring2]:
                    if len(new_population) < population_size:  # Check if there's still space in the new population
                        if torch.rand(1) < mutation_rate:
                            mutation = torch.randint(-1, 2, size=offspring.shape)
                            mutated_offspring = offspring + mutation
                            mutated_offspring = mutated_offspring.clamp(0, 1)
                            new_population.append(mutated_offspring)
    
            # Update population for the next iteration
            population = new_population
    
        return population[0]

In [12]:
def calculate_precision_recall_torch(transcription, prediction):
    TPs = torch.sum((transcription == 1) & (prediction == 1), dim=1)
    FPs = torch.sum((transcription == 0) & (prediction == 1), dim=1)
    FNs = torch.sum((transcription == 1) & (prediction == 0), dim=1)

    precision = TPs.float() / (TPs + FPs).float()
    recall = TPs.float() / (TPs + FNs).float()

    # Handle potential division by zero for precision and recall
    precision[torch.isnan(precision)] = 0
    recall[torch.isnan(recall)] = 0

    return precision, recall

In [13]:
def calculate_f_measure(precision, recall, beta=1):
    """
    Calculate the F-measure for each class and the average F-measure.

    Parameters:
    - precision: Tensor of precision values per class.
    - recall: Tensor of recall values per class.
    - beta: Weight of recall in the harmonic mean.

    Returns:
    - f_measure: Tensor of F-measure for each class.
    - average_f_measure: Scalar, average F-measure across all classes.
    """
    numerator = (1 + beta**2) * precision * recall
    denominator = (beta**2 * precision) + recall

    # Avoid division by zero
    denominator[denominator == 0] = 1

    f_measure = numerator / denominator

    # Handle potential NaN values
    f_measure[torch.isnan(f_measure)] = 0

    average_f_measure = torch.mean(f_measure)

    return f_measure, average_f_measure

In [None]:
name = 'epoch_285'
#try:

out_dir = f"D:/Github/phd-drum-sep/model-as-adt/results_ht_{name}/"
try:
    os.mkdir(out_dir)
except:
    pass


out_dir = f"D:/Github/phd-drum-sep/model-as-adt/results_ht_{name}/adt/"
try:
    os.mkdir(out_dir)
except:
    pass

model = DrumDemucs.load_from_checkpoint(f'D:/Github/phd-drum-sep/analysis/demucs_small_model_analysis/checkpoint/{name}.ckpt')
model = model.eval()

results = museval.EvalStore(frames_agg='median', tracks_agg='median')
for track in tqdm(all_tracks):

    mixture_tensor = load_audio(track.mix_path).unsqueeze(0).to(model.device)
    snippet_length = (mixture_tensor.shape[2] // (44100 * 4)) * (44100 * 4)
    mixture_tensor = mixture_tensor[:,:, :snippet_length]

    drum_tensor = load_audio(track.drum_path).unsqueeze(0)
    drum_tensor = torch.cat([drum_tensor, drum_tensor], dim=1).to(model.device)
    drum_tensor = drum_tensor[:,:, :snippet_length]

    shape = mixture_tensor.shape[2]
    roll_tensor = load_roll(track.midi_path, shape).unsqueeze(0).to(model.device)
    roll_tensor = roll_tensor[:,:, :snippet_length]

    proposed_answers = []
    
    device = mixture_tensor.device
    batch, channels, length = mixture_tensor.shape
    chunk_len = int(44100 * 4)

    for start in tqdm(range(0, length, chunk_len)):
        end = start + chunk_len
        answer = find_best(mixture_tensor[:,:,start:end], drum_tensor[:,:, start:end])
        proposed_answers.append(answer)

    expanded = [expand(p).unsqueeze(0) for p in proposed_answers]
    # expanded = torch.cat(expanded, dim=2)
    pres = [[],[],[],[],[]]
    recs = [[],[],[],[],[]]
    for idx, val in enumerate(expanded):
        segment = 44100 * 4
        start = idx * segment
        end = start + segment
        slice = roll_tensor[:, :, start:end].to(model.device).squeeze(0)
        pred = val.to(model.device).squeeze(0)
        pre, rec = calculate_precision_recall_torch(slice, pred)
    
        for drum in range(5):
            pres[drum].append(pre[drum].unsqueeze(0))
            recs[drum].append(rec[drum].unsqueeze(0))
            
    for p in range(len(pres)):
        for q in range(len(pres[p])):
            try:
                pres[p][q] = pres[p][q].item()
            except:
                pass
                
    for p in range(len(recs)):
        for q in range(len(recs[p])):
            try:
                recs[p][q] = recs[p][q].item()
            except:
                pass

    try:
        os.mkdir(f'{out_dir}{track.name}')
    except:
        pass
    
    df = pd.DataFrame(pres)
    df = df.T
    df.to_csv(f'{out_dir}{track.name}/precision.csv')

    df = pd.DataFrame(recs)
    df = df.T
    df.to_csv(f'{out_dir}{track.name}/recall.csv')

    for idx, val in enumerate(proposed_answers):
        adf = pd.DataFrame(val.numpy())
        adf.to_csv(f'{out_dir}{track.name}/{idx}.csv')
    