In [1]:
import soundfile as sf
import os
import numpy as np
import torchaudio
from IPython.display import Audio
import io

import torch
import torch.nn as nn


from config import Wav2Vec2Config
from model import Wav2Vec2ForPreTraining,Wav2Vec2FeatureEncoder,Wav2Vec2GumbelVectorQuantizer,_compute_mask_indices,Wav2Vec2Encoder,Wav2Vec2FeatureProjection


def resample_audio_torchaudio(file_path, original_sample_rate=44100, target_sample_rate=16000):
    waveform, sample_rate = torchaudio.load(file_path)
    if sample_rate != original_sample_rate:
        raise ValueError(f"Expected sample rate to be {original_sample_rate}, but got {sample_rate}")
    
    resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate)
    waveform = resampler(waveform)
    
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    return waveform.squeeze(), target_sample_rate

def load_audio(path):
    waveform,sample_rate = torchaudio.load(path)

    return waveform.mean(dim=0), sample_rate

def load_dataset(file_list):
    dataset = []
    for file_path in file_list:
        if file_path.endswith('.mp3'):
            audio, sample_rate = resample_audio_torchaudio(file_path)
            dataset.append(audio)
    return torch.stack(dataset)


dataset = load_dataset([f'data/mp3_train_files/Gould/Gould - WTC_clip_{i}.mp3' for i in range(1,501)])
dataset.shape

torch.Size([500, 80000])

X -> Z

In [4]:
config = Wav2Vec2Config()


feature_encoder = Wav2Vec2FeatureEncoder(config)
# print(feature_encoder)


latent_reps= feature_encoder(dataset)
print(f"{latent_reps.shape=}") # batch,num_channels,cov_output_len


KeyboardInterrupt: 

project Z to correct dim

In [3]:
feature_projection = Wav2Vec2FeatureProjection(config)
print(feature_projection)

hidden_states, extract_features = feature_projection(latent_reps.transpose(1,2))

print(f"{hidden_states.shape=}")


# then mask here

Wav2Vec2FeatureProjection(
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (projection): Linear(in_features=512, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)
hidden_states.shape=torch.Size([3, 249, 768])


In [4]:
encoder = Wav2Vec2Encoder(config)

encoder_outputs = encoder(hidden_states)

hidden_states = encoder_outputs[0]

In [5]:
#hidden_states,extract_features

In [6]:
project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
dropout_features = nn.Dropout(config.feat_quantizer_dropout)

transformer_features = project_hid(hidden_states)

extract_features = dropout_features(extract_features)

In [7]:
extract_features.shape

torch.Size([3, 249, 512])

In [8]:
quantizer = Wav2Vec2GumbelVectorQuantizer(config)
project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)

mask_time_indices = torch.tensor(_compute_mask_indices(shape=(extract_features.shape[0], extract_features.shape[1]), mask_prob=0.2, mask_length=2))

quantized_features, codevector_perplexity = quantizer(extract_features,mask_time_indices=mask_time_indices) 

quantized_features = project_q(quantized_features)

print(quantized_features.shape)

torch.Size([3, 249, 256])


In [2]:
config = Wav2Vec2Config()
model = Wav2Vec2ForPreTraining(config)

from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices


In [3]:
extract_features = model.wav2vec2.feature_extractor(dataset).transpose(1, 2)
batch_size, seq_len, _ = extract_features.shape
 
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long) # no padding tokens


mask_time_indices = _compute_mask_indices(
        shape=(batch_size, seq_len),
        mask_prob=config.mask_time_prob,
        mask_length=config.mask_time_length,
        attention_mask=attention_mask,
        min_masks=config.mask_time_min_masks
    )

sampled_negative_indices = _sample_negative_indices(
    features_shape=(batch_size, seq_len),
    num_negatives=model.config.num_negatives,
    mask_time_indices=mask_time_indices,
)

mask_time_indices = torch.tensor(mask_time_indices,dtype=torch.bool)

sampled_negative_indices = torch.tensor(sampled_negative_indices)



: 

In [43]:
extract_features = model.wav2vec2.feature_extractor(dataset).transpose(1, 2)
batch_size, seq_len, _ = extract_features.shape

In [44]:
input_values = dataset #  Float values of input raw speech waveform.
attention_mask = attention_mask  # bool tensor (batch_size, seq_len)
mask_time_indices = mask_time_indices# bool tensor (batch_size, seq_len)
sampled_negative_indices = sampled_negative_indices # bool tensor (batch_size, sequence_length, num_negatives)
output_attentions = True
output_hidden_states = False
return_dict = torch.BoolTensor(1)


In [84]:
out = model(input_values=input_values,
      attention_mask=attention_mask,
      mask_time_indices=mask_time_indices,
      sampled_negative_indices=sampled_negative_indices,
      output_attentions=output_attentions,
      output_hidden_states=output_hidden_states,
      return_dict=return_dict)

loss = out.loss

Pre model


In [85]:
loss

tensor(259.9305, grad_fn=<AddBackward0>)