In [1]:
import os
import glob
import torch
import numpy as np
import kaldiio
from torch.utils.data import Dataset

In [12]:
class AudioFeatureDataset(Dataset):
    
    def __init__(self, HOW2_PATH, subfolder):
        
        if subfolder not in ['train', 'val', 'dev5']:
            raise ValueError('subfolder must be either train, val or dev5')
        # Set relative paths for 300h
        self.HOW2_PATH = HOW2_PATH
        self.BASE_PATH = self.HOW2_PATH + 'how2-300h-v1/data/'
        self.AUDIO_PATH = self.HOW2_PATH + 'fbank_pitch_181506/'
        # Read id file
        self.ids = self.get_ids(self.BASE_PATH + f'{subfolder}/id')
        
        # Map id-audio
        self.mapping = self.make_dict()
    
    def get_ids(self, id_file):
        with open(id_file) as f:
            content = f.read()
        return content.strip().split('\n')
    
    def make_dict(self):
        all_scp = [file for file in glob.glob(self.AUDIO_PATH + '*.scp') if 'raw' in file]
        mapping = dict()
        
        for scpfile in all_scp:
            with open(scpfile) as f:
                for line in f:
                    video_id, audio_path = line.strip().split()
                    audio_path = audio_path.replace('ARK_PATH/', self.AUDIO_PATH, 1)
                    mapping[video_id] = audio_path
        return mapping
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        sample_id = self.ids[idx]
        sample_ark = self.mapping[sample_id]
        return kaldiio.load_mat(sample_ark)

In [13]:
HOW2_PATH = './data/how2-dataset/'

d = AudioFeatureDataset(HOW2_PATH, 'train')

In [14]:
d[0]

array([[ 1.33435392e+01,  1.34493484e+01,  1.15592442e+01, ...,
         6.57519177e-02, -2.48693824e-02,  4.65756506e-02],
       [ 1.43025455e+01,  1.60679817e+01,  1.69987659e+01, ...,
        -1.93527699e-01, -2.48693824e-02,  1.37654662e-01],
       [ 1.52859964e+01,  1.67500744e+01,  1.98471756e+01, ...,
        -2.42214203e-01,  4.43451107e-02,  1.50330648e-01],
       ...,
       [ 1.24373922e+01,  1.46604671e+01,  1.52252274e+01, ...,
         1.04104526e-01,  7.18828619e-01,  3.59221399e-02],
       [ 1.22473936e+01,  1.43004045e+01,  1.46340475e+01, ...,
         1.27116099e-01,  7.18828619e-01,  1.46151185e-02],
       [ 1.34896917e+01,  1.40385408e+01,  1.35591764e+01, ...,
         9.38771591e-02,  7.18828619e-01,  1.35497674e-02]], dtype=float32)

In [15]:
print(len(d.ids))

184949


In [16]:
dt = AudioFeatureDataset(HOW2_PATH, 'val')

In [17]:
dt[0]

array([[ 1.19268036e+01,  1.26650295e+01,  1.38440723e+01, ...,
        -4.81955916e-01,  8.70175362e-02,  3.98578793e-02],
       [ 1.14035606e+01,  1.15549412e+01,  1.14114571e+01, ...,
        -7.07235694e-01,  6.60834312e-02,  2.88045555e-02],
       [ 1.19495325e+01,  1.18188782e+01,  1.28029146e+01, ...,
        -1.00062466e+00,  7.65504837e-02,  9.19905752e-02],
       ...,
       [ 1.29609966e+01,  1.61369553e+01,  1.65433693e+01, ...,
        -7.71601379e-01, -6.99882507e-02,  1.88073963e-01],
       [ 1.28473492e+01,  1.63066597e+01,  1.69675446e+01, ...,
        -6.42870069e-01, -2.81200409e-02,  1.94079176e-01],
       [ 1.32132759e+01,  1.58654318e+01,  1.69289818e+01, ...,
        -5.20575285e-01,  1.37481689e-02,  1.22016631e-01]], dtype=float32)

In [18]:
len(dt)

2022