In [None]:
import torch
import os
import pandas as pd
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from pyannote.audio import Pipeline

def speaker_diarization(
            audio_file:str, 
            output_rttm_file_path:str,
            output_csv_file_path:str):
    
    diarization_pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization-3.1",
        use_auth_token="hf_token" # Hugging Face 토큰 입력
    )

    # CUDA 사용 설정
    if torch.cuda.is_available():
        diarization_pipeline.to(torch.device("cuda"))

    diarization_result = diarization_pipeline(audio_file)

    with open(output_rttm_file_path, "w", encoding="utf-8") as rttm_file:
        diarization_result.write_rttm(rttm_file)

    df_rttm =pd.read_csv(
        output_rttm_file_path, # rttm 파일을 DataFrame으로 로드
        sep=" ",
        names=[
            "Type", "File ID", "Channel ID", "start_time", 
            "Duration", "NA1", "NA2", "Speaker ID", "NA3", "NA4"
        ])

    df_rttm["end_time"] = df_rttm["start_time"] + df_rttm["Duration"]

    # speaker ID를 기준으로 화자별 구간 나누기
    df_rttm["number"] = None
    df_rttm.at[0, "number"] = 0

    for i in range(1, len(df_rttm)):
        if df_rttm.at[i, "Speaker ID"] != df_rttm.at[i-1, "Speaker ID"]:
            df_rttm.at[i, "number"] = df_rttm.at[i-1, "number"] + 1
        else:
            df_rttm.at[i, "number"] = df_rttm.at[i-1, "number"]
    
    df_rttm_grouped = df_rttm.groupby("number").agg(
        start=pd.NamedAgg(column="start_time", aggfunc="min"),
        end=pd.NamedAgg(column="end_time", aggfunc="max"),
        speaker_id=pd.NamedAgg(column="Speaker ID", aggfunc="first")
    )

    df_rttm_grouped["duration"] = df_rttm_grouped["end"] - df_rttm_grouped["start"]

    df_rttm_grouped.to_csv(output_csv_file_path, index=False, encoding="utf-8")

    return df_rttm_grouped

if __name__ == "__main__":  
    audio_file = "./CH/chapter5/data/multi_speaker_audio_2023_1min.wav"
    output_rttm_file_path = "./CH/chapter5/data/diarization_output.rttm"
    output_csv_file_path = "./CH/chapter5/data/diarization_output.csv"

    diarization_result = speaker_diarization(
        audio_file,
        output_rttm_file_path,
        output_csv_file_path
    )

    print(diarization_result)