In [1]:
############################
# 1) Install Dependencies #
############################
!pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
!pip install --upgrade transformers datasets evaluate librosa seaborn
!pip install bitsandbytes
!pip install pydub



Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (27 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.21.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu11==11.8.87 (from torch)
  Downloading nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cu

In [2]:

############################
# 2) Import Libraries     #
############################
import os
import requests
import shutil
from zipfile import ZipFile
import torch
import librosa
import datasets
import evaluate
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_from_disk, Dataset, DatasetDict, concatenate_datasets
from transformers import (
    Wav2Vec2ForSequenceClassification,
    Wav2Vec2Processor,
    TrainingArguments,
    Trainer
)
from torch.nn.utils.rnn import pad_sequence
from bitsandbytes.optim import Adam8bit
from sklearn.metrics import classification_report, confusion_matrix
from pydub import AudioSegment
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
############################
# 3) Check GPU & Versions #
############################
print("Torch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("GPU available?", torch.cuda.is_available())
print("Using device:", device)
!nvidia-smi


Torch version: 2.6.0+cu118
GPU available? True
Using device: cuda
Fri Feb 21 20:03:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   33C    P0             47W /  400W |       5MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+------

## Data acquisition

In [3]:
urls = [
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc2.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc3.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc4.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc5.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc6.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc7.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc8.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc9.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc10.zip",
    "https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc11.zip"
]

# 2) SET UP FOLDERS
# Create root directories for your final data
os.makedirs("Healthy", exist_ok=True)
os.makedirs("MCI", exist_ok=True)
os.makedirs("AD", exist_ok=True)

# Create a temporary folder for extracted files
temp_folder = "tmp_extracted"
os.makedirs(temp_folder, exist_ok=True)

# 3) DOWNLOAD AND EXTRACT EACH ZIP
for i, url in enumerate(urls):
    zip_filename = f"downloaded_{i}.zip"   # A local name to store the downloaded file

    # Download the file
    print(f"Downloading from {url}...")
    response = requests.get(url)
    with open(zip_filename, "wb") as f:
        f.write(response.content)
    print(f"Saved {zip_filename}")

    # Extract all contents into the temp_folder
    print(f"Extracting {zip_filename}...")
    with ZipFile(zip_filename, 'r') as zip_ref:
        zip_ref.extractall(temp_folder)

    # Optionally delete the ZIP file after extraction to save space
    os.remove(zip_filename)

temp_folder = "tmp_extracted"  # your temporary extraction folder



# 4) MOVE FILES INTO THE RIGHT FOLDERS (WITH MP3 TO WAV CONVERSION)
for root, dirs, files in os.walk(temp_folder):
    for filename in files:
        full_path = os.path.join(root, filename)

        # Convert MP3 to WAV if needed
        if filename.endswith(".mp3"):
            wav_filename = filename.replace(".mp3", ".wav")
            wav_path = os.path.join(root, wav_filename)

            # Convert MP3 to WAV
            audio = AudioSegment.from_mp3(full_path)
            audio.export(wav_path, format="wav")

            # Remove the original MP3
            os.remove(full_path)

            # Update full_path to the new WAV file
            full_path = wav_path
            filename = wav_filename

        # Move to corresponding folder
        if filename.startswith("AD"):
            shutil.move(full_path, os.path.join("AD", filename))
        elif filename.startswith("MCI"):
            shutil.move(full_path, os.path.join("MCI", filename))
        elif filename.startswith("HC"):
            shutil.move(full_path, os.path.join("Healthy", filename))
        else:
            print(f"File '{filename}' doesn't match AD/MCI/HC. Skipping or placing it elsewhere.")

print("MP3 conversion and file moving completed.")


# 5) CLEAN UP
shutil.rmtree(temp_folder, ignore_errors=True)
print("Temporary folder removed.")


Downloading from https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc2.zip...
Saved downloaded_0.zip
Extracting downloaded_0.zip...
Downloading from https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc3.zip...
Saved downloaded_1.zip
Extracting downloaded_1.zip...
Downloading from https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc4.zip...
Saved downloaded_2.zip
Extracting downloaded_2.zip...
Downloading from https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc5.zip...
Saved downloaded_3.zip
Extracting downloaded_3.zip...
Downloading from https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc6.zip...
Saved downloaded_4.zip
Extracting downloaded_4.zip...
Downloading from https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc7.zip...
Saved downloaded_5.zip
Extracting downloaded_5.zip...
Downloading from https://ars.els-cdn.com/content/image/1-s2.0-S0885230821001340-mmc8.zip...
Saved downloaded_6.zip
Extractin

In [9]:

############################
# 1) Data Preprocessing   #
############################

ROOT_DIR = os.getcwd()
DATASET_PATH = ROOT_DIR  # Root folder where audio files are stored
OUTPUT_PATH = os.path.join(ROOT_DIR, "ProcessedFiles")
os.makedirs(OUTPUT_PATH, exist_ok=True)

LABEL_MAP = {"Healthy": 0, "MCI": 1, "AD": 2}

def load_audio(file_path, target_sr=16000):
    audio, sr = librosa.load(file_path, sr=target_sr)
    return np.array(audio, dtype=np.float32), sr  # Ensure float32 output

def chunk_audio(example, max_length=16000*60):  # 60 seconds max
    audio = example["audio"]
    if len(audio) > max_length:
        example["audio"] = audio[:max_length]
    return example

def augment_audio(example):
    audio = np.array(example["audio"], dtype=np.float32)

    if np.random.rand() > 0.5:
        audio = librosa.effects.time_stretch(audio, rate=np.random.uniform(0.9, 1.1))
    if np.random.rand() > 0.5:
        audio = librosa.effects.pitch_shift(audio, sr=16000, n_steps=np.random.randint(-2, 3))
    if np.random.rand() > 0.5:
        audio = audio + 0.005 * np.random.normal(0, 1, len(audio))

    example["audio"] = np.array(audio, dtype=np.float32)
    return example




audio_files = []
labels = []

for category in LABEL_MAP.keys():
    category_path = os.path.join(DATASET_PATH, category)
    if not os.path.exists(category_path):
        continue
    for file in os.listdir(category_path):
        if file.endswith(".wav"):
            audio_files.append(os.path.join(category_path, file))
            labels.append(LABEL_MAP[category])

data_df = pd.DataFrame({"file_path": audio_files, "label": labels})

# Splitting dataset
train_files, test_files, train_labels, test_labels = train_test_split(
    data_df["file_path"], data_df["label"],
    test_size=0.2, stratify=data_df["label"], random_state=42
)
train_files, val_files, train_labels, val_labels = train_test_split(
    train_files, train_labels,
    test_size=0.1, stratify=train_labels, random_state=42
)

def process_data(files, labels):
    data = []
    for file, label in tqdm(zip(files, labels), total=len(files)):
        audio, sr = load_audio(file)
        data.append({"audio": np.array(audio, dtype=np.float32), "label": label, "sampling_rate": sr})  # Enforce float32
    return data

train_data = process_data(train_files, train_labels)
val_data = process_data(val_files, val_labels)
test_data = process_data(test_files, test_labels)

dataset = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(val_data),
    "test": Dataset.from_list(test_data),
})

dataset = dataset.map(chunk_audio)
dataset = dataset.map(augment_audio)

# Balance classes by oversampling AD & MCI
healthy_samples = dataset["train"].filter(lambda x: x["label"] == 0)
mci_samples = dataset["train"].filter(lambda x: x["label"] == 1)
ad_samples = dataset["train"].filter(lambda x: x["label"] == 2)

oversampled_ad = concatenate_datasets([ad_samples] * 3)
oversampled_mci = concatenate_datasets([mci_samples] * 3)

dataset["train"] = concatenate_datasets([healthy_samples, oversampled_mci, oversampled_ad])

dataset.save_to_disk(OUTPUT_PATH)
print(f"Dataset saved to {OUTPUT_PATH}")



100%|██████████| 259/259 [00:15<00:00, 17.18it/s]
100%|██████████| 29/29 [00:02<00:00, 13.43it/s]
100%|██████████| 73/73 [00:04<00:00, 18.10it/s]


Map:   0%|          | 0/259 [00:00<?, ? examples/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

Map:   0%|          | 0/73 [00:00<?, ? examples/s]

Map:   0%|          | 0/259 [00:00<?, ? examples/s]

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

Map:   0%|          | 0/73 [00:00<?, ? examples/s]

Filter:   0%|          | 0/259 [00:00<?, ? examples/s]

Filter:   0%|          | 0/259 [00:00<?, ? examples/s]

Filter:   0%|          | 0/259 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/495 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/29 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/73 [00:00<?, ? examples/s]

Dataset saved to /content/ProcessedFiles


In [10]:


############################
# 2) Load Dataset & Model #
############################

dataset = load_from_disk(OUTPUT_PATH)


In [None]:

model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    model_name,
    num_labels=3,
)
model.gradient_checkpointing_enable()
optimizer = Adam8bit(model.parameters(), lr=2e-5)



In [12]:


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    report = classification_report(labels, predictions, target_names=["Healthy", "MCI", "AD"], output_dict=True)

    print("Classification Report:\n", classification_report(labels, predictions, target_names=["Healthy", "MCI", "AD"]))

    return {
        "accuracy": report["accuracy"],
        "f1_macro": report["macro avg"]["f1-score"],
        "f1_healthy": report["Healthy"]["f1-score"],
        "f1_mci": report["MCI"]["f1-score"],
        "f1_ad": report["AD"]["f1-score"],
    }

############################
# 3) Optimized Data Collator #
############################

def data_collator(features):
    waveforms = [torch.tensor(f["audio"]) for f in features]
    labels = torch.tensor([f["label"] for f in features])

    input_values = pad_sequence(waveforms, batch_first=True, padding_value=0)

    inputs = processor(
        input_values.numpy(),
        sampling_rate=16000,
        padding=True,
        return_tensors="pt"
    )

    inputs["labels"] = labels
    return inputs

############################
# 4) Training Arguments   #
############################

training_args = TrainingArguments(
    output_dir="./wav2vec2_classification",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=15,
    weight_decay=0.2,
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",
    fp16=True,
    remove_unused_columns=False,
    gradient_accumulation_steps=2,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    optimizers=(optimizer, None),
)




In [None]:

trainer.train()

model.save_pretrained("./wav2vec2_trained")
processor.save_pretrained("./wav2vec2_trained")

print("Training complete! Model saved to ./wav2vec2_trained")


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Healthy,F1 Mci,F1 Ad
1,0.9279,1.026291,0.37931,0.304928,0.48,0.0,0.434783
2,1.0344,1.210188,0.241379,0.242105,0.2,0.0,0.526316
3,0.7392,1.060025,0.413793,0.371795,0.615385,0.125,0.375


Classification Report:
               precision    recall  f1-score   support

     Healthy       0.67      0.38      0.48        16
         MCI       0.00      0.00      0.00         7
          AD       0.29      0.83      0.43         6

    accuracy                           0.38        29
   macro avg       0.32      0.40      0.30        29
weighted avg       0.43      0.38      0.35        29

Classification Report:
               precision    recall  f1-score   support

     Healthy       0.50      0.12      0.20        16
         MCI       0.00      0.00      0.00         7
          AD       0.38      0.83      0.53         6

    accuracy                           0.24        29
   macro avg       0.29      0.32      0.24        29
weighted avg       0.36      0.24      0.22        29

Classification Report:
               precision    recall  f1-score   support

     Healthy       0.80      0.50      0.62        16
         MCI       0.11      0.14      0.12         7
   

Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Healthy,F1 Mci,F1 Ad
1,0.9279,1.026291,0.37931,0.304928,0.48,0.0,0.434783
2,1.0344,1.210188,0.241379,0.242105,0.2,0.0,0.526316
3,0.7392,1.060025,0.413793,0.371795,0.615385,0.125,0.375
4,0.9303,1.326774,0.37931,0.351277,0.434783,0.142857,0.47619


Classification Report:
               precision    recall  f1-score   support

     Healthy       0.71      0.31      0.43        16
         MCI       0.14      0.14      0.14         7
          AD       0.33      0.83      0.48         6

    accuracy                           0.38        29
   macro avg       0.40      0.43      0.35        29
weighted avg       0.50      0.38      0.37        29

