# Imports

In [1]:
import os
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import random
from pytorch_lightning.loggers import WandbLogger
import wandb
import auraloss
import collections
from tqdm import tqdm
import pretty_midi
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
print(torch.cuda.is_available())
import plotly.graph_objects as go
from torch.optim import lr_scheduler

True


In [2]:
seed_value = 3407
# seed_everything(seed_value, workers=True)
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

In [3]:
path = "D:/Github/phd-drum-sep/Data/musdb18_stems/"
adtof_path = "D:/Github/phd-drum-sep/Data/musdb18hq/"

In [4]:
import musdb
mus = musdb.DB(root=path, subsets="train")

In [5]:
all_scenes = {}
all_songs = {}

counter = 0
length_s = 4
fs = 16000

for track in tqdm(mus):
    mix = track.audio
    mix = np.swapaxes(mix,0,1)
    mix = torch.tensor(mix).type(torch.FloatTensor)
    
    drums = track.targets['drums'].audio
    drums = np.swapaxes(drums,0,1)
    drums = torch.tensor(drums).type(torch.FloatTensor)

    p2 = adtof_path + 'train/' + track.name + "/"

    if track.name not in list(all_songs.keys()):
        all_songs[track.name] = {'mix':mix, 'drums': drums}

    seconds = mix.shape[1] // 16000
    for i in range(0, seconds - length_s, length_s):

        start_point = i * 16000
        if start_point + (16000 * length_s) < mix.shape[1]:
            all_scenes[counter] = {'music_name': track.name, 'adtof_path': p2, 'start_point': start_point, 'length': 16000 * length_s, 'frames' : mix.shape[1]}
            counter += 1


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:51<00:00,  1.11s/it]


In [6]:
drums.shape

torch.Size([2, 12251136])

# Set Seeds

# Construct Teh Datas

In [7]:
os.listdir(adtof_path)

train = list(os.listdir(path+'train'))
test = list(os.listdir(path+'test'))

In [8]:
sources = ['drum', 'bass', 'other', 'vocals']

In [9]:
def turn_transcription_into_roll(transcription, frames):
    # Determine your sampling frequency (frames per second)
    fs = 16000
    
    piano_roll_length = int(frames)
    
    # Initialize the piano roll array
    piano_roll = np.zeros((64, piano_roll_length))
    
    # Fill in the piano roll array
    for note in transcription.instruments[0].notes:
        # Convert start and end times to frame indices
        start_frame = int(np.floor(note.start * fs))
        end_frame = int(np.ceil(note.end * fs))
        
        # Set the corresponding frames to 1 (or note.velocity for a velocity-sensitive representation)
        piano_roll[note.pitch, start_frame:end_frame] = 1  # Or use note.velocity
        
    roll = np.vstack([piano_roll[35:36, :], piano_roll[38:39, :], piano_roll[42:43, :], piano_roll[47:48, :], piano_roll[49:50, :]])
    return roll

# Data Loaders

In [10]:
class AudioDataGenerator(Dataset):
    def __init__(self, data, songs, sample_rate=16000, segment_length = 4):
        self.data = data
        self.songs = songs
        self.sample_rate = sample_rate
        self.segment_length = sample_rate * segment_length
        
        #caching
        self.cache_drumroll = collections.OrderedDict()

    def __len__(self):
        return len(self.data)
    
    def add_to_cache(self, cache, key, value, size):
        if key in cache:
            # Move to the end to avoid being removed soon
            cache.move_to_end(key)
        else:
            cache[key] = value
            # Remove the oldest item if cache exceeds the size limit
            if len(cache) >= size:
                cache.popitem(last=False)

    def load_audio(self, music_name, start_point, filename):
        waveform = self.songs[music_name][filename]
        segment_tensor = waveform[:, start_point: start_point + self.segment_length]
        return segment_tensor

    def load_roll(self, path, start_point, frames):
        midi = path + '/mixture.wav.mid'
        transcription = pretty_midi.PrettyMIDI(midi)

        cache_key = (path)
        if cache_key in self.cache_drumroll:
            roll = self.cache_drumroll[cache_key]
        else:
            roll = turn_transcription_into_roll(transcription, frames)
            self.add_to_cache(self.cache_drumroll, cache_key, roll, 400)

        roll = roll[:, start_point: start_point + self.segment_length]
        return torch.from_numpy(roll).float()

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.data[idx]

        # Load audio as a tensor
        music_name = sample['music_name']
        adtof_path = sample['adtof_path']

        start_point = sample['start_point']

        mixture_tensor = self.load_audio(music_name, start_point,'mix')
        drum_tensor = self.load_audio(music_name, start_point,'drums')
        roll_tensor = self.load_roll(adtof_path, start_point, sample['frames'])
        
        return mixture_tensor, drum_tensor, roll_tensor

## Lightning Data Module

In [11]:
class AudioDataModule(pl.LightningDataModule):
    def __init__(self, data, songs, batch_size=32, num_workers=0, persistent_workers=False, shuffle=False):
        super().__init__()
        self.data = data
        self.songs = songs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.persistent_workers=persistent_workers
        self.shuffle = shuffle

    def setup(self, stage=None):
        # Split your data here if necessary, e.g., into train, validation, test
        self.dataset = AudioDataGenerator(self.data, self.songs)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers = self.num_workers, persistent_workers=self.persistent_workers)

    # Implement val_dataloader() and test_dataloader() if you have validation and test data

# making the model

In [12]:
class NewSDRLoss(nn.Module):
    """
    New Signal-to-Distortion Ratio (SDR) loss module based on the MDX challenge definition.

    Args:
        reduction (str, optional): Specifies the reduction to apply to the output:
            'none': no reduction will be applied,
            'mean': the mean of the output will be calculated,
            'sum': the sum of the output will be calculated. Default: 'mean'.
    """
    def __init__(self, reduction='mean'):
        super(NewSDRLoss, self).__init__()
        self.reduction = reduction

    def forward(self, estimates, references):
        """
        Computes the SDR loss between the estimated and reference signals.

        Args:
            estimates (Tensor): Estimated signals, shape (batch, channels, height, width).
            references (Tensor): Reference signals, shape (batch, channels, height, width).

        Returns:
            Tensor: The calculated SDR loss.
        """
        assert references.dim() == 4 and estimates.dim() == 4, "Inputs must be 4D tensors."

        delta = 1e-7  # Avoid numerical errors
        num = torch.sum(torch.square(references), dim=(2, 3))
        den = torch.sum(torch.square(references - estimates), dim=(2, 3))
        num += delta
        den += delta
        scores = 10 * torch.log10(num / den)

        if self.reduction == 'mean':
            return -scores.mean()
        elif self.reduction == 'sum':
            return -scores.sum()
        else:  # 'none'
            return -scores

In [13]:
  
class DrumDemucs(pl.LightningModule):
    def __init__(self):
        super(DrumDemucs, self).__init__()

        self.loss_fn = NewSDRLoss()

        self.loss_used = 0

        sources = ['drum',
                   'noise',
                   ]
        
        self.demucs_mixer =  torchaudio.models.HDemucs(
            sources=sources,
            audio_channels=7,
            depth=6,
            nfft=2048,
        )

        self.out_conv = nn.Conv1d(in_channels=7, out_channels=2, kernel_size=1)
        
    def compute_loss(self, outputs, ref_signals):
        loss = self.loss_fn(outputs.unsqueeze(2), ref_signals.unsqueeze(2))
        return loss

    def forward(self, audio, drumroll):
        to_mix = torch.cat([audio, drumroll], axis=1)
        out = self.demucs_mixer(to_mix)
        out_2 = self.out_conv(out[:, 0, :, :])
        # out_2 = torch.tanh(out_2)

        return out_2
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        audio, drum, drumroll = batch
        
        outputs = self.forward(audio, drumroll)
        # print(outputs.size())

        if batch_idx % 256 == 0:
            for i in range(2):
                input_signal = audio[i].cpu().detach().numpy().T
                generated_signal = outputs[i].cpu().detach().numpy().T
                drum_signal = drum[i].cpu().detach().numpy().T 
                wandb.log({f'audio_input_{i}': [wandb.Audio(input_signal, caption="Input", sample_rate=16000)]})
                wandb.log({f'audio_reference_{i}': [wandb.Audio(drum_signal, caption="Reference", sample_rate=16000)]})
                wandb.log({f'audio_output_{i}': [wandb.Audio(generated_signal, caption="Output", sample_rate=16000)]})
             
            for i in range(5):
                wandb.log({f'drum_{i + 1}': [wandb.Audio(drumroll[0].cpu().detach().numpy()[i, :], caption="Output", sample_rate=16000)]})


        loss = self.compute_loss(outputs, drum)         

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    

    def configure_optimizers(self):
        # Define your optimizer and optionally learning rate scheduler here
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        return [optimizer], [scheduler]
        
        

# Train Loop

In [14]:
model = DrumDemucs()

In [15]:
wandb_logger = WandbLogger(project='CompressedDrumDemucsSmall', log_model='all')

In [16]:
audio_data_module = AudioDataModule(all_scenes, all_songs, batch_size=32, num_workers=0, persistent_workers=False)

In [17]:
trainer = pl.Trainer(
    max_epochs=250,
    accelerator="gpu", 
    devices=-1,
    logger=wandb_logger,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(model, audio_data_module)

[34m[1mwandb[0m: Currently logged in as: [33mhephyrius[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params
--------------------------------------------
0 | loss_fn      | NewSDRLoss | 0     
1 | demucs_mixer | HDemucs    | 78.9 M
2 | out_conv     | Conv1d     | 16    
--------------------------------------------
78.9 M    Trainable params
0         Non-trainable params
78.9 M    Total params
315.685   Total estimated model params size (MB)
C:\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])
torch.Size([2, 7552000])


RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same