In [2]:
import torch
from DPHuBERT.wav2vec2.model import wav2vec2_model
import torch.nn as nn
import fairseq
import torchaudio

In [2]:
ckpt_path = "./DPHuBERT/checkpoints/DPHuBERT-sp0.75.pth"
ckpt = torch.load(ckpt_path)

In [24]:
dp_hubert_model = wav2vec2_model(**ckpt["config"])

In [33]:
class SSLModel(nn.Module):
    def __init__(self, device):
        super(SSLModel, self).__init__()

        cp_path = "./checkpoints_xlsr/xlsr2_300m.pt"  # Change the pre-trained XLSR model path.
        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [cp_path]
        )
        self.model = model[0]
        self.device = device
        self.out_dim = 1024
        return

    def extract_feat(self, input_data):

        # put the model to GPU if it not there
        if (
            next(self.model.parameters()).device != input_data.device
            or next(self.model.parameters()).dtype != input_data.dtype
        ):
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        if True:
            # input should be in shape (batch, length)
            if input_data.ndim == 3:
                input_tmp = input_data[:, :, 0]
            else:
                input_tmp = input_data

            # [batch, length, dim]
            emb = self.model(input_tmp, mask=False, features_only=True)["x"]
        return emb

In [6]:
path = "/data/a.varlamov/LJSpeech-1.1/wavs/LJ001-0008.wav"
audio, sr = torchaudio.load(path)

In [44]:
batch = torch.cat([audio] * 32)

In [16]:
batch.shape

torch.Size([8, 39325])

In [35]:
xlsr_model = SSLModel("cpu")

In [34]:
xlsr_model.extract_feat(batch).shape

NameError: name 'xlsr_model' is not defined

In [39]:
dp_hubert_model.extract_features(batch)[0][-1].shape

torch.Size([2, 122, 768])

In [25]:
class DPHubertModel(nn.Module):
    def __init__(self, device, behaviour="last-layer", freeze=False):
        '''
        Args:
            device: obvious...
            behaviour: last-layer / weighted-sum
            freeze: to freeze weights of the pre-train or not
                for weighted-sum freezing will not let weights of sum train
        '''
        super(DPHubertModel, self).__init__()

        ckpt_path = "./DPHuBERT/checkpoints/DPHuBERT-sp0.75.pth"
        ckpt = torch.load(ckpt_path)
        self.model = wav2vec2_model(**ckpt["config"]).to(device)
        self.device = device
        self.out_dim = 768
        self.n_layers = 12
        self.behaviour = behaviour
        
        if behaviour == "weighted-sum":
            self.sum_weights = nn.parameter.Parameter(torch.tensor([0.] * 9 + [0.5, 0.5, 0.5])).reshape(self.n_layers, 1, 1, 1)
        
        if freeze:
            for param in self.model.parameters():
                param.requires_grad = False

    def extract_feat(self, input_data):

        # put the model to GPU if it not there
        if (
            next(self.model.parameters()).device != input_data.device
            or next(self.model.parameters()).dtype != input_data.dtype
        ):
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        if True:
            # input should be in shape (batch, length)
            if input_data.ndim == 3:
                input_tmp = input_data[:, :, 0]
            else:
                input_tmp = input_data

            # [batch, length, dim]
            if self.behaviour == "last-layer":
                emb = self.model.extract_features(input_tmp)[0][-1]  # getting features from the last layer of transformer
            elif self.behaviour == "weighted-sum":
                all_layers_out = self.model.extract_features(input_tmp)[0][1:]
                all_layers_out = torch.stack(all_layers_out)
                emb = (all_layers_out * self.sum_weights).sum(dim=0)
                return emb
        return emb

In [26]:
dp_hubert_model = DPHubertModel("cpu", freeze=False, behaviour="weighted-sum")

In [29]:
emb = dp_hubert_model.extract_feat(batch)

In [24]:
torch.stack(all_out).shape

torch.Size([12, 2, 122, 768])

In [20]:
tensor_list = [torch.randn(2, 122, 768) for _ in range(12)]  # list of 12 tensors
weights_tensor = torch.randn(12)  # 1D tensor with 12 elements

# Convert list of tensors into a single tensor of shape [12, 2, 122, 768]
tensor_stack = torch.stack(tensor_list)  # Shape: [12, 2, 122, 768]

# Reshape weights_tensor to [12, 1, 1, 1] for broadcasting
weights_tensor_reshaped = weights_tensor.view(12, 1, 1, 1)

# Perform element-wise multiplication
result = (tensor_stack * weights_tensor_reshaped).sum(dim=0)

In [21]:
result.shape

torch.Size([2, 122, 768])

---

# wav2vec XLSR works faster then DPHubert, which is more than 10 times smaller!

In [45]:
%%time
xlsr_model.extract_feat(batch)

CPU times: user 7min 58s, sys: 8min 15s, total: 16min 14s
Wall time: 7.98 s


tensor([[[-3.1684e-01, -1.3144e-01, -9.9839e-02,  ..., -2.2259e-02,
           2.6934e-02,  3.4119e-01],
         [-3.1124e-02, -7.2893e-02,  4.6221e-02,  ..., -9.5193e-02,
           8.7276e-02,  5.3428e-01],
         [ 6.8431e-02, -4.2156e-02,  7.4829e-02,  ..., -1.1923e-01,
           7.5676e-02,  6.3278e-01],
         ...,
         [ 1.0732e-01,  4.2675e-02,  1.1723e-01,  ..., -1.0207e-01,
           1.0415e-01,  6.0296e-01],
         [-3.0037e-02,  5.5227e-04,  7.0944e-02,  ..., -9.7161e-02,
           1.1185e-01,  4.5757e-01],
         [-2.8352e-01, -6.8307e-02, -6.7748e-02,  ..., -3.2461e-02,
           3.2547e-02,  2.2996e-01]],

        [[-3.1684e-01, -1.3144e-01, -9.9839e-02,  ..., -2.2259e-02,
           2.6934e-02,  3.4119e-01],
         [-3.1124e-02, -7.2892e-02,  4.6222e-02,  ..., -9.5193e-02,
           8.7275e-02,  5.3428e-01],
         [ 6.8431e-02, -4.2156e-02,  7.4829e-02,  ..., -1.1923e-01,
           7.5676e-02,  6.3278e-01],
         ...,
         [ 1.0732e-01,  4

In [46]:
%%time
dp_hubert_model.extract_feat(batch)

CPU times: user 2min 2s, sys: 46.8 s, total: 2min 49s
Wall time: 2.25 s


tensor([[[ 0.3416,  1.8239,  2.2060,  ...,  0.6883,  1.1687,  0.0100],
         [-0.5105,  1.5057,  0.5117,  ..., -0.2433, -0.6276, -0.4823],
         [ 0.8634, -0.2174,  0.8490,  ..., -0.0038, -0.2467, -1.8248],
         ...,
         [-0.8037,  3.2378,  0.4703,  ...,  0.1115, -1.5998,  0.4930],
         [-1.2858,  1.7123,  0.8499,  ..., -0.8332, -0.9068,  0.0979],
         [-1.0495,  2.3076,  0.5729,  ..., -1.8239, -0.9515,  1.3305]],

        [[-1.1864,  0.9392,  2.9589,  ...,  0.6217, -0.3574,  1.2964],
         [-1.1078,  0.6414,  2.1782,  ..., -0.3556,  0.0243, -1.0702],
         [ 0.8445,  1.0507, -0.0178,  ...,  0.2536, -0.3710, -1.5432],
         ...,
         [-0.3000,  1.9700, -0.2216,  ..., -0.6348, -1.2263, -0.3300],
         [ 0.2855,  1.4401,  0.8531,  ..., -1.4564, -1.0461,  0.6909],
         [-0.3108,  2.3947,  0.1691,  ..., -0.1243, -1.5212, -0.4023]],

        [[-0.8744,  0.4075,  2.9414,  ...,  0.9608, -0.6673, -0.1667],
         [-1.0468,  0.6904,  1.7022,  ..., -0

In [48]:
%%time
for i in range(32):
    xlsr_model.extract_feat(batch[0].unsqueeze(0))

CPU times: user 34min 43s, sys: 3min 31s, total: 38min 14s
Wall time: 22.4 s


In [49]:
%%time
for i in range(32):
    dp_hubert_model.extract_feat(batch[0].unsqueeze(0))

CPU times: user 17min 38s, sys: 51.6 s, total: 18min 30s
Wall time: 11 s
