In [1]:
import os
import warnings
from pathlib import Path
from typing import Dict, List

import pandas as pd
import torch
from transformers import pipeline
from tqdm.auto import tqdm

import utils

In [2]:
BASE_PATH = "/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration"

## CONFIGURATION

In [3]:
# # Run medium model
# MODEL_NAME = "openai/whisper-medium.en"
# BATCH_SIZE = 8
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# TRAIN_PATH_DATA = f"{BASE_PATH}/data/diagnosis/train"
# TRAIN_AUDIO_PATH = f"{TRAIN_PATH_DATA}/audio"
# OUTPUT_PATH = f"{BASE_PATH}/output/transcripts"
# os.makedirs(OUTPUT_PATH, exist_ok=True)

# print(f"Using device: {DEVICE}")
# print(f"Model: {MODEL_NAME}")

# Run large model
MODEL_NAME = "openai/whisper-large-v3"
BATCH_SIZE = 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TRAIN_PATH_DATA = f"{BASE_PATH}/data/diagnosis/train"
TRAIN_AUDIO_PATH = f"{TRAIN_PATH_DATA}/audio"
OUTPUT_PATH = f"{BASE_PATH}/output/transcripts"
os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"Using device: {DEVICE}")
print(f"Model: {MODEL_NAME}")

Using device: cuda
Model: openai/whisper-large-v3


In [4]:
audio_files_dict = utils.get_audio_files(TRAIN_AUDIO_PATH)
print(audio_files_dict)

Found 87 AD files
Found 79 CN files
{'ad': [PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso024.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso025.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso027.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso028.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso031.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso032.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso033.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/ADReSSo-feature-extration/data/diagnosis/train/audio/ad/adrso035.wav'), PosixPath('/mnt/data_lab513/ducvu/ADReSSo/AD

## Transription function

In [None]:
# Run medium model on one GPU
transcriber = pipeline(
    "automatic-speech-recognition",
    model=MODEL_NAME,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    model_kwargs={
        "attn_implementation": "sdpa"  # Scaled dot-product attention (faster)
    }
)

transcriber.model.config.forced_decoder_ids = transcriber.tokenizer.get_decoder_prompt_ids(
    language="english",
    task="transcribe"
)


# # processor = AutoProcessor.from_pretrained(MODEL_NAME)
# torch_type=torch.float16 if torch.cuda.is_available() else torch.float32

# # Create pipeline
# transcriber = pipeline(
#     task="automatic-speech-recognition",
#     batch_size=BATCH_SIZE,
#     model=MODEL_NAME,
#     torch_dtype=torch_type,
#     device_map="auto", # Using this to multiGPU
#     generate_kwargs={
#         "language": "english"
#     },
#     model_kwargs={
#         "attn_implementation": "sdpa"  # Scaled dot-product attention (faster)
#     }
# )

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/1259 [00:00<?, ?it/s]

In [9]:
def transcribe_audio_files(
    audio_files:Dict[str, List[Path]],
    transcriber, # model
) -> pd.DataFrame:
    """ Transcribe audio files without diarization
    """

    results = []

    for diagnosis, files in audio_files.items():
        for audio_file in tqdm(files, desc=f"{diagnosis.upper()}"):
            output = transcriber(
                str(audio_file),
                return_timestamps=True,
                generate_kwargs={
                    "task": "transcribe",
                    "language": "en",
                    "return_timestamps": True,
                    "num_beams": 5,
                }
            )

            # Handle different output formats
            if isinstance(output, dict):
                if "text" in output:
                    transcript = output["text"].strip()
                elif "chunks" in output:
                    transcript = " ".join([chunk["text"] for chunk in output["chunks"]]).strip()
                else:
                    transcript = ""
            else:
                transcript = str(output).strip()

            results.append({
                "files_id": audio_file.stem,
                "diagnosis": diagnosis,
                "transcript": transcript,
                })
    return pd.DataFrame(results)

In [10]:
df_transcripts = transcribe_audio_files(
    audio_files=audio_files_dict,
    transcriber=transcriber,
)

print(f"\nTotal transcription: {len(df_transcripts)}")
print(f"\nSample transcripts: {df_transcripts.head()}")

AD:   0%|          | 0/87 [00:00<?, ?it/s]

A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> to see related `.generate()` flags.
A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> to see related `.generate()` flags.
You seem to be using the pipeli

CN:   0%|          | 0/79 [00:00<?, ?it/s]

Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the m


Total transcription: 166

Sample transcripts:    files_id diagnosis                                         transcript
0  adrso024        ad  There's a cookie jar and the lid is off the co...
1  adrso025        ad  Tell me everything that's going on. Well, the ...
2  adrso027        ad  There's a little girl and a little boy standin...
3  adrso028        ad  How she would find her, and the mother's wishe...
4  adrso031        ad  What do you see going on? Well, the boy's on a...


In [11]:
print(df_transcripts["transcript"])

0      There's a cookie jar and the lid is off the co...
1      Tell me everything that's going on. Well, the ...
2      There's a little girl and a little boy standin...
3      How she would find her, and the mother's wishe...
4      What do you see going on? Well, the boy's on a...
                             ...                        
161    Bring in that picture. Tell me everything that...
162    Okay, I just want you to tell me what you see ...
163    What do you see going on in that picture? Oh, ...
164    Just the action. The girl is reaching for a co...
165    I'd like you to tell me everything that you se...
Name: transcript, Length: 166, dtype: object


## SAVE TO CSV

In [13]:
output_file = Path(OUTPUT_PATH) / f"adresso_transcripts_{MODEL_NAME.split('/')[-1]}.csv"
df_transcripts.to_csv(output_file, index=False)