In [None]:
import os
import shutil
import pandas as pd
from pyannote.audio import Pipeline
from dotenv import load_dotenv
from tqdm import tqdm
import torch
import torchaudio
import sounddevice as sd
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

In [None]:
base_dir = "../datasets/EDAIC-WOZ"

# Remove folders from 300 to 492
for i in range(300, 493):
    folder_name = f"{i}_P"
    folder_path = os.path.join(base_dir, folder_name)
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)

# XXX_P/XXX_P -> XXX_P
for dir_name in os.listdir(base_dir):
    outer_path = os.path.join(base_dir, dir_name)
    if os.path.isdir(outer_path):
        inner_path = os.path.join(outer_path, dir_name)
        if os.path.isdir(inner_path):
            # Sposta tutti i file dal secondo livello al primo
            for filename in os.listdir(inner_path):
                src = os.path.join(inner_path, filename)
                dst = os.path.join(outer_path, filename)
                shutil.move(src, dst)
            # Rimuove la cartella interna vuota
            os.rmdir(inner_path)

# Csv files to concatenate
csv_files = ['dev_split.csv', 'test_split.csv', 'train_split.csv']
dfs = []

for csv_file in csv_files:
    path = os.path.join(base_dir, csv_file)
    df = pd.read_csv(path)
    dfs.append(df)

all_data = pd.concat(dfs, ignore_index=True)

# Filter from 300 to 492
all_data = all_data[~all_data['Participant_ID'].between(300, 492)]

# Save
output_path = os.path.join(base_dir, "all_data.csv")
all_data.to_csv(output_path, index=False)

In [None]:
# Fix inconsistent PHQ labels
# Find participants with PHQ_Score >= 10 but PHQ_Binary = 0
inconsistent_mask = (all_data['PHQ_Score'] >= 10) & (all_data['PHQ_Binary'] == 0)
inconsistent_participants = all_data[inconsistent_mask]['Participant_ID'].tolist()

print(f"Found {len(inconsistent_participants)} participants with inconsistent PHQ labels:")
for participant_id in inconsistent_participants:
    phq_score = all_data[all_data['Participant_ID'] == participant_id]['PHQ_Score'].iloc[0]
    print(f"  Participant {participant_id}: PHQ_Score={phq_score}, PHQ_Binary=0 -> fixing to PHQ_Binary=1")

# Fix the inconsistent labels
all_data.loc[inconsistent_mask, 'PHQ_Binary'] = 1

print(f"\nFixed {len(inconsistent_participants)} inconsistent labels")

# Save the corrected data
all_data.to_csv(output_path, index=False)

In [None]:
# Delete directories for non-depressed participants (PHQ_Binary = 0)
non_depressed_participants = all_data[all_data['PHQ_Binary'] == 0]['Participant_ID'].unique()

print(f"Found {len(non_depressed_participants)} non-depressed participants to remove:")

deleted_count = 0
for participant_id in non_depressed_participants:
    folder_name = f"{participant_id}_P"
    folder_path = os.path.join(base_dir, folder_name)
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)
        deleted_count += 1

print(f"Deleted {deleted_count} directories for non-depressed participants")

# Remove non-depressed participants from the CSV data
all_data = all_data[all_data['PHQ_Binary'] == 1]
all_data.to_csv(output_path, index=False)

In [None]:
def overlap(s1, e1, s2, e2):
    return max(0, min(e1, e2) - max(s1, s2))

def assign_speakers_to_segments(segments, diarization_result):
    output = []
    for seg in segments:
        start = seg["start"]
        end = seg["end"]
        text = seg["text"]

        speaker = "UNKNOWN"
        max_ov = 0
        for segment, _, label in diarization_result.itertracks(yield_label=True):
            ov = overlap(start, end, segment.start, segment.end)
            # Abbiamo bisogno di una sovrapposizione minima per considerarla valida
            if ov > max_ov:
                max_ov = ov
                speaker = label

        output.append({
            "start_time": start,
            "stop_time": end,
            "speaker": speaker,
            "value": text
        })
    return output

In [None]:
def align_words_to_speakers(word_timestamps, diarization_result):
    """
    Allinea i timestamp di ogni parola con gli speaker e raggruppa le parole
    in segmenti continui per ogni speaker.
    """
    aligned_segments = []
    
    # Prepara una lista di turni di speaker per una ricerca più rapida
    speaker_turns = []
    for turn, _, speaker_label in diarization_result.itertracks(yield_label=True):
        speaker_turns.append({"start": turn.start, "end": turn.end, "speaker": speaker_label})

    # Variabili per tenere traccia del segmento corrente
    current_segment_speaker = None
    current_segment_text = ""
    current_segment_start = 0
    
    for word_info in word_timestamps:
        word_start, word_end = word_info["timestamp"]
        word_text = word_info["text"]

        # Trova lo speaker per la parola corrente
        word_speaker = "UNKNOWN"
        for turn in speaker_turns:
            if turn["start"] <= word_start and turn["end"] >= word_end:
                word_speaker = turn["speaker"]
                break
        
        if current_segment_speaker is None:
            # Inizia il primo segmento
            current_segment_speaker = word_speaker
            current_segment_start = word_start

        # Se lo speaker cambia, salva il segmento precedente e iniziane uno nuovo
        if word_speaker != current_segment_speaker:
            if current_segment_text: # Salva solo se c'è del testo
                aligned_segments.append({
                    "start_time": current_segment_start,
                    "stop_time": last_word_end, # Usa il tempo di fine dell'ultima parola
                    "speaker": current_segment_speaker,
                    "value": current_segment_text.strip()
                })
            
            # Inizia un nuovo segmento
            current_segment_speaker = word_speaker
            current_segment_start = word_start
            current_segment_text = ""

        # Aggiungi la parola al testo del segmento corrente
        current_segment_text += word_text
        last_word_end = word_end

    # Aggiungi l'ultimo segmento rimasto dopo la fine del loop
    if current_segment_text:
        aligned_segments.append({
            "start_time": current_segment_start,
            "stop_time": last_word_end,
            "speaker": current_segment_speaker,
            "value": current_segment_text.strip()
        })
        
    return aligned_segments

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-medium.en"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
).to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
  "automatic-speech-recognition",
  model=model,
  tokenizer=processor.tokenizer,
  feature_extractor=processor.feature_extractor,
  torch_dtype=torch_dtype,
  device=device,
  #return_timestamps=True,
  return_timestamps="word",
  generate_kwargs={"max_new_tokens": 400},#, "language": "english"},
  chunk_length_s=30
)

load_dotenv()
# Inizializza pipeline diarizzazione da HuggingFace
diarization_pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=os.getenv("HUGGINGFACE_TOKEN")   
)
diarization_pipeline.to(torch.device(device))

sessions = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])

# Itera su ogni sessione singolarmente
for session in tqdm(sessions, desc="Processing sessions"):
    session_path = os.path.join(base_dir, session)
    base_name = session.split("_")[0]
    audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
    transcript_path = os.path.join(session_path, f"{base_name}_TRANSCRIPT.csv")

    print(f"\nProcessing session: {audio_path}")

    if not os.path.exists(audio_path):
        print(f"Audio file not found for session {session}, skipping.")
        continue
    
    # Rimuovi la vecchia trascrizione se esiste
    old_transcript = os.path.join(session_path, f"{base_name}_Transcript.csv")
    if os.path.exists(old_transcript):
        os.remove(old_transcript)

    print(f"Running diarization...")
    diarization = diarization_pipeline(audio_path)
    last_diarization_end = 0
    for turn, _, _ in diarization.itertracks(yield_label=True):
        if turn.end > last_diarization_end:
            last_diarization_end = turn.end

    print(f"\nTranscribing")

    result = pipe(audio_path)
    '''
    transcription = result['text']
    timestamps = result['chunks']
    print(timestamps)
    '''
    word_timestamps = result['chunks']
    print(word_timestamps)
    '''
    segments = []
    for chunk in timestamps:
        start_time, end_time = chunk["timestamp"]
        if end_time is None:
          print(f"Warning: Found a segment with no end time. Using fallback: {last_diarization_end}")
          end_time = last_diarization_end
        segments.append({
            "start": start_time,
            "end": end_time,
            "text": chunk["text"]
        })

    segments = [
      s for s in segments 
      if s["start"] is not None and s["end"] is not None and s["end"] > s["start"] and s["text"] != ""
    ]'''

    # Se non c'è testo, crea un file vuoto e continua
    #if not segments:
    if not word_timestamps:
        print(f"No speech detected in {audio_path}. Saving empty transcript.")
        pd.DataFrame(columns=["start_time", "stop_time", "speaker", "value"]).to_csv(transcript_path, sep="\t", index=False)
        continue

    speakers = list(set(segment[2] for segment in diarization.itertracks(yield_label=True)))
    print(f"Speakers found: {len(speakers)} ({speakers})")

    # 3. Allineamento
    #print(f"Aligning transcription with speakers")
    #aligned = assign_speakers_to_segments(segments, diarization)
    print(f"Aligning transcription with speakers using word-level timestamps...")
    aligned = align_words_to_speakers(word_timestamps, diarization)

    df = pd.DataFrame(aligned)
    df.to_csv(transcript_path, sep="\t", index=False)
    print(f"Saved {transcript_path}")

In [None]:
def play_audio_segment(audio_tensor, sample_rate):
    # audio_tensor shape: [channels, samples]
    audio_np = audio_tensor.numpy().T  # Trasponi per shape (samples, channels)
    sd.play(audio_np, sample_rate)
    sd.wait()

def label_speakers(audio_path, transcript_path):
    df = pd.read_csv(transcript_path, sep="\t")
    waveform, sample_rate = torchaudio.load(audio_path)

    speaker_labels = {}
    speakers = df['speaker'].unique()

    for spk in speakers:
        first_seg = df[df['speaker'] == spk].iloc[0]
        start_sample = int(first_seg['start_time'] * sample_rate)
        end_sample = int(first_seg['stop_time'] * sample_rate)

        segment_audio = waveform[:, start_sample:end_sample]

        print(f"\nSpeaker: {spk} - playing audio segment from {start_sample/sample_rate:.2f}s to {end_sample/sample_rate:.2f}s")
        play_audio_segment(segment_audio, sample_rate)

        choice = input("Label this speaker as (E)llie, (P)articipant, (O)ther: ").strip().lower()
        if choice == 'e':
            speaker_labels[spk] = "Ellie"
        elif choice == 'p':
            speaker_labels[spk] = "Participant"
        else:
            speaker_labels[spk] = "ignore"

    df['speaker'] = df['speaker'].map(speaker_labels)
    df.to_csv(transcript_path, sep="\t", index=False)
    print(f"Updated transcript saved to {transcript_path}")

# Process all sessions for speaker labeling
sessions = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])

for session in tqdm(sessions, desc="Labeling speakers"):
    session_path = os.path.join(base_dir, session)
    base_name = session.split("_")[0]
    audio_path = os.path.join(session_path, f"{base_name}_AUDIO.wav")
    transcript_path = os.path.join(session_path, f"{base_name}_TRANSCRIPT.csv")
    
    if os.path.exists(audio_path) and os.path.exists(transcript_path):
        print(f"\n=== Processing session {session} ===")
        try:
            label_speakers(audio_path, transcript_path)
        except Exception as e:
            print(f"Error processing {session}: {e}")
            continue
    else:
        print(f"Skipping {session}: missing audio or transcript file")