In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import os
import pickle

# Define a new combined model
class Wav2Vec2AudioClassifier(nn.Module):
    def __init__(self, num_classes, feature_extractor_model="facebook/wav2vec2-large-960h"):
        super(Wav2Vec2AudioClassifier, self).__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained(feature_extractor_model)
        self.classifier = nn.Sequential(
            nn.Linear(self.wav2vec2.config.hidden_size, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, input_values):
        with torch.no_grad():  # Remove this line if you want to fine-tune Wav2Vec2
            features = self.wav2vec2(input_values).last_hidden_state
        features = features.mean(dim=1)  # Global average pooling
        return self.classifier(features)

# Function to load data
def load_data(dataset_path, data_type, processor, sampling_rate=16000):
    data = []
    labels = []
    for file_number in range(1, 51):
        filename = f"data_{file_number}_{data_type}.pkl"
        file_path = os.path.join(dataset_path, filename)
        if os.path.exists(file_path):
            with open(file_path, 'rb') as file:
                file_data = pickle.load(file)
                for waveform, label in file_data:
                    input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=sampling_rate).input_values
                    data.append(input_values)
                    labels.append(label)

    labels = [x[0] for x in labels]  # Adjust label format if necessary
    return data, torch.tensor(labels)

# Initialize processor and data loaders
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
dataset_path = 'emi_dataset/'

train_data, train_labels = load_data(dataset_path, "train", processor)
train_dataset = TensorDataset(torch.cat(train_data, dim=0), train_labels)  # Concatenate all data tensors
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

validate_data, validate_labels = load_data(dataset_path, "valid", processor)
validate_dataset = TensorDataset(torch.cat(validate_data, dim=0), validate_labels)
validate_loader = DataLoader(validate_dataset, batch_size=32)

# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Wav2Vec2AudioClassifier(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 30
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (input_values, labels) in enumerate(train_loader):
        input_values, labels = input_values.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(input_values)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Step {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # Validation loop (example)
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    for input_values, labels in validate_loader:
        input_values, labels = input_values.to(device), labels.to(device)
        with torch.no_grad():
            outputs = model(input_values)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss /= len(validate_loader)
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {(100 * correct / total):.2f}%')


  from .autonotebook import tqdm as notebook_tqdm
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/30, Step 1/59, Loss: 2.3147
Epoch 1/30, Step 11/59, Loss: 2.2564
Epoch 1/30, Step 21/59, Loss: 2.3417
Epoch 1/30, Step 31/59, Loss: 2.3026
Epoch 1/30, Step 41/59, Loss: 2.3153
Epoch 1/30, Step 51/59, Loss: 2.3002
Epoch 1, Validation Loss: 2.2968, Validation Accuracy: 13.67%
Epoch 2/30, Step 1/59, Loss: 2.2895
Epoch 2/30, Step 11/59, Loss: 2.3030
Epoch 2/30, Step 21/59, Loss: 2.2979
Epoch 2/30, Step 31/59, Loss: 2.3080
Epoch 2/30, Step 41/59, Loss: 2.3007
Epoch 2/30, Step 51/59, Loss: 2.3034
Epoch 2, Validation Loss: 2.2913, Validation Accuracy: 16.33%
Epoch 3/30, Step 1/59, Loss: 2.2953
Epoch 3/30, Step 11/59, Loss: 2.2981
Epoch 3/30, Step 21/59, Loss: 2.3241
Epoch 3/30, Step 31/59, Loss: 2.2908
Epoch 3/30, Step 41/59, Loss: 2.2896
Epoch 3/30, Step 51/59, Loss: 2.2953
Epoch 3, Validation Loss: 2.2957, Validation Accuracy: 13.00%
Epoch 4/30, Step 1/59, Loss: 2.2971
Epoch 4/30, Step 11/59, Loss: 2.2985
Epoch 4/30, Step 21/59, Loss: 2.3003
Epoch 4/30, Step 31/59, Loss: 2.3058
Epoc

In [None]:
# Load test data
test_data, test_labels = load_data(dataset_path, "test", processor)
test_dataset = TensorDataset(torch.cat(test_data, dim=0), test_labels)
test_loader = DataLoader(test_dataset, batch_size=32)

# Test loop
model.eval()
test_loss = 0.0
correct = 0
total = 0
for input_values, labels in test_loader:
    input_values, labels = input_values.to(device), labels.to(device)
    with torch.no_grad():
        outputs = model(input_values)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_loader)
test_accuracy = 100 * correct / total
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
