In [None]:
! pip install -r requirements.txt

In [None]:
import os
import whisper
from pyannote.audio import Pipeline
from pathlib import Path
import torch
import torchaudio
#device = "mps" if torch.backends.mps.is_available() else "cpu"

notebook_dir = Path.cwd()

project_root = notebook_dir.parent  
data_dir = project_root / "data_samples"
data_dir.mkdir(exist_ok=True) 

# Relative paths...modify according to your directory structure
audio_file_path = data_dir / "RES0029.mp3"
output_file_path = data_dir / "transcription2.txt"
whisper_model = whisper.load_model("base")

HF_TOKEN = os.getenv("HF_TOKEN")
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN)


In [None]:
diarization = diarization_pipeline(audio_file_path)
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    segments.append({
        "start": turn.start,
        "end": turn.end,
        "speaker": speaker
    })

print(f"Found {len(segments)} speaker segments.")



In [None]:
def transcribe_segment(audio_path, start_time, end_time, model):
    waveform, sample_rate = torchaudio.load(audio_path)
    # Trim the audio to the segment
    start_sample = int(start_time * sample_rate)
    end_sample = int(end_time * sample_rate)
    segment_waveform = waveform[:, start_sample:end_sample]

    # Save temporary segment file
    temp_path = "temp_segment.wav"
    torchaudio.save(temp_path, segment_waveform, sample_rate)

    # Transcribe with Whisper
    result = model.transcribe(temp_path)
    return result["text"]


In [None]:
speaker_map = {
    "SPEAKER_00": "Doctor",
    "SPEAKER_01": "Patient"
}
speaker_transcripts = []
for seg in segments:
    text = transcribe_segment(audio_file_path, seg["start"], seg["end"], whisper_model)
    speaker_name = speaker_map.get(seg["speaker"], seg["speaker"])
    speaker_transcripts.append(f'{seg["speaker"]}: {text}')

with open(output_file_path, "w", encoding="utf-8") as f:
    for line in speaker_transcripts:
        f.write(line + "\n")

print(f"Transcription completed and saved to {output_file_path}")