In [1]:
import librosa
import torch
import torchaudio
import numpy as np

from sklearn.metrics import classification_report
from datasets import load_dataset, load_metric
from transformers import AutoConfig, Wav2Vec2FeatureExtractor

In [2]:
data_files = {
    "train": "../../KEMDy20_v1_1/Splitting/Train.csv",
    "test": "../../KEMDy20_v1_1/Splitting/Test.csv"
}

In [3]:
dataset = load_dataset("csv", data_files = data_files)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

print(train_dataset)
print(test_dataset)

Found cached dataset csv (C:/Users/Yechani/.cache/huggingface/datasets/csv/default-1ac06c035d845d89/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/2 [00:00<?, ?it/s]

Dataset({
    features: ['Emotion', 'Path'],
    num_rows: 25890
})
Dataset({
    features: ['Emotion', 'Path'],
    num_rows: 6312
})


In [4]:
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.file_utils import ModelOutput

@dataclass
class SpeechClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [5]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2PreTrainedModel,
    Wav2Vec2Model
)


class Wav2Vec2ClassificationHead(nn.Module):
    """Head for wav2vec classification task."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.pooling_mode = config.pooling_mode
        self.config = config

        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = Wav2Vec2ClassificationHead(config)

        self.init_weights()

    def freeze_feature_extractor(self):
        self.wav2vec2.feature_extractor._freeze_parameters()

    def merged_strategy(
            self,
            hidden_states,
            mode="mean"
    ):
        if mode == "mean":
            outputs = torch.mean(hidden_states, dim=1)
        elif mode == "sum":
            outputs = torch.sum(hidden_states, dim=1)
        elif mode == "max":
            outputs = torch.max(hidden_states, dim=1)[0]
        else:
            raise Exception(
                "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")

        return outputs

    def forward(
            self,
            input_values,
            attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            labels=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SpeechClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name_or_path = "../../KEMDy20_v1_1/Pretrained_Model2/"
config = AutoConfig.from_pretrained(model_name_or_path)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
sampling_rate = feature_extractor.sampling_rate
model = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path).to(device)

In [20]:
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["Path"])
    resampler = torchaudio.transforms.Resample(sampling_rate)
    speech = resampler(speech_array).squeeze().numpy()[::-1]

    batch["speech"] = speech
    return batch

def predict(batch):
    features = feature_extractor(batch["speech"], sampling_rate=sampling_rate, return_tensors="pt", padding=True)

    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)

    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits 

    pred_ids = torch.argmax(logits, dim=-1).detach().cpu().numpy()
    batch["predicted"] = pred_ids
    return batch

def predict_logits(batch):
    features = feature_extractor(batch["speech"], sampling_rate=sampling_rate, return_tensors="pt", padding=True)

    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)

    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits 

    pred_ids = logits.detach().cpu().numpy()
    batch["predicted"] = pred_ids
    return batch

In [21]:
train_dataset = train_dataset.map(speech_file_to_array_fn)
test_dataset = test_dataset.map(speech_file_to_array_fn)

Map:   0%|          | 0/25890 [00:00<?, ? examples/s]

Map:   0%|          | 0/6312 [00:00<?, ? examples/s]

In [22]:
train_logits = train_dataset.map(predict_logits, batched=True, batch_size=8)
test_logits = test_dataset.map(predict_logits, batched=True, batch_size=8)

Map:   0%|          | 0/25890 [00:00<?, ? examples/s]

Map:   0%|          | 0/6312 [00:00<?, ? examples/s]

In [32]:
wav2vec_train_logits = np.array(train_logits["predicted"])
wav2vec_train_pred = np.exp(wav2vec_train_logits)/np.sum(np.exp(wav2vec_train_logits), axis=1, keepdims=True)

wav2vec_test_logits = np.array(test_logits["predicted"])
wav2vec_test_pred = np.exp(wav2vec_test_logits)/np.sum(np.exp(wav2vec_test_logits), axis=1, keepdims=True)

In [33]:
np.savez("../../KEMDy20_v1_1/wav2vec_R_train_pred.npz", predict_prob=wav2vec_train_pred)
np.savez("../../KEMDy20_v1_1/wav2vec_R_test_pred.npz", predict_prob=wav2vec_test_pred)

In [34]:
import numpy as np
kobert_test_pred = np.load("../../KEMDy20_v1_1/kobert_test_pred.npz")["predict_prob"]

In [35]:
wav2vec_bert_test_pred = (wav2vec_test_pred*0.5 + kobert_test_pred*0.5)

In [36]:
label_names = [config.id2label[i] for i in range(config.num_labels)]
label_names

['angry', 'disqust', 'fear', 'happy', 'neutral', 'sad', 'surprise']

In [37]:
y_true = [config.label2id[name] for name in test_logits["Emotion"]]
y_wav2vec_pred = np.argmax(wav2vec_test_pred, axis=1)
y_bert_pred = np.argmax(kobert_test_pred, axis=1)
y_wav2vec_bert_pred = np.argmax(wav2vec_bert_test_pred, axis=1)

In [38]:
print(classification_report(y_true, y_wav2vec_pred, target_names=label_names, digits=4))

              precision    recall  f1-score   support

       angry     0.7304    0.6732    0.7006       664
     disqust     0.7253    0.7638    0.7440       470
        fear     0.7143    0.7732    0.7426       291
       happy     0.7725    0.7465    0.7592      1128
     neutral     0.8719    0.9186    0.8947      2741
         sad     0.7766    0.6852    0.7280       629
    surprise     0.8067    0.7404    0.7721       389

    accuracy                         0.8096      6312
   macro avg     0.7711    0.7573    0.7630      6312
weighted avg     0.8075    0.8096    0.8077      6312



In [39]:
print(classification_report(y_true, y_bert_pred, target_names=label_names, digits=4))

              precision    recall  f1-score   support

       angry     0.9038    0.8343    0.8677       664
     disqust     0.9085    0.8872    0.8977       470
        fear     0.9498    0.9107    0.9298       291
       happy     0.8971    0.8271    0.8607      1128
     neutral     0.8734    0.9486    0.9094      2741
         sad     0.8961    0.8362    0.8651       629
    surprise     0.9468    0.8689    0.9062       389

    accuracy                         0.8924      6312
   macro avg     0.9108    0.8733    0.8909      6312
weighted avg     0.8937    0.8924    0.8918      6312



In [40]:
print(classification_report(y_true, y_wav2vec_bert_pred, target_names=label_names, digits=4))

              precision    recall  f1-score   support

       angry     0.9132    0.8238    0.8662       664
     disqust     0.8989    0.8894    0.8941       470
        fear     0.9597    0.9003    0.9291       291
       happy     0.9322    0.8165    0.8705      1128
     neutral     0.8704    0.9726    0.9187      2741
         sad     0.9045    0.8283    0.8647       629
    surprise     0.9569    0.8560    0.9037       389

    accuracy                         0.8980      6312
   macro avg     0.9194    0.8696    0.8924      6312
weighted avg     0.9009    0.8980    0.8969      6312

