In [None]:
import sys
sys.path.append("./..")
from config import config
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
import os, sys
from glob import glob
import shutil
import json
import librosa
import torch
import transformers

from modelsage import *

In [None]:
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
)

In [None]:
# load model from hub
device = torch.device(config['cude_device'])
model_name = config["age_gender_model"]

processor = Wav2Vec2Processor.from_pretrained(model_name)
model = AgeGenderModel.from_pretrained(model_name).to(device)

In [None]:
with open(config['json_path'], 'r') as fp:
    all_data = json.load(fp)

In [None]:
audio_paths = glob(os.path.join(config['podcast']['path'], "*.wav"))
audio_paths.sort()
audio_paths = {os.path.basename(x).split('.')[0]:x for x in audio_paths}

In [None]:
class ProcessPipeline(torch.utils.data.Dataset):
    def __init__(self, all_data, audios):
        print("Organizing the data")
        self.X = []
        self.segments = []
        self.audios = audios
        self.data = all_data
        for i, pod_name in enumerate(tqdm(list(all_data.keys()))):
            for seg_name, seg in all_data[pod_name].items():
                if 'age_gender' in seg: continue
                if (seg['end']-seg['start']) > 250: continue
                
                self.X.append([seg['end']-seg['start'], pod_name, seg_name])
        
        self.X.sort(reverse=False)

        
        
    def __len__(self):
        return len(self.X)
        
        
    def __getitem__(self, index):
        """ get a video and its label """
        _, podcast, seg_name  = self.X[index]
        
        seg = self.data[podcast][seg_name]

        start = seg['start']
        end = seg['end']
        wav, sr = librosa.load(self.audios[podcast], offset=start, duration=end-start, sr=16000)
#         x = test_data_processing(wav, self.mean, self.std)
#         x = x.reshape((1, x.shape[0], x.shape[1]))
        
        return podcast, seg_name, wav


In [None]:
def colate_fun(x):
    data = []
    lengths = []
    for sample in x:
        data.append([sample[0], sample[1], np.array(sample[2], dtype=np.float32)])
        lengths.append(len(sample[2]))
        
    aud = np.zeros((len(x), max(lengths)))
    for i, a in enumerate(data):
        aud[i][:len(a[2])] = a[2]
    return data, np.array(aud, dtype=np.float32), lengths

In [None]:
process_data = ProcessPipeline(all_data, audio_paths)
process_loader = torch.utils.data.DataLoader(process_data, batch_size = 32, num_workers=2, pin_memory=True, shuffle = False,
                                              collate_fn=colate_fun)

for info, auds, lengths in tqdm(process_loader):
    
    outputs = process_func(auds, 16000, processor, model, device)#.item()
    
    for i, output in zip(info, outputs):
        podcast, seg_name, x =  i
        buf ={
                'Age':float(round(output[0]*100, 1)),
                'Female':float(round(output[1], 2)),
                'Male':float(round(output[2], 2)),
                'Child':float(round(output[3], 2)),
            }
            
        all_data[podcast][seg_name]['age_gender'] = buf

In [None]:
with open(config['json_path'], 'w') as fp:
    json.dump(all_data, fp)