In [54]:

%load_ext autoreload
%autoreload 2

import sys 
sys.path.insert(0,'/app')
sys.path.insert(0,'..')
import diart
from audio.redis import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
from pathlib import Path

def list_directory_contents(path='.'):
    """
    Lists the contents of a directory using pathlib.

    :param path: The directory path to list contents of. Defaults to the current directory.
    :return: None
    """
    p = Path(path)
    return [entry for entry in p.iterdir()]


In [56]:
import torch
import torch.nn.functional as F

def average_embeddings(embeddings):
    """
    Averages embeddings along the second dimension.

    :param embeddings: A 3D tensor of shape (num_slices, num_embeddings, embedding_dim)
                       where num_slices is the number of time slices,
                       num_embeddings is the number of embeddings per time slice,
                       and embedding_dim is the dimensionality of each embedding.
    :return: A 2D tensor of shape (num_slices, embedding_dim) representing the averaged
             embeddings for each time slice.
    """
    return torch.mean(embeddings, dim=1)

def pairwise_cosine_similarity(embeddings):
    """
    Calculate the pairwise cosine similarity of a set of averaged embeddings.

    :param embeddings: A 2D tensor of shape (num_slices, embedding_dim)
                       where num_slices is the number of time slices and
                       embedding_dim is the dimensionality of the averaged embeddings.
    :return: A 2D tensor of shape (num_slices, num_slices) containing
             the pairwise cosine similarity scores between slices.
    """
    embeddings_norm = F.normalize(embeddings, p=2, dim=1)
    similarity_matrix = torch.mm(embeddings_norm, embeddings_norm.t())
    
    return similarity_matrix



In [57]:
redis_client = await get_inner_redis()

In [58]:
redis_keys = await redis_client.keys('*')

In [59]:
redis_keys

['Start:42381612-e5e2-46ba-bb5a-73d17d25f90f--0',
 'test',
 '7b2f5cfb-9171-4eab-9b35-bd65bc1a925c',
 'diarization_85ecfa13-14ce-4c59-a824-b3405d258f13--0',
 '21801789-400a-4eb1-ae76-1bb73b9771a1',
 'd1a63b4d-c2fc-40e6-bca9-bb875fbf97c3',
 'diarization_42381612-e5e2-46ba-bb5a-73d17d25f90f--0',
 'Segment:a72d3780-84f9-4e0a-8a60-6b3aba70bf23--0',
 '8da4cbc3-2fe4-4032-84e3-e8775ad913dd',
 'f15d36ab-7cbc-4529-bdf2-70e37775fe68',
 '796ceef6-df11-447d-9e1f-ee1eadcc478c',
 'Start:a72d3780-84f9-4e0a-8a60-6b3aba70bf23--0',
 '045d11f5-3d54-4dde-a572-f8043f146abb',
 '9bdf4c01-0709-4b98-b3bb-37196a84c033',
 'b72c112c-ec66-4e0c-8196-b78a32daa756',
 '9aac8334-114b-4721-8617-1858a030e986',
 '031a0b05-3029-470a-9608-3dbf6a202502',
 'Start:a5fc8dae-9463-4973-ae9c-f95ed56b2137--0',
 '471d87bf-0b4a-4363-ba10-83051c38bcd9',
 'Embeddings']

In [60]:
#[await redis_client.delete(k) for k in redis_keys]

In [61]:
from diart import SpeakerDiarization
from diart.sources import FileAudioSource
from diart.inference import StreamingInference
from diart.sinks import  RedisWriter
import json
import pandas as pd

In [62]:
from diart.sinks import  RTTMWriter

In [63]:
res = []

In [78]:
id = '42381612-e5e2-46ba-bb5a-73d17d25f90f--0' 

In [95]:
r=1
while r:
    r = await redis_client.rpop(f'diarization_{id}')
    if r:
        res.append(r)


In [96]:
len(res)

965

In [108]:
res[0]

{'start': 69.50833333333333,
 'end': 70.00833333333333,
 'speaker_id': 'speaker0',
 'centroids': [[4.3035756357057835,
   2.6926474636420608,
   2.6899493230885128,
   -4.165116617441527,
   -3.632945484263473,
   -7.08434839732945,
   -1.7988810682436451,
   -0.7240494233483332,
   7.758875464845914,
   -1.9481693013076438,
   -8.13655921118334,
   1.6333843934335164,
   -5.081057824470918,
   -4.143064765710733,
   2.0753976447304012,
   5.421707333007362,
   2.3927385698771104,
   2.4464143089717254,
   -6.668232839088887,
   3.2791755097568966,
   -0.485542195783637,
   -2.978622494963929,
   5.324479583650827,
   -0.1622489005189891,
   -0.10806973020953592,
   3.6767006495938404,
   -5.481223476468585,
   -5.383870865567587,
   -6.824676690157503,
   1.7298555001616478,
   5.214812894235365,
   2.046353729296243,
   7.836172005627304,
   7.293278624303639,
   0.5499065378244268,
   0.4492039378237678,
   4.63608896691585,
   -1.0819280102186895,
   4.17807868856471,
   2.79101040

In [97]:
res = [json.loads(r) for r in res]

TypeError: the JSON object must be str, bytes or bytearray, not dict

In [98]:
len(res)

965

In [99]:
import torch

In [100]:
embs = torch.Tensor(pd.DataFrame(res)['centroids'])

AttributeError: 'str' object has no attribute 'keys'

In [71]:
embs.shape

torch.Size([127, 20, 512])

In [72]:
# Average the embeddings within each time slice
averaged_embeddings = average_embeddings(embs)

# Calculate pairwise cosine similarity between averaged embeddings of time slices
similarity = pairwise_cosine_similarity(averaged_embeddings)

In [77]:
similarity[0].shape

torch.Size([127])

In [73]:
similarity.min()

tensor(0.5016)

In [90]:
similarity.mean()

tensor(0.9176)

In [74]:
similarity.max()

tensor(1.0000)

In [79]:
from audio.audio import AudioSlicer

In [80]:
get_connections('diarization',redis_client)

<coroutine object get_connections at 0x7fe55589d6c0>

In [81]:
audio_file_path = f'/audio/{id}.webm'

In [82]:
files = list_directory_contents('/audio/')

In [83]:
files[0]

PosixPath('/audio/42381612-e5e2-46ba-bb5a-73d17d25f90f--0.webm')

In [91]:
call = await AudioSlicer.from_file(files[0],format= 'webm')

In [93]:
call.slice(0,1000)

In [94]:
230/60

3.8333333333333335