<a href="https://colab.research.google.com/github/HamdanXI/nlp_adventure/blob/main/speech-privacy/bigger-training-data-lower-performance-wav2vec2-gender-prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install datasets



In [4]:
################################################################################
# 1. IMPORTS
################################################################################
import torch
import torchaudio
import numpy as np
import random
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2Model
)
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
from typing import Any, Dict

################################################################################
# 2. LOAD THE DATASET
################################################################################
dataset = load_dataset("HamdanXI/speech-accent-archive-v2")["train"]
print("Full dataset size:", len(dataset))
print("Columns:", dataset.column_names)
print("Unique speakerid:", len(set(dataset["speakerid"])))

################################################################################
# 3. CHOOSE (A) "TRAIN" + (B) "TEST" SPEAKERS
################################################################################
A = 120
B = 40

# We'll focus on English native speakers only, for male/female
male_speakers_all = list({
    ex["speakerid"] for ex in dataset
    if ex["sex"] == "male" and ex["native_language"].lower() == "english"
})
female_speakers_all = list({
    ex["speakerid"] for ex in dataset
    if ex["sex"] == "female" and ex["native_language"].lower() == "english"
})

# Pick train/test subsets of speakers
random.seed(42)
chosen_male_train = random.sample(male_speakers_all, A)
chosen_female_train = random.sample(female_speakers_all, A)

# Remove them so we don't re-pick them for test
for spk in chosen_male_train:
    male_speakers_all.remove(spk)
for spk in chosen_female_train:
    female_speakers_all.remove(spk)

chosen_male_test = random.sample(male_speakers_all, B)
chosen_female_test = random.sample(female_speakers_all, B)

chosen_speakers_train = chosen_male_train + chosen_female_train
chosen_speakers_test  = chosen_male_test + chosen_female_test

print(f"Chosen training male:   {chosen_male_train}")
print(f"Chosen training female: {chosen_female_train}")
print(f"Chosen test male:       {chosen_male_test}")
print(f"Chosen test female:     {chosen_female_test}")

# Filter dataset into train / test based on speaker IDs
train_ds = dataset.filter(lambda x: x["speakerid"] in chosen_speakers_train)
test_ds  = dataset.filter(lambda x: x["speakerid"] in chosen_speakers_test)

################################################################################
# 4. LOAD Wav2Vec2 PROCESSOR & MODEL (FOR EMBEDDING EXTRACTION)
################################################################################
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
# Use Wav2Vec2Model (no CTC head) for speaker-embedding extraction
wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

# Optionally freeze the entire model if you do NOT want to fine-tune it:
for param in wav2vec2_model.parameters():
    param.requires_grad = False

wav2vec2_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav2vec2_model.to(device)

################################################################################
# 5. EMBEDDING EXTRACTION FUNCTION
################################################################################
def extract_wav2vec2_embedding(batch):
    """
    1) Resample audio to 16k if needed.
    2) Pass through Wav2Vec2 processor and model to get hidden states.
    3) Mean-pool over time dimension to get a single speaker embedding.
    4) Return that embedding plus the label (0=male, 1=female).
    """
    sr = batch["audio"]["sampling_rate"]
    audio_array = batch["audio"]["array"]
    wave_tensor = torch.tensor(audio_array, dtype=torch.float32)

    if sr != 16000:
        wave_tensor = torchaudio.functional.resample(wave_tensor, sr, 16000)

    # Process into model input format
    inputs = processor(wave_tensor, sampling_rate=16000, return_tensors="pt")

    # Move to GPU if available
    input_values = inputs["input_values"].to(device)
    attention_mask = inputs["attention_mask"].to(device) if "attention_mask" in inputs else None

    with torch.no_grad():
        outputs = wav2vec2_model(input_values, attention_mask=attention_mask)
        # outputs.last_hidden_state shape: (batch_size=1, time_steps, hidden_dim)
        hidden_states = outputs.last_hidden_state[0]  # shape: (time_steps, hidden_dim)

    # Mean pooling over time dimension -> single (hidden_dim,) embedding
    embedding = hidden_states.mean(dim=0).cpu().numpy()  # shape: (hidden_dim,)

    # Convert sex to label: 0 = male, 1 = female
    label = 0 if batch["sex"] == "male" else 1

    return {
        "speakerid": batch["speakerid"],
        "embedding": embedding,
        "label":     label
    }

################################################################################
# 6. BUILD NEW DATASET WITH EMBEDDINGS
################################################################################
# We map the train_ds and test_ds to new columns: "embedding", "label"
# CAUTION: Doing map over large sets might be slow in pure Python. For demonstration it's fine.
train_with_emb = train_ds.map(extract_wav2vec2_embedding)
test_with_emb  = test_ds.map(extract_wav2vec2_embedding)

# Keep only "embedding" and "label" columns
train_with_emb = train_with_emb.remove_columns(
    [col for col in train_with_emb.column_names if col not in ["embedding", "label"]]
)
test_with_emb = test_with_emb.remove_columns(
    [col for col in test_with_emb.column_names if col not in ["embedding", "label"]]
)

print(f"train_with_emb length: {len(train_with_emb)}")
print(f"test_with_emb length:  {len(test_with_emb)}")

################################################################################
# 7. CONVERT EMBEDDINGS TO NUMPY/torch FOR CLASSIFICATION
################################################################################
# Turn the datasets into lists of (embedding, label)
X_train = np.array([example["embedding"] for example in train_with_emb], dtype=np.float32)
y_train = np.array([example["label"] for example in train_with_emb], dtype=np.int64)

X_test = np.array([example["embedding"] for example in test_with_emb], dtype=np.float32)
y_test = np.array([example["label"] for example in test_with_emb], dtype=np.int64)

print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:",  X_test.shape)
print("y_test shape:",  y_test.shape)

################################################################################
# 8. TRAIN A SMALL MLP CLASSIFIER ON THE EMBEDDINGS
################################################################################
# If the base Wav2Vec2 model is "base-960h", the hidden dimension is 768.
input_dim = X_train.shape[1]

class SexNet(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=128):
        super(SexNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 2)  # 2 outputs (male vs. female)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Create torch datasets
X_train_tensor = torch.from_numpy(X_train)
y_train_tensor = torch.from_numpy(y_train)

X_test_tensor = torch.from_numpy(X_test)
y_test_tensor = torch.from_numpy(y_test)

train_dataset_torch = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)
test_dataset_torch  = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)

train_loader = torch.utils.data.DataLoader(train_dataset_torch, batch_size=8, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset_torch, batch_size=8, shuffle=False)

# Instantiate the model
model_nn = SexNet(input_dim=input_dim, hidden_dim=128).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_nn.parameters(), lr=1e-3)
epochs = 10

model_nn.train()
for epoch in range(epochs):
    running_loss = 0.0
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)

        optimizer.zero_grad()
        outputs = model_nn(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")

################################################################################
# 9. EVALUATE ON TEST SET
################################################################################
model_nn.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        batch_x = batch_x.to(device)
        logits = model_nn(batch_x)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(batch_y.numpy())

accuracy = accuracy_score(all_labels, all_preds)
print(f"\nTest Accuracy on speaker-level utterances: {accuracy*100:.2f}%")

################################################################################
# DONE
################################################################################
print("Finished speaker-attribute classification with Wav2Vec2 embeddings.")

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Full dataset size: 2138
Columns: ['age', 'age_onset', 'birthplace', 'filename', 'native_language', 'sex', 'speakerid', 'country', 'file_missing?', 'file_exists', '__index_level_0__', 'audio']
Unique speakerid: 2138
Chosen training male:   [137, 71, 673, 584, 547, 161, 132, 1642, 2172, 1226, 76, 2121, 127, 536, 555, 1469, 73, 1709, 521, 1895, 1222, 538, 1295, 695, 2077, 443, 1878, 876, 1672, 427, 535, 869, 1962, 1800, 1075, 129, 937, 889, 637, 86, 1312, 148, 1072, 2165, 767, 951, 517, 112, 92, 551, 740, 1536, 1733, 131, 516, 1970, 1477, 821, 1220, 508, 120, 509, 496, 133, 863, 155, 900, 1325, 1650, 835, 81, 776, 1384, 2171, 638, 952, 145, 1363, 554, 1388, 156, 824, 1560, 1649, 136, 884, 446, 1225, 1078, 1100, 74, 138, 1188, 65, 1163, 439, 522, 1340, 1876, 134, 734, 1564, 1377, 1176, 855, 586, 1720, 825, 552, 113, 153, 1530, 1544, 1051, 720, 662, 1242, 1067, 754, 534]
Chosen training female: [871, 1353, 1716, 2162, 490, 981, 1105, 1070, 1113, 636, 639, 1925, 574, 1374, 1372, 1293, 1214, 

Filter:   0%|          | 0/2138 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2138 [00:00<?, ? examples/s]

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/240 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

train_with_emb length: 240
test_with_emb length:  80
X_train shape: (240, 768)
y_train shape: (240,)
X_test shape: (80, 768)
y_test shape: (80,)
Epoch 1/10 - Loss: 0.7040
Epoch 2/10 - Loss: 0.6956
Epoch 3/10 - Loss: 0.6949
Epoch 4/10 - Loss: 0.6940
Epoch 5/10 - Loss: 0.6942
Epoch 6/10 - Loss: 0.6929
Epoch 7/10 - Loss: 0.6927
Epoch 8/10 - Loss: 0.6921
Epoch 9/10 - Loss: 0.6918
Epoch 10/10 - Loss: 0.6910

Test Accuracy on speaker-level utterances: 50.00%
Finished speaker-attribute classification with Wav2Vec2 embeddings.
