# Imports

In [1]:
import os
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import random
from pytorch_lightning.loggers import WandbLogger
import wandb
import auraloss
import collections
from tqdm import tqdm
import pretty_midi
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
print(torch.cuda.is_available())
import plotly.graph_objects as go
from torch.optim import lr_scheduler
from IPython.display import Audio

True


# Set Seeds

In [2]:
seed_value = 3407
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

# Construct Teh Datas

In [3]:
path = "D:/Github/phd-drum-sep/Data/musdb18hq/"

In [4]:
os.listdir(path)

['test', 'train']

In [5]:
train = list(os.listdir(path+'train'))
test = list(os.listdir(path+'test'))

In [6]:
sources = ['drum', 'bass', 'other', 'vocals']

In [7]:
all_scenes = {}
counter = 0
sample_rate = 44100
segment_length = sample_rate * 1

for idx, val in tqdm(enumerate(test)):
    p = path + 'test/' + val + "/"
    info = torchaudio.info(f"{p}mixture.wav")
    seconds = info.num_frames // 44100
    for i in range(0, seconds - 1, 1):
        start_point = i * 44100
        if start_point + 44100 < info.num_frames:
            all_scenes[counter] = {'music_path': p, 'start_point': start_point, 'length': 44100, 'frames' : info.num_frames}
            counter += 1

50it [00:00, 4544.40it/s]


In [8]:
def turn_transcription_into_roll(transcription, frames):
    # Determine your sampling frequency (frames per second)
    fs = 44100
    
    piano_roll_length = int(frames)
    
    # Initialize the piano roll array
    piano_roll = np.zeros((64, piano_roll_length))
    
    # Fill in the piano roll array
    for note in transcription.instruments[0].notes:
        # Convert start and end times to frame indices
        start_frame = int(np.floor(note.start * fs))
        end_frame = int(np.ceil(note.end * fs))
        
        # Set the corresponding frames to 1 (or note.velocity for a velocity-sensitive representation)
        piano_roll[note.pitch, start_frame:end_frame] = 1  # Or use note.velocity
        
    roll = np.vstack([piano_roll[35:36, :], piano_roll[38:39, :], piano_roll[42:43, :], piano_roll[47:48, :], piano_roll[49:50, :]])
    return roll

# model

In [9]:
class DrumDemucs(pl.LightningModule):
    def __init__(self):
        super(DrumDemucs, self).__init__()

        self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
                    fft_sizes=[1024, 2048, 4096],
                    hop_sizes=[256, 512, 1024],
                    win_lengths=[1024, 2048, 4096],
                    scale="mel", 
                    n_bins=150,
                    sample_rate=44100,
                    device="cuda"
                )

        self.loss_fn_2 = auraloss.time.SISDRLoss()

        self.loss_fn_3 = torch.nn.L1Loss()

        self.loss_used = 0

        sources = ['drum',
                   'noise',
                   ]
        
        self.demucs_mixer =  torchaudio.models.HDemucs(
            sources=sources,
            audio_channels=7,
            depth=6,
        )

        self.out_conv = nn.Conv1d(in_channels=7, out_channels=2, kernel_size=1)
        self.out = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1)      


    def compute_loss(self, outputs, ref_signals):
        loss = self.loss_fn(outputs, ref_signals) + self.loss_fn_2(outputs, ref_signals) +  self.loss_fn_3(outputs, ref_signals)
        return loss

    def forward(self, audio, drumroll):
        to_mix = torch.cat([audio, drumroll], axis=1)
        out = self.demucs_mixer(to_mix)
        out_2 = self.out_conv(out[:, 0, :, :])
        out_2 = self.out(out_2)
        # out_2 = torch.tanh(out_2)

        return out_2
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        audio, drum, drumroll = batch
        
        outputs = self.forward(audio, drumroll)
        # print(outputs.size())

        if batch_idx % 64 == 0:
            input_signal = audio[0].cpu().detach().numpy().T
            generated_signal = outputs[0].cpu().detach().numpy().T
            drum_signal = drum[0].cpu().detach().numpy().T 
            wandb.log({'audio_input': [wandb.Audio(input_signal, caption="Input", sample_rate=44100)]})
            wandb.log({'audio_reference': [wandb.Audio(drum_signal, caption="Reference", sample_rate=44100)]})
            wandb.log({'audio_output': [wandb.Audio(generated_signal, caption="Output", sample_rate=44100)]})
             
            for i in range(5):
                wandb.log({f'drum_{i + 1}': [wandb.Audio(drumroll[0].cpu().detach().numpy()[i, :], caption="Output", sample_rate=44100)]})


        loss = self.compute_loss(outputs, drum)         

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    

    def configure_optimizers(self):
        # Define your optimizer and optionally learning rate scheduler here
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        return [optimizer], [scheduler]
        

In [10]:
model = DrumDemucs.load_from_checkpoint('D:/Github/phd-drum-sep/analysis/demucs_model_analysis/checkpoint/epoch_380.ckpt')
model.to('cpu')
model.eval()

DrumDemucs(
  (loss_fn): MultiResolutionSTFTLoss(
    (stft_losses): ModuleList(
      (0-2): 3 x STFTLoss(
        (spectralconv): SpectralConvergenceLoss()
        (logstft): STFTMagnitudeLoss(
          (distance): L1Loss()
        )
        (linstft): STFTMagnitudeLoss(
          (distance): L1Loss()
        )
      )
    )
  )
  (loss_fn_2): SISDRLoss()
  (loss_fn_3): L1Loss()
  (demucs_mixer): HDemucs(
    (freq_encoder): ModuleList(
      (0): _HEncLayer(
        (conv): Conv2d(14, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
        (norm1): Identity()
        (rewrite): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))
        (norm2): Identity()
        (dconv): _DConv(
          (layers): ModuleList(
            (0): Sequential(
              (0): Conv1d(48, 12, kernel_size=(3,), stride=(1,), padding=(1,))
              (1): GroupNorm(1, 12, eps=1e-05, affine=True)
              (2): GELU(approximate='none')
              (3): Conv1d(12, 96, kernel_size=(1,), strid

In [11]:
cache_audio = collections.OrderedDict()
cache_drumroll = collections.OrderedDict()

def add_to_cache(cache, key, value, size=10):
    if key in cache:
        # Move to the end to avoid being removed soon
        cache.move_to_end(key)
    else:
        cache[key] = value
        # Remove the oldest item if cache exceeds the size limit
        if len(cache) >= size:
            cache.popitem(last=False)


def load_audio(path, start_point, filename):
    audio_tensors = []

    cache_key = (path+'/'+filename)
    if cache_key in cache_audio:
        waveform = cache_audio[cache_key]
    else:
        waveform, _ = torchaudio.load(f"{path}/{filename}")
        add_to_cache(cache_audio, cache_key, waveform)

    segment = waveform[:, start_point: start_point + segment_length]
    audio_tensors.append(segment)
    
    return torch.cat(audio_tensors, dim=0)

def load_roll(path, start_point, frames):
    midi = path + '/mixture.wav.mid'
    transcription = pretty_midi.PrettyMIDI(midi)

    cache_key = (path)
    if cache_key in cache_drumroll:
        roll = cache_drumroll[cache_key]
    else:
        roll = turn_transcription_into_roll(transcription, frames)
        add_to_cache(cache_drumroll, cache_key, roll)

    roll = roll[:, start_point: start_point + segment_length]
    return torch.from_numpy(roll).float()


# Seperate specific tracks

In [12]:
# for idx in tqdm(idxs[:1]):
audios = []
idx = 1100
sample =  all_scenes[idx]

audio_path = sample['music_path']

start_point = sample['start_point']

mixture_tensor = load_audio(audio_path, start_point,'mixture.wav').unsqueeze(0).to(model.device)
drum_tensor = load_audio(audio_path, start_point,'drums.wav').unsqueeze(0).to(model.device)
roll_tensor = load_roll(audio_path, start_point, sample['frames'])#

for i in range(5):
    roll_tensor_0 = (torch.zeros_like(roll_tensor) + torch.ones_like(roll_tensor)) / 2
    roll_tensor_0[i, :] = roll_tensor[i, :]
    roll_tensor_0 = roll_tensor_0.unsqueeze(0).to(model.device)
    sep = model(mixture_tensor, roll_tensor_0)
    audios.append(sep)

sep = model(mixture_tensor, roll_tensor.unsqueeze(0).to(model.device))
audios.append(sep)


# SISNR

In [13]:
def new_sdr(references, estimates):
    """
    Compute the SDR according to the MDX challenge definition.
    Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
    """
    assert references.dim() == 4
    assert estimates.dim() == 4
    delta = 1e-7  # avoid numerical errors
    num = torch.sum(torch.square(references), dim=(2, 3))
    den = torch.sum(torch.square(references - estimates), dim=(2, 3))
    num += delta
    den += delta
    scores = 10 * torch.log10(num / den)
    return scores


In [14]:
idxs = list(all_scenes)

sdr = {}
sdr_underlying = {}
current_track = ""
for idx in tqdm(idxs):

    with torch.no_grad():
        audio_path = sample['music_path']
        
        if audio_path != current_track:
            current_track = audio_path
            sdr[current_track] = []
            sdr_underlying[current_track] = []
        
        sample =  all_scenes[idx]
        start_point = sample['start_point']
    
        mixture_tensor = load_audio(audio_path, start_point,'mixture.wav').unsqueeze(0).to(model.device)
        drum_tensor = load_audio(audio_path, start_point,'drums.wav').unsqueeze(0).to(model.device)
        roll_tensor = load_roll(audio_path, start_point, sample['frames']).unsqueeze(0).to(model.device)
        sep = model(mixture_tensor, roll_tensor)
    
        x_ = mixture_tensor.detach().unsqueeze(2)
        y = drum_tensor.unsqueeze(2)
        y_hat = sep.detach().unsqueeze(2)
        
        nsdr = new_sdr(y, y_hat)
        nsdr_underlying = new_sdr(y, x_)

        sdr[current_track].append(nsdr)
    
        sdr_underlying[current_track].append(nsdr_underlying)

100%|████████████████████████████████████████████████████████████████████████████| 12396/12396 [49:22<00:00,  4.18it/s]


In [15]:
median_sdrs = {}

for name in list(sdr.keys()):
    left = []
    right = []
    sdrs = sdr[name]

    for value in sdrs:
        left.append(value[0][0].item())
        right.append(value[0][1].numpy().item())

    l = np.median(left)
    r = np.median(right)

    median_sdrs[name] = {'left_median': l, 'right_median':r}
    

In [16]:
lefts = []
rights = []
for name in list(median_sdrs.keys()):
    values = median_sdrs[name]
    lefts.append(values['left_median'])
    rights.append(values['right_median'])
    

In [17]:
l = np.median(lefts)
r = np.median(rights)
l, r

(3.7577788829803467, 3.9401825070381165)