In [None]:
import os
import torch
import shutil
import torchaudio
import pandas as pd
import sounddevice as sd
from dotenv import load_dotenv
from tqdm import tqdm
import whisperx
import gc
import json

In [None]:
base_dir = "../datasets/EDAIC-WOZ"

# Remove folders from 300 to 492
for i in range(300, 493):
    folder_name = f"{i}_P"
    folder_path = os.path.join(base_dir, folder_name)
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)

# XXX_P/XXX_P -> XXX_P
for dir_name in os.listdir(base_dir):
    outer_path = os.path.join(base_dir, dir_name)
    if os.path.isdir(outer_path):
        inner_path = os.path.join(outer_path, dir_name)
        if os.path.isdir(inner_path):
            # Sposta tutti i file dal secondo livello al primo
            for filename in os.listdir(inner_path):
                src = os.path.join(inner_path, filename)
                dst = os.path.join(outer_path, filename)
                shutil.move(src, dst)
            # Rimuove la cartella interna vuota
            os.rmdir(inner_path)

# Csv files to concatenate
csv_files = ['dev_split.csv', 'test_split.csv', 'train_split.csv']
dfs = []

for csv_file in csv_files:
    path = os.path.join(base_dir, csv_file)
    df = pd.read_csv(path)
    dfs.append(df)

all_data = pd.concat(dfs, ignore_index=True)

# Filter from 300 to 492
all_data = all_data[~all_data['Participant_ID'].between(300, 492)]

# Save
output_path = os.path.join(base_dir, "all_data.csv")
all_data.to_csv(output_path, index=False)

In [None]:
# Fix inconsistent PHQ labels
# Find participants with PHQ_Score >= 10 but PHQ_Binary = 0
inconsistent_mask = (all_data['PHQ_Score'] >= 10) & (all_data['PHQ_Binary'] == 0)
inconsistent_participants = all_data[inconsistent_mask]['Participant_ID'].tolist()

print(f"Found {len(inconsistent_participants)} participants with inconsistent PHQ labels:")
for participant_id in inconsistent_participants:
    phq_score = all_data[all_data['Participant_ID'] == participant_id]['PHQ_Score'].iloc[0]
    print(f"  Participant {participant_id}: PHQ_Score={phq_score}, PHQ_Binary=0 -> fixing to PHQ_Binary=1")

# Fix the inconsistent labels
all_data.loc[inconsistent_mask, 'PHQ_Binary'] = 1

print(f"\nFixed {len(inconsistent_participants)} inconsistent labels")

# Save the corrected data
all_data.to_csv(output_path, index=False)

In [None]:
# Delete directories for non-depressed participants (PHQ_Binary = 0)
non_depressed_participants = all_data[all_data['PHQ_Binary'] == 0]['Participant_ID'].unique()

print(f"Found {len(non_depressed_participants)} non-depressed participants to remove:")

deleted_count = 0
for participant_id in non_depressed_participants:
    folder_name = f"{participant_id}_P"
    folder_path = os.path.join(base_dir, folder_name)
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)
        deleted_count += 1

print(f"Deleted {deleted_count} directories for non-depressed participants")

# Remove non-depressed participants from the CSV data
all_data = all_data[all_data['PHQ_Binary'] == 1]
all_data.to_csv(output_path, index=False)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
compute_type = "float16" if torch.cuda.is_available() else "float32"
batch_size=32
model_id = "tiny.en"
language = "en"
temp_dir = os.path.join(base_dir, "temp_results")
os.makedirs(temp_dir, exist_ok=True) 

sessions = sorted([d for d in os.listdir(base_dir) \
                   if os.path.isdir(os.path.join(base_dir, d)) and d.endswith('_P')])
sessions = sessions[:1]

model = whisperx.load_model(model_id, device, compute_type=compute_type, language=language)

for session in tqdm(sessions, desc="Trascrizione Audio"):
    session_path = os.path.join(base_dir, session)
    base_name = session.split("_")[0]
    audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
    intermediate_path = os.path.join(temp_dir, f"{session}_transcript.json")
    if not os.path.exists(audio_path):
        print(f"Audio non trovato: {audio_path}. Salto la sessione.")
        continue
    print(f"\nSto processando: {audio_path}")
    
    audio = whisperx.load_audio(audio_path)
    result = model.transcribe(audio, batch_size=batch_size)
    print(result)

    with open(intermediate_path, 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=4)
    print(f"Trascrizione intermedia salvata in: {intermediate_path}")

del model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

model_a, metadata = whisperx.load_align_model(language_code=language, device=device)

for session in tqdm(sessions, desc="Allineamento Audio"):
    intermediate_path = os.path.join(temp_dir, f"{session}_transcript.json")
    if not os.path.exists(intermediate_path):
        print(f"File intermedio non trovato: {intermediate_path}. Salto.")
        continue
    print(f"\nSto allineando: {session}")

    # Carica il risultato della trascrizione
    with open(intermediate_path, 'r', encoding='utf-8') as f:
        result = json.load(f)

    # Se la trascrizione è vuota, salta l'allineamento
    if not result["segments"]:
        print("Nessun segmento da allineare.")
        continue
        
    # Per l'allineamento è necessario ricaricare l'audio
    session_path = os.path.join(base_dir, session)
    base_name = session.split("_")[0]
    audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
    audio = whisperx.load_audio(audio_path)
    
    # Esegui l'allineamento
    result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
    print(result)

    # Aggiorna il file intermedio con i dati di allineamento
    with open(intermediate_path, 'w', encoding='utf-8') as f:
        # result ora contiene 'word_segments', ma lo salviamo sotto la chiave 'segments'
        json.dump(result, f, ensure_ascii=False, indent=4)
        
    print(f"Allineamento salvato in: {intermediate_path}")

del model_a
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

load_dotenv()
diarize_model = whisperx.diarize.DiarizationPipeline(
    use_auth_token=os.getenv("HUGGINGFACE_TOKEN"), 
    device=device
)

for session in tqdm(sessions, desc="Diarizzazione Speaker"):
    session_path = os.path.join(base_dir, session)
    base_name = session.split("_")[0]
    audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
    intermediate_path = os.path.join(temp_dir, f"{session}_transcript.json")
    transcript_path = os.path.join(session_path, f"{base_name}_TRANSCRIPT.csv")
    
    if not os.path.exists(intermediate_path):
        print(f"File intermedio non trovato: {intermediate_path}. Salto.")
        continue

    print(f"\nSto diarizzando: {session}")
    
    # Carica il risultato allineato
    with open(intermediate_path, 'r', encoding='utf-8') as f:
        result = json.load(f)

    # Se non ci sono segmenti, crea un CSV vuoto
    if not result["segments"]:
        print(f"Nessun segmento rilevato in {audio_path}. Salvataggio di un transcript vuoto.")
        pd.DataFrame(columns=["start_time", "stop_time", "speaker", "value"]).to_csv(transcript_path, sep="\t", index=False)
        continue

    # Esegui la diarizzazione sull'audio completo
    audio = whisperx.load_audio(audio_path)
    diarize_segments = diarize_model(audio, min_speakers=2, max_speakers=4) 

    # Assegna gli speaker alle parole
    result = whisperx.assign_word_speakers(diarize_segments, result)
    print(result)

    # Raggruppa le parole in segmenti per speaker
    final_segments = []
    current_segment = None
    all_words = []
    for segment in result["segments"]:
        all_words.extend(segment.get("words", []))

    if not all_words:
        print(f"Nessuna parola trovata dopo l'allineamento per {session}. Salto la creazione dei segmenti.")
        pd.DataFrame(columns=["start_time", "stop_time", "speaker", "value"]).to_csv(transcript_path, sep="\t", index=False)
        continue

    for word_info in all_words:
        if 'speaker' not in word_info or 'start' not in word_info or 'end' not in word_info:
            continue

        speaker = word_info["speaker"]
        
        if current_segment is None:
            current_segment = {"start_time": word_info["start"], "speaker": speaker, "value": word_info["word"]}
        elif speaker != current_segment["speaker"]:
            current_segment["stop_time"] = last_word_end
            final_segments.append(current_segment)
            current_segment = {"start_time": word_info["start"], "speaker": speaker, "value": word_info["word"]}
        else:
            current_segment["value"] += " " + word_info["word"]

        last_word_end = word_info["end"]

    if current_segment is not None:
        current_segment["stop_time"] = last_word_end
        final_segments.append(current_segment)
    
    # Creazione e salvataggio del DataFrame finale
    df = pd.DataFrame(final_segments)
    if not df.empty:
        df = df[["start_time", "stop_time", "speaker", "value"]]
    else:
        print(f"Non è stato possibile creare segmenti finali per {session}. Salvo un file vuoto.")
        df = pd.DataFrame(columns=["start_time", "stop_time", "speaker", "value"])

    df.to_csv(transcript_path, sep="\t", index=False)
    print(f"Output finale salvato in: {transcript_path}")

    # (Opzionale) Rimuovi il file intermedio dopo l'uso
    # os.remove(intermediate_path)

del diarize_model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
def play_audio_segment(audio_tensor, sample_rate):
    # audio_tensor shape: [channels, samples]
    audio_np = audio_tensor.numpy().T  # Trasponi per shape (samples, channels)
    sd.play(audio_np, sample_rate)
    sd.wait()

def label_speakers(audio_path, transcript_path):
    df = pd.read_csv(transcript_path, sep="\t")
    waveform, sample_rate = torchaudio.load(audio_path)

    speaker_labels = {}
    speakers = df['speaker'].unique()

    for spk in speakers:
        first_seg = df[df['speaker'] == spk].iloc[0]
        start_sample = int(first_seg['start_time'] * sample_rate)
        end_sample = int(first_seg['stop_time'] * sample_rate)

        segment_audio = waveform[:, start_sample:end_sample]

        print(f"\nSpeaker: {spk} - playing audio segment from {start_sample/sample_rate:.2f}s to {end_sample/sample_rate:.2f}s")
        play_audio_segment(segment_audio, sample_rate)

        choice = input("Label this speaker as (E)llie, (P)articipant, (O)ther: ").strip().lower()
        if choice == 'e':
            speaker_labels[spk] = "Ellie"
        elif choice == 'p':
            speaker_labels[spk] = "Participant"
        else:
            speaker_labels[spk] = "ignore"

    df['speaker'] = df['speaker'].map(speaker_labels)
    df.to_csv(transcript_path, sep="\t", index=False)
    print(f"Updated transcript saved to {transcript_path}")

# Process all sessions for speaker labeling
sessions = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])

for session in tqdm(sessions, desc="Labeling speakers"):
    session_path = os.path.join(base_dir, session)
    base_name = session.split("_")[0]
    audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
    transcript_path = os.path.join(session_path, f"{base_name}_TRANSCRIPT.csv")
    
    if os.path.exists(audio_path) and os.path.exists(transcript_path):
        print(f"\n=== Processing session {session} ===")
        try:
            label_speakers(audio_path, transcript_path)
        except Exception as e:
            print(f"Error processing {session}: {e}")
            continue
    else:
        print(f"Skipping {session}: missing audio or transcript file")