In [1]:
#首先加载定义好的模型
import torch
import numpy 
import torch.nn as nn
import torch.nn.functional as F
import os
import torchaudio
import collections
import numpy
import heapq
from collections import deque
from transformers import Wav2Vec2ForPreTraining, Wav2Vec2Config
import argparse
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
import numpy as np
import torch.optim as optim
from torch.utils import data
import pytorch_lightning.core.lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Wav2vec2Wrapper(nn.Module):
    def __init__(self, pretrain=True):
        super().__init__()
        self.wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base", revision='2dcc7b7f9b11f0ef271067e62599a27317a03114').wav2vec2
        #Disable gradient checkpointing for ddp
        self.wav2vec2.encoder.config.gradient_checkpointing = False
        self.pretrain = pretrain
        if pretrain:
            self.mask_time_length = 15
            self.mask_time_prob = 0.06 #Probability of each time step is masked!
            self.observe_time_prob = 0.0 #Percentage of tokens that are perserved
            self.mask_feature_prob = 0
        else:
            #SpecAug
            self.mask_time_length = 15
            self.mask_time_prob = 0.08
            self.observe_time_prob = 0.0

            self.mask_feature_length = 64
            self.mask_feature_prob = 0.05


    def trainable_params(self):
        ret = list(self.wav2vec2.encoder.parameters())
        return ret

    def forward(self, x, length=None):
        with torch.no_grad():
            x = self.wav2vec2.feature_extractor(x)
            x = x.transpose(1, 2) #New version of huggingface
            x, _ = self.wav2vec2.feature_projection(x) #New version of huggingface
            mask = None
            if length is not None:
                length = self.get_feat_extract_output_lengths(length)
                mask = prepare_mask(length, x.shape[:2], x.dtype, x.device)
            if self.pretrain or self.training:
                batch_size, sequence_length, hidden_size = x.size()

                # apply SpecAugment along time axis
                if self.mask_time_prob > 0:
                    mask_time_indices = _compute_mask_indices(
                        (batch_size, sequence_length),
                        self.mask_time_prob,
                        self.mask_time_length,
                        min_masks=2,
                        device=x.device
                    )
                    masked_indicies = mask_time_indices & mask
                    flip_mask = torch.rand((batch_size, sequence_length), device=masked_indicies.device) > self.observe_time_prob
                    x[masked_indicies & flip_mask] = self.wav2vec2.masked_spec_embed.to(x.dtype)

                # apply SpecAugment along feature axis
                if self.mask_feature_prob > 0:
                    mask_feature_indices = _compute_mask_indices(
                        (batch_size, hidden_size),
                        self.mask_feature_prob,
                        self.mask_feature_length,
                        device=x.device,
                        min_masks=1
                    )
                    x[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
        x = self.wav2vec2.encoder(x, attention_mask=mask)[0]
        reps = F.relu(x)
        if self.pretrain:
            return reps, masked_indicies
        return reps

    #From huggingface
    def get_feat_extract_output_lengths(self, input_length):
        """
        Computes the output length of the convolutional layers
        """
        def _conv_out_length(input_length, kernel_size, stride):
            # 1D convolutional layer output length formula taken
            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
            return (input_length - kernel_size) // stride + 1
        for kernel_size, stride in zip(self.wav2vec2.config.conv_kernel, self.wav2vec2.config.conv_stride):
            input_length = _conv_out_length(input_length, kernel_size, stride)
        return input_length

def prepare_mask(length, shape, dtype, device):
    #Modified from huggingface
    mask = torch.zeros(
        shape, dtype=dtype, device=device
    )
    # these two operations makes sure that all values
    # before the output lengths indices are attended to
    mask[
        (torch.arange(mask.shape[0], device=device), length - 1)
    ] = 1
    mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool()
    return mask

In [3]:
class PretrainedRNNHead(pl.LightningModule):
    def __init__(self, n_classes, backend='wav2vec2', wav2vecpath=None):
        assert backend in ['wav2vec2', 'wav2vec']
        super().__init__()
        self.backend = backend
        if backend == 'wav2vec2':
            self.wav2vec2 = Wav2vec2Wrapper(pretrain=False)
            feature_dim = 768
        else:
            assert wav2vecpath is not None
            self.wav2vec = Wav2vecWrapper(wav2vecpath)
            feature_dim = 512
        self.rnn_head = nn.LSTM(feature_dim, 256, 1, bidirectional=True)
        self.linear_head = nn.Sequential(
            nn.ReLU(),
            nn.Linear(768, n_classes)
        )

    def trainable_params(self):
        return list(self.rnn_head.parameters()) + list(self.linear_head.parameters()) + list(getattr(self, self.backend).trainable_params())

    def forward(self, x, length):
        reps = getattr(self, self.backend)(x, length)
        return reps

In [4]:
model = PretrainedRNNHead(4)
state_dict = torch.load('/home/ni/FT-w2v2-ser/output_dir/Session4_pt/eager.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.eval().to(device)

In [6]:
def wav_preprocessing(
        txtpath: str,
        wavpath: str,
        tgt_dim: float = 256
):
    emotion_label_dict = {'ang': 0, 'exc': 1, 'hap': 1, 'sad': 2, 'neu': 3}
    wav2vec = []
    labels = []
    # 获取txt文件列表
    file_names = os.listdir(txtpath)
    for file_name in file_names:
        if file_name.endswith('.txt'):
#            print(file_name)
            wav_folder_name = os.path.join(wavpath, file_name[:-4])
#            print(wav_folder_name)

            txt_file_path = os.path.join(txtpath, file_name)
            with open(txt_file_path, 'r') as f:
                line = f.readline()
                while line:
                    line_list = line.split()

                    if len(line_list) == 8 and line_list[0][0] == '[' and line_list[1] == '-' and line_list[2][-1] == ']':
                        label_key = line_list[4]
                        if label_key not in emotion_label_dict:
                            pass
                        else:
                            label = emotion_label_dict[label_key]
                            wav_file_name = line_list[3] + '.wav'
                            wav_file_path = os.path.join(wav_folder_name, wav_file_name)

                            # Load and preprocess audio
                            waveform, sample_rate = torchaudio.load(wav_file_path)
                            waveform = waveform.to(device)
    
                            # If stereo audio, convert to mono by averaging channels
                            if waveform.shape[0] > 1:
                                waveform = waveform.mean(dim=0, keepdim=True)
            
                            if sample_rate != 16000:
                                waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
                            with torch.no_grad():
                                lengths = torch.tensor([waveform.shape[1]]).to(device)
                                wav2vec_features = model(waveform, lengths)
                        
                            d0, d1, d2 = wav2vec_features.shape
                            if d1 < tgt_dim:
                                n_padding = tgt_dim - d1
                                wav2vec_features = torch.nn.functional.pad(wav2vec_features, (0, 0 , 0, n_padding, 0, 0), mode='constant', value=0)
                            elif d1 == tgt_dim:
                                pass
                            if d1 > tgt_dim:
                                start_index = (wav2vec_features.size(1) - tgt_dim) // 2
                                wav2vec_features = wav2vec_features[:, start_index:start_index+tgt_dim, :]
                            
                            labels = labels + [label]
                            wav2vec.append(wav2vec_features)
                    # print(line)
                    line = f.readline()


    wav2vec_last = torch.cat(wav2vec, dim = 0).squeeze(1)
    print(wav2vec_last.shape)
    label_last = np.stack(labels, axis=0)
    print(label_last.shape)
    

    return wav2vec_last , label_last

In [129]:
wav2vec_last1,label_last1 = wav_preprocessing(
        txtpath = '/home/ni/提取FT-w2v2特征/IEMOCAP数据集_分部分/Session5/EmoEvaluation_5M_impro_3',
        wavpath = '/home/ni/提取FT-w2v2特征/IEMOCAP数据集_分部分/Session5/wav_5M_impro_3'
    )

torch.Size([131, 256, 768])
(131,)


In [130]:
wav2vec_last1 = wav2vec_last1.cpu().numpy()
numpy.save('data_Session5M_w2v2_impro_3'+'.npy', wav2vec_last1)
#label_last1 = label_last1.cpu().numpy()
numpy.save('data_Session5M_label_impro_3'+'.npy', label_last1)