In [4]:
import os
import re
import numpy as np
import torch
import librosa
import torch.nn as nn
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CustomMambaBlock(nn.Module):
    def __init__(self, d_input, d_model, dropout=0.1):
        super().__init__()
        self.in_proj = nn.Linear(d_input, d_model)
        self.s_B = nn.Linear(d_model, d_model)
        self.s_C = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_input)
        self.norm = nn.LayerNorm(d_input)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        x_in = x  # сохраняем вход
        x = self.in_proj(x)
        B = self.s_B(x)
        C = self.s_C(x)
        x = x + B + C
        x = self.activation(x)
        x = self.out_proj(x)
        x = self.dropout(x)
        x = self.norm(x + x_in)  # residual + norm
        return x

class CustomMambaClassifier(nn.Module):
    def __init__(self, input_size=1024, d_model=256, num_layers=2, num_classes=7, dropout=0.1):
        super().__init__()
        self.input_proj = nn.Linear(input_size, d_model)
        self.blocks = nn.ModuleList([
            CustomMambaBlock(d_model, d_model, dropout=dropout)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x, lengths):
        # x: (batch, seq_length, input_size)
        x = self.input_proj(x)
        for block in self.blocks:
            x = block(x)
        pooled = []
        for i, l in enumerate(lengths):
            if l > 0:
                pooled.append(x[i, :l, :].mean(dim=0))
            else:
                pooled.append(torch.zeros(x.size(2), device=x.device))
        pooled = torch.stack(pooled, dim=0)
        return self.fc(pooled)

def get_model_mamba(params):
    return CustomMambaClassifier(
        input_size=params.get("input_size", 1024),
        d_model=params.get("d_model", 256),
        num_layers=params.get("num_layers", 2),
        num_classes=params.get("num_classes", 7),
        dropout=params.get("dropout", 0.1)
    )

label_to_emotion = {
    0: 'anger',
    1: 'disgust',
    2: 'fear',
    3: 'joy/happiness',
    4: 'neutral',
    5: 'sadness',
    6: 'surprise/enthusiasm'
}

class EmotionModel(Wav2Vec2PreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)
        self.init_weights()

    def forward(self, input_values):
        outputs = self.wav2vec2(input_values)
        hidden_states = outputs[0]  # (batch_size, sequence_length, hidden_size)
        return hidden_states


model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
processor = Wav2Vec2Processor.from_pretrained(model_name)
audio_embedder = EmotionModel.from_pretrained(model_name).to(device)

model_params = {
        "input_size": 1024,
        "d_model": 256,
        "num_layers": 2,
        "num_classes": 7,
        "dropout": 0.2
    }


def process_audio(signal: np.ndarray, sampling_rate: int) -> np.ndarray:
    inputs = processor(signal, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
    input_values = inputs["input_values"].to(device)
    
    with torch.no_grad():
        outputs = audio_embedder(input_values)
        embeddings = outputs
        
    return embeddings.detach().cpu().numpy()

def load_classifier_model_from_checkpoint(checkpoint_path):
    classifier_model = get_model_mamba(model_params).to(device)
    classifier_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    classifier_model.eval()
    return classifier_model

classifier_checkpoint = "new_best_model.pt"
classifier_model = load_classifier_model_from_checkpoint(classifier_checkpoint)

audio_dir = "RESD_RAW/test/09_neutral_happiness" 

embedding_dict = {}

audio_files = [f for f in os.listdir(audio_dir) if f.lower().endswith('.wav')]

print(f"Найдено {len(audio_files)} аудио файлов.\n")

for file_name in audio_files:
    
    file_id = file_name
    file_path = os.path.join(audio_dir, file_name)
    
    signal, sr = librosa.load(file_path, sr=16000)

    embeddings = process_audio(signal, sr)

    tensor_emb = torch.tensor(embeddings, dtype=torch.float32).to(device)
    lengths = [tensor_emb.shape[1]] 

    with torch.no_grad():
        logits = classifier_model(tensor_emb, lengths)
        pred_idx = torch.argmax(logits, dim=1).item()
        predicted_emotion = label_to_emotion.get(pred_idx, "Unknown")
    
    embedding_dict[file_id] = {
        "embeddings": embeddings, 
        "predicted_emotion": predicted_emotion
    }
    
    print(f"Обработан файл: {file_name} (id={file_id}), эмбеддингов: {embeddings.shape}, предсказанная эмоция: {predicted_emotion}")
    

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim were not used when initializing EmotionModel: ['classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.dense.bias', 'classifier.out_proj.bias']
- This IS expected if you are initializing EmotionModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EmotionModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Найдено 5 аудио файлов.
Обработан файл: 09_neutral_happiness n_060.wav (id=09_neutral_happiness n_060.wav), эмбеддингов: (1, 538, 1024), предсказанная эмоция: fear
Обработан файл: 09_neutral_happiness n_080.wav (id=09_neutral_happiness n_080.wav), эмбеддингов: (1, 651, 1024), предсказанная эмоция: neutral
Обработан файл: 09_neutral_happiness n_052.wav (id=09_neutral_happiness n_052.wav), эмбеддингов: (1, 527, 1024), предсказанная эмоция: neutral
Обработан файл: 09_neutral_happiness h_071.wav (id=09_neutral_happiness h_071.wav), эмбеддингов: (1, 633, 1024), предсказанная эмоция: joy/happiness
Обработан файл: 09_neutral_happiness n_020.wav (id=09_neutral_happiness n_020.wav), эмбеддингов: (1, 275, 1024), предсказанная эмоция: surprise/enthusiasm
