In [1]:
!pip install transformers



In [11]:
import torch.nn as nn
import torchaudio

In [51]:
audio_path = "/content/colorofsky.wav"
waveform, sample_rate = torchaudio.load(audio_path)

print(waveform)
print(sample_rate)

tensor([[ 0.0000,  0.0000,  0.0000,  ..., -0.0090, -0.0098, -0.0111],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0032, -0.0009,  0.0003]])
48000


In [52]:
if sample_rate != 16000:
  resample = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
  waveform = resample(waveform)

if waveform.shape[0] > 1:
  waveform = waveform.mean(dim=0, keepdim=True)

In [37]:
from transformers import Wav2Vec2Model

# Load the pre-trained model from Hugging Face
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")



In [38]:
model

Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [53]:
encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")



In [50]:
projection = nn.Linear(encoder.config.hidden_size, 1024)

In [41]:
waveform.shape

torch.Size([1, 49536])

In [54]:
import torch

with torch.no_grad():
  waveform = encoder(waveform)

print(waveform)

Wav2Vec2BaseModelOutput(last_hidden_state=tensor([[[-0.0781,  0.1688,  0.0316,  ...,  0.3240,  0.4987, -0.3718],
         [-0.0874,  0.0913, -0.1814,  ...,  0.2342,  0.5450, -0.5056],
         [-0.0818,  0.0838, -0.1880,  ...,  0.2107,  0.5439, -0.5101],
         ...,
         [ 0.1371,  0.2922, -0.0030,  ...,  0.2240,  0.2370, -0.3827],
         [ 0.1242,  0.2938, -0.0057,  ...,  0.3437,  0.2205, -0.4508],
         [ 0.1906,  0.3058, -0.1881,  ...,  0.2661,  0.1275, -0.3255]]]), extract_features=tensor([[[ 0.3463, -0.0382, -0.0162,  ..., -0.4728, -0.1189, -0.3435],
         [ 0.3463, -0.0382, -0.0162,  ..., -0.4728, -0.1189, -0.3435],
         [ 0.3463, -0.0382, -0.0162,  ..., -0.4728, -0.1189, -0.3435],
         ...,
         [ 0.5760, -0.3437,  0.6435,  ...,  0.0543, -0.2448,  0.3804],
         [ 0.4357, -0.1091,  0.5273,  ...,  0.0175,  0.0106,  0.2667],
         [ 0.7384, -0.2758,  0.4743,  ..., -0.2888, -0.2321,  0.1672]]]), hidden_states=None, attentions=None)


In [55]:
waveform = waveform.last_hidden_state

print(waveform)
print(waveform.shape)

tensor([[[-0.0781,  0.1688,  0.0316,  ...,  0.3240,  0.4987, -0.3718],
         [-0.0874,  0.0913, -0.1814,  ...,  0.2342,  0.5450, -0.5056],
         [-0.0818,  0.0838, -0.1880,  ...,  0.2107,  0.5439, -0.5101],
         ...,
         [ 0.1371,  0.2922, -0.0030,  ...,  0.2240,  0.2370, -0.3827],
         [ 0.1242,  0.2938, -0.0057,  ...,  0.3437,  0.2205, -0.4508],
         [ 0.1906,  0.3058, -0.1881,  ...,  0.2661,  0.1275, -0.3255]]])
torch.Size([1, 154, 768])


In [56]:
waveform

tensor([[[-0.0781,  0.1688,  0.0316,  ...,  0.3240,  0.4987, -0.3718],
         [-0.0874,  0.0913, -0.1814,  ...,  0.2342,  0.5450, -0.5056],
         [-0.0818,  0.0838, -0.1880,  ...,  0.2107,  0.5439, -0.5101],
         ...,
         [ 0.1371,  0.2922, -0.0030,  ...,  0.2240,  0.2370, -0.3827],
         [ 0.1242,  0.2938, -0.0057,  ...,  0.3437,  0.2205, -0.4508],
         [ 0.1906,  0.3058, -0.1881,  ...,  0.2661,  0.1275, -0.3255]]])

In [57]:
waveform = projection(waveform)

In [58]:
print(waveform)
print(waveform.shape)

tensor([[[ 0.0921, -0.1642,  0.1744,  ...,  0.0982, -0.1150,  0.1578],
         [ 0.1408, -0.0401,  0.1541,  ...,  0.1644, -0.1875,  0.1902],
         [ 0.1378, -0.0298,  0.1452,  ...,  0.1770, -0.2041,  0.1955],
         ...,
         [ 0.0021, -0.1432,  0.2494,  ..., -0.0727, -0.1309, -0.5114],
         [-0.0091, -0.0453,  0.2223,  ..., -0.0582, -0.0601, -0.4734],
         [ 0.0040, -0.1545,  0.1704,  ..., -0.1562,  0.0898, -0.5249]]],
       grad_fn=<ViewBackward0>)
torch.Size([1, 154, 1024])
