In [3]:
from __future__ import division
import torch
import random
import numpy as np
import os, glob
import cv2 as cv
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [2]:
def local_normalize_image(img):

    gray = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    float_gray = gray.astype(np.float32) / 255.0

    blur = cv.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
    num = float_gray - blur

    blur = cv.GaussianBlur(num*num, (0, 0), sigmaX=20, sigmaY=20)
    den = cv.pow(blur, 0.5)+0.0000001

    gray = num / den

    cv.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv.NORM_MINMAX)

    gray = np.concatenate((gray[:,:,np.newaxis],gray[:,:,np.newaxis],gray[:,:,np.newaxis]),axis=2)

    return gray

In [2]:
class DataLoader(object):
    """ Data loading class for training heatmap-attention-padding network

    Args:
        dataset_dir: Folder contain .tfrecords files
        batch_size: training batch size
        image_height, image_width: input image height and width
        opt: flags from input parser
    
    Returns:
        new_mask: A gauss smoothed tensor

    """
    def __init__(self, dataset_dir, batch_size, image_height, image_width, num_epochs, num_views):
        self.dataset_dir=dataset_dir
        self.batch_size=batch_size
        self.image_height=image_height
        self.image_width=image_width
        self.num_epochs = num_epochs
        self.num_views = num_views
        
    def inputs(self, is_training=True):
        """Reads input data num_epochs times.
        Args:
            batch_size: Number of examples per returned batch.
            num_epochs: Number of times to read the input data, or 0/None to
            train forever.
        Returns:
            data_dict: A dictional contain input image and groundtruth label
        """
        data_dict={}
        image_seq_list=[]
        image_seq_norm_list=[]
        depth_seq_list=[]
        intrinsics_list=[]
        
        def loader(example_pkl):
            image_seq=example_pkl['image_seq']
            depth_seq=example_pkl['depth_seq']
            intrinsics=example_pkl['intrinsics']
            image_seq_norm=image_seq.copy()
            
            for i in range (0, len(img_seq)):
                image_seq_norm[i]=local_normalize_image(image_seq_norm[i])
                image_seq_norm[i]=torch.reshape(image_seq_norm[i],(image_height, image_width*num_views,3))
                
            data_dict={}
            data_dict['image_seq']=image_seq
            data_dict['image_seq_norm']=image_seq_norm
            data_dict['depth_seq']=depth_seq
            data_dict['intrinsics']=intrinsics
            
            if is_training:
                data_dict=self.data_augmentation2(data_dict, self.image_height, self.image_width)
            
            return data_dict
        
        if not self.num_epochs:
            self.num_epochs=None
            
        filenames=glob.glob(os.path.join(self.dataset_dir, '*.pickle'))
        from random import shuffle
        shuffle(filenames)
        
        datasets=Dataloader(dataset=filenames, batch_size=3, shuffle=True)
        #Complete This
    
        
        
                
                
            
        

In [9]:
def data_augmentation2(self, data_dict, out_h, out_w, is_training=True):
    
    def flip_intrinsics(intrinsics, width):
        
        fx=intrinsics[0,0]
        fy=intrinsics[1,1]
        cx=width-intrinsics[0,2]
        cy=intrinsics[1,2]
        
        zeros=torch.zeros(fx.shape)
        r1=torch.stack([fx,zeros,cx])
        r2=torch.stack([zeros, fy, cy])
        r3=torch.tensor([0,0,1])
        intrinsics=torch.stack([r1,r2,r3], axis=0)
        
        return intrinscs
    
    def flip_left_right(image_seq, num_views):
        """Perform random distortions on an image.
            Args:
            image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
            thread_id: Preprocessing thread id used to select the ordering of color
              distortions. There should be a multiple of 2 preprocessing threads.
            Returns:
            distorted_image: A float32 Tensor of shape [height, width, 3] with values in
              [0, 1].
            """
        in_h, in_w, _ = image_seq.shape
        in_w=in_w/num_views
        
        for i in range(num_views):
            
            image=image_seq[0:-1, int(in_w)*i:int(in_w), 0:-1]
            image=flip_left_right(image)
            
            if i==0:
                flip_image=image
            else:
                flip_image=torch.cat([flip_image, image], axix=1)
                
        return flip_image
    
    def random_scaling(data_dict, num_views):
        in_h, in_w=data_dict['image_seq'][0].shape
        in_w = in_w/num_views
        
        scaling=torch.rand((1))
        x_scaling=scaling[0]
        y_scaling=scaling[0]
        out_h=torch.tensor(in_h*y_scaling).type('torch.IntTensor')
        out_w=torch.tensor(in_w*x_scaling).type('torch.IntTensor')
        
        scaled_image=[]
        scaled_depths=[]
        scaled_images_norm=[]
        
        for i in range(num_views):
            
            image=data_dict['image_seq'][i][0:-1, int(in_w)*i: int(in_w), 0:-1]
            image=image.transforms.CenterCrop((out_h,out_w))
            scaled_images.append(image)
            
            image_norm=data_dict['image_seq_norm'][i][0:-1, int(in_w)*i:int(in_w), 0:-1]   
            image_norm=F.interpolate(image_norm,(out_h, out_w))
            scaled_images_norm.append(image_norm)
            
            depth=data_dict['depth_seq'][i][0:-1, int(in_w)*i:int(in_w), 0:-1]
            depth=F.interpolate(depth,(out_h, out_w))
            scaled_depths.append(depth)
            
        return scaled_images, scaled_depths, scaled_images_norm
    
    
    def random_cropping(data_dict, scaled_images, scaled_depths, scaled_images_norm, num_views, out_h, out_w):
        in_h, in_w=torch.unbind(scaled_image[0].shape)
        
        offset_y=torch.rand([1], 0, in_h - out_h+1)
        offset_x=torch.rand([1], 0, in_w - out_w+1)
        
        _in_h=in_h.float32
        _in_w=in_w.float32
        _out_h=out_w.float32
        _out_w=out_w.float32
        
        fx = data_dict['intrinsics'][0,0]*_in_w/_out_w
        fy = data_dict['intrinsics'][1,1]*_in_h/_out_h
        cx = data_dict['intrinsics'][0,2]*_in_w/_out_w-offest_x.float32
        cy = data_dict['intrinsics'][0,2]*_in_h/_out_h-offest_y.float32
        
        zeros=torch.zeros(fx.shape)
        r1=torch.stack([fx,zeros,cx])
        r2=torch.stack([zeros, fy, cy])
        r3=torch.tensor([0,0,1])
        data_dict['intrinsics']=torch.stack([r1,r2,r3], axis=0)
        
        for i in range(num_views):
            
            if i==0:
                cropped_images = scaled_image[i][offset_y:offset_x, out_h:out_w]
                cropped_depths = scaled_depths[i][offset_y:offset_x, out_h:out_w]
                cropped_images_norm=scaled_images_norm[i][offset_y:offset_x, out_h:out_w]
            else:
                cropped_images=torch.cat(cropped_images, scaled_image[i][offset_y:offset_x, out_h:out_w], dim=1)
                cropped_depths = torch.cat(cropped_depths, scaled_depths[i][offset_y:offset_x, out_h:out_w], dim=1)
                cropped_images_norm=torch.cat(cropped_images_norm, scaled_images_norm[i][offset_y:offset_x, out_h:out_w], dim=1)
                
        data_dict['image_seq'] = cropped_images
        data_dict['depth_Seq'] = cropped_depths
        data_dict['image_seq_norm'] = cropped_images_norm
        
        return data_dict
    
    scaled_images, scaled_depths, scaled_images_norm = random_scaling(data_dict, self.num_views)
    
    data_dict = random_cropping(data_dict, scaled_images, scaled_depths, scaled_images_norm, self.num_views, out_h, out_w)
    
    return data_dict
        