In [1]:
from os.path import join, dirname, split
import inspect

from transformers import AutoProcessor, AutoModelForCTC
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model, Wav2Vec2ForCTC
from transformers import Wav2Vec2Model, Wav2Vec2Config
import torch.nn as nn
import torchaudio

ROOT_DIR = split(dirname("/workspace/NN/notebooks/test.ipynb"))[0]
print(ROOT_DIR)

/workspace/NN


In [2]:
configuration = Wav2Vec2Config()

processor = AutoProcessor.from_pretrained(
            "Eyvaz/wav2vec2-base-russian-demo-kaggle",
            cache_dir=join(ROOT_DIR, "weights", "loaded_weights", ),
         )
model : Wav2Vec2ForCTC = AutoModelForCTC.from_pretrained(
    "Eyvaz/wav2vec2-base-russian-demo-kaggle",
    cache_dir=join(ROOT_DIR, "weights", "loaded_weights", ),
)


bundle = torchaudio.pipelines.WAV2VEC2_BASE
old_feature_extractor = bundle.get_model(
            dl_kwargs={
                "file_name": join(ROOT_DIR, "weights", "loaded_weights", "wav2vec2_fairseq_base_ls960.pth")
            }
        )


Some weights of the model checkpoint at Eyvaz/wav2vec2-base-russian-demo-kaggle were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v', 'wav2vec2.masked_spec_embed']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Eyvaz/wav2vec2-base-russian-demo-kaggle and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.we

In [3]:
model.wav2vec2

Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [19]:
feature_extractor = nn.Sequential(
    model.wav2vec2.feature_extractor,
    model.wav2vec2.feature_projection,
)

for param in feature_extractor.parameters():
    param.requires_grad = False

feature_extractor

Sequential(
  (0): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (1): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
)

In [20]:
old_feature_extractor

Wav2Vec2Model(
  (feature_extractor): FeatureExtractor(
    (conv_layers): ModuleList(
      (0): ConvLayerBlock(
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
      )
      (1-4): 4 x ConvLayerBlock(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
      )
      (5-6): 2 x ConvLayerBlock(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
      )
    )
  )
  (encoder): Encoder(
    (feature_projection): FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (pos_conv_embed): ConvolutionalPositionalEmbedding(
        (conv): ParametrizedConv1d(
          768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16
          (parametriza

In [10]:
from torch.nn.parameter import Parameter
from collections import OrderedDict
import torch.nn as nn


layers = nn.Sequential(
    OrderedDict([
            ("1d_conv", nn.Conv2d(99, 16, 1)),
            ("bath_after_1d_conv", nn.BatchNorm2d(16)),
            ("relu_after_1d_conv", nn.ReLU(inplace=True)),
        ]))
param_groups = [{'params': list((list(layers[0].parameters())[0])), 'lr': 1e-5}]

param_groups

[{'params': [tensor([[[ 0.0846]],
   
           [[-0.0029]],
   
           [[-0.0239]],
   
           [[-0.0774]],
   
           [[ 0.0194]],
   
           [[ 0.0944]],
   
           [[ 0.0296]],
   
           [[ 0.0539]],
   
           [[ 0.0137]],
   
           [[ 0.0676]],
   
           [[ 0.0296]],
   
           [[-0.0049]],
   
           [[ 0.0296]],
   
           [[ 0.0815]],
   
           [[ 0.0906]],
   
           [[-0.0257]],
   
           [[-0.0939]],
   
           [[-0.0768]],
   
           [[ 0.0888]],
   
           [[-0.0206]],
   
           [[-0.0549]],
   
           [[ 0.0142]],
   
           [[ 0.0305]],
   
           [[ 0.0509]],
   
           [[-0.0466]],
   
           [[ 0.0473]],
   
           [[ 0.0940]],
   
           [[-0.0382]],
   
           [[ 0.0825]],
   
           [[-0.0464]],
   
           [[ 0.0810]],
   
           [[-0.0163]],
   
           [[-0.0121]],
   
           [[ 0.0733]],
   
           [[ 0.0934]],
   
          