In [None]:
import os
import torch
import shutil
import pandas as pd
from tqdm import tqdm
from dotenv import load_dotenv
import whisperx
import gc
import json
import re
import nltk
from deepmultilingualpunctuation import PunctuationModel

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# https://github.com/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb
def get_words_speaker_mapping(wrd_ts, spk_ts):
    turn_idx = 0
    wrd_spk_mapping = []

    for wrd_dict in wrd_ts:
        ws, we, wrd = wrd_dict["start"], wrd_dict["end"], wrd_dict["word"]

        while turn_idx < len(spk_ts) - 1 and ws > spk_ts.iloc[turn_idx]["end"]:
            turn_idx += 1

        current_row = spk_ts.iloc[turn_idx]
        wrd_spk_mapping.append({
            "word": wrd,
            "start_time": ws,
            "end_time": we,
            "speaker": current_row["speaker"]
        })

    return wrd_spk_mapping

sentence_ending_punctuations = ".?!"

def get_realigned_ws_mapping_with_punctuation(word_speaker_mapping, max_words_in_sentence=60):
    words_list = [d["word"] for d in word_speaker_mapping]
    speaker_list = [d["speaker"] for d in word_speaker_mapping]

    # Funzioni di utilità per trovare l'inizio e la fine di una frase
    def is_word_sentence_end(idx):
        return idx >= 0 and words_list[idx][-1] in sentence_ending_punctuations

    def get_first_word_idx(current_idx):
        left_idx = current_idx
        while (left_idx > 0 and
               current_idx - left_idx < max_words_in_sentence and
               speaker_list[left_idx - 1] == speaker_list[left_idx] and
               not is_word_sentence_end(left_idx - 1)):
            left_idx -= 1
        return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1

    def get_last_word_idx(current_idx):
        right_idx = current_idx
        while (right_idx < len(words_list) - 1 and
               right_idx - current_idx < max_words_in_sentence and
               not is_word_sentence_end(right_idx)):
            right_idx += 1
        return right_idx if right_idx == len(words_list) - 1 or is_word_sentence_end(right_idx) else -1

    # Itera e corregge
    k = 0
    while k < len(word_speaker_mapping) - 1:
        if speaker_list[k] != speaker_list[k + 1] and not is_word_sentence_end(k):
            left_idx = get_first_word_idx(k)
            right_idx = get_last_word_idx(k) if left_idx > -1 else -1

            if left_idx != -1 and right_idx != -1:
                sub_speaker_list = speaker_list[left_idx : right_idx + 1]
                # Assegna lo speaker più frequente a tutta la frase
                '''dominant_speaker = max(set(sub_speaker_list), key=sub_speaker_list.count)
                for i in range(left_idx, right_idx + 1):
                    speaker_list[i] = dominant_speaker
                k = right_idx'''
                dominant_speaker = max(set(sub_speaker_list), key=sub_speaker_list.count)
                # Aggiungi un controllo di robustezza
                if sub_speaker_list.count(dominant_speaker) >= len(sub_speaker_list) / 2:
                    for i in range(left_idx, right_idx + 1):
                        speaker_list[i] = dominant_speaker
                k = right_idx
        k += 1

    # Crea la lista riallineata
    realigned_list = []
    for i, d in enumerate(word_speaker_mapping):
        new_dict = d.copy()
        new_dict["speaker"] = speaker_list[i]
        realigned_list.append(new_dict)

    return realigned_list


def get_sentences_speaker_mapping(word_speaker_mapping):
    """
    Raggruppa le parole (con speaker assegnato) in frasi.
    Una nuova frase inizia quando cambia lo speaker o quando si incontra
    una fine di frase grammaticale.
    """
    sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak

    sentences = []
    current_sentence = None

    for wrd_dict in word_speaker_mapping:
        word, speaker = wrd_dict["word"], wrd_dict["speaker"]
        start_time, end_time = wrd_dict["start_time"], wrd_dict["end_time"]

        is_new_sentence = (
            current_sentence is None or
            speaker != current_sentence["speaker"] or
            sentence_checker(current_sentence["text"] + " " + word)
        )

        if is_new_sentence:
            if current_sentence:
                sentences.append(current_sentence)
            current_sentence = {
                "speaker": speaker,
                "start_time": start_time,
                "end_time": end_time,
                "text": word,
            }
        else:
            current_sentence["end_time"] = end_time
            current_sentence["text"] += " " + word

    if current_sentence:
        sentences.append(current_sentence)

    return sentences

In [None]:
def transcribe_sessions(dataset_dir, sessions, device, language="en"):
    compute_type = "float16" if torch.cuda.is_available() else "float32"
    batch_size=32
    model_id = "tiny.en"
    temp_dir = os.path.join(dataset_dir, "temp_results")
    os.makedirs(temp_dir, exist_ok=True) 

    model = whisperx.load_model(model_id, device, compute_type=compute_type, language=language)

    for session in tqdm(sessions, desc="Trascrizione Audio"):
        session_path = os.path.join(dataset_dir, session)
        base_name = session.split("_")[0]
        audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
        intermediate_path = os.path.join(temp_dir, f"{session}_transcript.json")
        if os.path.exists(intermediate_path):
            print(f"Skippo {audio_path} perché la trascrizione esiste già in {intermediate_path}")
            continue
        print(f"\nSto processando: {audio_path}")
        
        audio = whisperx.load_audio(audio_path)
        result = model.transcribe(audio, batch_size=batch_size)

        with open(intermediate_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=4)
        print(f"Trascrizione intermedia salvata in: {intermediate_path}")

    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
def align_sessions(dataset_dir, sessions, device, language="en"):
    temp_dir = os.path.join(dataset_dir, "temp_results")
    model_a, metadata = whisperx.load_align_model(language_code=language, device=device)

    for session in tqdm(sessions, desc="Allineamento Audio"):
        intermediate_path = os.path.join(temp_dir, f"{session}_transcript.json")
        aligned_path = os.path.join(temp_dir, f"{session}_aligned.json")
        session_path = os.path.join(dataset_dir, session)
        base_name = session.split("_")[0]
        audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
        if os.path.exists(aligned_path):
            print(f"Skippo {audio_path} perché l'allineamento esiste già in {aligned_path}")
            continue
        print(f"\nSto allineando: {session}")

        # Carica il risultato della trascrizione
        with open(intermediate_path, 'r', encoding='utf-8') as f:
            result = json.load(f)

        audio = whisperx.load_audio(audio_path)

        # Esegui l'allineamento
        result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
        with open(aligned_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=4)
        print(f"Allineamento salvato in: {aligned_path}")

    del model_a
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
def diarize_sessions(dataset_dir, sessions, device):
    temp_dir = os.path.join(dataset_dir, "temp_results")
    load_dotenv()
    diarize_model = whisperx.diarize.DiarizationPipeline(
        use_auth_token=os.getenv("HUGGINGFACE_TOKEN"), 
        device=device
    )
    punct_model = PunctuationModel(model="kredor/punctuate-all")

    for session in tqdm(sessions, desc="Diarizzazione Speaker"):
        session_path = os.path.join(dataset_dir, session)
        base_name = session.split("_")[0]
        audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
        aligned_path = os.path.join(temp_dir, f"{session}_aligned.json")
        temp_transcript_path = os.path.join(temp_dir, f"{base_name}_TRANSCRIPT.csv")
        transcript_path = os.path.join(session_path, f"{base_name}_TRANSCRIPT.csv")
        old_transcript_path = os.path.join(session_path, f"{base_name}_transcript.csv")
        if os.path.exists(old_transcript_path):
            os.remove(old_transcript_path)
        
        if os.path.exists(temp_transcript_path) or os.path.exists(transcript_path):
            print(f"Skippo {audio_path} perché la trascrizione esiste già in {temp_transcript_path}")
            if not os.path.exists(transcript_path): # per trasferire temp_results fatto in un altra macchina
                shutil.copy2(temp_transcript_path, transcript_path)
            continue
                
        print(f"\nSto diarizzando: {session}")
        with open(aligned_path, 'r', encoding='utf-8') as f:
            aligned_result = json.load(f)
        # `whisperx.align` salva i risultati in 'segments' e 'word_segments'.
        word_timestamps = aligned_result.get("word_segments")

        # Esegui la diarizzazione sull'audio completo
        audio = whisperx.load_audio(audio_path)
        speaker_timestamps = diarize_model(audio, min_speakers=2, max_speakers=3) 

        # Assegna gli speaker alle parole
        wsm = get_words_speaker_mapping(word_timestamps, speaker_timestamps)
        
        words_list = [item['word'] for item in wsm]
        labeled_words = punct_model.predict(words_list)
        # Questa parte serve a migliorare il raggruppamento in frasi
        ending_puncts = ".?!"
        model_puncts = ".,;:!?"
        is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
        for word_dict, labeled_tuple in zip(wsm, labeled_words):
            word = word_dict["word"]
            if (word and labeled_tuple[1] in ending_puncts and (word[-1] not in model_puncts or is_acronym(word))):
                word += labeled_tuple[1]
                if word.endswith(".."): word = word.rstrip(".")
                word_dict["word"] = word

        wsm = get_realigned_ws_mapping_with_punctuation(wsm)
        ssm = get_sentences_speaker_mapping(wsm, speaker_timestamps)

        # 6. Salva nel formato CSV finale
        final_segments = []
        for s in ssm:
            final_segments.append({"start_time": s["start_time"], "stop_time": s["end_time"], "speaker": s["speaker"], "value": s["text"].strip()})

        df = pd.DataFrame(final_segments)
        if not df.empty:
            df = df[["start_time", "stop_time", "speaker", "value"]]
        else:
            df = pd.DataFrame(columns=["start_time", "stop_time", "speaker", "value"])

        df.to_csv(temp_transcript_path, sep="\t", index=False)
        df.to_csv(transcript_path, sep="\t", index=False)

    del diarize_model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
edaic_dir = "../datasets/EDAIC-WOZ"

sessions = sorted([d for d in os.listdir(edaic_dir) \
                    if os.path.isdir(os.path.join(edaic_dir, d)) and d.endswith('_P')])
transcribe_sessions(edaic_dir, sessions, device)
align_sessions(edaic_dir, sessions, device)
diarize_sessions(edaic_dir, sessions, device)

In [None]:
daic_dir = "../datasets/DAIC-WOZ" 

sessions = ["318_P", "321_P", "341_P", "362_P"] # https://github.com/adbailey1/daic_woz_process/tree/master
transcribe_sessions(daic_dir, sessions, device)
align_sessions(daic_dir, sessions, device)
diarize_sessions(daic_dir, sessions, device)