# Annealing


In [None]:
import math
from bisect import bisect_right,bisect_left

import torch
import numpy as np
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

class CyclicCosAnnealingLR(_LRScheduler):
    r"""
    Implements reset on milestones inspired from CosineAnnealingLR pytorch

    Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:

    .. math::

        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
        \cos(\frac{T_{cur}}{T_{max}}\pi))

    When last_epoch > last set milestone, lr is automatically set to \eta_{min}

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
    implements the cosine annealing part of SGDR, and not the restarts.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        milestones (list of ints): List of epoch indices. Must be increasing.
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer,milestones, eta_min=0, last_epoch=-1):
        if not list(milestones) == sorted(milestones):
            raise ValueError('Milestones should be a list of'
                             ' increasing integers. Got {}', milestones)
        self.eta_min = eta_min
        self.milestones=milestones
        super(CyclicCosAnnealingLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):

        if self.last_epoch >= self.milestones[-1]:
            return [self.eta_min for base_lr in self.base_lrs]

        idx = bisect_right(self.milestones,self.last_epoch)

        left_barrier = 0 if idx==0 else self.milestones[idx-1]
        right_barrier = self.milestones[idx]

        width = right_barrier - left_barrier
        curr_pos = self.last_epoch- left_barrier

        return [self.eta_min + (base_lr - self.eta_min) *
               (1 + math.cos(math.pi * curr_pos/ width)) / 2
                for base_lr in self.base_lrs]


class CyclicLinearLR(_LRScheduler):
    r"""
    Implements reset on milestones inspired from Linear learning rate decay

    Set the learning rate of each parameter group using a linear decay
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart:

    .. math::

        \eta_t = \eta_{min} + (\eta_{max} - \eta_{min})(1 -\frac{T_{cur}}{T_{max}})

    When last_epoch > last set milestone, lr is automatically set to \eta_{min}

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        milestones (list of ints): List of epoch indices. Must be increasing.
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer,milestones, eta_min=0, last_epoch=-1):
        if not list(milestones) == sorted(milestones):
            raise ValueError('Milestones should be a list of'
                             ' increasing integers. Got {}', milestones)
        self.eta_min = eta_min
        self.milestones=milestones
        super(CyclicLinearLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):

        if self.last_epoch >= self.milestones[-1]:
            return [self.eta_min for base_lr in self.base_lrs]

        idx = bisect_right(self.milestones,self.last_epoch)

        left_barrier = 0 if idx==0 else self.milestones[idx-1]
        right_barrier = self.milestones[idx]

        width = right_barrier - left_barrier
        curr_pos = self.last_epoch- left_barrier

        return [self.eta_min + (base_lr - self.eta_min) *
               (1. - 1.0*curr_pos/ width)
                for base_lr in self.base_lrs]

'''
#################################
# TEST FOR SCHEDULER
#################################
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim

net = nn.Sequential(nn.Linear(2,2))
milestones = [(2**x)*300 for x in range(30)]
optimizer = optim.SGD(net.parameters(),lr=1e-3,momentum=0.9,weight_decay=0.0005,nesterov=True)
scheduler = CyclicCosAnnealingLR(optimizer,milestones=milestones,eta_min=1e-6)

lr_log = []

for i in range(20*300):
    optimizer.step()
    scheduler.step()
    for param_group in optimizer.param_groups:
        lr_log.append(param_group['lr'])

plt.plot(lr_log)
plt.show()
'''


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SepConvNet(nn.Module):
    def __init__(self, t1, f1, t2, f2, N1, N2, input_shape=[513, 345], NN=128):
        super(SepConvNet, self).__init__()
        self.vconv_left = nn.Conv2d(1, N1, kernel_size=(f1, t1), padding=0)
        self.hconv_left = nn.Conv2d(N1, N2, kernel_size=(f2, t2))
        self.hconv_right = nn.Conv2d(1, N1, kernel_size=(f2, t2))
        self.vconv_right = nn.Conv2d(N1, N2, kernel_size=(f1, t1), padding=0)

        self.fc0 = nn.Linear(N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2), NN)
        self.fc1 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc2 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc3 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc4 = nn.Linear(NN, N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.hdeconv1 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2))
        self.hdeconv2 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2))
        self.hdeconv3 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2))
        self.hdeconv4 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2, t2))
        self.vdeconv1 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1))
        self.vdeconv2 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1))
        self.vdeconv3 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1))
        self.vdeconv4 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1, t1))

    def forward(self, x):
        x_left = self.vconv_left(x)
        x_left = self.hconv_left(x_left)

        x_right = self.hconv_right(x)
        x_right = self.vconv_right(x_right)

        x = x_left + x_right

        s1 = x.shape

        x = x.view(s1[0], -1)

        x = F.relu(self.fc0(x))

        x1 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x))
        x3 = F.relu(self.fc3(x))
        x4 = F.relu(self.fc4(x))

        x1 = x1.view(s1[0], s1[1], s1[2], s1[3])
        x2 = x2.view(s1[0], s1[1], s1[2], s1[3])
        x3 = x3.view(s1[0], s1[1], s1[2], s1[3])
        x4 = x4.view(s1[0], s1[1], s1[2], s1[3])

        x1 = self.hdeconv1(x1)
        x2 = self.hdeconv2(x2)
        x3 = self.hdeconv3(x3)
        x4 = self.hdeconv4(x4)

        x1 = self.vdeconv1(x1)
        x2 = self.vdeconv2(x2)
        x3 = self.vdeconv3(x3)
        x4 = self.vdeconv4(x4)

        return x1, x2, x3, x4
    
import torch
import torch.nn as nn
import torch.nn.functional as F
class SepConvNet(nn.Module):
    def __init__(self,t1,f1,t2,f2,N1,N2,input_shape=[513,862],NN=128):
        super(SepConvNet, self).__init__()
        self.vconv = nn.Conv2d(1,N1, kernel_size=(f1,t1),padding=0)
        self.hconv = nn.Conv2d(N1,N2, kernel_size=(f2,t2))

        self.fc0 = nn.Linear(N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2), NN)
        self.fc1 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc2 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc3 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc4 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.hdeconv1 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2))
        self.hdeconv2 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2))
        self.hdeconv3 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2))
        self.hdeconv4 = nn.ConvTranspose2d(N2, N1, kernel_size=(f2,t2))
        self.vdeconv1 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1))
        self.vdeconv2 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1))
        self.vdeconv3 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1))
        self.vdeconv4 = nn.ConvTranspose2d(N1, 1, kernel_size=(f1,t1))
    def forward(self, x):
        x = self.vconv(x)

        x = self.hconv(x)

        s1 = x.shape

        x = x.view(s1[0],-1)



        x = F.relu(self.fc0(x))

        x1 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x))
        x3 = F.relu(self.fc3(x))
        x4 = F.relu(self.fc4(x))

        x1 = x1.view(s1[0], s1[1],s1[2],s1[3])
        x2 = x2.view(s1[0], s1[1],s1[2],s1[3])
        x3 = x3.view(s1[0], s1[1],s1[2],s1[3])
        x4 = x4.view(s1[0], s1[1],s1[2],s1[3])

        x1 = self.hdeconv1(x1)
        x2 = self.hdeconv2(x2)
        x3 = self.hdeconv3(x3)
        x4 = self.hdeconv4(x4)

        x1 = self.vdeconv1(x1)
        x2 = self.vdeconv2(x2)
        x3 = self.vdeconv3(x3)
        x4 = self.vdeconv4(x4)

        return x1, x2, x3, x4



KeyboardInterrupt: 

# Load

In [None]:
from torch.utils.data.dataset import Dataset
import torch
from torchvision import transforms
#from skimage import io, transform
import os
import numpy as np
import re

class SourceSepTrain(Dataset):
    def __init__(self, path='../Processed/Mixtures', transforms=None):
    # assuming this to be the directory containing all the magnitude spectrum
    #for all songs and all segments used in training
        self.path = path
        self.list = os.listdir(self.path)
        self.transforms = transforms

    def __getitem__(self, index):
        mixture_path = '../Processed/Mixtures/'
        bass_path = '../Processed/Bass/'
        vocals_path = '../Processed/Vocals/'
        drums_path = '../Processed/Drums/'
        others_path = '../Processed/Others/'
        mixture = torch.load(mixture_path+self.list[index])
        #phase = torch.load(mixture_path+self.list[index]+'_p')
        bass = torch.load(bass_path+self.list[index])
        vocals = torch.load(vocals_path+self.list[index])
        drums = torch.load(drums_path+self.list[index])
        others = torch.load(others_path+self.list[index])
        #print(mixture)
        if self.transforms is not None:
            mixture = self.transforms(mixture)

            bass = self.transforms(bass)
            vocals = self.transforms(vocals)
            drums = self.transforms(drums)
            others = self.transforms(others)
        return (mixture,bass, vocals, drums, others)

    def __len__(self):
        return len(self.list) # length of how much data you have


class SourceSepVal(Dataset):
    def __init__(self, path='../Val/Mixtures', transforms=None):
        # assuming this to be the directory containing all the magnitude spectrum
        #for all songs and all segments used in training
        self.path = path
        self.list = os.listdir(self.path)
        self.transforms = transforms

    def __getitem__(self, index):
        # stuff
        mixture_path = '../Val/Mixtures/'
        bass_path = '../Val/Bass/'
        vocals_path = '../Val/Vocals/'
        drums_path = '../Val/Drums/'
        others_path = '../Val/Others/'

        mixture = torch.load(mixture_path+self.list[index])
        #phase = torch.load(mixture_path+self.list[index]+'_p')
        bass = torch.load(bass_path+self.list[index])
        vocals = torch.load(vocals_path+self.list[index])
        drums = torch.load(drums_path+self.list[index])
        others = torch.load(others_path+self.list[index])

        if self.transforms is not None:
            mixture = self.transforms(mixture)
            bass = self.transforms(bass)
            vocals = self.transforms(vocals)
            drums = self.transforms(drums)
            others = self.transforms(others)

        return (mixture,bass, vocals, drums, others)
    def __len__(self):
        return len(self.list)

class SourceSepTest(Dataset):
    def __init__(self, path='../Val/Mixtures',transforms=None):
        # assuming this to be the directory containing all the magnitude spectrum
        #for all songs and all segments used in training
        self.path = path
        self.list = os.listdir(self.path)
        self.transforms = transforms

    def __getitem__(self, index):
        mixture_path = '../Val/Mixtures/'
        bass_path = '../Val/Bass/'
        vocals_path = '../Val/Vocals/'
        drums_path = '../Val/Drums/'
        others_path = '../Val/Others/'
        phase_path = '../Val/Phases/'

        phase_file=self.list[index].replace('_m','_p')
        phase_file=phase_file.replace('.pt','.npy')
        mixture = torch.load(mixture_path+self.list[index])
        #phase = np.load(phase_path+phase_file)
        bass = torch.load(bass_path+self.list[index])
        vocals = torch.load(vocals_path+self.list[index])
        drums = torch.load(drums_path+self.list[index])
        others = torch.load(others_path+self.list[index])

        if self.transforms is not None:
            mixture = self.transforms(mixture)
            bass = self.transforms(bass)
            vocals = self.transforms(vocals)
            drums = self.transforms(drums)
            others = self.transforms(others)

        return (mixture,phase_file,self.list[index])


    def __len__(self):
        return len(self.list)


# Preprocess

In [None]:
import librosa
import numpy as np
import torch
import os
import re

# Paths
base_path = "../dsd100/"
path_mixtures = os.path.join(base_path, "Mixtures/Dev/")
path_sources = os.path.join(base_path, "Sources/Dev/")
processed_path = "../Processed/"

destination_path = os.path.join(processed_path, "Mixtures")
phase_path = os.path.join(processed_path, "Phases")
bass_path = os.path.join(processed_path, "Bass")
vocals_path = os.path.join(processed_path, "Vocals")
drums_path = os.path.join(processed_path, "Drums")
others_path = os.path.join(processed_path, "Others")
source_dest_paths = [vocals_path, bass_path, drums_path, others_path]

# Validation Paths
path_val_mixtures = os.path.join(base_path, "Mixtures/Test/")
path_val_sources = os.path.join(base_path, "Sources/Test/")
val_path = "../Val/"

validation_path = os.path.join(val_path, "Mixtures")
val_phase_path = os.path.join(val_path, "Phases")
val_bass_path = os.path.join(val_path, "Bass")
val_vocals_path = os.path.join(val_path, "Vocals")
val_drums_path = os.path.join(val_path, "Drums")
val_others_path = os.path.join(val_path, "Others")
source_val_paths = [val_vocals_path, val_bass_path, val_drums_path, val_others_path]

# Test Paths (same structure as validation)
path_test_mixtures = path_val_mixtures
path_test_sources = path_val_sources
test_path = "../Test/"

testing_path = os.path.join(test_path, "Mixtures")
test_phase_path = os.path.join(test_path, "Phases")
test_bass_path = os.path.join(test_path, "Bass")
test_vocals_path = os.path.join(test_path, "Vocals")
test_drums_path = os.path.join(test_path, "Drums")
test_others_path = os.path.join(test_path, "Others")
source_test_paths = [test_vocals_path, test_bass_path, test_drums_path, test_others_path]


def process(file_path, direc, destination_path, phase_bool, destination_phase_path):
    """
    Process audio files: segment, compute STFT, and save magnitude and phase data.
    """
    try:
        # Get duration and calculate segments
        duration = librosa.get_duration(filename=file_path)
        max_segments = int(duration // 0.3)
        regex = re.compile(r'\d+')
        index = regex.findall(direc)
        if not index:
            raise ValueError("Directory name does not contain valid index.")
        
        for start in range(max_segments):
            wave_array, fs = librosa.load(file_path, sr=44100, offset=start * 0.3, duration=0.3)
            mag, phase = librosa.magphase(librosa.stft(wave_array, n_fft=1024, hop_length=256, window='hann', center=True))
            
            os.makedirs(destination_path, exist_ok=True)
            torch.save(torch.from_numpy(np.expand_dims(mag, axis=0)), os.path.join(destination_path, f"{index[0]}_{start}_m.pt"))
            
            if phase_bool:
                os.makedirs(destination_phase_path, exist_ok=True)
                np.save(os.path.join(destination_phase_path, f"{index[0]}_{start}_p.npy"), phase)
    
    except Exception as e:
        print(f"Error processing {file_path}: {e}")


def process_directory(base_path, destination_paths, phase_bool=False, phase_path=None):
    """
    Process all files in a directory structure.
    """
    for subdirs, dirs, files in os.walk(base_path):
        for direc in dirs:
            print(f"Processing directory: {direc}")
            for s, d, f in os.walk(os.path.join(base_path, direc)):
                if not f:
                    continue
                for i, file in enumerate(f[:len(destination_paths)]):  # Ensure alignment with paths
                    file_path = os.path.join(base_path, direc, file)
                    process(file_path, direc, destination_paths[i], phase_bool, phase_path)


# ------------------------- Training Data -------------------------
print("Processing training data...")
process_directory(path_mixtures, [destination_path], phase_bool=True, phase_path=phase_path)
process_directory(path_sources, source_dest_paths, phase_bool=False)

# ------------------------ Validation Data ------------------------
print("Processing validation data...")
process_directory(path_val_mixtures, [validation_path], phase_bool=True, phase_path=val_phase_path)
process_directory(path_val_sources, source_val_paths, phase_bool=False)

# -------------------------- Test Data ---------------------------
print("Processing test data...")
process_directory(path_test_mixtures, [testing_path], phase_bool=True, phase_path=test_phase_path)
process_directory(path_test_sources, source_test_paths, phase_bool=False)

print("Processing complete.")


# Train

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from build_model_original import SepConvNet  # Ensure this file and model are correctly implemented
from data_loader import SourceSepTrain, SourceSepVal  # Ensure these are correctly implemented
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR
import os
from tqdm import tqdm


# Paths and Parameters
mean_var_path = "../Processed/"
if not os.path.exists('Weights'):
    os.makedirs('Weights')

inp_size = [513, 52]
t1, f1 = 1, 513
t2, f2 = 15, 1
N1, N2, NN = 50, 30, 128
alpha, beta, beta_vocals = 0.005, 0.05, 0.08
batch_size, num_epochs = 30, 50

writer = SummaryWriter()  # TensorBoard writer


# Utility Classes
class Average:
    """Tracks and computes the average of values over time."""
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val
        self.count += n

    @property
    def avg(self):
        return self.sum / self.count


# Custom Loss Function
class MixedSquaredError(nn.Module):
    """Custom loss function for the source separation task."""
    def __init__(self):
        super(MixedSquaredError, self).__init__()

    def forward(self, pred_bass, pred_vocals, pred_drums, pred_others, gt_bass, gt_vocals, gt_drums, gt_others):
        L_sq = torch.sum((pred_bass - gt_bass).pow(2)) + torch.sum((pred_vocals - gt_vocals).pow(2)) + torch.sum((pred_drums - gt_drums).pow(2))
        L_other = torch.sum((pred_bass - gt_others).pow(2)) + torch.sum((pred_drums - gt_others).pow(2))
        L_othervocals = torch.sum((pred_vocals - gt_others).pow(2))
        L_diff = torch.sum((pred_bass - pred_vocals).pow(2)) + torch.sum((pred_bass - pred_drums).pow(2)) + torch.sum((pred_vocals - pred_drums).pow(2))
        return L_sq - alpha * L_diff - beta * L_other - beta_vocals * L_othervocals


# Time-Frequency Masking Function
def TimeFreqMasking(bass, vocals, drums, others, cuda=False):
    """Applies time-frequency masking."""
    den = torch.abs(bass) + torch.abs(vocals) + torch.abs(drums) + torch.abs(others) + 1e-8
    bass = torch.abs(bass) / den
    vocals = torch.abs(vocals) / den
    drums = torch.abs(drums) / den
    others = torch.abs(others) / den
    return bass, vocals, drums, others


# Training Function
def train():
    cuda = torch.cuda.is_available()
    net = SepConvNet(t1, f1, t2, f2, N1, N2, inp_size, NN)
    criterion = MixedSquaredError()
    if cuda:
        net = net.cuda()
        criterion = criterion.cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    scheduler = MultiStepLR(optimizer, milestones=[60, 120])

    print("Preparing training data ...")
    train_loader = DataLoader(SourceSepTrain(transforms=None), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(SourceSepVal(transforms=None), batch_size=batch_size, shuffle=False)
    print("Data preparation done.")

    for epoch in range(num_epochs):
        scheduler.step()
        train_loss = Average()

        # Training Phase
        net.train()
        for inp, gt_bass, gt_vocals, gt_drums, gt_others in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            mean, std = torch.mean(inp), torch.std(inp)
            inp_n = (inp - mean) / std

            inp, inp_n = Variable(inp), Variable(inp_n)
            gt_bass, gt_vocals, gt_drums, gt_others = map(Variable, (gt_bass, gt_vocals, gt_drums, gt_others))
            if cuda:
                inp, inp_n, gt_bass, gt_vocals, gt_drums, gt_others = inp.cuda(), inp_n.cuda(), gt_bass.cuda(), gt_vocals.cuda(), gt_drums.cuda(), gt_others.cuda()

            optimizer.zero_grad()
            o_bass, o_vocals, o_drums, o_others = net(inp_n)
            mask_bass, mask_vocals, mask_drums, mask_others = TimeFreqMasking(o_bass, o_vocals, o_drums, o_others, cuda)

            pred_bass, pred_vocals, pred_drums, pred_others = inp * mask_bass, inp * mask_vocals, inp * mask_drums, inp * mask_others
            loss = criterion(pred_bass, pred_vocals, pred_drums, pred_others, gt_bass, gt_vocals, gt_drums, gt_others)
            writer.add_scalar('Train Loss', loss.item(), epoch)

            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))

        # Validation Phase
        val_loss = Average()
        net.eval()
        for val_inp, gt_bass, gt_vocals, gt_drums, gt_others in val_loader:
            with torch.no_grad():
                val_mean, val_std = torch.mean(val_inp), torch.std(val_inp)
                val_inp_n = (val_inp - val_mean) / val_std

                if cuda:
                    val_inp_n, gt_bass, gt_vocals, gt_drums, gt_others = val_inp_n.cuda(), gt_bass.cuda(), gt_vocals.cuda(), gt_drums.cuda(), gt_others.cuda()

                o_bass, o_vocals, o_drums, o_others = net(val_inp_n)
                mask_bass, mask_vocals, mask_drums, mask_others = TimeFreqMasking(o_bass, o_vocals, o_drums, o_others, cuda)

                pred_bass, pred_vocals, pred_drums, pred_others = val_inp * mask_bass, val_inp * mask_vocals, val_inp * mask_drums, val_inp * mask_others
                vloss = criterion(pred_bass, pred_vocals, pred_drums, pred_others, gt_bass, gt_vocals, gt_drums, gt_others)
                writer.add_scalar('Validation Loss', vloss.item(), epoch)
                val_loss.update(vloss.item(), val_inp.size(0))

        print(f"Epoch {epoch+1}, Train Loss: {train_loss.avg:.4f}, Validation Loss: {val_loss.avg:.4f}")
        torch.save(net.state_dict(), f"Weights/Weights_epoch{epoch+1}_valLoss{val_loss.avg:.4f}.pth")

    return net


# Test Function
def test(model):
    """Function to test the trained model."""
    model.eval()
    # Implement test logic here


# Main
if __name__ == "__main__":
    train()


# Test


In [None]:
import torch
import numpy as np
import glob
import re
import os
from build_model_original import SepConvNet  # Ensure this file is implemented correctly
from torch.utils.data import DataLoader
from data_loader import SourceSepTest  # Ensure this dataset class is implemented correctly
from post_processing import reconstruct  # Ensure this function reconstructs the audio correctly
from train_model import TimeFreqMasking  # Reuse the masking function from training
from tqdm import tqdm

# Main testing script
if __name__ == '__main__':
    # Model and dataset configurations
    inp_size = [513, 52]
    t1, f1 = 1, 513
    t2, f2 = 15, 1
    N1, N2, NN = 50, 30, 128
    batch_size = 1

    # Directory configurations
    destination_path = '../AudioResults/'
    phase_path = '../Val/Phases/'
    vocals_directory = os.path.join(destination_path, 'vocals')
    drums_directory = os.path.join(destination_path, 'drums')
    bass_directory = os.path.join(destination_path, 'bass')
    others_directory = os.path.join(destination_path, 'others')

    # Create output directories if they don't exist
    for directory in [destination_path, vocals_directory, drums_directory, bass_directory, others_directory]:
        os.makedirs(directory, exist_ok=True)

    # Load the model
    net = SepConvNet(t1, f1, t2, f2, N1, N2, inp_size, NN)
    
    weights_path = '../Weights/Weights_epoch10_valLoss27994.5971.pth'  # Replace with your best model's weight file path
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"Model weights not found at {weights_path}")
    
    # Load the trained weights
    net.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
    net.eval()

    # Load the test set
    test_set = SourceSepTest(transforms=None)  # Replace with any transformations, if applicable
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    # Test loop
    for i, (test_inp, test_phase_file, file_str) in tqdm(enumerate(test_loader), total=len(test_loader), desc="Testing"):
        print(f'Testing sample {i + 1}/{len(test_loader)}')

        # Load the phase file for reconstruction
        test_phase_path = os.path.join(phase_path, test_phase_file[0])
        if not os.path.exists(test_phase_path):
            raise FileNotFoundError(f"Phase file not found at {test_phase_path}")
        test_phase = np.load(test_phase_path)  # NumPy array

        # Normalize input
        mean, std = torch.mean(test_inp), torch.std(test_inp)
        test_inp_n = (test_inp - mean) / std

        # Forward pass through the model
        with torch.no_grad():
            bass_mag, vocals_mag, drums_mag, others_mag = net(test_inp_n)
            bass_mag, vocals_mag, drums_mag, others_mag = TimeFreqMasking(bass_mag, vocals_mag, drums_mag, others_mag)

            # Apply masks to the input spectrogram
            bass_mag = bass_mag * test_inp
            vocals_mag = vocals_mag * test_inp
            drums_mag = drums_mag * test_inp
            others_mag = others_mag * test_inp

        # Extract indices from the file string
        regex = re.compile(r'\d+')
        indices = regex.findall(file_str[0])
        if len(indices) < 2:
            raise ValueError(f"Expected at least two indices in the file string, found: {file_str[0]}")
        index_start, index_end = indices[0], indices[1]

        # Reconstruct the audio files
        reconstruct(
            test_phase,                          # NumPy array
            bass_mag.squeeze().cpu().numpy(),   # PyTorch tensor to NumPy array
            vocals_mag.squeeze().cpu().numpy(),
            drums_mag.squeeze().cpu().numpy(),
            others_mag.squeeze().cpu().numpy(),
            index_start,                        # Start index from the file name
            index_end,                          # End index from the file name
            destination_path                    # Path to save the results
        )

    print("Testing complete. Check the results in the specified output directories.")


  net.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'Val/Mixtures'

: 

# Post


In [None]:
import librosa
import numpy as np
import os
import torch
import soundfile as sf  # Import soundfile for writing audio files

def reconstruct(phase, bass_mag, vocals_mag, drums_mag, others_mag, song_num, segment_num, destination_path):
    # Ensure bass_mag, vocals_mag, drums_mag, others_mag are NumPy arrays
    if isinstance(bass_mag, torch.Tensor):
        bass_mag = bass_mag.detach().cpu().numpy()  # Convert to NumPy after detaching
    if isinstance(vocals_mag, torch.Tensor):
        vocals_mag = vocals_mag.detach().cpu().numpy()
    if isinstance(drums_mag, torch.Tensor):
        drums_mag = drums_mag.detach().cpu().numpy()
    if isinstance(others_mag, torch.Tensor):
        others_mag = others_mag.detach().cpu().numpy()

    # Ensure phase is a NumPy array if not already
    phase = np.asarray(phase)

    # Retrieve complex STFT
    vocals = vocals_mag * phase
    bass = bass_mag * phase
    drums = drums_mag * phase
    others = others_mag * phase

    # Perform ISTFT
    vocals_audio = librosa.istft(vocals, win_length=1024, hop_length=256, window='hann', center=True)
    bass_audio = librosa.istft(bass, win_length=1024, hop_length=256, window='hann', center=True)
    drums_audio = librosa.istft(drums, win_length=1024, hop_length=256, window='hann', center=True)
    others_audio = librosa.istft(others, win_length=1024, hop_length=256, window='hann', center=True)

    # Save as wav files using soundfile.write instead of librosa.output.write_wav
    sf.write(os.path.join(destination_path, 'vocals', f'{song_num}_{segment_num}.wav'), vocals_audio, 44100)
    sf.write(os.path.join(destination_path, 'bass', f'{song_num}_{segment_num}.wav'), bass_audio, 44100)
    sf.write(os.path.join(destination_path, 'drums', f'{song_num}_{segment_num}.wav'), drums_audio, 44100)
    sf.write(os.path.join(destination_path, 'others', f'{song_num}_{segment_num}.wav'), others_audio, 44100)

    return



# Eval

In [None]:
import mir_eval  
import numpy as np
from scipy.io import wavfile
import librosa


####################### MODIFY ##############################
#### additional for loop to evaluate multiple songs #########
# increase step to decrease time
step  = 10
bass_gt_path = 'bass.wav'
bass_rec_path = 'bass_rec.wav'
vocal_gt_path = 'vocals.wav'
vocal_rec_path = 'vocals_rec.wav'
drums_gt_path = 'drums.wav'
drums_rec_path = 'drums_rec.wav'
other_gt_path = 'other.wav'
other_rec_path = 'other_rec.wav'
############################################################



bass_gt, rate11 = librosa.load(bass_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3)
bass_rec, rate21 = librosa.load(bass_rec_path,sr=44100)

vocals_gt, rate12 = librosa.load(vocal_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3)
vocals_rec, rate22 = librosa.load(vocal_rec_path,sr=44100)

drums_gt, rate13 = librosa.load(drums_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3)
drums_rec, rate23 = librosa.load(drums_rec_path,sr=44100)

other_gt, rate14 = librosa.load(other_gt_path,sr=44100, offset=30*0.3,duration = 170*0.3)
other_rec, rate24 = librosa.load(other_rec_path,sr=44100)


bass_gt = bass_gt[0:bass_rec.shape[0]:step]
bass_gt = np.transpose(bass_gt.reshape(len(bass_gt), 1))

vocals_gt = vocals_gt[0:vocals_rec.shape[0]:step]
vocals_gt = np.transpose(vocals_gt.reshape(len(vocals_gt), 1))

drums_gt = drums_gt[0:drums_rec.shape[0]:step]
drums_gt = np.transpose(drums_gt.reshape(len(drums_gt), 1))

other_gt = other_gt[0:other_rec.shape[0]:step]
other_gt = np.transpose(other_gt.reshape(len(other_gt), 1))

final_gt = np.concatenate((bass_gt, vocals_gt, drums_gt, other_gt), axis = 0)
print(final_gt.shape)


bass_rec = bass_rec[0:bass_rec.shape[0]:step]
bass_rec = np.transpose(bass_rec.reshape(len(bass_rec), 1))

vocals_rec = vocals_rec[0:vocals_rec.shape[0]:step]
vocals_rec = np.transpose(vocals_rec.reshape(len(vocals_rec), 1))

drums_rec = drums_rec[0:drums_rec.shape[0]:step]
drums_rec = np.transpose(drums_rec.reshape(len(drums_rec), 1))

other_rec = other_rec[0:other_rec.shape[0]:step]
other_rec = np.transpose(other_rec.reshape(len(other_rec), 1))

final_rec = np.concatenate((bass_rec, vocals_rec, drums_rec, other_rec), axis = 0)
print(final_rec.shape)



SDR, SIR, SAR, perm = mir_eval.separation.bss_eval_sources(final_gt, final_rec)

print(SDR)
print(SIR)
print(SAR)
print(perm)

# Join

In [None]:
import librosa
import soundfile as sf  # Use soundfile for writing wav files
import numpy as np
import os
import re
import glob

destination_path = '../Recovered_Songs_bigger5/'
vocals_directory = '../AudioResults/vocals'
drums_directory = '../AudioResults/drums'
bass_directory = '../AudioResults/bass'
others_directory = '../AudioResults/others'
test_songs_list = []
vocals_list = []

# Create necessary directories
if not os.path.exists(destination_path):
    os.makedirs(destination_path)
if not os.path.exists(vocals_directory):
    os.makedirs(vocals_directory)
if not os.path.exists(drums_directory):
    os.makedirs(drums_directory)
if not os.path.exists(bass_directory):
    os.makedirs(bass_directory)
if not os.path.exists(others_directory):
    os.makedirs(others_directory)

# Collect all unique test songs
for subdirs, dirs, files in os.walk(vocals_directory):
    print('Finding list of songs')
    for file in files:
        regex = re.compile(r'\d+')
        index = regex.findall(file)
        if index and index[0] not in test_songs_list:
            test_songs_list.append(index[0])

# Iterate through each test song and combine audio segments
for test_songs in test_songs_list:
    combined_vocals = np.array([])
    sr = None
    print(f'Testing: {test_songs}')
    print('Stitching Vocals')

    # Get list of vocals files for the current song
    vocals_list = sorted(glob.glob(os.path.join(vocals_directory, test_songs + "*")))
    vocals_path = os.path.join(destination_path, 'vocals')
    if not os.path.exists(vocals_path):
        os.makedirs(vocals_path)

    sound_output_path = os.path.join(vocals_path, f'{test_songs}.wav')

    # Combine vocals segments
    for segment in vocals_list:
        seg, sr = librosa.load(segment, sr=44100)
        assert sr == 44100
        combined_vocals = np.append(combined_vocals, seg)
    
    # Save combined vocals using soundfile.write
    sf.write(sound_output_path, combined_vocals, sr)

    print('Stitching Bass')
    combined_bass = np.array([])
    bass_list = sorted(glob.glob(os.path.join(bass_directory, test_songs + "*")))
    bass_path = os.path.join(destination_path, 'bass')
    if not os.path.exists(bass_path):
        os.makedirs(bass_path)
    
    sound_output_path = os.path.join(bass_path, f'{test_songs}.wav')

    # Combine bass segments
    for segment in bass_list:
        seg, sr = librosa.load(segment, sr=44100)
        assert sr == 44100
        combined_bass = np.append(combined_bass, seg)
    
    # Save combined bass using soundfile.write
    sf.write(sound_output_path, combined_bass, sr)

    print('Stitching Drums')
    combined_drums = np.array([])
    drums_list = sorted(glob.glob(os.path.join(drums_directory, test_songs + "*")))
    drums_path = os.path.join(destination_path, 'drums')
    if not os.path.exists(drums_path):
        os.makedirs(drums_path)
    
    sound_output_path = os.path.join(drums_path, f'{test_songs}.wav')

    # Combine drums segments
    for segment in drums_list:
        seg, sr = librosa.load(segment, sr=44100)
        combined_drums = np.append(combined_drums, seg)
    
    # Save combined drums using soundfile.write
    sf.write(sound_output_path, combined_drums, sr)

    print('Stitching Others')
    combined_others = np.array([])
    others_list = sorted(glob.glob(os.path.join(others_directory, test_songs + "*")))
    others_path = os.path.join(destination_path, 'others')
    if not os.path.exists(others_path):
        os.makedirs(others_path)
    
    sound_output_path = os.path.join(others_path, f'{test_songs}.wav')

    # Combine other segments
    for segment in others_list:
        seg, sr = librosa.load(segment, sr=44100)
        combined_others = np.append(combined_others, seg)
    
    # Save combined others using soundfile.write
    sf.write(sound_output_path, combined_others, sr)

print("All songs have been stitched and saved successfully.")
