In [168]:

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [169]:
import asyncio
import io
import json
import logging
from datetime import datetime, timezone
from uuid import uuid4

import pandas as pd
import torch
from pyannote.audio import Pipeline
from qdrant_client import QdrantClient, models

from app.database_redis import keys
from app.database_redis.connection import get_redis_client
from app.services.audio.audio import AudioSlicer
from app.services.audio.redis import Diarisation, Diarizer, Meeting, best_covering_connection
from app.settings import settings

logger = logging.getLogger(__name__)


In [170]:
client = QdrantClient("qdrant", timeout=10)

In [171]:

def get_stored_knn(emb: list, user_id):
    search_result = client.search(
        collection_name="main",
        query_vector=emb,
        limit=1,
        query_filter=models.Filter(
            must=[models.FieldCondition(key="user_id", match=models.MatchValue(value=user_id))]
        ),
    )
    if len(search_result) > 0:
        search_result = search_result[0]
        return search_result.payload["speaker_id"], search_result.score
    else:
        return None, None


async def add_new_speaker_emb(emb: list, redis_client, user_id, speaker_id=None):
    logger.info("Adding new speaker...")
    speaker_id = speaker_id if speaker_id else str(uuid4())
    client.upsert(
        collection_name="main",
        wait=True,
        points=[
            models.PointStruct(id=str(uuid4()), vector=emb, payload={"speaker_id": speaker_id, "user_id": user_id})
        ],
    )
    await redis_client.lpush(keys.SPEAKER_EMBEDDINGS, json.dumps((speaker_id, emb.tolist(), user_id)))
    logger.info(f"Added new speaker {speaker_id}")
    return speaker_id


async def process_speaker_emb(emb: list, redis_client, user_id):
    speaker_id, score = get_stored_knn(emb, user_id)
    logger.info(f"score: {score}")

    if speaker_id:
        if score > 0.95:
            pass
        elif score > 0.75:
            await add_new_speaker_emb(emb, redis_client, user_id, speaker_id=speaker_id)
        else:
            speaker_id = await add_new_speaker_emb(emb, redis_client, user_id)
    else:
        speaker_id = await add_new_speaker_emb(emb, redis_client, user_id)

    return str(speaker_id), score


def parse_segment(segment):
    return segment[0].start, segment[0].end, int(segment[-1].split("_")[1])


async def get_next_chunk_start(diarization_result, length, shift):

    if len(diarization_result) > 0:
        last_speech = diarization_result[-1]

        ended_silence = length - last_speech["end"]
        logger.info(ended_silence)
        if ended_silence < 2:
            logger.info("interrupted")
            return last_speech["start"] + shift

        else:
            logger.info("non-interrupted")
            return last_speech["end"] + shift

    else:
        return None


In [172]:
redis_client = await get_redis_client(settings.redis_host, settings.redis_port, settings.redis_password)

In [173]:
diarizer = Diarizer(redis_client)
meeting_id = await diarizer.pop_inprogress()

In [174]:
diarizer

<app.services.audio.redis.Diarizer at 0x7fd3f0ad7790>

In [175]:
meeting_id

'meeting1'

In [176]:
meeting = Meeting(redis_client, meeting_id)
await meeting.load_from_redis()
seek = (meeting.diarizer_seek_timestamp - meeting.start_timestamp).total_seconds()

current_time = datetime.now(timezone.utc)

In [177]:
seek

0.0

In [130]:
current_time

datetime.datetime(2024, 5, 20, 18, 20, 51, 450760, tzinfo=datetime.timezone.utc)

In [131]:
connections = await meeting.get_connections()

In [132]:
connections

[<app.services.audio.redis.Connection at 0x7fd5f5c2ac80>,
 <app.services.audio.redis.Connection at 0x7fd5f5c2ad70>]

In [133]:
connection = connections[0]

In [134]:
connection.start_timestamp

datetime.datetime(2024, 5, 20, 17, 38, 5, 107331, tzinfo=tzutc())

In [137]:
await connection.load_from_redis()

In [138]:
connection.end_timestamp

datetime.datetime(2024, 5, 20, 18, 21, 11, 370512, tzinfo=tzutc())

In [178]:
seek

0.0

In [179]:
meeting.diarizer_seek_timestamp

datetime.datetime(2024, 5, 20, 17, 38, 5, 107331, tzinfo=tzutc())

In [180]:
connections[0].end_timestamp

datetime.datetime(2024, 5, 20, 18, 21, 11, 370512, tzinfo=tzutc())

In [181]:
meeting.diarizer_seek_timestamp, current_time

(datetime.datetime(2024, 5, 20, 17, 38, 5, 107331, tzinfo=tzutc()),
 datetime.datetime(2024, 5, 20, 18, 30, 22, 34749, tzinfo=datetime.timezone.utc))

In [182]:
connections

[<app.services.audio.redis.Connection at 0x7fd5f5c2ac80>,
 <app.services.audio.redis.Connection at 0x7fd5f5c2ad70>]

In [183]:
connection = best_covering_connection(meeting.diarizer_seek_timestamp, current_time, connections)

In [184]:
max_length = 240

In [185]:
seek

0.0

In [186]:
seek + max_length

240.0

In [148]:
connection.id

'02a0f16b-4a17-4315-b267-3f6b147b3465'

In [149]:
f"/audio/{connection.id}.webm"

'/audio/02a0f16b-4a17-4315-b267-3f6b147b3465.webm'

In [187]:
connection = best_covering_connection(meeting.diarizer_seek_timestamp, current_time, connections)
audio_slicer = await AudioSlicer.from_ffmpeg_slice(f"/audio/{connection.id}.webm", seek, max_length)
slice_duration = audio_slicer.audio.duration_seconds
audio_data = await audio_slicer.export_data()

2024-05-20 18:31:12,349 - INFO - app.services.audio.audio - None


In [188]:
audio_slicer.audio

In [189]:
pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization-3.1",
        use_auth_token="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",
    )
pipeline.to(torch.device("cuda"))

<pyannote.audio.pipelines.speaker_diarization.SpeakerDiarization at 0x7fd3982879a0>

In [190]:

output, embeddings = pipeline(io.BytesIO(audio_data), return_embeddings=True)

In [191]:
segments = [i for i in output.itertracks(yield_label=True)]

In [192]:
seek

0.0

In [207]:

speakers = [await process_speaker_emb(e, redis_client, connection.user_id) for e in embeddings]

segments = [i for i in output.itertracks(yield_label=True)]
df = pd.DataFrame([parse_segment(s) for s in segments], columns=["start", "end", "speaker_id"])
df["speaker"] = df["speaker_id"].replace({i: s[0] for i, s in enumerate(speakers)})
df["score"] = df["speaker_id"].replace({i: s[1] for i, s in enumerate(speakers)})
result = df.drop(columns=["speaker_id"]).to_dict("records")

diarization = Diarisation(meeting_id, redis_client, result)
await diarization.lpush()

seek = await get_next_chunk_start(result, slice_duration, seek)
seek = seek or  seek + slice_duration

meeting.diarize_seek_timestamp = meeting.start_timestamp+timedelta(seconds=seek)
await diarizer.remove(meeting.meeting_id)
await meeting.update_redis()

In [219]:
Diarisation??

[0;31mInit signature:[0m
[0mDiarisation[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmeeting_id[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mredis_client[0m[0;34m:[0m [0mredis[0m[0;34m.[0m[0masyncio[0m[0;34m.[0m[0mclient[0m[0;34m.[0m[0mRedis[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdata[0m[0;34m:[0m [0mList[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      Data(key: str, redis_client: redis.asyncio.client.Redis, data: Union[List, dict] = None)
[0;31mSource:[0m        
[0;32mclass[0m [0mDiarisation[0m[0;34m([0m[0mData[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmeeting_id[0m[0;34m:[0m [0mstr[0m[0;34m,[0m [0mredis_client[0m[0;34m:[0m [0mRedis[0m[0;34m,[0m [0mdata[0m[0;34m:[0m [0mList[0m [0;34m=[0m [0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;3

In [220]:
diarization = Diarisation(meeting_id='meeting1',redis_client = redis_client)
d = await diarization.rpop()

In [221]:
d

'[{"start": 114.5246179966044, "end": 115.20373514431239, "speaker": "328e5dd7-2a38-4fef-ab2b-658425b40a37", "score": 1.0}, {"start": 195.6451612903226, "end": 196.62988115449917, "speaker": "d366dad6-8059-462b-9de8-11bdee4ef3c6", "score": 0.9999999}, {"start": 215.509337860781, "end": 215.7470288624788, "speaker": "d165a9a5-d6e6-40d9-a787-1e5c3c0dc763", "score": 0.99999994}, {"start": 215.7470288624788, "end": 215.93378607809848, "speaker": "328e5dd7-2a38-4fef-ab2b-658425b40a37", "score": 1.0}, {"start": 215.79796264855688, "end": 215.8149405772496, "speaker": "d165a9a5-d6e6-40d9-a787-1e5c3c0dc763", "score": 0.99999994}, {"start": 215.93378607809848, "end": 216.71477079796267, "speaker": "d165a9a5-d6e6-40d9-a787-1e5c3c0dc763", "score": 0.99999994}, {"start": 217.76740237691004, "end": 218.446519524618, "speaker": "d165a9a5-d6e6-40d9-a787-1e5c3c0dc763", "score": 0.99999994}, {"start": 249.43123938879458, "end": 250.77249575551784, "speaker": "14541c71-b3c1-483f-a1e7-9542ebe3ace0", "sco

In [None]:
diarizer

In [199]:
seek = await get_next_chunk_start(result, slice_duration, seek)

In [200]:
seek

436.893039049236

In [201]:
seek = seek or  seek + slice_duration

In [206]:
meeting.start_timestamp+timedelta(seconds=seek)

datetime.datetime(2024, 5, 20, 17, 45, 22, 370, tzinfo=tzutc())

In [202]:
seek

436.893039049236

In [194]:
seek

0.0

In [160]:
current_time

datetime.datetime(2024, 5, 20, 18, 20, 51, 450760, tzinfo=datetime.timezone.utc)

In [162]:
from datetime import timedelta

In [164]:
timedelta(seconds=seek)

datetime.timedelta(0)

In [158]:
speakers

[('d366dad6-8059-462b-9de8-11bdee4ef3c6', 0.9999999),
 ('328e5dd7-2a38-4fef-ab2b-658425b40a37', 1.0),
 ('14541c71-b3c1-483f-a1e7-9542ebe3ace0', 1.0),
 ('d165a9a5-d6e6-40d9-a787-1e5c3c0dc763', 0.99999994),
 ('d98547a6-036c-4829-b986-85c7f6beeb77', 1.0)]