In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
from transformers import AutoFeatureExtractor, Wav2Vec2Model

In [3]:
class BaseLine(nn.Module):
    def __init__(self, pretrain_feat="extract_features", backend="linear"):
        super().__init__()

        assert pretrain_feat in ["last_hidden_state", "extract_features"]
        self.pretrain_feat = pretrain_feat
        # The channels of used features for the pretrained model is 512 when using
        # the 'extract_features',  but 768 when ["last_hidden_state"] is used.
        C_features = 512 if pretrain_feat == "extract_features" else 768

        self.pretrain_model = Wav2Vec2Model.from_pretrained(
            "/usr/local/ay_data/0-model_weights/models--facebook--wav2vec2-base-960h"
        )

        self.backend = backend
        if backend == "resnet":
            self.backend_model = ResNet50(in_channels=C_features, classes=1)
        elif backend == "linear":
            self.pooler = nn.AdaptiveAvgPool1d(1)
            self.backend_model = nn.Linear(C_features, 1)

    def forward(self, x):
        feature = self.pretrain_model(x)[self.pretrain_feat]
        feature = torch.transpose(feature, 1, 2)
        if self.backend == "linear":
            feature = torch.squeeze(self.pooler(feature), -1)
        # print(feature.shape, self.pooler(feature).shape)
        outputs = self.backend_model(feature)
        return outputs

    def extract_feature(self, x):
        # print(x.shape, self.pretrain_feat)
        feature = self.pretrain_model(x)[self.pretrain_feat]
        feature = torch.transpose(feature, 1, 2)
        if self.backend == "linear":
            feature = torch.squeeze(self.pooler(feature), -1)
        return feature

    def make_prediction(self, feature):
        # print(feature.shape, self.pooler(feature).shape)
        outputs = self.backend_model(feature)
        return outputs

In [4]:
x = torch.rand(10, 69000)
model = BaseLine(backend="linear")

model(x)

  return self.fget.__get__(instance, owner)()
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at /usr/local/ay_data/0-model_weights/models--facebook--wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor([[0.3212],
        [0.3307],
        [0.3266],
        [0.3083],
        [0.3145],
        [0.3171],
        [0.3300],
        [0.3246],
        [0.3117],
        [0.3167]], grad_fn=<AddmmBackward0>)

In [9]:
# from torchtnt.utils.flops import FlopTensorDispatchMode
# with FlopTensorDispatchMode(model) as ftdm:
#     res = model(x).mean()
#     flops_forward = copy.deepcopy(ftdm.flop_counts)