In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
from transformers import AutoFeatureExtractor, WavLMModel

In [8]:
class BaseLine(nn.Module):
    def __init__(self, pretrain_feat="extract_features", backend='linear', num_classes=1, **kwargs):
        super().__init__()

        assert pretrain_feat in ["last_hidden_state", "extract_features"]
        self.pretrain_feat = pretrain_feat
        # The channels of used features for the pretrained model is 512 when using
        # the 'extract_features',  but 768 when ["last_hidden_state"] is used.
        C_features = 512 if pretrain_feat == "extract_features" else 768
        
        self.pretrain_model = WavLMModel.from_pretrained(
            "/usr/local/ay_data/0-model_weights/microsoft_wavlm-base"
        )

        self.backend = backend
        if backend == 'resnet':
            self.backend_model = ResNet50(
                in_channels=C_features, classes=num_classes
            )
        elif backend == 'linear':
            self.pooler = nn.AdaptiveAvgPool1d(1)
            self.backend_model = nn.Linear(C_features, num_classes)

    def get_feature(self, x):
        feature = self.pretrain_model(x)[self.pretrain_feat]
        feature = torch.transpose(feature, 1, 2)
        if self.backend == 'linear':
            feature = torch.squeeze(self.pooler(feature), -1)
        return feature
    
    def forward(self, x):
        feature = self.pretrain_model(x)[self.pretrain_feat]
        feature = torch.transpose(feature, 1, 2)
        if self.backend == 'linear':
            feature = torch.squeeze(self.pooler(feature), -1)
        # print(feature.shape, self.pooler(feature).shape)
        outputs = self.backend_model(feature)
        return outputs

    def extract_feature(self, x):
        return self.get_feature(x)

    def make_prediction(self, feature):
        outputs = self.backend_model(feature)
        return outputs

In [9]:
x = torch.rand(10, 69000)
model = BaseLine(backend='linear', pretrain_feat="last_hidden_state")

model(x)

Some weights of the model checkpoint at /usr/local/ay_data/0-model_weights/microsoft_wavlm-base were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at /usr/local/ay_data/0-model_weights/microsoft_wavlm-base and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this mo

tensor([[-0.2070],
        [-0.2199],
        [-0.2235],
        [-0.2120],
        [-0.2166],
        [-0.2089],
        [-0.1970],
        [-0.1958],
        [-0.2072],
        [-0.2172]], grad_fn=<AddmmBackward0>)