In [1]:
#!sudo apt install libcudnn8 libcudnn8-dev -y

In [2]:
import os
import torch
import shutil
import pandas as pd
from tqdm import tqdm
from dotenv import load_dotenv
import whisperx
import gc
import json
import re
import nltk
from deepmultilingualpunctuation import PunctuationModel

device ="cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [3]:
# https://github.com/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb
def get_words_speaker_mapping(wrd_ts, spk_ts):
    turn_idx = 0
    wrd_spk_mapping = []

    for wrd_dict in wrd_ts:
        ws, we, wrd = wrd_dict["start"], wrd_dict["end"], wrd_dict["word"]

        while turn_idx < len(spk_ts) - 1 and ws > spk_ts.iloc[turn_idx]["end"]:
            turn_idx += 1

        current_row = spk_ts.iloc[turn_idx]
        wrd_spk_mapping.append({
            "word": wrd,
            "start_time": ws,
            "end_time": we,
            "speaker": current_row["speaker"]
        })

    return wrd_spk_mapping

sentence_ending_punctuations = ".?!"

def get_realigned_ws_mapping_with_punctuation(word_speaker_mapping, max_words_in_sentence=60):
    words_list = [d["word"] for d in word_speaker_mapping]
    speaker_list = [d["speaker"] for d in word_speaker_mapping]

    # Funzioni di utilità per trovare l'inizio e la fine di una frase
    def is_word_sentence_end(idx):
        return idx >= 0 and words_list[idx][-1] in sentence_ending_punctuations

    def get_first_word_idx(current_idx):
        left_idx = current_idx
        while (left_idx > 0 and
               current_idx - left_idx < max_words_in_sentence and
               speaker_list[left_idx - 1] == speaker_list[left_idx] and
               not is_word_sentence_end(left_idx - 1)):
            left_idx -= 1
        return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1

    def get_last_word_idx(current_idx):
        right_idx = current_idx
        while (right_idx < len(words_list) - 1 and
               right_idx - current_idx < max_words_in_sentence and
               not is_word_sentence_end(right_idx)):
            right_idx += 1
        return right_idx if right_idx == len(words_list) - 1 or is_word_sentence_end(right_idx) else -1

    # Itera e corregge
    k = 0
    while k < len(word_speaker_mapping) - 1:
        if speaker_list[k] != speaker_list[k + 1] and not is_word_sentence_end(k):
            left_idx = get_first_word_idx(k)
            right_idx = get_last_word_idx(k) if left_idx > -1 else -1

            if left_idx != -1 and right_idx != -1:
                sub_speaker_list = speaker_list[left_idx : right_idx + 1]
                # Assegna lo speaker più frequente a tutta la frase
                '''dominant_speaker = max(set(sub_speaker_list), key=sub_speaker_list.count)
                for i in range(left_idx, right_idx + 1):
                    speaker_list[i] = dominant_speaker
                k = right_idx'''
                dominant_speaker = max(set(sub_speaker_list), key=sub_speaker_list.count)
                # Aggiungi un controllo di robustezza
                if sub_speaker_list.count(dominant_speaker) >= len(sub_speaker_list) / 2:
                    for i in range(left_idx, right_idx + 1):
                        speaker_list[i] = dominant_speaker
                k = right_idx
        k += 1

    # Crea la lista riallineata
    realigned_list = []
    for i, d in enumerate(word_speaker_mapping):
        new_dict = d.copy()
        new_dict["speaker"] = speaker_list[i]
        realigned_list.append(new_dict)

    return realigned_list


def get_sentences_speaker_mapping(word_speaker_mapping):
    """
    Raggruppa le parole (con speaker assegnato) in frasi.
    Una nuova frase inizia quando cambia lo speaker o quando si incontra
    una fine di frase grammaticale.
    """
    sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak

    sentences = []
    current_sentence = None

    for wrd_dict in word_speaker_mapping:
        word, speaker = wrd_dict["word"], wrd_dict["speaker"]
        start_time, end_time = wrd_dict["start_time"], wrd_dict["end_time"]

        is_new_sentence = (
            current_sentence is None or
            speaker != current_sentence["speaker"] or
            sentence_checker(current_sentence["text"] + " " + word)
        )

        if is_new_sentence:
            if current_sentence:
                sentences.append(current_sentence)
            current_sentence = {
                "speaker": speaker,
                "start_time": start_time,
                "end_time": end_time,
                "text": word,
            }
        else:
            current_sentence["end_time"] = end_time
            current_sentence["text"] += " " + word

    if current_sentence:
        sentences.append(current_sentence)

    return sentences

In [4]:
def transcribe_sessions(dataset_dir, sessions, device, language="en"):
    compute_type = "float16" if device == 'cuda' else "float32"
    batch_size=32
    model_id = "large-v2"
    temp_dir = os.path.join(dataset_dir, "temp_results")
    os.makedirs(temp_dir, exist_ok=True) 

    model = whisperx.load_model(model_id, device, compute_type=compute_type, language=language,local_files_only=False)

    for session in tqdm(sessions, desc="Trascrizione Audio"):
        session_path = os.path.join(dataset_dir, session)
        base_name = session.split("_")[0]
        audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
        intermediate_path = os.path.join(temp_dir, f"{session}_transcript.json")
        if os.path.exists(intermediate_path):
            print(f"Skippo {audio_path} perché la trascrizione esiste già in {intermediate_path}")
            continue
        print(f"\nSto processando: {audio_path}")
        
        audio = whisperx.load_audio(audio_path)
        result = model.transcribe(audio, batch_size=batch_size)

        with open(intermediate_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=4)
        print(f"Trascrizione intermedia salvata in: {intermediate_path}")

    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [5]:
def align_sessions(dataset_dir, sessions, device, language="en"):
    temp_dir = os.path.join(dataset_dir, "temp_results")
    model_a, metadata = whisperx.load_align_model(language_code=language, device=device)

    for session in tqdm(sessions, desc="Allineamento Audio"):
        intermediate_path = os.path.join(temp_dir, f"{session}_transcript.json")
        aligned_path = os.path.join(temp_dir, f"{session}_aligned.json")
        if os.path.exists(aligned_path):
            print(f"Skippo {audio_path} perché l'allineamento esiste già in {aligned_path}")
            continue
        print(f"\nSto allineando: {session}")

        # Carica il risultato della trascrizione
        with open(intermediate_path, 'r', encoding='utf-8') as f:
            result = json.load(f)

        # Per l'allineamento è necessario ricaricare l'audio
        session_path = os.path.join(dataset_dir, session)
        base_name = session.split("_")[0]
        audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
        audio = whisperx.load_audio(audio_path)

        # Esegui l'allineamento
        result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
        with open(aligned_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=4)
        print(f"Allineamento salvato in: {aligned_path}")

    del model_a
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
def diarize_sessions(dataset_dir, sessions, device):
    temp_dir = os.path.join(dataset_dir, "temp_results")
    load_dotenv()
    diarize_model = whisperx.diarize.DiarizationPipeline(
        use_auth_token=os.getenv("HUGGINGFACE_TOKEN"), 
        device=device
    )
    punct_model = PunctuationModel(model="kredor/punctuate-all")

    for session in tqdm(sessions, desc="Diarizzazione Speaker"):
        session_path = os.path.join(dataset_dir, session)
        base_name = session.split("_")[0]
        audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
        aligned_path = os.path.join(temp_dir, f"{session}_aligned.json")
        temp_transcript_path = os.path.join(temp_dir, f"{base_name}_TRANSCRIPT.csv")
        transcript_path = os.path.join(session_path, f"{base_name}_TRANSCRIPT.csv")
        old_transcript_path = os.path.join(session_path, f"{base_name}_transcript.csv")
        if os.path.exists(old_transcript_path):
            os.remove(old_transcript_path)
        
        if os.path.exists(temp_transcript_path) or os.path.exists(transcript_path):
            print(f"Skippo {audio_path} perché la trascrizione esiste già in {temp_transcript_path}")
            if not os.path.exists(transcript_path): # per trasferire temp_results fatto in un altra macchina
                shutil.copy2(temp_transcript_path, transcript_path)
            continue
                
        print(f"\nSto diarizzando: {session}")
        with open(aligned_path, 'r', encoding='utf-8') as f:
            aligned_result = json.load(f)
        # `whisperx.align` salva i risultati in 'segments' e 'word_segments'.
        word_timestamps = aligned_result.get("word_segments")

        # Esegui la diarizzazione sull'audio completo
        audio = whisperx.load_audio(audio_path)
        speaker_timestamps = diarize_model(audio, min_speakers=2, max_speakers=3) 

        # Assegna gli speaker alle parole
        wsm = get_words_speaker_mapping(word_timestamps, speaker_timestamps)
        
        words_list = [item['word'] for item in wsm]
        labeled_words = punct_model.predict(words_list)
        # Questa parte serve a migliorare il raggruppamento in frasi
        ending_puncts = ".?!"
        model_puncts = ".,;:!?"
        is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
        for word_dict, labeled_tuple in zip(wsm, labeled_words):
            word = word_dict["word"]
            if (word and labeled_tuple[1] in ending_puncts and (word[-1] not in model_puncts or is_acronym(word))):
                word += labeled_tuple[1]
                if word.endswith(".."): word = word.rstrip(".")
                word_dict["word"] = word

        wsm = get_realigned_ws_mapping_with_punctuation(wsm)
        ssm = get_sentences_speaker_mapping(wsm)

        # 6. Salva nel formato CSV finale
        final_segments = []
        for s in ssm:
            final_segments.append({"start_time": s["start_time"], "stop_time": s["end_time"], "speaker": s["speaker"], "value": s["text"].strip()})

        df = pd.DataFrame(final_segments)
        if not df.empty:
            df = df[["start_time", "stop_time", "speaker", "value"]]
        else:
            df = pd.DataFrame(columns=["start_time", "stop_time", "speaker", "value"])

        df.to_csv(temp_transcript_path, sep="\t", index=False)
        df.to_csv(transcript_path, sep="\t", index=False)

    del diarize_model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [7]:
edaic_dir = "../datasets/EDAIC-WOZ"

sessions = sorted([d for d in os.listdir(edaic_dir) \
                    if os.path.isdir(os.path.join(edaic_dir, d)) and d.endswith('_P')])
transcribe_sessions(edaic_dir, sessions, device)
#align_sessions(edaic_dir, sessions, device)
diarize_sessions(edaic_dir, sessions, device)

  if ismodule(module) and hasattr(module, '__file__'):
Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint c:\Users\anto-\Miniconda3\envs\speech_project\lib\site-packages\whisperx\assets\pytorch_model.bin`


>>Performing voice activity detection using Pyannote...
Model was trained with pyannote.audio 0.0.1, yours is 3.3.2. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.6.0+cu124. Bad things might happen unless you revert torch to 1.x.


Trascrizione Audio: 100%|██████████| 29/29 [00:00<00:00, 28981.37it/s]

Skippo ../datasets/EDAIC-WOZ\602_P\602_AUDIO.wav perché la trascrizione esiste già in ../datasets/EDAIC-WOZ\temp_results\602_P_transcript.json
Skippo ../datasets/EDAIC-WOZ\604_P\604_AUDIO.wav perché la trascrizione esiste già in ../datasets/EDAIC-WOZ\temp_results\604_P_transcript.json
Skippo ../datasets/EDAIC-WOZ\617_P\617_AUDIO.wav perché la trascrizione esiste già in ../datasets/EDAIC-WOZ\temp_results\617_P_transcript.json
Skippo ../datasets/EDAIC-WOZ\624_P\624_AUDIO.wav perché la trascrizione esiste già in ../datasets/EDAIC-WOZ\temp_results\624_P_transcript.json
Skippo ../datasets/EDAIC-WOZ\633_P\633_AUDIO.wav perché la trascrizione esiste già in ../datasets/EDAIC-WOZ\temp_results\633_P_transcript.json
Skippo ../datasets/EDAIC-WOZ\636_P\636_AUDIO.wav perché la trascrizione esiste già in ../datasets/EDAIC-WOZ\temp_results\636_P_transcript.json
Skippo ../datasets/EDAIC-WOZ\637_P\637_AUDIO.wav perché la trascrizione esiste già in ../datasets/EDAIC-WOZ\temp_results\637_P_transcript.json


Device set to use cuda:0
It can be re-enabled by calling
   >>> import torch
   >>> torch.backends.cuda.matmul.allow_tf32 = True
   >>> torch.backends.cudnn.allow_tf32 = True
See https://github.com/pyannote/pyannote-audio/issues/1370 for more details.




Sto diarizzando: 602_P


  std = sequences.std(dim=-1, correction=1)
Diarizzazione Speaker:   3%|▎         | 1/29 [00:15<07:06, 15.23s/it]


Sto diarizzando: 604_P


Diarizzazione Speaker:   7%|▋         | 2/29 [00:26<05:55, 13.18s/it]


Sto diarizzando: 617_P


You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
Diarizzazione Speaker:  10%|█         | 3/29 [00:53<08:18, 19.18s/it]


Sto diarizzando: 624_P


Diarizzazione Speaker:  14%|█▍        | 4/29 [01:09<07:30, 18.02s/it]


Sto diarizzando: 633_P


Diarizzazione Speaker:  17%|█▋        | 5/29 [01:40<09:07, 22.83s/it]


Sto diarizzando: 636_P


Diarizzazione Speaker:  21%|██        | 6/29 [02:13<10:03, 26.25s/it]


Sto diarizzando: 637_P


Diarizzazione Speaker:  24%|██▍       | 7/29 [02:45<10:19, 28.17s/it]


Sto diarizzando: 638_P


Diarizzazione Speaker:  28%|██▊       | 8/29 [02:55<07:49, 22.36s/it]


Sto diarizzando: 640_P


Diarizzazione Speaker:  31%|███       | 9/29 [03:11<06:44, 20.24s/it]


Sto diarizzando: 641_P


Diarizzazione Speaker:  34%|███▍      | 10/29 [03:25<05:48, 18.36s/it]


Sto diarizzando: 649_P


Diarizzazione Speaker:  38%|███▊      | 11/29 [03:44<05:33, 18.51s/it]


Sto diarizzando: 655_P


Diarizzazione Speaker:  41%|████▏     | 12/29 [04:00<05:02, 17.81s/it]


Sto diarizzando: 658_P


Diarizzazione Speaker:  45%|████▍     | 13/29 [04:19<04:49, 18.07s/it]


Sto diarizzando: 659_P


Diarizzazione Speaker:  48%|████▊     | 14/29 [04:47<05:19, 21.28s/it]


Sto diarizzando: 661_P


Diarizzazione Speaker:  52%|█████▏    | 15/29 [05:12<05:11, 22.25s/it]


Sto diarizzando: 673_P


Diarizzazione Speaker:  55%|█████▌    | 16/29 [05:23<04:06, 18.97s/it]


Sto diarizzando: 677_P


Diarizzazione Speaker:  59%|█████▊    | 17/29 [05:41<03:41, 18.46s/it]


Sto diarizzando: 680_P


Diarizzazione Speaker:  62%|██████▏   | 18/29 [06:08<03:51, 21.00s/it]


Sto diarizzando: 682_P


Diarizzazione Speaker:  66%|██████▌   | 19/29 [06:21<03:08, 18.84s/it]


Sto diarizzando: 684_P


Diarizzazione Speaker:  69%|██████▉   | 20/29 [06:42<02:53, 19.33s/it]


Sto diarizzando: 688_P


Diarizzazione Speaker:  72%|███████▏  | 21/29 [07:03<02:39, 19.93s/it]


Sto diarizzando: 689_P


Diarizzazione Speaker:  76%|███████▌  | 22/29 [07:20<02:13, 19.04s/it]


Sto diarizzando: 691_P


Diarizzazione Speaker:  79%|███████▉  | 23/29 [07:45<02:04, 20.76s/it]


Sto diarizzando: 696_P


Diarizzazione Speaker:  83%|████████▎ | 24/29 [07:59<01:34, 18.83s/it]


Sto diarizzando: 698_P


Diarizzazione Speaker:  86%|████████▌ | 25/29 [08:26<01:25, 21.31s/it]


Sto diarizzando: 699_P


Diarizzazione Speaker:  90%|████████▉ | 26/29 [08:39<00:56, 18.86s/it]


Sto diarizzando: 705_P


Diarizzazione Speaker:  93%|█████████▎| 27/29 [08:54<00:35, 17.61s/it]


Sto diarizzando: 709_P


Diarizzazione Speaker:  97%|█████████▋| 28/29 [09:10<00:17, 17.06s/it]


Sto diarizzando: 716_P


Diarizzazione Speaker: 100%|██████████| 29/29 [09:28<00:00, 19.62s/it]


In [9]:
daic_dir = "../datasets/DAIC-WOZ" 

sessions = ["318_P", "321_P", "341_P", "362_P"] # https://github.com/adbailey1/daic_woz_process/tree/master
transcribe_sessions(daic_dir, sessions, device)
align_sessions(daic_dir, sessions, device)
diarize_sessions(daic_dir, sessions, device)

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint c:\Users\anto-\Miniconda3\envs\speech_project\lib\site-packages\whisperx\assets\pytorch_model.bin`


>>Performing voice activity detection using Pyannote...
Model was trained with pyannote.audio 0.0.1, yours is 3.3.2. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.6.0+cu124. Bad things might happen unless you revert torch to 1.x.


Trascrizione Audio:   0%|          | 0/4 [00:00<?, ?it/s]


Sto processando: ../datasets/DAIC-WOZ\318_P\318_AUDIO.wav


Trascrizione Audio:  25%|██▌       | 1/4 [02:21<07:05, 141.88s/it]

Trascrizione intermedia salvata in: ../datasets/DAIC-WOZ\temp_results\318_P_transcript.json

Sto processando: ../datasets/DAIC-WOZ\321_P\321_AUDIO.wav


Trascrizione Audio:  50%|█████     | 2/4 [05:52<06:04, 182.05s/it]

Trascrizione intermedia salvata in: ../datasets/DAIC-WOZ\temp_results\321_P_transcript.json

Sto processando: ../datasets/DAIC-WOZ\341_P\341_AUDIO.wav


Trascrizione Audio:  75%|███████▌  | 3/4 [09:49<03:27, 207.22s/it]

Trascrizione intermedia salvata in: ../datasets/DAIC-WOZ\temp_results\341_P_transcript.json

Sto processando: ../datasets/DAIC-WOZ\362_P\362_AUDIO.wav


Trascrizione Audio: 100%|██████████| 4/4 [12:20<00:00, 185.11s/it]

Trascrizione intermedia salvata in: ../datasets/DAIC-WOZ\temp_results\362_P_transcript.json



Allineamento Audio:   0%|          | 0/4 [00:00<?, ?it/s]


Sto allineando: 318_P


Allineamento Audio:  25%|██▌       | 1/4 [00:07<00:22,  7.44s/it]

Allineamento salvato in: ../datasets/DAIC-WOZ\temp_results\318_P_aligned.json

Sto allineando: 321_P


Allineamento Audio:  50%|█████     | 2/4 [00:16<00:16,  8.12s/it]

Allineamento salvato in: ../datasets/DAIC-WOZ\temp_results\321_P_aligned.json

Sto allineando: 341_P


Allineamento Audio:  75%|███████▌  | 3/4 [00:25<00:08,  8.83s/it]

Allineamento salvato in: ../datasets/DAIC-WOZ\temp_results\341_P_aligned.json

Sto allineando: 362_P


Allineamento Audio: 100%|██████████| 4/4 [00:31<00:00,  7.98s/it]

Allineamento salvato in: ../datasets/DAIC-WOZ\temp_results\362_P_aligned.json



Device set to use cuda:0
Diarizzazione Speaker:   0%|          | 0/4 [00:00<?, ?it/s]


Sto diarizzando: 318_P


  std = sequences.std(dim=-1, correction=1)
Diarizzazione Speaker:  25%|██▌       | 1/4 [00:11<00:33, 11.05s/it]


Sto diarizzando: 321_P


Diarizzazione Speaker:  50%|█████     | 2/4 [00:26<00:27, 13.67s/it]


Sto diarizzando: 341_P


Diarizzazione Speaker:  75%|███████▌  | 3/4 [00:43<00:14, 14.98s/it]


Sto diarizzando: 362_P


Diarizzazione Speaker: 100%|██████████| 4/4 [00:54<00:00, 13.57s/it]
