In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
# data paths
mixture_directory = "data/mixture"
vocals_directory = "data/vocals"

In [None]:
# parameters to play around with to better optimize training
patch_size = 128
stride = 64
batch_size = 16
epochs = 20
model_save_path = "vocal_isolator.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # will utilize GPU if possible to train/test

In [None]:
class SpectrogramDataset(Dataset):
    def __init__(self, filepaths, sequence_length = 1, label_map={"instrumental": 0, "acapella": 1, "full": 2}):
        self.filepaths = filepaths
        self.sequence_length = sequence_length

        self.label_map = label_map
        self.labels = []
    
    def __len__(self):
        return len(self.filepaths) #length of le dataset?BRO THIS IS WH Y IM NOT IN COMPUTER ENGINE
        
    def __getitem__(self, index):
        filepath = self.file_paths[index]
        label = self.labels[index]

        #Extract from Sheet
        Re = pd.read_excel(filepath, sheet_name='Re_X', header=None).values
        Im = pd.read_excel(filepath, sheet_name='Im_X', header=None).values
        fs = pd.read_excel(filepath, sheet_name='Sampling Rate', header=None).values[0][0]

        #Construct Complex Signal
        Signal = Re + 1j *Im

        return torch.tensor(Signal), torch.tensor(label), fs


In [None]:
class UNet(nn.Module):  #TODO: Should we be using a CNN or a UNET?
    def __init__(self):
        super(UNet, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  # 3x3 filter for feature detection
                nn.BatchNorm2d(out_channels),  # stabilize and speeds up training
                nn.ReLU(inplace=True),  # apply non-linearity for complex pattern learning
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
        
        # Encoding - reduce spatial dimensions and abstract features so model can understand
        self.encoder1 = conv_block(1, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # reduce resolution by 2 to allow for larger context
        self.encoder2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.center = conv_block(512, 1024)  # decision hub - learns what high level features are
        
        #Decoding - up sample and reconstruct the isolated vocals
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)
            
        self.final = nn.Conv2d(64, 1, kernel_size=1)  # reduce back to 1 channel (vocal spectrogram)
        
    def forward(self, x):
        # encoding
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool1(e1))
        e3 = self.encoder3(self.pool2(e2))
        e4 = self.encoder4(self.pool3(e3))
        center = self.center(self.pool4(e4))
        
        # decoding
        d4 = self.dec4(torch.cat([self.up4(center), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        out = self.final(d1)
        
        return out

In [None]:
# training function to be called
def train(model, dataloader, criterion, optimizer, num_epochs, checkpoint_path="checkpoints"):
    os.makedirs(checkpoint_path, exist_ok=True)
    model.train()
    checkpoint_interval = 5

    for epoch in range(1, num_epochs + 1):
        running_loss = 0.0
        num_batches = 0

        for spectrograms, labels in dataloader:
            # Move to GPU
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            # Forward pass
            outputs = model(spectrograms)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            num_batches += 1

        # Calculate average loss
        epoch_loss = running_loss / num_batches

        print(f"Epoch {epoch}/{num_epochs} — Average Loss: {epoch_loss:.4f}")

        # Periodic checkpoints
        if epoch % checkpoint_interval == 0:
            ckpt = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
            }
            torch.save(ckpt, os.path.join(checkpoint_path, f"sep_epoch_{epoch}.pth"))

    return model