In [1]:
!pip install wav2vec

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import numpy as np
import pandas as pd
import transformers
from transformers import BertTokenizer, BertModel, AutoModel, AutoProcessor
import torch, torchaudio, torchtext
import wav2vec
import os
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [3]:
try:
    df_path = '/kaggle/input/multimodal-argument-mining/MM-USElecDeb60to16/MM-USElecDeb60to16.csv'
    df = pd.read_csv(df_path, index_col=0)
except FileNotFoundError:
    df_path = 'multimodal-dataset/files/MM-USElecDeb60to16/MM-USElecDeb60to16.csv'
    df = pd.read_csv(df_path, index_col=0)
    
train_df = df[df['Set'] == 'TRAIN']
val_df = df[df['Set'] == 'VALIDATION']
test_df = df[df['Set'] == 'TEST']

In [4]:
df.head()

Unnamed: 0,Text,Part,Document,Order,Sentence,Start,End,Annotator,Tag,Component,...,Speaker,SpeakerType,Set,Date,Year,Name,MainTag,NewBegin,NewEnd,idClip
0,"CHENEY: Gwen, I want to thank you, and I want ...",1,30_2004,0,0,2101,2221,,"{""O"": 27}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,126.52,131.08,clip_0
1,"It's a very important event, and they've done ...",1,30_2004,1,1,2221,2304,,"{""O"": 19}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,131.08,134.4,clip_1
2,It's important to look at all of our developme...,1,30_2004,2,2,2304,2418,,"{""O"": 23}",O,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,O,134.4,140.56,clip_2
3,"And, after 9/11, it became clear that we had t...",1,30_2004,3,3,2418,2744,,"{""O"": 16, ""Claim"": 50}",Claim,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,Claim,140.56,158.92,clip_3
4,And we also then finally had to stand up democ...,1,30_2004,4,4,2744,2974,,"{""O"": 4, ""Claim"": 13, ""Premise"": 25}",Premise,...,CHENEY,Candidate,TRAIN,05 Oct 2004,2004,Richard(Dick) B. Cheney,Mixed,158.92,172.92,clip_4


In [18]:
text_model_card = 'bert-base-uncased'
audio_model_card = 'facebook/wav2vec2-base-960h'

class MM_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, audio_dir, sample_rate):
        self.audio_dir = audio_dir
        self.sample_rate = sample_rate

        self.audio_processor = AutoProcessor.from_pretrained(audio_model_card)
        self.audio_model = AutoModel.from_pretrained(audio_model_card).to(device)
        
        self.text_tokenizer = BertTokenizer.from_pretrained(text_model_card)
        self.text_model = BertModel.from_pretrained(text_model_card).to(device)
        self.dataset = []

        # Iterate over df
        for _, row in df.iterrows():
            path = os.path.join(self.audio_dir, f"{row['Document']}/{row['idClip']}.wav")
            if os.path.exists(path):
                # obtain audio features
                audio, sampling_rate = torchaudio.load(path)
                if sampling_rate != self.sample_rate:
                    audio = torchaudio.functional.resample(audio, sample_rate, self.sample_rate)
                with torch.inference_mode():
                    input_values = self.audio_processor(audio, sample_rate=self.sample_rate).input_values[0]
                    input_values = torch.tensor(input_values).to(device)
                    audio_model_output = self.audio_model(input_values)
                    audio_features = audio_model_output.last_hidden_state
                    # audio_features = audio_features[:, 0, :]

                # obtain BERT features
                text = row['Text']
                token_text = self.text_tokenizer(text, return_tensors='pt').to(device)
                with torch.inference_mode():
                    text_features = self.text_model(**token_text, return_dict=True)  
                    text_features = text_features.last_hidden_state[0]
                # append to dataset  
                self.dataset.append((text_features, audio_features, row['Component']))
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]

In [19]:
test_df = df[:1]

audio_path = 'multimodal-dataset/files/MM-USElecDeb60to16/audio_clips'

audio_dataset = MM_Dataset(test_df, audio_path, 16000)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [20]:
for text_f, audio_f, label in audio_dataset:
    print(text_f.shape, audio_f.shape, label)
    break

torch.Size([29, 768]) torch.Size([2, 683, 768]) O
