In [None]:
import pandas as pd
import os
import torch
from scipy.io import wavfile
import numpy as np
import gc  # Import garbage collector

csv_file_path = r"../../Dataset/neurovoz_v3/data/audio_features/audio_features.csv"
df = pd.read_csv(csv_file_path)

audio_directory = r"../../Dataset/neurovoz_v3/data/audios"

i = 0
for row in df.itertuples():
    with torch.no_grad():
        relative_path = row.AudioPath.strip()

        if relative_path.startswith('../data/audios/'):
            relative_path = relative_path.replace('../data/audios/', '')

        file_path = os.path.join(audio_directory, relative_path)
        base_filename = os.path.splitext(os.path.basename(file_path))[0]

        try:
            sample_rate, f = wavfile.read(file_path)

            if len(f.shape) == 2:
                f = f[:, 0]

            f = f / np.max(np.abs(f))

            # Parameters
            alpha = float(f.shape[0])
            tau = 0
            K = 8
            DC = False
            init = 1
            tol = 1e-30

            device = torch.device('cuda')
            f = torch.tensor(f, dtype=torch.float32, device=device, requires_grad=False)

            # Ensure even length
            if len(f) % 2:
                f = f[:-1]

            fs = torch.tensor(1.0 / len(f), device=device, requires_grad=False)
            ltemp = len(f) // 2

            fMirr = torch.empty(len(f) + 2 * ltemp, device=device, requires_grad=False)
            fMirr[:ltemp] = torch.flip(f[:ltemp], dims=[0])
            fMirr[ltemp:ltemp + len(f)] = f
            fMirr[ltemp + len(f):] = torch.flip(f[-ltemp:], dims=[0])

            T = len(fMirr)
            t = torch.arange(1, T + 1, device=device, requires_grad=False) / T
            freqs = t - 0.5 - (1 / T)

            Niter = 500
            Alpha = torch.full((K,), alpha, device=device, requires_grad=False)

            f_hat = torch.fft.fftshift(torch.fft.fft(fMirr.contiguous()))
            f_hat_plus = f_hat.clone()
            f_hat_plus[:T // 2] = 0

            if init == 1:
                omega_curr = torch.linspace(0, 0.5, K, device=device, requires_grad=False)
            else:
                omega_curr = torch.zeros(K, device=device, requires_grad=False)

            if DC:
                omega_curr[0] = 0

            lambda_curr = torch.zeros(len(freqs), dtype=torch.cfloat, device=device, requires_grad=False)
            u_curr = torch.zeros((len(freqs), K), dtype=torch.cfloat, device=device, requires_grad=False)
            u_prev = torch.zeros((len(freqs), K), dtype=torch.cfloat, device=device, requires_grad=False)

            omega_history = torch.zeros((Niter, K), device=device, requires_grad=False)
            omega_history[0] = omega_curr

            uDiff = tol + torch.finfo(torch.float32).eps
            n = 0
            sum_uk = torch.zeros(len(freqs), dtype=torch.cfloat, device=device, requires_grad=False)

            while uDiff > tol and n < Niter - 1:
                u_prev.copy_(u_curr)

                sum_uk = torch.sum(u_prev, dim=1) - u_prev[:, 0]
                u_curr[:, 0] = (f_hat_plus - sum_uk - lambda_curr / 2) / (1 + Alpha[0] * (freqs - omega_curr[0]) ** 2)

                if not DC:
                    omega_curr[0] = (
                        torch.sum(freqs[T // 2:T] * (torch.abs(u_curr[T // 2:T, 0]) ** 2)) /
                        torch.sum(torch.abs(u_curr[T // 2:T, 0]) ** 2)
                    )

                for k in range(1, K):
                    sum_uk += u_curr[:, k - 1] - u_prev[:, k]
                    u_curr[:, k] = (f_hat_plus - sum_uk - lambda_curr / 2) / (1 + Alpha[k] * (freqs - omega_curr[k]) ** 2)
                    omega_curr[k] = (
                        torch.sum(freqs[T // 2:T] * (torch.abs(u_curr[T // 2:T, k]) ** 2)) /
                        torch.sum(torch.abs(u_curr[T // 2:T, k]) ** 2)
                    )

                lambda_curr += tau * (torch.sum(u_curr, dim=1) - f_hat_plus)
                omega_history[n + 1] = omega_curr
                n += 1
                uDiff = torch.sum(torch.abs(u_curr - u_prev) ** 2) / T

            Niter = min(Niter, n)
            omega = omega_history[:Niter]

            u_hat = torch.zeros((T, K), dtype=torch.cfloat, device=device, requires_grad=False)
            u_hat[T // 2:T, :] = u_curr[T // 2:T, :]
            u_hat[1:T // 2, :] = torch.conj(torch.flip(u_curr[T // 2 + 1:T, :], dims=[0]))
            u_hat[0, :] = torch.conj(u_hat[-1, :])

            u = torch.zeros((K, T), device=device, requires_grad=False)
            for k in range(K):
                u[k] = torch.real(torch.fft.ifft(torch.fft.ifftshift(u_hat[:, k].contiguous())))

            u = u[:, T // 4:3 * T // 4]

            u_hat_final = torch.zeros((u.shape[1], K), dtype=torch.cfloat, device=device, requires_grad=False)
            for k in range(K):
                u_hat_final[:, k] = torch.fft.fftshift(torch.fft.fft(u[k].contiguous()))

            u, u_hat_final, omega = u.cpu().numpy(), u_hat_final.cpu().numpy(), omega.cpu().numpy()

            i += 1
            print(i)

            # Cleanup
            del (
                u_hat_final, omega, u, fMirr, f_hat, freqs, f_hat_plus, lambda_curr,
                t, f, u_curr, u_prev, omega_history, u_hat, fs, Alpha, omega_curr,
                uDiff, sum_uk
            )
            torch.cuda.empty_cache()
            gc.collect()

        except Exception as e:
            print(f"Error processing {file_path}: {e}")
