# Imports

In [1]:
import os
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import random
from pytorch_lightning.loggers import WandbLogger
import wandb
import auraloss
import collections
from tqdm import tqdm
import pretty_midi
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
print(torch.cuda.is_available())
import plotly.graph_objects as go
from torch.optim import lr_scheduler


True


# Set Seeds

In [2]:
seed_value = 3407
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

# Construct Teh Datas

In [3]:
path = "D:/Github/phd-drum-sep/Data/musdb18hq/"

In [4]:
os.listdir(path)

['test', 'train']

In [5]:
train = list(os.listdir(path+'train'))
test = list(os.listdir(path+'test'))

In [6]:
sources = ['drum', 'bass', 'other', 'vocals']

In [7]:
all_scenes = {}
counter = 0
for idx, val in tqdm(enumerate(test)):
    p = path + 'test/' + val + "/"
    info = torchaudio.info(f"{p}mixture.wav")
    seconds = info.num_frames // 44100
    for i in range(0, seconds - 10, 5):
        start_point = i * 44100
        if start_point + 441000 < info.num_frames:
            all_scenes[counter] = {'music_path': p, 'start_point': start_point, 'length': 441000, 'frames' : info.num_frames}
            counter += 1

50it [00:00, 5556.55it/s]


In [8]:
def turn_transcription_into_roll(transcription, frames):
    # Determine your sampling frequency (frames per second)
    fs = 44100
    
    piano_roll_length = int(frames)
    
    # Initialize the piano roll array
    piano_roll = np.zeros((64, piano_roll_length))
    
    # Fill in the piano roll array
    for note in transcription.instruments[0].notes:
        # Convert start and end times to frame indices
        start_frame = int(np.floor(note.start * fs))
        end_frame = int(np.ceil(note.end * fs))
        
        # Set the corresponding frames to 1 (or note.velocity for a velocity-sensitive representation)
        piano_roll[note.pitch, start_frame:end_frame] = 1  # Or use note.velocity
        
    roll = np.vstack([piano_roll[35:36, :], piano_roll[38:39, :], piano_roll[42:43, :], piano_roll[47:48, :], piano_roll[49:50, :]])
    return roll

# model

In [9]:
class DrumDemucs(pl.LightningModule):
    def __init__(self):
        super(DrumDemucs, self).__init__()

        self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
                    fft_sizes=[1024, 2048, 4096],
                    hop_sizes=[256, 512, 1024],
                    win_lengths=[1024, 2048, 4096],
                    scale="mel", 
                    n_bins=150,
                    sample_rate=44100,
                    device="cuda"
                )

        self.loss_fn_2 = auraloss.time.SISDRLoss()

        self.loss_fn_3 = torch.nn.L1Loss()

        self.loss_used = 0

        sources = ['drum',
                   'noise',
                   ]
        
        self.demucs_mixer =  torchaudio.models.HDemucs(
            sources=sources,
            audio_channels=7,
            depth=6,
        )

        self.out_conv = nn.Conv1d(in_channels=7, out_channels=2, kernel_size=1)
        self.out = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1)      


    def compute_loss(self, outputs, ref_signals):
        loss = self.loss_fn(outputs, ref_signals) + self.loss_fn_2(outputs, ref_signals) +  self.loss_fn_3(outputs, ref_signals)
        return loss

    def forward(self, audio, drumroll):
        to_mix = torch.cat([audio, drumroll], axis=1)
        out = self.demucs_mixer(to_mix)
        out_2 = self.out_conv(out[:, 0, :, :])
        out_2 = self.out(out_2)
        # out_2 = torch.tanh(out_2)

        return out_2
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        audio, drum, drumroll = batch
        
        outputs = self.forward(audio, drumroll)
        # print(outputs.size())

        if batch_idx % 64 == 0:
            input_signal = audio[0].cpu().detach().numpy().T
            generated_signal = outputs[0].cpu().detach().numpy().T
            drum_signal = drum[0].cpu().detach().numpy().T 
            wandb.log({'audio_input': [wandb.Audio(input_signal, caption="Input", sample_rate=44100)]})
            wandb.log({'audio_reference': [wandb.Audio(drum_signal, caption="Reference", sample_rate=44100)]})
            wandb.log({'audio_output': [wandb.Audio(generated_signal, caption="Output", sample_rate=44100)]})
             
            for i in range(5):
                wandb.log({f'drum_{i + 1}': [wandb.Audio(drumroll[0].cpu().detach().numpy()[i, :], caption="Output", sample_rate=44100)]})


        loss = self.compute_loss(outputs, drum)         

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    

    def configure_optimizers(self):
        # Define your optimizer and optionally learning rate scheduler here
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        return [optimizer], [scheduler]
        

In [10]:
model = DrumDemucs.load_from_checkpoint('D:/Github/phd-drum-sep/models/DrumSep/step_62720.ckpt')

In [11]:
sample_rate = 44100
segment_length = sample_rate * 10

def load_audio(path, start_point, filename):
    audio_tensors = []
    file = filename
    segment, _ = torchaudio.load(f"{path}/{file}", frame_offset=start_point, num_frames=segment_length)
    audio_tensors.append(segment)
    return torch.cat(audio_tensors, dim=0)

def load_roll(path, start_point, frames):
    midi = path + '/mixture.wav.mid'
    transcription = pretty_midi.PrettyMIDI(midi)
    roll = turn_transcription_into_roll(transcription, frames)
    # print(roll.shape)
    roll = roll[:, start_point: start_point + segment_length]
    return torch.from_numpy(roll).float()


# SISNR

In [12]:
import numpy as np

def calculate_si_snr(target, estimate):
    # Ensure the shapes of target and estimate match
    assert target.shape == estimate.shape, "Target and estimate should have the same shape"
    
    # Normalize target and estimate to zero mean
    target = target - np.mean(target)
    estimate = estimate - np.mean(estimate)
    
    # Calculate the scaling factor
    scaling_factor = np.dot(estimate, target) / np.dot(target, target)
    
    # Scale the target signal
    scaled_target = scaling_factor * target
    
    # Calculate the error signal
    error = scaled_target - estimate
    
    # Calculate the SI-SNR
    si_snr = 10 * np.log10(np.sum(scaled_target ** 2) / np.sum(error ** 2))
    
    return si_snr


In [22]:
def calculate_sdr(target, estimate):
    # Ensure the shapes of target and estimate match
    assert target.shape == estimate.shape, "Target and estimate should have the same shape"
    
    # Calculate signal power
    signal_power = np.sum(target ** 2)
    
    # Calculate error power
    error_power = np.sum((target - estimate) ** 2)
    
    # Calculate SDR
    sdr = 10 * np.log10(signal_power / error_power)
    
    return sdr

In [13]:
idxs = list(all_scenes)
si_snrs = []
si_snrs_underlying = []
sdr = []
sdr_underlying = []
for idx in tqdm(idxs):
    sample =  all_scenes[idx]

    audio_path = sample['music_path']

    start_point = sample['start_point']

    mixture_tensor = load_audio(audio_path, start_point,'mixture.wav').unsqueeze(0).to(model.device)
    drum_tensor = load_audio(audio_path, start_point,'drums.wav').unsqueeze(0).to(model.device)
    roll_tensor = load_roll(audio_path, start_point, sample['frames']).unsqueeze(0).to(model.device)
    sep = model(mixture_tensor, roll_tensor)

    sisnr_l = calculate_si_snr(drum_tensor.squeeze(0)[0].cpu().numpy(), sep.squeeze(0)[0].detach().cpu().numpy())
    sisnr_r = calculate_si_snr(drum_tensor.squeeze(0)[1].cpu().numpy(), sep.squeeze(0)[1].detach().cpu().numpy())
    sisnr_avg = (sisnr_l + sisnr_r) / 2
    si_snrs.append([sisnr_l, sisnr_r, sisnr_avg])

    sdr_l = calculate_sdr(drum_tensor.squeeze(0)[0].cpu().numpy(), sep.squeeze(0)[0].detach().cpu().numpy())
    sdr_r = calculate_sdr(drum_tensor.squeeze(0)[1].cpu().numpy(), sep.squeeze(0)[1].detach().cpu().numpy())
    sdr_avg = (sdr_l + sdr_r) / 2
    sdr.append([sdr_l, sdr_r, sdr_avg])

    ################################################################
    
    sisnr_l = calculate_si_snr(drum_tensor.squeeze(0)[0].cpu().numpy(), mixture_tensor.squeeze(0)[0].detach().cpu().numpy())
    sisnr_r = calculate_si_snr(drum_tensor.squeeze(0)[1].cpu().numpy(), mixture_tensor.squeeze(0)[1].detach().cpu().numpy())
    sisnr_avg = (sisnr_l + sisnr_r) / 2
    si_snrs_underlying.append([sisnr_l, sisnr_r, sisnr_avg])

    sdr_l = calculate_sdr(drum_tensor.squeeze(0)[0].cpu().numpy(), mixture_tensor.squeeze(0)[0].detach().cpu().numpy())
    sdr_r = calculate_sdr(drum_tensor.squeeze(0)[1].cpu().numpy(), mixture_tensor.squeeze(0)[1].detach().cpu().numpy())
    sdr_avg = (sdr_l + sdr_r) / 2
    sdr_underlying.append([sdr_l, sdr_r, sdr_avg])


  scaling_factor = np.dot(estimate, target) / np.dot(target, target)
  si_snr = 10 * np.log10(np.sum(scaled_target ** 2) / np.sum(error ** 2))
100%|██████████████████████████████████████████████████████████████████████████████| 2409/2409 [14:57<00:00,  2.68it/s]


In [16]:
import math
avg = 0
total = 0
for idx, val in enumerate(si_snrs):
    try:
        val = (si_snrs[idx][2] -  si_snrs_underlying[idx][2])
        if math.isnan(val) == False and math.isinf(val) == False:
            print((si_snrs[idx][2] -  si_snrs_underlying[idx][2]))
            avg += val
            total += 1
    except:
        pass
avg/total

-0.4773104190826416
7.617533504962921
6.7350660264492035
6.204380989074707
6.378434635698795
6.33437842130661
6.668040305376053
6.942460220307112
7.448592595756054
9.589654758456163
9.504331587813795
8.96923454478383
8.717327187769115
5.780076589435339
5.672363303601742
5.729688536375761
6.335941106081009
6.157513111829758
6.200115382671356
6.84120487421751
6.441983422264457
7.211280167102814
9.27529675886035
9.477374618873
8.151331208646297
7.394803203642368
7.608300596475601
8.057842627167702
8.520137369632721
7.978949435055256
6.569849103689194
8.338380130007863
9.409715700894594
9.432345429740963
7.973926570266485
7.077752277255058
9.028401225805283
9.765885714441538
6.890062242746353
6.7417846620082855
7.1610963344573975
7.640435136854649
5.903421230614185
8.336483184248209
8.783568926155567
6.877617239952087
4.877880532294512
2.606088719330728
2.065776502713561
1.6414556617382914
1.1653224751353264
4.585756356827915
7.933578193187714
8.213276378810406
7.805398352793418
8.04927914

In [18]:
avg = 0
total = 0
for idx, val in enumerate(si_snrs):
    try:
        val = (sdr[idx][2] -  sdr_underlying[idx][2])
        if math.isnan(val) == False and math.isinf(val) == False:
            # print((sdr[idx][2] -  sdr_underlying[idx][2]))
            avg += val
            total += 1
    except:
        pass
avg/total

7.236212988765851

In [19]:
idxs = list(all_scenes)
si_snrs_0 = []
si_snrs_underlying_0 = []
sdr_0 = []
sdr_underlying_0 = []
for idx in tqdm(idxs):
    sample =  all_scenes[idx]

    audio_path = sample['music_path']

    start_point = sample['start_point']

    mixture_tensor = load_audio(audio_path, start_point,'mixture.wav').unsqueeze(0).to(model.device)
    drum_tensor = load_audio(audio_path, start_point,'drums.wav').unsqueeze(0).to(model.device)
    roll_tensor = load_roll(audio_path, start_point, sample['frames']).unsqueeze(0).to(model.device)
    roll_tensor_0 = torch.zeros_like(roll_tensor).to(model.device)
    sep = model(mixture_tensor, roll_tensor_0)

    sisnr_l = calculate_si_snr(drum_tensor.squeeze(0)[0].cpu().numpy(), sep.squeeze(0)[0].detach().cpu().numpy())
    sisnr_r = calculate_si_snr(drum_tensor.squeeze(0)[1].cpu().numpy(), sep.squeeze(0)[1].detach().cpu().numpy())
    sisnr_avg = (sisnr_l + sisnr_r) / 2
    si_snrs_0.append([sisnr_l, sisnr_r, sisnr_avg])

    sdr_l = calculate_sdr(drum_tensor.squeeze(0)[0].cpu().numpy(), sep.squeeze(0)[0].detach().cpu().numpy())
    sdr_r = calculate_sdr(drum_tensor.squeeze(0)[1].cpu().numpy(), sep.squeeze(0)[1].detach().cpu().numpy())
    sdr_avg = (sdr_l + sdr_r) / 2
    sdr_0.append([sdr_l, sdr_r, sdr_avg])

    ################################################################
    
    sisnr_l = calculate_si_snr(drum_tensor.squeeze(0)[0].cpu().numpy(), mixture_tensor.squeeze(0)[0].detach().cpu().numpy())
    sisnr_r = calculate_si_snr(drum_tensor.squeeze(0)[1].cpu().numpy(), mixture_tensor.squeeze(0)[1].detach().cpu().numpy())
    sisnr_avg = (sisnr_l + sisnr_r) / 2
    si_snrs_underlying_0.append([sisnr_l, sisnr_r, sisnr_avg])

    sdr_l = calculate_sdr(drum_tensor.squeeze(0)[0].cpu().numpy(), mixture_tensor.squeeze(0)[0].detach().cpu().numpy())
    sdr_r = calculate_sdr(drum_tensor.squeeze(0)[1].cpu().numpy(), mixture_tensor.squeeze(0)[1].detach().cpu().numpy())
    sdr_avg = (sdr_l + sdr_r) / 2
    sdr_underlying_0.append([sdr_l, sdr_r, sdr_avg])


  scaling_factor = np.dot(estimate, target) / np.dot(target, target)
  si_snr = 10 * np.log10(np.sum(scaled_target ** 2) / np.sum(error ** 2))
100%|██████████████████████████████████████████████████████████████████████████████| 2409/2409 [14:55<00:00,  2.69it/s]


In [23]:
import math
avg = 0
total = 0
for idx, val in enumerate(si_snrs):
    try:
        val = (si_snrs_0[idx][2] -  si_snrs_underlying_0[idx][2])
        if math.isnan(val) == False and math.isinf(val) == False:
            # print((si_snrs_0[idx][2] -  si_snrs_underlying_0[idx][2]))
            avg += val
            total += 1
    except:
        pass

avg/total

-2.0011597871780396
-1.332143247127533
-4.964304864406586
-2.1132171154022217
-1.4989972114562988
-0.7988956570625305
-3.746318370103836
-11.400507986545563
-2.8164413571357727
1.1688870191574097
3.596876561641693
2.7645710110664368
1.1559030413627625
-2.6297730207443237
-2.0368246734142303
0.26597604155540466
-4.60450753569603
-3.972073197364807
-0.7188493013381958
-2.2567689418792725
-1.8664464354515076
-0.2548027038574219
4.992964416742325
3.389427661895752
-1.1439529061317444
-7.781703770160675
-0.9126344323158264
-0.5147570371627808
-5.443170070648193
-2.622312605381012
-2.226265072822571
-2.998054027557373
-0.4925847053527832
3.873268961906433
-3.9056283235549927
-4.598327577114105
-0.6390729546546936
1.0771051049232483
-10.400405079126358
-7.531547695398331
-12.302199304103851
-8.542415760457516
-15.765834264457226
-21.631499882787466
-5.747926346957684
-7.950068414211273
-7.191348951309919
-4.872061610221863
-5.603581964969635
-6.313891265308484
-6.895253993570805
-6.5573854791

-5.396237227772006

In [21]:
import math
avg = 0
total = 0
for idx, val in enumerate(si_snrs):
    try:
        val = (sdr_0[idx][2] -  sdr_underlying_0[idx][2])
        if math.isnan(val) == False and math.isinf(val) == False:
            # print((si_snrs_0[idx][2] -  si_snrs_underlying_0[idx][2]))
            avg += val
            total += 1
    except:
        pass

avg/total

-5.396237227772006

# SDR