In [11]:
import glob
import os
import scipy
import torch
import numpy as np
import flow_transforms
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.nn.functional import grid_sample
from imageio import imread
import random
import imageio

def load_flo(path):
    with open(path, 'rb') as f:
        magic = np.fromfile(f, np.float32, count=1)
        assert(202021.25 == magic),'Magic number incorrect. Invalid .flo file'
        h = np.fromfile(f, np.int32, count=1)[0]
        w = np.fromfile(f, np.int32, count=1)[0]
        data = np.fromfile(f, np.float32, count=2*w*h)
    # Reshape data into 3D array (columns, rows, bands)
    data2D = np.resize(data, (w, h, 2))
    return data2D

def default_loader(root, path_imgs, path_flo, path_occ):
    imgs = [os.path.join(root,path) for path in path_imgs]
    flo = os.path.join(root,path_flo)
    occ = os.path.join(root,path_occ)
    return [imread(img).astype(np.float32) for img in imgs], load_flo(flo), imread(occ)

In [12]:

class ExtractDataset:
    def __init__(self, root, scene_dir_list, data_type, save_dir="./save"):
        self.root = root
        self.type = data_type
        self.scene_dir_list = scene_dir_list
        self.data = self.get_data(scene_dir_list)
        self.save_dir = save_dir
        self.pair_data = self.get_data(scene_dir_list, True)
        
    def load_flo(self, path):
        with open(path, 'rb') as f:
            magic = np.fromfile(f, np.float32, count=1)
            assert(202021.25 == magic),'Magic number incorrect. Invalid .flo file'
            h = np.fromfile(f, np.int32, count=1)[0]
            w = np.fromfile(f, np.int32, count=1)[0]
            data = np.fromfile(f, np.float32, count=2*w*h)
        # Reshape data into 3D array (columns, rows, bands)
        data2D = np.resize(data, (w, h, 2))
        return data2D
    
    def default_loader(self, root, path_imgs, path_flo, path_occ):
        imgs = [os.path.join(root,path) for path in path_imgs]
        flo = os.path.join(root,path_flo)
        occ = os.path.join(root,path_occ)
        return [imread(img).astype(np.float32) for img in imgs], self.load_flo(flo), imread(occ)
        

    def save_flo(self, filename, flow):
        TAG_STRING = b'PIEH'
        # torch.Size([436, 1024, 2])
        height, width, nBands = np.shape(flow)
        
        u = flow[: , : , 0]
        v = flow[: , : , 1]
        
        height, width = u.shape
        f = open(filename,'wb')
        f.write(TAG_STRING)
        np.array(width).astype(np.int32).tofile(f)
        np.array(height).astype(np.int32).tofile(f)
        tmp = np.zeros((height, width*nBands))
        tmp[:,np.arange(width)*2] = u
        tmp[:,np.arange(width)*2 + 1] = v
        tmp.astype(np.float32).tofile(f)
        f.close()
        
    
    def get_data(self, scene_dir_list, is_pair=False):
        whole_file = []
        for scene_dir in self.scene_dir_list:
            single_dir_file = []
            filelist = sorted(glob.glob(os.path.join(self.root,'flow',scene_dir,'*.flo')))
            for flow_map in filelist:
                flow_map = os.path.relpath(flow_map, os.path.join('sintel','flow'))
                
                scene_dir, filename = os.path.split(flow_map)
                no_ext_filename = os.path.splitext(filename)[0]
                prefix, frame_nb = no_ext_filename.split('_')
                frame_nb = int(frame_nb)
                
                occ_mask = os.path.join('occlusions', scene_dir, '{}_{:04d}.png'.format(prefix, frame_nb))
                flow_map = os.path.join('flow', flow_map)
                
                if(is_pair==False) : 
                    img = os.path.join(self.type, scene_dir, '{}_{:04d}.png'.format(prefix, frame_nb))
                    if not (os.path.isfile(os.path.join('sintel',img))):
                        continue
                    single_dir_file.append([img, flow_map, occ_mask])
                    
                else:
                    img1 = os.path.join('clean', scene_dir, '{}_{:04d}.png'.format(prefix, frame_nb))
                    img2 = os.path.join('clean', scene_dir, '{}_{:04d}.png'.format(prefix, frame_nb+1))
                    if not (os.path.isfile(os.path.join('sintel',img1)) and os.path.isfile(os.path.join('sintel',img2))):
                        continue
                    whole_file.append([[img1, img2], flow_map, occ_mask])
                    
            if(is_pair==False):
                whole_file.append(single_dir_file)
            
        return whole_file
    
    def extract_data(self):
        if not(os.path.isdir(self.save_dir)):
            os.makedirs(os.path.join(self.save_dir))
            
        for dir_num, dir_list in enumerate(self.data): # per directory
            
            tmp = self.get_data(self.scene_dir_list[dir_num], True)
            
            for cnt, file in enumerate(dir_list): # data num
                if((cnt + 4) == len(dir_list)) :
                    break
                how_many_pick = random.randint(2, len(dir_list)-cnt-3)
                selected = self.get_random_num(cnt+1, len(dir_list)+1, how_many_pick)
                start_num =int(file[0].split('/')[2].split('.')[0].split('_')[1])
                
                if not(os.path.isdir(os.path.join(self.save_dir, str(start_num)))):
                        os.makedirs(os.path.join(self.save_dir, str(start_num)))
                        
                for sub_num, file_num in enumerate(selected): 
                    start_img, end_img, flow_map, occ_mask = self.get_flo(int(start_num), int(file_num), tmp)
                    
                    if not(os.path.isdir(os.path.join(self.save_dir, str(start_num), str(sub_num)))):
                        os.makedirs(os.path.join(self.save_dir, str(start_num), str(sub_num)))
                    path = os.path.join(self.save_dir, str(start_num), str(sub_num))
                    imageio.imwrite(path+'/start.png', start_img.astype(np.uint8))
                    imageio.imwrite(path+'/end.png', end_img.astype(np.uint8))
                    imageio.imwrite(path+'/occlusion.png', occ_mask.numpy().astype(np.uint8))
                    self.save_flo(path+'/flow.flo', flow_map)
                    
                    
    def get_flo(self, start, end, train_samples):
        start_frame = start
        end_frame = end
        inputs, target, mask = train_samples[start_frame]
        inputs, target, mask = self.default_loader(self.root, inputs, target, mask)
        height, width, _ = target.shape
        
        ## Grid Location Point (x,y) matrix of size H X W
        Y, X = torch.meshgrid(torch.arange(0, height), torch.arange(0, width))
        Grid = torch.stack((X, Y), 2).float()
        
        ## Survived points in start frame image.  H x W
        survived_mask = torch.zeros(height, width).int() ## 0 for not vanished 1 for vanished  
        
        ## new_Grid as transformed pixel location, according to each pixel location in source frame
        new_Grid = Grid.clone()
        
        for cnt, ind in enumerate(range(start_frame, end_frame)):
            imgs, target, occ_mask = train_samples[ind]
            
            # target shape : torch.Size([436, 1024, 2])
            target = torch.from_numpy(self.load_flo(os.path.join(self.root, target)))   ## Flow annotation. H x W x 2   
            occ_mask = torch.from_numpy(imread(os.path.join(self.root, occ_mask))).int()
    
            survived_mask = survived_mask | occ_mask  ## whenever 1 is detected, remain it.
        
            new_Grid[:,:,0] = torch.clamp((new_Grid[:,:,0] + target[:,:,0]), 0, width-1)
            new_Grid[:,:,1] = torch.clamp((new_Grid[:,:,1] + target[:,:,1]), 0, height-1)
                
        inputs, target, mask = train_samples[start_frame]
        inputs, _, _ = self.default_loader(self.root, inputs, target, mask)
        outputs, target, occ_mask = train_samples[end_frame]    
        outputs, _, _ = self.default_loader(self.root, outputs, target, occ_mask)
        return inputs[0], outputs[0], new_Grid, survived_mask
   
                
    def get_random_num(self, start, finish, num):
        if(start >= finish) : print('finish number is same or less than start num');return -1
        if(num <= 1) : print('The number of data requires bigger then one'); return -1
        return sorted(random.sample(range(start+2, finish+1), num))
        
        

        

In [13]:
Test = ExtractDataset('sintel', ['alley_1', 'ambush_4', 'temple_2'], 'clean')

In [14]:
Test.extract_data()

KeyboardInterrupt: 