In [None]:
import huggingface_hub
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

huggingface_hub.login(token=user_secrets.get_secret("HUGGINGFACE_TOKEN"))
%env WANDB_API_KEY={user_secrets.get_secret("WANDB_API_KEY")}

#  Preprocess

In [None]:
from datasets import Audio, ClassLabel, load_dataset
from datasets import IterableDatasetDict

dusha = IterableDatasetDict()

dusha["train"] = load_dataset("KELONMYOSA/dusha_emotion_audio", split="train", streaming=True)
dusha["test"] = load_dataset("KELONMYOSA/dusha_emotion_audio", split="test", streaming=True)

labels = dusha["train"].features["label"].names

dusha = dusha.cast_column("audio", Audio(sampling_rate=16_000))
dusha = dusha.remove_columns("file")

In [None]:
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
num_labels = len(id2label)

In [None]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-xls-r-1b-russian")


def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = processor(audio_arrays, sampling_rate=processor.feature_extractor.sampling_rate,
                       return_tensors="pt", padding=True)
    return inputs

In [None]:
dusha = dusha.map(preprocess_function, remove_columns="audio", batched=True, batch_size=256).with_format("torch")
dusha["train"] = dusha["train"].shuffle(buffer_size=256, seed=0)

# Model

In [None]:
!pip install -U accelerate

In [None]:
from dataclasses import dataclass
from typing import Tuple, Dict, List, Optional, Union, Any
import torch
import torch.nn as nn
from packaging import version
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.file_utils import ModelOutput
from transformers import Wav2Vec2Processor, Trainer, is_apex_available
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model


@dataclass
class SpeechClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class Wav2Vec2ClassificationHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = Wav2Vec2ClassificationHead(config)

        self.init_weights()

    def freeze_feature_extractor(self):
        self.wav2vec2.feature_extractor._freeze_parameters()

    def merged_strategy(
            self,
            hidden_states,
            mode="mean"
    ):
        if mode == "mean":
            outputs = torch.mean(hidden_states, dim=1)
        elif mode == "sum":
            outputs = torch.sum(hidden_states, dim=1)
        elif mode == "max":
            outputs = torch.max(hidden_states, dim=1)[0]
        else:
            raise Exception(
                "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")

        return outputs

    def forward(
            self,
            input_values,
            attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            labels=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        hidden_states = self.merged_strategy(hidden_states, mode="mean")
        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SpeechClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [feature["label"] for feature in features]

        d_type = torch.long if isinstance(label_features[0], int) else torch.float

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch["labels"] = torch.tensor(label_features, dtype=d_type)

        return batch


if is_apex_available():
    from apex import amp

if version.parse(torch.__version__) >= version.parse("1.6"):
    _is_native_amp_available = True


class CTCTrainer(Trainer):
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        model.train()
        inputs = self._prepare_inputs(inputs)

        loss = self.compute_loss(model, inputs)

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            self.deepspeed.backward(loss)
        else:
            loss.backward()

        return loss.detach()

# Metrics

In [None]:
import numpy as np
from transformers import EvalPrediction

is_regression = False


def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)

    if is_regression:
        return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
    else:
        return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

# Train

In [None]:
from transformers import Seq2SeqTrainingArguments, AutoConfig

config = AutoConfig.from_pretrained(
    "lighteternal/wav2vec2-large-xlsr-53-greek",
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
    finetuning_task="wav2vec2_emotion_ru",
    pooling_mode="mean"
)
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

model = Wav2Vec2ForSpeechClassification.from_pretrained("facebook/wav2vec2-xls-r-300m", config=config)
model.freeze_feature_extractor()

training_args = Seq2SeqTrainingArguments(
    output_dir="KELONMYOSA/wav2vec2-xls-r-300m-emotion-ru",
    evaluation_strategy="steps",
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=8,
    max_steps=12000,
    warmup_steps=500,
    save_steps=4000,
    eval_steps=4000,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

In [None]:
import torch
import time
import gc
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo


def clear_gpu_memory():
    torch.cuda.empty_cache()
    gc.collect()


def wait_until_enough_gpu_memory(min_memory_available, max_retries=10, sleep_time=5):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(torch.cuda.current_device())

    for _ in range(max_retries):
        info = nvmlDeviceGetMemoryInfo(handle)
        if info.free >= min_memory_available:
            break
        print(f"Waiting for {min_memory_available} bytes of free GPU memory. Retrying in {sleep_time} seconds...")
        time.sleep(sleep_time)
    else:
        raise RuntimeError(f"Failed to acquire {min_memory_available} bytes of free GPU memory after {max_retries} retries.")

In [None]:
from transformers import TrainerCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset


class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        min_memory_available = 2 * 1024 * 1024 * 1024
        clear_gpu_memory()
        wait_until_enough_gpu_memory(min_memory_available)
        
        if isinstance(train_dataloader.dataset, IterableDatasetShard):
            pass  # set_epoch() is handled by the Trainer
        elif isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)

In [None]:
trainer = CTCTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=dusha["train"],
    eval_dataset=dusha["test"],
    tokenizer=processor.feature_extractor,
    callbacks=[ShuffleCallback()],
)

In [None]:
trainer.train()

In [None]:
kwargs = {
    "dataset_tags": "KELONMYOSA/dusha_emotion_audio",
    "dataset": "Dusha",
    "language": "ru",
    "model_name": "Speech emotion recognition",
    "finetuned_from": "facebook/wav2vec2-xls-r-300m",
    "tasks": "emotion-speech-recognition",
    "tags": "emotion",
}

trainer.push_to_hub(**kwargs)