In [10]:
import torch
from transformers import Wav2Vec2Model
class Wav2Vec2Base(torch.nn.Module):
    def __init__(self, vocab_size,attention_dropout=0.1, hidden_dropout=0.1, feat_proj_dropout = 0.1,
                    mask_time_prob=0.075,layerdrop=0.1,classifier_dropout=0.1,pretrained="facebook/wav2vec2-xls-r-300m"):
        super().__init__()
        if pretrained is not None:
            self.model = Wav2Vec2Model.from_pretrained(
                pretrained, 
                attention_dropout=attention_dropout,
                hidden_dropout=hidden_dropout,
                feat_proj_dropout=feat_proj_dropout,
                mask_time_prob=mask_time_prob,
                layerdrop=layerdrop)
        else:
            raise ValueError("non preteained model is not supported yet")
        self.dropout = torch.nn.Dropout(p=classifier_dropout)
        self.classifier = torch.nn.Linear(1024,vocab_size)

    def forward(self,inp):
        return self.classifier(self.dropout(self.model(inp).last_hidden_state))

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
model = Wav2Vec2Base(80)



In [16]:
model.model.feature_extractor

Wav2Vec2FeatureEncoder(
  (conv_layers): ModuleList(
    (0): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (activation): GELUActivation()
    )
    (1): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (activation): GELUActivation()
    )
    (2): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (activation): GELUActivation()
    )
    (3): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (activation): GELUActivation()
    )
    (4): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,)

In [80]:
inp = torch.randn(1,60000)
with torch.no_grad():
    result = model(inp)

In [83]:
result.last_hidden_state.shape

torch.Size([1, 187, 1024])