In [1]:
import os
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from torch.utils.data import Dataset

class AudioDataset(Dataset):
    def __init__(self, file_list, processor, model, device):
        self.file_list = file_list
        self.processor = processor
        self.model = model
        self.device = device

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        waveform, sample_rate = torchaudio.load(file_path)
        
        # Process the audio file as before
        # Convert stereo to mono
        if waveform.shape[0] == 2:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample to 16kHz if necessary
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(waveform)

        # Remove the extra dimension
        waveform = waveform.squeeze(0)



        # Process the waveform
        inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)

        # Pass the input through the model
        with torch.no_grad():
            outputs = model(**inputs)

        # Extract the last hidden states
        hidden_states = outputs.last_hidden_state

        # Aggregate to a single representation (if needed)
        audio_representation = torch.mean(hidden_states, dim=1)

        return audio_representation

# Initialize Wav2Vec2 outside the dataset for efficiency
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").to(device)
model.eval()


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [3]:
import random

def split_data(audio_dir, train_ratio=0.8, valid_ratio=0.1):
    files = [os.path.join(audio_dir, f) for f in os.listdir(audio_dir) if f.endswith('.mp3')]
    random.shuffle(files)

    total_files = len(files)
    train_count = int(total_files * train_ratio)
    valid_count = int(total_files * valid_ratio)

    train_files = files[:train_count]
    valid_files = files[train_count:train_count + valid_count]
    test_files = files[train_count + valid_count:]

    return train_files, valid_files, test_files

audio_dir = '../preprocess_skyfall/audio_1-4seconds'
train_files, valid_files, test_files = split_data(audio_dir)


In [4]:
from torch.utils.data import DataLoader

train_dataset = AudioDataset(train_files, processor, model, device)
valid_dataset = AudioDataset(valid_files, processor, model, device)
test_dataset = AudioDataset(test_files, processor, model, device)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
