-
Notifications
You must be signed in to change notification settings - Fork 152
/
Copy pathcfp.py
89 lines (72 loc) · 2.5 KB
/
cfp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python
# encoding: utf-8
'''
@author: wujiyang
@contact: wujiyang@hust.edu.cn
@file: cfp.py
@time: 2018/12/26 16:19
@desc: the CFP-FP test dataset loader, it's similar with lfw and adedb, except that it has 700 pairs every fold
'''
import numpy as np
import cv2
import os
import torch.utils.data as data
import torch
import torchvision.transforms as transforms
def img_loader(path):
try:
with open(path, 'rb') as f:
img = cv2.imread(path)
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
return img
except IOError:
print('Cannot load image ' + path)
class CFP_FP(data.Dataset):
def __init__(self, root, file_list, transform=None, loader=img_loader):
self.root = root
self.file_list = file_list
self.transform = transform
self.loader = loader
self.nameLs = []
self.nameRs = []
self.folds = []
self.flags = []
with open(file_list) as f:
pairs = f.read().splitlines()
for i, p in enumerate(pairs):
p = p.split(' ')
nameL = p[0]
nameR = p[1]
fold = i // 700
flag = int(p[2])
self.nameLs.append(nameL)
self.nameRs.append(nameR)
self.folds.append(fold)
self.flags.append(flag)
def __getitem__(self, index):
img_l = self.loader(os.path.join(self.root, self.nameLs[index]))
img_r = self.loader(os.path.join(self.root, self.nameRs[index]))
imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)]
if self.transform is not None:
for i in range(len(imglist)):
imglist[i] = self.transform(imglist[i])
imgs = imglist
return imgs
else:
imgs = [torch.from_numpy(i) for i in imglist]
return imgs
def __len__(self):
return len(self.nameLs)
if __name__ == '__main__':
root = '/media/sda/CFP-FP/CFP_FP_aligned_112'
file_list = '/media/sda/CFP-FP/cfp-fp-pair.txt'
transform = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
])
dataset = CFP_FP(root, file_list, transform=transform)
trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False)
for data in trainloader:
for d in data:
print(d[0].shape)