In [1]:
"""
Script defining EvilMouDataSet Class and loaders to be used along with VAE model.
"""
import os, sys
import argparse
import numpy as np
import random
import torch
import time
import glob
from wfield import * #for loading wfield data in nice format
import torch
from torch.utils.data import Dataset, DataLoader
import h5py
from pathlib import Path
import pickle

In [3]:
class EvilMouDataSet(torch.utils.data.Dataset):
    """
    Defines EvilMouDataSet to be used with VAE model.
    This is NOT efficient at all in terms of mem usage.
    You might wish to do something like having your dset be list of .h5 files
    And then in the __getitem__ method pick only file and corresponding slices you want.
    """
    def __init__(self, video_dir, wfield_dir, transform=None):
        #collect all file names for files containing different frames
        all_frames = []
        for frame in Path(video_dir).rglob('frame*.h5'):
            all_frames.append(str(frame))
        #get first set of frames
        #load them
        f0 = h5py.File(all_frames[0], 'r+')
        all_data = f0['cam1'][:]
        #now read in rest of them and concatenate them over last axis
        #this should give an array with all frames in dset (160, 120, 89900)
#         for i in range(1, len(all_frames)):
#             f = h5py.File(all_frames[i], 'r+')
#             f_data = f['cam1'][:]
#             all_data = np.concatenate((all_data, f_data), axis=2)
        self.df_video      = all_data
        self.max_video     = np.amax(all_data.flatten())
        self.min_video     = np.amin(all_data.flatten())
        self.mean_video    = np.mean(all_data.flatten())
        
        wfield_data        = mmap_dat(wfield_dir)
        wfield_data        = wfield_data[0:100]
        self.df_wfield     = wfield_data
        self.max_wfield    = np.max(wfield_data)
        self.min_wfield    = np.min(wfield_data)
        self.mean_wfield   = np.mean(wfield_data)
        self.transform     = transform
    def __len__(self):
        """
        Returns number of samples in dset
        """
        return (int(self.df_wfield.shape[0] - 7))
    def __getitem__(self, idx):
        """
        Returns a single sample from dset.
        """
        time_start = idx
        time_end = time_start + 7 #am picking 1 frames at time here, this might be too much for your data!
        frame_video = self.df_video[:, :, time_start:time_end]
        scld_frame_video = np.true_divide((frame_video - self.min_video), (self.max_video - self.min_video)) #min/max norm (global)
        video_data   = torch.from_numpy(scld_frame_video)   # tensor of size [160, 120, 1]
        video_data   = video_data[:,:,0]
        
        frame_wfield = self.df_wfield[time_start:time_end, :, :, :]
        scld_frame_wfield = np.true_divide((frame_wfield - self.min_wfield), (self.max_wfield - self.min_wfield)) #min/max norm (global)
        wfield_data  = torch.from_numpy(scld_frame_wfield)  # tensor of size [7, 2, 540, 640]
        wfield_data  = wfield_data[0,:,:,:]
        return video_data, wfield_data