# Open Set Emotion Recognition

## Library Imports

In [39]:
import warnings
warnings.filterwarnings("ignore")
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%matplotlib inline
from collections import defaultdict
import torch.nn.functional as F
import torch.nn as nn
import os
import pandas as pd
import torch
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from scipy.io import wavfile
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
tqdm.pandas()
import librosa
import re
from collections import Counter

## Dataset Creation

### MELD

In [40]:
class MELDDataset(Dataset):
    def __init__(self, meld_dir, split, transform=None):
        self.meld_dir = meld_dir
        self.transform = transform
        self.split = split
        self.audio_files = self.load_audio_files()
        self.dialogues = self.load_dialogues()

    def load_audio_files(self):
        audio_dir = os.path.join(self.meld_dir, f'{self.split}_audio')
        audio_files = os.listdir(audio_dir)
        return audio_files

    def load_dialogues(self):
        dialogue_file = os.path.join(self.meld_dir, f'{self.split}_sent_emo.csv')
        dialogues = pd.read_csv(dialogue_file)
        return dialogues

    def __len__(self):
        return len(self.dialogues)

    def __getitem__(self, idx):
        row = self.dialogues.iloc[idx]
        text = row['Utterance']
        audio_dir = os.path.join(self.meld_dir, f'{self.split}_audio')
        audio_data = librosa.load(os.path.join(audio_dir, f'dia{row["Dialogue_ID"]}_utt{row["Utterance_ID"]}.wav'))
        label = row['Emotion']
        if self.transform:
            audio_data[0] = self.transform(audio_data[0])
        return text, audio_data, label

train_meld = MELDDataset("../MELD_Dataset", "train")
test_meld = MELDDataset("../MELD_Dataset", "test")
dev_meld = MELDDataset("../MELD_Dataset", "dev")

# concat all 3 datasets into 1 dataset
meld_dataset = train_meld + test_meld + dev_meld

In [34]:
meld_dataset[0]

('also I was the point person on my company\x92s transition from the KL-5 to GR-6 system.',
 (array([-0.00198962, -0.02142129, -0.02587057, ..., -0.06124197,
         -0.06868309, -0.04373461], dtype=float32),
  22050),
 'neutral')

#### IEMOCAP

In [36]:
class IemocapDataset(Dataset):
    def __init__(self, iemocap_dataset_full_path, transform=None):
        self.IEMOCAP_MAIN_FOLDER = os.path.join(iemocap_dataset_full_path, "IEMOCAP_full_release")
        self.TRANSCRIPTION_FOLDER = os.path.join("dialog", "transcriptions")
        self.AUDIO_FOLDER = os.path.join("sentences", "wav")
        self.CATEGORICAL_LABELS_PATH = os.path.join("dialog", "EmoEvaluation", "Categorical")
        self.transform = transform

        self.errors = defaultdict(int)
        self.dataset = self.create_dataset()
        self.print_summary()

    def get_evaluator_filenames_with_video_file_prefix(self, input_list, prefix_value):
        regex_pattern = re.compile(f'^{re.escape(prefix_value)}.*\.txt$')
        matching_strings = [s for s in input_list if regex_pattern.match(s)]
        return matching_strings

    def get_utterance_to_evaluationCounter_mapping_from_evaluation_files(self, evaluation_files):
        utterance_to_all_evaluations = {}

        for evaluation_file in evaluation_files:
            utterance_to_evaluationList = {}
            with open(evaluation_file,'r') as f:
                contents = f.read()
                utterance_evaluations = contents.split("\n")
                for evaluation in utterance_evaluations:
                    evaluation = evaluation.strip()
                    if len(evaluation) == 0:
                        continue
                    matches = re.findall(r':[^;]+;', evaluation)
                    matches = [match[1:-1] for match in matches]
                    utterance_to_evaluationList[evaluation.split()[0]] = matches

            # Combine lists from dict1
            for key, value_list in utterance_to_evaluationList.items():
                utterance_to_all_evaluations[key] = utterance_to_all_evaluations.get(key, []) + value_list

        utterance_to_evaluationsCounter = {k:Counter(v).most_common(1)[0][0] for k,v in utterance_to_all_evaluations.items()}
        return utterance_to_evaluationsCounter

    def create_dataset(self):
        dataset = []
        for session_num in range(1,6):
            for transcription_filename in os.listdir(os.path.join(self.IEMOCAP_MAIN_FOLDER,f"Session{session_num}", self.TRANSCRIPTION_FOLDER)):
                if transcription_filename[0] != ".":
                    filename_without_extension = transcription_filename.split(".")[0]

                    categorical_labels_folder_full_path = os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.CATEGORICAL_LABELS_PATH)
                    evaluation_filenames = self.get_evaluator_filenames_with_video_file_prefix(os.listdir(categorical_labels_folder_full_path), filename_without_extension)
                    evaluation_files_full_paths_for_this_file = [os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.CATEGORICAL_LABELS_PATH, f) for f in evaluation_filenames]
                    evaluations_per_utterance = self.get_utterance_to_evaluationCounter_mapping_from_evaluation_files(evaluation_files_full_paths_for_this_file)

                    transcription_file_full_path = os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.TRANSCRIPTION_FOLDER, transcription_filename)
                    with open(transcription_file_full_path,'r') as f:
                        contents = f.read()
                        lines = contents.split("\n")

                        # Iterate through utterances where every utterance looks like:
                        # Ses01F_impro01_F000 [006.2901-008.2357]: Excuse me.
                        for line in lines:

                            # Remove extra spaces and check if the line is not an empty link (usually at EOF)
                            line = line.strip()
                            if(len(line)==0):
                                break

                            # Remove idx of first space, ], -
                            try:
                                space_idx = line.index(" ")
                                timestampEndBracket_idx = line.index("]")
                                timestampHyphen_idx = line.index("-")
                            except:
                                self.errors["Problematic Transcription Line"]+=1
                                continue
                            else:
                                audio_filename = line[:space_idx]        # output audio file name = utterance name
                                text = line[timestampEndBracket_idx+3:]         # the transcription of the utterance
                                evaluation = evaluations_per_utterance.get(audio_filename,"KEY_ERROR")
                                if(evaluation=="KEY_ERROR"):
                                    self.errors["Unavailable Label for an utterance"]+=1

                                utterance_audios_per_video_folder = audio_filename[:line.rindex('_')]       # Only need Ses01F_impro01 from Ses01F_impro01_F000
                                audio_file_full_path = os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.AUDIO_FOLDER, utterance_audios_per_video_folder, audio_filename+".wav")         # name of the video file

                                if evaluation!="KEY_ERROR" and os.path.isfile(audio_file_full_path)==True:
                                    dataset.append((text,audio_file_full_path,evaluation))
        return dataset

    def print_summary(self):
        print("SUMMARY:\n")
        for k,v in self.errors.items():
            print(f"{k}: {v}")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text, audio, label = self.dataset[idx]

        if self.transform:
            audio[0] = self.transform(audio[0])
        return text, librosa.load(audio), label

iemocap_dataset = IemocapDataset("../IEMOCAP_Dataset")

SUMMARY:

Problematic Transcription Line: 152
Unavailable Label for an utterance: 48


In [37]:
iemocap_dataset[0]

('Excuse me.',
 (array([-0.00476289, -0.0055054 , -0.00418305, ..., -0.00345229,
         -0.0044057 , -0.00205744], dtype=float32), 22050),
 'Neutral state')

## Preprocessing

In [47]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True,
                                  )
model.eval()

def encode_sentence(sentence):
    encoded_dict = tokenizer.encode_plus(
                        sentence,
                        add_special_tokens = True,
                        max_length = 64,
                        pad_to_max_length = True,
                        return_attention_mask = True,  # Construct attention masks.
                        return_tensors = 'pt',
                   )

    with torch.no_grad():
        outputs = model(encoded_dict['input_ids'], encoded_dict['attention_mask'])
        hidden_states = outputs[2]

    token_vecs_cat = torch.stack(hidden_states[-4:], dim=0)
    token_vecs_cat = torch.mean(token_vecs_cat, 0)
    sentence_embedding = torch.mean(token_vecs_cat, 1)

    return sentence_embedding[0].numpy()

def preprocess_text(text):
    # apostrophe ' is not rendered properly so replacing special character with apostrophe
    text = text.replace("\x92", "'")
    return encode_sentence(text)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


array([-1.97807834e-01, -3.21971893e-01,  1.74944237e-01,  5.88698089e-02,
       -8.24544728e-02, -9.44081992e-02, -2.41122276e-01,  4.06191707e-01,
       -4.01808284e-02, -4.22902219e-02, -2.42862441e-02,  4.39610481e-02,
        4.06122029e-01, -6.70870207e-03,  2.19272390e-01, -1.17840350e-01,
        5.50280735e-02, -9.10694078e-02,  4.62006390e-01, -6.63721561e-03,
       -2.88291685e-02,  2.31405236e-02,  1.55743495e-01,  5.23088813e-01,
       -7.16724917e-02,  1.73188880e-01, -1.97734818e-01,  2.80617148e-01,
       -2.52348661e-01,  2.27498002e-02,  2.16145024e-01, -1.03119127e-02,
       -1.54864728e-01,  1.01391830e-01, -1.45695031e-01, -4.98875342e-02,
       -1.62815645e-01,  7.69490898e-02, -5.53310633e-01,  7.87323341e-04,
       -6.94538474e-01, -3.27924848e-01,  3.03773314e-01, -4.50806618e-02,
       -3.67972553e-01,  1.69997901e-01,  9.03725624e-02, -6.69086426e-02,
        4.09864068e-01, -2.58970767e-01, -2.32696936e-01, -4.69171628e-02,
       -1.00126803e-01, -

In [None]:
def compute_and_plot_mfccs(filepath, save_plot=True):
    # Load the audio file
    y, sr = librosa.load(filepath)
    # Compute the MFCCs
    mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)  # You can adjust n_mfcc as needed

    plt.figure(figsize=(10, 4))
    librosa.display.specshow(mfccs, sr=sr, x_axis='time')
    plt.colorbar()
    plt.title(f'MFCC of {filepath}')
    plt.tight_layout()

    # Save or show the plot
    if save_plot:
        plt.savefig(f'{filepath}_mfcc.png')
        plt.show()

        plt.close()  # Close the plot to free memory
    else:
        plt.show()

    return mfccs

def create_spectrogram(filepath):
    # Load the audio file
    y, sr = librosa.load(filepath)

    S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)

    S_dB = librosa.power_to_db(S, ref=np.max)

    plt.figure(figsize=(10, 4))
    librosa.display.specshow(S_dB, sr=sr, x_axis='time', y_axis='mel', fmax=8000)
    plt.colorbar(format='%+2.0f dB')
    plt.title(f'Mel-frequency spectrogram of {filepath}')
    plt.tight_layout()
    plt.show()

def log_specgram(audio, sample_rate, window_size=20, step_size=10, eps=1e-10):
    audio_length = len(audio)
    window_size_samples = int(round(window_size * sample_rate / 1e3))
    step_size_samples = int(round(step_size * sample_rate / 1e3))

    # Adjust nperseg and noverlap for very short audio
    nperseg = min(window_size_samples, audio_length // 3)
    noverlap = min(step_size_samples, nperseg // 2, audio_length // 4)

    freqs, times, spec = signal.spectrogram(audio,
                                    fs=sample_rate,
                                    window='hann',
                                    nperseg=nperseg,
                                    noverlap=noverlap,
                                    detrend=False)
    return freqs, np.log(spec.T.astype(np.float32) + eps)

def audio2spectrogram_and_save(filepath):
    samplerate, test_sound = wavfile.read(filepath, mmap=True)
    if test_sound.ndim > 1:
        test_sound = test_sound.mean(axis=1)
    _, spectrogram = log_specgram(test_sound, samplerate)
    plt.figure(figsize=(10, 4))  # Adjust size as needed
    plt.imshow(spectrogram.T, aspect='auto', origin='lower')
    plt.title(f'Spectrogram of {filepath}')
    plt.colorbar()
    plt.tight_layout()
    # Save the plot instead of showing it
    if not os.path.exists("./spectrogram_dir"):
        os.makedirs("./spectrogram_dir")
    plt.savefig(os.path.join("./spectrogram_dir", f'{filepath}_spectrogram.png'))
    plt.show()
    plt.close()  # Close the plot to free memory
    return spectrogram



def audio2wave(filepath):
    fig = plt.figure(figsize=(5,5))
    samplerate, test_sound  = wavfile.read(filepath,mmap=True)
    plt.plot(test_sound)


# Example usage11
filepaths = [f'/content/dia0_utt{i}.wav' for i in range(10)]
for filepath in filepaths:
    spectrogram = audio2spectrogram_and_save(filepath)
    print(f'{filepath}: Spectrogram shape: {spectrogram.shape}')

def preprocess_audio(audio):

    return audio

## Model

In [5]:
class AudioEmotionModel(nn.Module):
    def __init__(self):
        super(AudioEmotionModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv4(x))
        x = F.max_pool1d(x, 2)
        x = x.view(x.size(0), -1)
        return x

class AudioTextEmotionModel(nn.Module):
    def __init__(self, audio_model, text_model, num_classes):
        super(AudioTextEmotionModel, self).__init__()
        self.audio_model = audio_model
        self.text_model = text_model
        self.fc = nn.Linear(2 * audio_model.conv4.out_channels, num_classes)

    def forward(self, audio, text):
        audio_out = self.audio_model(audio)
        text_out = self.text_model(text)
        combined = torch.cat([audio_out, text_out], dim=1)
        return self.fc(combined)

## Loss Function and Optimizer
One final step before we can simply call `model.fit`

## Train!

## Evaluation

In [18]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for text, audio, label in dataloader:
            audio = audio.to(device)
            text = text.to(device)
            label = label.to(device)
            outputs = model(audio, text)
            _, predicted = torch.max(outputs.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
    return correct / total

init class1
init class2
False
True
True
5
6
5
