In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
from copy import deepcopy

import torch
import torch.nn as nn
from transformers import AutoFeatureExtractor, Wav2Vec2Model

In [31]:
class TransformerBaseLine(nn.Module):
    def __init__(self, pretrain_feat="extract_features"):
        super().__init__()

        assert pretrain_feat in ["last_hidden_state", "extract_features"]
        self.pretrain_feat = pretrain_feat
        # The channels of used features for the pretrained model is 512 when using
        # the 'extract_features',  but 768 when ["last_hidden_state"] is used.
        C_features = 512 if pretrain_feat == "extract_features" else 768

        self.pretrain_model = Wav2Vec2Model.from_pretrained(
            "/usr/local/ay_data/0-model_weights/models--facebook--wav2vec2-base-960h"
        )

    def build_final_block(self):
        copied_layers = [deepcopy(self.pretrain_model.encoder.layers[i]) for i in range(6, 12)]
        self.copied_transformer = nn.ModuleList(copied_layers)

    def copy_final_stage(self):
        # self.block4_copied = self.build_final_block()
        self.build_final_block()

    def extract_feature(self, x):
        extract_features = self.pretrain_model.feature_extractor(x)
        extract_features = extract_features.transpose(1, 2)

        hidden_states, extract_features = self.pretrain_model.feature_projection(extract_features)
        hidden_states = self.pretrain_model._mask_hidden_states(
            hidden_states, mask_time_indices=None, attention_mask=None
        )

        #### split encoder process
        encoder = self.pretrain_model.encoder

        position_embeddings = encoder.pos_conv_embed(hidden_states)
        hidden_states = hidden_states + position_embeddings
        hidden_states = encoder.layer_norm(hidden_states)
        hidden_states = encoder.dropout(hidden_states)
        #### In original Wav2Vec, encoder has 12 layers
        for layer in encoder.layers:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.rand([])

            skip_the_layer = True if self.training and (dropout_probability < encoder.config.layerdrop) else False
            if not skip_the_layer:
                layer_outputs = layer(hidden_states, attention_mask=None, output_attentions=None)
                hidden_states = layer_outputs[0]

        return hidden_states

    def get_main_stem(self):
        encoder = self.pretrain_model.encoder
        return [
            self.pretrain_model.feature_extractor,
            self.pretrain_model.feature_projection,
            encoder.pos_conv_embed, encoder.layer_norm, encoder.layers[0:6]
        ]

    def get_content_stem(self):
        encoder = self.pretrain_model.encoder
        return [encoder.layers[6:]]

    def get_vocoder_stem(self):
        return [self.copied_transformer]

    def preprocess(self, x, **kwargs):
        return x[:, 0, :]
    
    def get_hidden_state(self, x):
        extract_features = self.pretrain_model.feature_extractor(x)
        extract_features = extract_features.transpose(1, 2)

        hidden_states, extract_features = self.pretrain_model.feature_projection(extract_features)
        hidden_states = self.pretrain_model._mask_hidden_states(
            hidden_states, mask_time_indices=None, attention_mask=None
        )

        #### split encoder process
        encoder = self.pretrain_model.encoder

        position_embeddings = encoder.pos_conv_embed(hidden_states)
        hidden_states = hidden_states + position_embeddings
        hidden_states = encoder.layer_norm(hidden_states)
        hidden_states = encoder.dropout(hidden_states)
        #### In original Wav2Vec, encoder has 12 layers
        for layer in encoder.layers[0:6]:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.rand([])

            skip_the_layer = True if self.training and (dropout_probability < encoder.config.layerdrop) else False
            if not skip_the_layer:
                layer_outputs = layer(hidden_states, attention_mask=None, output_attentions=None)
                hidden_states = layer_outputs[0]

        return hidden_states


    def get_final_feature(self, hidden_states):
        encoder = self.pretrain_model.encoder
        for layer in encoder.layers[6:]:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.rand([])

            skip_the_layer = True if self.training and (dropout_probability < encoder.config.layerdrop) else False
            if not skip_the_layer:
                layer_outputs = layer(hidden_states, attention_mask=None, output_attentions=None)
                hidden_states = layer_outputs[0]

        return hidden_states.mean(1)

    def get_final_feature_copyed(self, hidden_states):
        encoder = self.pretrain_model.encoder
        for layer in self.copied_transformer:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.rand([])

            skip_the_layer = True if self.training and (dropout_probability < encoder.config.layerdrop) else False
            if not skip_the_layer:
                layer_outputs = layer(hidden_states, attention_mask=None, output_attentions=None)
                hidden_states = layer_outputs[0]

        return hidden_states.mean(1)

In [32]:
x = torch.rand(10, 69000)
model = BaseLine()

model.extract_feature(x)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at /usr/local/ay_data/0-model_weights/models--facebook--wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor([[[-9.8749e-02, -1.7376e-02,  2.5778e-01,  ..., -1.0865e-01,
          -6.8892e-03, -1.7159e-01],
         [-9.5807e-02, -1.5792e-02,  2.5806e-01,  ..., -1.0220e-01,
          -9.5984e-03, -1.6934e-01],
         [-9.4965e-02, -1.8909e-02,  2.5308e-01,  ..., -9.7469e-02,
          -3.2506e-03, -1.7182e-01],
         ...,
         [-9.7951e-02, -1.5857e-02,  2.5040e-01,  ..., -1.0613e-01,
          -3.4473e-03, -1.7160e-01],
         [-9.7289e-02, -1.5004e-02,  2.5057e-01,  ..., -1.0719e-01,
          -5.3351e-03, -1.7056e-01],
         [-9.5800e-02, -1.5553e-02,  2.5213e-01,  ..., -1.0299e-01,
          -5.1184e-03, -1.7331e-01]],

        [[-9.3925e-02, -1.3120e-02,  2.2336e-01,  ..., -9.2711e-02,
           6.2209e-03, -1.4708e-01],
         [-9.2376e-02, -1.3820e-02,  2.1882e-01,  ..., -8.3005e-02,
           9.1521e-03, -1.4644e-01],
         [-9.7433e-02, -1.3845e-02,  2.2223e-01,  ..., -9.0570e-02,
           7.3257e-03, -1.4658e-01],
         ...,
         [-9.0032e-02, -1