<a href="https://colab.research.google.com/github/CQDCQD/MLody-PGSS-2025-/blob/main/MLody_Final_Version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Import libraries

In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import librosa
#audio processing library
from librosa import load
from scipy.signal import butter, filtfilt
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.utils.data import Dataset
from google.colab import auth
from google.colab import drive
import datetime
from torch.nn.utils.rnn import pad_sequence

drive.mount('/content/drive', force_remount=True)


#Z score normalization

In [None]:
def computeZScoreMeanStd(dataset):
    melSum = 0.0
    melSqSum = 0.0
    stftSum = 0.0
    stftSqSum = 0.0
    totalElements = 0

    loader = DataLoader(dataset, batch_size=1, shuffle=False)

    for mel, stft, _, _ in loader:
        mel = mel[0]  # [F, T]
        stft = stft[0]

        melSum += mel.sum().item()
        melSqSum += (mel ** 2).sum().item()
        stftSum += stft.sum().item()
        stftSqSum += (stft ** 2).sum().item()
        totalElements += mel.numel()

    melMean = melSum / totalElements
    melStd = np.sqrt((melSqSum / totalElements) - melMean**2)

    stftMean = stftSum / totalElements
    stftStd = np.sqrt((stftSqSum / totalElements) - stftMean**2)

    return {
        "mel": {"mean": melMean, "std": melStd},
        "stft": {"mean": stftMean, "std": stftStd}
    }


#Globals

In [None]:
sr = 22050 #sample rate = 22050 hz
file_name = "SET FILE NAME"
local = True # Set to True if running locally, False if on a server
modelSavePath = "/content/PATH" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pth"
modelLoadPath = "/content/PATH.pth"
loadDrive = "/content/PATH/"
saveDrive = "/content/PATH/"
file_path = loadDrive + file_name
outputAPath = saveDrive + "Split A/"
outputBPath = saveDrive + "Split B/"

melPath   = "/content/cached_data/mel_folder"
stftPath  = "/content/cached_data/stft_folder"
maskPathA = "/content/cached_data/mask_folderA"
maskPathB = "/content/cached_data/mask_folderB"

calculatedZmelMean = 0.0027987720467882882
calculatedZmelSTD = 0.020739795957941308
calculatedZSTFTMean = 0.16094034763575416
calculatedZSTFTSTD = 0.13678209049507903

def getCalculatedZMelMean():
    return calculatedZmelMean

def getCalculatedZMelSTD():
    return calculatedZmelSTD

def getCalculatedZSTFTMean():
    return calculatedZSTFTMean

def getCalculatedZSTFTSTD():
    return calculatedZSTFTSTD

def getSavePath():
    return saveDrive

def getLoadPath():
    return loadDrive

def getOutputAPath():
    return outputAPath

def getOutputBPath():
    return outputBPath

def getSR():
    return sr

def getFilePath():
    return file_path

def getZScoreNormalization():
    globalMean = globalZScoreNormalization.mean()
    globalStd = globalZScoreNormalization.std()
    return globalMean, globalStd

def localMode():
    return local

def getModelSavePath():
    return modelSavePath

def getModelLoadPath():
    return modelLoadPath

#Import Audio

In [None]:
def importAudio(file_path):
    """
    Imports an audio file from the specified file path.

    Args (arguments = inputs):
        file_path (str): The path to the audio file to be imported.

    Returns (what the function gives back when called):
        audio: An AudioSegment object representing the audio file.
        sr (int): The sample rate of the audio file.
    """
    auth.authenticate_user()
    drive.mount('/content/drive', force_remount=True)

    audio, sr = load(file_path, sr=getSR(), mono=True)

    return audio, sr


#Preprocessing

In [None]:
def preprocessAudio(audio):

    audio = amplitudeNormalize(audio)
    """Normalize the audio amplitude to [0, 1] so it's in a consistent range.
    If the music is louder, this will make it even to the model"""

    audio = highPassFilter(audio, getSR())
    """Apply a high-pass filter to remove low-frequency noise. This removes
    background sounds and isolates the vocals"""

    audio = preEmphasis(audio)
    """Apply pre-emphasis to the audio signal. This boosts high frequencies
    and helps the model focus on the important parts of the audio because
    human hearing is more sensitive to these frequencies."""

    framedAudio = framing(audio, getSR())
    """Frame the audio signal into overlapping segments. This is done to
    analyze the audio in smaller chunks, which is useful for processing
    speech and music signals."""

    audioSTFT = STFTSpectrogram(framedAudio)
    """Compute the Short-Time Fourier Transform (STFT) of the audio signal.
    This transforms the audio signal into the frequency domain, allowing us
    to analyze its frequency content over time."""

    audioMel = MelSpectrogram(framedAudio, getSR())
    """Also compute the Mel spectrogram of the audio signal. This is a representation
    of the audio signal that mimics human hearing by using a Mel scale, which
    is more aligned with how humans perceive sound frequencies. We need this in addition
    to the fourier transform because the fourier transform is better for small details
    which can over-specify the model, so we need the mel spectrum for more generalized
    application in conjunction with the specificity of the fourier transform"""

    audioMel = logCompress(audioMel)
    audioSTFT = logCompress(audioSTFT)
    """Apply logarithmic compression to the Mel spectrogram and STFT. This helps
    to reduce the dynamic range of the audio signal, making it easier for the model to learn."""

    audioMel = globalZScale(audioMel, getCalculatedZMelMean(), getCalculatedZMelSTD())
    audioSTFT = globalZScale(audioSTFT, getCalculatedZSTFTMean(), getCalculatedZSTFTSTD())
    """Apply global z-score normalization to the Mel spectrogram and STFT. This standardizes
    the features by removing the mean and scaling to unit variance, which helps keep consistency
    between samples."""

    return audioMel, audioSTFT
    """return the result"""


"""
Functions used above ----------------------------------------------------------
"""


def amplitudeNormalize(audio):
    """Librosa by default imports audio in form [-1, 1]
    This function re-normalizes to [0, 1] for consistency"""
    return (audio + 1) / 2

def highPassFilter(audio, sr):
    """sets up a high-pass filter to remove low-frequency noise"""

    cutoff = 75
    """hz, code cuts off frequencies below 75 hz"""
    order = 4
    """order of the filter, higher order means sharper cutoff. 4 is a happy medium
    between too sharp (choppy) and not sharp enough (doesn't cut well)"""
    zeroPhaseFiltering = True
    """Makes the filter zero-phase, meaning it doesn't introduce
    any phase distortion to the signal. This is important for audio processing to maintain
    the original timing of the audio signal."""
    #butterworth filter
    """A butterworth filter is a type of signal processing filter that has a flat frequency response in the passband.
    It is designed to have a smooth transition between the passband and the stopband, which
    makes it suitable for audio processing applications where we want to remove low-frequency noise
    without introducing significant phase distortion."""


    nyquist = 0.5 * sr
    """Nyquist frequency is half the sample rate, used to normalize the cutoff frequency"""
    normal_cutoff = cutoff / nyquist
    """Normalize the cutoff frequency to the Nyquist frequency"""
    """This is done to ensure that the filter works correctly regardless of the sample rate"""

    b, a = butter(order, normal_cutoff, btype='high', analog=False)
    """Create the Butterworth filter coefficients. The 'b' coefficients are the numerator and
    'a' coefficients are the denominator of the filter transfer function. Unless you are interested,
    ignore the fancy math and details"""

    if zeroPhaseFiltering:
        audio = filtfilt(b, a, audio)
    """Apply zero phase filtering to maintain signal phase"""

    return audio
    """Return the filtered audio signal"""

def preEmphasis(audio):

    """
    Apply pre-emphasis to the audio signal.

    Args:
        audio (np.ndarray): The input audio signal.
        alpha (float): The pre-emphasis coefficient.

    Returns:
        np.ndarray: The pre-emphasized audio signal.
    """

    alpha = 0.97
    """Pre-emphasis coefficient, controls the amount of emphasis on high frequencies. The value 0.97 is
    commonly used in speech processing because it provides a good balance between boosting high
    frequencies and avoiding excessive noise amplification."""

    return np.append(audio[0], audio[1:] - alpha * audio[:-1])

def framing(signal, sampleRate):
    frameLen = 25
    """ms, length of each frame"""

    desired_hop_samples = 256

    frameLengthSamples = int(sampleRate * frameLen / 1000)

    frameStepSamples = desired_hop_samples
    """Figures out sample sizes based on sample rates"""

    resulting_overlap_samples = frameLengthSamples - frameStepSamples
    resulting_overlap_ms = (resulting_overlap_samples / sampleRate) * 1000
    print(f"\n--- Framing Details ---")
    print(f"Desired frame length: {frameLen} ms ({frameLengthSamples} samples)")
    print(f"Desired hop size: {frameStepSamples} samples")
    print(f"Resulting overlap: {resulting_overlap_samples} samples ({resulting_overlap_ms:.2f} ms)")

    #check for short samples
    if frameLengthSamples <= 0:
        raise ValueError(f"Calculated frameLengthSamples is {frameLengthSamples}. Ensure sampleRate and frameLen are positive.")
    # Ensure frameStepSamples is valid (non-negative and less than or equal to frameLengthSamples)
    if frameStepSamples <= 0 or frameStepSamples > frameLengthSamples:
        raise ValueError(f"Calculated frameStepSamples is {frameStepSamples}. It must be positive and less than or equal to frameLengthSamples ({frameLengthSamples}).")

    if len(signal) < frameLengthSamples:
        print(f"Warning: Signal length ({len(signal)}) is less than frameLengthSamples ({frameLengthSamples}). No frames can be formed.")
        return np.array([])

    numFrames = 1 + int((len(signal) - frameLengthSamples) / frameStepSamples)

    if numFrames <= 0:
        print(f"Warning: numFrames is {numFrames}. Signal might be too short for the given frame parameters.")
        return np.array([])

    shape = (numFrames, frameLengthSamples)
    strides = (signal.strides[0] * frameStepSamples, signal.strides[0])
    frames = np.lib.stride_tricks.as_strided(signal, shape=shape, strides=strides).copy()

    windowType = 'hamming' # for smoothness
    if windowType == 'hamming':
        window = np.hamming(frameLengthSamples)
    elif windowType == 'hann':
        window = np.hanning(frameLengthSamples)
    else:
        raise ValueError(f"Unsupported window type: {windowType}")

    return frames * window

def STFTSpectrogram(frames_2d_array):
    """
    Compute the Short-Time Fourier Transform (STFT) of a pre-framed audio signal.

    Args:
        frames_2d_array (np.ndarray): A 2D array where each row is a windowed frame
                                      from the `framing` function.
                                      Expected shape: (num_frames, frame_length_samples)

    Returns:
        np.ndarray: The magnitude spectrogram. Shape: (n_fft_bins, num_frames).
                    Note: The output is magnitude (real-valued), not complex.
    """
    if frames_2d_array.size == 0:
        print("Warning: Input frames_2d_array is empty for STFTSpectrogram.")
        return np.array([[]]) # Return a 2D empty array

    # Get the length of individual frames from the input 2D array
    frame_length_samples = frames_2d_array.shape[1]

    # --- SETTINGS FOR FFT ---
    # nFFT: Number of points in the FFT. It's common to use the next power of 2
    # greater than or equal to the frame_length_samples for FFT efficiency and zero-padding.
    nFFT = int(2**np.ceil(np.log2(frame_length_samples)))

    # Apply FFT to each frame (row)
    # np.fft.rfft is for real-valued input and returns only the positive frequency components.
    stft_complex = np.fft.rfft(frames_2d_array, n=nFFT, axis=1)

    # Return magnitude spectrogram.
    # Transpose (.T) to get the common (frequency_bins, time_frames) shape.
    return np.abs(stft_complex).T



def MelSpectrogram(frames_2d_array, sr):
    """
    Compute the Mel spectrogram from a 2D array of pre-framed audio signals.
    This function first computes the STFT of the frames and then applies
    the Mel filter banks.

    Args:
        frames_2d_array (np.ndarray): A 2D array where each row is a windowed frame
                                      from the `framing` function.
                                      Expected shape: (num_frames, frame_length_samples)
        sr (int): The sample rate.

    Returns:
        np.ndarray: The Mel spectrogram (power). Shape: (n_mels, num_frames).
    """
    if frames_2d_array.size == 0:
        print("Warning: Input frames_2d_array is empty for MelSpectrogram.")
        return np.array([[]]) # Return a 2D empty array

    # === SETTINGS FOR MEL SPECTROGRAM ===
    nMels = 128
    """Number of Mel frequency bins."""
    fMin = 60      # Hz
    fMax = 7600    # Hz
    """Minimum and maximum frequencies for the Mel scale."""

    # Get the length of individual frames from the input 2D array
    frame_length_samples = frames_2d_array.shape[1]

    # nFFT for the internal FFT applied to frames (usually next power of 2)
    nFFT_for_mel_internal = int(2**np.ceil(np.log2(frame_length_samples)))

    # 1. Compute STFT (power) from the 2D frames using numpy's FFT
    # This part is similar to the beginning of the STFTSpectrogram function
    stft_complex = np.fft.rfft(frames_2d_array, n=nFFT_for_mel_internal, axis=1)
    power_spectrogram = np.abs(stft_complex)**2 # Square magnitude for power spectrogram

    # 2. Create Mel filter bank
    # n_fft parameter for librosa.filters.mel should correspond to the FFT size used for the power_spectrogram
    n_fft_for_mel_filter = (power_spectrogram.shape[1] - 1) * 2 # If power_spectrogram is (num_frames, n_fft_bins)
                                                             # If power_spectrogram is (n_fft_bins, num_frames), use power_spectrogram.shape[0]

    # Assuming power_spectrogram is (num_frames, n_fft_bins) from np.fft.rfft
    # librosa.filters.mel expects n_fft, not n_fft_bins.
    # The actual FFT length was nFFT_for_mel_internal.
    mel_basis = librosa.filters.mel(
        sr=sr,
        n_fft=nFFT_for_mel_internal, # Use the nFFT that was used for the frames
        n_mels=nMels,
        fmin=fMin,
        fmax=fMax,
        htk=False
    )

    # 3. Apply Mel filter bank to the power spectrogram
    # Power spectrogram is (num_frames, n_fft_bins), Mel basis is (n_mels, n_fft_bins)
    # For matrix multiplication (np.dot), shapes need to align.
    # np.dot(mel_basis, power_spectrogram.T) -> (n_mels, n_fft_bins) @ (n_fft_bins, num_frames)
    melSpec = np.dot(mel_basis, power_spectrogram.T)

    return melSpec


def logCompress(spec):
    """
    Apply logarithmic compression to the audio signal.

    Args:
        audio (np.ndarray): The input audio signal.

    Returns:
        np.ndarray: The log-compressed audio signal.

    Add a small constant to avoid log(0) which is undefined

    """
    return np.log1p(spec)  # log(1 + x) is numerically stable for small x

def globalZScale(audio, mean, std):
    return (audio - mean) / std
    """Apply global z-score normalization to the audio signal. This standardizes the features by removing the
    mean and scaling to unit variance, which helps keep consistency between samples. The mean and std are
    obtained from the globalZScoreNormalization module. Please check out globalZScoreNormalization.py
    to see how the mean and std are computed."""

# Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, random
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

# U-Net block
class UNetBlock(nn.Module):
    def __init__(self, inChannels, outChannels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

# Deep U-Net
class DeepUNet(nn.Module):
    def __init__(self, inputChannels=1, baseChannels=4):
        super().__init__()
        self.enc1 = UNetBlock(inputChannels, baseChannels)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = UNetBlock(baseChannels, baseChannels * 2)
        self.pool2 = nn.MaxPool2d(2)
        self.bottleneck = UNetBlock(baseChannels * 2, baseChannels * 4)
        self.up2 = nn.ConvTranspose2d(baseChannels * 4, baseChannels * 2, 2, stride=2)
        self.dec2 = UNetBlock(baseChannels * 4, baseChannels * 2)
        self.up1 = nn.ConvTranspose2d(baseChannels * 2, baseChannels, 2, stride=2)
        self.dec1 = UNetBlock(baseChannels * 2, baseChannels)
        self.final = nn.Conv2d(baseChannels, baseChannels, kernel_size=1)

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool1(x1))
        x3 = self.bottleneck(self.pool2(x2))
        x = self.up2(x3)
        x = self.dec2(torch.cat([x, F.interpolate(x2, size=x.shape[2:], mode='bilinear', align_corners=False)], dim=1))
        x = self.up1(x)
        x = self.dec1(torch.cat([x, F.interpolate(x1, size=x.shape[2:], mode='bilinear', align_corners=False)], dim=1))
        return self.final(x)

class DualUNetBLSTM(nn.Module):
    def __init__(self, inputChannels=1, baseChannels=4, lstmHidden=64, melFreqBins=128, stftFreqBins=513):
        super().__init__()
        self.baseChannels = baseChannels
        self.lstmHidden = lstmHidden
        self.melFreqBins = melFreqBins
        self.stftFreqBins = stftFreqBins

        self.melUNet = DeepUNet(inputChannels, baseChannels)
        self.stftUNet = DeepUNet(inputChannels, baseChannels)

        self.lstm = nn.LSTM(
            input_size=2 * baseChannels * melFreqBins,
            hidden_size=lstmHidden,
            batch_first=True,
            bidirectional=True
        )
        self.outputConv = nn.Conv2d(lstmHidden * 2, 2 * stftFreqBins, kernel_size=1)

    def forward(self, melInput, stftInput):
        melFeat = self.melUNet(melInput)
        stftFeat = self.stftUNet(stftInput)

        if stftFeat.shape[2:] != melFeat.shape[2:]:
            stftFeat = F.interpolate(stftFeat, size=melFeat.shape[2:], mode='bilinear', align_corners=False)

        fused = torch.cat([melFeat, stftFeat], dim=1)
        b, c, f, t = fused.shape
        rnnInput = fused.permute(0, 3, 1, 2).reshape(b, t, -1)

        rnnOutput, _ = self.lstm(rnnInput)
        rnnOutput = rnnOutput.permute(0, 2, 1).unsqueeze(2)  # [B, hidden*2, 1, T]

        output = self.outputConv(rnnOutput)  # [B, 2*512, 1, T]
        output = output.view(b, 2, self.stftFreqBins, t)  # Final: [B, 2, 512, T]

        output = torch.sigmoid(output)
        return output


def predictMasks(melInput, stftInput):
    device = torch.device("cuda")

    if isinstance(melInput, np.ndarray):
        melInput = torch.tensor(melInput, dtype=torch.float32)
    if isinstance(stftInput, np.ndarray):
        stftInput = torch.tensor(stftInput, dtype=torch.float32)

    if melInput.ndim == 2:
        melInput = melInput.unsqueeze(0).unsqueeze(0)
    elif melInput.ndim == 3:
        melInput = melInput.unsqueeze(1)

    if stftInput.ndim == 2:
        stftInput = stftInput.unsqueeze(0).unsqueeze(0)
    elif stftInput.ndim == 3:
        stftInput = stftInput.unsqueeze(1)

    melInput, stftInput = melInput.to(device), stftInput.to(device)

    model = DualUNetBLSTM()
    model.load_state_dict(torch.load(getModelLoadPath(), map_location=device))
    model.to(device).eval()

    with torch.no_grad():
        output = model(melInput, stftInput)

    return output[0, 0].cpu(), output[0, 1].cpu()


#Output

In [None]:
def reconstructSingleVoice(originalAudioPath, predictedMask, sampleRate=getSR()):
    predictedMask = predictedMask ** 2

    # --- 1. Load audio ---
    signal, sr = librosa.load(originalAudioPath, sr=sampleRate)

    # --- 2. Normalize amplitude to [0, 1] ---
    signal = (signal + 1) / 2

    # --- 3. High-pass filter (75 Hz) ---
    nyquist = 0.5 * sr
    cutoff = 75
    b, a = butter(4, cutoff / nyquist, btype='high')
    signal = filtfilt(b, a, signal)

    # --- 4. Pre-emphasis ---
    alpha = 0.97
    emphasized = np.append(signal[0], signal[1:] - alpha * signal[:-1])

    # --- 5. STFT ---
    stft = librosa.stft(emphasized, n_fft=1024, hop_length=256, win_length=1024, window='hamming')
    stftMag = np.abs(stft)
    stftPhase = np.angle(stft)
    stftMagLog = np.log1p(stftMag)

    # --- 6. Resize predicted mask to match STFT shape ---
    # predictedMask is [128, T], stftMagLog is [513, T]
    predictedMask_resized = librosa.util.fix_length(predictedMask, size=stftMagLog.shape[1], axis=1)
    predictedMask_resized = librosa.util.fix_length(predictedMask_resized, size=stftMagLog.shape[0], axis=0)

    # --- 7. Apply mask ---
    maskedLogMag = predictedMask_resized * stftMagLog
    maskedMag = np.expm1(maskedLogMag)

    # --- 8. Use Griffin-Lim to reconstruct waveform ---
    reconstructed = librosa.griffinlim(
        maskedMag,
        n_iter=64,
        hop_length=256,
        win_length=1024,
        window='hamming'
    )

    # --- 9. Undo pre-emphasis ---
    for i in range(1, len(reconstructed)):
        reconstructed[i] += alpha * reconstructed[i - 1]

    # --- 10. re-normalize (is clipping) ---
    max_val = np.max(np.abs(reconstructed))
    if max_val > 1.0:
        reconstructed = reconstructed / max_val * 0.95


    # --- 11. Undo pre-emphasis ---
    for i in range(1, len(reconstructed)):
        reconstructed[i] += alpha * reconstructed[i - 1]

    return reconstructed


#Training

In [None]:
# === Dataset Class ===
class SpectrogramDataset(Dataset):
    def __init__(self, melFolder, stftFolder, maskFolderA, maskFolderB, max_T=320):
        self.melFiles = sorted([os.path.join(melFolder, f) for f in os.listdir(melFolder) if f.endswith(".npy")])
        self.stftFiles = sorted([os.path.join(stftFolder, f) for f in os.listdir(stftFolder) if f.endswith(".npy")])
        self.maskFilesA = sorted([os.path.join(maskFolderA, f) for f in os.listdir(maskFolderA) if f.endswith(".npy")])
        self.maskFilesB = sorted([os.path.join(maskFolderB, f) for f in os.listdir(maskFolderB) if f.endswith(".npy")])
        self.max_T = max_T

    def __len__(self):
        return len(self.melFiles)

    def __getitem__(self, idx):
        try:
            # Load all spectrograms
            mel = torch.tensor(np.load(self.melFiles[idx]), dtype=torch.float32).unsqueeze(0)     # [1, 512, T]
            stft = torch.tensor(np.load(self.stftFiles[idx]), dtype=torch.float32).unsqueeze(0)   # [1, 512, T]
            maskA = torch.tensor(np.load(self.maskFilesA[idx]), dtype=torch.float32).unsqueeze(0) # [1, 512, T]
            maskB = torch.tensor(np.load(self.maskFilesB[idx]), dtype=torch.float32).unsqueeze(0) # [1, 512, T]

            #print(f"[{idx}] mel: {mel.shape}, stft: {stft.shape}, maskA: {maskA.shape}, maskB: {maskB.shape}")

            # Frequency dimension checks
            if mel.shape[1] != 128:
                print(f"⚠️ Mel freq != 513 at {idx}, skipping.")
                return None
            if any(x.shape[1] != 513 for x in [stft, maskA, maskB]):
                print(f"⚠️ STFT/mask freq != 513 at {idx}, skipping.")
                return None

            T = mel.shape[2]

            min_required_T = 20  # avoid tiny garbage tensors

            if mel.shape[2] < min_required_T or stft.shape[2] < min_required_T \
              or maskA.shape[2] < min_required_T or maskB.shape[2] < min_required_T:
                print(f"⚠️ Sample {idx} is too short: mel={mel.shape[2]}, stft={stft.shape[2]}")
                return None


            if T < self.max_T:
                pad_amt = self.max_T - T
                mel = F.pad(mel, (0, pad_amt))
                stft = F.pad(stft, (0, pad_amt))
                maskA = F.pad(maskA, (0, pad_amt))
                maskB = F.pad(maskB, (0, pad_amt))

            elif T > self.max_T:
                start = random.randint(0, T - self.max_T)
                mel = mel[:, :, start:start+self.max_T]
                stft = stft[:, :, start:start+self.max_T]
                maskA = maskA[:, :, start:start+self.max_T]
                maskB = maskB[:, :, start:start+self.max_T]

            min_T = min(mel.shape[-1], stft.shape[-1], maskA.shape[-1], maskB.shape[-1])
            mel = mel[:, :, :min_T]
            stft = stft[:, :, :min_T]
            maskA = maskA[:, :, :min_T]
            maskB = maskB[:, :, :min_T]


            assert mel.shape[2] == self.max_T, f"mel shape = {mel.shape}"
            assert stft.shape[2] == self.max_T, f"stft shape = {stft.shape}"
            assert maskA.shape[2] == self.max_T, f"maskA shape = {maskA.shape}"
            assert maskB.shape[2] == self.max_T, f"maskB shape = {maskB.shape}"


            return mel, stft, maskA, maskB

        except Exception as e:
            print(f"⚠️ Skipping file {idx}: {e}")
            return None





# === Collate Function ===
def collate_fn(batch):
    batch = [b for b in batch if b is not None]  # remove invalid entries
    if len(batch) == 0:
        return None  # skip this batch entirely
    mel, stft, maskA, maskB = zip(*batch)
    return torch.stack(mel), torch.stack(stft), torch.stack(maskA), torch.stack(maskB)


import torch.nn.functional as F

# === Dice Loss ===
def dice_loss(pred, target, epsilon=1e-6):
    intersection = (pred * target).sum(dim=(1, 2, 3))
    union = pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
    dice = (2 * intersection + epsilon) / (union + epsilon)
    return 1 - dice.mean()

# === Full Loss Function ===
def compute_total_loss(predA, predB, maskA, maskB, stft, epoch, criterion):
    reconA = predA * stft
    reconB = predB * stft
    recon_loss = criterion(reconA, maskA * stft) + criterion(reconB, maskB * stft)

    bce_loss = F.binary_cross_entropy(predA, maskA) + F.binary_cross_entropy(predB, maskB)
    dice = dice_loss(predA, maskA) + dice_loss(predB, maskB)
    mask_loss = 0.5 * bce_loss + 0.5 * dice

    complementarity = torch.mean(torch.abs(predA + predB - 1.0))

    k = 200
    certainty_penalty = torch.mean(torch.exp(-((predA - 0.5) ** 2) * k)) + \
                        torch.mean(torch.exp(-((predB - 0.5) ** 2) * k))

    confidence_gain = torch.mean(torch.abs(predA - 0.5)) + torch.mean(torch.abs(predB - 0.5))

    entropyA = - (predA * torch.log(predA + 1e-6) + (1 - predA) * torch.log(1 - predA + 1e-6)).mean()
    entropyB = - (predB * torch.log(predB + 1e-6) + (1 - predB) * torch.log(1 - predB + 1e-6)).mean()

    # Weights
    certainty_weight = 0.7
    confidence_weight = 0.25
    entropy_weight = 0.05
    complementarity_weight = 0.05

    total_loss = (
        recon_loss
        + mask_loss
        + complementarity_weight * complementarity
        + certainty_weight * certainty_penalty
        - confidence_weight * confidence_gain
        + entropy_weight * (entropyA + entropyB)
    )
    return total_loss

# === Training Loop ===
def trainModel(model, dataloader, numEpochs, lr, device='cuda', patience=12):
    import copy
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    best_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    epochs_no_improve = 0

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(numEpochs):
        model.train()
        totalLoss = 0

        for batch in dataloader:
            if batch is None:
                continue

            mel, stft, maskA, maskB = batch

            try:
                mel, stft, maskA, maskB = mel.to(device), stft.to(device), maskA.to(device), maskB.to(device)

                with torch.cuda.amp.autocast():
                    output = model(mel, stft)
                    predA = output[:, 0].unsqueeze(1)
                    predB = output[:, 1].unsqueeze(1)

                    loss = compute_total_loss(predA, predB, maskA, maskB, stft, epoch, criterion)

                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                totalLoss += loss.detach().item()

            except Exception as e:
                print(f"⚠️ Skipping batch due to error: {e}")
                continue

        avgLoss = totalLoss / len(dataloader)

        # === Logging ===
        print(f"Epoch {epoch+1}/{numEpochs}")
        print("Loss:", avgLoss)
        print("predA mean:", predA.mean().item(), "std:", predA.std().item())
        print("MSE loss:", criterion(predA, maskA).item())
        print("maskA vs maskB diff:", torch.mean(torch.abs(maskA - maskB)).item())

        # === Early stopping ===
        if avgLoss < best_loss - 1e-5:
            best_loss = avgLoss
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"⏹️ Early stopping at epoch {epoch+1}")
                break

    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), getModelSavePath())
    print("✅ Best model saved.")


# === Run Training ===
def runTraining():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🚀 Running on: {device}")

    # ✅ Model to GPU
    model = DualUNetBLSTM().to(device)

    # ✅ Dataset + DataLoader
    dataset = SpectrogramDataset(melPath, stftPath, maskPathA, maskPathB)

    # ✅ Recommended batch size for A100: 8, 16, or even 32
    loader = DataLoader(
        dataset,
        batch_size=128,  # adjust upward if you have enough memory
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,  # increase for faster loading
        pin_memory=True
    )

    # Run the actual training loop
    trainModel(model, loader, numEpochs=150, lr=1e-4, device=device)


# === Inference Helper ===
def predictMasks(melInput, stftInput):
    device = torch.device("cuda")

    if isinstance(melInput, np.ndarray):
        melInput = torch.tensor(melInput, dtype=torch.float32)
    if isinstance(stftInput, np.ndarray):
        stftInput = torch.tensor(stftInput, dtype=torch.float32)

    if melInput.ndim == 2:
        melInput = melInput.unsqueeze(0).unsqueeze(0)
    elif melInput.ndim == 3:
        melInput = melInput.unsqueeze(1)

    if stftInput.ndim == 2:
        stftInput = stftInput.unsqueeze(0).unsqueeze(0)
    elif stftInput.ndim == 3:
        stftInput = stftInput.unsqueeze(1)

    melInput, stftInput = melInput.to(device), stftInput.to(device)

    model = DualUNetBLSTM()
    model.load_state_dict(torch.load(getModelLoadPath(), map_location=device))
    model.to(device).eval()

    with torch.no_grad():
        output = model(melInput, stftInput)

    return output[0, 0].cpu(), output[0, 1].cpu()


#Main


In [None]:
"""Now on to where the code actually does stuff"""

audio = importAudio(getFilePath())
"""This function imports the audio from a file (whose name is included in globals).
We use a function call instead of directly accessing the variable as good practice
so we can't accidently change the variable

Now please view the file "import_audio.py" to see how this function works (if you want)"""


"""Now that we know how the audio is imported, we can preprocess it to get it ready
for the model. We need to do this to mimic how humans hear music to increase the
ability of the model to separate voices like a human would"""

globalZScoreNormalization = init(dataset = None) # Need to give dataset ---------------------------
"""gets some data used in preprocessing (please ignore for now)"""
audioMel, audioSTFT = preprocessAudio(audio)
"""This function does several steps of preprocessing. Please view the file
"preprocessing.py" to see how this function works (if you want)"""


maskA, maskB = predictMasks(audioMel, audioSTFT)
"""Run the neural network model to figure out the masks for each voice used to isolate it.
Please view model.py to see how this function works"""

outputA = reconstructSingleVoice(getFilePath(), maskA)
outputB = reconstructSingleVoice(getFilePath(), maskB)

"""Reconstruct the audio for each voice using the masks."""

os.save(outputA, getOutputPathA() + str(datetime.datetime.now()))
os.save(outputB, getOutputPathB() + str(datetime.datetime.now()))



#Blank Model Test Code

TODO list
~~audio import~~


Test audio import

In [None]:
audio = importAudio(getFilePath())

Check file path is valid

In [None]:
import os
from google.colab import drive

# Make sure drive is mounted
drive.mount('/content/drive', force_remount=True)

file_path = getFilePath() # Get the file path from globals

# Check if the file exists
if os.path.exists(file_path):
    print(f"Success: The file exists at {file_path}")
else:
    print(f"Error: The file was NOT found at {file_path}")
    # Let's try to list the contents of the parent directory
    parent_dir = os.path.dirname(file_path)
    print(f"Checking the parent directory: {parent_dir}")
    if os.path.exists(parent_dir):
        print(f"Contents of {parent_dir}:")
        try:
            for item in os.listdir(parent_dir):
                print(item)
        except Exception as e:
            print(f"Could not list directory contents: {e}")
    else:
        print(f"Error: The parent directory was NOT found at {parent_dir}")

Save out audio to make sure is correct

In [None]:
import soundfile as sf
import datetime
import os # Keep os for path operations

# Ensure the save directory exists
save_dir = getSavePath() + "Misc/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print(f"Created directory: {save_dir}")

# Define the output filename with a timestamp
output_filename = os.path.join(save_dir, f"imported_audio_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.wav")

# Assuming 'audio' is a tuple where the first element is the audio data (numpy array)
# and the second is the sample rate (sr)
audio_data = audio[0]
sample_rate = audio[1]

# Save the audio data using soundfile
sf.write(output_filename, audio_data, sample_rate)

print(f"Audio saved successfully to: {output_filename}")

Apply preprocessing to make spectrograms

In [None]:
# Assuming 'audio' is a tuple from importAudio (audio_data, sr)
# Pass only the audio data (the first element) to preprocessAudio
preprocessed_mel, preprocessed_stft = preprocessAudio(audio[0])

print(f"Shape of preprocessed_mel after preprocessAudio: {preprocessed_mel.shape}")
print(f"Shape of preprocessed_stft after preprocessAudio: {preprocessed_stft.shape}")

In [None]:
import matplotlib.pyplot as plt
import datetime # Ensure datetime is imported if not already
import numpy as np # Ensure numpy is imported

def print_and_save_spectrums(mel_spectrogram, stft_spectrogram, sr, save_path_prefix=("spectrum_output_" + str(datetime.datetime.now())).replace(" ", "_")):

    # Let's calculate the time per column based on the STFT/Mel settings
    # From STFTSpectrogram and MelSpectrogram: hopSize = 256
    hopSize = 256 # Define hopSize here for use in calculating time axis
    time_per_column = hopSize / sr

    # Number of columns per second
    columns_per_second = int(1 / time_per_column)

    # Get the total duration in seconds based on the number of columns
    total_duration_seconds = mel_spectrogram.shape[1] * time_per_column

    print(f"Total duration of spectrograms: {total_duration_seconds:.2f} seconds")
    print(f"Columns per second: {columns_per_second}")
    print(f"Shape of Mel spectrogram: {mel_spectrogram.shape}")
    print(f"Shape of STFT spectrogram: {stft_spectrogram.shape}")

    # Ensure save directory exists
    # Modified line to ensure save directory exists correctly within the mounted Google Drive
    save_dir = os.path.dirname(getSavePath() + "Spectrograms/" + str(datetime.datetime.now()) + "/" + save_path_prefix)

    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print(f"Created directory: {save_dir}")


    for i in range(int(total_duration_seconds)):
        start_col = i * columns_per_second
        end_col = (i + 1) * columns_per_second

        # Handle the last segment which might be less than 1 second
        if start_col >= mel_spectrogram.shape[1]:
            break
        end_col = min(end_col, mel_spectrogram.shape[1])

        print(f"Processing time interval: {i}s to {i+1}s (columns {start_col} to {end_col})")

        mel_interval = mel_spectrogram[:, start_col:end_col]
        stft_interval = np.abs(stft_spectrogram[:, start_col:end_col]) # Take absolute for magnitude

        # --- Print Mel Spectrogram ---
        plt.figure(figsize=(10, 4))
        # Convert Mel spectrogram to dB for visualization
        mel_interval_db = librosa.power_to_db(mel_interval, ref=np.max)

        # --- Diagnostic Prints ---
        print(f"  Mel interval dB shape: {mel_interval_db.shape}")
        print(f"  Mel interval dB dtype: {mel_interval_db.dtype}")
        print(f"  Mel interval dB min/max: {np.min(mel_interval_db):.2f} / {np.max(mel_interval_db):.2f}")
        # --- End Diagnostic Prints ---


        try:
            librosa.display.specshow(mel_interval_db, sr=sr, hop_length=hopSize, x_axis='time', y_axis='mel')
        except ValueError as e:
            print(f"  Error plotting Mel spectrogram with specshow: {e}")
            print("  Attempting to plot with plt.imshow...")
            # --- Alternative plotting with imshow ---
            # Note: imshow requires origin='lower' for typical spectrogram orientation
            # and extent for correct axis labels
            plt.imshow(mel_interval_db, aspect='auto', origin='lower',
                       extent=[start_col * time_per_column, end_col * time_per_column, 0, mel_interval_db.shape[0]],
                       cmap='viridis') # Using a common colormap
            plt.ylabel('Mel Bins')
            plt.xlabel('Time (s)')
            # --- End Alternative plotting ---


        plt.colorbar(format='%+2.0f dB')
        plt.title(f'Mel Spectrogram ({i}s - {i+1}s)')
        mel_filename = f"{save_dir}/{save_path_prefix}_mel_{i}s.png"
        plt.savefig(mel_filename)
        plt.close() # Close the plot to free memory
        print(f"Saved Mel spectrogram for interval {i}s - {i+1}s to {mel_filename}")


        # --- Print STFT Spectrogram ---
        plt.figure(figsize=(10, 4))
        # Use librosa.amplitude_to_db for better visualization of STFT magnitude
        stft_interval_db = librosa.amplitude_to_db(stft_interval, ref=np.max) # Convert STFT to dB for consistency

        # --- Diagnostic Prints ---
        print(f"  STFT interval dB shape: {stft_interval_db.shape}")
        print(f"  STFT interval dB dtype: {stft_interval_db.dtype}")
        print(f"  STFT interval dB min/max: {np.min(stft_interval_db):.2f} / {np.max(stft_interval_db):.2f}")
        # --- End Diagnostic Prints ---

        try:
             librosa.display.specshow(stft_interval_db, sr=sr, hop_length=hopSize, x_axis='time', y_axis='log')
        except ValueError as e:
             print(f"  Error plotting STFT spectrogram with specshow: {e}")
             print("  Attempting to plot with plt.imshow...")
             # --- Alternative plotting with imshow ---
             plt.imshow(stft_interval_db, aspect='auto', origin='lower',
                       extent=[start_col * time_per_column, end_col * time_per_column, 0, stft_interval_db.shape[0]],
                       cmap='viridis')
             plt.ylabel('STFT Bins')
             plt.xlabel('Time (s)')
             # --- End Alternative plotting ---


        plt.colorbar(format='%+2.0f dB')
        plt.title(f'STFT Spectrogram ({i}s - {i+1}s)')
        stft_filename = f"{save_dir}/{save_path_prefix}_stft_{i}s.png"
        plt.savefig(stft_filename)
        plt.close() # Close the plot
        print(f"Saved STFT spectrogram for interval {i}s - {i+1}s to {stft_filename}")

# Example usage (assuming audio and sr are loaded and preprocessAudio is run)
# audio, sr_loaded = importAudio(getFilePath())
# mel_spec, stft_spec = preprocessAudio(audio)
# print_and_save_spectrums(mel_spec, stft_spec, sr_loaded, save_path_prefix='audio_example') # Example save path

Print spectrograms

In [None]:
print_and_save_spectrums(preprocessed_mel, preprocessed_stft, getSR())

Random model testing

In [None]:


maskA, maskB = predictMasks(preprocessed_mel, preprocessed_stft)

Try using output

In [None]:
import soundfile as sf
import datetime
import os # Keep os for path operations

audioA = reconstructSingleVoice(getFilePath(), maskA)
audioB = reconstructSingleVoice(getFilePath(), maskB)

def saveAudioFile(file, path):
    path = path + str(datetime.datetime.now())
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Created directory: {path}")

    # Define the output filename with a timestamp
    output_filename = os.path.join(path, f"exported_audio_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.wav")

    audio_data = file
    sample_rate = getSR()

    # Save the audio data using soundfile
    sf.write(output_filename, audio_data, sample_rate)

    print(f"Audio saved successfully to: {output_filename}")


saveAudioFile(audioA, getOutputAPath())
saveAudioFile(audioB, getOutputBPath())


All blank model running

In [None]:
def saveAudioFile(file, path):
    path = path + str(datetime.datetime.now())
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Created directory: {path}")

    # Define the output filename with a timestamp
    output_filename = os.path.join(path, f"exported_audio_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.wav")

    audio_data = file
    sample_rate = getSR()

    # Save the audio data using soundfile
    sf.write(output_filename, audio_data, sample_rate)

    print(f"Audio saved successfully to: {output_filename}")

audio = importAudio(getFilePath())
preprocessed_mel, preprocessed_stft = preprocessAudio(audio[0])
maskA, maskB = predictMasks(preprocessed_mel, preprocessed_stft)
audioA = reconstructSingleVoice(getFilePath(), maskA)
audioB = reconstructSingleVoice(getFilePath(), maskB)



#Training Working


Import global Z score dataset

In [None]:
drive.mount('/content/drive', force_remount=True)


dataset = SpectrogramDataset(melPath, stftPath, maskPathA, maskPathB)

In [None]:
computeZScoreMeanStd(dataset)

In [None]:

runTraining()


In [None]:
#check GPU performance

!nvidia-smi


#Run the model

In [None]:
import soundfile as sf


audio, _ = importAudio(getFilePath())  # Unpack the tuple



audioMel, audioSTFT = preprocessAudio(audio)


maskA, maskB = predictMasks(audioMel, audioSTFT)

outputA = reconstructSingleVoice(getFilePath(), maskA)
outputB = reconstructSingleVoice(getFilePath(), maskB)

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

sf.write(getOutputAPath() + f"_{timestamp}.wav", outputA, samplerate=getSR())
sf.write(getOutputBPath() + f"_{timestamp}.wav", outputB, samplerate=getSR())



In [None]:
getFilePath()

Test Reconstruction to make sure training masks are correct

In [None]:
originalFile = "/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/All Data/combined iwitw angadi melody iwitw barsotti harmony.wav"
testMaskPath = "/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/All Data/mask_folderA/combined iwitw angadi melody iwitw barsotti harmony.npy"



mask = np.load(testMaskPath)

reconstructedAudio = reconstructSingleVoice(originalFile, mask, sampleRate=getSR())

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

sf.write("/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/Misc/" + f"_{timestamp}.wav", reconstructedAudio, samplerate=getSR())



#cache data

In [None]:
import os
import shutil
from google.colab import auth
from google.colab import drive
import datetime

drive.mount('/content/drive', force_remount=True)

def cache_data_to_local(src_root, dst_root, subfolders):
    os.makedirs(dst_root, exist_ok=True)

    for folder in subfolders:
        src_path = os.path.join(src_root, folder)
        dst_path = os.path.join(dst_root, folder)
        os.makedirs(dst_path, exist_ok=True)

        files = sorted(f for f in os.listdir(src_path) if f.endswith('.npy'))
        for f in files:
            src_file = os.path.join(src_path, f)
            dst_file = os.path.join(dst_path, f)
            if not os.path.exists(dst_file):  # skip if already cached
                shutil.copy2(src_file, dst_file)
        print(f"✅ Cached {len(files)} files from {folder}")

# Example usage
google_drive_root = "/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/All Data"
local_cache_root = "/content/cached_data"

os.makedirs(local_cache_root, exist_ok=True)

subfolders = ["mel_folder3", "stft_folder3", "mask_folderA3", "mask_folderB3"]
cache_data_to_local(google_drive_root, local_cache_root, subfolders)


Create validation split

In [None]:
import os
import random
import shutil

# === Settings ===
x = 100  # Number of validation sets to move

base_path = "/content/cached_data"
mel_folder      = os.path.join(base_path, "mel_folder3")
stft_folder     = os.path.join(base_path, "stft_folder3")
maskA_folder    = os.path.join(base_path, "mask_folderA3")
maskB_folder    = os.path.join(base_path, "mask_folderB3")

# Create validation folders
val_mel      = mel_folder + "_val"
val_stft     = stft_folder + "_val"
val_maskA    = maskA_folder + "_val"
val_maskB    = maskB_folder + "_val"

for folder in [val_mel, val_stft, val_maskA, val_maskB]:
    os.makedirs(folder, exist_ok=True)

# === Pick random subset ===
mel_files = sorted([f for f in os.listdir(mel_folder) if f.endswith(".npy")])
val_files = random.sample(mel_files, min(x, len(mel_files)))

# === Move matching sets ===
for fname in val_files:
    shutil.move(os.path.join(mel_folder, fname),   os.path.join(val_mel, fname))
    shutil.move(os.path.join(stft_folder, fname),  os.path.join(val_stft, fname))
    shutil.move(os.path.join(maskA_folder, fname), os.path.join(val_maskA, fname))
    shutil.move(os.path.join(maskB_folder, fname), os.path.join(val_maskB, fname))

print(f"✅ Moved {len(val_files)} sets to validation folders.")


#Test Masks

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from google.colab import auth
from google.colab import drive
import datetime

drive.mount('/content/drive', force_remount=True)

# ==== MODIFY THESE PATHS ====
mel_path   = "/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/All Data/mel_folder3/combined iwitw albright melody iwitw albright harmony.npy"
stft_path  = "/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/All Data/stft_folder3/combined iwitw albright melody iwitw albright harmony.npy"
maskA_path = "/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/All Data/mask_folderA3/combined iwitw albright melody iwitw albright harmony.npy"
maskB_path = "/content/drive/My Drive/We Love Parth!!/Data Collection/Model Data Storage (DO NOT RENAME)/All Data/mask_folderB3/combined iwitw albright melody iwitw albright harmony.npy"

model_path = getModelSavePath()  # or replace with .pt path if saved manually
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== LOAD DATA ====
mel = np.load(mel_path)
stft = np.load(stft_path)
maskA = np.load(maskA_path)
maskB = np.load(maskB_path)

# Truncate or pad to match time dims if needed
min_T = min(mel.shape[1], stft.shape[1], maskA.shape[1], maskB.shape[1])
mel = mel[:, :min_T]
stft = stft[:, :min_T]
maskA = maskA[:, :min_T]
maskB = maskB[:, :min_T]

# Convert to torch tensors
mel_tensor = torch.tensor(mel).unsqueeze(0).unsqueeze(0).float().to(device)  # (1, 1, 128, T)
stft_tensor = torch.tensor(stft).unsqueeze(0).unsqueeze(0).float().to(device)  # (1, 1, 513, T)

# ==== LOAD MODEL ====
model = DualUNetBLSTM()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# ==== INFERENCE ====
with torch.no_grad():
    pred = model(mel_tensor, stft_tensor)
    predA = pred[0, 0].cpu().numpy()
    predB = pred[0, 1].cpu().numpy()

# ==== PLOT ====
fig, axs = plt.subplots(2, 2, figsize=(14, 8))

axs[0, 0].imshow(maskA, aspect='auto', origin='lower')
axs[0, 0].set_title("Ground Truth Mask A")

axs[0, 1].imshow(predA, aspect='auto', origin='lower')
axs[0, 1].set_title("Predicted Mask A")

axs[1, 0].imshow(maskB, aspect='auto', origin='lower')
axs[1, 0].set_title("Ground Truth Mask B")

axs[1, 1].imshow(predB, aspect='auto', origin='lower')
axs[1, 1].set_title("Predicted Mask B")

plt.tight_layout()
plt.show()


#bottom

#output testing

In [None]:
!pip install mir_eval


In [None]:
import numpy as np
from mir_eval.separation import bss_eval_sources

def compute_mask_accuracy(predA, predB, maskA, maskB, threshold=0.5):
    predA_bin = (predA > threshold).astype(np.float32)
    predB_bin = (predB > threshold).astype(np.float32)
    maskA_bin = (maskA > threshold).astype(np.float32)
    maskB_bin = (maskB > threshold).astype(np.float32)

    correctA = np.sum(predA_bin == maskA_bin)
    correctB = np.sum(predB_bin == maskB_bin)
    total = maskA.size + maskB.size

    return (correctA + correctB) / total


def evaluate_on_loader(val_loader):
    sdr_list, sir_list, sar_list = [], [], []
    acc_list = []

    for batch in val_loader:
        if batch is None:
            continue
        mel, stft, maskA, maskB = batch

        for i in range(mel.size(0)):

            if i is None:
                continue  # Skip batches with only bad data

            mel_i = mel[i].squeeze().cpu().numpy()
            stft_i = stft[i].squeeze().cpu().numpy()
            maskA_i = maskA[i].squeeze().cpu().numpy()
            maskB_i = maskB[i].squeeze().cpu().numpy()

            try:
                predA, predB = predictMasks(mel_i, stft_i)

                # Compute reconstructions
                estA = predA * stft_i
                estB = predB * stft_i
                trueA = maskA_i * stft_i
                trueB = maskB_i * stft_i

                ref = np.stack([trueA.flatten(), trueB.flatten()])
                est = np.stack([estA.flatten(), estB.flatten()])

                sdr, sir, sar, _ = bss_eval_sources(ref, est)

                sdr_list.append(np.mean(sdr))
                sir_list.append(np.mean(sir))
                sar_list.append(np.mean(sar))

                # Compute mask accuracy
                acc = compute_mask_accuracy(predA.numpy(), predB.numpy(), maskA_i, maskB_i)
                acc_list.append(acc)

            except Exception as e:
                print(f"⚠️ Skipped sample due to: {e}")

    print("\n✅ Evaluation complete.")
    print(f"SDR: {np.mean(sdr_list):.4f}")
    print(f"SIR: {np.mean(sir_list):.4f}")
    print(f"SAR: {np.mean(sar_list):.4f}")
    print(f"Mask Accuracy: {np.mean(acc_list):.4f}")

    return np.mean(sdr_list), np.mean(sir_list), np.mean(sar_list), np.mean(acc_list)

from torch.utils.data import DataLoader, random_split

val_dataset = SpectrogramDataset(
    melFolder='/content/cached_data/mel_folder3_val',
    stftFolder='/content/cached_data/stft_folder3_val',
    maskFolderA='/content/cached_data/mask_folderA3_val',
    maskFolderB='/content/cached_data/mask_folderB3_val',
    max_T=1024  # or whatever your default is
)

def safe_collate(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return torch.utils.data.default_collate(batch)


val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=safe_collate
)


evaluate_on_loader(val_loader)
