In [None]:
# remove boxes?


In [2]:
# from scipy.ndimage.interpolation import zoom 
import numbers, math
import numpy as np
import torch.utils.data as data
import torch
import torchvision.transforms as transforms
from datasets.point_meta import Point_Meta
import os 
import os.path as osp
from PIL import Image
from utils import generate_label_map_laplacian
from utils import generate_label_map_gaussian

In [None]:
def load_txt_file(file_path):
    '''
    load data or string from text file.
    '''
    file_path = osp.normpath(file_path)
    assert osp.exists(file_path), 'text file is not existing!'

    with open(file_path, 'r') as file:
        data = file.read().splitlines()
    num_lines = len(data)
    file.close()
    return data, num_lines
def annot_parser_300W(anno_path, n_pts):  
    data = load_txt_file(anno_path)[0]
    points = np.zeros((3,n_pts), dtype='float32')
    point_set = set()
    offset = 3
    for idx in range(n_pts):
        line = data[idx+offset].split(' ')
        if len(line) > 2 : line.remove(' ')
        points[0, idx] = float(line[0])
        points[1, idx] = float(line[1])
        points[2, idx] = float(1)
        point_set.add(idx)
    return points, point_set
def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

In [None]:
class GenDataset(data.Dataset):
    def __init__(self, transform, sigma, downsample, heatmap_type, dataset_name):
        self.transform = transform
        self.sigma = sigma
        self.downsample = downsample
        self.heatmap_type = heatmap_type
        self.dataset_name = dataset_name
        self.reset()
        print ('The general dataset initialization done, sigma is {}, downsample is {}, dataset-name : {}, self is : {}'.format(sigma, downsample, dataset_name, self))
    def reset(self, num_pts=68):
        self.length = 0
        self.NUM_PTS = num_pts
        self.datas = []
        self.labels = []
        self.face_sizes = []
        assert self.dataset_name is not None, 'The dataset name is None'  
    
    def load_list(self,train_list_file_paths, num_pts):#Extracts labels,boxes,faces information of training data
        if isinstance(train_list_file_paths, str) :
            train_list_file_paths = [train_list_file_paths]
        datas, labels, boxes, face_sizes = [], [], [], []
        for file_idx, file_path in enumerate(train_list_file_paths):
            train_list_file = open(file_path)
            train_list_data = train_list_file.read().splitlines()
            train_list_file.close()
            print(len(train_list_data))
            for jdx, data in enumerate(train_list_data):
                alls = data.split(' ')
                if '' in alls: alls.remove('')
                assert len(alls) == 6 or len(alls) == 7, 'The {:04d}-th line is wrong : {:}'.format(idx, data)
                datas.append(alls[0])
                if alls[1] == None:
                    labels.append(None)
                else: labels.append(alls[1])
                box = np.array( [ float(alls[2]), float(alls[3]), float(alls[4]), float(alls[5]) ] )
                boxes.append( box )
                if len(alls) == 6:
                    face_sizes.append( None )
                else:
                    face_sizes.append( float(alls[6]) )
        #     print(len(alls))
        self.load_data(datas, labels, boxes, face_sizes, num_pts)
        
    def load_data(self, datas, labels, boxes, face_sizes, num_pts):# Loads images, labels, boxes
        print ('Start load data for the general datas')
        assert isinstance(datas, list), 'The type of the datas is not correct : {}'.format( type(datas) )
        assert isinstance(labels, list) and len(datas) == len(labels), 'The type of the labels is not correct : {}'.format( type(labels) )
        assert isinstance(boxes, list) and len(datas) == len(boxes), 'The type of the boxes is not correct : {}'.format( type(boxes) )
        assert isinstance(face_sizes, list) and len(datas) == len(face_sizes), 'The type of the face_sizes is not correct : {}'.format( type(face_sizes) )
        assert num_pts == 68, 'The number of point is inconsistent : {} vs {}'.format(68, num_pts)
        
        for idx, data in enumerate(datas):
            assert isinstance(data, str), 'The type of data is not correct : {}'.format(data)
            assert osp.isfile(datas[idx]), '{} is not a file'.format(datas[idx])
            self.append(datas[idx], labels[idx], boxes[idx], face_sizes[idx])
            
        assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas))   
        assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels))
        assert len(self.face_sizes) == self.length, 'The length and the face_sizes is not right {} vs {}'.format(self.length, len(self.face_sizes))
        print ('Load data done for the general dataset, which has {} images.'.format(self.length))
        
    def append(self, data, label, box, face_size):
        assert osp.isfile(data), 'The image path is not a file {}'.format(data)
        assert osp.isfile(label), 'The label path is not a file {}'.format(label)
        self.datas.append(data)
        np_points,_ = annot_parser_300W(label, self.NUM_PTS)
        meta = Point_Meta(self.NUM_PTS, np_points, box, data, self.dataset_name)
        self.labels.append(meta)
        self.face_sizes.append(face_size)
        self.length += 1
    
    def __getitem__(self, index):
        image = pil_loader( self.datas[index] )
        xtarget = self.labels[index].copy()
        return self._process_(image, xtarget, index)
    def __len__(self):
        return self.length 
    
    def _process_(self, image, xtarget, index):
        visible = xtarget.points[2,:].astype('bool')
        if self.transform is not None:
            image, xtarget = self.transform(image, xtarget)
#         temp_save_wh = xtarget.temp_save_wh
        
#         ori_size = torch.IntTensor( [temp_save_wh[1], temp_save_wh[0], temp_save_wh[2], temp_save_wh[3]] ) # H, W, Cropped_[x1,y1]
    
        if isinstance(image, Image.Image):
            height, width = image.size[1], image.size[0]
#             print('Image converted to Tensor using transforms')
            trans = transforms.ToTensor()
            image = trans(image)
        elif isinstance(image, torch.FloatTensor):
            height, width = image.size(1),  image.size(2)
        else:
            raise Exception('Unknown type of image : {}'.format( type(image) ))
        if xtarget.is_none() == False:
            xtarget.apply_bound(width, height)
            points = xtarget.points.copy()
            points = torch.from_numpy(points.transpose((1,0))).type(torch.FloatTensor)
            Hpoint = xtarget.points.copy()
        else:
            points = torch.from_numpy(np.zeros((self.NUM_PTS,3))).type(torch.FloatTensor)
            Hpoint = self.NUM_PTS
        if self.heatmap_type == 'laplacian':
            target, mask = generate_label_map_laplacian(Hpoint, height//self.downsample, width//self.downsample, self.sigma, self.downsample, visible) # H*W*C
        elif self.heatmap_type == 'gaussian':
            target, mask = generate_label_map_gaussian(Hpoint, height//self.downsample, width//self.downsample, self.sigma, self.downsample, visible) # H*W*C
        else:
            raise Exception('Unknown type of image : {}'.format( type(image) ))
        target = torch.from_numpy(target.transpose((2, 0, 1))).type(torch.FloatTensor)
        mask   = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
  
        torch_index = torch.IntTensor([index])
        torch_indicate = torch.ByteTensor( [ xtarget.is_none() == False ] )
        return image, target, mask, points, torch_index, torch_indicate
#     , ori_size

In [None]:
path = '/home/abhirup/Datasets/300W-Style/box-coords/300W-Original/300w.train.GTB'