# Imports

In [1]:
import os
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
from torchaudio.models import conv_tasnet_base, ConvTasNet
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import random
from pytorch_lightning.loggers import WandbLogger
import wandb
import auraloss
import collections
from tqdm import tqdm
import pretty_midi
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB, CONVTASNET_BASE_LIBRI2MIX
print(torch.cuda.is_available())
import plotly.graph_objects as go
from torch.optim import lr_scheduler


True


# Set Seeds

In [2]:
seed_value = 3407
torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.set_float32_matmul_precision('high')

# Construct Teh Datas

In [3]:
path = "D:/Github/phd-drum-sep/Data/musdb18hq/"

In [4]:
os.listdir(path)

['test', 'train']

In [5]:
train = list(os.listdir(path+'train'))
test = list(os.listdir(path+'test'))

In [6]:
sources = ['drum', 'bass', 'other', 'vocals']

In [7]:
all_scenes = {}
counter = 0
for idx, val in tqdm(enumerate(train)):
    p = path + 'train/' + val + "/"
    info = torchaudio.info(f"{p}mixture.wav")
    seconds = info.num_frames // 44100
    for i in range(0, seconds - 10, 10):
        start_point = i * 44100
        if start_point + 441000 < info.num_frames:
            all_scenes[counter] = {'music_path': p, 'start_point': start_point, 'length': 441000, 'frames' : info.num_frames}
            counter += 1

100it [00:00, 6713.79it/s]


In [8]:
def turn_transcription_into_roll(transcription, frames):
    # Determine your sampling frequency (frames per second)
    fs = 44100
    
    piano_roll_length = int(frames)
    
    # Initialize the piano roll array
    piano_roll = np.zeros((64, piano_roll_length))
    
    # Fill in the piano roll array
    for note in transcription.instruments[0].notes:
        # Convert start and end times to frame indices
        start_frame = int(np.floor(note.start * fs))
        end_frame = int(np.ceil(note.end * fs))
        
        # Set the corresponding frames to 1 (or note.velocity for a velocity-sensitive representation)
        piano_roll[note.pitch, start_frame:end_frame] = 1  # Or use note.velocity
        
    roll = np.vstack([piano_roll[35:36, :], piano_roll[38:39, :], piano_roll[42:43, :], piano_roll[47:48, :], piano_roll[49:50, :]])
    return roll

# Data Loaders

In [9]:
class AudioDataGenerator(Dataset):
    def __init__(self, data, sample_rate=HDEMUCS_HIGH_MUSDB.sample_rate, segment_length = 10):
        self.data = data
        self.sample_rate = sample_rate
        self.segment_length = sample_rate * segment_length

    def __len__(self):
        return len(self.data)
    
    def load_audio(self, path, start_point, filename):
        audio_tensors = []
        file = filename
        segment, _ = torchaudio.load(f"{path}/{file}", frame_offset=start_point, num_frames=self.segment_length)
        audio_tensors.append(segment)
        return torch.cat(audio_tensors, dim=0)

    def load_roll(self, path, start_point, frames):
        midi = path + '/mixture.wav.mid'
        transcription = pretty_midi.PrettyMIDI(midi)
        roll = turn_transcription_into_roll(transcription, frames)
        # print(roll.shape)
        roll = roll[:, start_point: start_point + self.segment_length]
        return torch.from_numpy(roll).float()

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.data[idx]

        # Load audio as a tensor
        audio_path = sample['music_path']

        start_point = sample['start_point']

        mixture_tensor = self.load_audio(audio_path, start_point,'mixture.wav')
        drum_tensor = self.load_audio(audio_path, start_point,'drums.wav')
        roll_tensor = self.load_roll(audio_path, start_point, sample['frames'])
        return mixture_tensor, drum_tensor, roll_tensor

## Lightning Data Module

In [10]:
class AudioDataModule(pl.LightningDataModule):
    def __init__(self, data, batch_size=32, num_workers=0, persistent_workers=False, shuffle=False):
        super().__init__()
        self.data = data
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.persistent_workers=persistent_workers
        self.shuffle = shuffle

    def setup(self, stage=None):
        # Split your data here if necessary, e.g., into train, validation, test
        self.dataset = AudioDataGenerator(self.data)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers = self.num_workers, persistent_workers=self.persistent_workers)

    # Implement val_dataloader() and test_dataloader() if you have validation and test data

# making the model

In [None]:
import torch
import torch.nn as nn

In [12]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Sequential(
            nn.Conv1d(in_channels, in_channels // reduction_ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv1d(in_channels // reduction_ratio, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv1d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAMBlock(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, attention_kernel_size=7):
        super(CBAMBlock, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(attention_kernel_size)

    def forward(self, x):
        x = x * self.channel_attention(x)
        x = x * self.spatial_attention(x)
        return x


In [11]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # Define down-sampling layers for one encoder head
        self.down1 = nn.Sequential(nn.Conv1d(7, 16, kernel_size=15, stride=1, padding=7), nn.BatchNorm1d(16), nn.PReLU())
        self.down2 = nn.Sequential(nn.Conv1d(16, 32, kernel_size=15, stride=2, padding=7), nn.BatchNorm1d(32), nn.PReLU())
        self.down3 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=13, stride=2, padding=6), nn.BatchNorm1d(64), nn.PReLU())
        self.down4 = nn.Sequential(nn.Conv1d(64, 128, kernel_size=11, stride=2, padding=5), nn.BatchNorm1d(128), nn.PReLU())
        self.down5 = nn.Sequential(nn.Conv1d(128, 256, kernel_size=9, stride=2, padding=4), nn.BatchNorm1d(256), nn.PReLU())
        self.down6 = nn.Sequential(nn.Conv1d(256, 512, kernel_size=7, stride=2, padding=3), nn.BatchNorm1d(512), nn.PReLU())
        self.down7 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(1024), nn.PReLU())
        self.down8 = nn.Sequential(nn.Conv1d(1024, 2048, kernel_size=3, stride=2, padding=1), nn.BatchNorm1d(2048), nn.PReLU())
        self.down9 = nn.Sequential(nn.Conv1d(2048, 4096, kernel_size=3, stride=2, padding=1), nn.BatchNorm1d(4096), nn.PReLU())
        # self.down8 = nn.Sequential(nn.Conv1d(1024, 2048, kernel_size=3, stride=2, padding=1), nn.BatchNorm1d(2048), nn.ReLU())
    
    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        d9 = self.down9(d8)
        # ... (additional down-sampling steps)
        return [d1, d2, d3, d4, d5, d6, d7, d8, d9]  # Return all intermediate outputs for skip connections

In [13]:
class WaveUNet(pl.LightningModule):
    def __init__(self):
        super(WaveUNet, self).__init__()
        # Create 4 separate encoder heads
        self.encoders = Encoder()

        # Define the up-sampling layers
        # Adjust the number of input channels according to the concatenated encoder outputs
        self.up9 = nn.Sequential(nn.ConvTranspose1d(4096, 2048, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm1d(2048), nn.PReLU())
        self.up8 = nn.Sequential(nn.ConvTranspose1d(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=0), nn.BatchNorm1d(1024), nn.PReLU())
        self.up7 = nn.Sequential(nn.ConvTranspose1d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1), nn.BatchNorm1d(512), nn.PReLU())
        self.up6 = nn.Sequential(nn.ConvTranspose1d(512, 256, kernel_size=7, stride=2, padding=3, output_padding=0), nn.BatchNorm1d(256), nn.PReLU())
        self.up5 = nn.Sequential(nn.ConvTranspose1d(256, 128, kernel_size=9, stride=2, padding=4, output_padding=0), nn.BatchNorm1d(128), nn.PReLU())
        self.up4 = nn.Sequential(nn.ConvTranspose1d(128, 64, kernel_size=11, stride=2, padding=5, output_padding=1), nn.BatchNorm1d(64), nn.PReLU())
        self.up3 = nn.Sequential(nn.ConvTranspose1d(64, 32, kernel_size=13, stride=2, padding=6, output_padding=1), nn.BatchNorm1d(32), nn.PReLU())
        self.up2 = nn.Sequential(nn.ConvTranspose1d(32, 16, kernel_size=15, stride=2, padding=7, output_padding=1), nn.BatchNorm1d(16), nn.PReLU())
        self.up1 = nn.Sequential(nn.ConvTranspose1d(16, 16, kernel_size=15, stride=1, padding=7, output_padding=0), nn.BatchNorm1d(16), nn.PReLU())
        
         # Additional convolutions for processing concatenated skip connections
        self.conv_skip_8 = nn.Sequential(nn.Conv1d(4096, 2048, kernel_size=1), nn.BatchNorm1d(2048), nn.PReLU())  # 1024 (from up3) + 512 (skip)
        self.conv_skip_7 = nn.Sequential(nn.Conv1d(2048, 1024, kernel_size=1), nn.BatchNorm1d(1024), nn.PReLU())  # 1024 (from up3) + 512 (skip)
        self.conv_skip_6 = nn.Sequential(nn.Conv1d(1024, 512, kernel_size=1), nn.BatchNorm1d(512), nn.PReLU())  # 1024 (from up3) + 512 (skip)
        self.conv_skip_5 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1), nn.BatchNorm1d(256), nn.PReLU())  # 512 (from up3) + 256 (skip)
        self.conv_skip_4 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1), nn.BatchNorm1d(128), nn.PReLU())  # 256 (from up3) + 128 (skip)
        self.conv_skip_3 = nn.Sequential(nn.Conv1d(128, 64, kernel_size=1), nn.BatchNorm1d(64), nn.PReLU())  # 128 (from up3) + 64 (skip)
        self.conv_skip_2 = nn.Sequential(nn.Conv1d(64, 32, kernel_size=1), nn.BatchNorm1d(32), nn.PReLU())  # 64 (from up3) + 32 (skip)
        self.conv_skip_1 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=1), nn.BatchNorm1d(16), nn.PReLU())  # 32 (from up2) + 16 (skip)

        self.cbam_8 = CBAMBlock(2048)
        self.cbam_7 = CBAMBlock(1024)
        self.cbam_6 = CBAMBlock(512)
        self.cbam_5 = CBAMBlock(256)
        self.cbam_4 = CBAMBlock(128)
        self.cbam_3 = CBAMBlock(64)
        self.cbam_2 = CBAMBlock(32)
        self.cbam_1 = CBAMBlock(16)

        # Output layer
        self.out = nn.Conv1d(16, 2, kernel_size=1)

        self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
                fft_sizes=[1024, 2048, 4096],
                hop_sizes=[256, 512, 1024],
                win_lengths=[1024, 2048, 4096],
                scale="mel", 
                n_bins=160,
                sample_rate=44100,
                device="cuda"
            )

        self.loss_fn_2 = auraloss.time.SISDRLoss()
    
        self.loss_fn_3 = torch.nn.L1Loss()


    def forward(self, audio, drumroll):
        skip_connections = []
        encoded_output = []

        audio_input = torch.cat([audio, drumroll], axis=1)
        
        outputs = self.encoders(audio_input)
        encoded_output.append(outputs[-1])  # Use the last output for the main path
        skip_connections.append(outputs[:-1])  # Collect earlier outputs for skip connections

        # Combine the encoded representations with feature embeddings
        combined = torch.cat([o for o in encoded_output], dim=1)

        #8
        u9 = self.up9(combined)

        #8
        cb8 = self.cbam_8(skip_connections[0][-1])
        skip_8 = torch.cat([u9, cb8], dim=1)  
        skip_8 = self.conv_skip_8(skip_8)  # Convolution to adjust channel size
        u8 = self.up8(skip_8)


        #7
        cb7 = self.cbam_8(skip_connections[0][-2])
        skip_7 = torch.cat([u8, cb7], dim=1)  
        skip_7 = self.conv_skip_7(skip_7)  # Convolution to adjust channel size
        u7 = self.up7(skip_7)

        #6
        cb6 = self.cbam_8(skip_connections[0][-3])
        skip_6 = torch.cat([u7, cb6], dim=1)  
        skip_6 = self.conv_skip_6(skip_6)  # Convolution to adjust channel size
        u6 = self.up6(skip_6)

        #5
        cb5 = self.cbam_8(skip_connections[0][-4])
        skip_5 = torch.cat([u6, cb5], dim=1)  
        skip_5 = self.conv_skip_5(skip_5)  # Convolution to adjust channel size
        u5 = self.up5(skip_5)

        #4
        cb4 = self.cbam_8(skip_connections[0][-5])
        skip_4 = torch.cat([u5, cb4], dim=1)  
        skip_4 = self.conv_skip_4(skip_4)  # Convolution to adjust channel size
        u4 = self.up4(skip_4)

        #3
        cb3 = self.cbam_8(skip_connections[0][-6])
        skip_3 = torch.cat([u4, cb3], dim=1)  
        skip_3 = self.conv_skip_3(skip_3)  # Convolution to adjust channel size
        u3 = self.up3(skip_3)

        #2
        cb2 = self.cbam_8(skip_connections[0][-7])
        skip_2 = torch.cat([u3, cb2], dim=1)  
        skip_2 = self.conv_skip_2(skip_2)  # Convolution to adjust channel size
        u2 = self.up2(skip_2)

        #1
        cb1 = self.cbam_8(skip_connections[0][-8])
        skip_1 = torch.cat([u2, cb1], dim=1)  
        skip_1 = self.conv_skip_1(skip_1)  # Convolution to adjust channel size
        u1 = self.up1(skip_1)
        
        # Output
        out = self.out(u1)
        return out

    def compute_loss(self, outputs, ref_signals):
        loss = self.loss_fn(outputs, ref_signals) + self.loss_fn_2(outputs, ref_signals) +  self.loss_fn_3(outputs, ref_signals)
        return loss

    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        audio, drum, drumroll = batch
        
        outputs = self.forward(audio, drumroll)
        # print(outputs.size())

        if batch_idx % 64 == 0:
            input_signal = audio[0].cpu().detach().numpy().T
            generated_signal = outputs[0].cpu().detach().numpy().T
            drum_signal = drum[0].cpu().detach().numpy().T 
            wandb.log({'audio_input': [wandb.Audio(input_signal, caption="Input", sample_rate=44100)]})
            wandb.log({'audio_reference': [wandb.Audio(drum_signal, caption="Reference", sample_rate=44100)]})
            wandb.log({'audio_output': [wandb.Audio(generated_signal, caption="Output", sample_rate=44100)]})
             
            for i in range(5):
                wandb.log({f'drum_{i + 1}': [wandb.Audio(drumroll[0].cpu().detach().numpy()[i, :], caption="Output", sample_rate=44100)]})


        loss = self.compute_loss(outputs, drum)         

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    

    def configure_optimizers(self):
        # Define your optimizer and optionally learning rate scheduler here
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        return [optimizer], [scheduler]
        

## Lightning Callbacks

In [14]:
class SaveModelEveryNSteps(pl.Callback):
    def __init__(self, save_step_frequency=256,):
        self.save_step_frequency = save_step_frequency
        self.save_path = "D://Github//phd-drum-sep//models//DrumWaveUNet//"
        os.makedirs(self.save_path , exist_ok=True)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if (trainer.global_step + 1) % self.save_step_frequency == 0:
            checkpoint_path = os.path.join(self.save_path, f"step_{trainer.global_step + 1}.ckpt")
            trainer.save_checkpoint(checkpoint_path)

# Train Loop

In [15]:
model = WaveUNet()

In [16]:
wandb_logger = WandbLogger(project='DrumWaveUNet', log_model='all')

In [17]:
audio_data_module = AudioDataModule(all_scenes, batch_size=4, num_workers=0, persistent_workers=False)

In [18]:
trainer = pl.Trainer(
    max_epochs=1000,
    accelerator="gpu", 
    devices=-1,
    logger=wandb_logger,
    callbacks=[SaveModelEveryNSteps()],
    # accumulate_grad_batches=4,
    gradient_clip_val=5,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [19]:
trainer.fit(model, audio_data_module)

[34m[1mwandb[0m: Currently logged in as: [33mhephyrius[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name        | Type                    | Params
---------------------------------------------------------
0  | encoders    | Encoder                 | 35.4 M
1  | up9         | Sequential              | 25.2 M
2  | up8         | Sequential              | 6.3 M 
3  | up7         | Sequential              | 2.6 M 
4  | up6         | Sequential              | 918 K 
5  | up5         | Sequential              | 295 K 
6  | up4         | Sequential              | 90.3 K
7  | up3         | Sequential              | 26.7 K
8  | up2         | Sequential              | 7.7 K 
9  | up1         | Sequential              | 3.9 K 
10 | conv_skip_8 | Sequential              | 8.4 M 
11 | conv_skip_7 | Sequential              | 2.1 M 
12 | conv_skip_6 | Sequential              | 525 K 
13 | conv_skip_5 | Sequential              | 131 K 
14 | conv_skip_4 | Sequential              | 33.2 K
15 | conv_skip_3 | Sequential              | 8.4 K 
16 | conv_skip_

Training: |                                                                                      | 0/? [00:00<…

RuntimeError: Given groups=1, weight of size [128, 2048, 1], expected input[4, 1024, 1] to have 2048 channels, but got 1024 channels instead