In [17]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import torchaudio

class Wav2VecAudioDataset(torch.utils.data.Dataset):
    def __init__(self, df, processor):
        self.df = df.reset_index(drop=True)
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row["file_path"]
        label = int(row["label"])

        speech_array, _ = torchaudio.load(path)
        speech_array = speech_array.squeeze().numpy()

        inputs = self.processor(speech_array, sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.squeeze(0)
        attention_mask = inputs.attention_mask.squeeze(0) if "attention_mask" in inputs else None

        return input_values, attention_mask, label



In [18]:
from sklearn.model_selection import train_test_split

df = pd.read_csv("labeled_data.csv")
train_val_df, test_df = train_test_split(df, test_size=0.1, stratify=df["label"], random_state=42)
train_df, val_df = train_test_split(train_val_df, test_size=0.1, stratify=train_val_df["label"], random_state=42)

In [None]:
from transformers import Wav2Vec2Processor
from torch.utils.data import DataLoader

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

train_dataset = Wav2VecAudioDataset(train_df, processor)
val_dataset = Wav2VecAudioDataset(val_df, processor)
test_dataset = Wav2VecAudioDataset(test_df, processor)

def collate_fn(batch):
    input_values = [item[0] for item in batch]
    attention_masks = [item[1] for item in batch]
    labels = torch.tensor([item[2] for item in batch])

    input_values_padded = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True)

    if any(mask is not None for mask in attention_masks):
        attention_masks_padded = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True)
    else:
        attention_masks_padded = None

    return input_values_padded, attention_masks_padded, labels


train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn)



In [20]:
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim

class Wav2VecClassifier(nn.Module):
    def __init__(self, num_classes):
        super(Wav2VecClassifier, self).__init__()
        self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=4, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_values, attention_mask=None):
        outputs = self.wav2vec(input_values=input_values, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # [B, T, 768]

        attn_output, _ = self.attention(hidden_states, hidden_states, hidden_states)  # [B, T, 768]
        pooled = attn_output.mean(dim=1)  # [B, 768]

        return self.classifier(pooled)



def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for input_values, attention_mask, labels in dataloader:
        input_values = input_values.to(device)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(input_values, attention_mask)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = torch.max(logits, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(dataloader), correct / total * 100


def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for input_values, attention_mask, labels in dataloader:
            input_values = input_values.to(device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            logits = model(input_values, attention_mask)
            _, preds = torch.max(logits, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total * 100


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Wav2VecClassifier(num_classes=7).to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.bias', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_q.weight', 'project_q.bias', 'project_hid.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
EPOCHS = 10
best_val_acc = 0.0

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, optimizer, device)
    val_acc = evaluate(model, val_loader, device)

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model_wav2vec_attention.pth")
        print("Saved new best model!")

In [None]:

model.load_state_dict(torch.load("best_model_wav2vec_attention.pth"))
test_acc = evaluate(model, test_loader, device)
print(f"Test Accuracy: {test_acc:.2f}%")