In [3]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch
import logging
import torch.nn as nn

import os, sys
sys.path.insert(0, os.path.join(os.pardir, os.pardir))

# from utils.data import save_pickle,load_pickle


import pickle


__all__ = [
    "save_pickle",
    "load_pickle",
]


# def save_pickle func
def save_pickle(fname, obj):
    """
    Save object to pickle file.
    :param fname: The file name to save object.
    :param obj: The object to be saved.
    """
    with open(fname, "wb") as f:
        pickle.dump(obj, f)


# def load_pickle func
def load_pickle(fname):
    """
    Load object from pickle file.
    :param fname: The file name to load object.
    :return obj: The object loaded from file.
    """
    with open(fname, "rb") as f:
        obj = pickle.load(f)
    return obj

# write a class for wav2vec
class wav2vec():
    def __init__(self, model_name="facebook/wav2vec2-base-960h", device="cpu"):
        self.model_name = model_name
        self.device = device
        self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
        self.model = Wav2Vec2Model.from_pretrained(self.model_name).to(self.device)
        self.logger = logging.getLogger(__name__)
        self.logger.info("wav2vec model loaded.")
        self.mlp = torch.nn.Sequential(
            nn.linear(768, 256),
            nn.GELU(),
            nn.linear(256, 128),
        )

    def forward(self, audio, return_tensors="pt", padding="longest"):
        input_values = self.processor(audio, return_tensors=return_tensors, padding=padding).input_values.to(self.device)
        representation = self.model(input_values)
        last_hidden_state = representation.last_hidden_state
        return self.mlp(last_hidden_state.permute(0, 2, 1))

    def save(self, path):
        torch.save(self.model.state_dict(), path)
        self.logger.info("Wav2Vec2 model saved.")

    def load(self, path):
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        self.logger.info("Wav2Vec2 model loaded.")

In [4]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

Downloading:   0%|          | 0.00/159 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/163 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# Load the audio file
waveform, sample_rate = torchaudio.load("../../preprocess_skyfall/audio_1-4seconds/segment_1_00-01-24,935_00-01-26,111.mp3")


# Convert stereo to mono
if waveform.shape[0] == 2:
    waveform = waveform.mean(dim=0, keepdim=True)

# Resample to 16kHz if necessary
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

# Remove the extra dimension
waveform = waveform.squeeze(0)

# Load the processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

# Process the waveform
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)

# Pass the input through the model
with torch.no_grad():
    outputs = model(**inputs)

# Extract the last hidden states
hidden_states = outputs.last_hidden_state

# Aggregate to a single representation (if needed)
audio_representation = torch.mean(hidden_states, dim=1)


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2Model: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
print(audio_representation.shape)

torch.Size([1, 768])
