In [2]:
import os
import numpy as np
import pandas as pd
import glob
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sympy.codegen.cnodes import static
from torch.utils.data import Dataset, DataLoader

In [3]:
# data paths
mixture_directory_test = "data/mixture/test"
mixture_directory_train = "data/mixture/train"
vocal_directory_test = "data/vocal/test"
vocal_directory_train = "data/vocal/train"

In [4]:
# parameters to play around with to better optimize training
patch_size = 128
stride = 64
batch_size = 4
epochs = 20
model_save_path = "vocal_isolator.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # will utilize GPU if possible to train/test

In [5]:
class SpectrogramDataset(Dataset):
    def __init__(self, mixed_dir, vocal_dir, use_magnitude=True):
        self.mixed_map = self._make_file_map(mixed_dir, prefix='mix_')
        self.vocals_map = self._make_file_map(vocal_dir, prefix='vocal_')
        
        # error check with available song keys
        self.keys = sorted(list(set(self.mixed_map.keys()) & set(self.vocals_map.keys())))
        assert self.keys, "No matching files found between mixed and vocal directories"
        
        self.pairs = [(self.mixed_map[k], self.vocals_map[k]) for k in self.keys]
        self.use_magnitude = use_magnitude
    
    @staticmethod
    def _make_file_map(directory, prefix):
        file_map = {}
        for path in glob.glob(os.path.join(directory, "*.npy")):
            filename = os.path.basename(path)
            if filename.startswith(prefix):
                key = filename[len(prefix):]  # strip the prefix
                file_map[key] = path
        return file_map
    
    # @staticmethod
    # def _load_excel_sheet(filepath):
    #     real = pd.read_excel(filepath, sheet_name='Re_X', header=None).values.astype(np.float32)
    #     imaginary = pd.read_excel(filepath, sheet_name='Im_X', header=None).values.astype(np.float32)
    #     fs = pd.read_excel(filepath, sheet_name='Sampling Rate', header=None).values[0][0]
    #     
    #     complex_stft = real + 1j*imaginary
    #     return complex_stft, fs
    
    @staticmethod
    def _load_npy(filepath):
        return np.load(filepath)
    
    @staticmethod
    def _pad_to_multiple_16(tensor, multiple=16):
        _, h, w = tensor.shape
        pad_h = (multiple - h % multiple) % multiple
        pad_w = (multiple - w % multiple) % multiple
        return F.pad(tensor, (0, pad_w, 0, pad_h))
    
    def __len__(self):
        return len(self.pairs) #length of le dataset?BRO THIS IS WH Y IM NOT IN COMPUTER ENGINE
        
    def __getitem__(self, index):
        mixed_path, vocals_path = self.pairs[index]

        mixed_stft = self._load_npy(mixed_path)
        vocals_stft = self._load_npy(vocals_path)
        
        if self.use_magnitude:
            mixed_mag = np.abs(mixed_stft)
            vocals_mag = np.abs(vocals_stft)
            
            mixed_mag = (mixed_mag - mixed_mag.min()) / (mixed_mag.max() - mixed_mag.min() + 1e-8)
            vocals_mag = (vocals_mag - vocals_mag.min()) / (vocals_mag.max() - vocals_mag.min() + 1e-8)
            
            mixed_tensor = torch.tensor(mixed_mag).unsqueeze(0)
            vocals_tensor = torch.tensor(vocals_mag).unsqueeze(0)
            
            mixed_tensor = self._pad_to_multiple_16(mixed_tensor)
            vocals_tensor = self._pad_to_multiple_16(vocals_tensor)
            
            return mixed_tensor, vocals_tensor

        else:  # if training on complex values
            return torch.tensor(mixed_stft), torch.tensor(vocals_stft)
        
def pad_collate_fn(batch):
    """
    Pads all spectrograms in the batch to the maximum height and width in that batch.
    """
    max_h = max(item[0].shape[1] for item in batch)
    max_w = max(item[0].shape[2] for item in batch)
    
    mixed_padded = []
    vocals_padded = []
    for mixed, vocals in batch:
        mixed_padded.append(F.pad(mixed, (0, max_w - mixed.shape[2], 0, max_h - mixed.shape[1])))
        vocals_padded.append(F.pad(vocals, (0, max_w - vocals.shape[2], 0, max_h - vocals.shape[1])))

    return torch.stack(mixed_padded), torch.stack(vocals_padded)


In [6]:
class UNet(nn.Module):  #TODO: Should we be using a CNN or a UNET?
    def __init__(self):
        super(UNet, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  # 3x3 filter for feature detection
                nn.BatchNorm2d(out_channels),  # stabilize and speeds up training
                nn.ReLU(inplace=True),  # apply non-linearity for complex pattern learning
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
        
        # Encoding - reduce spatial dimensions and abstract features so model can understand
        self.encoder1 = conv_block(1, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # reduce resolution by 2 to allow for larger context
        self.encoder2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.center = conv_block(512, 1024)  # decision hub - learns what high level features are
        
        #Decoding - up sample and reconstruct the isolated vocals
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)
            
        self.final = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )  # reduce back to 1 channel (vocal spectrogram)
        
    def forward(self, x):
        # encoding
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool1(e1))
        e3 = self.encoder3(self.pool2(e2))
        e4 = self.encoder4(self.pool3(e3))
        center = self.center(self.pool4(e4))
        
        # decoding
        d4 = self.dec4(torch.cat([self.up4(center), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        out = self.final(d1)
        
        return out

In [7]:
# helper functions for training function
def compute_sdr(target, prediction):
    noise = target - prediction
    return 10 * torch.log10(torch.sum(target ** 2) / (torch.sum(noise ** 2) + 1e-8))

def save_prediction_as_xlsx(prediction, filename):
    prediction_np = prediction.squeeze().cpu().detach().numpy()
    df = pd.DataFrame(prediction_np)
    df.to_excel(filename, index=False, header=False)

In [8]:
# training function to be called
def train(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, checkpoint_path="vocal_isolator.pth"):
    print("Starting training loop...")
    # os.makedirs(checkpoint_path, exist_ok=True)
    # model.train()
    # checkpoint_interval = 5
    model.to(device)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs} starting...")
        # TRAINING
        model.train()  # set to training mode
        running_loss = 0.0
        # num_batches = 0

        for mixed, vocals in train_loader:
            # Move to GPU
            mixed = mixed.to(device)
            vocals = vocals.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(mixed)
            loss = criterion(outputs, vocals)

            # Backward pass
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # num_batches += 1

        # Calculate average loss
        train_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch}/{num_epochs} — Train Loss: {train_loss:.4f}")
        
        # VALIDATION
        model.eval()   # set to eval mode for validation
        val_loss, mae, sdr = 0.0, 0.0, 0.0
        with torch.no_grad():
            for mixed, vocals in val_loader:
                mixed, vocals = mixed.to(device), vocals.to(device)
                outputs = model(mixed)
                val_loss += criterion(outputs, vocals).item()
                mae += F.l1_loss(outputs, vocals).item()
                sdr += compute_sdr(outputs, vocals)
        val_loss /= len(val_loader)
        mae /= len(val_loader)
        sdr /= len(val_loader)
        print(f"Validation: Loss = {val_loss:.4f}, MAE = {mae:.4f}, SDR = {sdr:.4f} dB")
        
        torch.save(model.state_dict(), checkpoint_path)
        
        # save predicted xlsx
        mixed_sample, _ = next(iter(val_loader))
        mixed_sample = mixed_sample.to(device)
        predicted = model(mixed_sample[:1])
        save_prediction_as_xlsx(predicted, f"predicted_epoch_{epoch}.xlsx")
        print(f"Saved sample prediction: predicted_epoch_{epoch}.xlsx")
    
    print(f"Training complete, model saved at {checkpoint_path}")
    return model

In [None]:
# training block
train_dataset = SpectrogramDataset(mixed_dir="data/mixture/train", vocal_dir="data/vocal/train")
val_dataset = SpectrogramDataset(mixed_dir="data/mixture/val", vocal_dir="data/vocal/val")

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate_fn, num_workers=0)

print("Testing dataset loading...")
for i in range(min(3, len(train_dataset))):
    x,y = train_dataset[i]
    print(f"Sample {i} : x shape {x.shape}, y shape {y.shape}")

print(f"Training pairs: {len(train_dataset)}")
print(f"Validation pairs: {len(val_dataset)}")

model = UNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # load model onto device

criterion = torch.nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train(
    model=model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=epochs,
    checkpoint_path='vocal_isolator.pth'
)

Testing dataset loading...
Sample 0 : x shape torch.Size([1, 528, 5184]), y shape torch.Size([1, 528, 5184])
Sample 1 : x shape torch.Size([1, 528, 5184]), y shape torch.Size([1, 528, 5184])
