In [4]:
# import torch
# from torch.utils.data import Dataset, DataLoader
import pandas as pd
from pycochleagram.utils import wav_to_array
from pycochleagram.cochleagram import human_cochleagram
import cv2
import numpy as np
import scipy
import torchvision.transforms as T
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
import pickle

In [2]:
class VISDataset(Dataset):
    def __init__(self, root, coch_root, dataset_file, window_duration=0.5, datum_len=45, transform=T.Compose([T.Resize(256, antialias=False)]),is_eval=False):
        self.root = root
        self.coch_root = coch_root
        
        with open(f'{self.root}/{dataset_file}', 'r') as f:
            self.file_list = [file.strip() for file in f.readlines()]
        
        self.is_eval = is_eval
        self.transform = transform
        
        self.video_fps = 30
        self.window_duration = window_duration
        self.n_frames = int(window_duration * self.video_fps)
        assert datum_len % self.n_frames == 0
        self.n_tiles = datum_len // self.n_frames
        
        self.data = []
        
        for file in tqdm(self.file_list):
            try:    
                vid = cv2.VideoCapture(f'{self.root}/{file}_denoised.mp4')
                frames = []
                while True:
                    ret, frame = vid.read()
                    if not ret:
                        break
                    frames.append(frame)
                vid.release()
                
                wav, sample_rate = wav_to_array(f'{self.root}/{file}_denoised.wav')
                annotations = pd.read_csv(f'{self.root}/{file}_times.txt', sep=' ', names=['Time', 'Material', 'Contact Type', 'Motion Type'])
                
                cochleagrams = scipy.io.loadmat(f'{self.coch_root}/{file}_sf.mat')['sfs']
                
                for ind, row in annotations.iterrows():
                    datum = {}
                    peak_time = row['Time']
                    peak_vid = int(peak_time * self.video_fps)
                    frames_rgb = np.stack(frames[peak_vid-self.n_frames//2:1+peak_vid+self.n_frames//2])
                    frames_spacetime = np.stack([self.get_spacetime(frames[i-1:i+2]) for i in range(peak_vid-self.n_frames//2, 1+peak_vid+self.n_frames//2)])
                    frames_rgb = np.repeat(frames_rgb, self.n_tiles, axis=0).transpose(0, 3, 1, 2)
                    frames_spacetime = np.repeat(frames_spacetime, self.n_tiles, axis=0)                    
                    frames_rgb = self.transform(torch.tensor(frames_rgb))
                    frames_spacetime = self.transform(torch.tensor(frames_spacetime))
                    
                    start_time = peak_time - window_duration/2
                    end_time = peak_time + window_duration/2
                    start_frame = int(start_time * sample_rate)
                    end_frame = int(end_time * sample_rate)
                    peak = wav[start_frame:end_frame]
                    coch = human_cochleagram(peak, sample_rate, n=40, low_lim=100, hi_lim=10000, sample_factor=1, downsample=90, nonlinearity='power')
                            
                    datum['frames_rgb'] = frames_rgb
                    datum['frames_spacetime'] = frames_spacetime
                    datum['og_cochleagram'] = torch.tensor(cochleagrams[ind])
                    datum['cochleagram'] = torch.tensor(coch).transpose(1, 0)
                    datum['material'] = row['Material']
                    self.data.append(datum)
                    
            except:
                pass
    
    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)           
    
    def get_spacetime(self, frames):
        return np.stack([cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) for frame in frames])

    def dump(self, fname):
        with open(fname, 'wb') as f:
            pickle.dump(self.data, f)

In [3]:
ds = VISDataset('../../data/vis-data-256', '../../data/vis-data', 'train_sample.txt')

 20%|██        | 1/5 [00:00<00:03,  1.15it/s]

here


  freqs_to_plot = np.log10(freqs)


torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([45, 42]) torch.Size([45, 42])
here
torch.Size([4

 40%|████      | 2/5 [00:13<00:23,  7.76s/it][mov,mp4,m4a,3gp,3g2,mj2 @ 0x88c5640] moov atom not found
[mov,mp4,m4a,3gp,3g2,mj2 @ 0x892d340] moov atom not found
[mov,mp4,m4a,3gp,3g2,mj2 @ 0x892cf00] moov atom not found
100%|██████████| 5/5 [00:13<00:00,  2.70s/it]

torch.Size([45, 42]) torch.Size([45, 42])





In [26]:
print(len(ds))
print(ds.__getitem__(3).keys())

93
dict_keys(['frames_rgb', 'frames_spacetime', 'og_cochleagram', 'material'])
