In [None]:
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import librosa
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import soundfile as sf
from IPython.display import Audio


from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.cluster.hierarchy import fcluster
from sklearn.decomposition import PCA
import seaborn as sns

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score
import joblib
import random
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
import os

# Diarization

## K-Means

In [None]:
def kmeans_plot_mfcc(audio, sr):
    mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
    mfccs_delta = librosa.feature.delta(mfccs)  
    mfccs_delta2 = librosa.feature.delta(mfccs, order=2)
    
    plt.figure(figsize=(15, 8))

    plt.subplot(3, 1, 1)
    librosa.display.specshow(mfccs, x_axis='time', sr=sr, cmap='coolwarm')
    plt.colorbar(label='MFCCs')
    plt.title('MFCCs')

    plt.subplot(3, 1, 2)
    librosa.display.specshow(mfccs_delta, x_axis='time', sr=sr, cmap='coolwarm')
    plt.colorbar(label='Delta MFCCs')
    plt.title('Delta MFCCs')

    plt.subplot(3, 1, 3)
    librosa.display.specshow(mfccs_delta2, x_axis='time', sr=sr, cmap='coolwarm')
    plt.colorbar(label='Delta-Delta MFCCs')
    plt.title('Delta-Delta MFCCs')

    plt.tight_layout()
    plt.show()
    return mfccs

In [None]:
def diarize_kmeans(audio_path, k=2):
    audio, sr = librosa.load(audio_path)
    speakers = []
    paths = []
    
    mfccs = kmeans_plot_mfcc(audio, sr)
    
    scaler = StandardScaler()
    mfccs = mfccs.T 
    mfccs_scaled = scaler.fit_transform(mfccs)
    
    
    kmeans = KMeans(n_clusters=k+1, random_state=42)
    labels = kmeans.fit_predict(mfccs_scaled)
    for cluster in range(k+1):
        cluster_indices = np.where(labels == cluster)[0]
        cluster_audio = []
        hop_length = 512
        
        for idx in cluster_indices:
            start_sample = idx * hop_length
            end_sample = (idx + 1) * hop_length
            cluster_audio.append(audio[start_sample:end_sample])
        
        cluster_audio = np.concatenate(cluster_audio)
        
        output_file = f'kmeans/cluster_{cluster}.wav'
        sf.write(output_file, cluster_audio, sr)
        paths.append(os.path.abspath(output_file))
        speakers.append((cluster_audio, sr))
        print(f'Cluster {cluster} audio saved to {output_file}')
    
    
    return paths

## Hierarchical Clustering

In [None]:
def extract_audio_features(audio, window_size=2048, hop_length=1024, sample_rate=16000):
    mfccs = librosa.feature.mfcc(y=audio, sr=sample_rate, n_mfcc=13, n_fft=window_size, hop_length=hop_length)

    pitches, magnitudes = librosa.core.piptrack(y=audio, sr=sample_rate)
    pitch = [np.max(p) if np.max(p) > 0 else 0 for p in pitches.T]

    spectral_contrast = librosa.feature.spectral_contrast(y=audio, sr=sample_rate, n_fft=window_size, hop_length=hop_length)

    pitch_resampled = np.interp(np.linspace(0, len(pitch) - 1, spectral_contrast.shape[1]), np.arange(len(pitch)), pitch)

    features = np.vstack((mfccs, spectral_contrast, pitch_resampled)).T
    return features


def diarize_hierarchical(audio_path, k=2, window_size=2048, hop_length=1024):
    audio, sr = librosa.load(audio_path, sr=None)

    features = extract_audio_features(audio, window_size, hop_length, sr)

    features = StandardScaler().fit_transform(features)

    distance_matrix = pdist(features, metric='euclidean')
    Z = linkage(distance_matrix, method='ward')

    plt.figure(figsize=(12, 8))
    dendrogram(Z, truncate_mode='level', p=5)
    plt.title('Dendrogram')
    plt.xlabel('Sample Index')
    plt.ylabel('Distance')
    plt.show()

    clusters = fcluster(Z, t=k+1, criterion='maxclust')
    print("Cluster labels:", clusters)
    unique_labels, counts = np.unique(clusters, return_counts=True)

    for label, count in zip(unique_labels, counts):
        print(f"Cluster {label}: {count} samples")

    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(features)

    plt.figure(figsize=(10, 8))
    plt.scatter(pca_result[:, 0], pca_result[:, 1], c=clusters, cmap='viridis', alpha=0.5)
    plt.title('PCA of Clusters')
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.colorbar(label='Cluster Label')
    plt.show()

    centroids = np.array([features[clusters == i].mean(axis=0) for i in np.unique(clusters)])
    plt.figure(figsize=(10, 8))
    plt.scatter(pca_result[:, 0], pca_result[:, 1], c=clusters, cmap='viridis', alpha=0.5)
    plt.scatter(pca.transform(centroids)[:, 0], pca.transform(centroids)[:, 1], marker='X', color='red', s=100, label='Centroids')
    plt.title('PCA of Clusters with Centroids')
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.colorbar(label='Cluster Label')
    plt.legend()
    plt.show()

    speakers = []
    paths = []
    for label in unique_labels:
        cluster_indices = np.where(clusters == label)[0]
        cluster_audio = np.concatenate([audio[i:i+1] for i in cluster_indices], axis=0)
        output_path = f'hierarchical_clustering/cluster_{label}.wav'
        sf.write(output_path, cluster_audio, sr)
        print(f"Cluster {label} audio saved to {output_path}")
        paths.append(os.path.abspath(output_path))
        speakers.append((cluster_audio, sr))

    features_df = pd.DataFrame(features)
    features_df['Cluster'] = clusters
    features_df_sorted = features_df.sort_values(by='Cluster')

    heatmap_data = features_df_sorted.drop(columns='Cluster')
    palette = sns.color_palette("coolwarm", n_colors=len(np.unique(clusters)))
    row_colors = features_df_sorted['Cluster'].map({i: palette[i] for i in range(len(palette))})

    scaler = MinMaxScaler()
    heatmap_data = pd.DataFrame(scaler.fit_transform(heatmap_data), columns=heatmap_data.columns)

    sns.clustermap(
        heatmap_data.T,
        cmap='coolwarm',
        linewidths=0.5,
        row_colors=row_colors,
        yticklabels=False,
        annot=True, 
        fmt=".2f"    
    )
    plt.figure(figsize=(14, 12))
    plt.show()

    return paths

## Gaussian Mixture Based

In [None]:
def diarize_gmm(audio_path, k=2):
    audio, sr = librosa.load(audio_path)
    speakers = []
    paths = []
    
    mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13).T   
    scaler = StandardScaler()
    mfccs_scaled = scaler.fit_transform(mfccs)
    gmm = GaussianMixture(n_components=k+1, random_state=42, warm_start=False)
    gmm.fit(mfccs_scaled)

    labels = gmm.predict(mfccs_scaled)
    
    for cluster in range(k+1):
        cluster_indices = np.where(labels == cluster)[0]
        cluster_audio = []
        hop_length = 512
        
        for idx in cluster_indices:
            start_sample = idx * hop_length
            end_sample = (idx + 1) * hop_length
            cluster_audio.append(audio[start_sample:end_sample])
        
        cluster_audio = np.concatenate(cluster_audio)
        
        output_file = f'gmm/cluster_{cluster}.wav'
        sf.write(output_file, cluster_audio, sr)
        speakers.append((cluster_audio, sr))
        paths.append(os.path.abspath(output_file))
        print(f'Cluster {cluster} audio saved to {output_file}')
    
    return paths

# Language Classification

In [None]:
from datasets import load_dataset
from huggingface_hub import login
# notebook_login()

login('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')
# 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

In [None]:
def extract_mfcc(audio, sr):
    mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
    mfccs_mean = np.mean(mfccs.T, axis=0)
    return mfccs_mean

In [None]:
def train_random_forest():
    hindi_dataset = load_dataset("mozilla-foundation/common_voice_17_0", "hi", split="train", streaming=True, trust_remote_code=True)
    english_dataset = load_dataset("mozilla-foundation/common_voice_17_0", "en", split="train", streaming=True, trust_remote_code=True)
    
    hindi_iter = iter(hindi_dataset)
    hindi_files = [next(hindi_iter) for _ in range(100)]

    eng_iter = iter(english_dataset)
    eng_files = [next(eng_iter) for _ in range(100)]
    
    combined_samples = hindi_files + eng_files
    random.shuffle(combined_samples)
    
    X = []
    y = []

    for sample in combined_samples:
        features = extract_mfcc(sample['audio']['array'], sample['audio']['sampling_rate'])
        X.append(features)
        y.append(sample['locale'])

    X = np.array(X)
    y = np.array(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
    
    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='weighted')
    precision = precision_score(y_test, y_pred, average='weighted')
    recall = recall_score(y_test, y_pred, average='weighted')

    print(f"Accuracy: {accuracy * 100:.2f}%")
    print(f"F1 Score: {f1:.2f}")
    print(f"Precision: {precision:.2f}")
    print(f"Recall: {recall:.2f}")
    
    
    joblib.dump(clf, 'models/random_forest_hi_en.joblib')
    
    conf_matrix = confusion_matrix(y_test, y_pred, labels=['en', 'hi'])

    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['English', 'Hindi'], yticklabels=['English', 'Hindi'])
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
def classify_random_forest(audio_path, retrain=False):
    
    if retrain:
        train_random_forest()

    clf = joblib.load('models/random_forest_hi_en.joblib')

    audio, sr = librosa.load(audio_path)
    mfcc = extract_mfcc(audio, sr)
    lang = clf.predict([mfcc])
    
    return lang

## Assembling the UI

In [None]:
cluster_drop_choices = {
    'KMeans Clustering': diarize_kmeans,
    'Hierarchical Clustering': diarize_hierarchical,
    'Gaussian Mixture Model Based Clustering': diarize_gmm
}

In [None]:
def diarize(audio_path, num_speakers, method):
    return cluster_drop_choices[method](audio_path, num_speakers)

In [None]:
def classify(audio_path, retrain_bool):
    if retrain_bool == 'True':
        retrain=True
    else:
        retrain=False
    return classify_random_forest(audio_path, retrain)

In [None]:
with gr.Blocks() as demo:
    gr.Markdown('# Speaker Diarization and Language Classification')
    
    with gr.Row():
        audio_path = gr.Audio(type='filepath', label='Audio File')
        
    gr.Markdown('## Diarization')
    with gr.Row():
        method = gr.Dropdown(choices=cluster_drop_choices.keys(), label='Method of Diarization')
        num_speakers = gr.Dropdown(choices=range(1, 11), value=2, label='Number of Speakers')
        
    diarize_btn = gr.Button('Diarize')
    
    diarized_audio_outputs = [gr.Audio(label=f'Segment {i+1}', type='filepath', visible=False) for i in range(11)]

    diarize_btn.click(
        inputs=[num_speakers],
        fn=lambda num_speakers: [gr.update(visible=True)] * (num_speakers+1) + [gr.update(visible=False)] * (11 - num_speakers),
        outputs=diarized_audio_outputs
    )
    
    diarize_btn.click(
        inputs=[audio_path, num_speakers, method],
        fn=diarize,
        outputs=diarized_audio_outputs[:(num_speakers.value+1)]
    )
    
    gr.Markdown('-------')
    gr.Markdown('## Language Classification')
    
    retrain_bool = gr.Dropdown(choices=['True', 'False'], value='False', label='Retrain?')
    classify_btn = gr.Button('Classify Language')
    
    classification_result_output = gr.Textbox(label='Classification Result', interactive=False, visible=False)

    classify_btn.click(
        fn=lambda: gr.update(visible=True),
        outputs=classification_result_output
    )
    classify_btn.click(
        inputs=[audio_path, retrain_bool],
        fn=classify,
        outputs=classification_result_output
    )


In [None]:
demo.launch()