In [None]:
import numpy as np
import pandas as pd
import sklearn
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns 
import librosa 
import os
import glob
from scipy.stats import wasserstein_distance

# Importing example of the raw data in the GTZAN dataset. Blues and Country Music

In [None]:
data_dir = '../datasets/raw_data/GTZAN_Dataset/genres_original'
blues_audio_files = sorted(glob.glob(data_dir + "/blues" + "/*.wav"))
country_audio_files = sorted(glob.glob(data_dir + "/country" + "/*.wav"))

# Exploring Data Features

In [None]:
librosa.load(blues_audio_files[1])

In [None]:
b1_audio, b1_freq = librosa.load(blues_audio_files[0])
b2_audio , b2_freq = librosa.load(blues_audio_files[1])
b3_audio , b3_freq = librosa.load(blues_audio_files[2])

In [None]:
c1_audio, c1_freq = librosa.load(country_audio_files[0])
c2_audio , c2_freq = librosa.load(country_audio_files[1])
c3_audio , c3_freq = librosa.load(country_audio_files[2])

print(c1_audio.shape,c2_audio.shape, c3_audio.shape)
print(c1_freq, c2_freq, c3_freq)

In [None]:
n_pt = 661000

In [None]:
X,_ = librosa.load(blues_audio_files[0])
X = X[:n_pt]
for i in range(1,100):
    bi,_ = librosa.load(blues_audio_files[i])
    X = np.column_stack((X,bi[:n_pt]))
X.shape

In [None]:
Y, _ = librosa.load(country_audio_files[0])
Y = Y[:n_pt]
for i in range(1,100):
    ci, _ = librosa.load(country_audio_files[i])
    Y = np.column_stack((Y,ci[:n_pt]))
Y.shape

In [None]:
# Computation time is so high with 661000 points, just lessen the points to have a reasonable time cost
n_pt2 = 400

In [None]:
X1 = X[:n_pt2,:]
Y1 = Y[:n_pt2,:]
XN = np.concatenate([X1,Y1], axis=1)

# GW-Distance
 Implementation of Scalable Gromow-Wasserstein distance, which has been founded and described by [Natalia Kravtsova, Reginald L. McGee II & Adriana T. Dawes](https://link.springer.com/article/10.1007/s11538-023-01175-y). 

In [None]:
def gw_distance_matrix(XN, n_pt):
    '''
    Requires scipy.stats.wasserstein_distance
    XN narray
    n_pt: int -> # of points that 
    '''
    n = XN.shape[1]
    # n_time_pts = X.shape[0]
    time = np.arange(n_pt)

    # GWtaud
    GW = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            Traji = np.column_stack((time, XN[:, i]))
            Trajj = np.column_stack((time, XN[:, j]))
            vi = np.linalg.norm(np.diff(Traji, axis=0), axis=1)
            vj = np.linalg.norm(np.diff(Trajj, axis=0), axis=1)
            GW[i, j] = wasserstein_distance(vi, vj)

    GW = GW + GW.T
    return GW 

GW = gw_distance_matrix(XN, n_pt2)

# GW Distance matrix between blues and country musics genres.

In [None]:
plt.imshow(GW, cmap='cool', interpolation='none')
plt.title('GWtau')
plt.colorbar()

In [None]:
sns.heatmap(GW)

# Different Data Feature for the 3 sec time windows



In [None]:
F3 = pd.read_csv('datasets/GTZAN_Dataset/features_3_sec.csv')


In [None]:
F3

In [None]:
F3n = F3['mfcc10_mean'].to_numpy()
F3n = F3n.reshape(999,10).T

In [None]:
# n = F3n.shape[1]
# # n_time_pts = X.shape[0]
# time = np.arange(10)

# # GWtau
# GW = np.zeros((n, n))
# for i in range(n):
#     for j in range(i + 1, n):
#         Traji = np.column_stack((time, F3n[:, i]))
#         Trajj = np.column_stack((time, F3n[:, j]))
#         vi = np.linalg.norm(np.diff(Traji, axis=0), axis=1)
#         vj = np.linalg.norm(np.diff(Trajj, axis=0), axis=1)
#         GW[i, j] = wasserstein_distance(vi, vj)

# GW = GW + GW.T

## mfcc10_mean

In [None]:
GW = gw_distance_matrix(F3n,n_pt=10)
sns.heatmap(GW)

In [None]:
F3n.T[0,:]

## chroma_stft_var

In [None]:
F3n = F3['chroma_stft_var'].to_numpy().reshape(999,10).T

GW = gw_distance_matrix(F3n, n_pt=10)
# n = F3n.shape[1]
# # n_time_pts = X.shape[0]
# time = np.arange(10)

# # GWtau
# GW = np.zeros((n, n))
# for i in range(n):
#     for j in range(i + 1, n):
#         Traji = np.column_stack((time, F3n[:, i]))
#         Trajj = np.column_stack((time, F3n[:, j]))
#         vi = np.linalg.norm(np.diff(Traji, axis=0), axis=1)
#         vj = np.linalg.norm(np.diff(Trajj, axis=0), axis=1)
#         GW[i, j] = wasserstein_distance(vi, vj)

# GW = GW + GW.T




In [None]:
plt.imshow(GW, cmap='gist_stern', interpolation='none')
plt.title('GWtau')
plt.colorbar()

In [None]:
col_list = F3.columns.values.tolist()
print("Get the list from DataFrame column:\n", col_list)


In [None]:
my_list = list()

for names in col_list[2:]:
    F3n = F3[names].to_numpy().reshape(999,10).T

    n = F3n.shape[1]
    # n_time_pts = X.shape[0]
    time = np.arange(10)

    # GWtau
    GW = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            Traji = np.column_stack((time, F3n[:, i]))
            Trajj = np.column_stack((time, F3n[:, j]))
            vi = np.linalg.norm(np.diff(Traji, axis=0), axis=1)
            vj = np.linalg.norm(np.diff(Trajj, axis=0), axis=1)
            GW[i, j] = wasserstein_distance(vi, vj)

    GW = GW + GW.T

    my_list.append([GW,names])

# Visualizing of GW distance for different audio features  

In [None]:
# for DM in my_list:
#     plt.imshow(DM[0], cmap='gist_stern', interpolation='none')
#     plt.title('GWtau {}'.format(DM[1]))
#     plt.colorbar()
#     plt.show()

# Same Visualization by using Seaborn module

In [None]:
# for DM in my_list:
#     ax = plt.axes()
#     sns.heatmap(DM[0], ax = ax)
#     ax.set_title('GWtau {}'.format(DM[1]))
#     plt.show()

In [None]:
plt.imshow(DM[0], cmap='gist_stern', interpolation='none')
plt.title('GWtau {}'.format(DM[1]))
plt.colorbar()
plt.show()

# All GW Distance Matrices at one frame 

In [None]:
fig, axs = plt.subplots(7, 8, figsize=(70, 80))  # 7 rows, 8 columns for 56 subplots
fig.tight_layout(pad=2.0)  # Adjust the spacing between subplots

# Iterate through the list and plot each array with its corresponding name
for i, (array, name) in enumerate(my_list):
    row, col = divmod(i, 8)  # Calculate the row and column indices
    ax = axs[row, col]
    
    # Display the array using imshow
    ax.imshow(array, cmap='viridis', interpolation='nearest')
    
    # Annotate with the name
    ax.set_title(name, fontsize=30)
    
    # Turn off axis labels and ticks for better visibility
    # ax.axis('off')

plt.show()