In [1]:
from datasets import concatenate_datasets, load_dataset, Audio, DatasetDict
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, TrainingArguments, Trainer, EarlyStoppingCallback
import numpy as np
import torch
import evaluate
from random import randint

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the feature extractor and the saved model from "models/whisper-large-v3_ADReSSo"
model_name = "distil-whisper/distil-large-v3" # Model name
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Load the model from the saved directory
model = AutoModelForAudioClassification.from_pretrained(
    "models/whisper-large-v3_ADReSSo/checkpoint-93", num_labels=2, ignore_mismatched_sizes=True
)
model.to(device)

# Preprocessing function for the audio data
preprocess = lambda examples: feature_extractor(
    [i["array"][(n := randint(0, len(i["array"]) - (m := min(len(i["array"]), feature_extractor.sampling_rate * 30)))) : n + m] for i in examples["audio"]],
    sampling_rate=feature_extractor.sampling_rate,
    do_normalize=True,
)

#### LOAD DATASET HERE ############
AD_PATH = 'ADReSSo21/diagnosis/train/audio/ad'
CN_PATH = 'ADReSSo21/diagnosis/train/audio/cn'

ad_dataset = (
    load_dataset(AD_PATH)
    .cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
)
ad_dataset = ad_dataset.map(lambda example: {"label": 0})

cn_dataset = (
    load_dataset(CN_PATH)
    .cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
)
cn_dataset = cn_dataset.map(lambda example: {"label": 1})

dataset = concatenate_datasets([ad_dataset["train"], cn_dataset["train"]])
dataset = DatasetDict({"train": dataset})
dataset["train"], dataset["valid"] = dataset["train"].train_test_split(0.25).values()
dataset = dataset.map(preprocess, remove_columns="audio", batched=True)

train_dataset = dataset["train"].with_format("torch")
val_dataset = dataset["valid"].with_format("torch")

# Load evaluation metrics
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
specificity = evaluate.load("nevikw39/specificity")

# Define training arguments (for evaluation purposes)
training_args = TrainingArguments(
    output_dir="models/whisper-large-v3_ADReSSo" + ("_fp16" if False else ""),
    fp16=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=10,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

# Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=feature_extractor,
    compute_metrics=lambda eval_pred: (
        accuracy.compute(
            predictions=(pred := np.argmax(eval_pred.predictions, axis=1)),
            references=eval_pred.label_ids,
        ) | f1.compute(
            predictions=pred,
            references=eval_pred.label_ids,
        ) | specificity.compute(
            predictions=pred,
            references=eval_pred.label_ids,
        )
    ),
    callbacks=[EarlyStoppingCallback(10)],
)




Resolving data files:   0%|          | 0/87 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/87 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/87 [00:00<?, ? examples/s]

Resolving data files:   0%|          | 0/79 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/79 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/79 [00:00<?, ? examples/s]

Map:   0%|          | 0/124 [00:00<?, ? examples/s]

Map:   0%|          | 0/42 [00:00<?, ? examples/s]



In [2]:
# Evaluate the model using the validation dataset
eval_result = trainer.evaluate()

# Print the evaluation metrics
print("Evaluation results:", eval_result)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


  0%|          | 0/21 [00:00<?, ?it/s]

Evaluation results: {'eval_loss': 0.4877254366874695, 'eval_model_preparation_time': 0.0033, 'eval_accuracy': 0.9047619047619048, 'eval_f1': 0.92, 'eval_specificity': 0.8823529411764706, 'eval_runtime': 6.2097, 'eval_samples_per_second': 6.764, 'eval_steps_per_second': 3.382}
