In [None]:
import argparse
from dataclasses import dataclass
import numpy as np
import soundfile as sf
from tqdm import tqdm
import torch
import torch.nn.functional as F
import fairseq
import pandas as pd
import os
import pickle
import torchaudio
import torchaudio.transforms as T

# Set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

@dataclass
class UserDirModule:
    user_dir: str

model_dir = '/home/legalalien/Documents/Jiawei/EmoTracjectory/emotion2vec/upstream'
checkpoint_dir = '/home/legalalien/Documents/Jiawei/EmoTracjectory/emotion2vec/emotion2vec_base.pt'
granularity = 'utterance'

model_path = UserDirModule(model_dir)
fairseq.utils.import_user_module(model_path)
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir])
model = model[0]
model.eval()
model.cuda()


def load_audio(file_path):
    """Load audio file and resample to 16000 Hz if necessary."""
    audio, sr = torchaudio.load(file_path)
    if sr != 16000:
        print("This audio is not 16kHz, resampling to 16kHz:", file_path)
        resampler = T.Resample(orig_freq=sr, new_freq=16000)
        audio = resampler(audio)
        sr = 16000
    return audio, sr


def process_data():
    # parser = get_parser()
    # args = parser.parse_args()
    # print(args)
    source_csv_file = '/home/legalalien/Documents/Jiawei/EmoTracjectory/data_pre/sorted_detailed_train.csv'
    csv_file = pd.read_csv(source_csv_file)
    # wavs = csv_file['Wav-path'].values
    
    target_pkl_file = '/home/legalalien/Documents/Jiawei/EmoTracjectory/data_pre/emo2vec_train.pkl'
    # generate target npy file if it does not exist
    results = []
    
    for index, row in tqdm(csv_file.iterrows(), total=len(csv_file)):
        wav_path = row['Wav-path']
        if not os.path.exists(wav_path):
            print("File not found: {}".format(wav_path))
            continue
        # audio = load_audio(wav_path)
        audio = wav_path

        if audio.endswith('.wav'):
            wav, sr = load_audio(audio)
            channel = sf.info(audio).channels

            assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, audio)
            assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, audio)
            
        with torch.no_grad():
            source = wav.float().cuda()
            if task.cfg.normalize:
                source = F.layer_norm(source, source.shape)
            try:
                feats = model.extract_features(source, padding_mask=None)
                feats = feats['x'].squeeze(0).cpu().numpy()
                if granularity == 'frame':
                    feats = feats
                elif granularity == 'utterance':
                    feats = np.mean(feats, axis=0)
                else:
                    raise ValueError("Unknown granularity: {}".format(args.granularity))
                results.append(feats)
                # np.save(target_file, feats)
            except:
                Exception("Error in extracting features from {}".format(audio))

    results = np.array(results)
    print("Extracted features shape: {}".format(results.shape))
    with open(target_pkl_file, 'wb') as f:
        pickle.dump(results, f)
    print("Saved features to {}".format(target_pkl_file))


if __name__ == '__main__':
    
    process_data()

In [3]:
# Load pkl file
import pickle

file_path = '/home/legalalien/Documents/Jiawei/EmoTracjectory/data_pre/emo2vec_train.pkl'
def load_pkl(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
        print("Loaded data shape: {}".format(data.shape))
        print("Loaded data type: {}".format(type(data)))
    return data

data = load_pkl(file_path)

Loaded data shape: (15084, 768)
Loaded data type: <class 'numpy.ndarray'>
