# Imports

In [1]:
import os
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
from torchaudio.models import conv_tasnet_base, ConvTasNet
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import random
from pytorch_lightning.loggers import WandbLogger
import wandb
import auraloss
import collections
from tqdm import tqdm
import pretty_midi
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB, CONVTASNET_BASE_LIBRI2MIX
print(torch.cuda.is_available())
import plotly.graph_objects as go
from torch.optim import lr_scheduler


True


# Set Seeds

In [2]:
seed_value = 3407
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

# Construct Teh Datas

In [3]:
path = "D:/Github/phd-drum-sep/Data/musdb18hq/"

In [4]:
os.listdir(path)

['test', 'train']

In [5]:
train = list(os.listdir(path+'train'))
test = list(os.listdir(path+'test'))

In [6]:
sources = ['drum', 'bass', 'other', 'vocals']

In [7]:
all_scenes = {}
counter = 0
for idx, val in tqdm(enumerate(train)):
    p = path + 'train/' + val + "/"
    info = torchaudio.info(f"{p}mixture.wav")
    seconds = info.num_frames // 44100
    for i in range(0, seconds - 1, 1):
        start_point = i * 44100
        if start_point + 44100 < info.num_frames:
            all_scenes[counter] = {'music_path': p, 'start_point': start_point, 'length': 44100, 'frames' : info.num_frames}
            counter += 1

100it [00:00, 4639.41it/s]


In [8]:
def turn_transcription_into_roll(transcription, frames):
    # Determine your sampling frequency (frames per second)
    fs = 44100
    
    piano_roll_length = int(frames)
    
    # Initialize the piano roll array
    piano_roll = np.zeros((64, piano_roll_length))
    
    # Fill in the piano roll array
    for note in transcription.instruments[0].notes:
        # Convert start and end times to frame indices
        start_frame = int(np.floor(note.start * fs))
        end_frame = int(np.ceil(note.end * fs))
        
        # Set the corresponding frames to 1 (or note.velocity for a velocity-sensitive representation)
        piano_roll[note.pitch, start_frame:end_frame] = 1  # Or use note.velocity
        
    roll = np.vstack([piano_roll[35:36, :], piano_roll[38:39, :], piano_roll[42:43, :], piano_roll[47:48, :], piano_roll[49:50, :]])
    return roll

# Data Loaders

In [9]:
class AudioDataGenerator(Dataset):
    def __init__(self, data, sample_rate=HDEMUCS_HIGH_MUSDB.sample_rate, segment_length = 1):
        self.data = data
        self.sample_rate = sample_rate
        self.segment_length = sample_rate * segment_length

    def __len__(self):
        return len(self.data)
    
    def load_audio(self, path, start_point, filename):
        audio_tensors = []
        file = filename
        segment, _ = torchaudio.load(f"{path}/{file}", frame_offset=start_point, num_frames=self.segment_length)
        audio_tensors.append(segment)
        return torch.cat(audio_tensors, dim=0)

    def load_roll(self, path, start_point, frames):
        midi = path + '/mixture.wav.mid'
        transcription = pretty_midi.PrettyMIDI(midi)
        roll = turn_transcription_into_roll(transcription, frames)
        # print(roll.shape)
        roll = roll[:, start_point: start_point + self.segment_length]
        return torch.from_numpy(roll).float()

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.data[idx]

        # Load audio as a tensor
        audio_path = sample['music_path']

        start_point = sample['start_point']

        mixture_tensor = self.load_audio(audio_path, start_point,'mixture.wav')
        drum_tensor = self.load_audio(audio_path, start_point,'drums.wav')
        roll_tensor = self.load_roll(audio_path, start_point, sample['frames'])
        return mixture_tensor, drum_tensor, roll_tensor

## Lightning Data Module

In [10]:
class AudioDataModule(pl.LightningDataModule):
    def __init__(self, data, batch_size=32, num_workers=0, persistent_workers=False, shuffle=False):
        super().__init__()
        self.data = data
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.persistent_workers=persistent_workers
        self.shuffle = shuffle

    def setup(self, stage=None):
        # Split your data here if necessary, e.g., into train, validation, test
        self.dataset = AudioDataGenerator(self.data)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers = self.num_workers, persistent_workers=self.persistent_workers)

    # Implement val_dataloader() and test_dataloader() if you have validation and test data

# making the model

In [11]:

class Decoder(torch.nn.Module):
    def __init__(self, channels = 2, length = 40, basis_signals = 500):
        super().__init__()
        self.channels = channels
        self.basis_signals = basis_signals
        self.length = length
        self.inverse = torch.nn.Linear(self.basis_signals, self.length, bias=False)

    def forward(self, masks, weight_mixture, norm):
        out = weight_mixture.unsqueeze(2) * masks

        out = self.inverse(out)
        norm = norm.unsqueeze(2)
        out = out * norm
        
        return out
        
class Encoder(torch.nn.Module):
    def __init__(self, length=40, basis_signals=500, eps=1e-8):
        super().__init__()

        self.length = length
        self.basis_signals = basis_signals
        self.eps = eps

        self.u = torch.nn.Conv1d(length, basis_signals, kernel_size=1, bias=False)
        self.v = torch.nn.Conv1d(length, basis_signals, kernel_size=1, bias=False)
        self.relu = torch.nn.ReLU()


    def forward(self, x):

        B, K, L = x.size()

        norm = torch.norm(x, p=2, dim=2, keepdim=True)  # B x K x 1
        normalised_x = x / (norm + self.eps) # B x K

        # print(normalised_x.shape)
        normalised_x = normalised_x.view(B, L, -1)
        # print(normalised_x.shape)


        ux = self.u(normalised_x)
        ux = self.relu(ux)


        vx = self.v(normalised_x)
        vx = torch.sigmoid(vx)

        w = torch.mul(ux, vx)

        w = w.view(B, K, self.basis_signals)

        return w, norm


class DepthConvLayer(torch.nn.Module):

    def __init__(self, input_channels, hidden_channels, conv_kernel_size, padding, dilation=1):
        super().__init__()

        self.conv_kernel_size = conv_kernel_size
        self.padding = padding
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        
        self.prelu = torch.nn.PReLU()
        self.group_norm = torch.nn.GroupNorm(1, self.hidden_channels, eps=1e-08)
        
        self.conv = torch.nn.Conv1d(input_channels, self.hidden_channels, 1)
        
        self.dilated_conv = torch.nn.Conv1d(self.hidden_channels, self.hidden_channels, 
                                            self.conv_kernel_size, dilation=dilation, 
                                            groups=self.hidden_channels, padding=self.padding)
        
        self.residual_output = torch.nn.Conv1d(self.hidden_channels, self.input_channels, 1)
        
        self.skip_output = torch.nn.Conv1d(self.hidden_channels, self.input_channels, 1)

    def forward(self, x):
        
        output = self.group_norm(self.prelu(self.conv(x)))
        
        output = self.group_norm(self.prelu(self.dilated_conv(output)))
        
        residual = self.residual_output(output)
        
        skip = self.skip_output(output)
        
        return residual, skip

class Seperator(torch.nn.Module):
    def __init__(self, input_dim, output_dim, bottleneck_conv_size=128, skip_conv_size = 512,
                 layers_per_block=8, number_blocks=3, conv_kernel_size=3):
        super().__init__()

        self.layers_per_block = layers_per_block
        self.number_blocks = number_blocks
        
        # normalization
        self.first_conv = torch.nn.Conv1d(input_dim, bottleneck_conv_size, 1)
        
        self.dilation_factors = []
        self.padding = []
        
        # TCN for feature extraction
        self.TCN = torch.nn.ModuleList([])
        for s in range(self.number_blocks):
            for l in range(self.layers_per_block):
                
                dilation_padding_factor = 2 ** l
                
                self.dilation_factors.append(dilation_padding_factor)
                self.padding.append(dilation_padding_factor)
                
                self.TCN.append(DepthConvLayer(bottleneck_conv_size, skip_conv_size, conv_kernel_size=conv_kernel_size, 
                                            dilation=dilation_padding_factor, padding=dilation_padding_factor))
                
        self.output = torch.nn.Sequential(torch.nn.PReLU(), torch.nn.Conv1d(bottleneck_conv_size, output_dim, 1) )
        self.softmax = torch.nn.Softmax(dim=2)
    
    def forward(self, x):
        
        B, K, L = x.size()
        
        output = x.permute(0,2,1)
        output = torch.nn.LayerNorm(output.shape, elementwise_affine=False, eps=1e-8)(output)
        output = self.first_conv(output)
        
        skip_connection = 0.
        
        for idx, layer in enumerate(self.TCN):
            residual, skip = layer(output)
            output = output + residual
            skip_connection = skip_connection + skip
            
        output = self.output(skip_connection)
        
        output = output.permute(0, 2, 1).view(B, K, 2, L)
        output = self.softmax(output)
        return output 


In [12]:

class DrumConvTasNet(pl.core.LightningModule):
    def __init__(self, 
                length=44100, 
                basis_signals=1000,
                num_sources=2,
                **kwargs):
        super().__init__()
        self.save_hyperparameters()

        # loss function
        self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
                    fft_sizes=[1024, 2048, 4096],
                    hop_sizes=[256, 512, 1024],
                    win_lengths=[1024, 2048, 4096],
                    scale="mel", 
                    n_bins=150,
                    sample_rate=44100,
                    device="cuda"
                )

        self.loss_fn_2 = auraloss.time.SISDRLoss()

        self.loss_fn_3 = torch.nn.L1Loss()

        self.loss_used = 0

        # this is where we construct the model components
        self.encoder = Encoder(length=length, 
                               basis_signals=basis_signals)
        self.seperation = Seperator(basis_signals, basis_signals*num_sources)
        self.decoder = Decoder(length=length, 
                               basis_signals=basis_signals)

        self.out_conv = nn.Conv1d(in_channels=7, out_channels=2, kernel_size=1)
        self.out = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1)  


    def compute_loss(self, outputs, ref_signals):
        loss = self.loss_fn(outputs, ref_signals) + self.loss_fn_2(outputs, ref_signals) +  self.loss_fn_3(outputs, ref_signals)
        return loss

    def forward(self, audio, drumroll):
        #input = x.permute(0, 2, 1)
        to_mix = torch.cat([audio, drumroll], axis=1)
        out, norm = self.encoder(to_mix)
        masks = self.seperation(out)
        outputs = self.decoder(masks, out, norm)
        shape = outputs.shape
        outputs = outputs.view(shape[0], shape[2], shape[1], shape[3])
        # print(outputs.size())
        

        out_2 = self.out_conv(outputs[:, 0, :, :])
        out_2 = self.out(out_2)
        
        return out_2

    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        audio, drum, drumroll = batch
        
        outputs = self.forward(audio, drumroll)
        # print(outputs.size())
        # print(audio.size())

        if batch_idx % 64 == 0:
            input_signal = audio[0].cpu().detach().numpy().T
            generated_signal = outputs[0].cpu().detach().numpy().T
            drum_signal = drum[0].cpu().detach().numpy().T 
            wandb.log({'audio_input': [wandb.Audio(input_signal, caption="Input", sample_rate=44100)]})
            wandb.log({'audio_reference': [wandb.Audio(drum_signal, caption="Reference", sample_rate=44100)]})
            wandb.log({'audio_output': [wandb.Audio(generated_signal, caption="Output", sample_rate=44100)]})
             
            for i in range(5):
                wandb.log({f'drum_{i + 1}': [wandb.Audio(drumroll[0].cpu().detach().numpy()[i, :], caption="Output", sample_rate=44100)]})


        loss = self.compute_loss(outputs, drum)         

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    

    def configure_optimizers(self):
        # Define your optimizer and optionally learning rate scheduler here
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        return [optimizer], [scheduler]
        

## Lightning Callbacks

In [13]:
class SaveModelEveryNSteps(pl.Callback):
    def __init__(self, save_step_frequency=256,):
        self.save_step_frequency = save_step_frequency
        self.save_path = "D://Github//phd-drum-sep//models//DrumConvTasNet//"
        os.makedirs(self.save_path , exist_ok=True)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if (trainer.global_step + 1) % self.save_step_frequency == 0:
            checkpoint_path = os.path.join(self.save_path, f"step_{trainer.global_step + 1}.ckpt")
            trainer.save_checkpoint(checkpoint_path)

# Train Loop

In [14]:
model = DrumConvTasNet()

In [15]:
wandb_logger = WandbLogger(project='DrumConvTasNet', log_model='all')

In [16]:
audio_data_module = AudioDataModule(all_scenes, batch_size=16, num_workers=0, persistent_workers=False)

In [17]:
trainer = pl.Trainer(
    max_epochs=1000,
    accelerator="gpu", 
    devices=-1,
    logger=wandb_logger,
    callbacks=[SaveModelEveryNSteps()],
    # accumulate_grad_batches=2,
    gradient_clip_val=5,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, audio_data_module)

[34m[1mwandb[0m: Currently logged in as: [33mhephyrius[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                    | Params
-------------------------------------------------------
0 | loss_fn    | MultiResolutionSTFTLoss | 0     
1 | loss_fn_2  | SISDRLoss               | 0     
2 | loss_fn_3  | L1Loss                  | 0     
3 | encoder    | Encoder                 | 88.2 M
4 | seperation | Seperator               | 5.2 M 
5 | decoder    | Decoder                 | 44.1 M
6 | out_conv   | Conv1d                  | 16    
7 | out        | Conv1d                  | 6     
-------------------------------------------------------
137 M     Trainable params
0         Non-trainable params
137 M     Total params
549.988   Total estimated model params size (MB)
C:\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoad

Training: |                                                                                      | 0/? [00:00<â€¦

torch.Size([32, 2, 7, 44100])
torch.Size([32, 2, 7, 44100])
