## Imports

In [None]:
import pickle
import os
import sys
import numpy as np
import torch
import torchaudio
import matplotlib
import matplotlib.pyplot as plt
import argparse
import psutil
import GPUtil

# Navigate up one level to the 'pretraining' directory, where 'dataloader.py' is located
sys.path.append(os.path.abspath('../'))

import dataloader

# Define the base path where your pickle file is located
base_path2 = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/mask01-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli'
base_path3 = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/mask01-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-20240412-172636'
base_path1_at_same_time = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/mask01-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-20240413-164417'
base_path_original = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-original-20240416-103133'
base_path_shuffled = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-shuffled-20240416-102831'
path_original_3 = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-original-20240418-211014'
path_original_correctMean = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-20240501-162648-original-base-f128-t2-b48-lr1e-4-m390-pretrain_joint-asli'

## Model Configuration

In [None]:
# Construct the full path to the 'args.pkl' file
args_file_path2 = os.path.join(path_original_3, 'args.pkl')

# Load the arguments from the pickle file
with open(args_file_path2, 'rb') as f:
    args2 = pickle.load(f)

# Convert the Namespace to a dictionary if it is of that type
if isinstance(args2, argparse.Namespace):
    args_dict2 = vars(args2)
else:
    print("The loaded 'args' object is not an argparse.Namespace. Its type is:", type(args2))
    exit()

# Determine the maximum width of the argument names for alignment
max_key_length = max(len(key) for key in args_dict2.keys())

# Print the arguments in a structured table format
print(f"{'Argument':<{max_key_length}} | Value")
print("-" * (max_key_length + 3) + "+" + "-" * 30)  # Adjust 30 if you expect wider values

for key, value in args_dict2.items():
    print(f"{key:<{max_key_length}} | {value}")

## Epochs, Iterations and Time Required

### Compare the time required for training the model with 2 and 3 GPUs

In [None]:
# open progress files
with open(os.path.join(base_path2, 'progress.pkl'), 'rb') as f:
    progress2 = pickle.load(f)

with open(os.path.join(base_path3, 'progress.pkl'), 'rb') as f:
    progress3 = pickle.load(f)

with open(os.path.join(base_path_original, 'progress.pkl'), 'rb') as f:
    progress_original = pickle.load(f)

with open(os.path.join(base_path_shuffled, 'progress.pkl'), 'rb') as f:
    progress_shuffled = pickle.load(f)

with open(os.path.join(path_original_3, 'progress.pkl'), 'rb') as f:
    progress_original_3 = pickle.load(f)

with open(os.path.join(path_original_correctMean, 'progress.pkl'), 'rb') as f:
    progress_original_correctMean = pickle.load(f)

In [None]:
# get the iteration (is at the second position in the list of progress)
iteration2 = np.array([x[1] for x in progress2])
iteration3 = np.array([x[1] for x in progress3])
iteration_original = np.array([x[1] for x in progress_original])
iteration_shuffled = np.array([x[1] for x in progress_shuffled])
iter_original = np.array([x[1] for x in progress_original_3])
iter_original_correctMean = np.array([x[1] for x in progress_original_correctMean])*2

# get time (is at the fourth position in the list of progress)
time2 = np.array([x[3] for x in progress2])
time3 = np.array([x[3] for x in progress3])
time_original = np.array([x[3] for x in progress_original])
time_shuffled = np.array([x[3] for x in progress_shuffled])
time_orig = np.array([x[3] for x in progress_original_3])
time_original_correctMean = np.array([x[3] for x in progress_original_correctMean])

In [None]:
# calculate total time required for all 800k iterations based on the number of iterations and the time per iteration
n_iterations = 800000
time_per_iteration2 = time2[10] / iteration2[10]
time_per_iteration3 = time3[10] / iteration3[10]
time_per_iteration_original = time_original[-1] / iteration_original[-1]
time_per_iteration_shuffled = time_shuffled[-1] / iteration_shuffled[-1]
time_per_iter_orig = time_orig[-1] / iter_original[-1]
time_per_iter_orig_correctMean = time_original_correctMean[-1] / iter_original_correctMean[-1]

total_time2 = time_per_iteration2 * n_iterations
total_time3 = time_per_iteration3 * n_iterations
total_time_original = time_per_iteration_original * n_iterations
total_time_shuffled = time_per_iteration_shuffled * n_iterations
total_time_orig = time_per_iter_orig * n_iterations
total_time_orig_correctMean = time_per_iter_orig_correctMean * n_iterations/2

print(f"Total time for 2 GPUs: {total_time2/3600:.2f} hours")
print(f"Total time for 3 GPUs: {total_time3/3600:.2f} hours")
print(f"Total time for original: {total_time_original/3600:.2f} hours")
print(f"Total time for shuffled: {total_time_shuffled/3600:.2f} hours")
print(f"Total time for original 3 PGUs: {total_time_orig/3600:.2f} hours")
print(f"Total time for original correctMean: {total_time_orig_correctMean/3600:.2f} hours")

In [None]:
# plot iter vs time all in one plot
plt.plot(iteration2, time2, label='2 GPUs')
plt.plot(iteration3, time3, label='3 GPUs')
plt.plot(iteration_original, time_original, label='original')
plt.plot(iteration_shuffled, time_shuffled, label='shuffled')
plt.plot(iter_original, time_orig, label='original 3 GPUs')
plt.xlabel('Iteration')
plt.ylabel('Time [s]')
plt.legend()
plt.show()


## Training and Evaluation Loss

In [None]:
# Construct the full path to the 'result.csv' file
result_file_path3 = os.path.join(base_path3, 'result.csv')

# Load the result from the csv file
result3 = np.genfromtxt(result_file_path3, delimiter=',')

# Extract the columns from the result
acc_train3 = result3[:, 0] # The first column
loss_train3 = result3[:, 1] # The second column
acc_eval3 = result3[:, 2] # The third column
mse_eval3 = result3[:, 3] # The fourth column

### Compare the training and evaluation loss for the model trained with original and shuffled spectrograms

In [None]:
# Construct the full path to the 'result.csv' file
result_file_path_original = os.path.join(base_path_original, 'result.csv')
result_file_path_shuffled = os.path.join(base_path_shuffled, 'result.csv')

# Load the result from the csv file
result_original = np.genfromtxt(result_file_path_original, delimiter=',')
result_shuffled = np.genfromtxt(result_file_path_shuffled, delimiter=',')
print("shape of result_original: ", result_original.shape)
print("shape of result_shuffled: ", result_shuffled.shape)

# Extract the columns from the result

acc_train_original = result_original[:, 0] # The first column
loss1_train_original = result_original[:, 1] # The second column
loss2_train_original = result_original[:, 2] # The third column
acc_eval_original = result_original[:, 3] # The fourth column
loss1_eval_original = result_original[:, 4] # The fifth column
loss2_eval_original = result_original[:, 5] # The sixth column

acc_train_shuffled = result_shuffled[:, 0] # The first column
loss1_train_shuffled = result_shuffled[:, 1] # The second column
loss2_train_shuffled = result_shuffled[:, 2] # The third column
acc_eval_shuffled = result_shuffled[:, 3] # The fourth column
loss1_eval_shuffled = result_shuffled[:, 4] # The fifth column
loss2_eval_shuffled = result_shuffled[:, 5] # The sixth column

learning_rate_original = result_original[:, 6] # The seventh column
learning_rate_shuffled = result_shuffled[:, 6] # The seventh column


# Define the format for each column
header_format = " {:>5}  | {:<10} | {:<10} | {:<10} | {:<10} | {:<10} | {:<10}"
row_format = "{:>5}k  | {:<10.5f} | {:<10.5f} | {:<10.5f} | {:<10.5f} | {:<10.5f} | {:<10.5f}"

In [None]:
# Orignal
print("Original Spectrograms:")
print("-" * 86)
print(header_format.format("iter", "acc_train", "loss1_tr", "loss2_tr", "acc_ev", "loss1_ev", "loss2_ev"))
print("-" * 86)  # Adjust the total length to fit your headers and column data
for i in range(len(acc_train_original)):
    print(row_format.format(iteration_original[i]/1e3, acc_train_original[i], loss1_train_original[i], loss2_train_original[i], acc_eval_original[i], loss1_eval_original[i], loss2_eval_original[i]))
print("-" * 86 + "\n")

In [None]:
# Shuffled
print("Shuffled Spectrograms:")
print("-" * 86)
print(header_format.format("iter", "acc_train", "loss1_tr", "loss2_tr", "acc_ev", "loss1_ev", "loss2_ev"))
print("-" * 86)  # Adjust the total length to fit your headers and column data
for i in range(len(acc_train_shuffled)):
    print(row_format.format(iteration_shuffled[i]/1e3, acc_train_shuffled[i], loss1_train_shuffled[i], loss2_train_shuffled[i], acc_eval_shuffled[i], loss1_eval_shuffled[i], loss2_eval_shuffled[i]))
print("-" * 86 + "\n")

In [None]:
result_gong_path = '/home/bosfab01/SpeakerVerificationBA/pretraining/result_gong.csv'
result_original_path = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-original-20240416-103133/result.csv'
result_shuffled_path = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-shuffled-20240416-102831/result.csv'
result_3GPU_path = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/mask01-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-20240412-172636/result.csv'
result_original_3GPU_path = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-base-f128-t2-b24-lr1e-4-m400-pretrain_joint-asli-original-20240418-211014/result.csv'
result_original_correctMean_path = '/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-20240501-162648-original-base-f128-t2-b48-lr1e-4-m390-pretrain_joint-asli/result.csv'

result_gong = np.genfromtxt(result_gong_path, delimiter=',')
result_original = np.genfromtxt(result_original_path, delimiter=',')
result_shuffled = np.genfromtxt(result_shuffled_path, delimiter=',')
result_3GPU = np.genfromtxt(result_3GPU_path, delimiter=',')
result_original_3GPU = np.genfromtxt(result_original_3GPU_path, delimiter=',')
result_original_correctMean = np.genfromtxt(result_original_correctMean_path, delimiter=',')

# print shapes of the arrays
print("shape of result_gong: ", result_gong.shape)
print("shape of result_original: ", result_original.shape)
print("shape of result_shuffled: ", result_shuffled.shape)
print("shape of result_3GPU: ", result_3GPU.shape)
print("shape of result_original_3GPU: ", result_original_3GPU.shape)
print("shape of result_original_correctMean: ", result_original_correctMean.shape)

In [None]:
# function to return columns of a numpy array
def get_column(array):
    for i in range(array.shape[1]):
        yield array[:, i]

acc_tr_gong, loss_tr_gong, acc_ev_gong, mse_ev_gong, lr_gong = get_column(result_gong)
acc_tr_3GPU, loss_tr_3GPU, acc_ev_3GPU, mse_ev_3GPU, _ = get_column(result_3GPU)
acc_tr_original, loss1_tr_original, loss2_tr_original, acc_ev_original, loss1_ev_original, loss2_ev_original, _ = get_column(result_original)
acc_tr_shuffled, loss1_tr_shuffled, loss2_tr_shuffled, acc_ev_shuffled, loss1_ev_shuffled, loss2_ev_shuffled, _ = get_column(result_shuffled)
acc_tr_original_3GPU, loss1_tr_original_3GPU, loss2_tr_original_3GPU, acc_ev_original_3GPU, loss1_ev_original_3GPU, loss2_ev_original_3GPU, lr_original_3GPU = get_column(result_original_3GPU)
acc_tr_original_correctMean, loss1_tr_original_correctMean, loss2_tr_original_correctMean, acc_ev_original_correctMean, loss1_ev_original_correctMean, loss2_ev_original_correctMean, lr_original_correctMean = get_column(result_original_correctMean)

iteration_gong = np.arange(1, len(acc_tr_gong)+1) * 4000

In [None]:
# plot all the stats
import matplotlib.pyplot as plt

fig, axs = plt.subplots(2, 2, figsize=(14, 10))

axs[0, 0].plot(iteration_gong[:len(acc_tr_original_3GPU)+5]/1e3, acc_tr_gong[:len(acc_tr_original_3GPU)+5], label='gong')
axs[0, 0].plot(iteration3[:len(acc_tr_original_3GPU)+5]/1e3, acc_tr_3GPU[:len(acc_tr_original_3GPU)+5], label='original 1024')
axs[0, 0].plot(iter_original_correctMean/1e3, acc_tr_original_correctMean, label='original correctMean')
# axs[0, 0].plot(iteration_original/1e3, acc_tr_original, label='original')
# axs[0, 0].plot(iteration_shuffled/1e3, acc_tr_shuffled, label='shuffled')
axs[0, 0].plot(iter_original/1e3, acc_tr_original_3GPU, label='original 998')
#axs[0, 0].set_title('Accuracy Training')
#axs[0, 0].set_xticks(np.arange(0, len(acc_tr_original)+5, 1))
axs[0, 0].set_xlabel('Iteration [k]')
axs[0, 0].set_ylabel('Accuracy Training')
axs[0, 0].legend()
axs[0, 0].grid()


axs[0, 1].plot(iteration_gong[:len(acc_tr_original_3GPU)+5]/1e3, loss_tr_gong[:len(acc_tr_original_3GPU)+5], label='gong')
axs[0, 1].plot(iteration3[:len(acc_tr_original_3GPU)+5]/1e3, loss_tr_3GPU[:len(acc_tr_original_3GPU)+5], label='original 1024')
axs[0, 1].plot(iter_original_correctMean/1e3, loss1_tr_original_correctMean+10*loss2_tr_original_correctMean, label='original correctMean')
# axs[0, 1].plot(iteration_original/1e3, loss1_tr_original+10*loss2_tr_original, label='original')
# axs[0, 1].plot(iteration_shuffled/1e3, loss1_tr_shuffled+10*loss2_tr_shuffled, label='shuffled')
axs[0, 1].plot(iter_original/1e3, loss1_tr_original_3GPU+10*loss2_tr_original_3GPU, label='original 998')
#axs[0, 1].set_title('Loss Training')
#axs[0, 1].set_xticks(np.arange(0, len(acc_tr_original)+5, 1))
axs[0, 1].set_xlabel('Iteration [k]')
axs[0, 1].set_ylabel('Loss Training')
axs[0, 1].legend()
axs[0, 1].grid()

axs[1, 0].plot(iteration_gong[:len(acc_tr_original_3GPU)+5]/1e3, acc_ev_gong[:len(acc_tr_original_3GPU)+5], label='gong')
axs[1, 0].plot(iteration3[:len(acc_tr_original_3GPU)+5]/1e3, acc_ev_3GPU[:len(acc_tr_original_3GPU)+5], label='original 1024')
axs[1, 0].plot(iter_original_correctMean/1e3, acc_ev_original_correctMean, label='original correctMean')
# axs[1, 0].plot(iteration_original/1e3, acc_ev_original, label='original')
# axs[1, 0].plot(iteration_shuffled/1e3, acc_ev_shuffled, label='shuffled')
axs[1, 0].plot(iter_original/1e3, acc_ev_original_3GPU, label='original 998')
#axs[1, 0].set_title('Accuracy Evaluation')
#axs[1, 0].set_xticks(np.arange(0, len(acc_tr_original)+5, 1))
axs[1, 0].set_xlabel('Iteration [k]')
axs[1, 0].set_ylabel('Accuracy Evaluation')
axs[1, 0].legend()
axs[1, 0].grid()

axs[1, 1].plot(iteration_gong[:len(acc_tr_original_3GPU)+5]/1e3, mse_ev_gong[:len(acc_tr_original_3GPU)+5], label='gong')
axs[1, 1].plot(iteration3[:len(acc_tr_original_3GPU)+5]/1e3, mse_ev_3GPU[:len(acc_tr_original_3GPU)+5], label='original 1024')
axs[1, 1].plot(iter_original_correctMean/1e3, loss2_ev_original_correctMean, label='original correctMean')
# axs[1, 1].plot(iteration_original/1e3, loss2_ev_original, label='original')
# axs[1, 1].plot(iteration_shuffled/1e3, loss2_ev_shuffled, label='shuffled')
axs[1, 1].plot(iter_original/1e3, loss2_ev_original_3GPU, label='original 998')
#axs[1, 1].set_title('MSE Evaluation')
#axs[1, 1].set_xticks(np.arange(0, len(acc_tr_original)+5, 1))
axs[1, 1].set_xlabel('Iteration [k]')
axs[1, 1].set_ylabel('MSE Evaluation')
axs[1, 1].legend()
axs[1, 1].grid()

plt.show()

In [None]:
# plot learning rate on log scale
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(iteration_original/1e3, learning_rate_original, label='original')
ax.plot(iteration_shuffled/1e3, learning_rate_shuffled, label='shuffled')
ax.plot(iter_original/1e3, lr_original_3GPU, label='original 998')
ax.plot(iteration_gong/1e3, lr_gong, label='gong')
ax.set_yscale('log')
ax.set_xlabel('Iteration [k]')
ax.set_ylabel('Learning Rate')
ax.legend()
ax.grid()
plt.show()

In [None]:
#plt.plot(iteration_shuffled/1e3, loss2_ev_shuffled, label='shuffled')
plt.plot(iteration_shuffled/1e3, loss1_ev_shuffled, label='shuffled')

## testing the dataloader

### demonstration of mismatch between target length and input length

In [None]:
train_loader = torch.utils.data.DataLoader(
    dataloader.AudioDataset(
        dataset_json_file='/home/bosfab01/SpeakerVerificationBA/data/audioset2M_librispeech960.json',
        audio_conf={
            'num_mel_bins': 128,
            'target_length': 1024,
            'freqm': 0,
            'timem': 0,
            'mixup': 0,
            'dataset': 'asli',
            'mean': -3.6925695,
            'std': 4.020388,
            'noise': False,
            'mode': 'train',
            'shuffle_frames': False
        },
        label_csv='/home/bosfab01/SpeakerVerificationBA/data/label_information.csv'
    ),
    batch_size=24,
    shuffle=True,
    num_workers=16,
    pin_memory=True,
    drop_last=True
)

### get one batch of data from the dataloader and display the images

In [None]:
# Create an iterator from the DataLoader
data_iterator = iter(train_loader)

# Fetch the first batch
audio_input, labels = next(data_iterator)

# Print out the details to see what the batch contains
print("Audio input shape:", audio_input.shape)
print("Labels shape:", labels.shape)

In [None]:
# print the last 10 spectra of the first sample in the batch
# this is to check if the number of frames matches the target_length

colors = plt.cm.viridis(np.linspace(0, 1, 10))  # Generate 10 colors from the 'viridis' colormap

for i in range(10, 0, -1):
    markerline, stemlines, baseline = plt.stem(audio_input[0, -i, :], linefmt='-', basefmt=" ")
    plt.setp(stemlines, 'linewidth', 2, 'color', colors[10-i])  # Set the color and line width
    plt.setp(markerline, 'marker', '')  # No marker at the end

plt.legend([f'Spectrum {-i}' for i in range(10, 0, -1)])
plt.xlabel('Frequency Bins')
plt.ylabel('Magnitude')
plt.title('Spectra of the Last 10 Frames')
plt.grid(True)
plt.show()

# Print the last spectrum of the first sample
print(audio_input[0, -1, :])

In [None]:
# Adjusted function to plot a spectrogram with the correct orientation
def plot_spectrogram(spectrogram, ax, title="Spectrogram"):
    # Transpose the spectrogram to align the axes correctly
    ax.imshow(spectrogram.T.cpu().numpy(), aspect='auto', origin='lower', cmap='viridis')
    ax.set_title(title)
    ax.set_xlim(0, spectrogram.shape[0])
    ax.set_xlabel('Time Frames')
    ax.set_ylabel('Mel Frequency Bins')

for i in range(3):
    fig, ax = plt.subplots(figsize=(10, 1.5))
    plot_spectrogram(audio_input[i, :, :], ax, title=f'Spectrogram of sample {i}')
    plt.show()

## classification objective

In [None]:
import sys
import os
import torch
import numpy as np
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)
sys.path.append(parent_directory)
from ssast_model import ASTModel
import soundfile as sf
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from PIL import Image
import qrcode
import torchaudio
import pickle
import librosa

# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)
print("Shape of audio tensor:", audio_tensor.shape)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor = audio_tensor.unsqueeze(0)
print("Shape of audio tensor:", audio_tensor.shape)

# Now call the fbank function
fbank_features = torchaudio.compliance.kaldi.fbank(
    audio_tensor, 
    sample_frequency=sample_rate, 
    htk_compat=True, 
    use_energy=False, 
    window_type='hanning', 
    num_mel_bins=128, 
    dither=0.0, 
    frame_shift=10
)

# Output the shape of the fbank features to confirm
print(f"Shape of fbank features: {fbank_features.shape}")
test_input = fbank_features

# normalize fbank features
dataset_mean=-3.6925695
dataset_std=4.020388
test_input = (test_input - dataset_mean) / (2 * dataset_std)

# add batch dimension
test_input = test_input.unsqueeze(0)
print(f"Shape of fbank features: {test_input.shape}")

# # duplicate input tensor to get a batch of 2
# test_input = torch.cat((test_input, test_input), 0)
# print(f"Shape of dublicated fbank features: {test_input.shape}")



model = ASTModel(fshape=128, tshape=2, fstride=128, tstride=2, input_fdim=128, input_tdim=998, model_size='base', pretrain_stage=True)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('/home/bosfab01/SpeakerVerificationBA/pretraining/exp/pretrained-20240429-112534-shuffled-base-f128-t2-b48-lr1e-4-m390-pretrain_joint-asli/models/audio_model.54.pth'))
model = model.module
model.to('cpu')
model.eval()
print(next(model.parameters()).device)  # Should print 'cpu'



hop_width = 20
hop_length = 50 # just for visualization, not the actual hop length used in the data preparation
hops = range(hop_length, 998//2 - hop_width//2, hop_length)
print(hops)
mask_indices = [range(i-hop_width//2, i + hop_width//2) for i in hops]
mask_indices = [idx for group in mask_indices for idx in group]
print("len(mask_indices):", len(mask_indices))

# turn indices from model basis [0, 499] to spectrogram basis [0, 998]
expanded_mask_indices = []
for idx in mask_indices:
    expanded_mask_indices.extend([2 * idx, 2 * idx + 1])  # Expanding indice

# Create a mask for the spectrogram
mask = torch.ones_like(test_input)
for idx in expanded_mask_indices:
    mask[0, idx, :] = 0  # Set the specific patches to 0

# Apply the mask to the input spectrogram
masked_spectrogram = test_input * mask

# turn into tensor
mask_indices = torch.tensor(mask_indices)

print("shape of mask_indices:", mask_indices.shape)

# Call the model
with torch.no_grad():
    c_vec, x_vec, prob, nce = model(test_input, task='show_classification_head', mask_indices=mask_indices)

# compare input and output
print(test_input.shape)
print(c_vec.shape)
print(x_vec.shape)
print(prob.shape)
print("expected probability:", 1 / len(mask_indices))
print("actual probabilities:", prob.cpu().numpy())

# plot the first 3 vectors from c_vec and x_vec
plt.figure(figsize=(10, 6))
plt.plot(c_vec[0, 0, :].numpy(), label='c_vec', color='blue')
plt.plot(x_vec[0, 0, :].numpy(), label='x_vec', color='blue', linestyle='dotted')
plt.plot(c_vec[0, 1, :].numpy(), label='c_vec', color='green')
plt.plot(x_vec[0, 1, :].numpy(), label='x_vec', color='green', linestyle='dotted')
plt.plot(c_vec[0, 2, :].numpy(), label='c_vec', color='yellow')
plt.plot(x_vec[0, 2, :].numpy(), label='x_vec', color='yellow', linestyle='dotted')
plt.legend()
plt.grid()
plt.show()


# probability matrix 
probabilities = c_vec[0].cpu().numpy() @ x_vec[0].cpu().numpy().T
plt.figure(figsize=(6,4))
plt.imshow(probabilities, cmap='viridis', aspect='auto')
plt.xlabel('c_vec')
plt.ylabel('x_vec')
plt.title('Pre-Probability Matrix')
plt.colorbar()
plt.show()

## own function for fbank (to ensure I understand the process)

### torchaudio

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
print("Data type of audio tensor:", audio_tensor.dtype)

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor = audio_tensor.unsqueeze(0)
print("Shape of audio tensor:", audio_tensor.shape)

# Now call the fbank function
fbank_features = torchaudio.compliance.kaldi.fbank(
    audio_tensor, 
    sample_frequency=sample_rate, 
    htk_compat=True, 
    use_energy=False, 
    window_type='hanning', 
    num_mel_bins=128, 
    dither=0.0, 
    frame_shift=10
)

# Output the shape of the fbank features to confirm
print(f"Shape of fbank features: {fbank_features.shape}")

### own function for fbank

In [None]:
import math
from typing import Tuple
import torch
from torch import Tensor

# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001


def _get_epsilon(device, dtype):
    return EPSILON.to(device=device, dtype=dtype)


def _next_power_of_2(x: int) -> int:
    r"""Returns the smallest power of 2 that is greater than x"""
    return 1 if x == 0 else 2 ** (x - 1).bit_length()


def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
    r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
    representing how the window is shifted along the waveform. Each row is a frame.

    Args:
        waveform (Tensor): Tensor of size ``num_samples``
        window_size (int): Frame length
        window_shift (int): Frame shift
        snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
            in the file, and the number of frames depends on the frame_length.  If False, the number of frames
            depends only on the frame_shift, and we reflect the data at the ends.

    Returns:
        Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
    """
    assert waveform.dim() == 1
    num_samples = waveform.size(0)
    strides = (window_shift * waveform.stride(0), waveform.stride(0))

    if snip_edges:
        if num_samples < window_size:
            return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
        else:
            m = 1 + (num_samples - window_size) // window_shift
    else:
        reversed_waveform = torch.flip(waveform, [0])
        m = (num_samples + (window_shift // 2)) // window_shift
        pad = window_size // 2 - window_shift // 2
        pad_right = reversed_waveform
        if pad > 0:
            # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
            # but we want [2, 1, 0, 0, 1, 2]
            pad_left = reversed_waveform[-pad:]
            waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
        else:
            # pad is negative so we want to trim the waveform at the front
            waveform = torch.cat((waveform[-pad:], pad_right), dim=0)

    sizes = (m, window_size)
    return waveform.as_strided(sizes, strides)


def _feature_window_function(
    window_size: int,
    device: torch.device,
    dtype: int,
) -> Tensor:
    r"""Returns a window function with the given type and size"""
    return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)

def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
    r"""Returns the log energy of size (m) for a strided_input (m,*)"""
    device, dtype = strided_input.device, strided_input.dtype
    log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log()  # size (m)
    if energy_floor == 0.0:
        return log_energy
    return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))


def _get_waveform_and_window_properties(
    waveform: Tensor,
    channel: int,
    sample_frequency: float,
    frame_shift: float,
    frame_length: float,
    round_to_power_of_two: bool,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
    r"""Gets the waveform and window properties"""
    channel = max(channel, 0)
    assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
    waveform = waveform[channel, :]  # size (n)
    window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
    window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
    padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size

    assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
        window_size, len(waveform)
    )
    assert 0 < window_shift, "`window_shift` must be greater than 0"
    assert padded_window_size % 2 == 0, (
        "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
    )
    assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
    assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
    return waveform, window_shift, window_size, padded_window_size


def _get_window(
    waveform: Tensor,
    padded_window_size: int,
    window_size: int,
    window_shift: int,
    snip_edges: bool,
    raw_energy: bool,
    energy_floor: float,
    remove_dc_offset: bool,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
    r"""Gets a window and its log energy

    Returns:
        (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
    """
    device, dtype = waveform.device, waveform.dtype
    epsilon = _get_epsilon(device, dtype)

    # size (m, window_size)
    strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)

    if remove_dc_offset:
        # Subtract each row/frame by its mean
        row_means = torch.mean(strided_input, dim=1).unsqueeze(1)  # size (m, 1)
        strided_input = strided_input - row_means

    if raw_energy:
        # Compute the log energy of each row/frame before applying preemphasis and
        # window function
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    if preemphasis_coefficient != 0.0:
        # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
        offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
            0
        )  # size (m, window_size + 1)
        strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]

    # Apply window_function to each row/frame
    window_function = _feature_window_function(window_size, device, dtype).unsqueeze(
        0
    )  # size (1, window_size)
    strided_input = strided_input * window_function  # size (m, window_size)

    # Pad columns with zero until we reach size (m, padded_window_size)
    if padded_window_size != window_size:
        padding_right = padded_window_size - window_size
        strided_input = torch.nn.functional.pad(
            strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
        ).squeeze(0)

    # Compute energy after window function (not the raw one)
    if not raw_energy:
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    return strided_input, signal_log_energy


def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
    # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
    # it returns size (m, n)
    if subtract_mean:
        col_means = torch.mean(tensor, dim=0).unsqueeze(0)
        tensor = tensor - col_means
    return tensor



def inverse_mel_scale_scalar(mel_freq: float) -> float:
    return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)


def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
    return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)


def mel_scale_scalar(freq: float) -> float:
    return 1127.0 * math.log(1.0 + freq / 700.0)


def mel_scale(freq: Tensor) -> Tensor:
    return 1127.0 * (1.0 + freq / 700.0).log()


def get_mel_banks(
    num_bins: int,
    window_length_padded: int,
    sample_freq: float,
    low_freq: float,
    high_freq: float,
    vtln_low: float,
    vtln_high: float,
) -> Tuple[Tensor, Tensor]:
    """
    Returns:
        (Tensor, Tensor): The tuple consists of ``bins`` (which is
        melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
        center frequencies of bins of size (``num_bins``)).
    """
    assert num_bins > 3, "Must have at least 3 mel bins"
    assert window_length_padded % 2 == 0
    num_fft_bins = window_length_padded / 2
    nyquist = 0.5 * sample_freq

    if high_freq <= 0.0:
        high_freq += nyquist

    assert (
        (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
    ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)

    # fft-bin width [think of it as Nyquist-freq / half-window-length]
    fft_bin_width = sample_freq / window_length_padded
    mel_low_freq = mel_scale_scalar(low_freq)
    mel_high_freq = mel_scale_scalar(high_freq)

    # divide by num_bins+1 in next line because of end-effects where the bins
    # spread out to the sides.
    mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)

    if vtln_high < 0.0:
        vtln_high += nyquist

    bin = torch.arange(num_bins).unsqueeze(1)
    left_mel = mel_low_freq + bin * mel_freq_delta  # size(num_bins, 1)
    center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta  # size(num_bins, 1)
    right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta  # size(num_bins, 1)

    center_freqs = inverse_mel_scale(center_mel)  # size (num_bins)
    # size(1, num_fft_bins)
    mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)

    # size (num_bins, num_fft_bins)
    up_slope = (mel - left_mel) / (center_mel - left_mel)
    down_slope = (right_mel - mel) / (right_mel - center_mel)

    # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
    bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
    
    return bins, center_freqs


def fbank_own(
    waveform: Tensor,
    channel: int = -1,
    energy_floor: float = 1.0,
    frame_length: float = 25.0,
    frame_shift: float = 10.0,
    high_freq: float = 0.0,
    low_freq: float = 20.0,
    min_duration: float = 0.0,
    num_mel_bins: int = 23,
    preemphasis_coefficient: float = 0.97,
    raw_energy: bool = True,
    remove_dc_offset: bool = True,
    round_to_power_of_two: bool = True,
    sample_frequency: float = 16000.0,
    snip_edges: bool = True,
    subtract_mean: bool = False,
    use_log_fbank: bool = True,
    use_power: bool = True,
    vtln_high: float = -500.0,
    vtln_low: float = 100.0,
) -> Tensor:
    r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
    compute-fbank-feats.

    Args:
        waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
        blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
        channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
        dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
            the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
        energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation.  Caution:
            this floor is applied to the zeroth component, representing the total signal energy.  The floor on the
            individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
        frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
        frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
        high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
         (Default: ``0.0``)
        htk_compat (bool, optional): If true, put energy last.  Warning: not sufficient to get HTK compatible features
         (need to change other parameters). (Default: ``False``)
        low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
        min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
        num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
        preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
        raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
        remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
        round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
            to FFT. (Default: ``True``)
        sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
            specified there) (Default: ``16000.0``)
        snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
            in the file, and the number of frames depends on the frame_length.  If False, the number of frames
            depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
        subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
            it this way.  (Default: ``False``)
        use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
        use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
        use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
        vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
            negative, offset from high-mel-freq (Default: ``-500.0``)
        vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
        vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
        window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
         (Default: ``'povey'``)

    Returns:
        Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
        where m is calculated in _get_strided
    """
    device, dtype = waveform.device, waveform.dtype

    waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
        waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
    )

    if len(waveform) < min_duration * sample_frequency:
        # signal is too short
        return torch.empty(0, device=device, dtype=dtype)

    # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
    strided_input, signal_log_energy = _get_window(
        waveform,
        padded_window_size,
        window_size,
        window_shift,
        snip_edges,
        raw_energy,
        energy_floor,
        remove_dc_offset,
        preemphasis_coefficient,
    )

    # size (m, padded_window_size // 2 + 1)
    spectrum = torch.fft.rfft(strided_input).abs()
    if use_power:
        spectrum = spectrum.pow(2.0)

    # size (num_mel_bins, padded_window_size // 2)
    mel_energies, _ = get_mel_banks(
        num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high
    )
    mel_energies = mel_energies.to(device=device, dtype=dtype)

    # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
    mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)

    # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
    mel_energies = torch.mm(spectrum, mel_energies.T)
    if use_log_fbank:
        # avoid log of zero (which should be prevented anyway by dithering)
        mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()

    mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
    return mel_energies

In [None]:
# call fbank_own

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
print("Data type of audio tensor:", audio_tensor.dtype)

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor = audio_tensor.unsqueeze(0)
print("Shape of audio tensor:", audio_tensor.shape)

# Now call the fbank function
fbank_features_own = fbank_own(
    audio_tensor, 
    sample_frequency=sample_rate, 
    num_mel_bins=128, 
    frame_shift=10
)

In [None]:
# plot both

# plot the fbank features from the own implementation
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_own.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Own fbank Features')

# plot the fbank features
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Kaldi fbank Features')



In [None]:
# compare the 90th frame
plt.plot(fbank_features[90, :], label='Kaldi')
plt.plot(fbank_features_own[90, :], label='Own')
plt.legend()


### remove more

In [None]:
import math
from typing import Tuple
import torch
from torch import Tensor

# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001


def _get_epsilon(device, dtype):
    return EPSILON.to(device=device, dtype=dtype)


def _next_power_of_2(x: int) -> int:
    return 1 if x == 0 else 2 ** (x - 1).bit_length()


def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
    assert waveform.dim() == 1
    num_samples = waveform.size(0)
    strides = (window_shift * waveform.stride(0), waveform.stride(0))

    if snip_edges:
        if num_samples < window_size:
            return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
        else:
            m = 1 + (num_samples - window_size) // window_shift
    else:
        reversed_waveform = torch.flip(waveform, [0])
        m = (num_samples + (window_shift // 2)) // window_shift
        pad = window_size // 2 - window_shift // 2
        pad_right = reversed_waveform
        if pad > 0:
            # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
            # but we want [2, 1, 0, 0, 1, 2]
            pad_left = reversed_waveform[-pad:]
            waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
        else:
            # pad is negative so we want to trim the waveform at the front
            waveform = torch.cat((waveform[-pad:], pad_right), dim=0)

    sizes = (m, window_size)
    return waveform.as_strided(sizes, strides)


def _feature_window_function(
    window_size: int,
    device: torch.device,
    dtype: int,
) -> Tensor:
    return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)

def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
    device, dtype = strided_input.device, strided_input.dtype
    log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log()  # size (m)
    if energy_floor == 0.0:
        return log_energy
    return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))


def _get_waveform_and_window_properties(
    waveform: Tensor,
    channel: int,
    sample_frequency: float,
    frame_shift: float,
    frame_length: float,
    round_to_power_of_two: bool,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
    r"""Gets the waveform and window properties"""
    channel = max(channel, 0)
    assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
    waveform = waveform[channel, :]  # size (n)
    window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
    window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
    padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size

    assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
        window_size, len(waveform)
    )
    assert 0 < window_shift, "`window_shift` must be greater than 0"
    assert padded_window_size % 2 == 0, (
        "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
    )
    assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
    assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
    return waveform, window_shift, window_size, padded_window_size


def _get_window(
    waveform: Tensor,
    padded_window_size: int,
    window_size: int,
    window_shift: int,
    snip_edges: bool,
    raw_energy: bool,
    energy_floor: float,
    preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
    
    device, dtype = waveform.device, waveform.dtype
    epsilon = _get_epsilon(device, dtype)

    # size (m, window_size)
    strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)

    # Subtract each row/frame by its mean
    row_means = torch.mean(strided_input, dim=1).unsqueeze(1)  # size (m, 1)
    strided_input = strided_input - row_means

    if raw_energy:
        # Compute the log energy of each row/frame before applying preemphasis and window function
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    if preemphasis_coefficient != 0.0:
        # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
        offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
            0
        )  # size (m, window_size + 1)
        strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]

    # Apply window_function to each row/frame
    window_function = _feature_window_function(window_size, device, dtype).unsqueeze(
        0
    )  # size (1, window_size)
    strided_input = strided_input * window_function  # size (m, window_size)

    # Pad columns with zero until we reach size (m, padded_window_size)
    if padded_window_size != window_size:
        padding_right = padded_window_size - window_size
        strided_input = torch.nn.functional.pad(
            strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
        ).squeeze(0)

    # Compute energy after window function (not the raw one)
    if not raw_energy:
        signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)  # size (m)

    return strided_input, signal_log_energy


def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
    # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
    # it returns size (m, n)
    if subtract_mean:
        col_means = torch.mean(tensor, dim=0).unsqueeze(0)
        tensor = tensor - col_means
    return tensor



def inverse_mel_scale_scalar(mel_freq: float) -> float:
    return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)


def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
    return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)


def mel_scale_scalar(freq: float) -> float:
    return 1127.0 * math.log(1.0 + freq / 700.0)


def mel_scale(freq: Tensor) -> Tensor:
    return 1127.0 * (1.0 + freq / 700.0).log()


def get_mel_banks(
    num_bins: int,
    window_length_padded: int,
    sample_freq: float,
    low_freq: float,
    high_freq: float,
) -> Tuple[Tensor, Tensor]:
    
    num_fft_bins = window_length_padded / 2
    nyquist = 0.5 * sample_freq

    if high_freq <= 0.0:
        high_freq += nyquist

    # fft-bin width [think of it as Nyquist-freq / half-window-length]
    fft_bin_width = sample_freq / window_length_padded
    mel_low_freq = mel_scale_scalar(low_freq)
    mel_high_freq = mel_scale_scalar(high_freq)

    # divide by num_bins+1 in next line because of end-effects where the bins
    # spread out to the sides.
    mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)

    bin = torch.arange(num_bins).unsqueeze(1)
    left_mel = mel_low_freq + bin * mel_freq_delta  # size(num_bins, 1)
    center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta  # size(num_bins, 1)
    right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta  # size(num_bins, 1)

    center_freqs = inverse_mel_scale(center_mel)  # size (num_bins)
    # size(1, num_fft_bins)
    mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)

    # size (num_bins, num_fft_bins)
    up_slope = (mel - left_mel) / (center_mel - left_mel)
    down_slope = (right_mel - mel) / (right_mel - center_mel)

    # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
    bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
    
    return bins, center_freqs


def fbank_own(
    waveform: Tensor,
    channel: int = -1,
    energy_floor: float = 1.0,
    frame_length: float = 25.0,
    frame_shift: float = 10.0,
    high_freq: float = 0.0,
    low_freq: float = 20.0,
    min_duration: float = 0.0,
    num_mel_bins: int = 128,
    preemphasis_coefficient: float = 0.97,
    raw_energy: bool = True,
    round_to_power_of_two: bool = True,
    sample_frequency: float = 16000.0,
    snip_edges: bool = True,
    subtract_mean: bool = False,
    use_log_fbank: bool = True,
    use_power: bool = True,
) -> Tensor:

    device, dtype = waveform.device, waveform.dtype

    waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
        waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
    )

    if len(waveform) < min_duration * sample_frequency:
        # signal is too short
        return torch.empty(0, device=device, dtype=dtype)

    # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
    strided_input, signal_log_energy = _get_window(
        waveform,
        padded_window_size,
        window_size,
        window_shift,
        snip_edges,
        raw_energy,
        energy_floor,
        preemphasis_coefficient,
    )

    # size (m, padded_window_size // 2 + 1)
    spectrum = torch.fft.rfft(strided_input).abs()
    if use_power:
        spectrum = spectrum.pow(2.0)

    # size (num_mel_bins, padded_window_size // 2)
    mel_energies, _ = get_mel_banks(
        num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq
    )
    mel_energies = mel_energies.to(device=device, dtype=dtype)

    # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
    mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)

    # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
    mel_energies = torch.mm(spectrum, mel_energies.T)
    if use_log_fbank:
        # avoid log of zero (which should be prevented anyway by dithering)
        mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()

    mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
    return mel_energies

### remove as much as possible

In [None]:
import math
from typing import Tuple
import torch
from torch import Tensor


def fbank_own(
    waveform: Tensor,
) -> Tensor:
    device, dtype = waveform.device, waveform.dtype

    # shape is [c, n] (=[1, n] in case of mono) = [1, 160000] in our case
    waveform = torch.squeeze(waveform)
    # now shape is [n] = [160000] in our case


    def get_window(
        waveform: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        device, dtype = waveform.device, waveform.dtype

        strides = (160, 1)
        sizes = (998, 400)

        strided_input = waveform.as_strided(sizes, strides) # size (998, 400)

        # Subtract each row/frame by its mean
        row_means = torch.mean(strided_input, dim=1).unsqueeze(1)  # size (998, 1)
        strided_input = strided_input - row_means # size (998, 400)

        # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
        offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)  # size (998, 400 + 1)
        strided_input = strided_input - 0.97 * offset_strided_input[:, :-1] # size (998, 400)

        # Apply window_function to each row/frame
        window_function = torch.hann_window(400, periodic=False, device=device, dtype=dtype).unsqueeze(0)  # size (1, 400)
        strided_input = strided_input * window_function  # size (998, 400)

        strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, 112), mode="constant", value=0).squeeze(0) # 512 - 400 = 112

        return strided_input

    # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
    strided_input = get_window(waveform) # size (998, 512)

    # size (m, padded_window_size // 2 + 1)
    spectrum = torch.fft.rfft(strided_input).abs() # size (998, 256 + 1)

    spectrum = spectrum.pow(2.0)


    def get_mel_banks(
        num_bins: int
    ) -> Tuple[Tensor, Tensor]:

        def inverse_mel_scale_scalar(mel_freq: float) -> float:
            return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)

        def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
            return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)

        def mel_scale_scalar(freq: float) -> float:
            return 1127.0 * math.log(1.0 + freq / 700.0)

        def mel_scale(freq: Tensor) -> Tensor:
            return 1127.0 * (1.0 + freq / 700.0).log()
        
        num_fft_bins = 256 # window_length_padded / 2 = 512 / 2
        nyquist_freq= 8000.0

        low_freq = 20.0
        high_freq = nyquist_freq

        # fft-bin width [think of it as Nyquist-freq / half-window-length]
        fft_bin_width = 31.25 # 16000 / window_length_padded = 16000 / 512
        mel_low_freq = mel_scale_scalar(low_freq) # 31.748578341466644
        mel_high_freq = mel_scale_scalar(high_freq) # 2840.0377117383778

        # divide by num_bins+1 in next line because of end-effects where the bins spread out to the sides.
        mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) # 21.769683204627217

        bin = torch.arange(num_bins).unsqueeze(1)
        left_mel = mel_low_freq + bin * mel_freq_delta  # size(num_bins, 1) = (128, 1)
        center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta  # size(num_bins, 1) = (128, 1)
        right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta  # size(num_bins, 1) = (128, 1)

        center_freqs = inverse_mel_scale(center_mel)  # size (num_bins) = (128)
        mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) # size(1, num_fft_bins) = size (1, 256)

        # size (num_bins, num_fft_bins)
        up_slope = (mel - left_mel) / (center_mel - left_mel) # size (128, 256)
        down_slope = (right_mel - mel) / (right_mel - center_mel) # size (128, 256)

        # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
        bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) # size (128, 256)
        
        return bins

    # size (num_mel_bins, padded_window_size // 2)
    mel_energies = get_mel_banks(128) # torch.Size([128, 256])
    mel_energies = mel_energies.to(device=device, dtype=dtype) # torch.Size([128, 256])

    # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
    mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) # torch.Size([128, 257])

    # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
    mel_energies = torch.mm(spectrum, mel_energies.T) # (998, 256 + 1) x (257, 128) = torch.Size([998, 128])
    
    # avoid log of zero (which should be prevented anyway by dithering)
    mel_energies = torch.max(mel_energies, torch.tensor(torch.finfo(torch.float).eps).to(device=device, dtype=dtype)).log() # torch.Size([998, 128])

    return mel_energies

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
# Data type of audio tensor: torch.float32
# Shape of audio tensor: torch.Size([160000])

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor_batch = audio_tensor.unsqueeze(0)
# Shape of audio tensor: torch.Size([1, 160000])


# Call the fbank_own function
fbank_features_own = fbank_own(
    waveform=audio_tensor_batch,
)

# Now call the fbank function
fbank_features_torch = torchaudio.compliance.kaldi.fbank(
    audio_tensor_batch, 
    sample_frequency=sample_rate, 
    htk_compat=True, 
    use_energy=False, 
    window_type='hanning', 
    num_mel_bins=128, 
    dither=0.0, 
    frame_shift=10
)

# Output the shape of the fbank features to confirm
# Shape of fbank features: torch.Size([998, 128])

# Assuming you have already read the audio file into `audio_signal` and it's a 1D array
# Initial shape of audio signal: (160000,)






# plot both

# plot the fbank features from the own implementation
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_own.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Own fbank Features')

# plot the fbank features
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_torch.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Kaldi fbank Features')

plt.show()

# compare the 90th frame
plt.plot(fbank_features_torch[90, :], label='Kaldi')
plt.plot(fbank_features_own[90, :], label='Own')
plt.legend()

### convert to raw python

In [None]:
import numpy as np

def fbank_own(waveform):
    # Waveform is now a 1D numpy array: shape = [160000]
    
    def get_window(waveform):
        # Stride and size configuration to simulate torch's as_strided
        # Assuming waveform length is n = 160000
        n = waveform.shape[0]
        stride = 160
        window_length = 400
        number_of_frames = (n - window_length) // stride + 1  # 998 frames
        
        # Create an array of indices for each strided window
        indices = np.lib.stride_tricks.as_strided(
            np.arange(n),
            shape=(number_of_frames, window_length),
            strides=(waveform.strides[0]*stride, waveform.strides[0])
        )
        strided_input = waveform[indices]  # shape = [998, 400]
        
        # Subtract each row/frame by its mean
        row_means = np.mean(strided_input, axis=1, keepdims=True)  # shape = [998, 1]
        strided_input -= row_means  # shape = [998, 400]
        
        # Pre-emphasis filtering
        preemphasis_coefficient = 0.97
        strided_input[:, 1:] -= preemphasis_coefficient * strided_input[:, :-1]
        
        # Apply Hanning window to each row/frame
        window_function = np.hanning(window_length)  # shape = [400]
        strided_input *= window_function  # shape = [998, 400]
        
        # Zero-pad each frame to the next power of two for FFT
        padded_window_size = 512
        strided_input = np.pad(strided_input, ((0, 0), (0, padded_window_size - window_length)), 'constant')  # shape = [998, 512]
        
        return strided_input  # shape = [998, 512]
    
    strided_input = get_window(waveform)  # shape = [998, 512]
    
    # Compute the power spectrum
    spectrum = np.abs(np.fft.rfft(strided_input, n=512))**2  # shape = [998, 257]

    def get_mel_banks(num_bins):
        num_fft_bins = 256  # Half the padded window size
        nyquist_freq = 8000.0
        low_freq = 20.0
        high_freq = nyquist_freq
        fft_bin_width = nyquist_freq / num_fft_bins
        
        # Mel scale conversion
        def mel_scale(freq):
            return 1127.0 * np.log(1.0 + freq / 700.0)
        
        def inverse_mel_scale(mel_freq):
            return 700.0 * (np.exp(mel_freq / 1127.0) - 1.0)
        
        mel_low_freq = mel_scale(low_freq)
        mel_high_freq = mel_scale(high_freq)
        mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
        
        mel_bins = np.zeros((num_bins, num_fft_bins + 1))
        
        for i in range(num_bins):
            left_mel = mel_low_freq + i * mel_freq_delta
            center_mel = left_mel + mel_freq_delta
            right_mel = center_mel + mel_freq_delta
            
            for j in range(num_fft_bins + 1):
                freq = j * fft_bin_width
                mel_freq = mel_scale(freq)
                
                if left_mel < mel_freq < right_mel:
                    if mel_freq <= center_mel:
                        mel_bins[i, j] = (mel_freq - left_mel) / (center_mel - left_mel)
                    else:
                        mel_bins[i, j] = (right_mel - mel_freq) / (right_mel - center_mel)
        
        return mel_bins  # shape = [128, 257]

    mel_energies = get_mel_banks(128)  # shape = [128, 257]
    
    # Filter bank energies
    filter_bank_energies = np.dot(spectrum, mel_energies.T)  # shape = [998, 128]
    
    # Log energies
    filter_bank_energies = np.log(np.maximum(filter_bank_energies, 1.19209e-07))
    
    return filter_bank_energies  # shape = [998, 128]

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchaudio
import soundfile as sf


# Verify the file path is correct
file_path = '/home/bosfab01/SpeakerVerificationBA/data/preprocessed/0a4b5c0f-facc-4d3b-8a41-bc9148d62d95/0_segment_0.flac'
try:
    audio_signal, sample_rate = sf.read(file_path)
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise

# Create time array for plotting
time = np.arange(len(audio_signal)) / sample_rate

# Convert the NumPy array to a PyTorch tensor
audio_tensor = torch.from_numpy(audio_signal)

# Ensure the tensor is in float32 format (required for most torchaudio operations)
audio_tensor = audio_tensor.float()
# Data type of audio tensor: torch.float32
# Shape of audio tensor: torch.Size([160000])

# If your array is not in batch x channels x time format, adjust accordingly
# Assuming the audio signal is single-channel and not batched:
audio_tensor_batch = audio_tensor.unsqueeze(0)
# Shape of audio tensor: torch.Size([1, 160000])


# Call the fbank_own function
fbank_features_own = fbank_own(
    waveform=audio_signal,
)

# Now call the fbank function
fbank_features_torch = torchaudio.compliance.kaldi.fbank(
    audio_tensor_batch, 
    sample_frequency=sample_rate, 
    htk_compat=True, 
    use_energy=False, 
    window_type='hanning', 
    num_mel_bins=128, 
    dither=0.0, 
    frame_shift=10
)

# Output the shape of the fbank features to confirm
# Shape of fbank features: torch.Size([998, 128])

# Assuming you have already read the audio file into `audio_signal` and it's a 1D array
# Initial shape of audio signal: (160000,)






# plot both

# plot the fbank features from the own implementation
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_own.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Own fbank Features')

# plot the fbank features
plt.figure(figsize=(10, 1.5))
plt.imshow(fbank_features_torch.T, aspect='auto', origin='lower', cmap='viridis')
plt.title('Kaldi fbank Features')

plt.show()

# compare the 90th frame
plt.plot(fbank_features_torch[90, :], label='Kaldi')
plt.plot(fbank_features_own[90, :], label='Own')
plt.legend()

In [None]:
torch.finfo(torch.float)