In [134]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModel
import torchaudio
from torch.utils.data import DataLoader
import os
import sys
sys.path.append("..")
from src.evaluation.representations import compute_linear_eval


In [136]:
data_folder = "../data/speechcommands"

# Load the full dataset (v0.02 with 35 classes)
import torchaudio
dataset = torchaudio.datasets.SPEECHCOMMANDS(
    root=data_folder,
    download=True,
)


def collate_fn(batch):

    word_dict = {
        "yes": 0,
        "no": 1,
        "up": 2,
        "down": 3,
        "left": 4,
    }
    """
    Processes a batch of samples to make waveforms uniform in length.
    A batch is a list of tuples: (waveform, sample_rate, label, speaker_id, utterance_number)
    """
    target_length = 16000  # 1 second at 16kHz
    waveforms = []
    labels = []

    # Process each sample in the batch
    for waveform, _, label, _, _ in batch:
        # Pad the waveform if it's shorter than the target length
        if waveform.shape[1] < target_length:
            padding_needed = target_length - waveform.shape[1]
            padding = torch.zeros((1, padding_needed))
            waveform = torch.cat([waveform, padding], dim=1)
        # Truncate the waveform if it's longer
        else:
            waveform = waveform[:, :target_length]
        

        waveforms.append(waveform)
        labels.append(label) # You'll likely want to convert these strings to numbers later

    # Stack the processed waveforms into a single tensor
    waveforms_tensor = torch.stack(waveforms)
    
    return waveforms_tensor.squeeze(1), labels

# When creating your DataLoader, pass your custom function
# Assuming 'dataset' is your torchaudio.datasets.SPEECHCOMMANDS object
dataloader = DataLoader(
    dataset, 
    batch_size=100, 
    shuffle=False, 
    collate_fn=collate_fn  # This is the key part
)




100%|██████████| 2.26G/2.26G [00:54<00:00, 44.7MB/s]


In [132]:
dataset.


AttributeError: 'SPEECHCOMMANDS' object has no attribute 'labels'

In [65]:
from transformers import Wav2Vec2Model
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")

device = 'mps'
model.to(device)
labels = []
representations = []
from tqdm import tqdm
with torch.no_grad():
    for batch in tqdm(dataloader):
        waveform, label = batch
        
        labels += label
        representations.append(model(waveform.to(device)).last_hidden_state.cpu())
representations = torch.cat(representations, dim=0)


  WeightNorm.apply(module, name, dim)
Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'project_hid.bias', 'project_q.bias', 'quantizer.codevectors', 'quantizer.weight_proj.bias', 'project_q.weight', 'quantizer.weight_proj.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 1059/1059 [04:45<00:00,  3.71it/s]


In [67]:
feat_mean,feat_std = representations.mean(dim=1),representations.std(dim=1)
cat_feat = torch.cat([feat_mean,feat_std],dim=1)


In [84]:
speech_labels = [
    'backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
    'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no',
    'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree',
    'two', 'up', 'visual', 'wow', 'yes', 'zero'
]
int_labels = [speech_labels.index(label) for label in labels]
int_labels = torch.tensor(int_labels).int()










In [122]:
dataset_len = len(int_labels)
train_len = int(dataset_len * 0.8)
test_len = int(dataset_len * 0.2)

idx = torch.randperm(dataset_len)
train_idx = idx[:train_len]
test_idx = idx[train_len:]

train_representations = feat_mean[train_idx]
test_representations = feat_mean[test_idx]
train_labels = int_labels[train_idx]
test_labels = int_labels[test_idx]

train_dataset = torch.utils.data.TensorDataset(train_representations, train_labels)
test_dataset = torch.utils.data.TensorDataset(test_representations, test_labels)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1000, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)


In [123]:
len(test_idx)

21166

In [125]:
from torch import nn

class LinearClassifier(nn.Module):
    def __init__(self,input_dim,num_classes):
        super(LinearClassifier,self).__init__()
        self.linear = nn.Linear(input_dim,num_classes)
    
    def forward(self,x):
        return self.linear(x)
    
class AttentionClassifier(nn.Module):
    def __init__(self,input_dim,num_classes):
        super(AttentionClassifier,self).__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, input_dim))
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=input_dim,
            nhead=8,
            dim_feedforward=1536,
            dropout=0.1,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)
        self.classifier = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        # x shape: (bs, 49, 768)
        batch_size = x.size(0)
        
        # Expand cls_token for batch
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # (bs, 1, 768)
        
        # Use input sequence as memory for decoder
        memory = x  # (bs, 49, 768)
        
        # Pass cls_token through transformer decoder
        decoded = self.transformer_decoder(cls_tokens, memory)  # (bs, 1, 768)
        
        # Extract cls token and classify
        cls_output = decoded.squeeze(1)  # (bs, 768)
        logits = self.classifier(cls_output)  # (bs, 35)
        
        return logits

model = LinearClassifier(768,35)
model.to('mps')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
from tqdm import tqdm

for epoch in range(30):
    lossval = 0   
    epoch_bar = tqdm(train_dataloader,desc=f'Epoch {epoch+1}')  
    for batch in epoch_bar:
        features, labels = batch
        optimizer.zero_grad()
        outputs = model(features.to('mps'))
        loss = criterion(outputs, labels.to('mps'))
        loss.backward()
        lossval += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_bar.set_postfix({'Loss': f'{loss.item():.4f}'})










Epoch 1: 100%|██████████| 85/85 [00:00<00:00, 138.89it/s, Loss=2.6595]
Epoch 2: 100%|██████████| 85/85 [00:00<00:00, 87.54it/s, Loss=2.2478] 
Epoch 3: 100%|██████████| 85/85 [00:00<00:00, 116.67it/s, Loss=1.9485]
Epoch 4: 100%|██████████| 85/85 [00:00<00:00, 109.93it/s, Loss=1.7662]
Epoch 5: 100%|██████████| 85/85 [00:00<00:00, 120.22it/s, Loss=1.6011]
Epoch 6: 100%|██████████| 85/85 [00:00<00:00, 91.33it/s, Loss=1.5108]
Epoch 7: 100%|██████████| 85/85 [00:00<00:00, 120.11it/s, Loss=1.3742]
Epoch 8: 100%|██████████| 85/85 [00:00<00:00, 119.63it/s, Loss=1.3044]
Epoch 9: 100%|██████████| 85/85 [00:00<00:00, 118.14it/s, Loss=1.1483]
Epoch 10: 100%|██████████| 85/85 [00:00<00:00, 87.47it/s, Loss=1.1267]
Epoch 11: 100%|██████████| 85/85 [00:00<00:00, 115.16it/s, Loss=1.1384]
Epoch 12: 100%|██████████| 85/85 [00:00<00:00, 117.58it/s, Loss=1.0657]
Epoch 13: 100%|██████████| 85/85 [00:00<00:00, 117.97it/s, Loss=1.0502]
Epoch 14: 100%|██████████| 85/85 [00:00<00:00, 85.33it/s, Loss=1.0164]
Epoc

In [126]:


model.eval()
predictions = []
labels = []
with torch.no_grad():
    for batch in test_dataloader:
        features, label = batch
        
        predictions.append(model(features.to('mps')).cpu().argmax(dim=1))
        labels.append(label)

predictions = torch.cat(predictions, dim=0)
labels = torch.cat(labels, dim=0)

accuracy = (predictions == labels).float().mean()
print(f"Accuracy: {accuracy:.4f}")





Accuracy: 0.8281


In [None]:
from torchaudio import transforms

In [133]:
#get the pca of the cat features

from sklearn.decomposition import PCA

pca = PCA(n_components=20)
pca.fit(cat_feat)
pca_feat = pca.transform(cat_feat)

from sklearn.metrics import normalized_mutual_info_score

from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=35,random_state=0)
kmeans.fit(pca_feat)

print(normalized_mutual_info_score(kmeans.labels_,int_labels))

0.12074510112527835
