In [10]:
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from IPython.display import Audio
from scipy.io import wavfile
import warnings
warnings.filterwarnings("ignore")
# import librosa

import torch
# from transformers import AutoFeatureExtractor, AutoProcessor, WhisperForConditionalGeneration
# from datasets import load_dataset
import torch.nn.functional as F
from python_speech_features import logfbank

plt.ion()   # interactive mode

In [11]:
%cd E:/university/FYT/repos/multi_modal_ser/finetune_encoder/audio_video/av_hubert/avhubert

import sys
sys.path.append("E:/university/FYT/repos/multi_modal_ser/finetune_encoder/audio_video/av_hubert/fairseq")
import utils as avhubert_utils

E:\university\FYT\repos\multi_modal_ser\finetune_encoder\audio_video\av_hubert\avhubert


In [12]:
# Constants
AUDIORATE = 16000

### Define Dataset

In [20]:
from torch.utils.data import Dataset, Subset
class MMSERDataset(Dataset):
    """multi model ser dataset."""
    
    def stacker(self, feats, stack_order):
        """
        Concatenating consecutive audio frames
        Args:
        feats - numpy.ndarray of shape [T, F]
        stack_order - int (number of neighboring frames to concatenate
        Returns:
        feats - numpy.ndarray of shape [T', F']
        """
        feat_dim = feats.shape[1]
        if len(feats) % stack_order != 0:
            res = stack_order - len(feats) % stack_order
            res = np.zeros([res, feat_dim]).astype(feats.dtype)
            feats = np.concatenate([feats, res], axis=0)
        feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
        return feats
        
    def __load_label__(self, cutmap_path):
        sheet_df = pd.DataFrame()
        for ses in range(1, 6):
            extractionmapPATH = cutmap_path + \
                str(ses)+'.xlsx'
            xl = pd.ExcelFile(extractionmapPATH)
            sheets = xl.sheet_names
            for sheet in sheets:
                sheet_df = pd.concat([sheet_df, xl.parse(sheet)])
        self.df_ = sheet_df
        
        # remove samples not agreed
        self.df_ = pd.merge(self.df_, self.df_text, on=["smp_id"])
        self.df_["emotion_id"] = self.df_["emotion"].map(self.emo2id)
        self.df_ = self.df_[self.df_["emotion_id"].notna()].reset_index(drop=True)
        self.df_["session"] = self.df_["smp_id"].apply(lambda x: x.split("_")[0])
        self.df_ = self.df_[self.df_["smp_id"].str.startswith("Ses01F_impro")].reset_index(drop=True)
        
    def __load_text__(self, text_path):
        self.df_text = pd.read_csv(text_path)
        pass
    
    def __load_audio__(self, fn_path):
        self.fn_list = list(self.df_["smp_id"])
        self.raw_list = []
        for fn in self.fn_list:
            self.raw_list.append(wavfile.read(os.path.join(fn_path, fn)+'.wav')[1])
    
    def __load_video__(self, idx):
        frames = avhubert_utils.load_video(os.path.join(self.video_path, idx+".mp4"))
#         transform = avhubert_utils.Compose([
#           avhubert_utils.Normalize(0.0, 255.0),
#           avhubert_utils.CenterCrop((task.cfg.image_crop_size, task.cfg.image_crop_size)),
#           avhubert_utils.Normalize(task.cfg.image_mean, task.cfg.image_std)])
        # frames = transform(frames)
        frames = torch.FloatTensor(frames).unsqueeze(dim=0).unsqueeze(dim=0)
        video_feats = frames
        return video_feats
    
    def __init__(self, 
                 fn_path, 
                 cutmap_path, 
                 text_path, 
                 video_path, 
                 emo2id,
                 audio_in_features = 104):
        
        self.emo2id = emo2id
        self.audio_in_features = audio_in_features
        self.video_path = video_path
        self.__load_text__(text_path)
        self.__load_label__(cutmap_path)
        self.__load_audio__(fn_path)
        
    def __len__(self):
        return self.df_.shape[0]
    
    def __getsingle__(self, idx):
        raw_audio = self.raw_list[idx]
        video_feats = self.__load_video__(self.fn_list[idx])
        audio_feats = logfbank(raw_audio, samplerate=AUDIORATE).astype(np.float32) # [T, F]
        audio_feats = self.stacker(audio_feats, self.audio_in_features//26) # [T/stack_order_audio, F*stack_order_audio]

        diff = audio_feats.shape[0] - video_feats.shape[2]
        if diff < 0:
            audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
        elif diff > 0:
            audio_feats = audio_feats[:-diff]

        with torch.no_grad():
            audio_feats = torch.from_numpy(audio_feats.astype(np.float32)).T
            audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
            audio_feats = audio_feats.unsqueeze(dim=0)
        return audio_feats, video_feats
    
    def collate(self, audio, video, max_size=500):
        padded_audio = pad_sequence([a.T.squeeze() for a in audio], batch_first=True)
        padded_video = pad_sequence([v.squeeze() for v in video], batch_first=True)[:, np.newaxis, : ,:,:]
        mask = torch.zeros_like(padded_audio)
        mask[padded_audio != 0] = 1
        return padded_audio, padded_video, mask
        
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            ret_dict = {}
            for key in self.__getitem__(0).keys():
                ret_dict[key] = [self.__getitem__(i)[key] for i in range(*idx.indices(len(self)))]
            
            padded_audio, padded_video, padding_mask = self.collate(ret_dict["audio"], ret_dict["video"])
            ret_dict["padding_mask"] = padding_mask
            ret_dict["audio"] = padded_audio.transpose(1, 2)
            ret_dict["video"] = padded_video
            return ret_dict
        else:
            audio_feats, video_feats = self.__getsingle__(idx)

            return {
                "sess": list(self.df_["session"])[idx],
                "fn": self.fn_list[idx],
                "audio": audio_feats,
                "video": video_feats,
                "text": list(self.df_["transcript"])[idx],
                "labels": list(self.df_["emotion_id"])[idx]
            }

### Dataset Init

In [21]:
emo2id_dict={
    "hap": 0,
    "ang": 1,
    "neu": 2,
    "sad": 3,
    "exc": 0,
       }

mmser_ds = MMSERDataset(fn_path = "E:/datasets/preprocessed/spectrogram/raw", 
                        cutmap_path = 'E:/datasets/preprocessed/extractionmap/cut_extractionmap', 
                        text_path = "E:/datasets/preprocessed/transcipt/transcript.csv", 
                        video_path = "E:/datasets/preprocessed/face_raw",
                        emo2id=emo2id_dict)


In [22]:
# print(mmser_ds[22:24])
# print(len(mmser_ds))
# print(mmser_ds.df_["emotion_id"].value_counts())
# mmser_ds.df_["emotion_id"].value_counts().plot(kind="pie")

### Save Dataset

In [23]:
torch.save(mmser_ds, "E:/datasets/preprocessed/dataset/avhubert_ds.pt")

In [24]:
mmser_ds[2]
mmser_ds.df_.head()

Unnamed: 0,Unnamed: 0_x,iframe,fframe,emotion,speaker,smp_id,emovectorA,emovectorB,emovectorC,Unnamed: 0_y,transcript,emotion_id,session
0,0,100642,131771,neu,L,Ses01F_impro01_F000,2.5,2.5,2.5,0,Excuse me.,2.0,Ses01F
1,2,160161,182280,neu,L,Ses01F_impro01_F001,2.5,2.5,2.5,2,Yeah.,2.0,Ses01F
2,4,238196,288280,neu,L,Ses01F_impro01_F002,2.5,2.5,2.5,4,Is there a problem?,2.0,Ses01F
3,9,439361,503840,neu,L,Ses01F_impro01_F005,2.5,3.5,2.0,9,Well what's the problem? Let me change it.,2.0,Ses01F
4,23,1364321,1408320,ang,L,Ses01F_impro01_F012,2.0,3.5,3.5,23,That's out of control.,1.0,Ses01F


### Load Model

In [25]:
%cd E:/university/FYT/repos/multi_modal_ser/finetune_encoder/audio_video/av_hubert/avhubert

from fairseq import checkpoint_utils, options, tasks, utils
ckpt_path = "E:/check_pts/avhubert.pt"
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])  
model = models[0]
if hasattr(models[0], 'decoder'):
    print(f"Checkpoint: fine-tuned")
    model = models[0].encoder.w2v_model
else:
    print(f"Checkpoint: pre-trained w/o fine-tuning")

E:\university\FYT\repos\multi_modal_ser\finetune_encoder\audio_video\av_hubert\avhubert


AssertionError: Could not infer task type from {'_name': 'av_hubert_pretraining', 'is_s2s': True, 'data': '/checkpoint/bshi/data/lrs3//exp/ls-hubert/tune-modality/all_tsv/', 'label_dir': '/checkpoint/bshi/data/lrs3//exp/ls-hubert/tune-modality/all_bpe/unigram1000/', 'normalize': True, 'labels': ['wrd'], 'single_target': True, 'stack_order_audio': 4, 'tokenizer_bpe_name': 'sentencepiece', 'max_sample_size': 500, 'modalities': ['video'], 'image_aug': True, 'pad_audio': True, 'random_crop': False, 'tokenizer_bpe_model': '/checkpoint/bshi/data/lrs3//lang/spm/spm_unigram1000.model', 'fine_tuning': True}. Available tasks: dict_keys(['audio_pretraining', 'cross_lingual_lm', 'denoising', 'hubert_pretraining', 'language_modeling', 'legacy_masked_lm', 'masked_lm', 'multilingual_denoising', 'multilingual_masked_lm', 'translation', 'multilingual_translation', 'online_backtranslation', 'semisupervised_translation', 'sentence_prediction', 'sentence_ranking', 'speech_to_text', 'simul_speech_to_text', 'simul_text_to_text', 'translation_from_pretrained_bart', 'translation_from_pretrained_xlm', 'translation_lev', 'translation_multi_simple_epoch', 'dummy_lm', 'dummy_masked_lm', 'dummy_mt'])