In [None]:
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2Model, TrainingArguments, Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the AudioCNN class with adjustments for input size
class AudioCNN(nn.Module):
    def __init__(self, num_classes, input_size):
        super(AudioCNN, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * (input_size // 8), 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add a channel dimension
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Dataset class for Wav2Vec2
class Wav2Vec2Dataset(Dataset):
    def __init__(self, dataset_path, data_type, processor, max_files=None, sampling_rate=16000):
        self.data = []
        self.labels = []
        self.sampling_rate = sampling_rate
        for file_number in range(1, max_files + 1 if max_files else 51):
            filename = f"data_{file_number}_{data_type}.pkl"
            file_path = os.path.join(dataset_path, filename)
            if os.path.exists(file_path):
                with open(file_path, 'rb') as file:
                    file_data = pickle.load(file)
                    for item in file_data:
                        if not isinstance(item, list) or len(item) != 2 or not isinstance(item[1], tuple):
                            raise ValueError("Invalid data format")

                        waveform, label_data = item
                        label_tensor, label_string = label_data  # Assuming the label is the first element in the tuple
                        
                        # Validate label_tensor is a tensor
                        if not isinstance(label_tensor, torch.Tensor):
                            raise ValueError("Label is not a tensor")

                        input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=self.sampling_rate).input_values
                        self.data.append(input_values.squeeze(0).to(device))
                        
                        # Convert the label tensor to the correct device and add it to labels
                        self.labels.append(label_tensor.to(device))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {"input_values": self.data[idx], "labels": self.labels[idx]}

# Initialize Wav2Vec2 processor and model for fine-tuning
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
model_wav2vec2 = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h", num_labels=10)  # Adjust num_labels as needed

# Load datasets
dataset_path = 'emi_dataset/'
train_dataset = Wav2Vec2Dataset(dataset_path, "train", processor)
valid_dataset = Wav2Vec2Dataset(dataset_path, "valid", processor)

# Define training arguments
training_args = TrainingArguments(
    output_dir='wav2vec2_finetuned_results',
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='wav2vec2_finetuned_logs',
    evaluation_strategy="epoch",
    save_strategy="epoch"
)

# Initialize Trainer
trainer = Trainer(
    model=model_wav2vec2,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=processor.feature_extractor,
)

# Fine-tune Wav2Vec2
trainer.train()

# Save the fine-tuned model
model_wav2vec2.save_pretrained('fine_tuned_wav2vec2')

# Load the fine-tuned model for feature extraction
model_wav2vec2 = Wav2Vec2Model.from_pretrained('fine_tuned_wav2vec2')

# Function to load data and extract features using the fine-tuned Wav2Vec2
def load_data(dataset_path, data_type, processor, model_wav2vec2, sampling_rate=16000):
    data = []
    labels = []
    for file_number in range(1, 51):
        filename = f"data_{file_number}_{data_type}.pkl"
        file_path = os.path.join(dataset_path, filename)
        if os.path.exists(file_path):
            with open(file_path, 'rb') as file:
                file_data = pickle.load(file)
                for waveform, label in file_data:
                    input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=sampling_rate).input_values
                    with torch.no_grad():
                        features = model_wav2vec2(input_values).last_hidden_state.mean(dim=1).squeeze()
                    data.append(features)
                    labels.append(label[0])

    return torch.stack(data), torch.tensor(labels)

# Load and prepare data with extracted features
train_data, train_labels = load_data(dataset_path, "train", processor, model_wav2vec2)
train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=32, shuffle=True)

validate_data, validate_labels = load_data(dataset_path, "valid", processor, model_wav2vec2)
validate_loader = DataLoader(TensorDataset(validate_data, validate_labels), batch_size=32)

test_data, test_labels = load_data(dataset_path, "test", processor, model_wav2vec2)
test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=32)

# Initialize and train the AudioCNN model (same as your original code)


Some weights of the model checkpoint at facebook/wav2vec2-large-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You s

Epoch,Training Loss,Validation Loss
1,No log,60.492317
2,No log,50.131699
3,30.488200,43.997604
4,30.488200,46.836647
5,22.756700,43.698139
6,22.756700,42.935936


In [None]:
# Initialize the AudioCNN model
audio_cnn = AudioCNN(num_classes=10, input_size=train_data.size(1)).to(device)  # Adjust num_classes and input_size as needed

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(audio_cnn.parameters(), lr=0.001)
optimizer = torch.optim.Adam(audio_cnn.parameters(), lr=0.0001)  # Adjust the learning rate here

# Training loop for AudioCNN
num_epochs = 30  # Adjust the number of epochs as needed
for epoch in range(num_epochs):
    audio_cnn.train()
    running_loss = 0.0
    for batch_idx, (features, labels) in enumerate(train_loader):
        features, labels = features.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = audio_cnn(features)
        loss = criterion(outputs, labels.long().squeeze())  # Ensure labels are correct shape
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Step {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # Print epoch loss
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Average Loss: {epoch_loss:.4f}")

# Validation loop for AudioCNN
audio_cnn.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for features, labels in validate_loader:
        features, labels = features.to(device), labels.to(device)
        outputs = audio_cnn(features)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.squeeze()).sum().item()

    print(f'Validation Accuracy: {(100 * correct / total):.2f}%')

# Test loop for AudioCNN
with torch.no_grad():
    correct = 0
    total = 0
    for features, labels in test_loader:
        features, labels = features.to(device), labels.to(device)
        outputs = audio_cnn(features)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.squeeze()).sum().item()

    print(f'Test Accuracy: {(100 * correct / total):.2f}%')
