In [None]:
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    DataCollatorWithPadding, 
    TrainingArguments,
    Trainer,
    AutoModelForAudioClassification,
    Wav2Vec2Processor,
    utils,
)

import pandas as pd
from datasets import load_dataset, load_metric
import numpy as np
from datasets import DatasetDict
import torch
from sklearn.svm import SVC
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
import pickle

utils.logging.set_verbosity_error()
torch.cuda.empty_cache()

plt.rcParams.update({'font.size': 18})


TER_MODEL_ID = "xlm-roberta-base"
SER_MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-russian"
batch_size = 4


class SpeechTextEmotionClassifier:
    def __init__(
        self, 
        df_paths={'train': 'data/toloka_marked_train.csv', 'test': 'data/toloka_marked_test.csv'},
        speech_paths={'train': 'data/speech_train.npz', 'test': 'data/speech_test.npz'},
    ):
        self.metric = load_metric("accuracy")
        
        self.ter_model = AutoModelForSequenceClassification.from_pretrained(TER_MODEL_ID, num_labels=4)
        
        self.ser_model = AutoModelForAudioClassification.from_pretrained(SER_MODEL_ID, num_labels=4)
        self.ser_model.freeze_feature_encoder()
        
        dataset = load_dataset(
            'csv', 
            data_files=df_paths,
        )

        def add_speech(path, dataset):
            with np.load(path) as data:
                speech = [data[i] for i in data]

            return dataset.add_column('speech', speech)
        
        for split, path in speech_paths.items():
            dataset[split] = add_speech(path, dataset[split])
        
        test_split = dataset['test'].train_test_split(shuffle=True, seed=200, test_size=0.5)
        self.dataset = DatasetDict({
            'train': dataset['train'],
            'validation': test_split['train'],
            'test': test_split['test'],
            })

        
        self.ser_processor = Wav2Vec2Processor.from_pretrained(SER_MODEL_ID)

        self.ter_tokenizer = AutoTokenizer.from_pretrained(TER_MODEL_ID)
        
        self.combined_model = SVC(
            C=1.0, 
            kernel='poly', 
            degree=5, 
            gamma='scale', 
            coef0=0.0, 
            probability=True, 
            tol=0.001, 
            max_iter=-1, 
            decision_function_shape='ovr',
            class_weight=None,
            random_state=12,
        )
        
        self.ter_output_dir = "ter_outputs"
        self.ser_output_dir = "ser_outputs"


    def compute_metrics(self, eval_pred):
        """Computes accuracy on a batch of predictions"""
        predictions = np.argmax(eval_pred.predictions, axis=1)
        return self.metric.compute(predictions=predictions, references=eval_pred.label_ids)

    
    def preprocess_audio_df(self, batch):
    #     batch["input_values"] = self.ser_processor(batch["speech"], sampling_rate=SAMPLE_RATE).input_values
        batch["input_values"] = batch["speech"]
        batch["labels"] = batch["label"]
        return batch
    
    def construct_speech_df(self):
        return self.dataset.map(
            self.preprocess_audio_df, 
            remove_columns=['path', 'hypo', 'result', 'label', 'speech'], 
            batched=True
        )
    
    def preprocess_text_df(self, batch):
        batch["input_ids"] = self.ter_tokenizer(batch["hypo"]).input_ids
        batch["labels"] = batch["label"]
        return batch
    
    def construct_text_df(self):
        return self.dataset.map(
            self.preprocess_text_df, 
            remove_columns=['path', 'hypo', 'result', 'label', 'speech'], 
            batched=True,
        )
    
    def create_ter_trainer(
        self, 
        df, 
        lr=3e-4, 
        n_epochs=5, 
        tp16=True, 
        disable_tqdm=True,
    ):
        training_args = TrainingArguments(
            output_dir=self.ter_output_dir,
            group_by_length=True,
            evaluation_strategy = "epoch",
            logging_strategy = "epoch",
            learning_rate=lr,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=64//batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=n_epochs,
            warmup_ratio=0.1,
            metric_for_best_model="accuracy",
            fp16=tp16,
            save_strategy = "epoch",
            load_best_model_at_end=True,
            save_total_limit=2,
            disable_tqdm=disable_tqdm,
        )
        
        return Trainer(
            model=self.ter_model,
            args=training_args,
            compute_metrics=self.compute_metrics,
            train_dataset=df['train'],
            eval_dataset=df['validation'],
            tokenizer=self.ter_tokenizer,
        )
        
        
    def train_model(
        self, 
        trainer,
        save_model=False,
    ):  
        trainer.train()
        if save_model:
            trainer.save_model(self.ter_output_dir)
        
    def create_ser_trainer(
        self, 
        df, 
        lr=3e-4, 
        n_epochs=5, 
        tp16=True, 
        disable_tqdm=True,
    ):
        training_args = TrainingArguments(
            output_dir=self.ser_output_dir,
            group_by_length=True,
            evaluation_strategy = "epoch",
            logging_strategy = "epoch",
            learning_rate=lr,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=64//batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=n_epochs,
            warmup_ratio=0.1,
            metric_for_best_model="accuracy",
            fp16=tp16,
            save_strategy = "epoch",
            load_best_model_at_end=True,
            save_total_limit=2,
            disable_tqdm=disable_tqdm,
        )
        
        return Trainer(
            model=self.ser_model,
            args=training_args,
            compute_metrics=self.compute_metrics,
            train_dataset=df['train'],
            eval_dataset=df['validation'],
            tokenizer=self.ser_processor.feature_extractor,
        )
    
    
    def load_pretrained_model(
        self, 
        model_type, 
        df,
        model_dir=None,
        lr=3e-4, 
        n_epochs=5, 
        tp16=True, 
        disable_tqdm=True,
    ):
        if model_type == 'ter':
            self.ter_model = AutoModelForSequenceClassification.from_pretrained(self.ter_output_dir if not model_dir else model_dir)
            trainer = self.create_ter_trainer(df, lr, n_epochs, tp16, disable_tqdm)
        elif model_type == 'ser':
            self.ser_model = AutoModelForAudioClassification.from_pretrained(self.ser_output_dir if not model_dir else model_dir)
            trainer = self.create_ser_trainer(df, lr, n_epochs, tp16, disable_tqdm)
        return trainer
        
        
    def predict_text_model(self, text_trainer, df):
        return text_trainer.predict(df).predictions
    
    def predict_speech_model(self, speech_trainer, df):
        return speech_trainer.predict(df).predictions
    
    def predict_and_concat(self, text_trainer, speech_trainer, text_df, speech_df):
        text_logits = self.predict_text_model(text_trainer, text_df)
        speech_logits = self.predict_speech_model(speech_trainer, speech_df)
        
        return np.concatenate((text_logits, speech_logits), axis=1)
    
    def train_combined(self, text_trainer, speech_trainer, text_df, speech_df):
        logits = self.predict_and_concat(text_trainer, speech_trainer, text_df, speech_df)
        
        self.combined_model.fit(logits, self.dataset['train']['label'])
        
    def predict_combined(self, text_trainer, speech_trainer, text_df, speech_df):
        logits = self.predict_and_concat(text_trainer, speech_trainer, text_df, speech_df)
        return self.combined_model.predict_proba(logits)
        
        
    def eval_metrics(self, text_trainer, speech_trainer, text_df, speech_df):
        logits = self.predict_combined(text_trainer, speech_trainer,text_df, speech_df)
        predictions = np.argmax(logits, axis=1)
        self.metric.add_batch(predictions=predictions, references=text_df["labels"])

        print(self.metric.compute())

        cm = confusion_matrix(text_df["labels"], predictions)
        plt.figure(figsize=(8, 8))
        plt.title('Матрица ошибок')
        
        sns.set(font_scale=2)
        sns.heatmap(cm, annot=True, fmt="d")
        plt.xlabel('Предсказанные значения')
        plt.xticks([0.5, 1.5, 2.5, 3.5], ['злость', 'радость', 'нейтральная', 'грусть'], rotation=45)
        plt.yticks([0.5, 1.5, 2.5, 3.5], ['злость', 'радость', 'нейтральная', 'грусть'], rotation=45)
        plt.ylabel('Верные значения')
        plt.show()
        
    def save_combined_model(self, combined_model_output_file = 'combined_model'):
        pickle.dump(self.combined_model, open(combined_model_output_file, 'wb'))
        
        
    def load_combined_model(self, combined_model_output_file = 'combined_model'):
        self.combined_model = pickle.load(open(combined_model_output_file, 'rb'))
        
clf = SpeechTextEmotionClassifier()

text_df = clf.construct_text_df()
speech_df = clf.construct_speech_df()

In [None]:
text_trainer = clf.create_ter_trainer(text_df, tp16=True, n_epochs=5, disable_tqdm=True)
clf.train_model(text_trainer, save_model=True)
speech_trainer = clf.create_ser_trainer(speech_df, tp16=True, n_epochs=8, disable_tqdm=False)
clf.train_model(speech_trainer, save_model=False)

clf.save_combined_model()

In [None]:
text_trainer = clf.load_pretrained_model('ter', text_df, model_dir='ter_outputs/checkpoint-193', n_epochs=5, tp16=True, disable_tqdm=True)
speech_trainer = clf.load_pretrained_model('ser', speech_df, model_dir='ser_outputs/checkpoint-1351', n_epochs=5, tp16=True, disable_tqdm=True)
clf.load_combined_model()

In [None]:
clf.train_combined(text_trainer, speech_trainer, text_df['train'], speech_df['train'])

In [None]:
import scikitplot as skplt
from sklearn import metrics


pd.options.mode.chained_assignment = None

def show_learning_plots(log_history):
    metrics = pd.DataFrame(log_history)
    metrics = metrics[['loss', 'eval_loss', 'eval_accuracy', 'epoch']]
    metrics['train'] = metrics['eval_loss'].isna()
    metrics['loss'].loc[metrics['loss'].isna()] = metrics[metrics['loss'].isna()]['eval_loss']
    metrics['split'] = metrics['train'].apply(lambda x: 'train' if x else 'eval')
    metrics.drop(['eval_loss', 'train'], axis=1, inplace=True)
    sns.set_theme()
    plt.rcParams.update({'font.size': 18, "figure.figsize": (20, 8)})
    sns.set(font_scale=2)
    
    fig, (ax1, ax2) = plt.subplots(1, 2)

    sns.lineplot(data=metrics, x='epoch', y='loss', hue='split', legend='full', ax=ax1)
    sns.lineplot(data=metrics, x='epoch', y='eval_accuracy', legend='full', ax=ax2)

    
def eval_metrics(trainer, ds):
    logits = trainer.predict(ds)

    predictions = np.argmax(logits.predictions, axis=1)
    clf.metric.add_batch(predictions=predictions, references=ds["labels"])

    print(clf.metric.compute())
    
    
    cm = confusion_matrix(ds["labels"], predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['ang', 'hap', 'neu', 'sad'])

    plt.figure(figsize=(8, 8))
    plt.title('Матрица ошибок')
    sns.heatmap(cm,annot=True, fmt="d")
    
    sns.set(font_scale=2)
    plt.xlabel('Предсказанные значения')
    plt.xticks([0.5, 1.5, 2.5, 3.5], ['злость', 'радость', 'нейтральная', 'грусть'], rotation=45)
    plt.yticks([0.5, 1.5, 2.5, 3.5], ['злость', 'радость', 'нейтральная', 'грусть'], rotation=45)
    plt.ylabel('Верные значения')
    plt.show()
        
    
    plt.rcParams.update({'font.size': 18, "figure.figsize": (14, 14)})

    y_true = ds['labels']
    y_pred = logits.predictions
    skplt.metrics.plot_roc_curve(y_true, y_pred)
    print('f1 score:')
    for average in ['micro', 'macro', 'weighted', None]:
        print(average, metrics.f1_score(y_true, predictions, average=average))
    