In [None]:
import os
import zipfile
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification, TrainingArguments, Trainer
import requests
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# Step 1: Download & Extract the Dataset
DATASET_URL = "https://github.com/Jakobovski/free-spoken-digit-dataset/archive/refs/heads/master.zip"
DATASET_PATH = "fsdd"

if not os.path.exists(DATASET_PATH):
    print("Downloading dataset...")
    response = requests.get(DATASET_URL)
    with open("fsdd.zip", "wb") as f:
        f.write(response.content)

    with zipfile.ZipFile("fsdd.zip", "r") as zip_ref:
        zip_ref.extractall(DATASET_PATH)

    print("Dataset extracted successfully.")

# Step 2: Load Dataset & Process Audio
audio_folder = os.path.join(DATASET_PATH, "free-spoken-digit-dataset-master", "recordings")
audio_files = [f for f in os.listdir(audio_folder) if f.endswith(".wav")]

labels = []
waveforms = []
sampling_rate = 16000  # Standard for Wav2Vec2

for file in audio_files:
    file_path = os.path.join(audio_folder, file)
    waveform, sr = torchaudio.load(file_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sampling_rate)(waveform)  # Resample if needed

    label = file.split("_")[0]  # Extract digit from filename
    labels.append(label)
    waveforms.append(waveform.squeeze(0))  # Remove channel dimension

# Encode labels
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(labels)

# Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(waveforms, labels, test_size=0.2, random_state=42)

# Load Wav2Vec2 Processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

# Step 3: Create Custom Dataset for Wav2Vec2
class FSDDDataset(Dataset):
    def __init__(self, waveforms, labels, processor):
        self.waveforms = waveforms
        self.labels = labels
        self.processor = processor

    def __len__(self):
        return len(self.waveforms)

    def __getitem__(self, idx):
        inputs = self.processor(
            self.waveforms[idx].numpy(), 
            sampling_rate=sampling_rate, 
            return_tensors="pt", 
            padding="max_length",  # Ensures uniform input size
            max_length=16000  # Adjust according to dataset
        )
        inputs["input_values"] = inputs["input_values"].squeeze(0)  # Remove batch dim
        inputs["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return inputs  # ✅ Ensure this is inside the function


train_dataset = FSDDDataset(X_train, y_train, processor)
test_dataset = FSDDDataset(X_test, y_test, processor)

# Step 4: Load Pretrained Wav2Vec2 Model for Classification
num_classes = len(label_encoder.classes_)
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base", num_labels=num_classes)

# Step 5: Fine-Tuning with Trainer API
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
)

from torch.nn.utils.rnn import pad_sequence

# Custom collate function to handle variable-length input tensors
def collate_fn(batch):
    input_values = [item["input_values"] for item in batch]
    labels = [item["labels"] for item in batch]
    
    # Pad input values to make them the same length
    input_values_padded = pad_sequence(input_values, batch_first=True, padding_value=0.0)
    
    # Convert labels to tensor
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    
    return {"input_values": input_values_padded, "labels": labels_tensor}

# Use this function in Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor,
    data_collator=collate_fn  # ✅ Now it's defined
)

# Train the model
trainer.train()

# Step 6: Evaluate Model
results = trainer.evaluate()
print(f"Test Accuracy: {100 * results['eval_loss']:.2f}%")
