In [1]:
# to use my GPU
# !pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu118


from datasets import concatenate_datasets, load_dataset, Audio, DatasetDict
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor, AutoModelForAudioClassification, WhisperForAudioClassification, WhisperProcessor,  TrainingArguments, Trainer, EarlyStoppingCallback
import numpy as np
from random import randint

import torch
import evaluate
from torch.cuda import device_count
import os
import librosa
import pandas as pd
from pathlib import Path


In [2]:

model_name = "distil-whisper/distil-large-v3" # "whisper-large-v3_ADReSSO"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32


feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(
    model_name, num_labels=2, ignore_mismatched_sizes=True
)
del model.config.__dict__["max_length"]
del model.config.__dict__["suppress_tokens"]
del model.config.__dict__["begin_suppress_tokens"]
model.to(device)
model.save_pretrained("model")

# model1 = WhisperForAudioClassification.from_pretrained(model_name, num_labels=2)
# processor1 = WhisperProcessor.from_pretrained(model_name)

preprocess = lambda examples: feature_extractor(
    [i["array"][(n := randint(0, len(i["array"]) - (m := min(len(i["array"]), feature_extractor.sampling_rate*30)))) : n + m] for i in examples["audio"]],
    sampling_rate=feature_extractor.sampling_rate,
    do_normalize=True,
    # max_length=16_000*args.sample_duration,
    # truncation=True,
)

#### LOAD DATASET HERE ############
AD_PATH = 'ADReSSo21/diagnosis/train/audio/ad'
CN_PATH = 'ADReSSo21/diagnosis/train/audio/cn'

ad_dataset = (
    load_dataset(AD_PATH)
    .cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
)
ad_dataset = ad_dataset.map(lambda example: {"label": 0})

cn_dataset = (
    load_dataset(CN_PATH)
    .cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
)
cn_dataset = cn_dataset.map(lambda example: {"label": 1})

dataset = concatenate_datasets([ad_dataset["train"], cn_dataset["train"]])
dataset = DatasetDict({
    "train": dataset
})
dataset["train"], dataset["valid"] = dataset["train"].train_test_split(0.25).values()
# dataset["valid"], dataset["test"] = dataset["valid"].train_test_split(0.2).values()
dataset = dataset.map(preprocess, remove_columns="audio", batched=True)


Some weights of WhisperForAudioClassification were not initialized from the model checkpoint at distil-whisper/distil-large-v3 and are newly initialized: ['model.classifier.bias', 'model.classifier.weight', 'model.projector.bias', 'model.projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Resolving data files:   0%|          | 0/87 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/79 [00:00<?, ?it/s]

Map:   0%|          | 0/124 [00:00<?, ? examples/s]

Exception ignored from cffi callback <function SoundFile._init_virtual_io.<locals>.vio_read at 0x000001CE950EADE0>:
Traceback (most recent call last):
  File "c:\Users\qscre\miniconda3\envs\aml\Lib\site-packages\soundfile.py", line 1246, in vio_read
    data_read = file.readinto(buf)
                ^^^^^^^^^^^^^^^^^^
KeyboardInterrupt: 


Map:   0%|          | 0/42 [00:00<?, ? examples/s]

In [3]:
# Convert dataset to Tensor (Data must be in PyTorch format, eg: {"input_values": audio_tensor, "label": label_tensor})
# class AudioDataset(Dataset):
#     def __init__(self, audio_data, labels):
#         self.audio_data = audio_data
#         self.labels = labels

#     def __len__(self):
#         return len(self.audio_data)

#     def __getitem__(self, idx):
#         # Convert to tensor and return
#         audio_tensor = torch.tensor(self.audio_data[idx]).float()  # Make sure it's float
#         label_tensor = torch.tensor(self.labels[idx]).long()  # Make sure it's long for classification
#         return {"input_values": audio_tensor, "label": label_tensor}
    
# train_dataset = {"input_features": [torch.tensor(feature, dtype=torch.float) for feature in dataset["train"]["input_features"]], 
#                  "label": [torch.tensor(feature, dtype=torch.float) for feature in dataset["train"]["label"]]}
# valid_dataset = {"input_features": [torch.tensor(feature, dtype=torch.float) for feature in dataset["valid"]["input_features"]], 
#                  "label": [torch.tensor(feature, dtype=torch.float) for feature in dataset["valid"]["label"]]}

train_dataset = dataset["train"].with_format("torch")
val_dataset = dataset["valid"].with_format("torch")
# test_dataset = dataset["test"].with_format("torch")


In [4]:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
specificity = evaluate.load("nevikw39/specificity")

## training part
training_args = TrainingArguments(
    output_dir="models/whisper-large-v3_ADReSSo" + ("_fp16" if False else ""),
    fp16=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    learning_rate=3e-5,
    per_device_train_batch_size=2, # 8
    per_device_eval_batch_size=2,# 8
    gradient_accumulation_steps=4,
    # gradient_checkpointing=True,
    num_train_epochs=10, # 100
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,

)

trainer = Trainer (
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset, # dataset["valid"],
    # eval_dataset=encoded_dataset["test"].select(np.random.choice(71, 42)),
    tokenizer=feature_extractor,
    compute_metrics=lambda eval_pred: (
        accuracy.compute(
            predictions=(pred := np.argmax(eval_pred.predictions, axis=1)),
            references=eval_pred.label_ids,
        ) | f1.compute(
            predictions=pred,
            references=eval_pred.label_ids,
        ) | specificity.compute(
            predictions=pred,
            references=eval_pred.label_ids,
        )
    ),
    callbacks=[EarlyStoppingCallback(10)],
)

#trainer.train()





  0%|          | 0/150 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 0.7083, 'grad_norm': 3.1401898860931396, 'learning_rate': 1.9999999999999998e-05, 'epoch': 0.65}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 0.6848700642585754, 'eval_accuracy': 0.5238095238095238, 'eval_f1': 0.0, 'eval_specificity': 1.0, 'eval_runtime': 159.7882, 'eval_samples_per_second': 0.263, 'eval_steps_per_second': 0.131, 'epoch': 0.97}
{'loss': 0.6411, 'grad_norm': 3.3013556003570557, 'learning_rate': 2.8888888888888888e-05, 'epoch': 1.29}
{'loss': 0.5405, 'grad_norm': 4.796957969665527, 'learning_rate': 2.6666666666666667e-05, 'epoch': 1.94}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 0.5502444505691528, 'eval_accuracy': 0.7142857142857143, 'eval_f1': 0.7391304347826086, 'eval_specificity': 0.5909090909090909, 'eval_runtime': 159.9654, 'eval_samples_per_second': 0.263, 'eval_steps_per_second': 0.131, 'epoch': 2.0}
{'loss': 0.2554, 'grad_norm': 2.6861839294433594, 'learning_rate': 2.4444444444444445e-05, 'epoch': 2.58}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 1.016992449760437, 'eval_accuracy': 0.6904761904761905, 'eval_f1': 0.7450980392156863, 'eval_specificity': 0.45454545454545453, 'eval_runtime': 174.3948, 'eval_samples_per_second': 0.241, 'eval_steps_per_second': 0.12, 'epoch': 2.97}
{'loss': 0.3521, 'grad_norm': 40.06829833984375, 'learning_rate': 2.222222222222222e-05, 'epoch': 3.23}
{'loss': 0.5699, 'grad_norm': 13.027278900146484, 'learning_rate': 1.9999999999999998e-05, 'epoch': 3.87}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 0.5279026031494141, 'eval_accuracy': 0.6904761904761905, 'eval_f1': 0.6666666666666666, 'eval_specificity': 0.7272727272727273, 'eval_runtime': 196.6405, 'eval_samples_per_second': 0.214, 'eval_steps_per_second': 0.107, 'epoch': 4.0}
{'loss': 0.0903, 'grad_norm': 57.81535339355469, 'learning_rate': 1.7777777777777777e-05, 'epoch': 4.52}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 0.8971467018127441, 'eval_accuracy': 0.7380952380952381, 'eval_f1': 0.7027027027027027, 'eval_specificity': 0.8181818181818182, 'eval_runtime': 199.6499, 'eval_samples_per_second': 0.21, 'eval_steps_per_second': 0.105, 'epoch': 4.97}
{'loss': 0.0963, 'grad_norm': 0.09985890984535217, 'learning_rate': 1.5555555555555555e-05, 'epoch': 5.16}
{'loss': 0.0036, 'grad_norm': 0.07194121181964874, 'learning_rate': 1.3333333333333333e-05, 'epoch': 5.81}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 0.7317893505096436, 'eval_accuracy': 0.8095238095238095, 'eval_f1': 0.8095238095238095, 'eval_specificity': 0.7727272727272727, 'eval_runtime': 198.2276, 'eval_samples_per_second': 0.212, 'eval_steps_per_second': 0.106, 'epoch': 6.0}
{'loss': 0.0023, 'grad_norm': 0.047676365822553635, 'learning_rate': 1.111111111111111e-05, 'epoch': 6.45}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 1.071899652481079, 'eval_accuracy': 0.7619047619047619, 'eval_f1': 0.7368421052631579, 'eval_specificity': 0.8181818181818182, 'eval_runtime': 200.0074, 'eval_samples_per_second': 0.21, 'eval_steps_per_second': 0.105, 'epoch': 6.97}
{'loss': 0.0017, 'grad_norm': 0.03788022696971893, 'learning_rate': 8.888888888888888e-06, 'epoch': 7.1}
{'loss': 0.0014, 'grad_norm': 0.03436962142586708, 'learning_rate': 6.666666666666667e-06, 'epoch': 7.74}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 1.3117010593414307, 'eval_accuracy': 0.7142857142857143, 'eval_f1': 0.6666666666666666, 'eval_specificity': 0.8181818181818182, 'eval_runtime': 200.2861, 'eval_samples_per_second': 0.21, 'eval_steps_per_second': 0.105, 'epoch': 8.0}
{'loss': 0.0013, 'grad_norm': 0.031182970851659775, 'learning_rate': 4.444444444444444e-06, 'epoch': 8.39}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 1.361391544342041, 'eval_accuracy': 0.7142857142857143, 'eval_f1': 0.6666666666666666, 'eval_specificity': 0.8181818181818182, 'eval_runtime': 196.6871, 'eval_samples_per_second': 0.214, 'eval_steps_per_second': 0.107, 'epoch': 8.97}
{'loss': 0.0012, 'grad_norm': 0.03122571110725403, 'learning_rate': 2.222222222222222e-06, 'epoch': 9.03}
{'loss': 0.0012, 'grad_norm': 0.02977282926440239, 'learning_rate': 0.0, 'epoch': 9.68}


  0%|          | 0/21 [00:00<?, ?it/s]

{'eval_loss': 1.3726603984832764, 'eval_accuracy': 0.7142857142857143, 'eval_f1': 0.6666666666666666, 'eval_specificity': 0.8181818181818182, 'eval_runtime': 198.9486, 'eval_samples_per_second': 0.211, 'eval_steps_per_second': 0.106, 'epoch': 9.68}
{'train_runtime': 11724.7202, 'train_samples_per_second': 0.106, 'train_steps_per_second': 0.013, 'train_loss': 0.21776453108216326, 'epoch': 9.68}


TrainOutput(global_step=150, training_loss=0.21776453108216326, metrics={'train_runtime': 11724.7202, 'train_samples_per_second': 0.106, 'train_steps_per_second': 0.013, 'total_flos': 1.756691463168e+18, 'train_loss': 0.21776453108216326, 'epoch': 9.67741935483871})

In [5]:
# print("test_dataset", test_dataset)

# eval_results = trainer.evaluate(test_dataset)

# eval_results

In [8]:

# Evaluate the model on the evaluation dataset and capture the evaluation results
eval_result = trainer.evaluate()

# Print the evaluation metrics
print("Evaluation results:", eval_result)

  0%|          | 0/21 [00:00<?, ?it/s]

KeyboardInterrupt: 