# Open Set Emotion Recognition

## Library Imports

In [1]:
import warnings
warnings.filterwarnings("ignore")
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%matplotlib inline
from collections import Counter, defaultdict
import torch.nn as nn
import os
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
tqdm.pandas()
from tqdm import tqdm
import librosa
import re
from collections import Counter
import torch
from torchvision import models, transforms
from PIL import Image
from scipy import signal
from scipy.io import wavfile
import matplotlib.pyplot as plt
import numpy as np
import time
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score,f1_score,confusion_matrix
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchaudio
from transformers import HubertModel, HubertConfig
from sentence_transformers import SentenceTransformer
from functools import lru_cache

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

Using device: cuda


In [2]:
OTHER_LABEL = 6

## Dataset Creation

### MELD

In [3]:
# class MELDDataset(Dataset):
#     def __init__(self, meld_dir, split, transform=None):
#         train_df = pd.read_csv("../MELD_Dataset/train_sent_emo.csv")
#         labels = train_df['Emotion'].unique().tolist()
#         self.label_to_int = {label: i for i, label in enumerate(labels)}

#         self.meld_dir = meld_dir
#         self.transform = transform
#         self.split = split
#         self.img_path = os.path.join(self.meld_dir, 'mel_spectrograms', f'{self.split}_img')
#         self.img_path = os.path.join(self.meld_dir, 'log_spectrogram', f'{self.split}_audio')

#         # load and create sentence embeddings
#         self.dialogues = self.load_dialogues()
#         self.sbert = SentenceTransformer('multi-qa-mpnet-base-dot-v1', device=device)
#         sentences = self.dialogues['Utterance'].tolist()
#         sentences = [text.replace("\x92", "'") for text in sentences]
#         self.sentence_embeddings = self.sbert.encode(sentences, convert_to_tensor=True, show_progress_bar=True, batch_size=128, device=device)

#         self.spectrograms = self.load_spectrograms()
#         self.resnet_model = models.resnet50(pretrained=True)
#         self.feature_extractor = torch.nn.Sequential(*list(self.resnet_model.children())[:-1]).to(device)
#         self.feature_extractor.eval()

#     def load_dialogues(self):
#         dialogue_file = os.path.join(self.meld_dir, f'{self.split}_sent_emo.csv')
#         dialogues = pd.read_csv(dialogue_file)
#         return dialogues

#     def load_spectrograms(self):
#         images = os.listdir(self.img_path)
#         return images

#     def __len__(self):
#         assert(len(self.sentence_embeddings) == len(self.spectrograms))
#         return len(self.dialogues)

#     def preprocess_img(self, img):
#         preprocessor = transforms.Compose([
#             transforms.Resize(256),
#             transforms.CenterCrop(224),
#             transforms.ToTensor(),
#         ])
#         img_t =  preprocessor(img).to(device)
#         return img_t

#     def extract_audio_features_from_spectrogram(self, img):
#         # Pass the input through the model
#         with torch.no_grad():
#             output = self.feature_extractor(img)
#         return output

#     def __getitem__(self, idx):
#         row = self.dialogues.iloc[idx]
#         text = self.sentence_embeddings[idx]
#         spectrogram_data = Image.open(os.path.join(self.img_path, f'dia{row["Dialogue_ID"]}_utt{row["Utterance_ID"]}.png'))
#         spectrogram_data = self.preprocess_img(spectrogram_data)
#         spectrogram_data = spectrogram_data[0:3, :, :]
#         spectrogram_data = spectrogram_data.unsqueeze(0)
#         spectrogram_data = self.extract_audio_features_from_spectrogram(spectrogram_data)
#         spectrogram_data = spectrogram_data.view(-1, 2048)[0]
#         label = row['Emotion']
#         label = torch.tensor(self.label_to_int[label])
#         return text, spectrogram_data, label

# train_meld = MELDDataset("../MELD_Dataset", "train")
# # test_meld = MELDDataset("../MELD_Dataset", "test")
# # dev_meld = MELDDataset("../MELD_Dataset", "dev")

# # concat all 3 datasets into 1 dataset
# meld_dataset = train_meld # + test_meld + dev_meld

In [4]:
# len(meld_dataset)

#### IEMOCAP

In [5]:
class IemocapDataset(Dataset):
    def __init__(self, iemocap_dataset_full_path, iemocap_spectrogram_dir, iemocap_log_spectrogram_dir, is_closed_label_set_flag, labels_to_int, split, modality = "both", transform=None):
        self.IEMOCAP_MAIN_FOLDER = os.path.join(iemocap_dataset_full_path,"IEMOCAP_full_release")
        self.TRANSCRIPTION_FOLDER = os.path.join("dialog", "transcriptions")
        self.AUDIO_FOLDER = os.path.join("sentences", "wav")
        self.CATEGORICAL_LABELS_PATH = os.path.join("dialog", "EmoEvaluation", "Categorical")
        self.split = split
        self.transform = transform
        self.is_closed_label_set_flag = is_closed_label_set_flag
        self.iemocap_spectrogram_dir = iemocap_spectrogram_dir
        self.iemocap_log_spectrogram_dir = iemocap_log_spectrogram_dir
        self.modality = modality
        
        self.errors = defaultdict(int)
        self.unique_labels = []
        self.audio_files = []
        self.sentences_list = []
        self.dataset = self.create_dataset()
        self.labels_to_int = labels_to_int

        self.sbert = SentenceTransformer('multi-qa-mpnet-base-dot-v1', device=device)
        self.sentence_embeddings = self.sbert.encode(self.sentences_list, convert_to_tensor=True, show_progress_bar=True, batch_size=128, device=device)
        
        # config = HubertConfig.from_pretrained("facebook/hubert-large-ls960-ft")
        # self.hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft", config=config)
        self.create_spectrograms(self.iemocap_spectrogram_dir)
        self.create_log_spectrograms(self.iemocap_log_spectrogram_dir)
        self.print_summary()

        self.resnet_model = models.resnet50(pretrained=True)
        self.feature_extractor = torch.nn.Sequential(*list(self.resnet_model.children())[:-1]).to(device)
        self.feature_extractor.eval()
        
    def get_evaluator_filenames_with_video_file_prefix(self, input_list, prefix_value):
        regex_pattern = re.compile(f'^{re.escape(prefix_value)}.*\.txt$')
        matching_strings = [s for s in input_list if regex_pattern.match(s)]
        return matching_strings
    
    def get_utterance_to_evaluationCounter_mapping_from_evaluation_files(self, evaluation_files):
        utterance_to_all_evaluations = {}

        for evaluation_file in evaluation_files:
            utterance_to_evaluationList = {}
            with open(evaluation_file,'r') as f:
                contents = f.read()
                utterance_evaluations = contents.split("\n")
                for evaluation in utterance_evaluations:
                    evaluation = evaluation.strip()
                    if(len(evaluation)==0):
                        continue
                    matches = re.findall(r':[^;]+;', evaluation)
                    matches = [match[1:-1] for match in matches]
                    utterance_to_evaluationList[evaluation.split()[0]] = matches
            
            # Combine lists from dict1
            for key, value_list in utterance_to_evaluationList.items():
                utterance_to_all_evaluations[key] = utterance_to_all_evaluations.get(key, []) + value_list

        utterance_to_evaluationsCounter = {k:Counter(v).most_common(1)[0][0] for k,v in utterance_to_all_evaluations.items()}
        return utterance_to_evaluationsCounter
    
    def is_label_a_closed_label(self,evaluation):
        return evaluation in ["Frustration","Excited","Neutral state","Anger","Sadness","Happiness"]
    
    def create_dataset(self):
        dataset = []
        for session_num in range(1,6):
            for transcription_filename in os.listdir(os.path.join(self.IEMOCAP_MAIN_FOLDER,f"Session{session_num}", self.TRANSCRIPTION_FOLDER)):
                if(transcription_filename[0]!="."): 

                    filename_without_extension = transcription_filename.split(".")[0]
                    
                    categorical_labels_folder_full_path = os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.CATEGORICAL_LABELS_PATH)
                    evaluation_filenames = [self.get_evaluator_filenames_with_video_file_prefix(os.listdir(categorical_labels_folder_full_path), filename_without_extension)[0]]
                    evaluation_files_full_paths_for_this_file = [os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.CATEGORICAL_LABELS_PATH, f) for f in evaluation_filenames]
                    evaluations_per_utterance = self.get_utterance_to_evaluationCounter_mapping_from_evaluation_files(evaluation_files_full_paths_for_this_file)
                    
                    transcription_file_full_path = os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.TRANSCRIPTION_FOLDER, transcription_filename) 
                    with open(transcription_file_full_path,'r') as f:
                        contents = f.read()
                        lines = contents.split("\n")

                        # Iterate through utterances where every utterance looks like:
                        # Ses01F_impro01_F000 [006.2901-008.2357]: Excuse me.
                        for line in lines:

                            # Remove extra spaces and check if the line is not an empty link (usually at EOF)
                            line = line.strip()
                            if(len(line)==0):
                                break

                            # Remove idx of first space, ], -
                            try:
                                space_idx = line.index(" ")
                                timestampEndBracket_idx = line.index("]")
                                timestampHyphen_idx = line.index("-")
                            except:
                                self.errors["Problematic Transcription Line"]+=1
                                continue
                            else:
                                audio_filename = line[:space_idx]        # output audio file name = utterance name
                                text = line[timestampEndBracket_idx+3:]         # the transcription of the utterance
                                evaluation = evaluations_per_utterance.get(audio_filename,"KEY_ERROR")
                                if(evaluation=="KEY_ERROR"):
                                    self.errors["Unavailable Label for an utterance"]+=1

                                utterance_audios_per_video_folder = audio_filename[:line.rindex('_')]       # Only need Ses01F_impro01 from Ses01F_impro01_F000
                                audio_file_full_path = os.path.join(self.IEMOCAP_MAIN_FOLDER, f"Session{session_num}", self.AUDIO_FOLDER, utterance_audios_per_video_folder, audio_filename+".wav")         # name of the video file

                                if(evaluation!="KEY_ERROR" and os.path.isfile(audio_file_full_path)==True and self.is_label_a_closed_label(evaluation)==self.is_closed_label_set_flag):
                                # if(evaluation!="KEY_ERROR" and os.path.isfile(audio_file_full_path)==True):    
                                    self.audio_files.append(audio_file_full_path)
                                    self.sentences_list.append(text)
                                    dataset.append((text,audio_file_full_path,evaluation))
                                    if evaluation not in self.unique_labels:
                                        self.unique_labels.append(evaluation)
        return dataset
    
    def print_summary(self):
        print("SUMMARY:\n")
        for k,v in self.errors.items():
            print(f"{k}: {v}")
    
    def create_spectrograms(self,iemocap_spectrogram_dir):
        log_dir = os.path.join(os.path.dirname(os.getcwd()),'iemocap','log_dir')
        output_dir = os.path.join(os.path.dirname(os.getcwd()),iemocap_spectrogram_dir)
        log_file_path = os.path.join(log_dir,'processed_files_spectrogram.log')
        error_log_path = os.path.join(log_dir,'error_files_spectrogram.log')

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        processed_files = set()
        if os.path.exists(log_file_path):
            with open(log_file_path, 'r') as file:
                processed_files = set(file.read().splitlines())

        processed_files_count = 0
        throttle_delay = 1 
        def create_spectrogram(filename, audio_file_path, output_file_path):
            y, sr = librosa.load(audio_file_path)
            S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
            S_dB = librosa.power_to_db(S, ref=np.max)
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(S_dB, sr=sr, fmax=8000)
            plt.tight_layout()
            plt.savefig(output_file_path)
            plt.close()

        for filenum in tqdm(range(len(self.audio_files))):
            filename = self.audio_files[filenum]
            if filename.endswith(".wav") and filename not in processed_files:

                audio_file_path = os.path.join(filename)
                output_file_path = os.path.join(output_dir, os.path.splitext(os.path.basename(filename))[0])
                try:
                    create_spectrogram(filename, audio_file_path, output_file_path)
                    processed_files.add(filename)
                    processed_files_count += 1
                    with open(log_file_path, 'a') as log_file:
                        log_file.write(f"{filename}\n")
                except Exception as e:
                    print(f"Error processing {filename}: {e}")
                    with open(error_log_path, 'a') as error_log:
                        error_log.write(f"{filename}: {e}\n")
                finally:
                    time.sleep(throttle_delay)

        print(f"Batch conversion completed for spectrograms. Processed {processed_files_count} files.")

    
    def create_log_spectrograms(self,iemocap_log_spectrogram_dir):
        def log_specgram(audio, sample_rate, window_size=20,
                        step_size=10, eps=1e-10):
            nperseg = int(round(window_size * sample_rate / 1e3))
            noverlap = int(round(step_size * sample_rate / 1e3))
            freqs, times, spec = signal.spectrogram(audio, fs=sample_rate, window='hann', 
                                                    nperseg=nperseg, noverlap=noverlap, detrend=False)
            return freqs, np.log(spec.T.astype(np.float32) + eps)

        def process_audio_file(filepath, output_dir):
            sample_rate, audio = wavfile.read(filepath)
            if audio.ndim > 1:
                audio = audio.mean(axis=1)
            _, spectrogram = log_specgram(audio, sample_rate)
            plt.figure(figsize=(10, 4))  
            plt.xticks([])
            plt.yticks([])
            plt.imshow(spectrogram.T, aspect='auto', origin='lower')
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, os.path.splitext(os.path.basename(filepath))[0]+".png"))
            plt.close()  

        log_dir = os.path.join(os.path.dirname(os.getcwd()),'iemocap','log_dir')
        output_dir = os.path.join(os.path.dirname(os.getcwd()),iemocap_log_spectrogram_dir)

        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        log_file_path = os.path.join(log_dir, 'processed_files_log_spectrogram.log')
        error_log_path = os.path.join(log_dir, 'error_files_log_spectrogram.log')

        throttle_delay = 1 

        processed_files = set()
        if os.path.exists(log_file_path):
            with open(log_file_path, 'r') as file:
                processed_files = set(file.read().splitlines())

        processed_files_count = 0
        for filenum in tqdm(range(len(self.audio_files))):
            filepath = self.audio_files[filenum]
            if filepath.endswith(".wav") and filepath not in processed_files:
                try:
                    process_audio_file(filepath, output_dir)
                    processed_files.add(filepath)
                    processed_files_count += 1
                    with open(log_file_path, 'a') as log_file:
                        log_file.write(f"{filepath}\n")
                except Exception as e:
                    print(f"Error processing {filepath}: {e}")
                    with open(error_log_path, 'a') as error_log:
                        error_log.write(f"{filepath}: {e}\n")
                finally:
                    time.sleep(throttle_delay)

        print(f"Batch conversion completed for log spectrograms. Processed {processed_files_count} files.")

    def preprocess_img(self, img):
        preprocessor = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        img_t =  preprocessor(img).to(device)
        return img_t

    def extract_audio_features_from_spectrogram(self, img):
        # Pass the input through the model
        with torch.no_grad():
            output = self.feature_extractor(img)
        return output
    
    def __len__(self):
        return len(self.dataset)
    

    @lru_cache(maxsize=None)
    def cached_audio_features(self, img_path):
        spectrogram_data = Image.open(img_path)
        spectrogram_data = self.preprocess_img(spectrogram_data)
        spectrogram_data = spectrogram_data[0:3, :, :]
        spectrogram_data = spectrogram_data.unsqueeze(0)
        spectrogram_data = self.extract_audio_features_from_spectrogram(spectrogram_data)
        spectrogram_data = spectrogram_data.view(-1, 2048)[0]
        return spectrogram_data

    # @lru_cache(maxsize=None)
    # def cached_audio_features(self,audio_path):
    #     waveform, sample_rate = torchaudio.load(audio_path)

    #     # Resample if necessary (HuBERT uses 16 kHz sample rate)
    #     if sample_rate != 16000:
    #         waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
    #         sample_rate = 16000

    #     # Ensure single channel audio (mono)
    #     if waveform.shape[0] > 1:
    #         waveform = torch.mean(waveform, dim=0, keepdim=True)

    #     # Normalize audio
    #     waveform = waveform / torch.max(torch.abs(waveform))

    #     # Step 3: Encode audio waveforms using HuBERT model
    #     # Pass the waveform through the model
    #     with torch.no_grad():
    #         outputs = self.hubert_model(waveform)

    #     # Get the hidden states
    #     hidden_states = outputs.last_hidden_state
    #     frame_averaged_hidden_states = torch.mean(hidden_states, dim=1)
    #     frame_averaged_hidden_states = frame_averaged_hidden_states.view(-1,1024)[0]
    #     return frame_averaged_hidden_states
    
    def __getitem__(self, idx):
        _, audio, label = self.dataset[idx]

        if self.modality == "text" or self.modality == "both":
            text = self.sentence_embeddings[idx]

        if self.modality == "audio" or self.modality == "both":
            img_path = os.path.join(os.path.dirname(os.getcwd()),self.iemocap_spectrogram_dir,os.path.splitext(os.path.basename(audio))[0]+".png")
            audio_features = self.cached_audio_features(img_path)

        if self.is_closed_label_set_flag==False:
            label = OTHER_LABEL
        else:
            label = self.labels_to_int[label]

        if (self.modality == "audio" or self.modality == "both") and self.transform:
            audio = self.transform(audio)
        
        if self.modality == "both":
            return text, audio_features, label
        elif self.modality == "text":
            return text, label
        else:
            return audio_features, label

In [6]:
IEMOCAP_FULL_PATH = os.path.join(os.path.dirname(os.getcwd()),"IEMOCAP_full_release")
labels_to_int = {'Neutral state': 0,
                'Frustration': 1,
                'Anger': 2,
                'Sadness': 3,
                'Happiness': 4,
                'Excited': 5,
                'Surprise': 6,
                'Fear': 7,
                'Other': 8,
                'Disgust': 9}

openIemocapDataset = IemocapDataset(iemocap_dataset_full_path=IEMOCAP_FULL_PATH,
                                iemocap_spectrogram_dir=os.path.join("iemocap","spectrogram"),
                                iemocap_log_spectrogram_dir=os.path.join("iemocap","log_spectrogram"),
                                is_closed_label_set_flag=False,
                                labels_to_int=labels_to_int,
                                split=None,
                                transform=None)
closedIemocapDataset = IemocapDataset(iemocap_dataset_full_path=IEMOCAP_FULL_PATH,
                                iemocap_spectrogram_dir=os.path.join("iemocap","spectrogram"),
                                iemocap_log_spectrogram_dir=os.path.join("iemocap","log_spectrogram"),
                                is_closed_label_set_flag=True,
                                labels_to_int=labels_to_int,
                                split=None,
                                transform=None)

closedTextOnlyIemocapDataset = IemocapDataset(iemocap_dataset_full_path=IEMOCAP_FULL_PATH,
                                    iemocap_spectrogram_dir=os.path.join("iemocap","spectrogram"),
                                    iemocap_log_spectrogram_dir=os.path.join("iemocap","log_spectrogram"),
                                    is_closed_label_set_flag=True,
                                    labels_to_int=labels_to_int,
                                    split=None,
                                    transform=None,
                                    modality = "text")

closedAudioOnlyIemocapDataset = IemocapDataset(iemocap_dataset_full_path=IEMOCAP_FULL_PATH,
                                    iemocap_spectrogram_dir=os.path.join("iemocap","spectrogram"),
                                    iemocap_log_spectrogram_dir=os.path.join("iemocap","log_spectrogram"),
                                    is_closed_label_set_flag=True,
                                    labels_to_int=labels_to_int,
                                    split=None,
                                    transform=None,
                                    modality = "audio")

Batches: 100%|██████████| 2/2 [00:04<00:00,  2.16s/it]
100%|██████████| 219/219 [00:00<?, ?it/s]


Batch conversion completed for spectrograms. Processed 0 files.


100%|██████████| 219/219 [00:00<?, ?it/s]


Batch conversion completed for log spectrograms. Processed 0 files.
SUMMARY:

Problematic Transcription Line: 152
Unavailable Label for an utterance: 128


Batches: 100%|██████████| 77/77 [00:09<00:00,  8.13it/s]
100%|██████████| 9740/9740 [00:00<00:00, 1300745.72it/s]


Batch conversion completed for spectrograms. Processed 0 files.


100%|██████████| 9740/9740 [00:00<00:00, 1483657.92it/s]


Batch conversion completed for log spectrograms. Processed 0 files.
SUMMARY:

Problematic Transcription Line: 152
Unavailable Label for an utterance: 128


Batches: 100%|██████████| 77/77 [00:09<00:00,  8.37it/s]
100%|██████████| 9740/9740 [00:00<00:00, 1773420.77it/s]


Batch conversion completed for spectrograms. Processed 0 files.


100%|██████████| 9740/9740 [00:00<00:00, 1618690.90it/s]


Batch conversion completed for log spectrograms. Processed 0 files.
SUMMARY:

Problematic Transcription Line: 152
Unavailable Label for an utterance: 128


Batches: 100%|██████████| 77/77 [00:09<00:00,  8.34it/s]
100%|██████████| 9740/9740 [00:00<00:00, 1756342.26it/s]


Batch conversion completed for spectrograms. Processed 0 files.


100%|██████████| 9740/9740 [00:00<00:00, 1375598.39it/s]


Batch conversion completed for log spectrograms. Processed 0 files.
SUMMARY:

Problematic Transcription Line: 152
Unavailable Label for an utterance: 128


In [7]:
print(Counter([i[2] for i in closedIemocapDataset]))
len(closedIemocapDataset)

Counter({1: 3787, 5: 2505, 2: 1204, 3: 1178, 0: 571, 4: 495})


9740

In [8]:
print(Counter([i[2] for i in openIemocapDataset]))
len(openIemocapDataset)

Counter({6: 219})


219

In [9]:
def get_stratified_split(dataset, test_size):

    indices = list(range(len(dataset)))
    dataset_labels = [item[-1] for item in dataset]
    stratified_splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=42)

    # Get the indices for training and testing sets
    for train_idx, test_idx in stratified_splitter.split(indices, dataset_labels):
        train_dataset = torch.utils.data.Subset(dataset, train_idx)
        test_dataset = torch.utils.data.Subset(dataset, test_idx)
    
    return train_dataset,test_dataset

In [10]:
# TEXT + AUDIO SPLITS
train_dataset, temp_dataset = get_stratified_split(closedIemocapDataset,0.2)
val_dataset_1, test_dataset_1 = get_stratified_split(temp_dataset,0.5)
val_dataset_2, test_dataset_2 = get_stratified_split(openIemocapDataset,0.5)
val_dataset = val_dataset_1+val_dataset_2
test_dataset = test_dataset_1+test_dataset_2

# TEXT SPLIT
train_dataset_text_only, val_dataset_text_only = get_stratified_split(closedTextOnlyIemocapDataset,0.2)

# AUDIO SPLIT
train_dataset_audio_only, val_dataset_audio_only = get_stratified_split(closedAudioOnlyIemocapDataset,0.2)

In [11]:
# Create data loaders.
batch_size = 64

# AUDIO + TEXT DATALOADERS (CONSISTS OF BOTH OPEN AND CLOSED LABELS)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# AUDIO + TEXT DATALOADERS (ONLY CLOSED LABELS - USED FOR CHECKING MODEL PERFORMANCE AFTER KEEPING ASIDE THE OPEN SET CHALLENGE)
val_closed_set_dataloader = DataLoader(val_dataset_1, batch_size=batch_size)
test_closed_set_dataloader = DataLoader(test_dataset_1, batch_size=batch_size)

# TEXT DATALOADERS
train_dataloader_text_only = DataLoader(train_dataset_text_only, batch_size=batch_size)
val_dataloader_text_only = DataLoader(val_dataset_text_only, batch_size=batch_size)

# AUDIO DATALOADERS
train_dataloader_audio_only = DataLoader(train_dataset_audio_only, batch_size=batch_size)
val_dataloader_audio_only = DataLoader(val_dataset_audio_only, batch_size=batch_size)

## Model Architectures Definition

In [12]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0, path='model_checkpoint.pt'):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_acc_max = np.Inf

    def __call__(self, val_acc, model):

        if self.best_score is None:
            self.best_score = val_acc
            self.save_checkpoint(val_acc, model)
        elif val_acc < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_acc
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_acc

### Text Unimodal

In [13]:
class TextEmotionModel(nn.Module):
    def __init__(self, num_classes):
        super(TextEmotionModel, self).__init__()
        # sequential model with 2 layers, followed by dropout and relu layers and output layer
        self.fc = nn.Sequential(
            nn.Linear(768, 1024),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )


    def forward(self, text):
        return self.fc(text)


text_model = TextEmotionModel(6)
text_model.to(device)

TextEmotionModel(
  (fc): Sequential(
    (0): Linear(in_features=768, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): ReLU()
    (8): Linear(in_features=512, out_features=6, bias=True)
  )
)

### Audio Unimodal

In [14]:
class AudioEmotionModel(nn.Module):
    def __init__(self, num_classes):
        super(AudioEmotionModel, self).__init__()
        # sequential model with 2 layers, followed by dropout and relu layers and output layer
        self.fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )


    def forward(self, audio):
        return self.fc(audio)


audio_model = AudioEmotionModel(6)
audio_model.to(device)

AudioEmotionModel(
  (fc): Sequential(
    (0): Linear(in_features=2048, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): ReLU()
    (8): Linear(in_features=512, out_features=6, bias=True)
  )
)

### Multimodal

In [15]:
class AudioTextEmotionModel(nn.Module):
    def __init__(self, num_classes):
        super(AudioTextEmotionModel, self).__init__()
        # sequential model with 2 layers, followed by dropout and relu layers and output layer
        self.fc = nn.Sequential(
            nn.Linear(2048 + 768, 1024),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )


    def forward(self, text, audio):
        combined = torch.cat([audio, text], axis=1)
        return self.fc(combined)


model = AudioTextEmotionModel(6)
model.to(device)


AudioTextEmotionModel(
  (fc): Sequential(
    (0): Linear(in_features=2816, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): ReLU()
    (8): Linear(in_features=512, out_features=6, bias=True)
  )
)

## Loss Function and Optimizer (of the 2 unimodal models and the multimodal model)

In [16]:
# TEXT ONLY 
loss_fn_text = nn.CrossEntropyLoss()
optimizer_text = torch.optim.RMSprop(text_model.parameters(), lr=1e-3, momentum=0.9)
scheduler_text = ReduceLROnPlateau(optimizer_text, mode='max', factor=0.1, patience=3, verbose=True)

# AUDIO ONLY
loss_fn_audio = nn.CrossEntropyLoss()
optimizer_audio = torch.optim.RMSprop(audio_model.parameters(), lr=1e-3, momentum=0.9)
scheduler_audio = ReduceLROnPlateau(optimizer_audio, mode='max', factor=0.1, patience=3, verbose=True)

# TEXT + AUDIO
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)

num_epochs = 100

## Train (the 2 unimodal models and the multimodal model on only closed set labels to test the performance. We have kept aside the open set challenge for now)

In [17]:
def accuracy(dataloader, model):
    size = len(dataloader.dataset)
    total_correct = 0
    model.eval()
    for batch, x_and_y in enumerate(dataloader):
        x_and_y_device = [x_and_y[i].to(device) for i in range(len(x_and_y))]

        # Compute prediction error
        pred = model(*x_and_y_device[:-1])
        predicted = torch.argmax(pred,dim=1).cpu()
        label = x_and_y_device[-1]
        actual = label.cpu()
        correct = predicted == actual
        total_correct += correct.sum().item()
    return total_correct/size

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, x_and_y in enumerate(dataloader):
        x_and_y_device = [x_and_y[i].to(device) for i in range(len(x_and_y))]

        # Compute prediction error
        pred = model(*x_and_y[:-1])
        label = x_and_y_device[-1]
        loss = loss_fn(pred, label)
        # acc = accuracy_score(torch.argmax(pred,dim=1).cpu(), label.cpu())
        # f1 = f1_score(torch.argmax(pred,dim=1).cpu(),label.cpu(),average="micro")

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(x_and_y_device[0])
            # print(f"loss: {loss:>7f}\t\tAccuracy: {acc:>7f}\t\tF1 Score: {f1:>7f}  [{current:>5d}/{size:>5d}]")
            
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

### Text Unimodal Training

In [18]:
early_stopping_text = EarlyStopping(patience=9, delta=0, path="text_model.pt")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_dataloader_text_only, text_model, loss_fn_text, optimizer_text)
    train_accuracy = accuracy(train_dataloader_text_only,text_model)
    val_accuracy = accuracy(val_dataloader_text_only,text_model)
    print(f"Accuracy on Train Set => {train_accuracy} | Accuracy on Closed Validation Set => {val_accuracy}")
    scheduler_text.step(val_accuracy)
    early_stopping_text(val_accuracy,text_model)
    if early_stopping_text.early_stop:
        print("Early stopping")
        break

Epoch 1
-------------------------------
loss: 1.897210  [   64/ 7792]
loss: 1.562696  [ 6464/ 7792]


Accuracy on Train Set => 0.5376026694045175 | Accuracy on Closed Validation Set => 0.46919917864476385
Epoch 2
-------------------------------
loss: 1.384188  [   64/ 7792]
loss: 1.518552  [ 6464/ 7792]
Accuracy on Train Set => 0.5966375770020534 | Accuracy on Closed Validation Set => 0.5148870636550308
Epoch 3
-------------------------------
loss: 1.238571  [   64/ 7792]
loss: 1.424977  [ 6464/ 7792]
Accuracy on Train Set => 0.6368069815195072 | Accuracy on Closed Validation Set => 0.5200205338809035
Epoch 4
-------------------------------
loss: 1.122440  [   64/ 7792]
loss: 1.345962  [ 6464/ 7792]
Accuracy on Train Set => 0.6799281314168378 | Accuracy on Closed Validation Set => 0.5313141683778234
Epoch 5
-------------------------------
loss: 1.064067  [   64/ 7792]
loss: 1.126402  [ 6464/ 7792]
Accuracy on Train Set => 0.6958418891170431 | Accuracy on Closed Validation Set => 0.5128336755646817
EarlyStopping counter: 1 out of 9
Epoch 6
-------------------------------
loss: 0.988964 

### Audio Unimodal Training

In [19]:
early_stopping_audio = EarlyStopping(patience=9, delta=0, path="audio_model.pt")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_dataloader_audio_only, audio_model, loss_fn_audio, optimizer_audio)
    train_accuracy = accuracy(train_dataloader_audio_only,audio_model)
    val_accuracy = accuracy(val_dataloader_audio_only,audio_model)
    print(f"Accuracy on Train Set => {train_accuracy} | Accuracy on Closed Validation Set => {val_accuracy}")
    scheduler_audio.step(val_accuracy)
    early_stopping_audio(val_accuracy,audio_model)
    if early_stopping_audio.early_stop:
        print("Early stopping")
        break

Epoch 1
-------------------------------
loss: 1.893719  [   64/ 7792]
loss: 1.655638  [ 6464/ 7792]
Accuracy on Train Set => 0.398870636550308 | Accuracy on Closed Validation Set => 0.39373716632443534
Epoch 2
-------------------------------
loss: 1.484247  [   64/ 7792]
loss: 1.615576  [ 6464/ 7792]
Accuracy on Train Set => 0.436088295687885 | Accuracy on Closed Validation Set => 0.425564681724846
Epoch 3
-------------------------------
loss: 1.444758  [   64/ 7792]
loss: 1.566702  [ 6464/ 7792]
Accuracy on Train Set => 0.4554671457905544 | Accuracy on Closed Validation Set => 0.4322381930184805
Epoch 4
-------------------------------
loss: 1.432117  [   64/ 7792]
loss: 1.548898  [ 6464/ 7792]
Accuracy on Train Set => 0.47369096509240244 | Accuracy on Closed Validation Set => 0.4368583162217659
Epoch 5
-------------------------------
loss: 1.376343  [   64/ 7792]
loss: 1.554894  [ 6464/ 7792]
Accuracy on Train Set => 0.4817761806981519 | Accuracy on Closed Validation Set => 0.42659137

### Multimodal Training

In [20]:
early_stopping = EarlyStopping(patience=9, delta=0)
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    print(f"Accuracy on Train Set => {accuracy(train_dataloader,model)} | Accuracy on Closed Validation Set => {accuracy(val_closed_set_dataloader,model)}")
    scheduler.step(accuracy(val_closed_set_dataloader,model))
    early_stopping(accuracy(val_closed_set_dataloader,model),model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

Epoch 1
-------------------------------
loss: 1.812260  [   64/ 7792]
loss: 1.492659  [ 6464/ 7792]
Accuracy on Train Set => 0.5338809034907598 | Accuracy on Closed Validation Set => 0.4948665297741273
Epoch 2
-------------------------------
loss: 1.211595  [   64/ 7792]
loss: 1.335443  [ 6464/ 7792]
Accuracy on Train Set => 0.5848305954825462 | Accuracy on Closed Validation Set => 0.5225872689938398
Epoch 3
-------------------------------
loss: 1.117370  [   64/ 7792]
loss: 1.269976  [ 6464/ 7792]
Accuracy on Train Set => 0.6279517453798767 | Accuracy on Closed Validation Set => 0.5359342915811088
Epoch 4
-------------------------------
loss: 1.060799  [   64/ 7792]
loss: 1.185912  [ 6464/ 7792]
Accuracy on Train Set => 0.6507956878850103 | Accuracy on Closed Validation Set => 0.5318275154004107
EarlyStopping counter: 1 out of 9
Epoch 5
-------------------------------
loss: 1.077716  [   64/ 7792]
loss: 1.003616  [ 6464/ 7792]
Accuracy on Train Set => 0.7014887063655031 | Accuracy on 

## Evaluation (of the multimodal model on closed and open set labels both)

In [31]:
def set_dropout_to_train(eval_model):
    for module in eval_model.modules():
        if isinstance(module, nn.Dropout):
            module.train()

def predict(label, model, text, spectrogram_data, n_simulations=100, threshold=1, other_label=OTHER_LABEL):
    predictions = [model(text, spectrogram_data).detach().cpu() for _ in range(n_simulations)]
    predictions = torch.stack(predictions)
    predictions = F.softmax(predictions, dim=2)

    mean_predictions = torch.mean(predictions,dim=0)
    std_predictions = torch.mean(torch.std(predictions,dim=0),dim=1)
    _,predicted_class = torch.max(mean_predictions,1)
    high_uncertainty = std_predictions>threshold
    predicted_class[high_uncertainty]=other_label
    return predicted_class

def evaluate(model, dataloader, device, threshold=0.6):
    # After setting the model to evaluation mode, call this function
    model.eval()
    set_dropout_to_train(model)

    size = len(dataloader.dataset)
    total_correct = 0
    total_confusion_matrix = torch.zeros((7,7))
    # total_correct_pred_of_other_label, total_actual_other_label = 0,0
    for batch, (text, spectrogram_data, label) in enumerate(dataloader):
        text, spectrogram_data, label = text.to(device), spectrogram_data.to(device), label.to(device)

        predicted = predict(label, model, text, spectrogram_data, threshold=threshold)
        predicted = predicted.cpu()
        actual = label.cpu()
        correct = predicted == actual
        total_correct += correct.sum().item()
        cm = confusion_matrix(predicted,actual)
        if(cm.shape[0]!=7 and cm.shape[1]!=7):
            row_of_zeros = np.zeros((7-cm.shape[0],cm.shape[1]))
            array_with_row = np.concatenate((cm, row_of_zeros), axis=0)

            # Add a column of zeros at the end
            column_of_zeros = np.zeros((7, 7-cm.shape[1]))
            array_with_row_and_column = np.concatenate((array_with_row, column_of_zeros), axis=1)
            cm = array_with_row_and_column

        total_confusion_matrix+= cm

    print(total_confusion_matrix)
    return total_correct/size

In [34]:
evaluate(model,val_dataloader,device,threshold=0.17)

tensor([[  9.,   6.,   1.,   3.,   1.,   3.,   2.],
        [ 31., 251.,  58.,  37.,  17.,  74.,  22.],
        [  2.,  32.,  48.,   5.,   2.,   9.,   5.],
        [  5.,  33.,   4.,  63.,   1.,  16.,   3.],
        [  1.,   6.,   1.,   2.,   4.,  31.,   2.],
        [  9.,  51.,   8.,   8.,  24., 177.,  16.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.]], dtype=torch.float64)


0.5096952908587258

In [33]:
evaluate(model,test_dataloader,device,threshold=0.17)

tensor([[ 15.,  11.,   0.,   2.,   2.,   3.,   2.],
        [ 25., 252.,  72.,  44.,   6.,  53.,  46.],
        [  0.,  31.,  34.,   0.,   1.,   9.,  11.],
        [  8.,  31.,   5.,  61.,   6.,   5.,  18.],
        [  1.,   5.,   0.,   4.,  11.,  15.,   3.],
        [  8.,  48.,  10.,   7.,  24., 165.,  30.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.]], dtype=torch.float64)


0.496309963099631

In [34]:
# learn pytorch basic with some basic models and datasets
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
# https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
# https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
# https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html
# https://pytorch.org/tutorials/beginner/basics/nnqs_tutorial.html