In [2]:
import numpy as np 
import pandas as pd
from fastai.data_block import get_files
from fastprogress import master_bar, progress_bar
import tensorflow as tf
import torch 
import os
import pickle

In [3]:
IN_PATH  = 'yt8m/3/frame/validate'
OUT_PATH = 'yt8m-processed'

files = [str(f) for f in get_files(IN_PATH, '.tfrecord')]
files

['yt8m/3/frame/validate/validate1032.tfrecord',
 'yt8m/3/frame/validate/validate2324.tfrecord',
 'yt8m/3/frame/validate/validate2351.tfrecord',
 'yt8m/3/frame/validate/validate3253.tfrecord',
 'yt8m/3/frame/validate/validate2328.tfrecord',
 'yt8m/3/frame/validate/validate3697.tfrecord',
 'yt8m/3/frame/validate/validate1343.tfrecord',
 'yt8m/3/frame/validate/validate0468.tfrecord',
 'yt8m/3/frame/validate/validate2820.tfrecord',
 'yt8m/3/frame/validate/validate1076.tfrecord',
 'yt8m/3/frame/validate/validate0155.tfrecord',
 'yt8m/3/frame/validate/validate1981.tfrecord',
 'yt8m/3/frame/validate/validate2442.tfrecord',
 'yt8m/3/frame/validate/validate1389.tfrecord',
 'yt8m/3/frame/validate/validate1376.tfrecord',
 'yt8m/3/frame/validate/validate2468.tfrecord',
 'yt8m/3/frame/validate/validate1112.tfrecord',
 'yt8m/3/frame/validate/validate0205.tfrecord',
 'yt8m/3/frame/validate/validate2948.tfrecord',
 'yt8m/3/frame/validate/validate0033.tfrecord',
 'yt8m/3/frame/validate/validate1497.tfr

In [4]:
raw_dataset = tf.data.TFRecordDataset(files)

In [50]:
A_SZ, V_SZ, L_SZ = 128,1024,303
SZ= int(2e7)

a = torch.empty((SZ,A_SZ), dtype=torch.uint8) # audio
v = torch.empty((SZ,V_SZ), dtype=torch.uint8) # video
d = dict()

In [51]:
def mmap_bytetensor(f, shape, shared=False,size=None):
    return torch.ByteTensor(
        torch.ByteStorage.from_file(f, shared=shared, size=size if size is not None else os.path.getsize(f))).view(*shape)

In [58]:
i_f = i_av = 0
TOTAL = 47073
for i,raw_record in enumerate(progress_bar(raw_dataset.take(-1), total=TOTAL)):
        
    example = tf.train.SequenceExample()
    example.ParseFromString(raw_record.numpy())
    labels = list(example.context.feature['labels'].int64_list.value)
    vid    = str(example.context.feature['id'].bytes_list.value[0].decode(encoding='UTF-8'))
    audio  = torch.ByteTensor(
        [np.frombuffer(f.bytes_list.value[0], dtype=np.uint8) for f in example.feature_lists.feature_list['audio'].feature])
    video  = torch.ByteTensor(
        [np.frombuffer(f.bytes_list.value[0], dtype=np.uint8) for f in example.feature_lists.feature_list['rgb'].feature])
    frames = len(audio)
    segments = list(zip(*[getattr(example.context.feature[f],t).value for f,t in [('segment_labels','int64_list'), 
        ('segment_scores','float_list'), ('segment_start_times','int64_list'),('segment_end_times','int64_list')]]))

    assert frames == len(video)
    assert frames <= L_SZ
    d[vid] = (i_av, frames, labels, segments)
    a[i_av:i_av+frames] = audio
    v[i_av:i_av+frames] = video
    del audio, video, labels, segments
    i_av += frames

    if ((i_av + L_SZ)>= SZ) or (i==TOTAL-1):
        print(f'Saving {i_f} at {i}')
        aw = mmap_bytetensor(f'{OUT_PATH}/val_a{i_f}.bin',(-1,A_SZ),  shared=True, size=i_av*A_SZ)
        vw = mmap_bytetensor(f'{OUT_PATH}/val_v{i_f}.bin',(-1,V_SZ),  shared=True, size=i_av*V_SZ)
        
        aw[:] = a[:i_av]
        vw[:] = v[:i_av]
        with open(f'{OUT_PATH}/val_d{i_f}.pkl', 'wb') as handle:
            pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)
        d = dict()
        i_av = 0
        i_f += 1
        del aw,vw

    if i == 2:
        break

    


In [60]:
d

{'vgqO': (0,
  188,
  [146, 158, 439, 454],
  [(279, 0.0, 115, 120),
   (279, 0.0, 145, 150),
   (279, 0.0, 90, 95),
   (279, 0.0, 150, 155),
   (279, 0.0, 55, 60)]),
 'O6qO': (188,
  152,
  [6, 8, 414],
  [(414, 0.0, 105, 110),
   (414, 1.0, 35, 40),
   (414, 1.0, 100, 105),
   (414, 1.0, 65, 70),
   (414, 0.0, 80, 85)]),
 'OLqO': (340,
  228,
  [2, 44, 46, 107, 109, 114, 162, 854, 1164],
  [(854, 0.0, 180, 185),
   (1164, 0.0, 185, 190),
   (854, 0.0, 30, 35),
   (1164, 1.0, 75, 80),
   (1164, 1.0, 65, 70),
   (1164, 1.0, 40, 45),
   (854, 0.0, 130, 135),
   (1164, 0.0, 140, 145),
   (854, 0.0, 185, 190)])}

[(279, 0.0, 115, 120),
 (279, 0.0, 145, 150),
 (279, 0.0, 90, 95),
 (279, 0.0, 150, 155),
 (279, 0.0, 55, 60)]

In [None]:
ifnone(v.numel(), 10)

In [None]:
for i in progress_bar(range(1,89)): # 89
    a = torch.load(f'{OUT_PATH}/a{i}.pth')
    aw = mmap_bytetensor(f'{OUT_PATH}/a{i}.bin',(-1,128),  shared=True, size=a.numel())
    aw[:] = a
    v = torch.load(f'{OUT_PATH}/v{i}.pth')
    vw = mmap_bytetensor(f'{OUT_PATH}/v{i}.bin',(-1,1024), shared=True, size=v.numel())
    vw[:] = v
    os.remove(f'{OUT_PATH}/a{i}.pth')
    os.remove(f'{OUT_PATH}/v{i}.pth')

In [None]:
vvw=read_bytetensor('yt8m-processed/v0.bin', (-1,1024))

In [None]:
n = 89
[read_bytetensor(f'yt8m-processed/v{i}.bin', (-1,1024)) for i in range(n)]