In [1]:
import numpy as np
import librosa
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.io import wavfile

# Define the root folder containing multiple subfolders
root_folder = "TEST/DR1/FAKS0"  # Change this to your folder path

# Recursively get all .wav files from all subdirectories
file_paths = [str(file) for file in Path(root_folder).rglob("*.wav")]

# Read all wav files and store them in a list
min_lgth = 0
wav_data = []
for file in file_paths:
    data, _ = librosa.load(file, sr=None)  # Keep original sample rate
    max_data= np.max(data)
    data /= max_data
    wav_data.append(data)
    
# Convert list to a matrix (pad shorter signals with zeros)
max_length = max(len(x) for x in wav_data)  # Find max length of any audio file
wav_matrix = np.array([np.pad(x, (0, max_length - len(x))) for x in wav_data])

print("Total WAV files found:", len(file_paths))
print("WAV Matrix shape:", wav_matrix.shape)  # (num_files, max_length)

Total WAV files found: 10
WAV Matrix shape: (10, 79565)


In [3]:
wav_matrix[np.isnan(wav_matrix)] = 0
wav_matrix = wav_matrix.reshape((10, 79565, 1))
wav_matrix = wav_matrix.reshape((79565, 10, 1))
wav_matrix.shape

(79565, 10, 1)

In [2]:
wav_matrix[np.isnan(wav_matrix)] = 0
np.isnan(wav_matrix).any()

False

In [5]:
import numpy as np
import sys
sys.path.append("/mnt/d_disk/ch22b007/mdla")  # Adjust the path accordingly

from mdla import MiniBatchMultivariateDictLearning, MultivariateDictLearning

from mdla import MultivariateDictLearning
from mdla import multivariate_sparse_encode
from numpy.linalg import norm

# rng_global = np.random.RandomState(0)
# n_samples, n_features, n_dims = 10, 5, 3
# X = rng_global.randn(n_samples, n_features, n_dims)
X = wav_matrix

n_kernels = 10
dico = MultivariateDictLearning(n_kernels=n_kernels, max_iter=10).fit(X)
residual, code = multivariate_sparse_encode(X, dico)
print ('Objective error for each samples is:')
for i in range(len(residual)):
    print ('Sample', i, ':', norm(residual[i], 'fro') + len(code[i]))

ValueError: Input contains NaN.

In [None]:
K = dico.kernels_
K = np.array(K)
K.shape

In [None]:
np.shape(code)

In [None]:
np.max(X[:, :, :])

In [None]:
X[0, :, :]

In [None]:
residual = np.array(residual)
code = np.array(code)
print(np.shape(residual))
print(np.shape(code))

# np.array(residual[0, :, :])@np.array(code)

In [None]:
residual[0, :, :]

In [None]:
from gammatone_utils import *
# from scikits.talkbox import segment_axis
# from scikits.audiolab import Sndfile, play
import soundfile as sf
import sounddevice as sd
import pydub 
from pydub import AudioSegment
from pydub.playback import play
import matplotlib.pyplot as plt
# from encoding1 import * 
plt.style.use('ggplot')

def matching_pursuit(signal, dict_kernels, threshold=0.1, max_iter=2000):
    """
    Matching pursuit algorithm for encoding
    :param signal: input signal
    :param dict_kernels: dictionary of kernels, each column is a kernel
    :param threshold: stop condition
    :param max_iter: maximum number of iterations
    :return: array of scalar weighting factor (one per kernel)
    """
    # Initialization
    res = signal
    coeff = np.zeros(dict_kernels.shape[0])
    # Iterative decomposition
    for i in range(max_iter):
        inner_prod = res.dot(dict_kernels.T)
        max_kernel = np.argmax(inner_prod)
        coeff[max_kernel] = inner_prod[max_kernel] / np.linalg.norm(dict_kernels[max_kernel,: ])**2
        res = res - coeff[max_kernel] * dict_kernels[max_kernel,: ]
        if np.linalg.norm(res) < threshold:
            return coeff
    return coeff

def segment_axis(arr, frame_size, overlap, end='pad'):
    step = frame_size - overlap
    if end == 'pad':
        pad_width = (frame_size - (len(arr) % step)) % frame_size
        arr = np.pad(arr, (0, pad_width), mode='constant')
    
    return np.lib.stride_tricks.sliding_window_view(arr, frame_size)[::step]

# Parametrization
b = 1.019
resolution = 160
step = 8
n_channels = 128
overlap = 50

# Compute gammatone-based dictionary
D_multi = np.r_[tuple(gammatone_matrix(b, fc, resolution, step)[0] for
                      fc in erb_space(150, 8000, n_channels))]
freq_c = np.array([gammatone_matrix(b, fc, resolution, step)[1] for
                      fc in erb_space(150, 8000, n_channels)]).flatten()
centers = np.array([gammatone_matrix(b, fc, resolution, step)[2] + i*resolution  for
                      i, fc in enumerate(erb_space(150, 8000, n_channels))]).flatten()

# Load test signal
filename = 'TEST/DR1/FAKS0/SX403.wav'
# f = Sndfile(filename, 'r')
f, samplerate1 = sf.read(filename)
f = sf.SoundFile(filename)

nf = len(f) # f.nframes
fs = samplerate1
length_sound = 20000
y = f.read(frames = length_sound)
# y = f.read_frames(length_sound)

Y = segment_axis(y, resolution, overlap=overlap, end='pad')
Y = np.hanning(resolution) * Y

# Encoding with matching pursuit
X = np.zeros((Y.shape[0],D_multi.shape[0]))
for idx in range(Y.shape[0]):
    X[idx, :] = matching_pursuit(Y[idx, :], D_multi)

# Reconstruction of the signal
out = np.zeros(int((np.ceil(len(y)/resolution)+1)*resolution))
for k in range(0, len(X)):
    idx = range(k*(resolution-overlap), k*(resolution-overlap) + resolution)
    out[idx] += np.dot(X[k], D_multi)
squared_error = np.sum((y - out[0:len(y)]) ** 2)


arr = np.array(range(length_sound))/float(fs)
plt.figure(1)
plt.subplot(311)
plt.plot(arr, y, 'b', label="Input Signal")
plt.legend()
plt.subplot(312)
plt.plot(arr, out[0:len(y)], 'r', label="Recontruction")
plt.legend()
plt.subplot(313)
plt.plot(arr, (y - out[0:len(y)])**2, 'g', label="Residual")
plt.legend()
plt.xlabel("Time in s")
plt.show()

# 2nd plot: spike train
plt.figure(2)
spikes_pos = np.array(np.nonzero(X))
temporal_position = centers[spikes_pos[0][:]]
centre_freq = freq_c[spikes_pos[1][:]]
plt.scatter(temporal_position, centre_freq, marker='+', s=1)
plt.show()

# 3rd plot: example of gammatone-based dictionary
fig = plt.figure(3)
fig.suptitle("Gammatone filters", fontsize="x-large")
freqs = [1000, 300, 40]
resolution = 5000
for center in [100, 1500, 3000]:
    plt.subplot(311)
    plt.plot(gammatone_function(resolution, freqs[0], center), linewidth=1.5)
    plt.subplot(312)
    plt.plot(gammatone_function(resolution, freqs[1], center+300), linewidth=1.5)
    plt.ylabel("Kernel values")
    plt.subplot(313)
    plt.plot(gammatone_function(resolution, freqs[2], center+1000), linewidth=1.5)
    plt.xlabel("Time (s)")
plt.show()


In [None]:
freqs[0]

In [None]:
b = 1.019
resolution = 3600
step = 8
n_channels = 128
overlap = 50

g_1 = gammatone_function(resolution, fc= 600, center=200, fs=16000, b=b)

In [None]:
plt.plot(g_1)
plt.xlim([0, 600])

In [None]:
np.shape(D_multi)