<a href="https://colab.research.google.com/github/Ekstaxy/Mix_Wave_U_Net/blob/main/Wave_U_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wave-U-Net Practice

In [2]:
from google.colab import drive
drive.mount('/content/drive')

import zipfile
import os

# Mount Google Drive
drive.mount('/content/drive')

# Extract your zip file (update the path to your actual zip file)
zip_path = '/content/drive/MyDrive/ENST-drums-audio.zip'  # Change this to your zip path
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall('/content/')

# Set your dataset directory (update based on extracted folder name)
dataset_dir = '/content/ENST-drums-audio'  # Change this to match your extracted folder

# Verify it worked
print("Dataset contents:", os.listdir(dataset_dir))

Mounted at /content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dataset contents: ['ENST-drums-public']


## Import Package

In [3]:
import numpy as np
import pandas as pd
import math
import glob
import random
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio

!pip install pyloudnorm
import pyloudnorm
!pip install auraloss
import auraloss

import warnings
# Add this at the top of your code to suppress these specific warnings
warnings.filterwarnings("ignore", message="Possible clipped samples in output.")

# !pip install git+https://github.com/csteinmetz1/automix-toolkit
# import automix.utils
# import automix.data

Collecting pyloudnorm
  Downloading pyloudnorm-0.1.1-py3-none-any.whl.metadata (5.6 kB)
Downloading pyloudnorm-0.1.1-py3-none-any.whl (9.6 kB)
Installing collected packages: pyloudnorm
Successfully installed pyloudnorm-0.1.1
Collecting auraloss
  Downloading auraloss-0.4.0-py3-none-any.whl.metadata (8.0 kB)
Downloading auraloss-0.4.0-py3-none-any.whl (16 kB)
Installing collected packages: auraloss
Successfully installed auraloss-0.4.0


## Load in Data

In [4]:
class ENST_Drumset(Dataset):
    def __init__(
        self,
        root_dir: str,
        sr: float,
        length: int,
        drummers: List[int] = [1, 2],
        track_names: List[str] = [
            "kick",
            "snare",
            "hihat",
            "overhead_L",
            "overhead_R",
            "tom_1",
            "tom_2",
            "tom_3"
        ],
        indices: Tuple[int, int] = [0, 1],
        wet_mix: bool = False,
        hits: bool = False,
        num_examples_per_epoch: int = 1000,
        seed: int = 42
    ) -> None:
        super().__init__()

        self.root_dir = root_dir
        self.length = length
        self.sr = sr
        self.drummers = drummers
        self.track_names = track_names
        self.indices = indices
        self.wet_mix = wet_mix
        self.hits = hits
        self.num_examples_per_epoch = num_examples_per_epoch
        self.seed = seed
        self.max_num_tracks = 8
        self.mix_filepaths = []

        if not os.path.isdir(root_dir):
            raise FileNotFoundError(f"找不到指定的音訊檔案目錄：{root_dir}")

        for drummer in drummers:
            search_path = os.path.join(
                root_dir,
                f"drummer_{drummer}",
                "audio",
                "wet_mix" if wet_mix else "dry_mix",
                "*.wav",
            )
            self.mix_filepaths += glob.glob(search_path)

        # remove any mixes that is shorter than required length
        self.mix_filepaths = [
            fp
            for fp in self.mix_filepaths
            if torchaudio.info(fp).num_frames > self.length
        ]

        # remove any mixes that have "norm" in the filename
        self.mix_filepaths = [fp for fp in self.mix_filepaths if not "norm" in fp]

        # remove any mixes that are just hits
        if not self.hits:
            self.mix_filepaths = [fp for fp in self.mix_filepaths if "hits" not in fp]

        random.Random(seed).shuffle(self.mix_filepaths)
        self.mix_filepaths = self.mix_filepaths[indices[0] : indices[1]]

        if len(self.mix_filepaths) < 1:
            raise RuntimeError(f"No files found in {self.root_dir}.")
        else:
            print(f"Found {len(self.mix_filepaths)} examples from drummers: {drummers}")


    def __len__(self):
        return len(self.mix_filepaths)

    def __getitem__(self, idx):
        mix_idx = idx
        if mix_idx >= len(self.mix_filepaths):
            raise IndexError(f"idx : {mix_idx} out of sequence length : {len(self.mix_filepaths)}")
        mix_filepath = self.mix_filepaths[idx]
        example_id = os.path.basename(mix_filepath)
        drummer_id = os.path.normpath(mix_filepath).split(os.path.sep)[-4]

        md = torchaudio.info(mix_filepath)  # check length

        silent = True
        while silent:
            # get random offset
            offset = offset = np.random.randint(0, md.num_frames - self.length - 1)

            y, sr = torchaudio.load(
                uri = mix_filepath,
                frame_offset = offset,
                num_frames = self.length,
                normalize = True,
            )
            energy = (y**2).mean()
            if energy > 1e-8:
                silent = False

        y = y.float()
        y /= y.abs().max().clamp(1e-8)  # peak normalize
        # y_numpy = y.squeeze().numpy()
        # if y_numpy.ndim > 1:
        #     y_numpy = y_numpy.T
        # meter = pyloudnorm.Meter(self.sr)  # create BS.1770 meter
        # y_loudness = meter.integrated_loudness(y_numpy)

        # y_normalized_numpy = pyloudnorm.normalize.loudness(y_numpy, y_loudness, -24.0)
        # y_normalized_numpy = np.clip(y_normalized_numpy, -1.0, 1.0)
        # if y_normalized_numpy.ndim > 1:
        #     y_normalized_numpy = y_normalized_numpy.T
        # y = torch.from_numpy(y_normalized_numpy)  # Add channel dim

        x = torch.zeros((self.max_num_tracks, self.length))
        pad = [True] * self.max_num_tracks  # note which tracks are empty

        for tidx, track_name in enumerate(self.track_names):
            track_path = os.path.join(
                self.root_dir,
                drummer_id,
                "audio",
                track_name,
                example_id
            )
            if os.path.isfile(track_path):
                x_s, sr = torchaudio.load(
                    uri = track_path,
                    frame_offset = offset,
                    num_frames = self.length
                )
                # # Convert the PyTorch tensor to a NumPy array
                # x_s_numpy = x_s.squeeze().numpy()
                # # Transpose for pyloudnorm: (channels, samples) -> (samples, channels)
                # if x_s_numpy.ndim > 1:
                #     x_s_numpy = x_s_numpy.T
                # # Create a loudness meter instance from pyloudnorm
                # meter = pyloudnorm.Meter(self.sr)  # create BS.1770 meter
                # # Measure the loudness of the NumPy array
                # x_loudness = meter.integrated_loudness(x_s_numpy)
                # # Normalize the NumPy array to -24 LUFS

                # x_s_normalized_numpy = pyloudnorm.normalize.loudness(x_s_numpy, x_loudness, -24.0)
                # x_s_normalized_numpy = np.clip(x_s_normalized_numpy, -1.0, 1.0)
                # # Transpose back for PyTorch: (samples, channels) -> (channels, samples)
                # if x_s_normalized_numpy.ndim > 1:
                #     x_s_normalized_numpy = x_s_normalized_numpy.T
                # # Convert the normalized NumPy array back to a PyTorch tensor
                # x_s = torch.from_numpy(x_s_normalized_numpy)
                x_s /= x_s.abs().max().clamp(1e-6)
                x_s *= 10 ** (-12 / 20.0)
                x[tidx, :] = x_s
                pad[tidx] = False

        return x, y, torch.tensor(pad)

## Model Define

In [5]:
class DownSampling(nn.Module):
    def __init__(
        self,
        channel_in: int,
        channel_out: int,
        kernel_size: int = 15
    ):
        super().__init__()

        # 確保 kernel_size不可被 2 整除，padding = 'same'
        assert kernel_size % 2 != 0
        padding = kernel_size // 2

        self.conv1 = nn.Conv1d(
            channel_in,
            channel_out,
            kernel_size = kernel_size,
            padding = padding
        )
        self.batchnorm = nn.BatchNorm1d(channel_out)
        self.prelu = nn.PReLU(channel_out)
        self.conv2 = nn.Conv1d(
            channel_out,
            channel_out,
            kernel_size = kernel_size,
            stride = 2,
            padding = padding
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = self.prelu(x)
        x_ds = self.conv2(x)
        return x_ds, x

class UpSampling(nn.Module):
    def __init__(
        self,
        channel_in: int,
        channel_out: int,
        kernel_size: int = 5,
        skip: str = 'add'
    ):
        super().__init__()

        assert kernel_size % 2 != 0
        padding = kernel_size // 2

        self.skip = skip

        self.conv = nn.Conv1d(
            channel_in,
            channel_out,
            kernel_size = kernel_size,
            padding = padding
        )
        self.batchnorm = nn.BatchNorm1d(channel_out)
        self.prelu = nn.PReLU(channel_out)
        self.upsampling = nn.Upsample(scale_factor = 2)

    def forward(self, x: torch.Tensor, skip: torch.Tensor):
        x = self.upsampling(x)

        if self.skip == 'add':
            x = x + skip
        elif self.skip == 'concat':
            x = torch.cat((x, skip), dim = 1)
        elif self.skip == 'none':
            pass
        else:
            raise NotImplementedError()

        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.prelu(x)

        return x

class MixWaveUNet(nn.Module):
    def __init__(
        self,
        n_inputs: int = 8,
        n_outputs: int = 2,
        ds_kernel: int = 15,    # 13
        us_kernel: int = 5,     # 13
        out_kernel: int = 5,
        layers: int = 10,
        channel_growth: int = 24,
        skip: str = 'concat'
    ):
        super().__init__()

        self.n_inputs = n_inputs
        self.n_outputs = n_outputs

        self.encoder = nn.ModuleList()
        for n in range(layers):
            if n == 0:
                channel_in = n_inputs
                channel_out = channel_growth
            else:
                channel_in = channel_out
                channel_out = channel_in + channel_growth

            self.encoder.append(DownSampling(channel_in, channel_out, kernel_size = ds_kernel))

        # print("Encoder Structure")
        # print(self.encoder)

        self.embedding = nn.Conv1d(channel_out, channel_out, kernel_size = 1)
        # print("Embedder Structure")
        # print(self.embedding)

        self.decoder = nn.ModuleList()
        for n in range(layers, 0, -1):
            channel_in = channel_out
            channel_out = channel_in - channel_growth

            if channel_out < channel_growth:
                channel_out = channel_growth

            if skip == 'concat':
                channel_in *= 2

            self.decoder.append(UpSampling(channel_in, channel_out, kernel_size=us_kernel, skip=skip))

        # print("Decoder Structure")
        # print(self.decoder)

        self.output_conv = nn.Conv1d(channel_out + n_inputs, n_outputs, kernel_size = out_kernel, padding = out_kernel // 2)

    def forward(self, x):
        x_in = x
        skips = []

        for enc in self.encoder:
            x, skip = enc(x)
            skips.append(skip)

        x = self.embedding(x)

        for dec in self.decoder:
            skip = skips.pop()
            x = dec(x, skip)

        x = torch.cat((x_in, x), dim = 1)
        y = self.output_conv(x)

        return y, torch.zeros(1)

## Train

In [6]:
dataset_dir = '/content/ENST-drums-audio/ENST-drums-public'
dataset_name = 'ESTN_Drum_Dataset'
sample_rate = 44100
train_length = 262144
val_length = 262144
test_length = 262144
output_length = 262144
max_num_track = 8
wet_mix = False

batch_size = 16
lr = 0.001
max_epochs = 50
patient = 5
num_workers = 1

model = MixWaveUNet()
loss_function = auraloss.freq.SumAndDifferenceSTFTLoss(
    fft_sizes = [512, 1024, 2048, 4096],  # Multiple scales for multi-resolution
    hop_sizes = [128, 256, 512, 1024],    # 25% overlap (hop = fft_size/4)
    win_lengths = [512, 1024, 2048, 4096], # Same as fft_sizes for full window
    window = "hann_window",
    w_sum = 1.0,
    w_diff = 1.0,
    output = "loss",
    sample_rate = sample_rate,
)
optimizer = optim.Adam(model.parameters(), lr = lr)
model_save_path = '/content/ckpt'

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)

Using device: cuda


MixWaveUNet(
  (encoder): ModuleList(
    (0): DownSampling(
      (conv1): Conv1d(8, 24, kernel_size=(15,), stride=(1,), padding=(7,))
      (batchnorm): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=24)
      (conv2): Conv1d(24, 24, kernel_size=(15,), stride=(2,), padding=(7,))
    )
    (1): DownSampling(
      (conv1): Conv1d(24, 48, kernel_size=(15,), stride=(1,), padding=(7,))
      (batchnorm): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=48)
      (conv2): Conv1d(48, 48, kernel_size=(15,), stride=(2,), padding=(7,))
    )
    (2): DownSampling(
      (conv1): Conv1d(48, 72, kernel_size=(15,), stride=(1,), padding=(7,))
      (batchnorm): BatchNorm1d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=72)
      (conv2): Conv1d(72, 72, kernel_size=(15,), stride=(2,), padding=(7,))
    )
   

In [7]:
train_dataset = ENST_Drumset(
    root_dir = dataset_dir,
    sr = sample_rate,
    length = train_length,
    drummers = [1, 2, 3],
    indices = [0, 168],
    num_examples_per_epoch = 1000,
    wet_mix = wet_mix,
)
val_dataset = ENST_Drumset(
    root_dir = dataset_dir,
    sr = sample_rate,
    length = val_length,
    drummers = [1, 2, 3],
    indices = [168, 189],
    num_examples_per_epoch = 1000,
    wet_mix = wet_mix,
)
test_dataset = ENST_Drumset(
    root_dir = dataset_dir,
    sr = sample_rate,
    length = test_length,
    drummers = [1, 2, 3],
    indices = [189, 210],
    num_examples_per_epoch = 1000,
    wet_mix = wet_mix,
)
train_dataloader = DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = num_workers,
    persistent_workers=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size = batch_size,
    shuffle = False,
    num_workers=1,
    persistent_workers=True,
)

  if torchaudio.info(fp).num_frames > self.length
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  return AudioMetaData(


Found 168 examples from drummers: [1, 2, 3]
Found 21 examples from drummers: [1, 2, 3]
Found 21 examples from drummers: [1, 2, 3]


In [None]:
best_val_lost = 10000000
epochs_no_improve = 0

# Training Loop
for epoch in range(max_epochs):
    model.train()
    total_train_loss = 0

    # idx, x, y = next(iter(train_dataloader))
    # print(idx, x, y)

    for idx, data in enumerate(train_dataloader):
        x_batch, y_batch, pad = data
        # x_batch = x_batch.unsqueeze(1)
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()
        output, _ = model(x_batch)
        loss = loss_function(output, y_batch)

        loss.backward()

        optimizer.step()
        total_train_loss += loss.item()
        torch.cuda.empty_cache()

    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{max_epochs}")
    print(f"Train Loss: {avg_train_loss:.4f}")

    # Model Evaluation
    total_val_loss = 0
    with torch.no_grad():
        for idx, data in enumerate(val_dataloader):
            x_batch, y_batch, pad = data
            # x_batch = x_batch.unsqueeze(1)
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            output, _ = model(x_batch)
            output = output.clamp(min=-1.0, max=1.0)
            loss = loss_function(output, y_batch)

            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_dataloader)

    if best_val_lost > avg_val_loss:
        best_val_lost = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), model_save_path)

    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patient:
            print(f"Early stopping triggered after {epoch+1} epochs due to no improvement for {patient} epochs.")
            break


print("\nTraining complete.")
model.load_state_dict(torch.load(model_save_path))
model.eval()

  md = torchaudio.info(mix_filepath)  # check length
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 1/50
Train Loss: 4.8441


  md = torchaudio.info(mix_filepath)  # check length
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 2/50
Train Loss: 1.9976
Epoch 3/50
Train Loss: 1.6886
Epoch 4/50
Train Loss: 1.6050
Epoch 5/50
Train Loss: 1.5370
Epoch 6/50
Train Loss: 1.4700


## Test

In [None]:
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,  # Use a batch size of 1 for evaluation
    shuffle=False,
    num_workers=1,
    persistent_workers=True,
)

In [None]:
output_dir = './output_audio'
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

In [None]:
from IPython.display import Audio, display

# Set the model to evaluation mode
model.eval()

# Loop through the test dataset
with torch.no_grad():
    for idx, data in enumerate(test_dataloader):
        # Unpack the data
        x_batch, y_batch, pad = data

        # Move tensors to the correct device
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        # Pass the mix through the trained model
        predicted_output, _ = model(x_batch)
        predicted_output = predicted_output.clamp(min=-1.0, max=1.0)

        # Move the tensors to the CPU for saving
        x_batch = x_batch.cpu()
        y_batch = y_batch.cpu()
        predicted_output = predicted_output.cpu()

        print(f"Saving and playing example {idx+1}...")

        # Save the original mix (y_batch)
        original_mix_path = os.path.join(output_dir, f"example_{idx+1}_original_mix.wav")
        torchaudio.save(
            original_mix_path,
            y_batch.squeeze(0),
            sample_rate,
        )

        # Save the model's two-channel output as a single stereo file
        separated_mix_path = os.path.join(output_dir, f"example_{idx+1}_separated_mix.wav")
        torchaudio.save(
            separated_mix_path,
            predicted_output.squeeze(0),
            sample_rate,
        )

        # Display the audio players for this example
        print("Original Mix:")
        display(Audio(original_mix_path))

        print("Separated Mix:")
        display(Audio(separated_mix_path))

print("\nEvaluation complete. All audio files saved to the 'output_audio' directory.")