In [None]:
import argparse
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
from utils import setup_seed


torch.set_default_tensor_type(torch.FloatTensor)
torch.autograd.set_detect_anomaly(True)
torch.use_deterministic_algorithms(True)

In [None]:
def padding(spec, ref_len):
    width, cur_len = spec.shape
    assert ref_len > cur_len
    padd_len = ref_len - cur_len
    return torch.cat((spec, torch.zeros(width, padd_len, dtype=spec.dtype)), 1)

def repeat_padding(spec, ref_len):
    mul = int(np.ceil(ref_len / spec.shape[1]))
    spec = spec.repeat(1, mul)[:, :ref_len]
    return spec

class SafeSpeakTest(Dataset):
    def __init__(self, path_to_audio="./SafeSpeak-2024/kaggle_data/wavs"):
        super(SafeSpeakTest, self).__init__()
        self.path_to_audio = path_to_audio
        self.files = os.listdir(self.path_to_audio)
        self.filepaths = [os.path.join(self.path_to_audio, filepath) for filepath in self.files]
        self.feat_len = 128000
        self.padding = 'repeat'

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filepath = self.filepaths[idx]
        wav, sr = torchaudio_load(filepath)
        this_feat_len = wav.shape[1]
        if this_feat_len > self.feat_len:
            wav = wav[:, :self.feat_len]

        if this_feat_len < self.feat_len:
            if self.padding == 'zero':
                wav = padding(wav, self.feat_len)
            elif self.padding == 'repeat':
                wav = repeat_padding(wav, self.feat_len)
            else:
                raise ValueError('Padding should be zero or repeat!')

        return wav.squeeze(0), self.files[idx]

    def collate_fn(self, samples):
        return default_collate(samples)

In [None]:
def test_model():

    model_path = './ocnet_finetune.pt'
    model = torch.load(model_path, map_location="cuda")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)


    #safespeak_raw = SafeSpeakTest()
    test_set = SafeSpeakTest()
    testDataLoader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
                               collate_fn=test_set.collate_fn)                             

    model.eval()
    with open('result.txt', 'w') as cm_score_file:
        with torch.no_grad():
            for i, (wave, audio_fn) in enumerate(tqdm(testDataLoader)):

                wave = wave.float().to(device)
                #labels = labels.to(device)

                #tags = tags.to(device)
                score = model(wave).float()
                if score.size(-1)>1:
                    score = F.softmax(score)[:,0].float()

                audio_fn = list(audio_fn)
                for j in range(len(audio_fn)):
                    cm_score_file.write(
                        '%s %s\n' % (audio_fn[j], score[j].item()))

In [None]:
test_model()