In [1]:
from qdrant_client import QdrantClient
from qdrant_client.models import  Distance, VectorParams

client = QdrantClient(url="http://localhost:6333")

from qdrant_client.models import PointStruct
from uuid import uuid4
import json

In [2]:
import redis

redis_client = redis.Redis(host='localhost', port=6399, db=0,decode_responses=True)



In [105]:
class SpeechDAL:
    def get_all_transcriptions(self) -> list:
        # Get all data
        raw_data = redis_client.lrange(f"transcriber", start=0, end=-1)
        return [json.loads(data) for data in raw_data]

    def get_all_diarizations(self) -> list:
        # Get all data
        raw_data = redis_client.lrange(f"diarizer", start=0, end=-1)
        return [json.loads(data) for data in raw_data]

    def lpush_diarizations(self,diarizations: str) -> None:
        for d in diarizations:
            redis_client.lpush(f"diarizer", json.dumps(d))

    def lpush_transcriptions(self, transcriptions) -> None:
        for t in transcriptions:
            redis_client.lpush(f"transcriber", json.dumps(t))

    def ltrim_transcriptions(self, meeting_id, slice: tuple):
        redis_client.ltrim(f"diarizer", int(slice[0]), int(slice[1]))

    def ltrim_diarizations(self, meeting_id, slice: tuple):
        redis_client.ltrim(f"diarizer", int(slice[0]), int(slice[1]))


In [3]:
from pathlib import  Path
import torch
from pyannote.audio import Pipeline
import io
from audio import AudioSlicer
import pandas as pd

import numpy as np
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
import matplotlib.pyplot as plt
import numpy as np
import random

Path.ls = lambda self:[item for item in self.iterdir()]

def parse_segment(segment):
    return segment[0].start, segment[0].end, int(segment[-1].split("_")[1])


pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            use_auth_token="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",
        )
pipeline.to(torch.device("cuda"))


torchvision is not available - cannot save figures


<pyannote.audio.pipelines.speaker_diarization.SpeakerDiarization at 0x7fcfcb38b160>

In [4]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import pairwise_distances
from scipy.cluster.hierarchy import linkage, fcluster

# Configuration parameters
config = {
    'clustering': {
        'method': 'centroid',
        'min_cluster_size': 3,
        'threshold': 1.#0.7# 0.7045654963945799 #
    }
}

def centroid_linkage_distance(matrix):
    """Custom function to compute linkage with centroid method."""
    Z = linkage(matrix, method='centroid')
    return Z

def compute_centroids(embeddings, labels):
    """Compute centroids of clusters and return centroids in the order of labels."""
    unique_labels = np.unique(labels)
    centroids = np.zeros((len(unique_labels), embeddings.shape[1]))
    
    for i, label in enumerate(unique_labels):
        centroids[i] = embeddings[labels == label].mean(axis=0)
    
    return centroids, unique_labels

def map_labels_to_centroid_indices(labels, unique_labels):
    """Map labels to the indices of centroids."""
    label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
    centroid_indices = np.array([label_to_index[label] if label in unique_labels else -1 for label in labels])
    return centroid_indices

def cluster_embeddings(embeddings, config):
    distance_matrix = pairwise_distances(embeddings, metric='cosine')
    Z = centroid_linkage_distance(distance_matrix)
    labels = fcluster(Z, t=config['clustering']['threshold'], criterion='distance')
    unique_labels, counts = np.unique(labels, return_counts=True)
    valid_labels = unique_labels[counts >= config['clustering']['min_cluster_size']]
    valid_indices = np.where(np.isin(labels, valid_labels))[0]
    centroids, unique_labels = compute_centroids(embeddings, labels)
    centroid_indices = map_labels_to_centroid_indices(labels, unique_labels)
    valid_centroid_indices = np.array([index for index, centroid in enumerate(centroids) if index in np.unique(centroid_indices[valid_indices])])
    return centroid_indices, centroids, valid_centroid_indices

def find_largest_span(points):
    # Use max to find the point with the largest span
    largest_span_point = max(points, key=lambda p: p.payload['span'])
    return largest_span_point


In [5]:
audio_file_path = Path('/home/andrew/volumes/audio/dd421e44-b243-4665-88c5-2c156687ddc8.webm')

In [41]:
class Clusterize:
    def __init__(self,path,config):
        self.config = config
        self.path = path
        self.embeddings_list = []
        self.segments = []
    
    async def extract(self,start,duration):
        audio_slicer = await AudioSlicer.from_ffmpeg_slice(audio_file_path, start, duration)
        audio_data = await audio_slicer.export_data()
        output, embs = pipeline(io.BytesIO(audio_data), return_embeddings=True) #audio service output

        segs = [i for i in output.itertracks(yield_label=True)]
        segs = pd.DataFrame([parse_segment(s) for s in segs], columns=["start", "end", "speaker_id"])
        segs['start']+=start
        segs['end']+=start
        
        self.embeddings_list.append(embs)
        self.segments.append(segs)

        self.embeddings = np.concatenate(self.embeddings_list, axis=0)
        self.clusters,self.centroids,self.valid_clusters = cluster_embeddings(self.embeddings, self.config)
        self.df = pd.concat(self.segments)
        self.df = self.df.set_index('speaker_id').join(pd.DataFrame({'cluster':self.clusters})).reset_index(drop=True)
        self.df['span'] = self.df['end'] - self.df['start']
        cluster_span = self.df.groupby('cluster')[['span']].sum()
        self.cluster_span = cluster_span.to_dict()['span']
        
        
    
        
    def assing_speaker_ids(self):
        new_points = []
        for cluster in list(self.cluster_span.keys()):
            point_id = None

            centroid = self.centroids[cluster]
            search_results = client.search(collection_name="speakers", query_vector=list(centroid), limit=5,score_threshold=0.65)
            
            
            
            if len(search_results)==0:
                if cluster in self.valid_clusters:
                    point_id = str(uuid4())
                    new_points.append(PointStruct(id = point_id,vector = list(centroid),payload={"span":self.cluster_span[cluster]}))

            else:
                search_result = find_largest_span(search_results)
                point_id = search_result.id 
                remote_span = search_result.payload['span']
                

                span = self.cluster_span[cluster]
                if span > remote_span:
                    new_points.append(PointStruct(id = point_id,vector = list(centroid),payload={"span":span}))  
                    
            self.df.loc[self.df['cluster']==cluster,'speaker'] = point_id
                    
        if len(new_points)>0:
            operation_info = client.upsert(
                collection_name="speakers",
                wait=True,
                points=new_points
            )    
        self.df = self.df.sort_values('start')


In [46]:
self.cluster_span.keys()

dict_keys([0, 1])

In [83]:
self = Clusterize(audio_file_path,config)

In [84]:
n=0

In [140]:
await self.extract(n,20)
self.assing_speaker_ids()
n+=20

In [141]:
renameing = {id:n for n,id in enumerate(self.df['speaker'].unique().tolist())}

In [142]:
renameing

{'01d8d910-6c0d-436b-a021-b1cfb9e7c1ad': 0,
 '2326f850-4157-4f2a-b826-e87318c1b942': 1}

In [143]:
diar_df = self.df.replace(renameing)

In [144]:
audio_slicer = await AudioSlicer.from_ffmpeg_slice(audio_file_path, 0, 100)
audio_slicer.audio

In [145]:
def prep_transcripts(transcriptions) -> pd.DataFrame:
    dfs = []
    for n, (t, seek) in enumerate(transcriptions):
        df = pd.DataFrame(t)[[2, 3, 4]]
        df.columns = ["start", "end", "speech"]
        df["start"] = pd.to_timedelta(df["start"], unit="s") + pd.Timestamp(seek)
        df["end"] = pd.to_timedelta(df["end"], unit="s") + pd.Timestamp(seek)
        df["trans_chunk"] = n
        dfs.append(df)
    return pd.concat(dfs).reset_index(drop=True)

In [146]:
speech_dal = SpeechDAL()
transcriptions = speech_dal.get_all_transcriptions()

In [147]:
trans_df = prep_transcripts(transcriptions)

In [148]:
start_timestamp = pd.Timestamp(redis_client.get('start_timestamp'))

In [149]:
start_timestamp

Timestamp('2024-06-21 12:36:33.258864+0000', tz='UTC')

In [150]:
diar_df['start'] = pd.to_timedelta(diar_df['start'],unit = 's')+start_timestamp
diar_df['end'] = pd.to_timedelta(diar_df['end'],unit = 's')+start_timestamp

In [151]:
segments = trans_df.to_dict("records")

if len(segments) > 0:
    for seg in segments:
        diar_df["intersection"] = np.minimum(diar_df["end"], seg["end"]) - np.maximum(
            diar_df["start"], seg["start"]
        )
        speaker_ = diar_df[
            (diar_df["intersection"] == diar_df["intersection"].max())
            & (diar_df["intersection"] > pd.Timedelta(0))
        ][["speaker"]]
        if len(speaker_) > 0:
            seg["speaker"] = speaker_["speaker"].iloc[0]


        # ToDo: Dima
        else:
            seg["speaker"] = np.nan


In [153]:
pd.DataFrame(segments).sort_values('start')

Unnamed: 0,start,end,speech,trans_chunk,speaker
54,2024-06-21 12:36:33.368864+00:00,2024-06-21 12:36:36.108864+00:00,Что у нас с тобой по спикерам? Пройдем все.,19,0
55,2024-06-21 12:36:37.248864+00:00,2024-06-21 12:36:39.268864+00:00,"По спикерам давай, по диарайзу.",19,1
56,2024-06-21 12:36:42.108864+00:00,2024-06-21 12:36:44.688864+00:00,Последний день перед отпуском на основной раб...,19,1
57,2024-06-21 12:36:44.708864+00:00,2024-06-21 12:36:47.548864+00:00,"Блин, то по носу, то золотом. Все мое, блин.",19,1
58,2024-06-21 12:36:48.668864+00:00,2024-06-21 12:36:50.128864+00:00,У тебя два дня отпуск?,19,0
59,2024-06-21 12:36:51.408864+00:00,2024-06-21 12:36:53.148864+00:00,"Нет, у меня четыре со вторника.",19,0
45,2024-06-21 12:36:53.258864+00:00,2024-06-21 12:36:55.498864+00:00,"по пятницу на основной работе,",18,1
46,2024-06-21 12:36:55.518864+00:00,2024-06-21 12:36:58.058864+00:00,а здесь я отдохну,18,1
47,2024-06-21 12:36:58.058864+00:00,2024-06-21 12:36:59.758864+00:00,"в среду, в четверг, завтра я в строю буду.",18,1
48,2024-06-21 12:37:00.658864+00:00,2024-06-21 12:37:01.218864+00:00,Ага.,18,0
