In [55]:
import soundfile as sf
import os
import numpy as np
import torchaudio
from IPython.display import Audio
import io

import torch
import torch.nn as nn


from config import Wav2Vec2Config
from model import Wav2Vec2ForPreTraining,Wav2Vec2FeatureEncoder,Wav2Vec2GumbelVectorQuantizer,_compute_mask_indices,Wav2Vec2Encoder,Wav2Vec2FeatureProjection


def resample_audio_torchaudio(file_path, original_sample_rate=44100, target_sample_rate=16000):
    waveform, sample_rate = torchaudio.load(file_path)
    if sample_rate != original_sample_rate:
        raise ValueError(f"Expected sample rate to be {original_sample_rate}, but got {sample_rate}")
    
    resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
    waveform = resampler(waveform)
    
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    return waveform.squeeze(), target_sample_rate

def load_audio(path):
    waveform,sample_rate = torchaudio.load(path)

    return waveform.mean(dim=0), sample_rate

def load_dataset(file_list):
    dataset = []
    for file_path in file_list:
        if file_path.endswith('.mp3'):
            audio, sample_rate = resample_audio_torchaudio(file_path)
            dataset.append(audio)
    return torch.stack(dataset)


dataset = load_dataset([f'data/mp3_train_files/Gould/Gould - WTC_clip_{i}.mp3' for i in range(1,4)])
dataset.shape

torch.Size([3, 80000])

X -> Z

In [51]:
config = Wav2Vec2Config()


feature_encoder = Wav2Vec2FeatureEncoder(config)
# print(feature_encoder)


latent_reps= feature_encoder(dataset)
print(f"{latent_reps.shape=}") # batch,num_channels,cov_output_len


latent_reps.shape=torch.Size([3, 512, 249])


project Z to correct dim

In [49]:
feature_projection = Wav2Vec2FeatureProjection(config)
print(feature_projection)

hidden_states, extract_features = feature_projection(latent_reps.transpose(1,2))

print(f"{hidden_states.shape=}")


# then mask here

Wav2Vec2FeatureProjection(
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (projection): Linear(in_features=512, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)
hidden_states.shape=torch.Size([3, 249, 768])


In [52]:
encoder = Wav2Vec2Encoder(config)

encoder_outputs = encoder(hidden_states)

hidden_states = encoder_outputs[0]

In [54]:
hidden_states,extract_features

tensor([[[-1.9127, -0.6875, -1.4736,  ..., -2.1643,  0.8411, -0.0380],
         [-1.4061, -1.0729, -1.2450,  ..., -1.4550, -0.0604,  0.0437],
         [-1.2815, -0.5890, -1.6554,  ..., -1.7457,  0.7717, -0.3420],
         ...,
         [-1.1253,  0.2196, -1.4514,  ..., -1.4712,  1.0693, -0.1597],
         [-0.9389,  0.5538, -0.6476,  ..., -1.0568,  0.7462, -1.1023],
         [-1.1520,  0.5902, -0.0117,  ..., -0.3745,  1.1961, -0.8357]],

        [[ 0.1138, -0.5742, -1.3258,  ..., -1.4461,  0.8528, -0.3117],
         [-0.4957, -1.0862, -0.8850,  ..., -2.2565,  0.1077,  0.7734],
         [-0.7777, -0.1501,  0.3099,  ..., -1.7332,  0.9596, -0.0120],
         ...,
         [-0.2798, -0.0862, -1.3628,  ..., -1.7271,  0.7087, -0.3084],
         [ 0.5742,  0.3760,  0.4267,  ..., -1.7190,  1.3869, -0.7468],
         [-2.1264,  1.0481,  0.1977,  ..., -0.9605,  0.5391,  0.3180]],

        [[-0.9347,  0.1893, -0.9451,  ..., -0.4437, -0.2000, -0.5965],
         [-1.6724, -0.2969, -1.0850,  ..., -0

In [59]:
project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
dropout_features = nn.Dropout(config.feat_quantizer_dropout)

transformer_features = project_hid(hidden_states)

extract_features = dropout_features(extract_features)

In [66]:
extract_features.shape

torch.Size([3, 249, 512])

In [72]:
quantizer = Wav2Vec2GumbelVectorQuantizer(config)
project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)

mask_time_indices = torch.tensor(_compute_mask_indices(shape=(extract_features.shape[0], extract_features.shape[1]), mask_prob=0.2, mask_length=2))

quantized_features, codevector_perplexity = quantizer(extract_features,mask_time_indices=mask_time_indices) 

quantized_features = project_q(quantized_features)

print(quantized_features.shape)

torch.Size([3, 249, 256])


In [73]:
config = Wav2Vec2Config()
model = Wav2Vec2ForPreTraining(config)

In [82]:
dataset[0].unsqueeze(0)

tensor([[ 0.0000,  0.0000,  0.0000,  ..., -0.0183, -0.0127, -0.0075]])

In [60]:
input_values = [] #  Float values of input raw speech waveform.
attention_mask = []  # bool tensor (batch_size, seq_len)
mask_time_indices = []# bool tensor (batch_size, seq_len)
sampled_negative_indices = [] # bool tensor (batch_size, sequence_length, num_negatives)
output_attentions = torch.BoolTensor(0)
output_hidden_states = torch.BoolTensor(0)
return_dict = torch.BoolTensor(1)


In [83]:
model(dataset[0].unsqueeze(0))

Wav2Vec2ForPreTrainingOutput(loss=None, projected_states=tensor([[[ 0.3807, -0.4252,  0.3157,  ..., -0.1948,  0.0573,  1.0749],
         [-0.6598, -0.2952, -0.6157,  ..., -0.2388,  0.6263,  0.7527],
         [-0.4502, -0.5760,  0.4205,  ..., -0.6499,  0.4563,  0.8842],
         ...,
         [-0.4406,  0.0493, -0.0269,  ...,  0.1158,  0.1032,  0.3635],
         [ 0.2953, -0.0126,  0.1722,  ..., -0.0835,  0.6201,  0.1347],
         [-0.5455, -0.4256, -0.0034,  ..., -0.0527, -0.3368,  0.4376]]],
       grad_fn=<ViewBackward0>), projected_quantized_states=tensor([[[-0.0051,  0.0159,  0.0079,  ...,  0.0671,  0.0092,  0.0474],
         [-0.0108, -0.0311,  0.0222,  ...,  0.0232, -0.0338, -0.0647],
         [ 0.0004,  0.0127,  0.0092,  ..., -0.0064, -0.0087,  0.0164],
         ...,
         [ 0.0073, -0.0303, -0.0054,  ..., -0.0123,  0.0036, -0.0212],
         [-0.0235,  0.0589,  0.0228,  ...,  0.0305, -0.0358, -0.0126],
         [ 0.0592, -0.0347,  0.0043,  ...,  0.0401,  0.0288, -0.0091]]],