In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import os
from tqdm import tqdm
import argparse
import raw_dataset as dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Tokenizer
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor,Wav2Vec2Config
import numpy as np
import pandas as pd

In [None]:
# Load dataset

In [None]:
def torchaudio_load(filepath):
    wave, sr = librosa.load(filepath,sr=16000)
    waveform = torch.Tensor(np.expand_dims(wave, axis=0))
    return [waveform, sr]

def pad_dataset(wav):
    waveform = wav.squeeze(0)
    waveform_len = waveform.shape[0]
    cut = 64600
    if waveform_len >= cut:
        waveform = waveform[:cut]
        return waveform
    # need to pad
    num_repeats = int(cut / waveform_len) + 1
    padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]
    return padded_waveform

class SafeSpeakTest(Dataset):
    def __init__(self, path_to_audio="./SafeSpeak-2024/kaggle_data/wavs"):
        super(SafeSpeakTest, self).__init__()
        self.path_to_audio = path_to_audio
        self.files = os.listdir(self.path_to_audio)
        self.filepaths = [os.path.join(self.path_to_audio, filepath) for filepath in self.files]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        waveform, sr = torchaudio_load(filepath)

        return waveform, self.files[idx]

    def collate_fn(self, samples):
        return default_collate(samples)

In [None]:
feat_model_path = './cotrain_finetune.pt'
   
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ADD_model = torch.load(feat_model_path)
    
processor =  Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m") 
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-300m").cuda() 
    
model.config.output_hidden_states = True
ADD_model.eval()
with open(f'result.txt', 'w') as cm_score_file:
    safespeak_raw = SafeSpeakTest()
    for idx in tqdm(range(len(safespeak_raw))):
        waveform, filename = safespeak_raw[idx]
        waveform = waveform.to(device)
        waveform = pad_dataset(waveform).to('cpu')
        input_values = processor(waveform, sampling_rate=16000,
                                        return_tensors="pt").input_values.cuda()  
        with torch.no_grad():
            wav2vec2 = model(input_values).hidden_states[5].cuda()  
        w2v2, audio_fn= wav2vec2, filename
        this_feat_len = w2v2.shape[1]
        w2v2 = w2v2.unsqueeze(dim=0)
        w2v2 = w2v2.transpose(2, 3).to(device)
        feats, w2v2_outputs = ADD_model(w2v2)
        score = F.softmax(w2v2_outputs)[:, 0]
        cm_score_file.write('%s %s \n' % (audio_fn, score.item()))