<a href="https://colab.research.google.com/github/Youruler1/Speech-Processing-Lab-Material/blob/main/wav2vec1_0_(Updated).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler

# ---- Optimize cuDNN for consistent input sizes ----
torch.backends.cudnn.benchmark = True

# ---- Audio Preprocessing ----
def preprocess_audio(waveform, sample_rate, target_length=16000):
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    if sample_rate != 16000:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    if waveform.shape[1] < target_length:
        pad_amount = target_length - waveform.shape[1]
        waveform = F.pad(waveform, (0, pad_amount))
    else:
        waveform = waveform[:, :target_length]
    return waveform

# ---- Dataset Wrapper ----
class SpeechCommandsDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, label_map):
        self.dataset = dataset
        self.label_map = label_map

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        waveform, sample_rate, label, *_ = self.dataset[idx]
        waveform = preprocess_audio(waveform, sample_rate)
        return waveform, self.label_map[label]

# ---- Classifier Model ----
class Wav2VecClassifier(nn.Module):
    def __init__(self, num_classes):
        super(Wav2VecClassifier, self).__init__()
        self.feature_encoder = FeatureEncoder()
        self.context_network = ContextNetwork()
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.feature_encoder(x)
        x = self.context_network(x)
        x = torch.mean(x, dim=1)
        return self.fc(x)

# ---- Training Loop with Mixed Precision ----
def train_model(model, train_loader, criterion, optimizer, num_epochs=5, device='cuda'):
    scaler = GradScaler()
    model.train()

    for epoch in range(num_epochs):
        total_loss, total_correct, total_samples = 0, 0, 0
        for waveforms, labels in train_loader:
            waveforms = waveforms.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            with autocast():
                outputs = model(waveforms)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            total_correct += (outputs.argmax(1) == labels).sum().item()
            total_samples += labels.size(0)

        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Accuracy={total_correct/total_samples:.4f}")

# ---- Main Execution ----
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = torchaudio.datasets.SPEECHCOMMANDS(root="./data", download=True)
    labels = sorted(set(entry[2] for entry in dataset))
    label_map = {label: i for i, label in enumerate(labels)}

    train_dataset = SpeechCommandsDataset(dataset, label_map)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)

    model = Wav2VecClassifier(num_classes=len(label_map)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    train_model(model, train_loader, criterion, optimizer, num_epochs=10, device=device)
