In [2]:
import torch
from DPHuBERT.wav2vec2.model import wav2vec2_model
import torch.nn as nn
import fairseq
import torchaudio

In [2]:
ckpt_path = "./DPHuBERT/checkpoints/DPHuBERT-sp0.75.pth"
ckpt = torch.load(ckpt_path)

In [24]:
dp_hubert_model = wav2vec2_model(**ckpt["config"])

In [6]:
class SSLModel(nn.Module):
    def __init__(self, device):
        super(SSLModel, self).__init__()

        cp_path = "./checkpoints_xlsr/xlsr2_300m.pt"  # Change the pre-trained XLSR model path.
        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [cp_path]
        )
        self.model = model[0]
        self.device = device
        self.out_dim = 1024
        return

    def extract_feat(self, input_data):

        # put the model to GPU if it not there
        if (
            next(self.model.parameters()).device != input_data.device
            or next(self.model.parameters()).dtype != input_data.dtype
        ):
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        if True:
            # input should be in shape (batch, length)
            if input_data.ndim == 3:
                input_tmp = input_data[:, :, 0]
            else:
                input_tmp = input_data

            # [batch, length, dim]
            emb = self.model(input_tmp, mask=False, features_only=True)["x"]
        return emb

In [6]:
path = "/data/a.varlamov/LJSpeech-1.1/wavs/LJ001-0008.wav"
audio, sr = torchaudio.load(path)

In [7]:
batch = torch.cat([audio] * 2)

In [16]:
batch.shape

torch.Size([8, 39325])

In [17]:
xlsr_model = SSLModel("cpu")

In [20]:
xlsr_model.extract_feat(batch).shape

torch.Size([2, 122, 1024])

In [39]:
dp_hubert_model.extract_features(batch)[0][-1].shape

torch.Size([2, 122, 768])

In [25]:
class DPHubertModel(nn.Module):
    def __init__(self, device, behaviour="last-layer", freeze=False):
        '''
        Args:
            device: obvious...
            behaviour: last-layer / weighted-sum
            freeze: to freeze weights of the pre-train or not
                for weighted-sum freezing will not let weights of sum train
        '''
        super(DPHubertModel, self).__init__()

        ckpt_path = "./DPHuBERT/checkpoints/DPHuBERT-sp0.75.pth"
        ckpt = torch.load(ckpt_path)
        self.model = wav2vec2_model(**ckpt["config"]).to(device)
        self.device = device
        self.out_dim = 768
        self.n_layers = 12
        self.behaviour = behaviour
        
        if behaviour == "weighted-sum":
            self.sum_weights = nn.parameter.Parameter(torch.tensor([0.] * 9 + [0.5, 0.5, 0.5])).reshape(self.n_layers, 1, 1, 1)
        
        if freeze:
            for param in self.model.parameters():
                param.requires_grad = False

    def extract_feat(self, input_data):

        # put the model to GPU if it not there
        if (
            next(self.model.parameters()).device != input_data.device
            or next(self.model.parameters()).dtype != input_data.dtype
        ):
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        if True:
            # input should be in shape (batch, length)
            if input_data.ndim == 3:
                input_tmp = input_data[:, :, 0]
            else:
                input_tmp = input_data

            # [batch, length, dim]
            if self.behaviour == "last-layer":
                emb = self.model.extract_features(input_tmp)[0][-1]  # getting features from the last layer of transformer
            elif self.behaviour == "weighted-sum":
                all_layers_out = self.model.extract_features(input_tmp)[0][1:]
                all_layers_out = torch.stack(all_layers_out)
                emb = (all_layers_out * self.sum_weights).sum(dim=0)
                return emb
        return emb

In [26]:
dp_hubert_model = DPHubertModel("cpu", freeze=False, behaviour="weighted-sum")

In [29]:
emb = dp_hubert_model.extract_feat(batch)

In [24]:
torch.stack(all_out).shape

torch.Size([12, 2, 122, 768])

In [20]:
tensor_list = [torch.randn(2, 122, 768) for _ in range(12)]  # list of 12 tensors
weights_tensor = torch.randn(12)  # 1D tensor with 12 elements

# Convert list of tensors into a single tensor of shape [12, 2, 122, 768]
tensor_stack = torch.stack(tensor_list)  # Shape: [12, 2, 122, 768]

# Reshape weights_tensor to [12, 1, 1, 1] for broadcasting
weights_tensor_reshaped = weights_tensor.view(12, 1, 1, 1)

# Perform element-wise multiplication
result = (tensor_stack * weights_tensor_reshaped).sum(dim=0)

In [21]:
result.shape

torch.Size([2, 122, 768])