In [1]:
import os
import inspect
from math import floor
from tqdm import tqdm

import torch
import numpy as np
import librosa
from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification

In [None]:
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-er")
print(inspect.getsource(model.forward))

In [3]:
class IdentityModule(torch.nn.Module):
    def __init__(self):
        super(IdentityModule, self).__init__()

    def forward(self, x):
        return x
    
model.classifier = IdentityModule()

In [4]:
audio, sr = librosa.load('/kaggle/input/audio-abaw5/batch1/batch1/100-29-1080x1920.mp3', sr=16000)

with torch.no_grad():
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
    inputs = feature_extractor(audio[:16000], sampling_rate=16000, padding=True, return_tensors="pt")

    logits = model(**inputs)

preprocessor_config.json:   0%|          | 0.00/213 [00:00<?, ?B/s]

2024-02-11 23:40:15.222780: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-11 23:40:15.222876: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-11 23:40:15.350040: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
logits.logits.shape

torch.Size([1, 256])

In [6]:
data_dir = '/kaggle/input/audio-abaw5'
folders = ['batch1', 'batch2', 'new_vids']

names = []
global_features = []

step = 24000 #1.5sec
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-er")

In [7]:
def process_signal(local_name, audio, step=24000, sr=16000):
    names = []
    features = []
    th = floor(audio.shape[0] / step)

    for s in range(th):
        with torch.no_grad():
            fea = feature_extractor(audio[step*s:(s+1)*step], 
                                    sampling_rate=sr, padding=True, 
                                    return_tensors="pt")
            signal = model(**fea)
        features.append(signal.logits[0].detach().numpy())
        names.append(f'{local_name}/{str(s+1).zfill(5)}')

    if audio[step*th:].shape[0] > 0:
        new_step = step - audio[step*th:].shape[0]
        with torch.no_grad():
            fea = feature_extractor(audio[step*th - new_step:], 
                                    sampling_rate=sr, padding=True, 
                                    return_tensors="pt")

            signal = model(**fea)
        features.append(signal.logits[0].detach().numpy())
        names.append(f'{local_name}/{str(th+1).zfill(5)}')
        
    return names, features

In [8]:
for folder in folders:
    dirpath=os.path.join(data_dir, folder, folder)
    print(f'in {folder}')

    for filename in tqdm(os.listdir(dirpath)):
        fn, ext = os.path.splitext(os.path.basename(filename))
        if ext.lower()=='.mp3':
            local_name = f'{fn}'
            
            audio, sr = librosa.load(os.path.join(dirpath, filename), sr=16000)
            nn, fea = process_signal(local_name, audio, step)
            
            names += nn
            
            if len(global_features):
                global_features=np.concatenate((global_features, fea),axis=0)
            else:
                global_features = fea
                
        else:
            print(filename)

in batch1


100%|██████████| 475/475 [2:25:47<00:00, 18.42s/it]  


in batch2


100%|██████████| 73/73 [20:26<00:00, 16.80s/it]


in new_vids


100%|██████████| 50/50 [31:15<00:00, 37.51s/it] 


In [9]:
global_features.shape, len(names)

((73460, 256), 73460)

In [10]:
import pickle

filename2featuresAll={img_name: gl_feature for img_name, gl_feature 
                      in zip(names, global_features)}

with open('wav2vec_hubert_fea.pickle', 'wb') as handle:
    pickle.dump(filename2featuresAll, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [11]:
from IPython.display import FileLink

FileLink('wav2vec_hubert_fea.pickle')